In [None]:
%pip install ipywidgets --trusted-host pypi.org --trusted-host pypi.python.org --trusted-host=files.pythonhosted.org
%pip install wget
%pip install torch_tensorrt
%pip install tensorrt
%pip install timm
%pip install torchsummary
%pip install pytorch-quantization --extra-index-url https://pypi.ngc.nvidia.com

In [None]:
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
print(torch.__version__)

import torchvision.transforms as transforms
from torchvision import models, datasets
import torch_tensorrt
print(torch_tensorrt.__version__)

import pytorch_quantization
from pytorch_quantization import nn as quant_nn
from pytorch_quantization import quant_modules
from pytorch_quantization.tensor_quant import QuantDescriptor
from pytorch_quantization import calib
print(pytorch_quantization.__version__)

from tqdm import tqdm
import timm
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

import os
import sys
import warnings
import time
import numpy as np
import wget
import tarfile
import shutil
from PIL import Image
warnings.simplefilter('ignore')

2.0.0+nv23.05
1.4.0
2.1.2


## load model skeleton definition

In [None]:
# https://stackoverflow.com/questions/67631/how-can-i-import-a-module-dynamically-given-the-full-path
import importlib.util
spec = importlib.util.spec_from_file_location("module.name", os.environ['HOME']+"/work/transfer-learning/EfficientFormerV2/model.py")
efficientformerv2 = importlib.util.module_from_spec(spec)
sys.modules["module.name"] = efficientformerv2
spec.loader.exec_module(efficientformerv2)

In [None]:
model_name = "efficientformerv2_s1"
orig_wight = "/home/loongson/work/orin-demo/eformer_s1_450.pth"

In [None]:
create_model = getattr(efficientformerv2, model_name)
model = create_model(num_classes=10).cuda()
weights_dict = torch.load(orig_wight)
weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict
for k in list(weights_dict.keys()):
    if "head.weight" in k or "head.bias" in k:
    # if "dist_head" in k:
        del weights_dict[k]
model.load_state_dict(weights_dict, strict=False)
#model = timm.create_model(model_name+'.snap_dist_in1k', pretrained=True)
#model.head = nn.Linear(224, 10)
#model.head_dist = nn.Linear(224, 10)
#model = model.cuda()

_IncompatibleKeys(missing_keys=['head.weight', 'head.bias'], unexpected_keys=[])

### demo

In [None]:
data_transform = transforms.Compose(
    [transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)])
# [N, C, H, W]
img = data_transform(Image.open("../daisy.jpg"))
# expand batch dimension
img = torch.unsqueeze(img, dim=0).cuda()
model.eval()
with torch.no_grad():
    out = model(img)
    predict = torch.squeeze(out).cpu()
    for i in [985]: # daisy
        print("class: {:10}   prob: {:.11}".format(str(i), predict[i].numpy()))

In [None]:
from torchsummary import summary
summary(model, (3, 224, 224))
model.eval()

## prepare dataset and begin fune-tuning & transfer learning

### Download data

In [None]:
def download_data(DATA_DIR):
    if os.path.exists(DATA_DIR):
        if not os.path.exists(os.path.join(DATA_DIR, 'imagenette2-320')):
            url = 'https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz'
            wget.download(url)
            # open file
            file = tarfile.open('imagenette2-320.tgz')
            # extracting file
            file.extractall(DATA_DIR)
            file.close()
    else:
        print("This directory doesn't exist. Create the directory and run again")

In [None]:
if not os.path.exists("./data"):
    os.mkdir("./data")
download_data("./data")

### construct train and valid data

In [None]:
# Define main data directory
DATA_DIR = './data/imagenette2-320' 
# Define training and validation data paths
TRAIN_DIR = os.path.join(DATA_DIR, 'train') 
VAL_DIR = os.path.join(DATA_DIR, 'val')

In [None]:
#Performing Transformations on the dataset and defining training and validation dataloaders
transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
            ])
