In [1]:
import numpy as np

import torch
from torch import nn
from torch import fx
from torch.fx import GraphModule, Graph, Node

import brt
from brt.runtime import log
from brt.router import ScatterRouter, GatherRouter
from brt.router.fabric import make_fabric
from brt.trace import symbolic_trace, GraphTracer

# from brt.trace.graph import symbolic_trace
from brt.passes import VerticalFusePass

log.set_level("BRT", "DEBUG")


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import sys
from brt.runtime import BRT_CACHE_PATH

sys.path.append(str(BRT_CACHE_PATH.parent / "benchmark/livesr/"))

# from nas_mdsr import SingleNetwork as nas_mdsr
from archs.livesr import LiveSR
from dataset import get_dataloader

In [3]:
channels = 8


In [4]:
livesr = LiveSR(n_subnets=10, subnet_num_block=3, num_feature=channels).eval().cuda()

dataloader = get_dataloader(
    str(BRT_CACHE_PATH.parent / "benchmark/livesr/dataset/cam1/LQ")
)

for x in dataloader:
    break

livesr(x)
print(livesr.scatter.load_history)
livesr.scatter.load_history = np.array([6, 7, 12, 27, 8, 8, 8, 12, 12, 4], dtype=int)


[ 5.  4. 10. 26.  7.  5.  8.  8. 11.  4.]


In [5]:
print(symbolic_trace(livesr).graph)

gm_livesr = symbolic_trace(
    livesr,
    tracing_shape=True,
    sample_inputs={"inputs": x},
)
# print(gm_livesr.graph)
vertical_fuse_pass = VerticalFusePass(livesr, sample_inputs={"inputs": x}, fixing_scatters=True)
vertical_fuse_pass.run_on_graph()
gm_livesr = vertical_fuse_pass.finalize()
print(gm_livesr.graph)


