In [1]:
import torch
from torch import nn
from torch import fx
from torch.fx import GraphModule, Graph, Node

import brt
from brt.trace import symbolic_trace, GraphTracer

# from brt.trace.graph import symbolic_trace
from brt.passes import VerticalFusePass
from brt.router import ScatterRouter, GatherRouter


[2023-01-16 17:05:36] INFO (numexpr.utils/MainThread) Note: detected 96 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2023-01-16 17:05:36] INFO (numexpr.utils/MainThread) Note: NumExpr detected 96 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
[2023-01-16 17:05:36] INFO (numexpr.utils/MainThread) NumExpr defaulting to 8 threads.


In [2]:
class TestModule(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.router = ScatterRouter(
            fabric_type="_fused_dispatch",
            fabric_kwargs={
                "fixed_capacity": torch.tensor((4, 6), dtype=torch.int32).cuda()
            },
            capturing=True,
            capture_mode="max",
        )
        # self.y1_bias = torch.nn.Parameter(torch.randn(3, 32, 32))
        self.y1_bias = torch.randn(3, 1, 1).cuda()
        self.expert0 = nn.Sequential(nn.Conv2d(3, 3, 1), nn.ReLU())
        self.expert1 = nn.Sequential(nn.Conv2d(3, 3, 1), nn.ReLU(), nn.Conv2d(3, 3, 1),)
        self.gather = GatherRouter()

    def forward(self, x=None):
        y0, y1 = self.router(
            x,
            torch.tensor(((3, 0), (2, 3), (0, 1)), dtype=torch.int32).cuda(),
            # torch.randint(0, 3, (3, 2)).cuda(),
        )
        y0 += self.expert0(y0)
        y1 = self.expert1(y1)
        self.y1_bias += self.y1_bias
        y1 += self.y1_bias
        y = self.gather([y0, y1])
        return y


In [3]:
m = TestModule().eval().cuda()
gm = symbolic_trace(m)
print(gm.graph)
vertical_fuse_pass = VerticalFusePass(
    m, sample_inputs={"x": torch.randn(3, 3, 1, 1).cuda()}, fixing_scatters=True, fixed_inputs=True
)
vertical_fuse_pass.run_on_graph()
gm = vertical_fuse_pass.finalize()
print(gm.graph)


graph():
    %x : [#users=1] = placeholder[target=x](default=None) | unfixed
    %_tensor_constant0 : [#users=1] = get_attr[target=_tensor_constant0] | unfixed
    %router : [#users=2] = call_module[target=router](args = (%x, %_tensor_constant0), kwargs = {}) | unfixed
    %getitem : [#users=2] = call_function[target=operator.getitem](args = (%router, 0), kwargs = {}) | unfixed
    %getitem_1 : [#users=1] = call_function[target=operator.getitem](args = (%router, 1), kwargs = {}) | unfixed
    %expert0_0 : [#users=1] = call_module[target=expert0.0](args = (%getitem,), kwargs = {}) | unfixed
    %expert0_1 : [#users=1] = call_module[target=expert0.1](args = (%expert0_0,), kwargs = {}) | unfixed
    %add : [#users=1] = call_function[target=operator.add](args = (%getitem, %expert0_1), kwargs = {}) | unfixed
    %expert1_0 : [#users=1] = call_module[target=expert1.0](args = (%getitem_1,), kwargs = {}) | unfixed
    %expert1_1 : [#users=1] = call_module[target=expert1.1](args = (%expert1_0,)

ValueError: No kernel found in database with identifier = '{"device_name": "NVIDIA_GeForce_RTX_3090", "input_infos": {"input_0": [4, 3, 1, 1]}, "method": "forward", "op_type": "Conv2dBiasReLUAdd", "output_infos": {"output_0": [4, 3, 1, 1]}, "parameters": {"dilation": [1, 1], "groups": 1, "in_channels": 3, "kernel_size": [1, 1], "out_channels": 3, "padding": [0, 0], "stride": [1, 1]}}', objective_func = 'fastest'

In [None]:
m = TestModule().eval().cuda()
for _ in range(2):
    m(torch.randn(3, 3, 1, 1).cuda())
print(m.router.load_history)
gm = symbolic_trace(
    m,
    tracing_shape=True,
    sample_inputs={"x": torch.randn(3, 3, 1, 1).cuda()},
    fixed_inputs=True,
)
print(gm.graph)
for node in gm.graph.nodes:
    print(node.name, node.is_fixed_inout, node.inshape, node.outshape)


In [None]:
from benchmark.livesr.archs.nas_mdsr import SingleNetwork as nas_mdsr

livesr = nas_mdsr(1, 3, 3, 4, 2).eval()
print(livesr)
gm_livesr = symbolic_trace(livesr)
print(gm_livesr.graph)
vertical_fuse_pass = VerticalFusePass(gm_livesr)
vertical_fuse_pass.run_on_graph()
gm_livesr = vertical_fuse_pass.finalize()
print(gm_livesr.graph)
