### Run this code once

In [None]:
%pip install torchvision

In [None]:
%pip install ipywidgets

In [None]:
%pip install opencv-python

In [None]:
%pip install jupyterlab_widgets

### Task

In [2]:
import ipywidgets as widgets
from IPython.display import display
widgets.IntSlider()

IntSlider(value=0)

In [3]:
import os
import cv2
import torch
import torchvision
import torchvision.transforms as transforms
import ipywidgets
from IPython.display import display
from PIL import Image
import numpy as np
from xy_dataset import XYDataset
import time


TASK = 'road_following'

CATEGORIES = ['apex']

DATASETS = ['A', 'B', 'C', 'D', 'E', 'F', 'G']



# Configuration
TASK = 'road_following'
CATEGORIES = ['apex']
DATASETS = ['A', 'B', 'C', 'D', 'E', 'F', 'G']





### Data Collection

In [4]:
IMAGE_DIRECTORY = 'csubjetracer/notebooks/modifiedtraining/road_following_G/apex'  # Update this path to your image directory

# Define transforms
TRANSFORMS = transforms.Compose([
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


# Initialize datasets
datasets = {}
for name in DATASETS:
    datasets[name] = XYDataset(TASK + '_' + name, CATEGORIES, TRANSFORMS, random_hflip=True)
dataset = datasets[DATASETS[0]]

# Create widgets for dataset selection
dataset_widget = ipywidgets.Dropdown(options=DATASETS, description='dataset')
info_widget = ipywidgets.HTML(value='<b>Dataset Information</b>')

def update_info():
    """Update information display about the current dataset"""
    total_count = 0
    info_text = f'<b>Dataset: {dataset_widget.value}</b><br>'
    for category in dataset.categories:
        count = dataset.get_count(category)
        total_count += count
        info_text += f'{category}: {count} images<br>'
    info_text += f'<b>Total: {total_count} images</b>'
    info_widget.value = info_text

def set_dataset(change):
    """Sets the active dataset"""
    global dataset
    dataset = datasets[change['new']]
    update_info()

dataset_widget.observe(set_dataset, names='value')
update_info()

# Display dataset selection widget
dataset_selection_widget = ipywidgets.VBox([
    dataset_widget,
    info_widget
])
display(dataset_selection_widget)

# Load images from directory (if you want to visualize existing images)
def load_images_from_directory(directory_path, dataset_name):
    """Load existing images from directory into dataset"""
    dataset_path = os.path.join(directory_path, f'{TASK}_{dataset_name}')
    
    if os.path.exists(dataset_path):
        for category in CATEGORIES:
            category_path = os.path.join(dataset_path, category)
            if os.path.exists(category_path):
                image_files = [f for f in os.listdir(category_path) if f.endswith('.jpg')]
                print(f"Found {len(image_files)} images in {category_path}")
                
                # Note: This assumes your images are already saved with annotations
                # If you need to load the annotations, you'll need to parse the filenames
                # or load from a separate annotation file

# Optional: Load existing images
load_button = ipywidgets.Button(description='Load Existing Images')
status_widget = ipywidgets.HTML(value='')

def on_load_click(b):
    status_widget.value = '<i>Loading images...</i>'
    try:
        for dataset_name in DATASETS:
            load_images_from_directory(IMAGE_DIRECTORY, dataset_name)
        status_widget.value = '<b style="color:green">Images loaded successfully!</b>'
    except Exception as e:
        status_widget.value = f'<b style="color:red">Error: {str(e)}</b>'

load_button.on_click(on_load_click)

load_widget = ipywidgets.VBox([
    ipywidgets.HTML(value=f'<b>Image Directory:</b> {IMAGE_DIRECTORY}'),
    load_button,
    status_widget
])
display(load_widget)

VBox(children=(Dropdown(description='dataset', options=('A', 'B', 'C', 'D', 'E', 'F', 'G'), value='A'), HTML(v…

VBox(children=(HTML(value='<b>Image Directory:</b> csubjetracer/notebooks/modifiedtraining/road_following_G/ap…

### Model

In [5]:
import torch

# Model setup
device = torch.device('cuda')
output_dim = 2 * len(dataset.categories)  # x, y coordinate for each category

# RESNET 18 (you can uncomment other models as needed)
model = torchvision.models.resnet18(pretrained=True)
model.fc = torch.nn.Linear(512, output_dim)

# Alternative models (uncomment to use):
# ALEXNET
# model = torchvision.models.alexnet(pretrained=True)
# model.classifier[-1] = torch.nn.Linear(4096, output_dim)

# SQUEEZENET 
# model = torchvision.models.squeezenet1_1(pretrained=True)
# model.classifier[1] = torch.nn.Conv2d(512, output_dim, kernel_size=1)

# RESNET 34
# model = torchvision.models.resnet34(pretrained=True)
# model.fc = torch.nn.Linear(512, output_dim)

# DENSENET 121
# model = torchvision.models.densenet121(pretrained=True)
# model.classifier = torch.nn.Linear(model.num_features, output_dim)

model = model.to(device)

# Model save/load widgets
model_save_button = ipywidgets.Button(description='save model')
model_load_button = ipywidgets.Button(description='load model')
model_path_widget = ipywidgets.Text(description='model path', value='road_following_model.pth')

def load_model(c):
    try:
        model.load_state_dict(torch.load(model_path_widget.value))
        print(f"Model loaded from {model_path_widget.value}")
    except Exception as e:
        print(f"Error loading model: {e}")

model_load_button.on_click(load_model)
    
def save_model(c):
    try:
        torch.save(model.state_dict(), model_path_widget.value)
        print(f"Model saved to {model_path_widget.value}")
    except Exception as e:
        print(f"Error saving model: {e}")

model_save_button.on_click(save_model)

model_widget = ipywidgets.VBox([
    model_path_widget,
    ipywidgets.HBox([model_load_button, model_save_button])
])
display(model_widget)



VBox(children=(Text(value='road_following_model.pth', description='model path'), HBox(children=(Button(descrip…

### Live Execution

In [6]:
import torch

# Training setup
BATCH_SIZE = 8
optimizer = torch.optim.Adam(model.parameters())
# Alternative optimizer:
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)

# Training widgets
epochs_widget = ipywidgets.IntText(description='epochs', value=1)
eval_button = ipywidgets.Button(description='evaluate')
train_button = ipywidgets.Button(description='train')
loss_widget = ipywidgets.FloatText(description='loss')
progress_widget = ipywidgets.FloatProgress(min=0.0, max=1.0, description='progress')
state_widget = ipywidgets.Text(description='state', value='ready')

def train_eval(is_training):
    global BATCH_SIZE, model, dataset, optimizer
    
    try:
        # Create data loader
        train_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=BATCH_SIZE,
            shuffle=True
        )
        
        state_widget.value = 'running'
        train_button.disabled = True
        eval_button.disabled = True
        
        if is_training:
            model = model.train()
        else:
            model = model.eval()
            
        while epochs_widget.value > 0:
            i = 0
            sum_loss = 0.0
            error_count = 0.0
            
            for images, category_idx, xy in iter(train_loader):
                # Send data to device
                images = images.to(device)
                xy = xy.to(device)
                
                if is_training:
                    # Zero gradients
                    optimizer.zero_grad()
                
                # Forward pass
                outputs = model(images)
                
                # Compute MSE loss over x, y coordinates for associated categories
                loss = 0.0
                for batch_idx, cat_idx in enumerate(list(category_idx.flatten())):
                    loss += torch.mean((outputs[batch_idx][2 * cat_idx:2 * cat_idx+2] - xy[batch_idx])**2)
                loss /= len(category_idx)
                
                if is_training:
                    # Backward pass and optimization
                    loss.backward()
                    optimizer.step()
                
                # Update progress
                count = len(category_idx.flatten())
                i += count
                sum_loss += float(loss)
                progress_widget.value = i / len(dataset)
                loss_widget.value = sum_loss / i
                
            if is_training:
                epochs_widget.value = epochs_widget.value - 1
                print(f"Epoch completed. Loss: {sum_loss / i:.4f}")
            else:
                break
                
    except Exception as e:
        print(f"Error during training/evaluation: {e}")
        state_widget.value = 'error'
    
    model = model.eval()
    train_button.disabled = False
    eval_button.disabled = False
    state_widget.value = 'ready'
    
train_button.on_click(lambda c: train_eval(is_training=True))
eval_button.on_click(lambda c: train_eval(is_training=False))
    
train_eval_widget = ipywidgets.VBox([
    epochs_widget,
    progress_widget,
    loss_widget,
    state_widget,
    ipywidgets.HBox([train_button, eval_button])
])
display(train_eval_widget)

VBox(children=(IntText(value=1, description='epochs'), FloatProgress(value=0.0, description='progress', max=1.…

In [7]:
import torch

BATCH_SIZE = 8

optimizer = torch.optim.Adam(model.parameters())
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)

epochs_widget = ipywidgets.IntText(description='epochs', value=1)
eval_button = ipywidgets.Button(description='evaluate')
train_button = ipywidgets.Button(description='train')
loss_widget = ipywidgets.FloatText(description='loss')
progress_widget = ipywidgets.FloatProgress(min=0.0, max=1.0, description='progress')

def train_eval(is_training):
    global BATCH_SIZE, LEARNING_RATE, MOMENTUM, model, dataset, optimizer, eval_button, train_button, accuracy_widget, loss_widget, progress_widget, state_widget
    
    try:
        train_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=BATCH_SIZE,
            shuffle=True
        )

        state_widget.value = 'stop'
        train_button.disabled = True
        eval_button.disabled = True
        time.sleep(1)

        if is_training:
            model = model.train()
        else:
            model = model.eval()

        while epochs_widget.value > 0:
            i = 0
            sum_loss = 0.0
            error_count = 0.0
            for images, category_idx, xy in iter(train_loader):
                # send data to device
                images = images.to(device)
                xy = xy.to(device)

                if is_training:
                    # zero gradients of parameters
                    optimizer.zero_grad()

                # execute model to get outputs
                outputs = model(images)

                # compute MSE loss over x, y coordinates for associated categories
                loss = 0.0
                for batch_idx, cat_idx in enumerate(list(category_idx.flatten())):
                    loss += torch.mean((outputs[batch_idx][2 * cat_idx:2 * cat_idx+2] - xy[batch_idx])**2)
                loss /= len(category_idx)

                if is_training:
                    # run backpropogation to accumulate gradients
                    loss.backward()

                    # step optimizer to adjust parameters
                    optimizer.step()

                # increment progress
                count = len(category_idx.flatten())
                i += count
                sum_loss += float(loss)
                progress_widget.value = i / len(dataset)
                loss_widget.value = sum_loss / i
                
            if is_training:
                epochs_widget.value = epochs_widget.value - 1
            else:
                break
    except e:
        pass
    model = model.eval()

    train_button.disabled = False
    eval_button.disabled = False
    state_widget.value = 'live'
    
train_button.on_click(lambda c: train_eval(is_training=True))
eval_button.on_click(lambda c: train_eval(is_training=False))
    
train_eval_widget = ipywidgets.VBox([
    epochs_widget,
    progress_widget,
    loss_widget,
    ipywidgets.HBox([train_button, eval_button])
,
model_widget
])
display(train_eval_widget)


VBox(children=(IntText(value=1, description='epochs'), FloatProgress(value=0.0, description='progress', max=1.…

### All together!

The following widget can be used to label a multi-class x, y dataset.  It supports labeling only one instance of each class per image (ie: only one dog), but multiple classes (ie: dog, cat, horse) per image are possible.

Click the image on the top left to save an image of ``category`` to ``dataset`` at the clicked location.

| Widget | Description |
|--------|-------------|
| dataset | Selects the active dataset |
| category | Selects the active category |
| epochs | Sets the number of epochs to train for |
| train | Trains on the active dataset for the number of epochs specified |
| evaluate | Evaluates the accuracy on the active dataset over one epoch |
| model path | Sets the active model path |
| load | Loads a model from the active model path |
| save | Saves a model to the active model path |
| stop | Disables the live demo |
| live | Enables the live demo |

In [None]:
summary_widget = ipywidgets.HTML(value=f'''
<h3>Training Configuration Summary</h3>
<ul>
<li><b>Task:</b> {TASK}</li>
<li><b>Categories:</b> {', '.join(CATEGORIES)}</li>
<li><b>Datasets:</b> {', '.join(DATASETS)}</li>
<li><b>Batch Size:</b> {BATCH_SIZE}</li>
<li><b>Model:</b> ResNet18</li>
<li><b>Device:</b> {device}</li>
<li><b>Image Directory:</b> {IMAGE_DIRECTORY}</li>
<li><b>Optimizer:</b> Adam</li>
</ul>
''')
display(summary_widget)

print("\n" + "="*50)
print("✓ SETUP COMPLETE!")
print("="*50)
print("\nNext steps:")
print("1. Click 'Check Image Directories' to verify your images are found")
print("2. Select a dataset from the dropdown if needed")
print("3. Set the number of epochs for training")
print("4. Click 'train' to start training or 'evaluate' to test")
print("\nNote: Make sure your images are organized as:")
print(f"  {IMAGE_DIRECTORY}/road_following_A/apex/*.jpg")
print(f"  {IMAGE_DIRECTORY}/road_following_B/apex/*.jpg")
print("  etc...")
