# JetBot Behavior Cloning - JetBot Notebook

This notebook runs on your **JetBot** for data collection and inference.

## Cells
1. **Config** - All parameters (run first, always)
2. **Crop Tuning** - Adjust crop visually (run once during setup)
3. **Data Collection** - Collect training data with joystick
4. **Inference** - Run trained model autonomously
5. **DAgger** - Collect corrections while model drives

In [None]:
# =============================================================================
# CELL 1: CONFIGURATION (Run this first!)
# =============================================================================

import torch

# -----------------------------------------------------------------------------
# PATHS
# -----------------------------------------------------------------------------
DATASET_DIR = 'dataset_v1'
DAGGER_DIR = 'dataset_dagger'
MODEL_PATH = 'steering_model_v1.pth'

# -----------------------------------------------------------------------------
# CAMERA
# -----------------------------------------------------------------------------
CAMERA_WIDTH = 640
CAMERA_HEIGHT = 480

# -----------------------------------------------------------------------------
# PREPROCESSING (must match PC training notebook)
# -----------------------------------------------------------------------------
CROP_TOP = 0.20
CROP_BOTTOM = 0.00
CROP_LEFT = 0.08
CROP_RIGHT = 0.12
INPUT_SIZE = (224, 224)

# ImageNet normalization
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

# -----------------------------------------------------------------------------
# MOTOR CONTROL
# -----------------------------------------------------------------------------
FORWARD_SPEED = 0.12
STEERING_GAIN = 0.08

# -----------------------------------------------------------------------------
# TIMING
# -----------------------------------------------------------------------------
LOOP_HZ = 20
LOOP_SLEEP = 1.0 / LOOP_HZ

# -----------------------------------------------------------------------------
# JOYSTICK 
# -----------------------------------------------------------------------------
AXIS_STEERING = 0  # Left stick X
RB_BUTTON = 5      # Right bumper
DEADZONE = 0.05

# -----------------------------------------------------------------------------
# DEVICE
# -----------------------------------------------------------------------------
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Configuration loaded. Device: {DEVICE}")

In [None]:
# =============================================================================
# CELL 2: CROP TUNING (Run once during setup)
# =============================================================================
# Adjust sliders to find optimal crop values, then update Cell 1

import cv2
import numpy as np
import ipywidgets as widgets
from IPython.display import display
from jetbot import Camera, bgr8_to_jpeg

camera = Camera.instance(width=CAMERA_WIDTH, height=CAMERA_HEIGHT)

# Widgets
raw_widget = widgets.Image(format='jpeg', width=320, height=240)
crop_widget = widgets.Image(format='jpeg', width=224, height=224)
top_slider = widgets.FloatSlider(value=CROP_TOP, min=0, max=0.6, step=0.02, description='Top:')
bottom_slider = widgets.FloatSlider(value=CROP_BOTTOM, min=0, max=0.5, step=0.02, description='Bottom:')
left_slider = widgets.FloatSlider(value=CROP_LEFT, min=0, max=0.3, step=0.02, description='Left:')
right_slider = widgets.FloatSlider(value=CROP_RIGHT, min=0, max=0.3, step=0.02, description='Right:')
info_label = widgets.Label()
stop_btn = widgets.Button(description='STOP', button_style='danger')

running = True

def update(change=None):
    if not running:
        return
    raw = camera.value
    h, w = raw.shape[:2]
    
    t, b, l, r = top_slider.value, bottom_slider.value, left_slider.value, right_slider.value
    y0, y1 = int(h * t), int(h * (1 - b))
    x0, x1 = int(w * l), int(w * (1 - r))
    
    cropped = cv2.resize(raw[y0:y1, x0:x1], (224, 224))
    
    raw_viz = raw.copy()
    cv2.rectangle(raw_viz, (x0, y0), (x1, y1), (0, 255, 0), 2)
    
    raw_widget.value = bgr8_to_jpeg(raw_viz)
    crop_widget.value = bgr8_to_jpeg(cropped)
    info_label.value = f'{w}x{h} -> {x1-x0}x{y1-y0} -> 224x224'

