In [1]:
from argparse import Namespace
import os
import inspect
import numpy as np

import torch
from torch import nn
from torch import fx
from torch.fx import GraphModule, Graph, Node
from torch.utils.benchmark import Timer

import brt
from brt.runtime import log
from brt.runtime.grid_tensor import GridTensor
from brt.runtime.benchmark import profile
from brt.router import ScatterRouter, GatherRouter, switch_capture
from brt.router.base import reset_router_stats
from brt.router.fabric import make_fabric
from brt.trace import symbolic_trace, GraphTracer

# from brt.trace.graph import symbolic_trace
from brt.passes import (
    HorizFusePass,
    VerticalFusePass,
    OperatorReorderPass,
    DeadPathEliminatePass,
    ConstantPropagationPass,
)

# log.set_level("BRT", "WARNING")
log.set_level("BRT", "DEBUG")

# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [2]:
import sys
from brt.runtime import BRT_CACHE_PATH

sys.path.append(str(BRT_CACHE_PATH.parent / "benchmark/msdnet/"))
from msdnet import MSDNet
from theshold_inference import threshold_dynamic_evaluate
from dataloader import get_dataloaders as msdnet_get_dataloaders


In [3]:
IS_PROFILING = False
# IS_PROFILING = True
IS_FUSING_HEAD = False
IS_FUSING_HEAD = True

In [4]:
args = Namespace(
    arch="msdnet",
    base=4,
    batch_size=256,
    benchmark=["all_opt"],
    bnFactor=[1, 2, 4, 4],
    bottleneck=True,
    data="ImageNet",
    data_root="/home/v-louyang/dataset/imagenet",
    decay_rate=0.1,
    epochs=90,
    evalmode="threshold",
    evaluate_from="/home/v-louyang/brainstorm_project/brainstorm/benchmark/msdnet/msdnet-step=4-block=5.pth.tar",
    gpu="0,1,2,3",
    grFactor=[1, 2, 4, 4],
    growthRate=16,
    init_routers=True,
    lr=0.1,
    lr_type="multistep",
    momentum=0.9,
    nBlocks=5,
    nChannels=32,
    nScales=4,
    num_classes=1000,
    optimizer="sgd",
    parallel=True,
    print_freq=10,
    prune="max",
    reduction=0.5,
    resume=False,
    save="/home/v-louyang/brainstorm_project/brainstorm/benchmark/msdnet/saveresult",
    seed=0,
    splits=["val", "test"],
    start_epoch=0,
    step=4,
    stepmode="even",
    thresholds=[-1, -1, -1, -1],                                    # 1.0, 0.0, 0.0, 0.0, 0.0
    # thresholds=[0.44246858, -1, -1, -1],                          # 0.5, 0.5, 0.0, 0.0, 0.0
    # thresholds=[0.44246849, 0.26682281, -1, -1],                  # 0.5, 0.3, 0.2, 0.0, 0.0
    # thresholds=[0.44246864, 0.39881980, 0.19329087, -1],          # 0.5, 0.2, 0.2, 0.1, 0.0
    # thresholds=[0.96616900, 0.95113075, 0.80969042, 0.45410264],  # 0.1, 0.1, 0.2, 0.3, 0.3
    # thresholds=[1000, 1000, 0.90728849, 0.57961094],              # 0.0, 0.0, 0.3, 0.3, 0.4
    # thresholds=[1000, 1000, 1000, 0.83451331],                    # 0.0, 0.0, 0.0, 0.4, 0.6
    # thresholds=[1000, 1000, 1000, 1000],                          # 0.0, 0.0, 0.0, 0.0, 1.0
    use_valid=True,
    weight_decay=0.0001,
    workers=16,
)

state_dict = torch.load(
    "/home/v-louyang/brainstorm_project/brainstorm/benchmark/msdnet/MSDNet.pth"
)
_, val_dataloader, test_dataloader = msdnet_get_dataloaders(args)

test_inputs = []

for i, (input, target) in enumerate(test_dataloader):
    input = input.cuda()
    if i == 13:
        break

for i, (test_input, target) in enumerate(test_dataloader):
    # if i < 50:
    if i < 30:
        test_inputs.append(test_input.cuda())


!!!!!! Load train_set_index !!!!!!


In [5]:
msdnet: nn.Module = MSDNet(args, False).eval().cuda()
msdnet.load_state_dict(state_dict)

building network of steps: 
[4, 4, 4, 4, 4] 20
 ********************** Block 1  **********************
|		inScales 4 outScales 4 inChannels 32 outChannels 16		|

