In [1]:
import traitlets
import os
from jetbot import Camera, bgr8_to_jpeg
import ipywidgets.widgets as widgets
from IPython.display import display
from uuid import uuid1
import json
import glob
import datetime
import torch
import torchvision

In [2]:
camera = Camera.instance(width=224, height=224)

image_widget = widgets.Image(format='jpeg', width=300, height=300)

traitlets.dlink((camera, 'value'), (image_widget, 'value'), transform=bgr8_to_jpeg)

display(image_widget)

Image(value=b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xdb\x00C\x00\x02\x01\x0…

In [3]:
THUMBNAIL_DIR = 'thumbnails'

class_names = [
    'cross',
    'left',
    'right',
    'straight',
    't_left',
    't_right',
    't_straight'
]


# create thumbnail widgets
class_widgets = []
prob_widgets = []
for i, name in enumerate(class_names):
    
    # get thumbnail widget
    thumbnail_widget = widgets.Image(format='jpeg', width=80, height=80)
    thumbnail_path = os.path.join(THUMBNAIL_DIR, name + '.jpg')
    with open(thumbnail_path, 'rb') as f:
        thumbnail_widget.value = f.read()
    
    # get count widget
    prob_widget = widgets.FloatSlider(min=0.0, max=1.0, step=0.001, orientation='vertical')
    prob_widgets.append(prob_widget)
    
    class_widget = widgets.VBox([
        thumbnail_widget,
        prob_widget
    ])
    class_widgets.append(class_widget)

display(widgets.HBox(class_widgets))

HBox(children=(VBox(children=(Image(value=b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x01\x00`\x00`\x00\x00\xff…

In [4]:
import cv2
import numpy as np

mean = 255.0 * np.array([0.485, 0.456, 0.406])
stdev = 255.0 * np.array([0.229, 0.224, 0.225])

normalize = torchvision.transforms.Normalize(mean, stdev)

def preprocess(camera_value):
    global device, normalize
    x = camera_value
    x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
    x = x.transpose((2, 0, 1))
    x = torch.from_numpy(x).float()
    x = normalize(x)
    x = x.to(device)
    x = x[None, ...]
    return x

In [5]:
model = torchvision.models.resnet18(pretrained=False)
model.fc = torch.nn.Linear(512, 7)
model.load_state_dict(torch.load('best_model_blocks.pth'))
device = torch.device('cuda')
model = model.to(device)
model = model.eval()

In [6]:
import torch.nn.functional as F
import time

def update(change):
    global prob_widgets
    x = change['new'] 
    x = preprocess(x)
    y = model(x)
    y = F.softmax(y, dim=1).detach().cpu().numpy().flatten()
   
    for i, cw in enumerate(prob_widgets):
        prob_widgets[i].value = float(y[i])
        
update({'new': camera.value})

In [7]:
camera.observe(update, names='value')

In [None]:
camera.unobserve_all()