In [None]:
# Camera stuff
# https://github.com/NVIDIA-AI-IOT/jetcam
from jetcam.usb_camera import USBCamera

IMG_W = 224
IMG_H = 224
CAPTURE_W = 640
CAPTURE_H = 480
CAPTURE_DEVICE = 0
camera = USBCamera(width = IMG_W, 
                   height = IMG_H, 
                   capture_width = CAPTURE_W, 
                   capture_height = CAPTURE_H, 
                   capture_device = CAPTURE_DEVICE)
camera.running = True       # Enables callback mode
print("Camera working")

In [None]:
# Prepping some variables for training
# Need dataset.py
import torchvision.transforms as transforms
from dataset import ImageClassificationDataset

TASK = "dog_breed"
CATEGORIES = ["golden retreiver", "chow chow", "sheepdoggo"]
DATASETS = ["A", "B", "C"]
TRANSFORMS = transforms.Compose([
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
dataset_dict = {}
for name in DATASETS:
    dataset_folder_name = "../data/classification/" + TASK + "_" + name
    dataset_dict[name] = ImageClassificationDataset(dataset_folder_name, CATEGORIES, TRANSFORMS)

print("{} task with {} categories defined.".format(TASK, CATEGORIES))

# Setup data directory if needed
DATA_DIR = "/nvdli-nano/data/classification/"
!mkdir -p {DATA_DIR}

In [None]:
# Data collection widget
import ipywidgets
import traitlets
from IPython.display import display
from jetcam.utils import bgr8_to_jpeg

working_dataset = dataset_dict[DATASETS[0]]
camera.unobserve_all()
camera_widget = ipywidgets.Image()
traitlets.dlink((camera, "value"), (camera_widget, "value"), transform=bgr8_to_jpeg)

dataset_widget = ipywidgets.Dropdown(options = DATASETS, descriptions = "dataset")
category_widget = ipywidgets.Dropdown(options = working_dataset.categories, description = "category")
count_widget = ipywidgets.IntText(description = "count")
save_widget = ipywidgets.Button(description = "add")

count_widget.value  = working_dataset.get_count(category_widget.value) # update count

def set_dataset(p_dataset):
    global working_dataset
    working_dataset = dataset_dict[p_dataset["new"]]
    count_widget.value = working_dataset.get_count(category_widget.value)

def update_count(p_count):
    count_widget.value = working_dataset.get_count(p_count["new"])

def save(img):
    working_dataset.save_entry(camera.value, categroy_widget.value)
    count_widget.value = working_dataset.get_count(cateory_widget.value)

dataset_widget.observe(set_dataset, names = "value")
count_widget.observe(update_count, names = "value")
save_widget.on_click(save)

data_collection_widget = ipywidgets.VBox([
    ipywidgets.HBox([camera_widget]), dataset_widget, category_widget, count_widget, save_widget
])

print("data collection widget created.")

In [None]:
# Model selection stuff
import torch
import torchvision
from datetime import datetime

device = torch.device("cuda")

# RESNET 18
model = torchvision.models.resnet18(pretrained = True)
# model = torchvision.models.resnet34(pretrained = True) for RESNET 34
model.fc = torch.nn.Linear(512, len(working_dataset.categories))

# ALEXNET
# model = torchvision.models.alexnet(pretrained = True)
# model.classifier[-1] = torch.nn.Linear(4096, len(working_dataset.categories))

# SQUEEZENET
# model = torchvision.models.squeezenet1_1(pretrained = True)
# model.classifier[1] = torch.nn.Conv2d(512, len(working_dataset.categories), kernel_size = 1)
# model.num_classes = len(working_dataset.categories)

model = model.to(device)
model_save_button = ipywidgets.Button(description = "save model")
model_load_button = ipywidgets.Button(description = "load model")
model_path = "/nvdli-nano/data/classification/model-" + datetime.now().strftime("%d-%m-%Y-%H:%M:%S") + ".pth"
model_path_widget = ipywidgets.Text(description = "model path", value = model_path)

def load_model_cb(param):
    model.load_state_dict(torch.load(model_path_widget.value))

def save_model_cb(param):
    torch.save(model.state_dict(), model_path_widget.value)

model_load_button.on_click(load_model_cb)
model_save_button.on_click(save_model_cb)
model_widget = ipywidgets.VBox([
    model_path_widget,
    ipywidgets.HBox([model_load_button, model_save_button])
])
print("Model configured and model widget created.")
print(model_path)

In [None]:
# Live demo widget setup
import threading
import time
from utils import preprocess 
import torch.nn.functional as F

state_widget = ipywidgets.ToggleButtons(options = ["stop", "live"], description = "state", value = "stop")
prediction_widget = ipywidgets.Text(description = "prediction")
score_widget_list = []

for category in working_dataset.categories:
    score_widget = ipywidgets.FloatSlider(min = 0.0, max = 1.0, description = category, orientation = "vertical")
    score_widget_list.append(score_widget)

def live_cb(p_state_widget, p_model, p_camera, p_prediction_widget, p_score_widget):
    global working_dataset
    while p_state_widget.value == "live":
        image = p_camera.value
        preprocessed = preprocess(image)
        output = p_model(preprocessed)
        output = F.softmax(output, dim = 1).detach().cpu().numpy().flatten() # Converting eval result into probabilty
        category_index = output.argmax()
        p_prediction_widget.value = working_dataset.categories[category_index]
        for i, score in enumerate(list(output)):
            score_widget_list[i].value = score
            
def start_live_cb(state):
    if state["new"] == "live":
        execute_thread = threading.Thread(target = live, 
                                          args=(state_widget, model, camera, prediction_widget, score_widget))
        execute_thread.start()
    
state_widget.observe(start_live_cb , names = "value")

live_execution_widget = ipywidgets.VBox([
    ipywidgets.HBox(score_widget_list),
    prediction_widget,
    state_widget
])

print("live_execution_widget created!")

In [None]:
# Training and Evaluation
BATCH_SIZE = 0
optimizer = torch.optim.Adam(model.parameters())
# optimizer = torch.optim.SGD(model.parameters(), lr = 1e-3, momentum = 0.9)

epochs_widget = ipywidgets.IntText(description = "epochs", value = 1)
eval_button = ipywidgets.Button(description = "evaluate")
train_button = ipywidgets.Button(description = "train")
loss_widget = ipywidgets.FloatText(description = "loss")
accuracy_widget = ipywidgets.FloatText(description = "accuracy")
progress_widget = ipywidgets.FloatProgress(min = 0.0, max = 1.0, description = "progress")

def train_eval_button_cb(is_training):
    global BATCH_SIZE, LEARNING_RATE, MOMENTUM, model, working_dataset, optimizer
    global eval_button, train_button, accuracy_widget, loss_widget, progress_widget, state_widget
    
    try:
        train_loader = torch.utils.data.DataLoader(
            working_dataset,
            batch_size = BATCH_SIZE,
            shuffle = True
        )
        
        state_widget.value = "stop"
        train_button.disabled, eval_button.disabled = True
        time.sleep(1)
        
        if is_training:
            model = model.train()
        else:
            model = model.eval()
        while epochs_widget.value > 0:
                i = 0
                loss_sum, error_count = 0.0
                for images, labels in iter(train_loader):
                    images = images.to(device)
                    labels = labels.to(device)
                    
                    if is_training:
                        optimizer.zero_grad() # zero gradients of parameters
                
                    output_list = model(images)
                    loss = F.cross_entropy(output_list, labels) 
                
                    if is_training:
                        loss.backward() # backpropogation
                        optimizer.step() # adjust parameters with optimizer
                
                    error_count += len(torch.nonzero(outputs.argmax(1) - labels).flatten()) # update progress
                    count = len(labels.flatten())
                    i += count
                    loss_sum += float(loss)
                    progress_widget.value = i / len(dataset)
                    loss_widget.value = sum_loss / i
                    accuracy_widget.value = 1.0 - error_count / i
            
                if is_training:
                    epochs_widget.value = epochs_widget.value - 1
                else:
                    break
    except e:
        pass
    model = model.eval()
    train_button.disabled, eval_button.disabled = False
    state_widget.value = "live"

train_button.on_click(lambda cb: train_eval_button_cb(is_training = True))
eval_button.on_click(lambda cb: train_eval_button_cb(is_training = False))

train_eval_widget = ipywidgets.VBox([
    epochs_widget,
    progress_widget,
    loss_widget,
    accuracy_widget,
    ipywidgets.HBox([train_button, eval_button])
])

print("Trainer configured and training widget created!")

In [None]:
all_widget = ipywidgets.VBox([
    ipywidgets.HBox([data_collection_widget, live_execution_widget]),
    train_eval_widget,
    model_widget
])

display(all_widget)

In [None]:
# Disconnect kernel and camera (the usb camera simply disconnects WITH the kernel, not much else to do)
import os
import IPython
os._exit(00)