In [None]:
# camera stuffs
!ls -ltrh /dev/video*

from jetcam.usb_camera import USBCamera
CAMERA_W = 224
CAMERA_H = 224
CAPTURE_DEVICE = 0
camera = USBCamera(width = CAMERA_W,
                   height = CAMERA_H,
                   capture_device = CAPTURE_DEVICE)
camera.running = True
print("Camera detected and created!")

In [None]:
# prepping variables for training
import torchvision.transforms as transforms
from dataset import XYDataset

TASK = "doggo"
CATEGORIES = ["nose", "tail", "trunk"]
DATASETS = ['A', 'B']
TRANSFORMS = transforms.Compose([
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
dataset_dict = {}
for name in DATASETS:
    dataset_path = "../data/regression/" + TASK + '_' + name
    dataset_dict[name] = XYDataset( dataset_path, CATEGORIES, TRANSFORMS)
print("{} task with {} categories defined.".format(TASK, CATEGORIES))

In [None]:
# Set up data directory if needed
DATA_DIR = "/nvdli-nano/data/regression"
!mkdir -p {DATA_DIR}

In [None]:
# Data collection stuffs
import cv2
import ipywidgets
import traitlets
from IPython.display import display
from jetcam.utils import bgr8_to_jpeg
from jupyter_clickable_image_widget import ClickableImageWidget

g_dataset = dataset_dict[DATASETS[0]]
camera.unobserve_all() # reset callbacks

camera_widget = ClickableImageWidget(width = camera.width, height = camera.height)
with open("../images/ready_img.jpg", "rb") as file:
    default_image = file.read()
snapshot_widget = ipywidgets.Image(value = default_image, width = camera.width, height = camera.height)
traitlets.dlink((camera, "value"), (camera_widget, "value"), transform = bgr8_to_jpeg)

dataset_widget = ipywidgets.Dropdown(options = DATASETS, description = "dataset")
category_widget = ipywidgets.Dropdown(options = g_dataset.categories, description = "category")
count_widget = ipywidgets.IntText(description = "count")
count_widget.value = g_dataset.get_count(category_widget.value) # update count

def set_dataset_cb(p_dataset):
    global g_dataset
    g_dataset = dataset_dict[p_dataset["new"]]
    count_widget.value = g_dataset.get_count(category_widget.value)

def update_count_cb(p_count):
    count_widget.value = g_dataset.get_count(p_count["new"])

def save_snapshot_cb(_, content, msg):
    if content["event"] == "click":
        data = content["eventData"]
        x = data["offsetX"]
        y = data["offsetY"]
        g_dataset.save_entry(category_widget.value, camera.value, x, y)
        snapshot = camera.value.copy()
        snapshot = cv2.circle(snapeshot, (x,y), 8, (0, 255, 0), 3)
        snapshot_widget.value = bgr8_to_jpeg(snapshot)
        count_widget.value = g_dataset.get_count(category_widget.value)
        
dataset_widget.observe(set_dataset_cb, names = "value")
category_widget.observe(update_count_cb, names = "value")
camera_widget.on_msg(save_snapshot_cb)

data_collection_widget = ipywidgets.VBox([
    ipywidgets.HBox([camera_widget, snapshot_widget]),
    dataset_widget,
    category_widget,
    count_widget
])
print("Created widget for data collection!")

In [None]:
import torch
import torchvision
from datetime import datetime

device = torch.device("cuda")
output_dim = 2 * len(dataset.categories)  # x, y coordinate for each category

# ALEXNET
# model = torchvision.models.alexnet(pretrained=True)
# model.classifier[-1] = torch.nn.Linear(4096, output_dim)

# SQUEEZENET 
# model = torchvision.models.squeezenet1_1(pretrained=True)
# model.classifier[1] = torch.nn.Conv2d(512, output_dim, kernel_size=1)
# model.num_classes = len(dataset.categories)

# RESNET 18
model = torchvision.models.resnet18(pretrained = True)
model.fc = torch.nn.Linear(512, output_dim)

# RESNET 34
# model = torchvision.models.resnet34(pretrained=True)
# model.fc = torch.nn.Linear(512, output_dim)

model = model.to(device)

model_save_button = ipywidgets.Button(description = "save model")
model_load_button = ipywidgets.Button(description = "load model")
model_path = "/nvdli-nano/data/regression/" + datetime.now().strftime("%d-%m-%Y-%H:%M:%S") + "_xy_model.pth"
model_path_widget = ipywidgets.Text(description = "model path", value = model_path)

def load_model(c):
    model.load_state_dict(torch.load(model_path_widget.value))
model_load_button.on_click(load_model)
    
def save_model(c):
    torch.save(model.state_dict(), model_path_widget.value)
model_save_button.on_click(save_model)

model_save_button.click()

model_widget = ipywidgets.VBox([
    model_path_widget,
    ipywidgets.HBox([model_load_button, model_save_button])
])

# display(model_widget)
print("Configured model and created model widget")

In [1]:
# Demo
import threading
import time
from utils import preprocess
import torch.nn.functional as F

state_widget = ipywidgets.ToggleButtons(options = ["stop", "live"], description = "state", value = "stop")
with open("../images/ready_img.jpg", "rb") as file:
    default_image = file.read()
prediction_widget = ipywidgets.Image(format = "jpeg", 
                                     width = camera.width, 
                                     height = camera.height,
                                     value = default_image)

def live_cb(state_widget, model, camera, prediction_widget):
    global g_dataset
    while state_widget.value == "live":
        image = camera.value
        preprocessed = preprocess(image)
        output = model(preprocessed).detach().cpu().numpy().flatten()
        category_index = g_dateset.categories.index(category_widget.value)
        x = output[2 * category_index]
        x = int(camera.width * (x / 2.0 + 0.5))
        y = output[2 * category_index + 1]
        y = int(camera.height * (y / 2.0 + 0.5))
        prediction = image.copy()
        prediction = cv2.circle(prediction, (x, y), 8, (255, 0, 0), 3)
        prediction_widget.value = bgr8_to_jpeg(prediction)
        
def start_live_cb(p_state):
    if p_state["new"] == "live":
        execute_thread = threading.Thread(target = live, 
                                          args = (state_widget, model, camera, prediction_widget))
        execute_thread.start()

state_widget.observe(start_live_cb, names = "value")
live_execution_widget = ipywidgets.VBox([
    prediction_widget,
    state_widget
])

print("Created demo widgets!")

SyntaxError: unexpected EOF while parsing (839596498.py, line 32)