## Initialize

In [1]:
%colors nocolor


In [2]:
import os
import inspect
import numpy as np

import torch
from torch import nn
from torch import fx
from torch.fx import GraphModule, Graph, Node
from torch.utils.benchmark import Timer

import brt
from brt.runtime import log
from brt.runtime import ProtoTensor
from brt.runtime.benchmark import profile
from brt.router import ScatterRouter, GatherRouter
from brt.router.fabric import make_fabric
from brt.trace import symbolic_trace, GraphTracer

# from brt.trace.graph import symbolic_trace
from brt.passes import (
    HorizFusePass,
    OperatorReorderPass,
    DeadPathEliminatePass,
    ConstantPropagationPass,
    RouterFixPass,
)

log.set_level("BRT", "DEBUG")

# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [3]:
import sys
from brt.runtime import BRT_CACHE_PATH
sys.path.append(str(BRT_CACHE_PATH.parent / "benchmark/livesr/"))
# from nas_mdsr import SingleNetwork as nas_mdsr
from archs.livesr import LiveSR
from dataset import get_dataloader

from argparse import Namespace

sys.path.append(str(BRT_CACHE_PATH.parent / "benchmark/msdnet/"))
from msdnet import MSDNet
from theshold_inference import threshold_dynamic_evaluate
from dataloader import get_dataloaders as msdnet_get_dataloaders


IndentationError: unexpected indent (livesr.py, line 49)

In [None]:
IS_PROFILING = False
# IS_PROFILING = True
IS_FUSING_HEAD = False
IS_FUSING_HEAD = True

## MSDNet

In [None]:
args = Namespace(
    arch="msdnet",
    base=4,
    batch_size=256,
    benchmark=["all_opt"],
    bnFactor=[1, 2, 4, 4],
    bottleneck=True,
    data="ImageNet",
    data_root="/home/v-louyang/dataset/imagenet",
    decay_rate=0.1,
    epochs=90,
    evalmode="threshold",
    evaluate_from="/home/v-louyang/brainstorm_project/brainstorm/benchmark/msdnet/msdnet-step=4-block=5.pth.tar",
    gpu="0,1,2,3",
    grFactor=[1, 2, 4, 4],
    growthRate=16,
    init_routers=True,
    lr=0.1,
    lr_type="multistep",
    momentum=0.9,
    nBlocks=5,
    nChannels=32,
    nScales=4,
    num_classes=1000,
    optimizer="sgd",
    parallel=True,
    print_freq=10,
    prune="max",
    reduction=0.5,
    resume=False,
    save="/home/v-louyang/brainstorm_project/brainstorm/benchmark/msdnet/saveresult",
    seed=0,
    splits=["val", "test"],
    start_epoch=0,
    step=4,
    stepmode="even",
    # thresholds=[0.44246858, -1, -1, -1], # 0.5 0.5 0 0
    # thresholds=[0.34071380, 0.47392023, 0.37517136, 0.22579938],  # 0.6 0.1 0.1 0.1 0.1
    thresholds=[1000000, 100000, 1000000, 100000], # 0, 0, 0, 0, 1
    use_valid=True,
    weight_decay=0.0001,
    workers=16,
)

msdnet: nn.Module = MSDNet(args, False).eval().cuda()
# pretrained_dict = torch.load("/home/v-louyang/brainstorm_project/brainstorm/benchmark/msdnet/MSDNet.pth")
state_dict = torch.load(
    "/home/v-louyang/brainstorm_project/brainstorm/benchmark/msdnet/MSDNet.pth"
)
# print([k for k, v in msdnet.named_parameters()])
# print([k for k, v in state_dict.items()])
msdnet.load_state_dict(state_dict)

_, val_dataloader, test_dataloader = msdnet_get_dataloaders(args)

# print(msdnet)
# print(input.shape)


def print_load_history(m: nn.Module):
    print("")
    for subn, subm in m.named_modules():
        if isinstance(subm, (ScatterRouter, GatherRouter)):
            # print(subm.load_history.shape)
            print(subm.load_history)


for i, (input, target) in enumerate(test_dataloader):
    input = input.cuda()
    if i > 100:
        break
    # if i % 1000 == 0:
    #     print_load_history(msdnet)
    print("*", end="")
    output = msdnet(input)

print("")
print_load_history(msdnet)

for i, (input, target) in enumerate(test_dataloader):
    input = input.cuda()
    if i == 13:
        break

