In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from xy_dataset import XYDataset
import torch.utils.data
from torch.utils.data import DataLoader
import cv2
import ipywidgets
import traitlets
from IPython.display import display
from utils import preprocess
import threading
import time
from utils import preprocess
import torch.nn.functional as F

In [2]:
CATEGORIES = ['apex']

TRANSFORMS = transforms.Compose([
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

dataset = XYDataset("data_obstacle_detection/road_following_D", CATEGORIES, TRANSFORMS, random_hflip=True)

In [3]:
#device = torch.device('cuda')
device = torch.device('cpu')
output_dim = 2 * len(dataset.categories)  # x, y coordinate for each category

# RESNET 18
model = torchvision.models.resnet18(pretrained=True)
model.fc = torch.nn.Linear(512, output_dim)

model = model.to(device)

model_save_button = ipywidgets.Button(description='save model')
model_load_button = ipywidgets.Button(description='load model')
model_path_widget = ipywidgets.Text(description='model path', value='road_following_model_resnet18.pth')

def load_model(c):
    model.load_state_dict(torch.load(model_path_widget.value))
model_load_button.on_click(load_model)
    
def save_model(c):
    torch.save(model.state_dict(), model_path_widget.value)
model_save_button.on_click(save_model)

model_widget = ipywidgets.VBox([
    model_path_widget,
    ipywidgets.HBox([model_load_button, model_save_button])
])

display(model_widget)



VBox(children=(Text(value='road_following_model_resnet18.pth', description='model path'), HBox(children=(Butto…

In [4]:
BATCH_SIZE = 8

optimizer = torch.optim.Adam(model.parameters())

epochs_widget = ipywidgets.IntText(description='epochs', value=1)
eval_button = ipywidgets.Button(description='evaluate')
train_button = ipywidgets.Button(description='train')
loss_widget = ipywidgets.FloatText(description='loss')
progress_widget = ipywidgets.FloatProgress(min=0.0, max=1.0, description='progress')
state_widget = ipywidgets.ToggleButtons(options=['stop', 'live'], description='state', value='stop')

def train_eval(is_training):
    global BATCH_SIZE, LEARNING_RATE, MOMENTUM, model, dataset, optimizer, eval_button, train_button, accuracy_widget, loss_widget, progress_widget, state_widget
    
    #try:
    train_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True
    )

    state_widget.value = 'stop'
    train_button.disabled = True
    eval_button.disabled = True
    time.sleep(1)

    if is_training:
        model = model.train()
    else:
        model = model.eval()

    while epochs_widget.value > 0:
        i = 0
        sum_loss = 0.0
        error_count = 0.0
        for images, category_idx, xy in iter(train_loader):
            # send data to device
            images = images.to(device)
            xy = xy.to(device)

            if is_training:
                # zero gradients of parameters
                optimizer.zero_grad()

            # execute model to get outputs
            outputs = model(images)

            # compute MSE loss over x, y coordinates for associated categories
            loss = 0.0
            for batch_idx, cat_idx in enumerate(list(category_idx.flatten())):
                loss += torch.mean((outputs[batch_idx][2 * cat_idx:2 * cat_idx+2] - xy[batch_idx])**2)
            loss /= len(category_idx)

            if is_training:
                # run backpropogation to accumulate gradients
                loss.backward()

                # step optimizer to adjust parameters
                optimizer.step()

            # increment progress
            count = len(category_idx.flatten())
            i += count
            sum_loss += float(loss)
            progress_widget.value = i / len(dataset)
            loss_widget.value = sum_loss / i
                
        if is_training:
            epochs_widget.value = epochs_widget.value - 1
        else:
            break
    #except Exception as e:
     #   pass
    model = model.eval()

    train_button.disabled = False
    eval_button.disabled = False
    state_widget.value = 'live'
    
train_button.on_click(lambda c: train_eval(is_training=True))
eval_button.on_click(lambda c: train_eval(is_training=False))
    
train_eval_widget = ipywidgets.VBox([
    epochs_widget,
    progress_widget,
    loss_widget,
    ipywidgets.HBox([train_button, eval_button])
])

display(train_eval_widget)

VBox(children=(IntText(value=1, description='epochs'), FloatProgress(value=0.0, description='progress', max=1.…

In [5]:
len(dataset)

780