def stop(b):
    global running
    running = False
    camera.stop()
    print(f"\nCopy these values to Cell 1:")
    print(f"CROP_TOP = {top_slider.value:.2f}")
    print(f"CROP_BOTTOM = {bottom_slider.value:.2f}")
    print(f"CROP_LEFT = {left_slider.value:.2f}")
    print(f"CROP_RIGHT = {right_slider.value:.2f}")

stop_btn.on_click(stop)
for s in [top_slider, bottom_slider, left_slider, right_slider]:
    s.observe(update, names='value')
camera.observe(update, names='value')

display(widgets.HBox([raw_widget, crop_widget]))
display(top_slider, bottom_slider, left_slider, right_slider)
display(info_label, stop_btn)

In [None]:
# =============================================================================
# CELL 3: DATA COLLECTION
# =============================================================================
# Hold RB to drive forward and record. Use left stick to steer.

import os
import time
import threading
import cv2
import numpy as np
import ipywidgets as widgets
from IPython.display import display
from jetbot import Camera, Robot, bgr8_to_jpeg

# Create directory
os.makedirs(DATASET_DIR, exist_ok=True)
existing = len([f for f in os.listdir(DATASET_DIR) if f.endswith('.jpg')])
print(f"Dataset: {DATASET_DIR} ({existing} existing images)")

# Hardware
camera = Camera.instance(width=CAMERA_WIDTH, height=CAMERA_HEIGHT)
robot = Robot()

# Preprocessing
def preprocess(img):
    h, w = img.shape[:2]
    y0, y1 = int(h * CROP_TOP), int(h * (1 - CROP_BOTTOM))
    x0, x1 = int(w * CROP_LEFT), int(w * (1 - CROP_RIGHT))
    return cv2.resize(img[y0:y1, x0:x1], INPUT_SIZE)

# Widgets
image_widget = widgets.Image(format='jpeg', width=224, height=224)
steering_slider = widgets.FloatSlider(value=0, min=-1, max=1, description='Steering:', disabled=True)
count_widget = widgets.IntText(value=existing, description='Images:', disabled=True)
status_label = widgets.Label(value='Ready - Press START')
start_btn = widgets.Button(description='START', button_style='success')
stop_btn = widgets.Button(description='STOP', button_style='danger')
controller = widgets.Controller()

running = False
image_count = existing

def collection_loop():
    global running, image_count
    
    while running:
        t0 = time.time()
        
        # Read joystick
        try:
            steering = controller.axes[AXIS_STEERING].value
            rb = controller.buttons[RB_BUTTON].value > 0.5
        except:
            steering, rb = 0.0, False
        
        if abs(steering) < DEADZONE:
            steering = 0.0
        
        # Process frame
        raw = camera.value
        processed = preprocess(raw)
        
        if rb:
            # Drive and record
            left = FORWARD_SPEED + steering * STEERING_GAIN
            right = FORWARD_SPEED - steering * STEERING_GAIN
            robot.left_motor.value = max(-1, min(1, left))
            robot.right_motor.value = max(-1, min(1, right))
            
            # Save
            filename = f"{int(time.time()*1000)}_{steering:.3f}.jpg"
            cv2.imwrite(os.path.join(DATASET_DIR, filename), processed)
            image_count += 1
            
            status_label.value = f'RECORDING ({image_count} images)'
        else:
            robot.stop()
            status_label.value = f'PAUSED - Hold RB ({image_count} images)'
        
        # Update UI
        image_widget.value = bgr8_to_jpeg(processed)
        steering_slider.value = steering
        count_widget.value = image_count
        
        time.sleep(max(0, LOOP_SLEEP - (time.time() - t0)))
    
    robot.stop()
    status_label.value = f'STOPPED ({image_count} images)'

def on_start(b):
    global running
    if not running:
        running = True
        threading.Thread(target=collection_loop, daemon=True).start()

def on_stop(b):
    global running
    running = False
    robot.stop()

start_btn.on_click(on_start)
stop_btn.on_click(on_stop)

print("Connect controller and press START")
display(controller)
display(widgets.VBox([image_widget, steering_slider, count_widget, status_label]))
display(widgets.HBox([start_btn, stop_btn]))

In [None]:
# Cleanup after data collection
running = False
robot.stop()
camera.stop()
print(f"Done. {image_count} images in {DATASET_DIR}")

In [None]:
# =============================================================================
# CELL 4: INFERENCE (Autonomous Driving)
# =============================================================================