y = msdnet(input)



In [None]:
test_inputs = []

for i, (test_input, target) in enumerate(test_dataloader):
    if i < 10:
        test_inputs.append(test_input.cuda())

In [None]:
if IS_PROFILING:
    profile(lambda: msdnet(input))

raw_time = []
for test_input in test_inputs:
    raw_time.append(
        Timer(
            f"model(x)",
            setup="import torch; torch.cuda.synchronize()",
            globals={"model": msdnet, "x": test_input},
        )
        .timeit(10)
        .mean
        * 10e6
    )

In [None]:
# gm_msdnet = symbolic_trace(
#     msdnet,
#     tracing_shape=True,
#     sample_inputs={"x": input},
# )
# print(gm_msdnet.graph)

In [None]:
# print(msdnet)

In [None]:
eliminate_pass = DeadPathEliminatePass(msdnet)
eliminate_pass.run_on_graph()
msdnet_dpe = eliminate_pass.finalize()

# constant_propagation_pass = ConstantPropagationPass(
#     msdnet, upper_perm_load=args.batch_size * n_batch
# )
# constant_propagation_pass.run_on_graph()
# msdnet = constant_propagation_pass.finalize()

operator_reorder_pass = OperatorReorderPass(msdnet_dpe, False)
operator_reorder_pass.run_on_graph()
msdnet_reorder = operator_reorder_pass.finalize()

horiz_fusion_pass = HorizFusePass(
    msdnet_reorder, sample_inputs={"x": input}, fusing_head=IS_FUSING_HEAD,
)
horiz_fusion_pass.run_on_graph()
msdnet_hf = horiz_fusion_pass.finalize()


In [None]:
print(msdnet_hf.graph)

In [None]:
print(msdnet_hf.code)


In [None]:
for node in msdnet_hf.graph.nodes:
    if node.op == "call_module" and node.is_fixed_inout:
        submodule = msdnet_hf.get_submodule(node.target)
        if not isinstance(submodule, (ScatterRouter, GatherRouter)):
            if "BRT_HF" not in node.name:
                continue
            print(f"{node.target}")
            submodule_input = msdnet_hf.graph._get_output_from_node_or_list(node.args)
            print([getattr(ii, "shape", None) for ii in submodule_input])
            print([ii.is_cuda for ii in submodule_input])
            print(submodule._module_name)
            print(submodule.cuda_code)
            # time.sleep(10000)
            break
            try:
                submodule(*submodule_input)
            except Exception as e:
                print(e)
                print(submodule_input)
        

In [None]:
from torch.fx.passes.graph_drawer import FxGraphDrawer
graph_drawer = FxGraphDrawer(msdnet_hf, "msdnet")
with open("msdnet_hfused.svg", "wb") as f:
    f.write(graph_drawer.get_dot_graph().create_svg())


In [None]:
msdnet: nn.Module = MSDNet(args, False).eval().cuda()
msdnet.load_state_dict(state_dict)


In [None]:
y = msdnet(input)
hy = msdnet_hf(input)
print(torch.allclose(y, hy, rtol=1e-100, atol=1e-2))

for i, (input, target) in enumerate(test_dataloader):
    input = input.cuda()
    if i > 100:
        break
    y = msdnet(input)
    hy = msdnet_hf(input)
    print(torch.allclose(y, hy, rtol=1e-100, atol=1e-1))
    if not torch.allclose(y, hy, rtol=1e-100, atol=1e-1):
        print(torch.sum(torch.abs(y)))
        print(torch.sum(torch.abs(y - hy)))
        print(torch.max(y - hy))

# print(torch.sum(y))
# print(torch.sum(y - hy))
# print(torch.sum(torch.abs(y)))
# print(torch.max(y))
# print(torch.max(y - hy))
# print(torch.min(torch.abs(y)))
# print(torch.min(torch.abs(y - hy)))
# print(torch.min(y))
# print(torch.min(y - hy))

In [None]:
if IS_PROFILING:
    profile(lambda: msdnet_hf(input))

hf_time = []
for test_input in test_inputs:
    hf_time.append(
        Timer(
            f"model(x)",
            setup="import torch; torch.cuda.synchronize()",
            globals={"model": msdnet_hf, "x": test_input},
        )
        .timeit(10)
        .mean
        * 10e6
    )


In [None]:

speedup = [rt / hft for rt, hft in zip(raw_time, hf_time)]

