# Jetbot Demo

In [None]:
# importing necessary libraries

import os
import torch
import torchvision.transforms as transforms
from PIL import Image
from jetbot import Robot
import numpy as np
from JetbotUsbCamera import JetbotUsbCamera as usbCam
import time
import threading
import ipywidgets as widgets
from IPython.display import display
import cv2
import torch.nn as nn
import torch.optim as optim
from torchvision import models

In [None]:
# Define the model architecture
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=5):
        super(SimpleCNN, self).__init__()
        self.features = models.resnet18(pretrained=True)
        self.features.fc = nn.Linear(self.features.fc.in_features, num_classes)

    def forward(self, x):
        return self.features(x)


In [None]:
# Load the trained model
model_path = '/workspace/jetbot/notebooks/testing/training/road_following_model.pth'
model = SimpleCNN(num_classes=5)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')))
model.eval()

In [None]:
# Define the transformation for the input image
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
# Initialize the JetBot and camera
robot = Robot()
camera = usbCam(device=0, width=224, height=224)

In [None]:
# Preprocess the image for the model
def preprocess(image):
    image = Image.fromarray(image)
    image = transform(image).unsqueeze(0)
    return image

In [None]:
# Predict the direction based on the image
def predict_direction(image):
    output = model(image)
    _, pred = torch.max(output, 1)
    return pred.item()

In [None]:
# Display arrows on the image based on the prediction
def display_arrows(camera_image, direction):
    image = np.copy(camera_image)
    arrow_color = (0, 255, 0)  # Green
    glow_color = (0, 255, 255)  # Yellow

    overlay = image.copy()
    if direction == 0:  # up
        cv2.arrowedLine(overlay, (112, 200), (112, 24), glow_color, 3, tipLength=0.5)
    if direction == 2:  # left
        cv2.arrowedLine(overlay, (200, 112), (24, 112), glow_color, 3, tipLength=0.5)
    if direction == 3:  # right
        cv2.arrowedLine(overlay, (24, 112), (200, 112), glow_color, 3, tipLength=0.5)
    if direction == 4:  # obstacle
        cv2.putText(overlay, "Obstacle", (50, 112), cv2.FONT_HERSHEY_SIMPLEX, 1, glow_color, 2, cv2.LINE_AA)

    cv2.addWeighted(overlay, 0.6, image, 0.4, 0, image)
    return image

In [None]:
# Move the robot based on the prediction
def move_robot(prediction):
    if prediction == 0:  # up
        robot.forward(speed_slider.value)
    elif prediction == 2:  # left
        robot.left(speed_slider.value)
    elif prediction == 3:  # right
        robot.right(speed_slider.value)
    elif prediction == 4:  # obstacle
        robot.stop()
    else:
        robot.stop()

In [None]:
# Update the image and make predictions
def update_image():
    global stop_requested, manual_mode
    while not stop_requested:
        frame = camera.read()
        if frame is not None:
            if not manual_mode:
                image = preprocess(frame)
                prediction = predict_direction(image)
                move_robot(prediction)
                frame = display_arrows(frame, prediction)
            target_widget.value = cv2.imencode('.jpg', frame)[1].tobytes()
        time.sleep(0.1)

In [None]:
# Define the functions for the buttons
def start_camera(_):
    global stop_requested
    stop_requested = False
    camera.start()
    update_thread = threading.Thread(target=update_image)
    update_thread.start()
    camera_active_indicator.button_style = 'success'

def stop_camera(_):
    global stop_requested
    stop_requested = True
    camera.stop()
    robot.stop()
    camera_active_indicator.button_style = ''

def start_robot(_):
    global robot_active
    robot_active = True
    robot_active_indicator.button_style = 'success'

def stop_robot(_):
    global robot_active
    robot_active = False
    robot.stop()
    robot_active_indicator.button_style = ''

def move_forward(_):
    global move_up
    move_up = not move_up
    if move_up:
        robot.forward(speed_slider.value)
    else:
        robot.stop()

def move_backward(_):
    global move_down
    move_down = not move_down
    if move_down:
        robot.backward(speed_slider.value)
    else:
        robot.stop()

def turn_left(_):
    global move_left
    move_left = not move_left
    if move_left:
        robot.left(speed_slider.value)
    else:
        robot.stop()

def turn_right(_):
    global move_right
    move_right = not move_right
    if move_right:
        robot.right(speed_slider.value)
    else:
        robot.stop()

def manual_control_toggle(_):
    global manual_mode
    manual_mode = not manual_mode
    manual_control_indicator.button_style = 'success' if manual_mode else ''

In [None]:
# Create the widgets
target_widget = widgets.Image(format='jpeg', width=224, height=224)
speed_slider = widgets.FloatSlider(min=0.0, max=1.0, step=0.01, value=0.3, description='Speed')
speed_text = widgets.FloatText(value=0.3, description='Speed')
widgets.jslink((speed_slider, 'value'), (speed_text, 'value'))

camera_active_indicator = widgets.Button(description='', button_style='', layout=widgets.Layout(width='20px', height='20px'), disabled=True)
robot_active_indicator = widgets.Button(description='', button_style='', layout=widgets.Layout(width='20px', height='20px'), disabled=True)
manual_control_indicator = widgets.Button(description='', button_style='', layout=widgets.Layout(width='20px', height='20px'), disabled=True)

start_camera_button = widgets.Button(description='Start Camera', button_style='success')
stop_camera_button = widgets.Button(description='Stop Camera', button_style='danger')
start_robot_button = widgets.Button(description='Start Robot', button_style='success')
stop_robot_button = widgets.Button(description='Stop Robot', button_style='warning')
manual_control_button = widgets.Button(description='Manual Control', button_style='info')

up_button = widgets.Button(description='↑', layout=widgets.Layout(width='50px', height='50px'))
down_button = widgets.Button(description='↓', layout=widgets.Layout(width='50px', height='50px'))
left_button = widgets.Button(description='←', layout=widgets.Layout(width='50px', height='50px'))
right_button = widgets.Button(description='→', layout=widgets.Layout(width='50px', height='50px'

In [None]:
# Link the buttons to their functions
start_camera_button.on_click(start_camera)
stop_camera_button.on_click(stop_camera)
start_robot_button.on_click(start_robot)
stop_robot_button.on_click(stop_robot)
manual_control_button.on_click(manual_control_toggle)

up_button.on_click(move_forward)
down_button.on_click(move_backward)
left_button.on_click(turn_left)
right_button.on_click(turn_right)

In [None]:
# Display the widgets
display(target_widget)
display(widgets.HBox([speed_slider, speed_text]))
display(widgets.HBox([start_camera_button, stop_camera_button, camera_active_indicator]))
display(widgets.HBox([start_robot_button, stop_robot_button, robot_active_indicator]))
display(widgets.HBox([manual_control_button, manual_control_indicator]))
display(widgets.HBox([left_button, up_button, right_button, down_button]))

# Initialize global variables
stop_requested = False
robot_active = False
manual_mode = False
move_up = False
move_down = False
move_left = False
move_right = False