import os
import time
import threading
import cv2
import numpy as np
import torch
import torch.nn as nn
import torchvision.models as models
import ipywidgets as widgets
from IPython.display import display
from jetbot import Camera, Robot, bgr8_to_jpeg

# Model
def get_model():
    model = models.resnet18(pretrained=False)
    model.fc = nn.Linear(model.fc.in_features, 1)
    return model

# Preprocessing
def preprocess_for_inference(img):
    h, w = img.shape[:2]
    y0, y1 = int(h * CROP_TOP), int(h * (1 - CROP_BOTTOM))
    x0, x1 = int(w * CROP_LEFT), int(w * (1 - CROP_RIGHT))
    cropped = cv2.resize(img[y0:y1, x0:x1], INPUT_SIZE)
    
    rgb = cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
    for i in range(3):
        rgb[:,:,i] = (rgb[:,:,i] - IMAGENET_MEAN[i]) / IMAGENET_STD[i]
    
    tensor = torch.from_numpy(rgb.transpose(2,0,1)).unsqueeze(0)
    return tensor, cropped

# Load model
print(f"Loading {MODEL_PATH}...")
model = get_model()
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.to(DEVICE).eval()
print("Model loaded.")

# Hardware
camera = Camera.instance(width=CAMERA_WIDTH, height=CAMERA_HEIGHT)
robot = Robot()

# Widgets
image_widget = widgets.Image(format='jpeg', width=224, height=224)
steering_slider = widgets.FloatSlider(value=0, min=-1, max=1, description='Steering:', disabled=True)
fps_widget = widgets.FloatText(value=0, description='FPS:', disabled=True)
status_label = widgets.Label(value='Ready')
start_btn = widgets.Button(description='START', button_style='success')
stop_btn = widgets.Button(description='STOP', button_style='danger')

running = False

def inference_loop():
    global running
    frame_times = []
    
    while running:
        t0 = time.time()
        
        # Get prediction
        raw = camera.value
        tensor, display_img = preprocess_for_inference(raw)
        
        with torch.no_grad():
            steering = model(tensor.to(DEVICE)).item()
        steering = max(-1, min(1, steering))
        
        # Drive
        left = FORWARD_SPEED + steering * STEERING_GAIN
        right = FORWARD_SPEED - steering * STEERING_GAIN
        robot.left_motor.value = max(-1, min(1, left))
        robot.right_motor.value = max(-1, min(1, right))
        
        # FPS
        frame_times.append(time.time() - t0)
        if len(frame_times) > 30:
            frame_times.pop(0)
        fps = 1.0 / (sum(frame_times) / len(frame_times))
        
        # UI
        image_widget.value = bgr8_to_jpeg(display_img)
        steering_slider.value = steering
        fps_widget.value = round(fps, 1)
        status_label.value = f'RUNNING | Steer: {steering:.2f}'
        
        time.sleep(max(0, LOOP_SLEEP - (time.time() - t0)))
    
    robot.stop()
    status_label.value = 'STOPPED'

def on_start(b):
    global running
    if not running:
        running = True
        threading.Thread(target=inference_loop, daemon=True).start()

def on_stop(b):
    global running
    running = False
    robot.stop()

start_btn.on_click(on_start)
stop_btn.on_click(on_stop)

display(widgets.VBox([image_widget, steering_slider, fps_widget, status_label]))
display(widgets.HBox([start_btn, stop_btn]))

In [None]:
# Cleanup after inference
running = False
robot.stop()
camera.stop()
print("Stopped.")

In [None]:
# =============================================================================
# CELL 5: DAgger (Dataset Aggregation)
# =============================================================================
# Model drives. Hold RB to take over and save corrections.

import os
import time
import threading
import cv2
import numpy as np
import torch
import torch.nn as nn
import torchvision.models as models
import ipywidgets as widgets
from IPython.display import display
from jetbot import Camera, Robot, bgr8_to_jpeg

# Create DAgger directory
os.makedirs(DAGGER_DIR, exist_ok=True)
existing = len([f for f in os.listdir(DAGGER_DIR) if f.endswith('.jpg')])
print(f"DAgger dir: {DAGGER_DIR} ({existing} existing corrections)")