print(max(speedup))
print(min(speedup))
print(sum(speedup) / len(speedup))

In [None]:
assert False

## LiveSR

In [None]:
channels = 8


In [None]:
livesr = LiveSR(n_subnets=10, subnet_num_block=3, num_feature=channels).eval().cuda()

dataloader = get_dataloader(
    str(BRT_CACHE_PATH.parent / "benchmark/livesr/dataset/cam1/LQ")
)

for x in dataloader:
    break

livesr(x)
print(livesr.scatter.load_history)
livesr.scatter.load_history = np.array([6, 7, 12, 27, 8, 8, 8, 12, 12, 4], dtype=int)




[ 5.  3. 14. 26. 10.  9.  6.  7.  2.  6.]


https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [None]:
print(symbolic_trace(livesr).graph)

gm_livesr = symbolic_trace(
    livesr,
    tracing_shape=True,
    sample_inputs={"inputs": x},
)

router_fix_pass = RouterFixPass(gm_livesr)
router_fix_pass.run_on_graph()
gm_livesr = router_fix_pass.finalize()

horizontal_fuse_pass = HorizFusePass(
    gm_livesr, sample_inputs={"inputs": x}
)
horizontal_fuse_pass.run_on_graph()
gm_livesr = horizontal_fuse_pass.finalize()

print(gm_livesr.graph)


