First, we create the pre-trained ImageNet model.  We'll use ``resnet18`` from the torchvision package.  Make sure to set the device to ``cuda``, since the inputs and parameter devices are inferred from model.  Also make sure to set ``eval()`` to fix batch norm statistics.

In [1]:
import torchvision

model = torchvision.models.resnet18(pretrained=True).cuda().half().eval()

Next, we create some sample input that will be used to infer the shape and data types of our TensorRT engine

In [2]:
import torch

data = torch.randn((1, 3, 224, 224)).cuda().half()

Finally, create the optimized TensorRT engine.

In [3]:
from torch2trt import torch2trt

model_trt = torch2trt(model, [data], fp16_mode=True)

We can execute the network like this

In [4]:
output_trt = model_trt(data)

And check against the original output

In [8]:
output = model(data)

print(output.flatten()[0:10])
print(output_trt.flatten()[0:10])
print('max error: %f' % float(torch.max(torch.abs(output - output_trt))))

tensor([ 0.7231,  3.0195,  3.1016,  3.1152,  4.7539,  3.8301,  3.9180,  0.3086,
        -0.8726, -0.2261], device='cuda:0', dtype=torch.float16,
       grad_fn=<SliceBackward>)
tensor([ 0.7202,  3.0234,  3.1074,  3.1133,  4.7539,  3.8340,  3.9141,  0.3081,
        -0.8716, -0.2227], device='cuda:0', dtype=torch.float16)
max error: 0.011719


We can save the model like this

In [None]:
torch.save(model_trt.state_dict(), 'resnet18_trt.pth')

And load the model like this.

In [None]:
from torch2trt import TRTModule

model_trt = TRTModule()

model_trt.load_state_dict(torch.load('resnet18_trt.pth'))

That's it for this notebook!  Try out the live demo to see real-time classification on a video feed.