First, we create the pre-trained ImageNet model.  We'll use ``resnet18`` from the torchvision package.  Make sure to set the device to ``cuda``, since the inputs and parameter devices are inferred from model.  Also make sure to set ``eval()`` to fix batch norm statistics.

In [1]:
import torchvision

model = torchvision.models.resnet18(pretrained=True).cuda().half().eval()

Next, we create some sample input that will be used to infer the shape and data types of our TensorRT engine

In [2]:
import torch

data = torch.randn((1, 3, 224, 224)).cuda().half()

Finally, create the optimized TensorRT engine.

In [3]:
from torch2trt import torch2trt

model_trt = torch2trt(model, [data], fp16_mode=True)

We can execute the network like this

In [4]:
output_trt = model_trt(data)

And check against the original output

In [8]:
output = model(data)

print(output.flatten()[0:10])
print(output_trt.flatten()[0:10])
print('max error: %f' % float(torch.max(torch.abs(output - output_trt))))

tensor([ 0.7231,  3.0195,  3.1016,  3.1152,  4.7539,  3.8301,  3.9180,  0.3086,
        -0.8726, -0.2261], device='cuda:0', dtype=torch.float16,
       grad_fn=<SliceBackward>)
tensor([ 0.7202,  3.0234,  3.1074,  3.1133,  4.7539,  3.8340,  3.9141,  0.3081,
        -0.8716, -0.2227], device='cuda:0', dtype=torch.float16)
max error: 0.011719


In [15]:
import json

with open('imagenet_labels.json', 'r') as f:
    labels = json.load(f)

In [18]:
import cv2
import numpy as np

device = torch.device('cuda')
mean = 255.0 * np.array([0.485, 0.456, 0.406])
stdev = 255.0 * np.array([0.229, 0.224, 0.225])

normalize = torchvision.transforms.Normalize(mean, stdev)

def preprocess(camera_value):
    global device, normalize
    x = camera_value
    x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
    x = x.transpose((2, 0, 1))
    x = torch.from_numpy(x).float()
    x = normalize(x)
    x = x.to(device)
    x = x[None, ...]
    return x

In [17]:
from jetbot import Camera
import ipywidgets

camera = Camera(width=224, height=224)

In [19]:
from jetbot import bgr8_to_jpeg
import traitlets

image_w = ipywidgets.Image()

traitlets.dlink((camera, 'value'), (image_w, 'value'), transform=bgr8_to_jpeg)

display(image_w)

Image(value=b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xdb\x00C\x00\x02\x01\x0…

In [20]:
text = ipywidgets.Textarea()
display(text)

Textarea(value='')

In [78]:
def execute(change):
    image = change['new']
    output = model_trt(preprocess(image).half()).detach().cpu().numpy().flatten()
    idx = output.argmax()
    text.value = labels[idx]

execute({'new': camera.value})