In [None]:
import traitlets
import ipywidgets.widgets as widgets
from IPython.display import display
from jetbot import Camera, Robot, bgr8_to_jpeg
import torchvision
import torch
import torchvision.transforms as transforms
import cv2
import PIL.Image

In [None]:
camera = Camera.instance(width=224, height=224)

image_widget = widgets.Image(format='jpeg', width=224, height=224)

traitlets.dlink((camera, 'value'), (image_widget, 'value'), transform=bgr8_to_jpeg)

display(image_widget)

In [None]:
model = torchvision.models.resnet18(pretrained=False)
model.fc = torch.nn.Linear(512, 1)
model.load_state_dict(torch.load('best_steering_model_1.pth'))
device = torch.device('cuda')
model = model.to(device)
model = model.eval()

In [None]:
def preprocess(image):
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = PIL.Image.fromarray(image)
    #image = transforms.functional.to_grayscale(image, num_output_channels=3)
    #image = transforms.functional.resized_crop(image, crop_percent * 224, 0, 224 - crop_percent * 224, 224, (224, 224))
    image = transforms.functional.to_tensor(image)
    image = transforms.functional.normalize(image, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    return image[None, ...].to(device)

In [None]:
steering_slider = widgets.FloatSlider(min=-1.0, max=1.0, value=0.0, description='steering')

display(steering_slider)

In [None]:
speed_slider = widgets.FloatSlider(min=0.0, max=1.0, step=0.001, value=0.0, description='speed')
gain_slider = widgets.FloatSlider(min=0.0, max=1.0, step=0.001, value=0.0, description='gain')

display(speed_slider)
display(gain_slider)

In [None]:
from jetbot import Robot

robot = Robot()

In [None]:
def execute(change):
    image = change['new']
    output = model(preprocess(image)).detach().cpu().numpy()
    steering = float(output[0][0])
    steering_slider.value = steering
    
    speed = speed_slider.value
    gain = gain_slider.value
    steering = steering_slider.value
    
    robot.set_motors(
        speed + gain * steering,
        speed - gain * steering
    )
    
execute({'new': camera.value})

In [None]:
camera.observe(execute, names='value')

In [None]:
import time

camera.unobserve_all()
time.sleep(0.5)
robot.stop()