# Model
def get_model():
    model = models.resnet18(pretrained=False)
    model.fc = nn.Linear(model.fc.in_features, 1)
    return model

# Preprocessing
def preprocess_for_save(img):
    h, w = img.shape[:2]
    y0, y1 = int(h * CROP_TOP), int(h * (1 - CROP_BOTTOM))
    x0, x1 = int(w * CROP_LEFT), int(w * (1 - CROP_RIGHT))
    return cv2.resize(img[y0:y1, x0:x1], INPUT_SIZE)

def preprocess_for_inference(img):
    cropped = preprocess_for_save(img)
    rgb = cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
    for i in range(3):
        rgb[:,:,i] = (rgb[:,:,i] - IMAGENET_MEAN[i]) / IMAGENET_STD[i]
    tensor = torch.from_numpy(rgb.transpose(2,0,1)).unsqueeze(0)
    return tensor, cropped

# Load model
print(f"Loading {MODEL_PATH}...")
model = get_model()
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.to(DEVICE).eval()
print("Model loaded.")

# Hardware
camera = Camera.instance(width=CAMERA_WIDTH, height=CAMERA_HEIGHT)
robot = Robot()

# Widgets
image_widget = widgets.Image(format='jpeg', width=224, height=224)
model_slider = widgets.FloatSlider(value=0, min=-1, max=1, description='Model:', disabled=True)
human_slider = widgets.FloatSlider(value=0, min=-1, max=1, description='Human:', disabled=True)
count_widget = widgets.IntText(value=existing, description='Corrections:', disabled=True)
status_label = widgets.Label(value='Ready')
start_btn = widgets.Button(description='START', button_style='success')
stop_btn = widgets.Button(description='STOP', button_style='danger')
controller = widgets.Controller()

running = False
correction_count = existing

def dagger_loop():
    global running, correction_count
    
    while running:
        t0 = time.time()
        
        raw = camera.value
        tensor, display_img = preprocess_for_inference(raw)
        
        # Model prediction
        with torch.no_grad():
            model_steer = model(tensor.to(DEVICE)).item()
        model_steer = max(-1, min(1, model_steer))
        
        # Human input
        try:
            human_steer = controller.axes[AXIS_STEERING].value
            rb = controller.buttons[RB_BUTTON].value > 0.5
        except:
            human_steer, rb = 0.0, False
        
        if abs(human_steer) < DEADZONE:
            human_steer = 0.0
        
        # Who controls?
        if rb:
            active_steer = human_steer
            # Save correction
            processed = preprocess_for_save(raw)
            filename = f"{int(time.time()*1000)}_{human_steer:.3f}.jpg"
            cv2.imwrite(os.path.join(DAGGER_DIR, filename), processed)
            correction_count += 1
            status_label.value = f'CORRECTING - RB held ({correction_count})'
        else:
            active_steer = model_steer
            status_label.value = f'MODEL DRIVING ({correction_count} corrections)'
        
        # Drive
        left = FORWARD_SPEED + active_steer * STEERING_GAIN
        right = FORWARD_SPEED - active_steer * STEERING_GAIN
        robot.left_motor.value = max(-1, min(1, left))
        robot.right_motor.value = max(-1, min(1, right))
        
        # UI
        image_widget.value = bgr8_to_jpeg(display_img)
        model_slider.value = model_steer
        human_slider.value = human_steer
        count_widget.value = correction_count
        
        time.sleep(max(0, LOOP_SLEEP - (time.time() - t0)))
    
    robot.stop()
    status_label.value = f'STOPPED ({correction_count} corrections)'

def on_start(b):
    global running
    if not running:
        running = True
        threading.Thread(target=dagger_loop, daemon=True).start()

def on_stop(b):
    global running
    running = False
    robot.stop()

start_btn.on_click(on_start)
stop_btn.on_click(on_stop)

print("Connect controller. Model drives, hold RB to correct.")
display(controller)
display(widgets.VBox([image_widget, model_slider, human_slider, count_widget, status_label]))
display(widgets.HBox([start_btn, stop_btn]))

In [None]:
# Cleanup after DAgger
running = False
robot.stop()
camera.stop()
print(f"Done. {correction_count} corrections in {DAGGER_DIR}")
print("Copy DAgger folder to PC, combine with original data, and retrain.")