|		inScales 4 outScales 4 inChannels 48 outChannels 16		|

|		inScales 4 outScales 4 inChannels 64 outChannels 16		|

|		inScales 4 outScales 4 inChannels 80 outChannels 16		|

 ********************** Block 2  **********************
|		inScales 4 outScales 4 inChannels 96 outChannels 16		|

|		inScales 4 outScales 3 inChannels 112 outChannels 16		|
|		Transition layer inserted! (max), inChannels 128, outChannels 64	|

|		inScales 3 outScales 3 inChannels 64 outChannels 16		|

|		inScales 3 outScales 3 inChannels 80 outChannels 16		|

 ********************** Block 3  **********************
|		inScales 3 outScales 3 inChannels 96 outChannels 16		|

|		inScales 3 outScales 3 inChannels 112 outChannels 16		|

|		inScales 3 outScales 2 inChannels 128 outChannels 16		|
|		Transition layer inserted! (max), inChannels 144, outChannels

<All keys matched successfully>

In [6]:
def print_load_history(m: nn.Module):
    print("")
    for subn, subm in m.named_modules():
        if isinstance(subm, (ScatterRouter, GatherRouter)):
            print(f"{subm.capturing=}")
            print(f"{subm.load_history=}")
            print(f"{subm.ptu_decision_history=}")
            print(f"{subm.ptu_grain_history=}")

reset_router_stats(msdnet)
switch_capture(msdnet, True)
for test_input in test_inputs:
    test_input = test_input.cuda()
    print("*", end="")
    test_output = msdnet(test_input)
switch_capture(msdnet, False)

print("")
print_load_history(msdnet)

input = test_inputs[13]
y = msdnet(input)

******************************

