In [None]:
import numpy as np

import torch
from torch import nn
from torch import fx
from torch.fx import GraphModule, Graph, Node

import brt
from brt.runtime import log
from brt.router import ScatterRouter, GatherRouter
from brt.router.fabric import make_fabric
from brt.trace import symbolic_trace, GraphTracer

# from brt.trace.graph import symbolic_trace
from brt.passes import (
    HorizFusePass,
    OperatorReorderPass,
    DeadPathEliminatePass,
    ConstantPropagationPass,
)

log.set_level("BRT", "DEBUG")


In [None]:
import sys
from brt.runtime import BRT_CACHE_PATH

sys.path.append(str(BRT_CACHE_PATH.parent / "benchmark/livesr/"))
# from nas_mdsr import SingleNetwork as nas_mdsr
from archs.livesr import LiveSR
from dataset import get_dataloader

from argparse import Namespace

sys.path.append(str(BRT_CACHE_PATH.parent / "benchmark/msdnet/"))
from msdnet import MSDNet
from theshold_inference import threshold_dynamic_evaluate
from dataloader import get_dataloaders as msdnet_get_dataloaders


In [None]:
args = Namespace(
    arch="msdnet",
    base=4,
    batch_size=256,
    benchmark=["all_opt"],
    bnFactor=[1, 2, 4, 4],
    bottleneck=True,
    data="ImageNet",
    data_root="/home/v-louyang/dataset/ILSVRC2012",
    decay_rate=0.1,
    epochs=90,
    evalmode="threshold",
    evaluate_from="/home/v-louyang/brainstorm_project/brainstorm/benchmark/msdnet/msdnet-step=4-block=5.pth.tar",
    gpu="0,1,2,3",
    grFactor=[1, 2, 4, 4],
    growthRate=16,
    init_routers=True,
    lr=0.1,
    lr_type="multistep",
    momentum=0.9,
    nBlocks=5,
    nChannels=32,
    nScales=4,
    num_classes=1000,
    optimizer="sgd",
    parallel=True,
    print_freq=10,
    prune="max",
    reduction=0.5,
    resume=False,
    save="/home/v-louyang/brainstorm_project/brainstorm/benchmark/msdnet/saveresult",
    seed=0,
    splits=["val", "test"],
    start_epoch=0,
    step=4,
    stepmode="even",
    thresholds=[0.44246858, -1, -1, -1], # 0.5 0.5 0 0
    # thresholds=[0.34071380, 0.47392023, 0.37517136, 0.22579938],  # 0.6 0.1 0.1 0.1 0.1
    use_valid=True,
    weight_decay=0.0001,
    workers=16,
)

msdnet: nn.Module = MSDNet(args, False).eval().cuda()
# pretrained_dict = torch.load("/home/v-louyang/brainstorm_project/brainstorm/benchmark/msdnet/MSDNet.pth")
state_dict = torch.load(
    "/home/v-louyang/brainstorm_project/brainstorm/benchmark/msdnet/MSDNet.pth"
)
# print([k for k, v in msdnet.named_parameters()])
# print([k for k, v in state_dict.items()])
msdnet.load_state_dict(state_dict)

_, val_dataloader, test_dataloader = msdnet_get_dataloaders(args)

# print(msdnet)
# print(input.shape)


def print_load_history(m: nn.Module):
    print("")
    for subn, subm in m.named_modules():
        if isinstance(subm, (ScatterRouter, GatherRouter)):
            # print(subm.load_history.shape)
            print(subm.load_history)


for i, (input, target) in enumerate(test_dataloader):
    input = input.cuda()
    if i > 100:
        break
    # if i % 1000 == 0:
    #     print_load_history(msdnet)
    print("*", end="")
    output = msdnet(input)

print("")
print_load_history(msdnet)


