In [None]:
# install progress bar
!pip install tqdm

In [16]:
# imports
import math
import zipfile
import os

import matplotlib.pyplot as plt

from tqdm.notebook import tqdm

import PIL.Image as Image

import numpy as np

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [None]:
# upload observation_model.py from your assignment folder before running this
# run this each time you want to reload the observation_model code
import observation_model
import importlib
import observation_model
importlib.reload(observation_model)

In [None]:
# get the train and test datasets
!wget https://courses.cs.washington.edu/courses/csep590a/23sp/hw1_train_dataset.zip -O hw1_train_dataset.zip
!wget https://courses.cs.washington.edu/courses/csep590a/23sp/hw1_test_dataset.zip -O hw1_test_dataset.zip

In [17]:
# minimized angle function from utils
def minimized_angle(angle):
    """Normalize an angle to [-pi, pi]."""
    while angle < -np.pi:
        angle += 2 * np.pi
    while angle >= np.pi:
        angle -= 2 * np.pi
    return angle

In [18]:
# dataset class for the car data
class CarDataset(Dataset):
    '''
    Returns:
        images as 32x128x3 numpy arrays
        labels as 6 numpy arrays
    '''
    def __init__(self, path, subset=None):
        self.zip = zipfile.ZipFile(path)
        files = self.zip.namelist()
        self.images = sorted([f for f in files if f.endswith('.png')])
        self.labels = sorted([f for f in files if f.endswith('.npy')])
        if subset is not None:
            self.images = self.images[:subset]
            self.labels = self.labels[:subset]

    def __getitem__(self, i):
        image = np.array(Image.open(self.zip.open(self.images[i])))
        label = np.load(self.zip.open(self.labels[i]))
        return image, label

    def __len__(self):
        return len(self.images)

In [19]:
# make train and test datasets
batch_size = 64
train_dataset = CarDataset('hw1_train_dataset.zip')
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = CarDataset('hw1_test_dataset.zip')
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [20]:
# build the model and optimizer
model = observation_model.ObservationModel().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

In [None]:
# train
supervision_mode = 'direction'
num_epochs = 10
losses = []
for epoch in range(1, num_epochs+1):
    print('Epoch: %i'%epoch)
    print('Training')
    model.train()
    iterate = tqdm(train_loader)
    for x, y in iterate:
        # normalize the image between 0 and 1, move to cuda and change the ordering to b,c,h,w
        x = (x.float().cuda() / 255.).permute(0,3,1,2)
        b = x.shape[0]
        y = y.float().view(b,6).cuda()

        x = model(x)

        if supervision_mode == 'theta':
            assert x.shape[1] == 6
        elif supervision_mode == 'direction':
            y = torch.cat([torch.cos(y), torch.sin(y)], dim=1)
            assert x.shape[1] == 12
        else:
            raise ValueError('Unknown supervision_mode: %s'%supervision_mode)
        loss = F.mse_loss(x,y)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        losses.append(float(loss))
        recent_losses = losses[-100:]
        running_loss = sum(recent_losses)/len(recent_losses)
        iterate.set_description('Loss: %.04f'%running_loss)

    plt.plot(np.arange(len(losses)), losses, label='loss')
    plt.legend()
    plt.show()

    print('Evaluating')
    all_errors = []
    model.eval()
    with torch.no_grad():
        for x, y in tqdm(test_loader):
            x = x.float().cuda() / 255.
            x = x.permute(0,3,1,2)
            b = x.shape[0]
            y = y.float().view(b,6).cuda()

            x = model(x)
            if supervision_mode == 'theta':
                theta = x
            elif supervision_mode == 'direction':
                theta = torch.atan2(x[:,6:], x[:,:6])
            else:
                raise ValueError('Unknown supervision_mode: %s'%supervision_mode)
            error = (theta-y).view(-1).cpu().numpy()
            error = [abs(minimized_angle(float(e))) for e in error]
            all_errors.extend(error)
    error_mean = np.mean(all_errors)
    error_std = np.std(all_errors)
    print('Error Mean: %f'%float(error_mean))
    print('Error Std: %f'%float(error_std))
    
    print('Saving Checkpoint')
    state_dict = model.state_dict()
    torch.save(state_dict, 'checkpoint.pt')