subm.capturing=False
subm.load_history=array([1., 0.])
subm.ptu_decision_history=[array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29], dtype=int32), array([], dtype=int32)]
subm.ptu_grain_history=[torch.Size([1, 96, 56, 56]), torch.Size([1, 192, 28, 28]), torch.Size([1, 384, 14, 14]), torch.Size([1, 384, 7, 7]), torch.Size([1, 1000])]
subm.capturing=False
subm.load_history=array([0., 0.])
subm.ptu_decision_history=[array([], dtype=int32), array([], dtype=int32)]
subm.ptu_grain_history=[torch.Size([0, 192, 28, 28]), torch.Size([0, 384, 14, 14]), torch.Size([0, 384, 7, 7]), torch.Size([0, 384, 7, 7]), torch.Size([0, 1000])]
subm.capturing=False
subm.load_history=array([0., 0.])
subm.ptu_decision_history=[array([], dtype=int32), array([], dtype=int32)]
subm.ptu_grain_history=[torch.Size([0, 352, 14, 14]), torch.Size([0, 352, 7, 7]), torch.Size([0, 352, 7, 7]), torch.Size([0, 3

In [None]:
if IS_PROFILING:
    profile(lambda: msdnet(input))

raw_time = []
for test_input in test_inputs:
    cost = Timer(
            f"model(x)",
            setup="import torch; torch.cuda.synchronize()",
            globals={"model": msdnet, "x": test_input},
        ).timeit(10).mean * 1e6
    print(cost)
    raw_time.append(cost)

In [None]:
verti_fusion_pass = VerticalFusePass(
    msdnet, sample_inputs={"x": input}, fusing_head=IS_FUSING_HEAD
)
verti_fusion_pass.run_on_graph()
msdnet_vf = verti_fusion_pass.finalize()

In [None]:
if IS_PROFILING:
    profile(lambda: msdnet_vf(input))

profile(lambda: msdnet_vf(input))
vf_time = []
for test_input in test_inputs:
    cost = Timer(
            f"model(x)",
            setup="import torch; torch.cuda.synchronize()",
            globals={"model": msdnet_vf, "x": test_input},
        ) .timeit(10) .mean * 1e6
    print(cost)
    vf_time.append(cost)

In [None]:
print(msdnet_vf.graph)

In [7]:
eliminate_pass = DeadPathEliminatePass(msdnet)
eliminate_pass.run_on_graph()
msdnet_dpe = eliminate_pass.finalize()

constant_propagation_pass = ConstantPropagationPass(
    msdnet_dpe, upper_perm_load=1
)
constant_propagation_pass.run_on_graph()
msdnet_cp = constant_propagation_pass.finalize()

operator_reorder_pass = OperatorReorderPass(msdnet_cp, False)
operator_reorder_pass.run_on_graph()
msdnet_reorder = operator_reorder_pass.finalize()

print(msdnet_reorder.graph)


graph():
    %x : [#users=1] = placeholder[target=x]
    %_is_measure : [#users=0] = placeholder[target=_is_measure](default=False)
    %blocks_0_0_layers_0_0 : [#users=1] = call_module[target=blocks.0.0.layers.0.0](args = (%x,), kwargs = {})
    %blocks_0_0_layers_0_1 : [#users=1] = call_module[target=blocks.0.0.layers.0.1](args = (%blocks_0_0_layers_0_0,), kwargs = {})
    %blocks_0_0_layers_0_2 : [#users=1] = call_module[target=blocks.0.0.layers.0.2](args = (%blocks_0_0_layers_0_1,), kwargs = {})
    %blocks_0_0_layers_0_3 : [#users=4] = call_module[target=blocks.0.0.layers.0.3](args = (%blocks_0_0_layers_0_2,), kwargs = {})
    %blocks_0_0_layers_1_net_0 : [#users=1] = call_module[target=blocks.0.0.layers.1.net.0](args = (%blocks_0_0_layers_0_3,), kwargs = {})
    %blocks_0_1_layers_0_conv_normal_net_0 : [#users=1] = call_module[target=blocks.0.1.layers.0.conv_normal.net.0](args = (%blocks_0_0_layers_0_3,), kwargs = {})
    %blocks_0_1_layers_1_conv_down_net_0 : [#users=1] = call_m

In [None]:
if IS_PROFILING:
    profile(lambda: msdnet_reorder(input))

opr_time = []
for test_input in test_inputs:
    cost = Timer(
            f"model(x)",
            setup="import torch; torch.cuda.synchronize()",
            globals={"model": msdnet_reorder, "x": test_input},
        ) .timeit(10) .mean * 1e6
    print(cost)
    opr_time.append(cost)

In [8]:
sp_verti_fusion_pass = VerticalFusePass(
    msdnet_reorder, sample_inputs={"x": input}, fusing_head=IS_FUSING_HEAD
)
sp_verti_fusion_pass.run_on_graph()
msdnet_sp = sp_verti_fusion_pass.finalize()

Currently not support nodes with kwargs (`cat`), the info of kwargs won't be traced
Currently not support nodes with kwargs (`cat_1`), the info of kwargs won't be traced
Currently not support nodes with kwargs (`cat_2`), the info of kwargs won't be traced
Currently not support nodes with kwargs (`cat_3`), the info of kwargs won't be traced
Currently not support nodes with kwargs (`cat_4`), the info of kwargs won't be traced
Currently not support nodes with kwargs (`cat_5`), the info of kwargs won't be traced
Currently not support nodes with kwargs (`cat_6`), the info of kwargs won't be traced
Currently not support nodes with kwargs (`cat_7`), the info of kwargs won't be traced
Currently not support nodes with kwargs (`cat_8`), the info of kwargs won't be traced
Currently not support nodes with kwargs (`cat_9`), the info of kwargs won't be traced
start node `x` should be a fixed module node
start node `_is_measure` should be a fixed module node
fuse node `%blocks_0_0_layers_0_0`
fuse no

In [None]:
if IS_PROFILING:
    profile(lambda: msdnet_sp(input))

sp_time = []
for test_input in test_inputs:
    cost = Timer(
            f"model(x)",
            setup="import torch; torch.cuda.synchronize()",
            globals={"model": msdnet_sp, "x": test_input},
        ) .timeit(10) .mean * 1e6
    print(cost)
    sp_time.append(cost)

In [None]:
print(msdnet_sp.graph)

In [9]:
horiz_fusion_pass = HorizFusePass(
    msdnet_reorder, sample_inputs={"x": input}, fusing_head=IS_FUSING_HEAD,
)
horiz_fusion_pass.run_on_graph()
msdnet_hf = horiz_fusion_pass.finalize()

Currently not support nodes with kwargs (`cat`), the info of kwargs won't be traced
Currently not support nodes with kwargs (`cat_1`), the info of kwargs won't be traced
Currently not support nodes with kwargs (`cat_2`), the info of kwargs won't be traced
Currently not support nodes with kwargs (`cat_3`), the info of kwargs won't be traced
Currently not support nodes with kwargs (`cat_4`), the info of kwargs won't be traced
Currently not support nodes with kwargs (`cat_5`), the info of kwargs won't be traced
Currently not support nodes with kwargs (`cat_6`), the info of kwargs won't be traced
Currently not support nodes with kwargs (`cat_7`), the info of kwargs won't be traced
Currently not support nodes with kwargs (`cat_8`), the info of kwargs won't be traced
Currently not support nodes with kwargs (`cat_9`), the info of kwargs won't be traced
fuse node `%blocks_0_0_layers_0_0`
fuse node `%blocks_0_0_layers_0_1`
fuse node `%blocks_0_0_layers_0_2`
node `%blocks_0_0_layers_0_3` has mor

In [12]:
if IS_PROFILING:
    profile(lambda: msdnet_hf(input))

hf_time = []
for test_input in test_inputs:
    cost = Timer(
            f"model(x)",
            setup="import torch; torch.cuda.synchronize()",
            globals={"model": msdnet_hf, "x": test_input},
        ) .timeit(100) .mean * 1e6
    print(cost)
    hf_time.append(cost)

967.6736999244895
1031.9762700237334
1031.6697899543215
1031.1972099589184
1009.3839900218883
974.9920800095424
1010.4992600099649
976.9320899795275
959.5742898818571
1004.9260601226707
995.4220100189559
1000.2832900499925
999.3678799946792
979.3596400413662
998.0500300298444
961.6752099827863
936.7701200244483
911.1273399321362
921.2570601084735
927.7683099207934
971.8470099323895
994.5325499575121
988.5518900409807
946.0269998817239
938.5710700007621
919.7030500217807
941.888969973661
920.324220060138
971.0026199172717
947.8899800160434


In [11]:
print(msdnet_hf.graph)

graph():
    %x : [#users=1] = placeholder[target=x] | fixed
    %_is_measure : [#users=0] = placeholder[target=_is_measure](default=False) | unfixed
    %blocks_0_0_layers_0_0 : [#users=1] = call_module[target=blocks.0.0.layers.0.0](args = (%x,), kwargs = {}) | fixed
    %blocks_0_0_layers_0_1 : [#users=1] = call_module[target=blocks.0.0.layers.0.1](args = (%blocks_0_0_layers_0_0,), kwargs = {}) | fixed
    %blocks_0_0_layers_0_2 : [#users=1] = call_module[target=blocks.0.0.layers.0.2](args = (%blocks_0_0_layers_0_1,), kwargs = {}) | fixed
    %blocks_0_0_layers_0_3 : [#users=2] = call_module[target=blocks.0.0.layers.0.3](args = (%blocks_0_0_layers_0_2,), kwargs = {}) | fixed
    %BRT_HF__V_blocks_0_1_layers_0_conv_normal_net_0__blocks_0_1_layers_0_conv_normal_net_1__blocks_0_1_layers_0_conv_normal_net_2__V_blocks_0_0_layers_1_net_0__blocks_0_0_layers_1_net_1__blocks_0_0_layers_1_net_2__V_blocks_0_1_layers_1_conv_down_net_0__blocks_0_1_layers_1_conv_down_net_1__blocks_0_1_layers_1_con

In [None]:
t_raw_time = torch.tensor(raw_time)
t_vf_time = torch.tensor(vf_time)
t_sp_time = torch.tensor(sp_time)
t_hf_time = torch.tensor(hf_time)

vf_speed_up = t_raw_time / t_vf_time
sp_speed_up = t_raw_time / t_sp_time
hf_speed_up = t_raw_time / t_hf_time

print("mean")
print(f"vf: {vf_speed_up.mean()}")
print(f"sp: {sp_speed_up.mean()}")
print(f"hf: {hf_speed_up.mean()}")
print("max")
print(f"vf: {vf_speed_up.max()}")
print(f"sp: {sp_speed_up.max()}")
print(f"hf: {hf_speed_up.max()}")
print("min")
print(f"vf: {vf_speed_up.min()}")
print(f"sp: {sp_speed_up.min()}")
print(f"hf: {hf_speed_up.min()}")

In [None]:
# 1 -1 -1 -1
# mean
# vf: 1.2304408550262451
# sp: 11.934501647949219
# hf: 13.882946014404297
# max
# vf: 1.3284677267074585
# sp: 13.092482566833496
# hf: 15.02823257446289
# min
# vf: 1.1286181211471558
# sp: 11.222148895263672
# hf: 13.1246976852417

# -1 -1 -1 -1
# mean
# vf: 1.363097071647644
# sp: 5.163495063781738
# hf: 8.713125228881836
# max
# vf: 1.4690160751342773
# sp: 6.095772743225098
# hf: 9.53192138671875
# min
# vf: 1.2727227210998535
# sp: 4.845872402191162
# hf: 8.095858573913574