<a href="https://colab.research.google.com/github/abhaymise/tutorials/blob/main/ai/deployment/onnx_export_and_inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.onnx
import onnx
import onnxruntime

# Load your PyTorch model
model = torch.load('path/to/your/model.pth')
model.eval()

# Create a dummy input with variable batch size
dummy_input = torch.randn(1, input_channels, input_height, input_width)  # Example: 1 sample with fixed input shape
input_names = ['input']

# Export the PyTorch model to ONNX with dynamic batch size support
output_onnx_path = 'path/to/your/model.onnx'

# Set dynamic_axes to mark the batch dimension as dynamic
dynamic_axes = {'input': {0: 'batch_size'}}
torch.onnx.export(model, dummy_input, output_onnx_path, input_names=input_names, dynamic_axes=dynamic_axes)

# Load the exported ONNX model with ONNX Runtime
ort_session = onnxruntime.InferenceSession(output_onnx_path)

# Perform inference with a variable batch size
batch_sizes = [1, 2, 4]  # Example: Different batch sizes
for batch_size in batch_sizes:
    dynamic_input = torch.randn(batch_size, input_channels, input_height, input_width)
    dynamic_input_name = ort_session.get_inputs()[0].name
    output = ort_session.run(None, {dynamic_input_name: dynamic_input.numpy()})
    print(f"Inference result for batch size {batch_size}: {output}")


## Sample Resnet Model Case Study


In [None]:
import torch
import onnx
import onnxruntime
import numpy as np
from torchvision import models, transforms

# Step 1: Load your PyTorch model
model = models.resnet18(pretrained=True)
model.eval()

# Step 2: Specify the input shape and dummy input
input_shape = (3, 224, 224)
dummy_input = torch.randn(1, *input_shape)

# Step 3: Export the PyTorch model to ONNX with dynamic batch size support
output_onnx_path = 'path/to/your/model.onnx'
dynamic_axes = {'input': {0: 'batch_size'}}
torch.onnx.export(
    model,
    dummy_input,
    output_onnx_path,
    verbose=True,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes=dynamic_axes
)

# Step 4: Load the ONNX model with ONNX Runtime
ort_session = onnxruntime.InferenceSession(output_onnx_path)

# Step 5: Preprocess the input data
def preprocess_image(image_path):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    image = Image.open(image_path).convert('RGB')
    input_data = transform(image)
    input_data = input_data.unsqueeze(0)  # Add batch dimension
    return input_data.numpy()

# Example usage:
image_path = 'path/to/your/image.jpg'
input_data = preprocess_image(image_path)

# Step 6: Run inference with ONNX Runtime
input_name = ort_session.get_inputs()[0].name
output_name = ort_session.get_outputs()[0].name
result = ort_session.run([output_name], {input_name: input_data})

# Step 7: Post-process the result as needed
output_result = result[0]

# Display or save the result as needed
print("Model output:", output_result)


## Optimization options

In [None]:
from onnxruntime import SessionOptions, ExecutionMode
from onnxruntime import GraphOptimizationLevel

session_options = SessionOptions()
session_options.execution_mode = ExecutionMode.ORT_PARALLEL  # Set to parallel execution mode
session_options.intra_op_num_threads = 4  # Number of threads per operator
session_options.inter_op_num_threads = 2  # Number of threads across operators


session_options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_EXTENDED
ort_session = onnxruntime.InferenceSession(onnx_model_path, sess_options=session_options)
# Run inference
input_name = ort_session.get_inputs()[0].name
output_name = ort_session.get_outputs()[0].name
result = ort_session.run([output_name], {input_name: input_image})
