In [1]:
from torchvision import transforms
from PIL import Image

import time
import numpy as np

import torch
import torchvision

from tvm.contrib.download import download_testdata

model_name = "resnet18"
model = torchvision.models.__dict__[model_name](pretrained=True)
model = model.eval()

input = torch.randn(1, 3, 224, 224)
script_module = torch.jit.trace(model, input).eval()


image_path = download_testdata("https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true",
                               "cat.png", module="data")
img = Image.open(image_path).resize((224, 224))

transform = transforms.Compose([transforms.Resize(256),
                                transforms.CenterCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])

data = transform(img)



In [2]:
# print(img)
img = np.expand_dims(img, axis=0)
print(img.shape)
img = np.transpose(img, (0, 3, 1, 2))

print(img.shape)

(1, 224, 224, 3)
(1, 3, 224, 224)


In [3]:
from tvm import relay

input_name = "input0"
shape_list = [(input_name, img.shape)]

mod,param = relay.frontend.from_pytorch(script_module, shape_list)

In [None]:
shape_list

In [None]:
mod

In [None]:
param

In [4]:
import tvm 

target = 'llvm'
target_host = 'llvm'
ctx = tvm.cpu(0)

with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target=target, target_host=target_host,params=param)

One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.


In [None]:
lib.__dict__

In [None]:
lib['default']

In [5]:
# from tvm.contrib import graph_runtime
import tvm.contrib.graph_executor as runtime

def bench_tvm():
    for i in range(100):
        dtype = "float32"

        # m = graph_runtime.GraphModule(lib['default'](ctx))
        m = runtime.GraphModule(lib['default'](ctx))
        m.set_input(input_name, tvm.nd.array(img.astype(dtype)))
        m.run()

        _ = m.get_output(0)

%timeit bench_tvm()

3.52 s ± 105 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
type(lib)

In [None]:
lib.export_library("module.so")

In [None]:
def bench_torch():
    for i in range(100):
        torch_img = torch.from_numpy(img.astype(np.float32))
        _ = model(torch_img)

%timeit bench_torch()

In [6]:
import tvm 

target = 'llvm'
target_host = 'llvm'
ctx = tvm.cpu(0)

with tvm.transform.PassContext(opt_level=3):
    graph, lib, params  = relay.build_module.build(mod, target=target, target_host=target_host,params=param)

  graph, lib, params  = relay.build_module.build(mod, target=target, target_host=target_host,params=param)


In [None]:
lib.export_library('resnet18.so')

In [None]:
with open("resnet18.graph","w") as f:
    f.write(graph)

In [None]:
from tvm  import runtime

with open("resnet18.params","wb") as f:
    f.write(runtime.save_param_dict(params))