In [1]:
import numpy as np

import torch
from torch import nn
from torch import fx
from torch.fx import GraphModule, Graph, Node

import brt
from brt.runtime import log
from brt.router import ScatterRouter, GatherRouter
from brt.router.fabric import make_fabric
from brt.trace import symbolic_trace, GraphTracer

# from brt.trace.graph import symbolic_trace
from brt.passes import VerticalFusePass

log.set_level("BRT", "DEBUG")


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import sys
from brt.runtime import BRT_CACHE_PATH

sys.path.append(str(BRT_CACHE_PATH.parent / "benchmark/livesr/"))

# from nas_mdsr import SingleNetwork as nas_mdsr
from archs.livesr import LiveSR
from dataset import get_dataloader

In [3]:
channels = 8


In [4]:
livesr = LiveSR(n_subnets=10, subnet_num_block=3, num_feature=channels).eval().cuda()

dataloader = get_dataloader(
    str(BRT_CACHE_PATH.parent / "benchmark/livesr/dataset/cam1/LQ")
)

for x in dataloader:
    break

livesr(x)
print(livesr.scatter.load_history)
livesr.scatter.load_history = np.array([6, 7, 12, 27, 8, 8, 8, 12, 12, 4], dtype=int)


[ 5.  4. 10. 26.  7.  5.  8.  8. 11.  4.]


In [5]:
gm_livesr = symbolic_trace(
    livesr,
    tracing_shape=True,
    sample_inputs={"inputs": x},
)
# print(gm_livesr.graph)
vertical_fuse_pass = VerticalFusePass(livesr, sample_inputs={"inputs": x}, fixing_scatters=True)
vertical_fuse_pass.run_on_graph()
gm_livesr = vertical_fuse_pass.finalize()
print(gm_livesr.graph)


graph():
    %inputs : torch.Tensor [#users=2] = placeholder[target=inputs] | unfixed
    %classifier : [#users=1] = call_module[target=classifier](args = (%inputs,), kwargs = {}) | unfixed
    %scatter : [#users=10] = call_module[target=scatter](args = (%inputs, %classifier), kwargs = {}) | fixed
    %getitem : [#users=1] = call_function[target=operator.getitem](args = (%scatter, 0), kwargs = {}) | fixed
    %subnets_0_head_0 : [#users=3] = call_module[target=subnets.0.head.0](args = (%getitem,), kwargs = {}) | fixed
    %subnets_0_body_0_0_body_0 : [#users=1] = call_module[target=subnets.0.body.0.0.body.0](args = (%subnets_0_head_0,), kwargs = {}) | fixed
    %subnets_0_body_0_0_body_1 : [#users=1] = call_module[target=subnets.0.body.0.0.body.1](args = (%subnets_0_body_0_0_body_0,), kwargs = {}) | fixed
    %subnets_0_body_0_0_body_2 : [#users=1] = call_module[target=subnets.0.body.0.0.body.2](args = (%subnets_0_body_0_0_body_1,), kwargs = {}) | fixed
    %add : [#users=2] = call_fun

In [None]:
class TestModule(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.router = ScatterRouter(
            fabric_type="_fused_dispatch",
            fabric_kwargs={
                "fixed_capacity": torch.tensor((4, 6), dtype=torch.int32).cuda()
            },
            capturing=True,
            capture_mode="max",
        )
        # self.y1_bias = torch.nn.Parameter(torch.randn(channels, 32, 32))
        self.y1_bias = torch.randn(channels, 32, 32).cuda()
        self.expert0 = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1), nn.ReLU()
        )
        self.expert1 = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
        )
        self.gather = GatherRouter()

    def forward(self, x=None):
        y0, y1 = self.router(
            x,
            torch.tensor(((3, 0), (2, 3), (0, 1)), dtype=torch.int32).cuda(),
            # torch.randint(0, 3, (3, 2)).cuda(),
        )
        y0 += self.expert0(y0)
        y1 = self.expert1(y1)
        self.y1_bias += self.y1_bias
        y1 += self.y1_bias
        y = self.gather([y0, y1])
        return y


In [None]:
m = TestModule().eval().cuda()
for _ in range(2):
    m(torch.randn(3, channels, 32, 32).cuda())
print(m.router.load_history)
gm = symbolic_trace(
    m,
    tracing_shape=True,
    sample_inputs={"x": torch.randn(3, channels, 32, 32).cuda()},
    fixed_inputs=True,
)
# print(gm.graph)
print(gm.graph)
for node in gm.graph.nodes:
    print(node.name, node.is_fixed_inout, node.inshape, node.outshape)


In [None]:
m = TestModule().eval().cuda()
gm = symbolic_trace(m)
print(gm.graph)


In [None]:
vertical_fuse_pass = VerticalFusePass(
    m,
    sample_inputs={"x": torch.randn(3, channels, 32, 32).cuda()},
    fixing_scatters=True,
    fixed_inputs=True,
)
vertical_fuse_pass.run_on_graph()
gm = vertical_fuse_pass.finalize()
print(gm.graph)
