In [3]:
# Importing all the required modules
import tensorrt as trt
import onnx
import torch
import numpy as np
import time
import torchvision.models as models

### Step 1: Convert the model to ONNX format

In [4]:
# Load the pretrained ResNet-50 model from torchvision library
model_resnet_50 = models.resnet50(pretrained=True)
model_resnet_50.eval()

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /home/lkondap/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████████████████████████████████████████████████████████████████████████████████████████| 97.8M/97.8M [00:00<00:00, 152MB/s]


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [6]:
# This block creates a dummy input and exports the model to ONNX
input_shape = (3,224,224)

dummy_input = torch.randn(1, *input_shape) # 1 here represents batch_size

# Emporting the model to ONNX
torch.onnx.export(model_resnet_50, dummy_input, "resnet_50.onnx", verbose=False)

# Notes: Setting the verbose to True is making the export run a lot slower

### Step 2: Optimize ONNX model with TensorRT

In [18]:
# Load the exported ONNX model
onnx_model = onnx.load("resnet_50.onnx")

# Create a TensorRT builder and network
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(TRT_LOGGER)
network = builder.create_network(1) # Notes 
config = builder.create_builder_config()
parser = trt.OnnxParser(network, TRT_LOGGER)
#builder.max_DLA_batch_size = 1
config.set_flag(trt.BuilderFlag.FP16)
#config.max_workspace_size = 1 << 30 #This equals 1GB

# Parse the ONNX model
success = parser.parse(onnx_model.SerializeToString())
if not success:
    for error in range(parser.num_errors):
        print(parser.get_error(error))
    exit()

# Build the TensorRT engine
engine = builder.build_serialized_network(network, config)

# Serialize the engine to a file
#with open("resnet_50.engine", "wb") as f:
#    f.write(engine.serialize())

In [19]:
with open("resnet_50.engine", "wb") as f:
    f.write(engine)

#### Notes:
1. EXPLICIT_BATCH is deprecated as of TensorRT 10.0, otherwise an argument 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) needs to be passed to create_network
2. 