In [1]:
!pip install onnxruntime-gpu
!pip install onnxconverter-common
!pip install apache-tvm
!pip install torch
!pip install numpy
!pip install onnxruntime
!pip install onnx
!pip install onnxsim
!pip install timm
!pip install tqdm
!pip install torchvision


Collecting onnxruntime-gpu
  Downloading onnxruntime_gpu-1.17.1-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (4.3 kB)
Collecting coloredlogs (from onnxruntime-gpu)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime-gpu)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)
Downloading onnxruntime_gpu-1.17.1-cp310-cp310-manylinux_2_28_x86_64.whl (192.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m192.1/192.1 MB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hDownloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.8/86.8 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected package

In [2]:
# import libraries
import os
import torch
import json
import time
from tqdm.auto import tqdm
import timm
import numpy as np
import onnxruntime as ort
from onnxruntime import quantization
from onnxconverter_common import float16
from onnxsim import simplify 
import onnx
import tvm
from tvm import relay


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device = ", device)


device =  cuda


In [4]:
def preprocess(model):
    data_config = timm.data.resolve_model_data_config(model)
    transforms = timm.data.create_transform(**data_config, is_training = False)
    print(transforms)
    
    val_dataset = timm.data.ImageDataset('/kaggle/input/imagenet/imagenet-mini', transform = transforms)
    val_loader = timm.data.create_loader(val_dataset, (1, 3, 224, 224), 1)
    
    val_dataset_sub = torch.utils.data.Subset(val_dataset, list(range(1000)))
    val_loader_sub = timm.data.create_loader(val_dataset_sub, (1, 3, 224, 224), 1)
    
    return val_loader, val_loader_sub

print("Preprocess Done ...!")


Preprocess Done ...!


In [5]:
class OnnxStaticQuantization:
    def __init__(self) -> None:
        self.enum_data = None
        self.calibration_technique = {
            "MinMax": ort.quantization.calibrate.CalibrationMethod.MinMax,
            "Entropy": ort.quantization.calibrate.CalibrationMethod.Entropy,
            "Percentile": ort.quantization.calibrate.CalibrationMethod.Percentile,
            "Distribution": ort.quantization.calibrate.CalibrationMethod.Distribution
        }

    def get_next(self, EP_list = ['CPUExecutionProvider']):
        if self.enum_data is None:
            session = ort.InferenceSession(self.fp32_onnx_path, providers = EP_list)
            input_name = session.get_inputs()[0].name
            calib_list = []
            count = 0
            for nhwc_data, _ in self.calibration_loader:
                nhwc_data=nhwc_data.cpu()
                calib_list.append({input_name: nhwc_data.numpy()}) 
                if self.sample == count: break
                count = count + 1
            self.enum_data = iter(calib_list)
        return next(self.enum_data, None)
    
    def quantization(self, fp32_onnx_path, future_onnx_path, calib_method, calibration_loader, sample = 100):
        self.sample = sample
        self.calibration_loader = calibration_loader 
        _ = ort.quantization.quantize_static(
                model_input = fp32_onnx_path,
                model_output = future_onnx_path,
                calibrate_method = self.calibration_technique[calib_method],
                activation_type=ort.quantization.QuantType.QInt8,
                weight_type=ort.quantization.QuantType.QInt8,
                per_channel = True, 
                reduce_range = True,
                calibration_data_reader = self
            )
        return self


print("Quantization Done ...!")


Quantization Done ...!


In [6]:
# Quantization Investigation
def quant_investigation(quant_model_name):
    _model = onnx.load(quant_model_name + ".onnx")
    initializers = _model.graph.initializer

    for node_i in _model.graph.node:
        if node_i.output and "QuantizeLinear" not in node_i.output[0] and "DequantizeLinear" not in node_i.name:
            for node_j in _model.graph.node:
                if node_j.input and node_i.output[0] == node_j.input[0] and "QuantizeLinear" not in node_j.output[0]:
                    print(node_i.name)

print("Quantization Investigation Done ...!")


Quantization Investigation Done ...!


In [14]:
# Validation 
def validate(model, val_loader, model_name, ONNX = False, quant = "", sample_size = 100, quant_invest = False, TVM = False):
    correct = 0
    total = 0
    elapsed_time = 0
    top5_correct = 0
    
    if ONNX:    # 1
        if not TVM:    # 2  
            if quant == "fp16":    # 3
                model = onnx.load(model_name + ".onnx")
                model_fp16 = float16.convert_float_to_float16(model)
                onnx.save(model_fp16,model_name + quant + ".onnx")
                
            elif quant == "int8":    # 3
                ort.quantization.shape_inference.quant_pre_process(model_name + ".onnx", "preprocess.onnx")
                module = OnnxStaticQuantization()
                module.fp32_onnx_path = "preprocess.onnx"
                
                module.quantization(
                    fp32_onnx_path = "preprocess.onnx",
                    future_onnx_path = model_name + quant + ".onnx",
#                     calib_method = "MinMax",
#                     calib_method = "Entropy",
                    calib_method = "Percentile",
#                     calib_method = "Distribution",
                    calibration_loader = val_loader,
                    sample = sample_size)
                
            elif quant == "fp32" :  # 3
                dummy_input = torch.randn(1, 3, 224, 224).to(device)
                torch.onnx.export(model, dummy_input, model_name + quant + ".onnx", export_params = True, opset_version = 14, do_constant_folding = True)
                
            else:    # 3
                dummy_input = torch.randn(1, 3, 224, 224).to(device)
                torch.onnx.export(model, dummy_input, model_name + ".onnx", export_params = True, opset_version = 14, do_constant_folding = True)
                
                
        else:    # 2
            onnx_model = onnx.load(model_name + quant + ".onnx")
            shape_dict = {"input.1": (1, 3, 224, 224)}
            mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)
            target = "llvm -mcpu=core-avx2"
            with tvm.transform.PassContext(opt_level = 2):
                executor = relay.build_module.create_executor("graph", mod, tvm.cpu(0), target, params).evaluate()
        
        # -----------------------------------------------------------------------------------------------------------------------------------------------------------------------------
        
        sess = ort.InferenceSession(model_name + quant + ".onnx", providers = ['CPUExecutionProvider'])
        for batch_idx, (inputs, labels) in enumerate(tqdm(val_loader)):
            if quant == "fp16":
                inputs = inputs.half()
            inputs = inputs.cpu().numpy()
            
            if TVM:
                if quant == "fp16":
                    inputs = tvm.nd.array(inputs.astype("float16"))
                    start_time = time.time()    
                    outputs = [executor(inputs).numpy()]
                    end_time = time.time()
                else:
                    inputs = tvm.nd.array(inputs.astype("float32"))
                    start_time = time.time()    
                    outputs = [executor(inputs).numpy()]
                    end_time = time.time()
                
            else:
                start_time = time.time()    
                outputs = sess.run(None, {sess.get_inputs()[0].name: inputs})
                end_time = time.time()
            
            elapsed_time += end_time - start_time

            predicted_labels = np.argmax(outputs[0], axis=1)

            correct += (predicted_labels == labels.cpu().numpy()).sum()
            total += labels.size(0)

            top5_predicted = np.argsort(outputs[0], axis=1)[:, -5:]  
            top5_correct += np.sum(np.equal(top5_predicted, np.expand_dims(labels.cpu().numpy(), axis=1)))
    
    
    elif ONNX == False and quant_invest == True and quant != "":    # 1
        if quant == "fp16" :
            print("fp16 quantization investigation : ")
            quant_investigation(model_name + quant)
            print()
        elif quant == "int8" :
            print("int8 quantization investigation : ")
            quant_investigation(model_name + quant)
            print()
        elif quant == "fp32" : 
            print("fp32 quantization investigation : ")
            quant_investigation(model_name + quant)
            print()
        else :
            print("Error : quant is empty.")
        return
    
    
    else:  # 1
        print("Default Model Accuracy Calculate : ")
        model = model.eval()
        with torch.inference_mode():
            for batch_idx, (images, labels) in enumerate(tqdm(val_loader)):
                images, labels = images.to(device), labels.to(device)
                
                start_time = time.time()
                outputs = model(images)
                end_time = time.time()
                elapsed_time += end_time - start_time
                
                _, predicted = torch.max(outputs.softmax(dim = 1) * 100, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                _, top5_predicted = torch.topk(outputs.softmax(dim = 1) * 100, k = 5)
                top5_correct += sum(labels[i].item() in top5_predicted[i] for i in range(len(labels)))
    
    # ==================================================================================================================================================================================
    
    # Calculate accuracy
    if total != 0 :
        single_inference_runtime = elapsed_time / total
        top1Accuracy = correct / total
        top5Accuracy = top5_correct / total

        print('Single Inference Runtime: {:.4f} seconds\nTop 1 Accuracy : {:.2f}%\nTop 5 Accuracy : {:.2f}%\n'.format(single_inference_runtime, 100 * top1Accuracy, 100 * top5Accuracy))


print("Validation Done ...!")


Validation Done ...!


## **Load Model**

In [8]:
# model load
model = timm.create_model('vit_base_patch16_224', pretrained=True).to(device)
torch.save(model.state_dict(), 'vit_base_patch16_224.pth')
print("Model Saved ...!")

val_loader, val_loader_sub = preprocess(model)
print("Data Preprocessing Done.")


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Model Saved ...!
Compose(
    Resize(size=248, interpolation=bicubic, max_size=None, antialias=warn)
    CenterCrop(size=(224, 224))
    ToTensor()
    Normalize(mean=tensor([0.5000, 0.5000, 0.5000]), std=tensor([0.5000, 0.5000, 0.5000]))
)
Data Preprocessing Done.


## **Baseline Setup**

In [10]:
# Baseline Setup : Pytorch Model
validate(model, val_loader_sub, 'vit_base_patch16_224')


Default Model Accuracy Calculate : 


  0%|          | 0/1000 [00:00<?, ?it/s]

Single Inference Runtime: 0.0053 seconds
Top 1 Accuracy : 94.90%
Top 5 Accuracy : 99.70%



## **Onnx Export**

In [11]:
# Onnx Export : ONNX FP32
validate(model, val_loader_sub, 'vit_base_patch16_224', ONNX=True)


  assert condition, message


  0%|          | 0/1000 [00:00<?, ?it/s]

Single Inference Runtime: 0.1942 seconds
Top 1 Accuracy : 94.90%
Top 5 Accuracy : 99.70%



In [None]:
# Onnx Export : ONNX FP16
validate(model, val_loader_sub, 'vit_base_patch16_224', ONNX=True, sample_size=100, quant="fp16")


### ***Model Quantization int8***

In [None]:
# Model Quantization int8 : QDQ Int8
validate(model, val_loader_sub, 'vit_base_patch16_224', ONNX=True, sample_size=50, quant="int8") 


In [None]:
# Model QDQ Investigation
validate(model, val_loader_sub, 'vit_base_patch16_224', ONNX=False, quant="int8", quant_invest=True)


## **TVM**

In [None]:
# Model TVM with fp16 : ONNX FP16
validate(model, val_loader_sub, 'vit_base_patch16_224', ONNX = True, sample_size = 50, quant = 'fp16', TVM = True)


In [None]:
# Model TVM with quant int8 : QDQ Int8
validate(model, val_loader_sub, 'vit_base_patch16_224', ONNX=True, sample_size=50, quant='int8', TVM=True)


In [16]:
# Model TVM with quant fp32 : ONNX FP32
validate(model, val_loader_sub, 'vit_base_patch16_224', ONNX=True, sample_size=100, quant="fp32")
validate(model, val_loader_sub, 'vit_base_patch16_224', ONNX=True, sample_size=50, quant='fp32', TVM=True)


  0%|          | 0/1000 [00:00<?, ?it/s]

Single Inference Runtime: 1.4753 seconds
Top 1 Accuracy : 94.90%
Top 5 Accuracy : 99.70%

