In [92]:
import tvm
from tvm import relay

import numpy as np

from tvm.contrib.download import download_testdata
import tvm
from tvm import relay
from tvm.contrib import graph_executor
import numpy as np

import cv2

# PyTorch imports
import torch
import torchvision

In [93]:
model_name = "resnet18"
model = getattr(torchvision.models, model_name)(pretrained=True)
model = model.eval()

# We grab the TorchScripted model via tracing
input_shape = [1, 3, 224, 224]
input_data = torch.randn(input_shape)
scripted_model = torch.jit.trace(model, input_data).eval()

In [102]:
model.maxpool(model.bn1(model.conv1(input_data))).shape

torch.Size([1, 64, 56, 56])

In [104]:
model.layer1(model.maxpool(model.bn1(model.conv1(input_data)))).shape

torch.Size([1, 64, 56, 56])

In [105]:
model.layer2(model.layer1(model.maxpool(model.bn1(model.conv1(input_data))))).shape

torch.Size([1, 128, 28, 28])

In [106]:
model.layer3(model.layer2(model.layer1(model.maxpool(model.bn1(model.conv1(input_data)))))).shape

torch.Size([1, 256, 14, 14])

In [108]:
model.layer4(model.layer3(model.layer2(model.layer1(model.maxpool(model.bn1(model.conv1(input_data))))))).shape

torch.Size([1, 512, 7, 7])

In [109]:
model.avgpool(model.layer4(model.layer3(model.layer2(model.layer1(model.maxpool(model.bn1(model.conv1(input_data)))))))).shape

torch.Size([1, 512, 1, 1])

In [54]:
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [111]:
in_size = 224

img_url = (
    "https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/detection/street_small.jpg"
)
img_path = download_testdata(img_url, "test_street_small.jpg", module="data")

img = cv2.imread(img_path).astype("float32")
img = cv2.resize(img, (in_size, in_size))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = np.transpose(img / 255.0, [2, 0, 1])
img = np.expand_dims(img, axis=0)

In [112]:
input_name = "input0"
shape_list = [(input_name, input_shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)

In [113]:
target = "llvm"
with tvm.transform.PassContext(opt_level=1):
    lib = relay.build(mod, target=target, params=params)
model = graph_executor.GraphModule(lib["default"](tvm.device(target, 0)))

In [114]:
model

<tvm.contrib.graph_executor.GraphModule at 0xff2a3876d3c0>

In [194]:
layer_dict = {}
for i, var in enumerate(mod["main"].params):
    layer_dict[str(i)] = var

layer_dict

{'0': Var(input0, ty=TensorType([1, 3, 224, 224], float32)),
 '1': Var(aten::_convolution_0.weight, ty=TensorType([64, 3, 7, 7], float32)),
 '2': Var(aten::batch_norm_0.weight, ty=TensorType([64], float32)),
 '3': Var(aten::batch_norm_0.bias, ty=TensorType([64], float32)),
 '4': Var(aten::batch_norm_0.running_mean, ty=TensorType([64], float32)),
 '5': Var(aten::batch_norm_0.running_var, ty=TensorType([64], float32)),
 '6': Var(aten::_convolution_1.weight, ty=TensorType([64, 64, 3, 3], float32)),
 '7': Var(aten::batch_norm_1.weight, ty=TensorType([64], float32)),
 '8': Var(aten::batch_norm_1.bias, ty=TensorType([64], float32)),
 '9': Var(aten::batch_norm_1.running_mean, ty=TensorType([64], float32)),
 '10': Var(aten::batch_norm_1.running_var, ty=TensorType([64], float32)),
 '11': Var(aten::_convolution_2.weight, ty=TensorType([64, 64, 3, 3], float32)),
 '12': Var(aten::batch_norm_2.weight, ty=TensorType([64], float32)),
 '13': Var(aten::batch_norm_2.bias, ty=TensorType([64], float32)),


In [207]:
main_func = mod["main"]
expr = main_func.body
set_layers = set({})
layers = []

def traverse_expr(expr):
    global set_layers, layers
    if isinstance(expr, relay.expr.Call):
        if expr.op.name == "nn.conv2d" or expr.op.name == "nn.dense":
            if expr not in set_layers:
                set_layers |= {expr}
                layers.append(expr)
        for arg in expr.args:
            traverse_expr(arg)
            
    elif isinstance(expr, relay.expr.TupleGetItem):
        for arg in expr.tuple_value.args:
            traverse_expr(arg)

traverse_expr(expr)

In [208]:
assert len(layers) == 21

In [214]:
relay_layer_functions = [relay.Function(main_func.params, layer) for layer in layers]
layer_mods = [tvm.IRModule.from_expr(rlf) for rlf in relay_layer_functions]
layer_libs = []

target = "llvm" #if self.device == "cpu" else "cuda"
for lm in layer_mods:
    with tvm.transform.PassContext(opt_level=1):
        lib = relay.build(lm, target=target, params=params)
        layer_libs.append(lib)

In [215]:
target = "llvm" #if self.device == "cpu" else "cuda"
layer_models = [graph_executor.GraphModule(lib["default"](tvm.device(target, 0))) for lib in layer_libs]

In [217]:
dummy_output_list = []

for tmp_model in layer_models:
    tmp_model.set_input("input0", tvm.nd.array(input_data))
    tmp_model.run()
    dummy_output_list.append(tmp_model.get_output(0).asnumpy())

In [219]:
for do in dummy_output_list:
    print(do.shape)

(1, 1000)
(1, 512, 7, 7)
(1, 512, 7, 7)
(1, 512, 7, 7)
(1, 512, 7, 7)
(1, 256, 14, 14)
(1, 256, 14, 14)
(1, 256, 14, 14)
(1, 256, 14, 14)
(1, 128, 28, 28)
(1, 128, 28, 28)
(1, 128, 28, 28)
(1, 128, 28, 28)
(1, 64, 56, 56)
(1, 64, 56, 56)
(1, 64, 56, 56)
(1, 64, 56, 56)
(1, 64, 112, 112)
(1, 128, 28, 28)
(1, 256, 14, 14)
(1, 512, 7, 7)