graph():
    %inputs : torch.Tensor [#users=2] = placeholder[target=inputs]
    %classifier : [#users=1] = call_module[target=classifier](args = (%inputs,), kwargs = {})
    %scatter : [#users=10] = call_module[target=scatter](args = (%inputs, %classifier), kwargs = {})
    %getitem : [#users=1] = call_function[target=operator.getitem](args = (%scatter, 0), kwargs = {})
    %subnets_0_head_0 : [#users=3] = call_module[target=subnets.0.head.0](args = (%getitem,), kwargs = {})
    %subnets_0_body_0_0_body_0 : [#users=1] = call_module[target=subnets.0.body.0.0.body.0](args = (%subnets_0_head_0,), kwargs = {})
    %subnets_0_body_0_0_body_1 : [#users=1] = call_module[target=subnets.0.body.0.0.body.1](args = (%subnets_0_body_0_0_body_0,), kwargs = {})
    %subnets_0_body_0_0_body_2 : [#users=1] = call_module[target=subnets.0.body.0.0.body.2](args = (%subnets_0_body_0_0_body_1,), kwargs = {})
    %add : [#users=2] = call_function[target=operator.add](args = (%subnets_0_body_0_0_body_2, %subn

In [None]:
y = livesr(x)

gm_livesr.delete_all_unused_submodules()

all_hooks = []
target_of_module = {}
scatter_outputs = [None]
try:
    for subn, subm in gm_livesr.named_modules():
        if "classifier" in subn:
            continue
        target_of_module[subm] = subn

        def print_pre_hook(m: nn.Module, i):
            name = target_of_module[m]
            print(
                f"{name:50.50} {m._get_name():20} {str(set(ii.__class__.__name__ for ii in i)):30}"
            )

        def print_hook(m: nn.Module, i, o):
            name = target_of_module[m]
            print(
                f"{name:50.50} {m._get_name():20} {str(set(ii.__class__.__name__ for ii in i)):30} "
                f"{str(set(oo.__class__.__name__ for oo in o)):30} "
            )
            # print("\t\t", getattr(o, "shape", None))
            if isinstance(o, (list, tuple)):
                for oo in o:
                    if isinstance(oo, ProtoTensor):
                        # print("\t\t", [ootg.squeeze().cpu() for ootg in oo.tag_stack])
                        print("\t\t", oo.shape)
            if name == "gather.fabric":
                for oo in i[0]:
                    if isinstance(oo, ProtoTensor):
                        # print("\t\t", [ootg.squeeze().cpu() for ootg in oo.tag_stack])
                        print("\t\t", oo.shape)

        all_hooks.append(subm.register_forward_hook(print_hook))
        # all_hooks.append(subm.register_forward_pre_hook(print_pre_hook))

        def get_scatter_outputs(m, i, o):
            scatter_outputs[0] = o

        if isinstance(subm, ScatterRouter):
            all_hooks.append(subm.register_forward_hook(get_scatter_outputs))

    hy = gm_livesr(x)
finally:
    for hook in all_hooks:
        hook.remove()


scatter.protocol                                   TopKProtocol         {'Tensor'}                     {'Tensor'}                     
scatter.fabric                                     DispatchFabric       {'Tensor'}                     {'ProtoTensor'}                
		 torch.Size([5, 3, 32, 32])
		 torch.Size([3, 3, 32, 32])
		 torch.Size([14, 3, 32, 32])
		 torch.Size([26, 3, 32, 32])
		 torch.Size([10, 3, 32, 32])
		 torch.Size([9, 3, 32, 32])
		 torch.Size([6, 3, 32, 32])
		 torch.Size([7, 3, 32, 32])
		 torch.Size([2, 3, 32, 32])
		 torch.Size([6, 3, 32, 32])
scatter                                            ScatterRouter        {'Tensor'}                     {'ProtoTensor'}                
		 torch.Size([5, 3, 32, 32])
		 torch.Size([3, 3, 32, 32])
		 torch.Size([14, 3, 32, 32])
		 torch.Size([26, 3, 32, 32])
		 torch.Size([10, 3, 32, 32])
		 torch.Size([9, 3, 32, 32])
		 torch.Size([6, 3, 32, 32])
		 torch.Size([7, 3, 32, 32])
		 torch.Size([2, 3, 32, 32])
		 torch.Size([6, 3

In [None]:
y = livesr(x)


In [None]:
hy = gm_livesr(x)


Bad pipe message: %s [b'\x87\x8cB\xc5q4\x91\x95\xb9\x9ak\xbfPYGZ\xf7\xdb e\x04\xb9\xa3\x10\x88\xf0\xf1fr\xcds\x13\xae1\xac,R\xbb\xb8L\xa9\x93\xa7\xd0)-/\x7f4y\x0e\x00\x08\x13\x02\x13\x03\x13\x01\x00\xff\x01\x00\x00\x8f\x00\x00\x00\x0e\x00\x0c\x00\x00\t127.0.0.1\x00\x0b\x00\x04\x03\x00']
Bad pipe message: %s [b'']
Bad pipe message: %s [b'\xd2]\x7f\x84]+\x0b\x1f\xf9\x15\xf23\xfc\x9c\xfcT\x9a\xb5 \xeeK\xf8\x9f[\x16\xb2+\xdc\xba\xa5\xee8\xe8\xc2\xbf\xed\xbal\x91]|\x05K`I\r\xaa\nE\x89\xfc\x00\x08\x13\x02\x13\x03\x13\x01\x00\xff\x01\x00\x00\x8f\x00\x00\x00\x0e\x00\x0c\x00\x00\t127.0.0.1\x00\x0b\x00\x04\x03\x00\x01\x02\x00\n\x00\x0c\x00\n\x00\x1d\x00\x17\x00\x1e\x00\x19\x00\x18\x00#\x00\x00\x00\x16\x00\x00\x00\x17\x00\x00\x00\r\x00\x1e\x00\x1c\x04\x03\x05\x03\x06\x03\x08', b'\x08\x08\t\x08\n\x08']
Bad pipe message: %s [b"y\x06|\x81\xff\xcf\x02A\xf2\xbf\xed\x93\xdb\xaa\xeb\xa59+\x00\x00|\xc0,\xc00\x00\xa3\x00\x9f\xcc\xa9\xcc\xa8\xcc\xaa\xc0\xaf\xc0\xad\xc0\xa3\xc0\x9f\xc0]\xc0a\xc0W\xc0S\xc0+\

In [None]:
print(y.shape)
print(hy.shape)

print(torch.allclose(y, hy, rtol=1e-100, atol=1e-2))
print(torch.sum(y))
print(torch.sum(y - hy))
print(torch.sum(torch.abs(y)))
print(torch.sum(torch.abs(y - hy)))
print(torch.max(y))
print(torch.max(y - hy))
print(torch.min(torch.abs(y)))
print(torch.min(torch.abs(y - hy)))
print(torch.min(y))
print(torch.min(y - hy))


In [None]:
raw_time = (
    Timer(
        f"model(x)",
        setup="import torch; torch.cuda.synchronize()",
        globals={"model": livesr, "x": x},
    )
    .timeit(100)
    .mean
    * 10e6
)

hf_time = (
    Timer(
        f"model(x)",
        setup="import torch; torch.cuda.synchronize()",
        globals={"model": gm_livesr, "x": x},
    )
    .timeit(100)
    .mean
    * 10e6
)

print(raw_time)
print(hf_time)
print(raw_time / hf_time)