# Model Training

I apply transfer learning over the existing AlexNet model using the images marked free or blocked that I captured as mentioned in the Data Collection notebook. 

My model uses monocular depth perception to distinguish between the two classes of situations. Specifically, the model distinguishes between the two classes based on the amount of floor it can see between the bottom of the image and the bottom of the nearest object in front of it.

Empirically this can be confirmed by moving the camera and retesting what the Jetbot considers blocked or free - if the camera is pointing down to the floor, it get closer to an object that it considers a blocking object because it believes there’s more safe space to go forward, and if the camera is pointing straight in front it can see less of the floor and hence it stops far from the obstacle. 

Starting with importing all the necessary libraries

In [1]:
import torch
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms

Note: Do not run code in the next cell if you already have the dataset directory. It might get stuck because it is waiting for your response to the overwrite prompt.

In [2]:
!unzip -q dataset.zip

Set up the dataset folder to be read into memory in the desired format

In [3]:
dataset = datasets.ImageFolder(
    'dataset',
    transforms.Compose([
        transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
)

Split the dataset between train and test images, with at least 50 images in the test set.

In [4]:
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len(dataset) - 50, 50])

Set up the DataLoaders for most efficient loading while training instead of bringing all files into memory at once

In [5]:
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True,
    num_workers=4
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=16,
    shuffle=True,
    num_workers=4
)

Download the AlexNet model if not already cached on the Jetbot and open it for use

In [6]:
model = models.alexnet(pretrained=True)

The Alexnet model was originally trained for a dataset which has 1000 class labels. Because my dataset has only 2 class labels, I replace the top-most or final layer with a new, untrained layer that has only two outputs. 

In [11]:
model.classifier[6] = torch.nn.Linear(model.classifier[6].in_features, 2)

Transfer the model for training on the GPU

In [8]:
device = torch.device('cuda')
model = model.to(device)

Train the model based on the dataset over 30 epochs, and save the current best model at every iteration while printing test accuracy. 

In [9]:
NUM_EPOCHS = 30
BEST_MODEL_PATH = 'best_model.pth'
best_accuracy = 0.0

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

for epoch in range(NUM_EPOCHS):
    
    for images, labels in iter(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = F.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()
    
    test_error_count = 0.0
    
    for images, labels in iter(test_loader):
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        test_error_count += float(torch.sum(torch.abs(labels - outputs.argmax(1))))
    
    test_accuracy = 1.0 - float(test_error_count) / float(len(test_dataset))
    print('%d: %f' % (epoch, test_accuracy))
    
    if test_accuracy > best_accuracy:
        torch.save(model.state_dict(), BEST_MODEL_PATH)
        best_accuracy = test_accuracy

0: 0.960000
1: 0.980000
2: 0.960000
3: 0.980000
4: 0.980000
5: 0.960000
6: 0.980000
7: 1.000000
8: 0.960000
9: 0.960000
10: 0.980000
11: 0.980000
12: 0.980000
13: 0.980000
14: 0.980000
15: 0.960000
16: 0.980000
17: 0.980000
18: 0.980000
19: 1.000000
20: 0.960000
21: 0.980000
22: 0.980000
23: 0.980000
24: 0.980000
25: 0.960000
26: 0.980000
27: 0.980000
28: 0.980000
29: 0.960000
