In [None]:
%load_ext autoreload
%autoreload 2



In [2]:
import torch
import torchvision.models as models
from torchvision.io import decode_image

# Load pretrained ResNet18 model
resnet18 = models.resnet18(pretrained=True)
resnet18.eval()  # Set to evaluation mode



ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [None]:
import requests

img = decode_image("./_static/img/cat.jpg")
# Convert to float and scale to [0, 1]
img = img.float() / 255.0

# Add batch dimension and ensure 3 channels (C,H,W format)
img = img.unsqueeze(0)

# Apply standard ImageNet normalization
normalize = torch.nn.functional.normalize
mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
img = (img - mean) / std

output = resnet18(img)

_, prediction = torch.max(output, 1)
prediction

# Download ImageNet labels

# Get the labels from a standard ImageNet labels file
LABELS_URL = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
response = requests.get(LABELS_URL)
labels = [line.strip() for line in response.text.split("\n")]

# Get prediction index and corresponding label
_, prediction = torch.max(output, 1)
class_idx = prediction.item()
class_name = labels[class_idx]

print(f"Predicted class index: {class_idx}")
print(f"Predicted class name: {class_name}")
# resnet18(img)
img.shape

torch.Size([1, 3, 224, 224])

In [23]:

batch_size = 4
# Input to the model
x = torch.randn(batch_size, 3, 224, 224, requires_grad=True)

scripted_model = torch.jit.trace(resnet18, x)
# Optional: optimize the traced model
optimized = torch.jit.optimize_for_inference(scripted_model)

torch_out = optimized(x)

# Export the model
torch.onnx.export(optimized,               # model being run
                  x,                         # model input (or a tuple for multiple inputs)
                  "resnet18-fused.onnx",   # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=20,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'], # the model's output names
                  dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                                'output' : {0 : 'batch_size'}})



In [17]:
import onnx

onnx_model = onnx.load("resnet18.onnx")
onnx.checker.check_model(onnx_model)

In [21]:
import onnxruntime
import numpy as np

session_options = onnxruntime.SessionOptions()
session_options.enable_profiling = True  # Enable profiling
ort_session = onnxruntime.InferenceSession("resnet18.onnx", providers=["CPUExecutionProvider"], sess_options=session_options)

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# compute ONNX Runtime output prediction
ort_inputs = {'input': to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)

profile_file = ort_session.end_profiling()
print(f"Profiling data saved to: {profile_file}")

# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)

print("Exported model has been tested with ONNXRuntime, and the result looks good!")

Profiling data saved to: onnxruntime_profile__2025-02-03_08-21-51.json
Exported model has been tested with ONNXRuntime, and the result looks good!


In [20]:
# Get predictions from ONNX model
predictions = ort_session.run(None, {'input': to_numpy(img)})[0]

# Find max probability and class index (equivalent to torch.max)
max_prob_idx = np.argmax(predictions, axis=1)
max_prob = np.max(predictions, axis=1)

print(f"Predicted class index: {max_prob_idx}")
print(f"Prediction probability: {max_prob}")


Predicted class index: [281]
Prediction probability: [10.75602]


In [4]:

from onnxruntime.quantization import quantize_dynamic, QuantType, quant_pre_process

quant_pre_process("resnet18.onnx",
                  "resnet18-preprocessed.onnx",
                  auto_merge=True,
                  guess_output_rank=True,
                  verbose=True)

# Quantize the model to 8-bit integers
quantize_dynamic("resnet18-preprocessed.onnx", 
                 "resnet18-quantized.onnx",
                 weight_type=QuantType.QUInt8)