### Camera

In [None]:
from jetcam.csi_camera import CSICamera
# from jetcam.usb_camera import USBCamera

camera = CSICamera(width=224, height=224)
# camera = USBCamera(width=224, height=224)

camera.running = True

### Task

In [None]:
import torchvision.transforms as transforms
from dataset import XYDataset, HeatmapGenerator

TASK = 'face'
# TASK = 'fingers'

CATEGORIES = ['nose', 'left_eye', 'right_eye']
# CATEGORIES = ['thumbs', 'index', 'middle', 'ring', 'pinky']

DATASETS = ['A', 'B']

TRANSFORMS = transforms.Compose([
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

datasets = {}
for name in DATASETS:
    datasets[name] = XYDataset(TASK + '_' + name, CATEGORIES, TRANSFORMS)

### Data Collection

In [None]:
import cv2
import ipywidgets
import traitlets
from IPython.display import display
from jetcam.utils import bgr8_to_jpeg
from jupyter_clickable_image_widget import ClickableImageWidget


# initialize active dataset
dataset = datasets[DATASETS[0]]

# unobserve all callbacks from camera in case we are running this cell for second time
camera.unobserve_all()

# create image preview
camera_widget = ClickableImageWidget(width=camera.width, height=camera.height)
snapshot_widget = ipywidgets.Image(width=camera.width, height=camera.height)
traitlets.dlink((camera, 'value'), (camera_widget, 'value'), transform=bgr8_to_jpeg)

# create widgets
dataset_widget = ipywidgets.Dropdown(options=DATASETS, description='dataset')
category_widget = ipywidgets.Dropdown(options=dataset.categories, description='category')
count_widget = ipywidgets.IntText(description='count')

# manually update counts at initialization
count_widget.value = dataset.get_count(category_widget.value)

# sets the active dataset
def set_dataset(change):
    global dataset
    dataset = datasets[change['new']]
    count_widget.value = dataset.get_count(category_widget.value)
dataset_widget.observe(set_dataset, names='value')

# update counts when we select a new category
def update_counts(change):
    count_widget.value = dataset.get_count(change['new'])
category_widget.observe(update_counts, names='value')


def save_snapshot(_, content, msg):
    if content['event'] == 'click':
        data = content['eventData']
        x = data['offsetX']
        y = data['offsetY']
        
        # save to disk
        dataset.save_entry(category_widget.value, camera.value, x, y)
        
        # display saved snapshot
        snapshot = camera.value.copy()
        snapshot = cv2.circle(snapshot, (x, y), 8, (0, 255, 0), 3)
        snapshot_widget.value = bgr8_to_jpeg(snapshot)
        count_widget.value = dataset.get_count(category_widget.value)
        
camera_widget.on_msg(save_snapshot)

data_collection_widget = ipywidgets.VBox([
    ipywidgets.HBox([camera_widget, snapshot_widget]),
    dataset_widget,
    category_widget,
    count_widget
])

display(data_collection_widget)

### Model

In [None]:
import torch
import torchvision

device = torch.device('cuda')
output_dim = len(dataset.categories)  # heatmap for each category

# RESNET 18
class HeatmapModel(torch.nn.Module):
    
    def __init__(self, output_dim, upsample_dims=[128, 128]):
        super(HeatmapModel, self).__init__()
        self.feature_extractor = torchvision.models.resnet18(pretrained=True)
        self.upsample_dims = upsample_dims
        upsample_dims = [512] + upsample_dims
        upsample_layers = []
        for i in range(1, len(upsample_dims)):
            upsample_layers += [
                torch.nn.ConvTranspose2d(upsample_dims[i-1], upsample_dims[i], kernel_size=4, stride=2, padding=1),
                torch.nn.BatchNorm2d(upsample_dims[i]),
                torch.nn.ReLU(),
            ]
        self.attention = torch.nn.ConvTranspose2d(upsample_dims[0], upsample_dims[-1], kernel_size=4, stride=4, padding=1, output_padding=2)
        self.upsample = torch.nn.Sequential(*upsample_layers)
        self.final = torch.nn.Conv2d(upsample_dims[-1], output_dim, kernel_size=1, stride=1, padding=0)
    
    def forward(self, x):
        x = self.feature_extractor.conv1(x)
        x = self.feature_extractor.bn1(x)
        x = self.feature_extractor.relu(x)
        x = self.feature_extractor.maxpool(x)

        x = self.feature_extractor.layer1(x)
        x = self.feature_extractor.layer2(x)
        x = self.feature_extractor.layer3(x)
        x = self.feature_extractor.layer4(x) # 512x7x7
        
        y = self.upsample(x)
        x = torch.sigmoid(self.attention(x)) * y
        x = self.final(x)
        
        return x

        
model = HeatmapModel(output_dim, [128, 128])
model = model.to(device)


model_save_button = ipywidgets.Button(description='save model')
model_load_button = ipywidgets.Button(description='load model')
model_path_widget = ipywidgets.Text(description='model path', value='my_heatmap_model.pth')

def load_model(c):
    model.load_state_dict(torch.load(model_path_widget.value))
model_load_button.on_click(load_model)
    
def save_model(c):
    torch.save(model.state_dict(), model_path_widget.value)
model_save_button.on_click(save_model)

model_save_button.click()

model_widget = ipywidgets.VBox([
    model_path_widget,
    ipywidgets.HBox([model_load_button, model_save_button])
])


display(model_widget)

### Live Execution

In [None]:
import threading
import time
from utils import preprocess
import numpy as np
import torch.nn.functional as F

state_widget = ipywidgets.ToggleButtons(options=['stop', 'live'], description='state', value='stop')
prediction_widget = ipywidgets.Image(format='jpeg', width=camera.width, height=camera.height)

def live(state_widget, model, camera, prediction_widget):
    global dataset, HEATMAP_SHAPE
    while state_widget.value == 'live':
        image = camera.value
        preprocessed = preprocess(image)
        output = model(preprocessed).detach().cpu().numpy()
        category_index = dataset.categories.index(category_widget.value)
        heatmap = output[0][category_index]
        heatmap = cv2.resize(heatmap, (image.shape[1], image.shape[0]),interpolation = cv2.INTER_CUBIC)
        idx = heatmap.argmax()
        y, x = np.unravel_index(idx, (image.shape[1], image.shape[0]))
        
        prediction = image.copy()
        prediction = 0.2 * prediction + 0.8 * heatmap[:, :, None] * 255.0
        prediction = cv2.circle(prediction, (x, y), 8, (0, 0, 255), -1)
        prediction_widget.value = bgr8_to_jpeg(prediction)
            
def start_live(change):
    if change['new'] == 'live':
        execute_thread = threading.Thread(target=live, args=(state_widget, model, camera, prediction_widget))
        execute_thread.start()

state_widget.observe(start_live, names='value')

live_execution_widget = ipywidgets.VBox([
    prediction_widget,
    state_widget
])

display(live_execution_widget)

### Training and Evaluation

In [None]:
BATCH_SIZE = 8
HEATMAP_SHAPE = (28, 28)
HEATMAP_STD = 0.1

heatmap_generator = HeatmapGenerator(HEATMAP_SHAPE, HEATMAP_STD)

optimizer = torch.optim.Adam(model.parameters())
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)

epochs_widget = ipywidgets.IntText(description='epochs', value=1)
eval_button = ipywidgets.Button(description='evaluate')
train_button = ipywidgets.Button(description='train')
loss_widget = ipywidgets.FloatText(description='loss')
progress_widget = ipywidgets.FloatProgress(min=0.0, max=1.0, description='progress')

def train_eval(is_training):
    global BATCH_SIZE, LEARNING_RATE, MOMENTUM, model, epochs_widget, dataset, optimizer, eval_button, train_button, accuracy_widget, loss_widget, progress_widget, state_widget
    
    try:
        train_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=BATCH_SIZE,
            shuffle=True
        )
        state_widget.value = 'stop'
        train_button.disabled = True
        eval_button.disabled = True
        time.sleep(1)

        if is_training:
            model = model.train()
        else:
            model = model.eval()
        
        while epochs_widget.value > 0:
            i = 0
            sum_loss = 0.0
            error_count = 0.0
            for images, category_idx, xy in iter(train_loader):
                # send data to device
                images = images.to(device)
                xy = xy.to(device)

                if is_training:
                    # zero gradients of parameters
                    optimizer.zero_grad()

                # execute model to get outputs
                outputs = model(images)

                # compute MSE pixel wise loss from heatmap for category
                loss = 0.0
                for batch_idx, cat_idx in enumerate(list(category_idx.flatten())):
                    heatmap = heatmap_generator.generate_heatmap(list(xy[batch_idx])).to(device)
                    loss += torch.mean((outputs[batch_idx][cat_idx] - heatmap)**2)
                loss /= len(category_idx)

                if is_training:
                    # run backpropogation to accumulate gradients
                    loss.backward()

                    # step optimizer to adjust parameters
                    optimizer.step()

                # increment progress
                count = len(category_idx.flatten())
                i += count
                sum_loss += float(loss)
                progress_widget.value = i / len(dataset)
                loss_widget.value = sum_loss / i
                
            if is_training:
                epochs_widget.value = epochs_widget.value - 1
            else:
                break
    except e:
        pass
    model = model.eval()

    train_button.disabled = False
    eval_button.disabled = False
    state_widget.value = 'live'
    
train_button.on_click(lambda c: train_eval(is_training=True))
eval_button.on_click(lambda c: train_eval(is_training=False))
    
train_eval_widget = ipywidgets.VBox([
    epochs_widget,
    progress_widget,
    loss_widget,
    ipywidgets.HBox([train_button, eval_button])
])

display(train_eval_widget)

### All together!

In [None]:
all_widget = ipywidgets.VBox([
    ipywidgets.HBox([data_collection_widget, live_execution_widget]), 
    train_eval_widget,
    model_widget
])

display(all_widget)