# Apex Tracking - Train Model

After following the data collection notebook, you should now have a dataset of images along with x, y coordinates of the apex.

Now, in this notebook we'll train a neural network to predict these x, y coordinates given the image as input.  To do this, we'll minimize
the mean squared error between the target point and output point.

First, let's create the ``XYDataset`` class.  Make sure that it's pointed the the same directory that you set in the ``data_collection`` notebook.

The XYDataset class we create also accepts a ``torchvision`` transform that is applied to the image.  We'll add ColorJitter, to modify the brightness, contrast, saturation and hue of the input image, as well as normalization.  We use the same normalization parameters that were used for the original models trained on ImageNet in the torchvision package.

We've also added a parameter to the XYDataset called ``random_hflip``, which indicates if we should randomly flip the input image horizontally.  When we add this parameter, the ``x`` value of the target is negated so that the target falls on the same point in the mirrored image.

In [1]:
import torchvision
from xy_dataset import XYDataset


transform = torchvision.transforms.Compose([
    torchvision.transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


dataset = XYDataset('apex_dataset', transform=transform, random_hflip=True)

Next, we'll create a loader that uses this dataset and generates batches of shuffled samples to train our neural network.

In [2]:
import torch

loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=16,
    shuffle=True
)

Now, let's define our neural network.  We'll start from a resnet18 model architecture, which is originally pretrained on ImageNet for classification.

We replace the final layer (which originally has 1000 outputs for the 1000 imagenet classes), with a linear layer that has just 2 classes, for our x, y coordinates.

In [5]:
model = torchvision.models.resnet18(pretrained=True)
model.fc = torch.nn.Linear(512, 2)

device = torch.device('cuda')
model = model.to(device)

Next, we'll create the optimizer that we'll use to train the neural network.  We'll use the ``Adam`` optimizer with default parameters.

In [6]:
optimizer = torch.optim.Adam(model.parameters())

Finally, execute the following cell to train the neural network for the number of epochs specified

In [7]:
import torch.nn.functional as F

EPOCHS = 30

for epoch in range(EPOCHS):
    
    epoch_loss = 0.0
    
    for image, xy in iter(loader):
        
        image = image.to(device)
        xy = xy.to(device)
        
        optimizer.zero_grad()
        
        xy_out = model(image)
        
        loss = F.mse_loss(xy_out, xy)
        
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss
        
    epoch_loss /= len(loader)
    
    print('%d: %f' % (epoch, epoch_loss))

0: 0.445279
1: 0.077040
2: 0.043039
3: 0.057193
4: 0.053531
5: 0.044843
6: 0.040853
7: 0.058367
8: 0.052599
9: 0.044711
10: 0.021138
11: 0.030148
12: 0.020463
13: 0.043705
14: 0.036735
15: 0.020968
16: 0.016151
17: 0.011030
18: 0.013973
19: 0.014218
20: 0.014213
21: 0.011286
22: 0.004724
23: 0.012168
24: 0.008409
25: 0.013980
26: 0.018960
27: 0.013675
28: 0.010431
29: 0.015584


Finally, we save the model for use in the live demo notebook.

In [8]:
torch.save(model.state_dict(), 'apex_model.pth')