In [1]:
import numpy as np

import torch
from torch import nn
from torch import fx
from torch.fx import GraphModule, Graph, Node

import brt
from brt.runtime import log
from brt.router import ScatterRouter, GatherRouter
from brt.router.fabric import make_fabric
from brt.trace import symbolic_trace, GraphTracer

# from brt.trace.graph import symbolic_trace
from brt.passes import VerticalFusePass

log.set_level("BRT", "DEBUG")


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import sys
from brt.runtime import BRT_CACHE_PATH

sys.path.append(str(BRT_CACHE_PATH.parent / "benchmark/livesr/"))

# from nas_mdsr import SingleNetwork as nas_mdsr
from archs.livesr import LiveSR
from dataset import get_dataloader


In [3]:
channels = 8


In [4]:
livesr = LiveSR(n_subnets=10, subnet_num_block=3, num_feature=channels).eval().cuda()

dataloader = get_dataloader(
    str(BRT_CACHE_PATH.parent / "benchmark/livesr/dataset/cam1/LQ")
)

for x in dataloader:
    break

livesr(x)
print(livesr.scatter.load_history)
livesr.scatter.load_history = np.array([6, 7, 12, 27, 8, 8, 8, 12, 12, 4], dtype=int)


[ 5.  4. 10. 26.  7.  5.  8.  8. 11.  4.]


In [5]:
print(symbolic_trace(livesr).graph)

gm_livesr = symbolic_trace(
    livesr,
    tracing_shape=True,
    sample_inputs={"inputs": x},
)
# print(gm_livesr.graph)
vertical_fuse_pass = VerticalFusePass(
    livesr, sample_inputs={"inputs": x}, fixing_scatters=True
)
vertical_fuse_pass.run_on_graph()
gm_livesr = vertical_fuse_pass.finalize()
print(gm_livesr.graph)

from torch.fx.passes.graph_drawer import FxGraphDrawer

graph_drawer = FxGraphDrawer(gm_livesr, "new_backbone")
with open("vfused_livesr.svg", "wb") as f:
    f.write(graph_drawer.get_dot_graph().create_svg())


