First, we create the pre-trained ImageNet model.  We'll use ``resnet18`` from the torchvision package.  Make sure to set the device to ``cuda``, since the inputs and parameter devices are inferred from model.  Also make sure to set ``eval()`` to fix batch norm statistics.

In [15]:
import torchvision

model = torchvision.models.resnet18(pretrained=True).cuda().half().eval()

Next, we create some sample input that will be used to infer the shape and data types of our TensorRT engine

In [16]:
import torch

data = torch.randn((1, 3, 224, 224)).cuda().half()

Finally, create the optimized TensorRT engine.

In [17]:
from torch2trt import torch2trt

model_trt = torch2trt(model, [data], fp16_mode=True)

We can execute the network like this

In [18]:
output_trt = model_trt(data)

And check against the original output

In [19]:
output = model(data)

print(output.flatten()[0:10])
print(output_trt.flatten()[0:10])
print('max error: %f' % float(torch.max(torch.abs(output - output_trt))))

tensor([ 0.6851,  3.0117,  2.8340,  2.6797,  4.4141,  3.6270,  4.0781,  0.2634,
        -0.8716, -0.4800], device='cuda:0', dtype=torch.float16,
       grad_fn=<SliceBackward0>)
tensor([ 0.6851,  3.0098,  2.8340,  2.6777,  4.4141,  3.6309,  4.0703,  0.2661,
        -0.8691, -0.4841], device='cuda:0', dtype=torch.float16)
max error: 0.014648


We can save the model like this

In [20]:
torch.save(model.state_dict(), 'resnet18.pth')
torch.save(model_trt.state_dict(), 'resnet18_trt.pth')

And load the model like this.

In [21]:
from torch2trt import TRTModule

model_trt = TRTModule()

model_trt.load_state_dict(torch.load('resnet18_trt.pth'))

<All keys matched successfully>

That's it for this notebook!  Try out the live demo to see real-time classification on a video feed.