In [1]:
import numpy as np

import torch
from torch import nn
from torch import fx
from torch.fx import GraphModule, Graph, Node

import brt
from brt.runtime import log
from brt.router import ScatterRouter, GatherRouter
from brt.router.fabric import make_fabric
from brt.trace import symbolic_trace, GraphTracer

# from brt.trace.graph import symbolic_trace
from brt.passes import (
    HorizFusePass,
    OperatorReorderPass,
    DeadPathEliminatePass,
    ConstantPropagationPass,
)

log.set_level("BRT", "DEBUG")


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import sys
from brt.runtime import BRT_CACHE_PATH

sys.path.append(str(BRT_CACHE_PATH.parent / "benchmark/livesr/"))
# from nas_mdsr import SingleNetwork as nas_mdsr
from archs.livesr import LiveSR
from dataset import get_dataloader

from argparse import Namespace

sys.path.append(str(BRT_CACHE_PATH.parent / "benchmark/msdnet/"))
from msdnet import MSDNet
from theshold_inference import threshold_dynamic_evaluate
from dataloader import get_dataloaders as msdnet_get_dataloaders


In [3]:
args = Namespace(
    arch="msdnet",
    base=4,
    batch_size=256,
    benchmark=["all_opt"],
    bnFactor=[1, 2, 4, 4],
    bottleneck=True,
    data="ImageNet",
    data_root="/home/v-louyang/dataset/ILSVRC2012",
    decay_rate=0.1,
    epochs=90,
    evalmode="threshold",
    evaluate_from="/home/v-louyang/brainstorm_project/brainstorm/benchmark/msdnet/msdnet-step=4-block=5.pth.tar",
    gpu="0,1,2,3",
    grFactor=[1, 2, 4, 4],
    growthRate=16,
    init_routers=True,
    lr=0.1,
    lr_type="multistep",
    momentum=0.9,
    nBlocks=5,
    nChannels=32,
    nScales=4,
    num_classes=1000,
    optimizer="sgd",
    parallel=True,
    print_freq=10,
    prune="max",
    reduction=0.5,
    resume=False,
    save="/home/v-louyang/brainstorm_project/brainstorm/benchmark/msdnet/saveresult",
    seed=0,
    splits=["val", "test"],
    start_epoch=0,
    step=4,
    stepmode="even",
    # thresholds=[0.44246858, -1, -1, -1], # 0.5 0.5 0 0
    thresholds=[0.34071380, 0.47392023, 0.37517136, 0.22579938],  # 0.6 0.1 0.1 0.1 0.1
    use_valid=True,
    weight_decay=0.0001,
    workers=16,
)

msdnet: nn.Module = MSDNet(args, False).eval().cuda()
# pretrained_dict = torch.load("/home/v-louyang/brainstorm_project/brainstorm/benchmark/msdnet/MSDNet.pth")
state_dict = torch.load(
    "/home/v-louyang/brainstorm_project/brainstorm/benchmark/msdnet/MSDNet.pth"
)
# print([k for k, v in msdnet.named_parameters()])
# print([k for k, v in state_dict.items()])
msdnet.load_state_dict(state_dict)

_, val_dataloader, test_dataloader = msdnet_get_dataloaders(args)

# print(msdnet)
# print(input.shape)


def print_load_history(m: nn.Module):
    print("")
    for subn, subm in m.named_modules():
        if isinstance(subm, (ScatterRouter, GatherRouter)):
            # print(subm.load_history.shape)
            print(subm.load_history)


for i, (input, target) in enumerate(test_dataloader):
    input = input.cuda()
    if i > 100:
        break
    # if i % 1000 == 0:
    #     print_load_history(msdnet)
    print("*", end="")
    output = msdnet(input)

print("")
print_load_history(msdnet)


building network of steps: 
[4, 4, 4, 4, 4] 20
 ********************** Block 1  **********************
|		inScales 4 outScales 4 inChannels 32 outChannels 16		|

|		inScales 4 outScales 4 inChannels 48 outChannels 16		|

|		inScales 4 outScales 4 inChannels 64 outChannels 16		|

|		inScales 4 outScales 4 inChannels 80 outChannels 16		|

 ********************** Block 2  **********************
|		inScales 4 outScales 4 inChannels 96 outChannels 16		|

|		inScales 4 outScales 3 inChannels 112 outChannels 16		|
|		Transition layer inserted! (max), inChannels 128, outChannels 64	|

