In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
import numpy as np

# Load the pre-trained ResNet-18 model
res_net = models.resnet18(weights="IMAGENET1K_V1")
res_net.eval()

# Generate a random input tensor
test_input = torch.randn(1, 3, 244, 244)

# Export the model to ONNX format
torch.onnx.export(
    res_net,
    test_input,
    "resnet18.onnx",
    opset_version=12,
    export_params=True,
    do_constant_folding=True,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={'input' : {0 : 'batch_size'},
                  'output' : {0 : 'batch_size'}}
)

In [None]:
import numpy as np
import onnxruntime
import onnx
from PIL import Image
import json
import time

## You will need to Install onnx and onnxruntime
# pip install onnx

# If you don't have a GPU install cpu version
# pip install onnxruntime

# If you have a GPU install gpu version
# pip install onnxruntime-gpu

# Make sure you install the correct version for your version of CUDA!
# Also check dependencies!
# https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html
# EG for CUDA version 12.2 use 
# pip install onnxruntime-gpu==1.17

# NOTE Pytorch has it's own cuDNN that gets installed with torch
# If you want to use other applications that need cuDNNm like onnxruntime-gpu (without having to import torch)
# You need to install cuDNN separately (it doesn't come with NVIDIA Toolkit)
# NOTE: at tim eof writing only cuDNN 8.X version are supported!!
# https://docs.nvidia.com/deeplearning/cudnn/archives/cudnn-890/install-guide/index.html

# Load the ONNX model
onnx_model = onnx.load("resnet18.onnx")
onnx.checker.check_model(onnx_model)

# Create an ONNX Runtime inference session with GPU support
ort_session = onnxruntime.InferenceSession("./resnet18.onnx", providers=['CUDAExecutionProvider'])

with open("../data/imagenet_classes.json", "r") as file:
    img_net_classes = json.load(file)

In [None]:
def crop_resize(image, new_size):
    # Get the dimensions of the original image
    width, height = image.size

    # Calculate the size of the square crop
    min_dim = min(width, height)

    # Calculate coordinates for the center crop
    left = (width - min_dim) // 2
    upper = (height - min_dim) // 2
    right = left + min_dim
    lower = upper + min_dim

    # Crop the image to a square
    square_image = image.crop((left, upper, right, lower))

    # Resize the image to the specified size
    resized_image = square_image.resize((new_size, new_size))

    return resized_image

In [None]:
def image_normalise_reshape(image, mean, std):
    h, w, c = image.shape    
    image = image.transpose((2, 0, 1))/255
    
    np_means = np.array(mean).reshape(c, 1, 1)
    np_stds = np.array(std).reshape(c, 1, 1)
    
    norm_image = (image - np_means)/(np_stds + 1e-6)
    
    return np.expand_dims(norm_image, 0).astype(np.float32)
    

In [None]:
test_image = crop_resize(Image.open("../data/dog.jpg"), 244)
test_image

In [None]:
np_image = np.array(test_image)
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
norm_image = image_normalise_reshape(np_image, mean, std)
# Should also work with batch of images!
# norm_image_batch = np.concatenate((norm_image, norm_image), 0)

# Prepare the inputs for ONNX Runtime
onnxruntime_input = {ort_session.get_inputs()[0].name: norm_image}

start_time = time.time()
# Run inference
onnxruntime_outputs = ort_session.run(None, onnxruntime_input)
end_time = time.time()

# Print the outputs
print("ONNX Runtime outputs:")
for output in onnxruntime_outputs:
    class_index = np.argmax(output)
    print("Class index:", class_index)
    print("Class Label:", img_net_classes[str(class_index)])
    
print("Time to run: %.4fs" % (end_time - start_time))