In [1]:
import numpy as np

import torch
from torch import nn
from torch import fx
from torch.fx import GraphModule, Graph, Node

import brt
from brt.runtime import log
from brt.router import ScatterRouter, GatherRouter
from brt.router.fabric import make_fabric
from brt.trace import symbolic_trace, GraphTracer

# from brt.trace.graph import symbolic_trace
from brt.passes import VerticalFusePass, RouterFixPass

log.set_level("BRT", "DEBUG")


In [2]:
import sys
from brt.runtime import BRT_CACHE_PATH

sys.path.append(str(BRT_CACHE_PATH.parent / "benchmark/livesr/"))

# from nas_mdsr import SingleNetwork as nas_mdsr
from archs.livesr import LiveSR
from dataset import get_dataloader


In [3]:
channels = 8


In [4]:
livesr = LiveSR(n_subnets=10, subnet_num_block=3, num_feature=channels).eval().cuda()

dataloader = get_dataloader(
    str(BRT_CACHE_PATH.parent / "benchmark/livesr/dataset/cam1/LQ")
)

for x in dataloader:
    break

livesr(x)
print(livesr.scatter.load_history)
livesr.scatter.load_history = np.array([6, 7, 12, 27, 8, 8, 8, 12, 12, 4], dtype=int)




[ 6.  5. 10. 24.  6.  6.  6. 11. 10.  4.]


https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [5]:
gm_livesr = symbolic_trace(
    livesr,
    tracing_shape=True,
    sample_inputs={"inputs": x},
)
print(gm_livesr.graph)

router_fix_pass = RouterFixPass(livesr)
router_fix_pass.run_on_graph()
livesr_rf = router_fix_pass.finalize()