|		inScales 3 outScales 3 inChannels 64 outChannels 16		|

|		inScales 3 outScales 3 inChannels 80 outChannels 16		|

 ********************** Block 3  **********************
|		inScales 3 outScales 3 inChannels 96 outChannels 16		|

|		inScales 3 outScales 3 inChannels 112 outChannels 16		|

|		inScales 3 outScales 2 inChannels 128 outChannels 16		|
|		Transition layer inserted! (max), inChannels 144, outChannels

In [4]:
# gm_msdnet = symbolic_trace(
#     msdnet,
#     tracing_shape=True,
#     sample_inputs={"x": input},
# )
# print(gm_msdnet.graph)

In [5]:
# print(msdnet)

In [6]:
eliminate_pass = DeadPathEliminatePass(msdnet)
eliminate_pass.run_on_graph()
msdnet = eliminate_pass.finalize()
# constant_propagation_pass = ConstantPropagationPass(
#     msdnet, upper_perm_load=args.batch_size * n_batch
# )
# constant_propagation_pass.run_on_graph()
# msdnet = constant_propagation_pass.finalize()
operator_reorder_pass = OperatorReorderPass(msdnet, False)
operator_reorder_pass.run_on_graph()
msdnet = operator_reorder_pass.finalize()
print("########## before HorizFusePass")
print(msdnet.graph)


