In [None]:
import torch
import torchvision
import torch2trt

In [None]:
model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True)

In [None]:
model = model.cuda().eval().half()

In [None]:
class ModelWrapper(torch.nn.Module):
    def __init__(self, model):
        super(ModelWrapper, self).__init__()
        self.model = model
    def forward(self, x):
        return self.model(x)['out']

In [None]:
model_w = ModelWrapper(model).half()

In [None]:
data = torch.ones((1, 3, 224, 224)).cuda().half()

In [None]:
model_trt = torch2trt.torch2trt(model_w, [data], fp16_mode=True)

# Live demo

In [None]:
# from jetcam.csi_camera import CSICamera
from jetcam.usb_camera import USBCamera

# camera = CSICamera(width=224, height=224)
camera = USBCamera(width=224, height=224)

camera.running = True

In [None]:
from jetcam.utils import bgr8_to_jpeg
import traitlets
import ipywidgets

image_w = ipywidgets.Image()

traitlets.dlink((camera, 'value'), (image_w, 'value'), transform=bgr8_to_jpeg)

display(image_w)

In [None]:
import cv2
import numpy as np
import torchvision

device = torch.device('cuda')
mean = 255.0 * np.array([0.485, 0.456, 0.406])
stdev = 255.0 * np.array([0.229, 0.224, 0.225])

normalize = torchvision.transforms.Normalize(mean, stdev)

def preprocess(camera_value):
    global device, normalize
    x = camera_value
    x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
    x = x.transpose((2, 0, 1))
    x = torch.from_numpy(x).float()
    x = normalize(x)
    x = x.to(device)
    x = x[None, ...]
    return x

In [None]:
seg_image = ipywidgets.Image()

display(seg_image)

In [None]:
def execute(change):
    image = change['new']
    output = model_trt(preprocess(camera.value).half())[0].detach().cpu().float().numpy()
    mask = 1.0 * (output.argmax(0) == 15)
    seg_image.value = bgr8_to_jpeg(mask[:, :, None] * image)
    
    
mask = execute({'new': camera.value})
# camera.observe(execute, names='value')

In [None]:
camera.observe(execute, names='value')

In [None]:
camera.unobserve(execute, names='value')

In [None]:
import time

torch.cuda.current_stream().synchronize()
t0 = time.time()
for i in range(100):
    output = model_w(preprocess(camera.value).half())
torch.cuda.current_stream().synchronize()
t1 = time.time()

print(100.0 / (t1 - t0))