In [7]:
eliminate_pass = DeadPathEliminatePass(msdnet)
eliminate_pass.run_on_graph()
msdnet = eliminate_pass.finalize()
# constant_propagation_pass = ConstantPropagationPass(
#     msdnet, upper_perm_load=args.batch_size * n_batch
# )
# constant_propagation_pass.run_on_graph()
# msdnet = constant_propagation_pass.finalize()
operator_reorder_pass = OperatorReorderPass(msdnet, False)
operator_reorder_pass.run_on_graph()
msdnet = operator_reorder_pass.finalize()
# print(GraphModule.__repr__(msdnet))
print("########## before HorizFusePass")
print(msdnet)
print(msdnet.graph)
horiz_fusion_pass = HorizFusePass(
    msdnet, sample_inputs={"x": input}, fixing_scatters=True
)
print("########## before HorizFusePass runs")
print(horiz_fusion_pass.origin_graph)
horiz_fusion_pass.run_on_graph()
msdnet = horiz_fusion_pass.finalize()
print("########## after HorizFusePass")
print(msdnet.graph)
# threshold_dynamic_evaluate(
#     msdnet,
#     [[torch.randn(1, 3, 256, 256), torch.tensor((1,))]],
#     [[torch.randn(1, 3, 256, 256), torch.tensor((1,))]],
#     args,
# )


########## before HorizFusePass
MSDNet(
  (blocks): Module(
    (0): Module(
      (0): Module(
        (layers): Module(
          (0): Module(
            (0): Conv2d(3, 32, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
            (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
          )
          (1): Module(
            (net): Module(
              (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
              (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU(inplace=True)
            )
          )
          (2): Module(
            (net): Module(
              (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, trac

RuntimeError: cuDNN error: CUDNN_STATUS_MAPPING_ERROR

In [None]:
channels = 8


In [None]:
livesr = LiveSR(n_subnets=10, subnet_num_block=3, num_feature=channels).eval().cuda()

dataloader = get_dataloader(
    str(BRT_CACHE_PATH.parent / "benchmark/livesr/dataset/cam1/LQ")
)

for x in dataloader:
    break

livesr(x)
print(livesr.scatter.load_history)
livesr.scatter.load_history = np.array([6, 7, 12, 27, 8, 8, 8, 12, 12, 4], dtype=int)


In [None]:
print(symbolic_trace(livesr).graph)

gm_livesr = symbolic_trace(
    livesr,
    tracing_shape=True,
    sample_inputs={"inputs": x},
)
# print(gm_livesr.graph)
horizontal_fuse_pass = HorizFusePass(
    livesr, sample_inputs={"inputs": x}, fixing_scatters=True
)
horizontal_fuse_pass.run_on_graph()
gm_livesr = horizontal_fuse_pass.finalize()
print(gm_livesr.graph)


In [None]:
y = livesr(x)

gm_livesr.delete_all_unused_submodules()

all_hooks = []
target_of_module = {}
scatter_outputs = [None]
try:
    for subn, subm in gm_livesr.named_modules():
        if "classifier" in subn:
            continue
        target_of_module[subm] = subn

        def print_pre_hook(m: nn.Module, i):
            name = target_of_module[m]
            print(
                f"{name:50.50} {m._get_name():20} {str(set(ii.__class__.__name__ for ii in i)):30}"
            )

        def print_hook(m: nn.Module, i, o):
            name = target_of_module[m]
            print(
                f"{name:50.50} {m._get_name():20} {str(set(ii.__class__.__name__ for ii in i)):30} "
                f"{str(set(oo.__class__.__name__ for oo in o)):30} "
                # f"{'[' + o[0].__class__.__name__ + ', ...]':20s}"
                # if isinstance(o, (tuple, list))
                # else f"{str(o.__class__.__name__) + ' | ' + str(o.load) + ' | ' + str(o.tag):20s}"
                # else f"{str(o.__class__.__name__):20s}"
            )

        all_hooks.append(subm.register_forward_hook(print_hook))
        # all_hooks.append(subm.register_forward_pre_hook(print_pre_hook))

        def get_scatter_outputs(m, i, o):
            scatter_outputs[0] = o

        if isinstance(subm, ScatterRouter):
            all_hooks.append(subm.register_forward_hook(get_scatter_outputs))

    hy = gm_livesr(x)
finally:
    for hook in all_hooks:
        hook.remove()


In [None]:
print(torch.allclose(y, hy, rtol=1e-100, atol=1e-2))
print(torch.sum(y))
print(torch.sum(y - hy))
print(torch.sum(torch.abs(y)))
print(torch.sum(torch.abs(y - hy)))
print(torch.max(y))
print(torch.max(y - hy))
print(torch.min(torch.abs(y)))
print(torch.min(torch.abs(y - hy)))
print(torch.min(y))
print(torch.min(y - hy))