########## before HorizFusePass
graph():
    %x : [#users=1] = placeholder[target=x]
    %_is_measure : [#users=0] = placeholder[target=_is_measure](default=False)
    %blocks_0_0_layers_0_0 : [#users=1] = call_module[target=blocks.0.0.layers.0.0](args = (%x,), kwargs = {})
    %blocks_0_0_layers_0_1 : [#users=1] = call_module[target=blocks.0.0.layers.0.1](args = (%blocks_0_0_layers_0_0,), kwargs = {})
    %blocks_0_0_layers_0_2 : [#users=1] = call_module[target=blocks.0.0.layers.0.2](args = (%blocks_0_0_layers_0_1,), kwargs = {})
    %blocks_0_0_layers_0_3 : [#users=4] = call_module[target=blocks.0.0.layers.0.3](args = (%blocks_0_0_layers_0_2,), kwargs = {})
    %blocks_0_0_layers_1_net_0 : [#users=1] = call_module[target=blocks.0.0.layers.1.net.0](args = (%blocks_0_0_layers_0_3,), kwargs = {})
    %blocks_0_1_layers_0_conv_normal_net_0 : [#users=1] = call_module[target=blocks.0.1.layers.0.conv_normal.net.0](args = (%blocks_0_0_layers_0_3,), kwargs = {})
    %blocks_0_1_layers_1_conv_

In [7]:

horiz_fusion_pass = HorizFusePass(
    msdnet, sample_inputs={"x": input}
)
horiz_fusion_pass.run_on_graph()
msdnet = horiz_fusion_pass.finalize()
print("########## after HorizFusePass")
print(msdnet.graph)
# threshold_dynamic_evaluate(
#     msdnet,
#     [[torch.randn(1, 3, 256, 256), torch.tensor((1,))]],
#     [[torch.randn(1, 3, 256, 256), torch.tensor((1,))]],
#     args,
# )


Currently not support nodes with kwargs (`cat`), the info of kwargs won't be traced
Currently not support nodes with kwargs (`cat_1`), the info of kwargs won't be traced
Currently not support nodes with kwargs (`cat_2`), the info of kwargs won't be traced
Currently not support nodes with kwargs (`cat_3`), the info of kwargs won't be traced
Currently not support nodes with kwargs (`cat_4`), the info of kwargs won't be traced
Currently not support nodes with kwargs (`cat_5`), the info of kwargs won't be traced
Currently not support nodes with kwargs (`cat_6`), the info of kwargs won't be traced
Currently not support nodes with kwargs (`cat_7`), the info of kwargs won't be traced
Currently not support nodes with kwargs (`cat_8`), the info of kwargs won't be traced
Currently not support nodes with kwargs (`cat_9`), the info of kwargs won't be traced
Currently not support nodes with kwargs (`cat_10`), the info of kwargs won't be traced
Currently not support nodes with kwargs (`cat_11`), the

KeyboardInterrupt: 

In [None]:
assert False

In [None]:
def find_cycle_inner(node, visited, stack: list):
    visited.add(node)
    stack.insert(-1, node)
    for user in node.users:
        if user not in visited:
            ret = find_cycle_inner(user, visited, stack)
            if ret is not None:
                return ret
        elif user in stack:
            return stack[stack.index(user):]
    stack.pop(-1)
    return None

def find_cycle(graph):
    visited = set()
    stack = []
    for node in graph.nodes:
        if node not in visited:
            ret = find_cycle_inner(node, visited, stack)
            if ret is not None:
                print(ret)
    return None

from pprint import pprint
pprint(find_cycle(horiz_fusion_pass.graph_mod.graph))
    

In [None]:
def topological_inner(cur: Node, visited, sequence):
    assert cur not in visited
    visited.add(cur)
    for user in cur.users:
        if user not in visited:
            topological_inner(user, visited, sequence)
    sequence.append(cur)

def topoligical_sort(graph: Graph):
    # Get topoligically ordered sequence
    visited = set()
    sequence = []
    for node in graph.nodes:
        if node not in visited:
            topological_inner(node, visited, sequence)
    # Make topoligically link list
    pre_node = graph._root
    print(sequence)
    for node in reversed(sequence):
        pre_node._next = node
        node._prev = pre_node
        pre_node = node
    graph._root._prev = pre_node
    pre_node._next = graph._root

topoligical_sort(horiz_fusion_pass.graph_mod.graph)

In [None]:
channels = 8


In [None]:
livesr = LiveSR(n_subnets=10, subnet_num_block=3, num_feature=channels).eval().cuda()

dataloader = get_dataloader(
    str(BRT_CACHE_PATH.parent / "benchmark/livesr/dataset/cam1/LQ")
)

for x in dataloader:
    break

livesr(x)
print(livesr.scatter.load_history)
livesr.scatter.load_history = np.array([6, 7, 12, 27, 8, 8, 8, 12, 12, 4], dtype=int)


In [None]:
print(symbolic_trace(livesr).graph)

gm_livesr = symbolic_trace(
    livesr,
    tracing_shape=True,
    sample_inputs={"inputs": x},
)
# print(gm_livesr.graph)
horizontal_fuse_pass = HorizFusePass(
    livesr, sample_inputs={"inputs": x}, fixing_scatters=True
)
horizontal_fuse_pass.run_on_graph()
gm_livesr = horizontal_fuse_pass.finalize()
print(gm_livesr.graph)


In [None]:
y = livesr(x)

gm_livesr.delete_all_unused_submodules()

all_hooks = []
target_of_module = {}
scatter_outputs = [None]
try:
    for subn, subm in gm_livesr.named_modules():
        if "classifier" in subn:
            continue
        target_of_module[subm] = subn

        def print_pre_hook(m: nn.Module, i):
            name = target_of_module[m]
            print(
                f"{name:50.50} {m._get_name():20} {str(set(ii.__class__.__name__ for ii in i)):30}"
            )

        def print_hook(m: nn.Module, i, o):
            name = target_of_module[m]
            print(
                f"{name:50.50} {m._get_name():20} {str(set(ii.__class__.__name__ for ii in i)):30} "
                f"{str(set(oo.__class__.__name__ for oo in o)):30} "
                # f"{'[' + o[0].__class__.__name__ + ', ...]':20s}"
                # if isinstance(o, (tuple, list))
                # else f"{str(o.__class__.__name__) + ' | ' + str(o.load) + ' | ' + str(o.tag):20s}"
                # else f"{str(o.__class__.__name__):20s}"
            )

        all_hooks.append(subm.register_forward_hook(print_hook))
        # all_hooks.append(subm.register_forward_pre_hook(print_pre_hook))

        def get_scatter_outputs(m, i, o):
            scatter_outputs[0] = o

        if isinstance(subm, ScatterRouter):
            all_hooks.append(subm.register_forward_hook(get_scatter_outputs))

    hy = gm_livesr(x)
finally:
    for hook in all_hooks:
        hook.remove()


In [None]:
print(torch.allclose(y, hy, rtol=1e-100, atol=1e-2))
print(torch.sum(y))
print(torch.sum(y - hy))
print(torch.sum(torch.abs(y)))
print(torch.sum(torch.abs(y - hy)))
print(torch.max(y))
print(torch.max(y - hy))
print(torch.min(torch.abs(y)))
print(torch.min(torch.abs(y - hy)))
print(torch.min(y))
print(torch.min(y - hy))
