In [1]:
!pip install onnxruntime-gpu
!pip install onnxconverter-common
!pip install apache-tvm
!pip install torch
!pip install numpy
!pip install onnxruntime
!pip install onnx
!pip install timm
!pip install tqdm
!pip install torchvision


Collecting onnxruntime-gpu
  Downloading onnxruntime_gpu-1.17.1-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (4.3 kB)
Collecting coloredlogs (from onnxruntime-gpu)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime-gpu)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)
Downloading onnxruntime_gpu-1.17.1-cp310-cp310-manylinux_2_28_x86_64.whl (192.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m192.1/192.1 MB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.8/86.8 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: huma

In [2]:
# import libraries
import os
import torch
import json
import time
from tqdm.auto import tqdm
import timm
import numpy as np
import onnxruntime as ort
from onnxruntime import quantization
from onnxconverter_common import float16
import onnx
import tvm
from tvm import relay



In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device = ", device)


device =  cuda


In [4]:
def preprocess(model):
    data_config = timm.data.resolve_model_data_config(model)
    transforms = timm.data.create_transform(**data_config, is_training = False)
    print(transforms)
    
    val_dataset = timm.data.ImageDataset('/kaggle/input/imagenet/imagenet-mini', transform = transforms)
    val_loader = timm.data.create_loader(val_dataset, (1, 3, 224, 224), 1)
    
    val_dataset_sub = torch.utils.data.Subset(val_dataset, list(range(1000)))
    val_loader_sub = timm.data.create_loader(val_dataset_sub, (1, 3, 224, 224), 1)
    
    return val_loader, val_loader_sub

print("Preprocess Done ...!")

Preprocess Done ...!


In [5]:
class OnnxStaticQuantization:
    def __init__(self) -> None:
        self.enum_data = None
        self.calibration_technique = {
            "MinMax": ort.quantization.calibrate.CalibrationMethod.MinMax,
            "Entropy": ort.quantization.calibrate.CalibrationMethod.Entropy,
            "Percentile": ort.quantization.calibrate.CalibrationMethod.Percentile,
            "Distribution": ort.quantization.calibrate.CalibrationMethod.Distribution
        }

    def get_next(self, EP_list = ['CPUExecutionProvider']):
        if self.enum_data is None:
            session = ort.InferenceSession(self.fp32_onnx_path, providers = EP_list)
            input_name = session.get_inputs()[0].name
            calib_list = []
            count = 0
            for nhwc_data, _ in self.calibration_loader:
                nhwc_data=nhwc_data.cpu()
                calib_list.append({input_name: nhwc_data.numpy()}) 
                if self.sample == count: break
                count = count + 1
            self.enum_data = iter(calib_list)
        return next(self.enum_data, None)
    
    def quantization(self, fp32_onnx_path, future_int8_onnx_path, calib_method, calibration_loader, sample = 100):
        self.sample = sample
        self.calibration_loader = calibration_loader 
        _ = ort.quantization.quantize_static(
                model_input = fp32_onnx_path,
                model_output = future_int8_onnx_path,
                calibrate_method = self.calibration_technique[calib_method],
                activation_type=ort.quantization.QuantType.QInt16,
                weight_type=ort.quantization.QuantType.QInt8,
                per_channel = True, 
                reduce_range = True,
                calibration_data_reader = self
            )
        return self


print("Quantization Done ...!")

Quantization Done ...!


In [6]:
# Quantization Investigation
def quant_investigation(quant_model_name):
    _model = onnx.load(quant_model_name + ".onnx")
    initializers = _model.graph.initializer

    for node_i in _model.graph.node:
        if node_i.output and "QuantizeLinear" not in node_i.output[0] and "DequantizeLinear" not in node_i.name:
            for node_j in _model.graph.node:
                if node_j.input and node_i.output[0] == node_j.input[0] and "QuantizeLinear" not in node_j.output[0]:
                    print(node_i.name)

print("Quantization Investigation Done ...!")


Quantization Investigation Done ...!


In [7]:
def validate(model, val_loader, model_name, ONNX = False, quant = "", sample_size = 100, quant_invest = False, TVM = False):
    correct = 0
    total = 0
    elapsed_time = 0
    top5_correct = 0
    
    if ONNX:    # 1
        if not TVM:    # 2  
            if quant == "fp16":    # 3
                model = onnx.load(model_name + ".onnx")
                model_fp16 = float16.convert_float_to_float16(model)
                onnx.save(model_fp16,model_name + quant + ".onnx")
                
            elif quant == "int8":    # 3
                ort.quantization.shape_inference.quant_pre_process(model_name + ".onnx", "preprocess.onnx")
                module = OnnxStaticQuantization()
                module.fp32_onnx_path = "preprocess.onnx"
                
                module.quantization(
                    fp32_onnx_path = "preprocess.onnx",
                    future_int8_onnx_path = model_name + quant + ".onnx",
#                     calib_method = "MinMax",
#                     calib_method = "Entropy",
                    calib_method = "Percentile",
#                     calib_method = "Distribution",
                    calibration_loader = val_loader,
                    sample = sample_size
                )
                
            else:    # 3
                dummy_input = torch.randn(1, 3, 224, 224).to(device)
                torch.onnx.export(model, dummy_input, model_name + ".onnx", export_params = True, opset_version = 14, do_constant_folding = True)
                
                
        else:    # 2
            onnx_model = onnx.load(model_name + quant + ".onnx")
            shape_dict = {"input.1": (1, 3, 224, 224)}
            mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)
            target = "llvm -mcpu=core-avx2"
            with tvm.transform.PassContext(opt_level = 2):
                executor = relay.build_module.create_executor("graph", mod, tvm.cpu(0), target, params).evaluate()
        
        # -------------------------------------------------------------------------------------------------------------------------
        
        sess = ort.InferenceSession(model_name + quant + ".onnx", providers = ['CPUExecutionProvider'])
        for batch_idx, (inputs, labels) in enumerate(tqdm(val_loader)):
            if quant == "fp16":
                inputs = inputs.half()
            inputs = inputs.cpu().numpy()
            
            if TVM:
                inputs = tvm.nd.array(inputs.astype("float32"))
                start_time = time.time()    
                outputs = [executor(inputs).numpy()]
                end_time = time.time()
                
            else:
                start_time = time.time()    
                outputs = sess.run(None, {sess.get_inputs()[0].name: inputs})
                end_time = time.time()
            
            elapsed_time += end_time - start_time

            predicted_labels = np.argmax(outputs[0], axis=1)

            correct += (predicted_labels == labels.cpu().numpy()).sum()
            total += labels.size(0)

            top5_predicted = np.argsort(outputs[0], axis=1)[:, -5:]  
            top5_correct += np.sum(np.equal(top5_predicted, np.expand_dims(labels.cpu().numpy(), axis=1)))
    
    
    elif ONNX == False and quant_invest == True and quant != "":    # 1
        if quant == "fp16" :
            print("fp16 quantization investigation : ")
            quant_investigation(model_name + quant)
            print()
        elif quant == "int8" :
            print("int8 quantization investigation : ")
            quant_investigation(model_name + quant)
            print()
        else :
            print("Error : quant is empty.")
        return
    
    else:  # 1
        model = model.eval()
        with torch.inference_mode():
            for batch_idx, (images, labels) in enumerate(tqdm(val_loader)):
                images, labels = images.to(device), labels.to(device)
                
                start_time = time.time()
                outputs = model(images)
                end_time = time.time()
                elapsed_time += end_time - start_time
                
                _, predicted = torch.max(outputs.softmax(dim = 1) * 100, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                _, top5_predicted = torch.topk(outputs.softmax(dim = 1) * 100, k = 5)
                top5_correct += sum(labels[i].item() in top5_predicted[i] for i in range(len(labels)))
    
    # =============================================================================================================================

    # Calculate accuracy
    single_inference_runtime = elapsed_time / total
    top1Accuracy = correct / total
    top5Accuracy = top5_correct / total
    
    print('Single Inference Runtime: {:.4f} seconds\nTop 1 Accuracy : {:.2f}%\nTop 5 Accuracy : {:.2f}%\n'.format(single_inference_runtime, 100 * top1Accuracy, 100 * top5Accuracy))


print("Validation Done ...!")

Validation Done ...!


In [8]:
# import torchvision.models as tm
# model = tm.maxvit_t('vit_base_patch16_384', pretrained=True).to(device)
# from torchvision.models import MaxVit_T_Weights
# model = timm.create_model('vit_base_patch16_384', pretrained=True).to(device)

# model load
model = timm.create_model('vit_base_patch16_224', pretrained=True).to(device)
torch.save(model.state_dict(), 'vit_base_patch16_224.pth')
print("Model Saved ...!")

val_loader, val_loader_sub = preprocess(model)
print("Data Preprocessing Done.")


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Model Saved ...!
Compose(
    Resize(size=248, interpolation=bicubic, max_size=None, antialias=warn)
    CenterCrop(size=(224, 224))
    ToTensor()
    Normalize(mean=tensor([0.5000, 0.5000, 0.5000]), std=tensor([0.5000, 0.5000, 0.5000]))
)
Data Preprocessing Done.


In [9]:
# Baseline Setup
validate(model, val_loader_sub, 'vit_base_patch16_224')


  0%|          | 0/1000 [00:00<?, ?it/s]

Single Inference Runtime: 0.0059 seconds
Top 1 Accuracy : 94.90%
Top 5 Accuracy : 99.70%



In [10]:
# Onnx Validation
validate(model, val_loader_sub, 'vit_base_patch16_224', ONNX=True)


  assert condition, message


  0%|          | 0/1000 [00:00<?, ?it/s]

Single Inference Runtime: 0.2128 seconds
Top 1 Accuracy : 94.90%
Top 5 Accuracy : 99.70%



In [11]:
# Model Quantization fp16
validate(model, val_loader_sub, 'vit_base_patch16_224', ONNX=True, quant="fp16")


[0;93m2024-04-06 13:08:56.185434244 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/blocks/blocks.0/attn/Sqrt'[m
[0;93m2024-04-06 13:08:56.185706000 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/blocks/blocks.1/attn/Sqrt'[m
[0;93m2024-04-06 13:08:56.185950038 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/blocks/blocks.2/attn/Sqrt'[m
[0;93m2024-04-06 13:08:56.186206724 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/blocks/blocks.3/attn/Sqrt'[m
[0;93m2024-04-06 13:08:56.186423491 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/blocks/blocks.4/attn/Sqrt'[m
[0;93m2024-04-06 13:08:56.186637530 [W:onnxruntime:, c

  0%|          | 0/1000 [00:00<?, ?it/s]

Single Inference Runtime: 0.4523 seconds
Top 1 Accuracy : 94.90%
Top 5 Accuracy : 99.70%



In [12]:
# Model QDQ Investigation
validate(model, val_loader_sub, 'vit_base_patch16_224', ONNX=False, quant="fp16", quant_invest=True)


fp16 quantization investigation : 
/patch_embed/proj/Conv
/patch_embed/proj/Conv
/patch_embed/Shape
/patch_embed/Slice
/patch_embed/Reshape
/Constant_1
/Constant_2
/ConstantOfShape
/Equal
/Expand
/Concat
/Add
/Add
/Add
/blocks/blocks.0/norm1/Sub
/blocks/blocks.0/norm1/Sub
/blocks/blocks.0/norm1/Pow
/blocks/blocks.0/norm1/ReduceMean_1
/blocks/blocks.0/norm1/Add
/blocks/blocks.0/norm1/Div
/blocks/blocks.0/norm1/Mul
/blocks/blocks.0/norm1/Add_1
/blocks/blocks.0/attn/qkv/Add
/blocks/blocks.0/attn/Reshape
/blocks/blocks.0/attn/Transpose
/blocks/blocks.0/attn/Split
/blocks/blocks.0/attn/Squeeze
/blocks/blocks.0/attn/Squeeze
/blocks/blocks.0/attn/Squeeze_1
/blocks/blocks.0/attn/Shape
/blocks/blocks.0/attn/Slice
/blocks/blocks.0/attn/Cast
/blocks/blocks.0/attn/Constant_7
/blocks/blocks.0/attn/Div
/blocks/blocks.0/attn/Div
/blocks/blocks.0/attn/Transpose_1
/blocks/blocks.0/attn/Mul
/blocks/blocks.0/attn/MatMul
/blocks/blocks.0/attn/Softmax
/blocks/blocks.0/attn/MatMul_1
/blocks/blocks.0/attn/Tr

In [13]:
# Model Quantization int8
validate(model, val_loader_sub, 'vit_base_patch16_224', ONNX=True, sample_size=50, quant="int8")


Collecting tensor data and making histogram ...
Finding optimal threshold for each tensor using 'percentile' algorithm ...
Number of tensors : 505
Number of histogram bins : 2048
Percentile : (0.0010000000000047748,99.999)


  quantized_data = (np.asarray(bias_data) / bias_scale).round().astype(np.int32)


  0%|          | 0/1000 [00:00<?, ?it/s]

Single Inference Runtime: 0.5787 seconds
Top 1 Accuracy : 87.00%
Top 5 Accuracy : 96.60%



In [14]:
# Model QDQ Investigation
validate(model, val_loader_sub, 'vit_base_patch16_224', ONNX=False, quant="int8", quant_invest=True)


int8 quantization investigation : 
/blocks/blocks.0/norm1/Sub
/blocks/blocks.0/norm1/Sub
/blocks/blocks.0/norm1/Pow
/blocks/blocks.0/norm2/Sub
/blocks/blocks.0/norm2/Sub
/blocks/blocks.0/norm2/Pow
/blocks/blocks.0/mlp/act/Div
/blocks/blocks.1/norm1/Sub
/blocks/blocks.1/norm1/Sub
/blocks/blocks.1/norm1/Pow
/blocks/blocks.1/norm2/Sub
/blocks/blocks.1/norm2/Sub
/blocks/blocks.1/norm2/Pow
/blocks/blocks.1/mlp/act/Div
/blocks/blocks.2/norm1/Sub
/blocks/blocks.2/norm1/Sub
/blocks/blocks.2/norm1/Pow
/blocks/blocks.2/norm2/Sub
/blocks/blocks.2/norm2/Sub
/blocks/blocks.2/norm2/Pow
/blocks/blocks.2/mlp/act/Div
/blocks/blocks.3/norm1/Sub
/blocks/blocks.3/norm1/Sub
/blocks/blocks.3/norm1/Pow
/blocks/blocks.3/norm2/Sub
/blocks/blocks.3/norm2/Sub
/blocks/blocks.3/norm2/Pow
/blocks/blocks.3/mlp/act/Div
/blocks/blocks.4/norm1/Sub
/blocks/blocks.4/norm1/Sub
/blocks/blocks.4/norm1/Pow
/blocks/blocks.4/norm2/Sub
/blocks/blocks.4/norm2/Sub
/blocks/blocks.4/norm2/Pow
/blocks/blocks.4/mlp/act/Div
/blocks/bl

In [15]:
# Quantization Investigation
import onnx

def quant_investigation(quant_model_name):
    _model = onnx.load(quant_model_name + ".onnx")
    initializers = _model.graph.initializer

    for node_i in _model.graph.node:
        if node_i.output and "QuantizeLinear" not in node_i.output[0] and "DequantizeLinear" not in node_i.name:
            for node_j in _model.graph.node:
                if node_j.input and node_i.output[0] == node_j.input[0] and "QuantizeLinear" not in node_j.output[0]:
                    print(node_i.name)

print("fp16 quantization investigation : ")
quant_investigation("vit_base_patch16_224fp16")
print()


fp16 quantization investigation : 
/patch_embed/proj/Conv
/patch_embed/proj/Conv
/patch_embed/Shape
/patch_embed/Slice
/patch_embed/Reshape
/Constant_1
/Constant_2
/ConstantOfShape
/Equal
/Expand
/Concat
/Add
/Add
/Add
/blocks/blocks.0/norm1/Sub
/blocks/blocks.0/norm1/Sub
/blocks/blocks.0/norm1/Pow
/blocks/blocks.0/norm1/ReduceMean_1
/blocks/blocks.0/norm1/Add
/blocks/blocks.0/norm1/Div
/blocks/blocks.0/norm1/Mul
/blocks/blocks.0/norm1/Add_1
/blocks/blocks.0/attn/qkv/Add
/blocks/blocks.0/attn/Reshape
/blocks/blocks.0/attn/Transpose
/blocks/blocks.0/attn/Split
/blocks/blocks.0/attn/Squeeze
/blocks/blocks.0/attn/Squeeze
/blocks/blocks.0/attn/Squeeze_1
/blocks/blocks.0/attn/Shape
/blocks/blocks.0/attn/Slice
/blocks/blocks.0/attn/Cast
/blocks/blocks.0/attn/Constant_7
/blocks/blocks.0/attn/Div
/blocks/blocks.0/attn/Div
/blocks/blocks.0/attn/Transpose_1
/blocks/blocks.0/attn/Mul
/blocks/blocks.0/attn/MatMul
/blocks/blocks.0/attn/Softmax
/blocks/blocks.0/attn/MatMul_1
/blocks/blocks.0/attn/Tr

In [16]:
# Model TVM with quant int8
validate(model, val_loader_sub, 'vit_base_patch16_224', ONNX = True, sample_size = 50, quant = 'int8', TVM = True)


  0%|          | 0/1000 [00:00<?, ?it/s]

Single Inference Runtime: 2.2313 seconds
Top 1 Accuracy : 86.90%
Top 5 Accuracy : 96.60%



In [17]:
# Model TVM with fp16
validate(model, val_loader_sub, 'vit_base_patch16_224', ONNX = True, sample_size = 50, quant = 'fp16', TVM = True)


[0;93m2024-04-06 14:13:23.062272959 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/blocks/blocks.0/attn/Sqrt'[m
[0;93m2024-04-06 14:13:23.062549787 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/blocks/blocks.1/attn/Sqrt'[m
[0;93m2024-04-06 14:13:23.062802720 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/blocks/blocks.2/attn/Sqrt'[m
[0;93m2024-04-06 14:13:23.063042316 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/blocks/blocks.3/attn/Sqrt'[m
[0;93m2024-04-06 14:13:23.063286738 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/blocks/blocks.4/attn/Sqrt'[m
[0;93m2024-04-06 14:13:23.063523432 [W:onnxruntime:, c

  0%|          | 0/1000 [00:00<?, ?it/s]

TVMError: Traceback (most recent call last):
  191: 0x00005c68b85dfd10
  190: __libc_start_main
  189: Py_BytesMain
        at /usr/local/src/conda/python-3.10.13/Modules/main.c:1090
  188: Py_RunMain
        at /usr/local/src/conda/python-3.10.13/Modules/main.c:670
  187: pymain_run_python
        at /usr/local/src/conda/python-3.10.13/Modules/main.c:585
  186: pymain_run_module
        at /usr/local/src/conda/python-3.10.13/Modules/main.c:297
  185: _PyFunction_Vectorcall
        at /usr/local/src/conda/python-3.10.13/Objects/call.c:342
  184: _PyEval_Vector
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5067
  183: _PyEval_EvalFrame
        at /usr/local/src/conda/python-3.10.13/Include/internal/pycore_ceval.h:46
  182: _PyEval_EvalFrameDefault
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:4213
  181: call_function
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5893
  180: PyObject_Vectorcall
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:123
  179: _PyObject_VectorcallTstate
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:114
  178: _PyFunction_Vectorcall
        at /usr/local/src/conda/python-3.10.13/Objects/call.c:342
  177: _PyEval_Vector
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5067
  176: _PyEval_EvalFrame
        at /usr/local/src/conda/python-3.10.13/Include/internal/pycore_ceval.h:46
  175: _PyEval_EvalFrameDefault
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:4213
  174: call_function
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5893
  173: PyObject_Vectorcall
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:123
  172: _PyObject_VectorcallTstate
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:114
  171: cfunction_vectorcall_FASTCALL
        at /usr/local/src/conda/python-3.10.13/Objects/methodobject.c:430
  170: builtin_exec
        at /usr/local/src/conda/python-3.10.13/Python/clinic/bltinmodule.c.h:371
  169: builtin_exec_impl
        at /usr/local/src/conda/python-3.10.13/Python/bltinmodule.c:1058
  168: PyEval_EvalCode
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:1134
  167: _PyEval_Vector
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5067
  166: _PyEval_EvalFrame
        at /usr/local/src/conda/python-3.10.13/Include/internal/pycore_ceval.h:46
  165: _PyEval_EvalFrameDefault
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:4181
  164: call_function
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5893
  163: PyObject_Vectorcall
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:123
  162: _PyObject_VectorcallTstate
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:114
  161: method_vectorcall
        at /usr/local/src/conda/python-3.10.13/Objects/classobject.c:53
  160: _PyObject_VectorcallTstate
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:114
  159: _PyFunction_Vectorcall
        at /usr/local/src/conda/python-3.10.13/Objects/call.c:342
  158: _PyEval_Vector
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5067
  157: _PyEval_EvalFrame
        at /usr/local/src/conda/python-3.10.13/Include/internal/pycore_ceval.h:46
  156: _PyEval_EvalFrameDefault
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:4198
  155: call_function
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5893
  154: PyObject_Vectorcall
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:123
  153: _PyObject_VectorcallTstate
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:114
  152: _PyFunction_Vectorcall
        at /usr/local/src/conda/python-3.10.13/Objects/call.c:342
  151: _PyEval_Vector
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5067
  150: _PyEval_EvalFrame
        at /usr/local/src/conda/python-3.10.13/Include/internal/pycore_ceval.h:46
  149: _PyEval_EvalFrameDefault
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:4198
  148: call_function
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5893
  147: PyObject_Vectorcall
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:123
  146: _PyObject_VectorcallTstate
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:114
  145: _PyFunction_Vectorcall
        at /usr/local/src/conda/python-3.10.13/Objects/call.c:342
  144: _PyEval_Vector
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5067
  143: _PyEval_EvalFrame
        at /usr/local/src/conda/python-3.10.13/Include/internal/pycore_ceval.h:46
  142: _PyEval_EvalFrameDefault
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:4198
  141: call_function
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5893
  140: PyObject_Vectorcall
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:123
  139: _PyObject_VectorcallTstate
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:114
  138: _PyFunction_Vectorcall
        at /usr/local/src/conda/python-3.10.13/Objects/call.c:342
  137: _PyEval_Vector
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5067
  136: _PyEval_EvalFrame
        at /usr/local/src/conda/python-3.10.13/Include/internal/pycore_ceval.h:46
  135: _PyEval_EvalFrameDefault
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:4198
  134: call_function
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5893
  133: PyObject_Vectorcall
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:123
  132: _PyObject_VectorcallTstate
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:114
  131: _PyFunction_Vectorcall
        at /usr/local/src/conda/python-3.10.13/Objects/call.c:342
  130: _PyEval_Vector
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5067
  129: _PyEval_EvalFrame
        at /usr/local/src/conda/python-3.10.13/Include/internal/pycore_ceval.h:46
  128: _PyEval_EvalFrameDefault
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:4198
  127: call_function
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5893
  126: PyObject_Vectorcall
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:123
  125: _PyObject_VectorcallTstate
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:114
  124: _PyFunction_Vectorcall
        at /usr/local/src/conda/python-3.10.13/Objects/call.c:342
  123: _PyEval_Vector
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5067
  122: _PyEval_EvalFrame
        at /usr/local/src/conda/python-3.10.13/Include/internal/pycore_ceval.h:46
  121: _PyEval_EvalFrameDefault
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:4277
  120: do_call_core
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5917
  119: cfunction_vectorcall_FASTCALL_KEYWORDS
        at /usr/local/src/conda/python-3.10.13/Objects/methodobject.c:446
  118: context_run
        at /usr/local/src/conda/python-3.10.13/Python/context.c:665
  117: _PyObject_VectorcallTstate
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:114
  116: cfunction_vectorcall_O
        at /usr/local/src/conda/python-3.10.13/Objects/methodobject.c:516
  115: task_step
        at /usr/local/src/conda/python-3.10.13/Modules/_asynciomodule.c:2950
  114: task_step_impl
        at /usr/local/src/conda/python-3.10.13/Modules/_asynciomodule.c:2653
  113: gen_send_ex2
        at /usr/local/src/conda/python-3.10.13/Objects/genobject.c:213
  112: _PyEval_EvalFrame
        at /usr/local/src/conda/python-3.10.13/Include/internal/pycore_ceval.h:46
  111: _PyEval_EvalFrameDefault
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:2586
  110: gen_send_ex2
        at /usr/local/src/conda/python-3.10.13/Objects/genobject.c:213
  109: _PyEval_EvalFrame
        at /usr/local/src/conda/python-3.10.13/Include/internal/pycore_ceval.h:46
  108: _PyEval_EvalFrameDefault
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:2586
  107: gen_send_ex2
        at /usr/local/src/conda/python-3.10.13/Objects/genobject.c:213
  106: _PyEval_EvalFrame
        at /usr/local/src/conda/python-3.10.13/Include/internal/pycore_ceval.h:46
  105: _PyEval_EvalFrameDefault
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:2586
  104: gen_send_ex2
        at /usr/local/src/conda/python-3.10.13/Objects/genobject.c:213
  103: _PyEval_EvalFrame
        at /usr/local/src/conda/python-3.10.13/Include/internal/pycore_ceval.h:46
  102: _PyEval_EvalFrameDefault
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:2586
  101: gen_send_ex2
        at /usr/local/src/conda/python-3.10.13/Objects/genobject.c:213
  100: _PyEval_EvalFrame
        at /usr/local/src/conda/python-3.10.13/Include/internal/pycore_ceval.h:46
  99: _PyEval_EvalFrameDefault
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:4231
  98: call_function
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5893
  97: PyObject_Vectorcall
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:123
  96: _PyObject_VectorcallTstate
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:114
  95: method_vectorcall
        at /usr/local/src/conda/python-3.10.13/Objects/classobject.c:53
  94: _PyObject_VectorcallTstate
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:114
  93: _PyFunction_Vectorcall
        at /usr/local/src/conda/python-3.10.13/Objects/call.c:342
  92: _PyEval_Vector
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5067
  91: _PyEval_EvalFrame
        at /usr/local/src/conda/python-3.10.13/Include/internal/pycore_ceval.h:46
  90: _PyEval_EvalFrameDefault
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:4277
  89: do_call_core
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5945
  88: PyObject_Call
        at /usr/local/src/conda/python-3.10.13/Objects/call.c:317
  87: _PyObject_Call
        at /usr/local/src/conda/python-3.10.13/Objects/call.c:290
  86: PyVectorcall_Call
        at /usr/local/src/conda/python-3.10.13/Objects/call.c:267
  85: method_vectorcall
        at /usr/local/src/conda/python-3.10.13/Objects/classobject.c:53
  84: _PyObject_VectorcallTstate
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:114
  83: _PyFunction_Vectorcall
        at /usr/local/src/conda/python-3.10.13/Objects/call.c:342
  82: _PyEval_Vector
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5067
  81: _PyEval_EvalFrame
        at /usr/local/src/conda/python-3.10.13/Include/internal/pycore_ceval.h:46
  80: _PyEval_EvalFrameDefault
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:4198
  79: call_function
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5893
  78: PyObject_Vectorcall
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:123
  77: _PyObject_VectorcallTstate
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:114
  76: _PyFunction_Vectorcall
        at /usr/local/src/conda/python-3.10.13/Objects/call.c:342
  75: _PyEval_Vector
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5067
  74: _PyEval_EvalFrame
        at /usr/local/src/conda/python-3.10.13/Include/internal/pycore_ceval.h:46
  73: _PyEval_EvalFrameDefault
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:4213
  72: call_function
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5893
  71: PyObject_Vectorcall
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:123
  70: _PyObject_VectorcallTstate
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:114
  69: _PyFunction_Vectorcall
        at /usr/local/src/conda/python-3.10.13/Objects/call.c:342
  68: _PyEval_Vector
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5067
  67: _PyEval_EvalFrame
        at /usr/local/src/conda/python-3.10.13/Include/internal/pycore_ceval.h:46
  66: _PyEval_EvalFrameDefault
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:4198
  65: call_function
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5893
  64: PyObject_Vectorcall
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:123
  63: _PyObject_VectorcallTstate
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:114
  62: method_vectorcall_O
        at /usr/local/src/conda/python-3.10.13/Objects/descrobject.c:460
  61: gen_send_ex
        at /usr/local/src/conda/python-3.10.13/Objects/genobject.c:279
  60: gen_send_ex2
        at /usr/local/src/conda/python-3.10.13/Objects/genobject.c:213
  59: _PyEval_EvalFrame
        at /usr/local/src/conda/python-3.10.13/Include/internal/pycore_ceval.h:46
  58: _PyEval_EvalFrameDefault
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:2586
  57: gen_send_ex2
        at /usr/local/src/conda/python-3.10.13/Objects/genobject.c:213
  56: _PyEval_EvalFrame
        at /usr/local/src/conda/python-3.10.13/Include/internal/pycore_ceval.h:46
  55: _PyEval_EvalFrameDefault
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:2586
  54: gen_send_ex2
        at /usr/local/src/conda/python-3.10.13/Objects/genobject.c:213
  53: _PyEval_EvalFrame
        at /usr/local/src/conda/python-3.10.13/Include/internal/pycore_ceval.h:46
  52: _PyEval_EvalFrameDefault
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:4213
  51: call_function
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5893
  50: PyObject_Vectorcall
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:123
  49: _PyObject_VectorcallTstate
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:114
  48: cfunction_vectorcall_FASTCALL
        at /usr/local/src/conda/python-3.10.13/Objects/methodobject.c:430
  47: builtin_exec
        at /usr/local/src/conda/python-3.10.13/Python/clinic/bltinmodule.c.h:371
  46: builtin_exec_impl
        at /usr/local/src/conda/python-3.10.13/Python/bltinmodule.c:1058
  45: PyEval_EvalCode
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:1134
  44: _PyEval_Vector
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5067
  43: _PyEval_EvalFrame
        at /usr/local/src/conda/python-3.10.13/Include/internal/pycore_ceval.h:46
  42: _PyEval_EvalFrameDefault
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:4231
  41: call_function
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5893
  40: PyObject_Vectorcall
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:123
  39: _PyObject_VectorcallTstate
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:114
  38: _PyFunction_Vectorcall
        at /usr/local/src/conda/python-3.10.13/Objects/call.c:342
  37: _PyEval_Vector
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5067
  36: _PyEval_EvalFrame
        at /usr/local/src/conda/python-3.10.13/Include/internal/pycore_ceval.h:46
  35: _PyEval_EvalFrameDefault
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:4213
  34: call_function
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5893
  33: PyObject_Vectorcall
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:123
  32: _PyObject_VectorcallTstate
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:114
  31: _PyFunction_Vectorcall
        at /usr/local/src/conda/python-3.10.13/Objects/call.c:342
  30: _PyEval_Vector
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5067
  29: _PyEval_EvalFrame
        at /usr/local/src/conda/python-3.10.13/Include/internal/pycore_ceval.h:46
  28: _PyEval_EvalFrameDefault
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:4198
  27: call_function
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5893
  26: PyObject_Vectorcall
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:123
  25: _PyObject_VectorcallTstate
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:114
  24: _PyFunction_Vectorcall
        at /usr/local/src/conda/python-3.10.13/Objects/call.c:342
  23: _PyEval_Vector
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5067
  22: _PyEval_EvalFrame
        at /usr/local/src/conda/python-3.10.13/Include/internal/pycore_ceval.h:46
  21: _PyEval_EvalFrameDefault
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:4198
  20: call_function
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5893
  19: PyObject_Vectorcall
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:123
  18: _PyObject_VectorcallTstate
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:114
  17: _PyFunction_Vectorcall
        at /usr/local/src/conda/python-3.10.13/Objects/call.c:342
  16: _PyEval_Vector
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5067
  15: _PyEval_EvalFrame
        at /usr/local/src/conda/python-3.10.13/Include/internal/pycore_ceval.h:46
  14: _PyEval_EvalFrameDefault
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:4198
  13: call_function
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5893
  12: PyObject_Vectorcall
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:123
  11: _PyObject_VectorcallTstate
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:114
  10: _PyFunction_Vectorcall
        at /usr/local/src/conda/python-3.10.13/Objects/call.c:342
  9: _PyEval_Vector
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5067
  8: _PyEval_EvalFrame
        at /usr/local/src/conda/python-3.10.13/Include/internal/pycore_ceval.h:46
  7: _PyEval_EvalFrameDefault
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:4198
  6: call_function
        at /usr/local/src/conda/python-3.10.13/Python/ceval.c:5893
  5: PyObject_Vectorcall
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:123
  4: _PyObject_VectorcallTstate
        at /usr/local/src/conda/python-3.10.13/Include/cpython/abstract.h:114
  3: method_vectorcall_O
        at /usr/local/src/conda/python-3.10.13/Objects/descrobject.c:460
  2: __pyx_pw_3tvm_4_ffi_4_cy3_4core_11NDArrayBase_5_copyto(_object*, _object*)
  1: TVMArrayCopyFromTo
  0: tvm::runtime::NDArray::CopyFromTo(DLTensor const*, DLTensor*, void*)
  File "/workspace/tvm/src/runtime/ndarray.cc", line 270
TVMError: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------

  Check failed: from_size == to_size (602112 vs. 301056) : TVMArrayCopyFromTo: The size must exactly match

In [None]:
# import torch
# import torchvision.models as models
# import urllib
# from PIL import Image
# from torchvision import transforms

# # Create an instance of the MaxViT-T model with pre-trained weights on ImageNet
# model = models.maxvit_t(weights='imagenet')

# # Define a preprocessing pipeline for the input image
# preprocess = transforms.Compose([
#     transforms.Resize(256),
#     transforms.CenterCrop(224),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
# ])

# # Load and preprocess an example image
# url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
# urllib.request.urlretrieve(url, filename)
# input_image = Image.open(filename)
# input_tensor = preprocess(input_image)
# input_batch = input_tensor.unsqueeze(0)  # Add a batch dimension

# # Put the model in evaluation mode
# model.eval()

# # Perform inference using the pre-trained model
# with torch.no_grad():
#     output = model(input_batch)

# # Print the top 5 predicted classes
# _, indices = torch.sort(output, descending=True)
# imagenet_labels = urllib.request.urlopen("https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json")
# labels = [line.strip() for line in imagenet_labels.readlines()]
# percentage = torch.nn.functional.softmax(output, dim=1)[0] * 100
# for i in range(5):
#     print(labels[indices[0][i]], percentage[indices[0][i]].item())