graph():
    %inputs : torch.Tensor [#users=2] = placeholder[target=inputs] | fixed
    %classifier : [#users=1] = call_module[target=classifier](args = (%inputs,), kwargs = {}) | fixed
    %scatter : [#users=10] = call_module[target=scatter](args = (%inputs, %classifier), kwargs = {}) | fixed
    %getitem : [#users=1] = call_function[target=operator.getitem](args = (%scatter, 0), kwargs = {}) | fixed
    %subnets_0_head_0 : [#users=3] = call_module[target=subnets.0.head.0](args = (%getitem,), kwargs = {}) | fixed
    %subnets_0_body_0_0_body_0 : [#users=1] = call_module[target=subnets.0.body.0.0.body.0](args = (%subnets_0_head_0,), kwargs = {}) | fixed
    %subnets_0_body_0_0_body_1 : [#users=1] = call_module[target=subnets.0.body.0.0.body.1](args = (%subnets_0_body_0_0_body_0,), kwargs = {}) | fixed
    %subnets_0_body_0_0_body_2 : [#users=1] = call_module[target=subnets.0.body.0.0.body.2](args = (%subnets_0_body_0_0_body_1,), kwargs = {}) | fixed
    %add : [#users=2] = call_functio

In [6]:
vertical_fuse_pass = VerticalFusePass(livesr_rf, sample_inputs={"inputs": x})
vertical_fuse_pass.run_on_graph()
livesr_vf = vertical_fuse_pass.finalize()


start node `inputs` should be a fixed module node
can't find jit module
node `%scatter` is a router node
start node `getitem` should be a fixed module node
node `%subnets_0_head_0` has more than 1 users, last try
fuse node `%subnets_0_head_0`

module name: Conv2dBias
input_infos: {'input_0': [6, 3, 32, 32]}
output_infos: {'output_0': [6, 8, 32, 32]}
parameters: {'in_channels': 3, 'out_channels': 8, 'kernel_size': (3, 3), 'stride': (1, 1), 'padding': (1, 1), 'dilation': (1, 1), 'groups': 1}

No kernel found in database with identifier = '{"device_name": "NVIDIA_A100_80GB_PCIe", "input_infos": {"input_0": [6, 3, 32, 32]}, "method": "forward", "op_type": "Conv2dBias", "output_infos": {"output_0": [6, 8, 32, 32]}, "parameters": {"dilation": [1, 1], "groups": 1, "in_channels": 3, "kernel_size": [3, 3], "out_channels": 8, "padding": [1, 1], "stride": [1, 1]}}', objective_func = 'fastest'
Fail to make jit module for nodes ['subnets_0_head_0']. Is this kernel already tuned?
fuse node `%subnets

In [7]:
print(livesr_vf.graph)

# from torch.fx.passes.graph_drawer import FxGraphDrawer

# graph_drawer = FxGraphDrawer(livesr_vf, "new_backbone")
# with open("vfused_livesr.svg", "wb") as f:
#     f.write(graph_drawer.get_dot_graph().create_svg())


graph():
    %inputs : torch.Tensor [#users=2] = placeholder[target=inputs] | fixed
    %classifier : [#users=1] = call_module[target=classifier](args = (%inputs,), kwargs = {}) | fixed
    %scatter : [#users=10] = call_module[target=scatter](args = (%inputs, %classifier), kwargs = {}) | fixed
    %getitem : [#users=2] = call_function[target=operator.getitem](args = (%scatter, 0), kwargs = {}) | fixed
    %deinit_grid_tensor_9 : [#users=1] = call_function[target=brt.runtime.grid_tensor.deinit_grid_tensor](args = (%getitem, False), kwargs = {}) | fixed
    %subnets_0_head_0 : [#users=3] = call_module[target=subnets.0.head.0](args = (%deinit_grid_tensor_9,), kwargs = {}) | fixed
    %BRT_VF__subnets_0_body_0_0_body_0__subnets_0_body_0_0_body_1 : [#users=1] = call_module[target=BRT_VF__subnets_0_body_0_0_body_0__subnets_0_body_0_0_body_1](args = (%subnets_0_head_0,), kwargs = {}) | fixed
    %BRT_VF__subnets_0_body_0_0_body_2__add : [#users=2] = call_module[target=BRT_VF__subnets_0_body_0

In [None]:
print(type(x))

y = livesr(x)

livesr_vf.delete_all_unused_submodules()

all_hooks = []
target_of_module = {}
scatter_outputs = [None]
try:
    for subn, subm in livesr_vf.named_modules():
        if "classifier" in subn:
            continue
        target_of_module[subm] = subn

        def print_hook(m: nn.Module, i, o):
            name = target_of_module[m]
            print(
                f"{name:50} {m._get_name():20}{str([ii.__class__.__name__ for ii in i]):30}{'[' + o[0].__class__.__name__ + ', ...]' if isinstance(o, (tuple, list)) else o.__class__.__name__:20s}"
            )

        all_hooks.append(subm.register_forward_hook(print_hook))

        def get_scatter_outputs(m, i, o):
            scatter_outputs[0] = o

        if isinstance(subm, ScatterRouter):
            all_hooks.append(subm.register_forward_hook(get_scatter_outputs))

    vy = livesr_vf(x)
finally:
    for hook in all_hooks:
        hook.remove()


In [None]:
subm = livesr_vf.get_submodule(
    "BRT_VF__subnets_0_body_0_0_body_0__subnets_0_body_0_0_body_1"
)
func = subm.function
print(func)
func_in = (scatter_outputs[0][0], subm.weight, subm.bias)
# inputs[0], self.weight, self.bias, *inputs[1:]
print([type(x) for x in func_in])
func_out = func.apply(*func_in)
print([type(x) for x in func_out])


In [None]:
y = livesr(x)
vy = livesr_vf(x)

print(y.shape)
print(vy.shape)
print(torch.allclose(y, vy, rtol=1e-100, atol=1e-1))
print(torch.sum(y))
print(torch.sum(y - vy))
print(torch.sum(torch.abs(y)))
print(torch.sum(torch.abs(y - vy)))
print(torch.max(y))
print(torch.max(y - vy))
print(torch.min(torch.abs(y)))
print(torch.min(torch.abs(y - vy)))
print(torch.min(y))
print(torch.min(y - vy))


In [None]:
class TestModule(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.router = ScatterRouter(
            fabric_type="_fused_dispatch",
            fabric_kwargs={
                "fixed_capacity": torch.tensor((4, 6), dtype=torch.int32).cuda()
            },
            capturing=True,
            capture_mode="max",
        )
        # self.y1_bias = torch.nn.Parameter(torch.randn(channels, 32, 32))
        self.y1_bias = torch.randn(channels, 32, 32).cuda()
        self.expert0 = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1), nn.ReLU()
        )
        self.expert1 = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
        )
        self.gather = GatherRouter()

    def forward(self, x=None):
        y0, y1 = self.router(
            x,
            torch.tensor(((3, 0), (2, 3), (0, 1)), dtype=torch.int32).cuda(),
            # torch.randint(0, 3, (3, 2)).cuda(),
        )
        y0 += self.expert0(y0)
        y1 = self.expert1(y1)
        self.y1_bias += self.y1_bias
        y1 += self.y1_bias
        y = self.gather([y0, y1])
        return y


In [None]:
m = TestModule().eval().cuda()
for _ in range(2):
    m(torch.randn(3, channels, 32, 32).cuda())
print(m.router.load_history)
gm = symbolic_trace(
    m,
    tracing_shape=True,
    sample_inputs={"x": torch.randn(3, channels, 32, 32).cuda()},
    fixed_inputs=True,
)
# print(gm.graph)
print(gm.graph)
for node in gm.graph.nodes:
    print(node.name, node.is_fixed_inout, node.inshape, node.outshape)


In [None]:
m = TestModule().eval().cuda()
gm = symbolic_trace(m)
print(gm.graph)


In [None]:
vertical_fuse_pass = VerticalFusePass(
    m,
    sample_inputs={"x": torch.randn(3, channels, 32, 32).cuda()},
    fixing_scatters=True,
    fixed_inputs=True,
)
vertical_fuse_pass.run_on_graph()
gm = vertical_fuse_pass.finalize()
print(gm.graph)