train_dataset = datasets.ImageFolder(TRAIN_DIR, transform=transform)
val_dataset = datasets.ImageFolder(VAL_DIR, transform=transform)
calib_dataset = torch.utils.data.random_split(val_dataset, [2901, 1024])[1]

train_dataloader = data.DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True)
val_dataloader = data.DataLoader(val_dataset, batch_size=64, shuffle=False, drop_last=True)
calib_dataloader = data.DataLoader(calib_dataset, batch_size=64, shuffle=False, drop_last=True)

# Use cross entropy loss for classification
criterion = nn.CrossEntropyLoss()

In [None]:
#Define functions for training, evalution, saving checkpoint and train parameter setting function
def train(model, dataloader, crit, opt, epoch):
    model.train()
    running_loss = 0.0
    for batch, (data, labels) in enumerate(dataloader):
        data, labels = data.cuda(), labels.cuda(non_blocking=True)
        opt.zero_grad()
        out = model(data)
        loss = crit(out, labels)
        loss.backward()
        opt.step()
        running_loss += loss.item()
        if batch % 100 == 99:
            print("Batch: [%5d | %5d] loss: %.3f" % (batch + 1, len(dataloader), running_loss / 100))
            running_loss = 0.0
        
def evaluate(model, dataloader, crit, non_blocking=True):
    total = 0
    correct = 0
    loss = 0.0
    class_probs = []
    class_preds = []
    model.eval()
    with torch.no_grad():
        for data, labels in dataloader:
            data, labels = data.cuda(), labels.cuda(non_blocking=non_blocking)
            out = model(data)
            loss += crit(out, labels)
            preds = torch.max(out, 1)[1]
            class_probs.append([F.softmax(i, dim=0) for i in out])
            class_preds.append(preds)
            total += labels.size(0)
            correct += (preds == labels).sum().item()

    evaluate_probs = torch.cat([torch.stack(batch) for batch in class_probs])
    evaluate_preds = torch.cat(class_preds)

    return loss / total, correct / total

def save_checkpoint(state, ckpt_path="checkpoint.pth"):
    torch.save(state, ckpt_path)
    print("Checkpoint saved")
    
cudnn.benchmark = True
def cpu_timestamp(*args, **kwargs):
    # perf_counter returns time in seconds
    return time.perf_counter()

# Helper function to benchmark the model
def benchmark(model, dtype='fp32', nwarmup=50, nruns=1000):
    data = iter(val_dataloader)
    images, _ = next(data)
    # expand batch dimension to [N, C, H, W]
    img = torch.unsqueeze(images[0], dim=0).cuda()

    if dtype=='fp16':
        img = img.half() # FIXME?
        
    with torch.no_grad():
        for _ in range(nwarmup):
            features = model(img)
    torch.cuda.synchronize()

    time_min = 1e5
    time_avg = 0
    time_max = 0
    with torch.no_grad():
        for i in range(nruns):
            start_time = cpu_timestamp()
            output = model(img)
            torch.cuda.synchronize()
            end_time = cpu_timestamp()
            time = (end_time - start_time)
            time_min = time if time < time_min else time_min
            time_max = time if time > time_max else time_max
            time_avg += time
    print("min = {:7.2f} ms\tmax = {:7.2f} ms\tavg = {:7.2f} ms".format(1000*time_min, 1000*time_max, 1000*time_avg/nruns))

### training...

In [None]:
# Train the model for 3 epochs to attain an acceptable accuracy.
num_epochs=10
best_acc = 0.
# Declare Learning rate
lr = 0.0001
#  Use SGD optimizer
optimizer = optim.SGD(model.parameters(), lr=lr)

