In [None]:
from torchvision import transforms
from PIL import Image

import time
import numpy as np

import torch
import torchvision

from tvm.contrib.download import download_testdata

model_name = "resnet18"
model = torchvision.models.__dict__[model_name](pretrained=True)
model = model.eval()

input = torch.randn(1, 3, 224, 224)
script_module = torch.jit.trace(model, input).eval()


image_path = download_testdata("https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true",
                               "cat.png", module="data")
img = Image.open(image_path).resize((224, 224))

transform = transforms.Compose([transforms.Resize(256),
                                transforms.CenterCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])

data = transform(img)

In [None]:
img = np.expand_dims(img, axis=0)
print(img.shape)
img = np.transpose(img, (0, 3, 1, 2))

print(img.shape)

from tvm import relay

input_name = "input0"
shape_list = [(input_name, img.shape)]

mod,param = relay.frontend.from_pytorch(script_module, shape_list)

In [None]:
import tvm 

target = 'llvm'
target_host = 'llvm'
ctx = tvm.cpu(0)

with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target=target, target_host=target_host,params=param)

In [None]:
# from tvm.contrib import graph_runtime
import tvm.contrib.graph_executor as runtime

def bench_tvm():
    for i in range(100):
        dtype = "float32"

        # m = graph_runtime.GraphModule(lib['default'](ctx))
        m = runtime.GraphModule(lib['default'](ctx))
        m.set_input(input_name, tvm.nd.array(img.astype(dtype)))
        m.run()

        _ = m.get_output(0)

%timeit bench_tvm()

In [None]:
from tvm import auto_scheduler
from tvm.relay import data_dep_optimization as ddo

In [None]:
tasks,wei =  auto_scheduler.extract_tasks(mod['main'],params=param,target=target)

In [None]:
tasks
wei

In [None]:
for i,task in enumerate(tasks):
    print(i, task.compute_dag)

In [None]:
tunner = auto_scheduler.TaskScheduler(tasks,wei)

In [23]:
log_file = 'learn_auto_sche.log'
tun_option = auto_scheduler.TuningOptions(
    num_measure_trials=200,
    runner=auto_scheduler.LocalRunner(repeat=1, enable_cpu_cache_flush=True),
    measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
)

tunner.tune(tun_option)

|  ID  |                       Task Description                        | Latency (ms) | Speed (GFLOPS) | Trials |----------------------------------------------------------------------
------------------------------  [ Task Scheduler ]

-----------------------------------------------------------------------------------------------------------------
|    0 |                    vm_mod_fused_nn_contrib_conv2d_NCHWc_add_1 |            - |              - |      0 |
|    1 |              vm_mod_fused_nn_contrib_conv2d_NCHWc_add_nn_relu |            - |              - |      0 |
----------------------------------------------------------------------
|    2 |        vm_mod_fused_nn_contrib_conv2d_NCHWc_add_add_nn_relu_1 |            - |              - |      0 |
|    3 |                                     vm_mod_fused_nn_dense_add |            - |              - |      0 |
|    4 |            vm_mod_fused_nn_contrib_conv2d_NCHWc_add_nn_relu_5 |            - |              - |      0 |
|    5 | 

TypeError: callback must be an instance of `TrainingCallback`.