graph():
    %inputs : torch.Tensor [#users=2] = placeholder[target=inputs] | unfixed
    %classifier : [#users=1] = call_module[target=classifier](args = (%inputs,), kwargs = {}) | unfixed
    %scatter : [#users=10] = call_module[target=scatter](args = (%inputs, %classifier), kwargs = {}) | unfixed
    %getitem : [#users=1] = call_function[target=operator.getitem](args = (%scatter, 0), kwargs = {}) | unfixed
    %subnets_0_head_0 : [#users=3] = call_module[target=subnets.0.head.0](args = (%getitem,), kwargs = {}) | unfixed
    %subnets_0_body_0_0_body_0 : [#users=1] = call_module[target=subnets.0.body.0.0.body.0](args = (%subnets_0_head_0,), kwargs = {}) | unfixed
    %subnets_0_body_0_0_body_1 : [#users=1] = call_module[target=subnets.0.body.0.0.body.1](args = (%subnets_0_body_0_0_body_0,), kwargs = {}) | unfixed
    %subnets_0_body_0_0_body_2 : [#users=1] = call_module[target=subnets.0.body.0.0.body.2](args = (%subnets_0_body_0_0_body_1,), kwargs = {}) | unfixed
    %add : [#users=2

In [6]:
print(type(x))

y = livesr(x)

gm_livesr.delete_all_unused_submodules()

all_hooks = []
target_of_module = {}
scatter_outputs = [None]
try:
    for subn, subm in gm_livesr.named_modules():
        if "classifier" in subn:
            continue
        target_of_module[subm] = subn

        def print_hook(m: nn.Module, i, o):
            name = target_of_module[m]
            print(
                f"{name:50} {m._get_name():20}{str([ii.__class__.__name__ for ii in i]):30}{'[' + o[0].__class__.__name__ + ', ...]' if isinstance(o, (tuple, list)) else o.__class__.__name__:20s}"
            )

        all_hooks.append(subm.register_forward_hook(print_hook))
        
        def get_scatter_outputs(m, i, o):
            scatter_outputs[0] = o

        if isinstance(subm, ScatterRouter):
            all_hooks.append(subm.register_forward_hook(get_scatter_outputs))

    vy = gm_livesr(x)
finally:
    for hook in all_hooks:
        hook.remove()


<class 'torch.Tensor'>
scatter.protocol                                   TopKProtocol        ['Tensor']                    [Tensor, ...]       
scatter.fabric                                     FusedDispatchFabric ['Tensor', 'Tensor', 'Tensor', 'Tensor'][ProtoTensor, ...]  
scatter                                            ScatterRouter       ['Tensor', 'Tensor']          [ProtoTensor, ...]  
subnets.0.head.0                                   Conv2d              ['ProtoTensor']               ProtoTensor         
BRT_VF__subnets_0_body_0_0_body_0__subnets_0_body_0_0_body_1 BRT.Conv2dBiasReLU  ['ProtoTensor']               ProtoTensor         
BRT_VF__subnets_0_body_0_0_body_2__add             BRT.Conv2dBiasAdd   ['ProtoTensor', 'ProtoTensor']ProtoTensor         
BRT_VF__subnets_0_body_1_0_body_0__subnets_0_body_1_0_body_1 BRT.Conv2dBiasReLU  ['ProtoTensor']               ProtoTensor         
BRT_VF__subnets_0_body_1_0_body_2__add_1           BRT.Conv2dBiasAdd   ['ProtoTensor', 'Proto

In [13]:
subm = gm_livesr.get_submodule('BRT_VF__subnets_0_body_0_0_body_0__subnets_0_body_0_0_body_1')
func = subm.function
print(func)
func_in = (scatter_outputs[0][0], subm.weight, subm.bias)
# inputs[0], self.weight, self.bias, *inputs[1:]
print([type(x) for x in func_in])
func_out = func.apply(*func_in)
print([type(x) for x in func_out])

<class 'brt.jit.modules.atom.AtomModule.make_function.<locals>.JitFunction'>
[<class 'brt.runtime.proto_tensor.ProtoTensor'>, <class 'torch.nn.parameter.Parameter'>, <class 'torch.nn.parameter.Parameter'>]
[<class 'brt.runtime.proto_tensor.ProtoTensor'>]


In [14]:
print(y.shape)
print(vy.shape)
print(torch.allclose(y, vy, rtol=1e-100, atol=1e-1))
print(torch.sum(y))
print(torch.sum(y - vy))
print(torch.sum(torch.abs(y)))
print(torch.sum(torch.abs(y - vy)))
print(torch.max(y))
print(torch.max(y - vy))
print(torch.min(torch.abs(y)))
print(torch.min(torch.abs(y - vy)))
print(torch.min(y))
print(torch.min(y - vy))


torch.Size([88, 3, 128, 128])
torch.Size([88, 3, 128, 128])
True
tensor(-219514.5312, device='cuda:0', grad_fn=<SumBackward0>)
tensor(-0.0451, device='cuda:0', grad_fn=<SumBackward0>)
tensor(20435528., device='cuda:0', grad_fn=<SumBackward0>)
tensor(52.0959, device='cuda:0', grad_fn=<SumBackward0>)
tensor(816.0132, device='cuda:0', grad_fn=<MaxBackward1>)
tensor(0.0100, device='cuda:0', grad_fn=<MaxBackward1>)
tensor(0., device='cuda:0', grad_fn=<MinBackward1>)
tensor(0., device='cuda:0', grad_fn=<MinBackward1>)
tensor(-437.6592, device='cuda:0', grad_fn=<MinBackward1>)
tensor(-0.0080, device='cuda:0', grad_fn=<MinBackward1>)


In [9]:
class TestModule(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.router = ScatterRouter(
            fabric_type="_fused_dispatch",
            fabric_kwargs={
                "fixed_capacity": torch.tensor((4, 6), dtype=torch.int32).cuda()
            },
            capturing=True,
            capture_mode="max",
        )
        # self.y1_bias = torch.nn.Parameter(torch.randn(channels, 32, 32))
        self.y1_bias = torch.randn(channels, 32, 32).cuda()
        self.expert0 = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1), nn.ReLU()
        )
        self.expert1 = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
        )
        self.gather = GatherRouter()

    def forward(self, x=None):
        y0, y1 = self.router(
            x,
            torch.tensor(((3, 0), (2, 3), (0, 1)), dtype=torch.int32).cuda(),
            # torch.randint(0, 3, (3, 2)).cuda(),
        )
        y0 += self.expert0(y0)
        y1 = self.expert1(y1)
        self.y1_bias += self.y1_bias
        y1 += self.y1_bias
        y = self.gather([y0, y1])
        return y


In [10]:
m = TestModule().eval().cuda()
for _ in range(2):
    m(torch.randn(3, channels, 32, 32).cuda())
print(m.router.load_history)
gm = symbolic_trace(
    m,
    tracing_shape=True,
    sample_inputs={"x": torch.randn(3, channels, 32, 32).cuda()},
    fixed_inputs=True,
)
# print(gm.graph)
print(gm.graph)
for node in gm.graph.nodes:
    print(node.name, node.is_fixed_inout, node.inshape, node.outshape)


[1. 2.]
graph():
    %x : [#users=1] = placeholder[target=x](default=None) | fixed
    %_tensor_constant0 : [#users=1] = get_attr[target=_tensor_constant0] | fixed
    %router : [#users=2] = call_module[target=router](args = (%x, %_tensor_constant0), kwargs = {}) | fixed
    %getitem : [#users=2] = call_function[target=operator.getitem](args = (%router, 0), kwargs = {}) | fixed
    %getitem_1 : [#users=1] = call_function[target=operator.getitem](args = (%router, 1), kwargs = {}) | fixed
    %expert0_0 : [#users=1] = call_module[target=expert0.0](args = (%getitem,), kwargs = {}) | fixed
    %expert0_1 : [#users=1] = call_module[target=expert0.1](args = (%expert0_0,), kwargs = {}) | fixed
    %add : [#users=1] = call_function[target=operator.add](args = (%getitem, %expert0_1), kwargs = {}) | fixed
    %expert1_0 : [#users=1] = call_module[target=expert1.0](args = (%getitem_1,), kwargs = {}) | fixed
    %expert1_1 : [#users=1] = call_module[target=expert1.1](args = (%expert1_0,), kwargs =

In [11]:
m = TestModule().eval().cuda()
gm = symbolic_trace(m)
print(gm.graph)


graph():
    %x : [#users=1] = placeholder[target=x](default=None) | unfixed
    %_tensor_constant0 : [#users=1] = get_attr[target=_tensor_constant0] | unfixed
    %router : [#users=2] = call_module[target=router](args = (%x, %_tensor_constant0), kwargs = {}) | unfixed
    %getitem : [#users=2] = call_function[target=operator.getitem](args = (%router, 0), kwargs = {}) | unfixed
    %getitem_1 : [#users=1] = call_function[target=operator.getitem](args = (%router, 1), kwargs = {}) | unfixed
    %expert0_0 : [#users=1] = call_module[target=expert0.0](args = (%getitem,), kwargs = {}) | unfixed
    %expert0_1 : [#users=1] = call_module[target=expert0.1](args = (%expert0_0,), kwargs = {}) | unfixed
    %add : [#users=1] = call_function[target=operator.add](args = (%getitem, %expert0_1), kwargs = {}) | unfixed
    %expert1_0 : [#users=1] = call_module[target=expert1.0](args = (%getitem_1,), kwargs = {}) | unfixed
    %expert1_1 : [#users=1] = call_module[target=expert1.1](args = (%expert1_0,)

In [12]:
vertical_fuse_pass = VerticalFusePass(
    m,
    sample_inputs={"x": torch.randn(3, channels, 32, 32).cuda()},
    fixing_scatters=True,
    fixed_inputs=True,
)
vertical_fuse_pass.run_on_graph()
gm = vertical_fuse_pass.finalize()
print(gm.graph)


graph():
    %x : [#users=1] = placeholder[target=x](default=None) | fixed
    %_tensor_constant1 : [#users=1] = get_attr[target=_tensor_constant1] | fixed
    %router : [#users=2] = call_module[target=router](args = (%x, %_tensor_constant1), kwargs = {}) | fixed
    %getitem : [#users=2] = call_function[target=operator.getitem](args = (%router, 0), kwargs = {}) | fixed
    %getitem_1 : [#users=1] = call_function[target=operator.getitem](args = (%router, 1), kwargs = {}) | fixed
    %expert0_0 : [#users=1] = call_module[target=expert0.0](args = (%getitem,), kwargs = {}) | fixed
    %expert0_1 : [#users=1] = call_module[target=expert0.1](args = (%expert0_0,), kwargs = {}) | fixed
    %add : [#users=1] = call_function[target=operator.add](args = (%getitem, %expert0_1), kwargs = {}) | fixed
    %expert1_0 : [#users=1] = call_module[target=expert1.0](args = (%getitem_1,), kwargs = {}) | fixed
    %expert1_1 : [#users=1] = call_module[target=expert1.1](args = (%expert1_0,), kwargs = {}) | f