graph():
    %inputs : torch.Tensor [#users=2] = placeholder[target=inputs] | unfixed
    %classifier : [#users=1] = call_module[target=classifier](args = (%inputs,), kwargs = {}) | unfixed
    %scatter : [#users=10] = call_module[target=scatter](args = (%inputs, %classifier), kwargs = {}) | unfixed
    %getitem : [#users=1] = call_function[target=operator.getitem](args = (%scatter, 0), kwargs = {}) | unfixed
    %subnets_0_head_0 : [#users=3] = call_module[target=subnets.0.head.0](args = (%getitem,), kwargs = {}) | unfixed
    %subnets_0_body_0_0_body_0 : [#users=1] = call_module[target=subnets.0.body.0.0.body.0](args = (%subnets_0_head_0,), kwargs = {}) | unfixed
    %subnets_0_body_0_0_body_1 : [#users=1] = call_module[target=subnets.0.body.0.0.body.1](args = (%subnets_0_body_0_0_body_0,), kwargs = {}) | unfixed
    %subnets_0_body_0_0_body_2 : [#users=1] = call_module[target=subnets.0.body.0.0.body.2](args = (%subnets_0_body_0_0_body_1,), kwargs = {}) | unfixed
    %mul : [#users=1

In [6]:
y = livesr(x)
vy = gm_livesr(x)
print(torch.allclose(y, vy, rtol=1e-100, atol=1e-2))
print(torch.sum(y))
print(torch.sum(y-vy))
print(torch.sum(torch.abs(y)))
print(torch.sum(torch.abs(y-vy)))
print(torch.max(y))
print(torch.max(y-vy))
print(torch.min(torch.abs(y)))
print(torch.min(torch.abs(y-vy)))
print(torch.min(y))
print(torch.min(y-vy))

False
tensor(-1624864.5000, device='cuda:0', grad_fn=<SumBackward0>)
tensor(-0.1697, device='cuda:0', grad_fn=<SumBackward0>)
tensor(23430906., device='cuda:0', grad_fn=<SumBackward0>)
tensor(54.6917, device='cuda:0', grad_fn=<SumBackward0>)
tensor(526.6124, device='cuda:0', grad_fn=<MaxBackward1>)
tensor(0.0145, device='cuda:0', grad_fn=<MaxBackward1>)
tensor(0., device='cuda:0', grad_fn=<MinBackward1>)
tensor(0., device='cuda:0', grad_fn=<MinBackward1>)
tensor(-616.9142, device='cuda:0', grad_fn=<MinBackward1>)
tensor(-0.0155, device='cuda:0', grad_fn=<MinBackward1>)


In [7]:
class TestModule(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.router = ScatterRouter(
            fabric_type="_fused_dispatch",
            fabric_kwargs={
                "fixed_capacity": torch.tensor((4, 6), dtype=torch.int32).cuda()
            },
            capturing=True,
            capture_mode="max",
        )
        # self.y1_bias = torch.nn.Parameter(torch.randn(channels, 32, 32))
        self.y1_bias = torch.randn(channels, 32, 32).cuda()
        self.expert0 = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1), nn.ReLU()
        )
        self.expert1 = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
        )
        self.gather = GatherRouter()

    def forward(self, x=None):
        y0, y1 = self.router(
            x,
            torch.tensor(((3, 0), (2, 3), (0, 1)), dtype=torch.int32).cuda(),
            # torch.randint(0, 3, (3, 2)).cuda(),
        )
        y0 += self.expert0(y0)
        y1 = self.expert1(y1)
        self.y1_bias += self.y1_bias
        y1 += self.y1_bias
        y = self.gather([y0, y1])
        return y


In [8]:
m = TestModule().eval().cuda()
for _ in range(2):
    m(torch.randn(3, channels, 32, 32).cuda())
print(m.router.load_history)
gm = symbolic_trace(
    m,
    tracing_shape=True,
    sample_inputs={"x": torch.randn(3, channels, 32, 32).cuda()},
    fixed_inputs=True,
)
# print(gm.graph)
print(gm.graph)
for node in gm.graph.nodes:
    print(node.name, node.is_fixed_inout, node.inshape, node.outshape)


[1. 2.]
graph():
    %x : [#users=1] = placeholder[target=x](default=None) | fixed
    %_tensor_constant0 : [#users=1] = get_attr[target=_tensor_constant0] | fixed
    %router : [#users=2] = call_module[target=router](args = (%x, %_tensor_constant0), kwargs = {}) | fixed
    %getitem : [#users=2] = call_function[target=operator.getitem](args = (%router, 0), kwargs = {}) | fixed
    %getitem_1 : [#users=1] = call_function[target=operator.getitem](args = (%router, 1), kwargs = {}) | fixed
    %expert0_0 : [#users=1] = call_module[target=expert0.0](args = (%getitem,), kwargs = {}) | fixed
    %expert0_1 : [#users=1] = call_module[target=expert0.1](args = (%expert0_0,), kwargs = {}) | fixed
    %add : [#users=1] = call_function[target=operator.add](args = (%getitem, %expert0_1), kwargs = {}) | fixed
    %expert1_0 : [#users=1] = call_module[target=expert1.0](args = (%getitem_1,), kwargs = {}) | fixed
    %expert1_1 : [#users=1] = call_module[target=expert1.1](args = (%expert1_0,), kwargs =

In [9]:
m = TestModule().eval().cuda()
gm = symbolic_trace(m)
print(gm.graph)


graph():
    %x : [#users=1] = placeholder[target=x](default=None) | unfixed
    %_tensor_constant0 : [#users=1] = get_attr[target=_tensor_constant0] | unfixed
    %router : [#users=2] = call_module[target=router](args = (%x, %_tensor_constant0), kwargs = {}) | unfixed
    %getitem : [#users=2] = call_function[target=operator.getitem](args = (%router, 0), kwargs = {}) | unfixed
    %getitem_1 : [#users=1] = call_function[target=operator.getitem](args = (%router, 1), kwargs = {}) | unfixed
    %expert0_0 : [#users=1] = call_module[target=expert0.0](args = (%getitem,), kwargs = {}) | unfixed
    %expert0_1 : [#users=1] = call_module[target=expert0.1](args = (%expert0_0,), kwargs = {}) | unfixed
    %add : [#users=1] = call_function[target=operator.add](args = (%getitem, %expert0_1), kwargs = {}) | unfixed
    %expert1_0 : [#users=1] = call_module[target=expert1.0](args = (%getitem_1,), kwargs = {}) | unfixed
    %expert1_1 : [#users=1] = call_module[target=expert1.1](args = (%expert1_0,)

In [10]:
vertical_fuse_pass = VerticalFusePass(
    m,
    sample_inputs={"x": torch.randn(3, channels, 32, 32).cuda()},
    fixing_scatters=True,
    fixed_inputs=True,
)
vertical_fuse_pass.run_on_graph()
gm = vertical_fuse_pass.finalize()
print(gm.graph)


graph():
    %x : [#users=1] = placeholder[target=x](default=None) | fixed
    %_tensor_constant1 : [#users=1] = get_attr[target=_tensor_constant1] | fixed
    %router : [#users=2] = call_module[target=router](args = (%x, %_tensor_constant1), kwargs = {}) | fixed
    %getitem : [#users=2] = call_function[target=operator.getitem](args = (%router, 0), kwargs = {}) | fixed
    %getitem_1 : [#users=1] = call_function[target=operator.getitem](args = (%router, 1), kwargs = {}) | fixed
    %expert0_0 : [#users=1] = call_module[target=expert0.0](args = (%getitem,), kwargs = {}) | fixed
    %expert0_1 : [#users=1] = call_module[target=expert0.1](args = (%expert0_0,), kwargs = {}) | fixed
    %add : [#users=1] = call_function[target=operator.add](args = (%getitem, %expert0_1), kwargs = {}) | fixed
    %expert1_0 : [#users=1] = call_module[target=expert1.0](args = (%getitem_1,), kwargs = {}) | fixed
    %expert1_1 : [#users=1] = call_module[target=expert1.1](args = (%expert1_0,), kwargs = {}) | f