for epoch in range(num_epochs):
    print('Epoch: [%5d / %5d] LR: %f' % (epoch + 1, num_epochs, lr))

    train(model, train_dataloader, criterion, optimizer, epoch)
    test_loss, test_acc = evaluate(model, val_dataloader, criterion)

    print("Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc))
    if best_acc < test_acc:
        best_acc = test_acc
        save_checkpoint({
            'epoch': epoch + 1,
            'model': model.state_dict(),
            'acc': best_acc,
            'opt_state_dict': optimizer.state_dict()
            }, ckpt_path=model_name+"_base_ckpt.pth")

Epoch: [    1 /    10] LR: 0.000100
Batch: [  100 |   295] loss: 2.291
Batch: [  200 |   295] loss: 2.232
Test Loss: 0.03353 Test Acc: 32.33%
Checkpoint saved
Epoch: [    2 /    10] LR: 0.000100
Batch: [  100 |   295] loss: 2.130
Batch: [  200 |   295] loss: 2.079
Test Loss: 0.03132 Test Acc: 48.39%
Checkpoint saved
Epoch: [    3 /    10] LR: 0.000100
Batch: [  100 |   295] loss: 1.975
Batch: [  200 |   295] loss: 1.913
Test Loss: 0.02881 Test Acc: 60.02%
Checkpoint saved
Epoch: [    4 /    10] LR: 0.000100
Batch: [  100 |   295] loss: 1.809
Batch: [  200 |   295] loss: 1.743
Test Loss: 0.02603 Test Acc: 66.52%
Checkpoint saved
Epoch: [    5 /    10] LR: 0.000100
Batch: [  100 |   295] loss: 1.635
Batch: [  200 |   295] loss: 1.580
Test Loss: 0.02353 Test Acc: 72.87%
Checkpoint saved
Epoch: [    6 /    10] LR: 0.000100
Batch: [  100 |   295] loss: 1.486
Batch: [  200 |   295] loss: 1.431
Test Loss: 0.02113 Test Acc: 76.20%
Checkpoint saved
Epoch: [    7 /    10] LR: 0.000100
Batch: [  

## begin eval and benchmark

In [None]:
weights_dict = torch.load(model_name+"_base_ckpt.pth") # map_location=args.device
weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict
create_model = getattr(efficientformerv2, model_name)
model = create_model(num_classes=10).cuda()
model.load_state_dict(weights_dict)
model.eval()

EfficientFormerV2(
  (patch_embed): Sequential(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): GELU(approximate='none')
    (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): GELU(approximate='none')
  )
  (network): ModuleList(
    (0): Sequential(
      (0): FFN(
        (mlp): Mlp(
          (fc1): Conv2d(32, 128, kernel_size=(1, 1), stride=(1, 1))
          (act): GELU(approximate='none')
          (fc2): Conv2d(128, 32, kernel_size=(1, 1), stride=(1, 1))
          (drop): Dropout(p=0.0, inplace=False)
          (mid): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128)
          (mid_norm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (norm1): BatchNorm2d(128, eps=1e-05

### raw model precision and performance

In [None]:
#Evaluate the baseline model
test_loss, test_acc = evaluate(model, val_dataloader, criterion)
print(model_name+" Baseline accuracy: {:.2f}%".format(100 * test_acc))

efficientformerv2_s1 Baseline accuracy: 83.56%


In [None]:
#benchmark the performance of the baseline model
benchmark(model, nruns=400)

min =   26.74 ms	max =   43.05 ms	avg =   27.03 ms


### trace model precision and performance

In [None]:
with torch.no_grad():
    traced_model = torch.jit.trace(model, torch.randn((1,3,224,224)).cuda())
    torch.jit.save(traced_model, model_name+"_base.jit.pt")
traced_model = torch.jit.load(model_name+"_base.jit.pt").eval()

In [None]:
#Evaluate the baseline model
traced_model = torch.jit.trace(model, torch.randn((1,3,224,224)).cuda()).eval()
test_loss, test_acc = evaluate(traced_model, val_dataloader, criterion)
print(model_name+" trace accuracy: {:.2f}%".format(100 * test_acc))

efficientformerv2_s1 trace accuracy: 83.56%


In [None]:
benchmark(traced_model, nruns=100)

min =   14.45 ms	max =   16.71 ms	avg =   14.70 ms


### script model precision and performance

In [None]:
script_model = torch.jit.script(model).eval()
test_loss, test_acc = evaluate(script_model, val_dataloader, criterion)
print(model_name+" script accuracy: {:.2f}%".format(100 * test_acc))

efficientformerv2_s1 script accuracy: 83.56%


In [None]:
benchmark(script_model, nruns=100)

min =   15.39 ms	max =   16.88 ms	avg =   15.64 ms


### Compile to Torch

In [None]:
# redefine val dataloader batch size
val_dataloader = data.DataLoader(val_dataset, batch_size=1, shuffle=False, drop_last=True)

#### TRT FP32 model on CUDA GPU

In [None]:
# benchmark the performance of the baseline TRT model (TRT FP32 Model)
compile_spec = {
    "inputs": [torch_tensorrt.Input([1, 3, 224, 224])],
    }

script_fp32_model = torch_tensorrt.compile(script_model, **compile_spec)
test_loss, test_acc = evaluate(script_fp32_model, val_dataloader, criterion)
print(model_name+" script accuracy: {:.2f}%".format(100 * test_acc))



efficientformerv2_s1 script accuracy: 83.49%


In [None]:
benchmark(script_fp32_model, nruns=1000)

min =    2.39 ms	max =    6.42 ms	avg =    2.63 ms


In [None]:
trace_fp32_model = torch_tensorrt.compile(traced_model, **compile_spec)
test_loss, test_acc = evaluate(trace_fp32_model, val_dataloader, criterion)
print(model_name+" script accuracy: {:.2f}%".format(100 * test_acc))



efficientformerv2_s1 script accuracy: 83.49%


In [None]:
benchmark(trace_fp32_model, nruns=1000)

min =    2.33 ms	max =    5.90 ms	avg =    2.56 ms


In [None]:
trt_fp32_model = torch_tensorrt.compile(model, **compile_spec)
test_loss, test_acc = evaluate(trt_fp32_model, val_dataloader, criterion)
print(model_name+" script accuracy: {:.2f}%".format(100 * test_acc))



efficientformerv2_s1 script accuracy: 83.49%


In [None]:
benchmark(trt_fp32_model, nruns=1000)

min =    2.33 ms	max =    5.37 ms	avg =    2.61 ms


如上可以看出直接输入model，让`torch_tensorrt.compile`来自动处理是最方便的，性能方面也无需顾虑。

#### TRT FP16 model on CUDA

In [None]:
script_model = torch.jit.script(model).eval()

In [None]:
# benchmark the performance of the baseline TRT model (TRT FP16 Model)
compile_spec = {
    "inputs": [torch_tensorrt.Input([1, 3, 224, 224])],
    "enabled_precisions": {torch.half},  # Run with fp16
    }
trt_gpu_fp16_model = torch_tensorrt.compile(model, **compile_spec)
test_loss, test_acc = evaluate(trt_gpu_fp16_model, val_dataloader, criterion)
print(model_name+" script accuracy: {:.2f}%".format(100 * test_acc))



efficientformerv2_s1 script accuracy: 83.52%


In [None]:
benchmark(trt_gpu_fp16_model, nruns=2000)

min =    1.86 ms	max =    5.96 ms	avg =    2.22 ms


In [None]:
spec = {
    "forward": torch_tensorrt.ts.TensorRTCompileSpec(
        **{
            "inputs": [torch_tensorrt.Input([1, 3, 224, 224])],
            "enabled_precisions": {torch.half},
            "refit": False,
            "debug": False,
            "device": {
                "device_type": torch_tensorrt.DeviceType.GPU,
                "gpu_id": 0,
                "dla_core": 0,
                "allow_gpu_fallback": True,
            },
            "capability": torch_tensorrt.EngineCapability.default,
            "num_avg_timing_iters": 1,
        }
    )
}
gpu_fp16_model = torch._C._jit_to_backend("tensorrt", script_model, spec)
test_loss, test_acc = evaluate(gpu_fp16_model, val_dataloader, criterion)
print(model_name+" script accuracy: {:.2f}%".format(100 * test_acc))



efficientformerv2_s1 script accuracy: 83.57%


In [None]:
benchmark(gpu_fp16_model, nruns=2000)

min =    1.88 ms	max =    4.32 ms	avg =    2.33 ms


DLA part

In [None]:
# benchmark the performance of the baseline TRT model (TRT FP16 Model)
if False:
    # very curious why the following code stuck!!!
    compile_spec = {
        "inputs": [torch_tensorrt.Input([1, 3, 224, 224])],
        "enabled_precisions": {torch.half},  # Run with FP16
        "device": torch_tensorrt.Device("dla:0", allow_gpu_fallback=True)  # Run with DLA
        }
    trt_dla_fp16_model = torch_tensorrt.compile(model, **compile_spec)
else:
    # python ~/work/transfer-learning/EfficientFormerV2/predict.py --benchmark --weights efficientformerv2_s1_base_ckpt.pth --factor s1 --num_classes 10 --device cuda --mode trt --test_iter 1000
    trt_dla_fp16_model = torch.jit.load("trt_dla_fp16_model.ts")
test_loss, test_acc = evaluate(trt_dla_fp16_model, val_dataloader, criterion, False)
print(model_name+" script accuracy: {:.2f}%".format(100 * test_acc))

efficientformerv2_s1 script accuracy: 82.83%


In [None]:
benchmark(trt_dla_fp16_model, nruns=400)

min =   25.87 ms	max =   64.58 ms	avg =   26.93 ms


In [None]:
spec = {
    "forward": torch_tensorrt.ts.TensorRTCompileSpec(
        **{
            "inputs": [torch_tensorrt.Input([1, 3, 224, 224])],
            "enabled_precisions": {torch.half},
            "refit": False,
            "debug": False,
            "device": {
                "device_type": torch_tensorrt.DeviceType.DLA,
                "gpu_id": 0,
                "dla_core": 0,
                "allow_gpu_fallback": True,
            },
            "capability": torch_tensorrt.EngineCapability.default,
            "num_avg_timing_iters": 1,
        }
    )
}
dla_fp16_model = torch._C._jit_to_backend("tensorrt", script_model, spec)
test_loss, test_acc = evaluate(dla_fp16_model, val_dataloader, criterion, False)
print(model_name+" script accuracy: {:.2f}%".format(100 * test_acc))
# efficientformerv2_s1 script accuracy: 82.83%

In [None]:
benchmark(dla_fp16_model, nruns=400)

min =   26.01 ms	max =   52.46 ms	avg =   26.70 ms


#### TFT int 8 model on CUDA | Post Training Quantization (PTQ)

In [None]:
calib_dataloader = data.DataLoader(calib_dataset, batch_size=64, shuffle=False, drop_last=True)
val_dataloader = data.DataLoader(val_dataset, batch_size=64, shuffle=False, drop_last=True)
criterion = nn.CrossEntropyLoss()

calibrator = torch_tensorrt.ptq.DataLoaderCalibrator(
    calib_dataloader,
    use_cache=False,
    algo_type=torch_tensorrt.ptq.CalibrationAlgo.ENTROPY_CALIBRATION_2,
    device=torch.device('cuda'))

compile_spec = {
         "inputs": [torch_tensorrt.Input([64, 3, 224, 224])],
         "enabled_precisions": torch.int8,
         "calibrator": calibrator,
     }
ptq_int8_model = torch_tensorrt.compile(model, **compile_spec)

test_loss, test_acc = evaluate(ptq_int8_model, val_dataloader, criterion, 0)
print(model_name+" script accuracy: {:.2f}%".format(100 * test_acc))

In [None]:
benchmark(ptq_int8_model, nruns=2000)

#### TFT int 8 model on CUDA | Quantization Aware Training (QAT)

In [None]:
quant_modules.initialize()

In [None]:
#This function allows you to set the all the parameters to not have gradients, 
#allowing you to freeze the model and not undergo training during the train step. 
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

In [None]:
# All the regular conv, FC layers will be converted to their quantized counterparts due to quant_modules.initialize()
# feature_extract = False
create_model = getattr(efficientformerv2, model_name)
q_model = create_model(num_classes=10).cuda()
# set_parameter_requires_grad(q_model, feature_extract)

# mobilenetv2_base_ckpt is the checkpoint generated from Step 2 : Training a baseline Mobilenetv2 model.
ckpt = torch.load(model_name+"_base_ckpt.pth")
modified_state_dict={}
for key, val in ckpt["model"].items():
    # Remove 'module.' from the key names
    if key.startswith('module'):
        modified_state_dict[key[7:]] = val
    else:
        modified_state_dict[key] = val

# Load the pre-trained checkpoint
q_model.load_state_dict(modified_state_dict)
lr = 0.0001
optimizer = optim.SGD(q_model.parameters(), lr=lr)
optimizer.load_state_dict(ckpt["opt_state_dict"])

In [None]:
def compute_amax(model, **kwargs):
    # Load calib result
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                if isinstance(module._calibrator, calib.MaxCalibrator):
                    module.load_calib_amax()
                else:
                    module.load_calib_amax(**kwargs)
    model.cuda()

def collect_stats(model, data_loader, num_batches):
    """Feed data to the network and collect statistics"""
    # Enable calibrators
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.disable_quant()
                module.enable_calib()
            else:
                module.disable()

    # Feed data to the network for collecting stats
    for i, (image, _) in tqdm(enumerate(data_loader), total=num_batches):
        model(image.cuda())
        if i >= num_batches:
            break

    # Disable calibrators
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.enable_quant()
                module.disable_calib()
            else:
                module.enable()

In [None]:
#Calibrate the model using percentile calibration technique.
with torch.no_grad():
    collect_stats(q_model, train_dataloader, num_batches=32)
    compute_amax(q_model, method="max")

In [None]:
# Finetune the QAT model for 2 epochs
train_dataloader = data.DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True)
val_dataloader = data.DataLoader(val_dataset, batch_size=64, shuffle=False, drop_last=True)
num_epochs=10
lr = 0.001
best_acc = 0.
for epoch in range(num_epochs):
    print('Epoch: [%5d / %5d] LR: %f' % (epoch + 1, num_epochs, lr))

    train(q_model, train_dataloader, criterion, optimizer, epoch)
    test_loss, test_acc = evaluate(q_model, val_dataloader, criterion, epoch)

    print("Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc))
    if best_acc < test_acc:
        best_acc = test_acc
        save_checkpoint({'epoch': epoch + 1,
                 'model': q_model.state_dict(),
                 'acc': best_acc,
                 'opt_state_dict': optimizer.state_dict()
                },
                ckpt_path=model_name+"_qat_ckpt.pth")

In [None]:
q_model.eval()

In [None]:
quant_nn.TensorQuantizer.use_fb_fake_quant = True
with torch.no_grad():
    jit_model = torch.jit.script(q_model, torch.randn((1,3,224,224)).cuda())
    torch.jit.save(jit_model, model_name+"_qat.jit.pt")

In [None]:
qat_model = torch.jit.load(model_name+"_qat.jit.pt").eval()
compile_spec = {"inputs": [torch_tensorrt.Input([1, 3, 224, 224])],
                "enabled_precisions": torch.int8,
                # "truncate_long_and_double": True,
               }
qat_int8_model = torch_tensorrt.compile(q_model.eval(), **compile_spec)
val_dataloader = data.DataLoader(val_dataset, batch_size=1, shuffle=False, drop_last=True)
test_loss, test_acc = evaluate(qat_int8_model, val_dataloader, criterion, 0)
print(model_name+" script accuracy: {:.2f}%".format(100 * test_acc))

In [None]:
benchmark(qat_int8_model, nruns=2000)