# Convert the trained pytorch model into ONNX format


In [1]:
import torch
import time
import matplotlib.pyplot as plt

import sys
import os

# Add the parent directory to the system path
sys.path.append(os.path.dirname(os.getcwd()))
print(os.getcwd())
# Now you can import the module using an absolute import
from src.models.CenterSpeed import *

model_name = 'onnx/tinycs_test.onnx'
input_file = '../src/trained_models/CenterSpeedDense.pt'

os.makedirs(os.path.dirname(model_name), exist_ok=True)
# Use the wanted model to export here: 
net = CenterSpeedDenseResidual(image_size=64)

print("Model Created")
net.load_state_dict(torch.load(input_file, map_location='cpu'))
print("Params Loaded")

# Create a random input tensor
randn_input = torch.randn(1, 6, 64, 64)

try:
    torch.onnx.export(net, randn_input, model_name)
    print("ONNX Exported")
except Exception as e:
    print(e)
    print("Failed to export ONNX")

/home/neil/catkin_opensource/src/os_racestack/perception/TinyCenterSpeed/deploy
Model Created
Params Loaded
ONNX Exported


  net.load_state_dict(torch.load(input_file, map_location='cpu'))


# Convert the ONNX model into a Nvidia TensorRT 
IMPORTANT: This needs to be run on the Jetson!

In [15]:
import tensorrt as trt
import onnx
import torch
import numpy as np

onnx_filename = "onnx/tinycs.onnx"
model_onnx = onnx.load(onnx_filename)
print("Model Loaded")

# Create a TensorRT builder and network
builder = trt.Builder(trt.Logger(trt.Logger.WARNING))
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))

# Create an ONNX-TensorRT backend
parser = trt.OnnxParser(network, builder.logger)
success = parser.parse_from_file(onnx_filename)
for idx in range(parser.num_errors):
    print(parser.get_error(idx))

if not success:
    print("ERROR")

#Optimization Config, THIS COULD BE OPTIMIZED
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30)
serial_engine = builder.build_serialized_network(network, config)
with open("onnx/tinycs.engine", "wb") as f:
    f.write(serial_engine)

print("Engine Built")

ModuleNotFoundError: No module named 'tensorrt'