In [1]:
import torch
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
import glob
import PIL.Image
import os
import numpy as np

In [2]:
!unzip -q dataset.zip

In [2]:
def get_x(path):
    return (float(int(path[3:6])) - 50.0) / 50.0

def get_y(path):
    return (float(int(path[7:10])) - 50.0) / 50.0

class XYDataset(torch.utils.data.Dataset):
    
    def __init__(self, directory, random_hflips=False):
        self.directory = directory
        self.random_hflips = random_hflips
        self.image_paths = glob.glob(os.path.join(self.directory, '*.jpg'))
        self.color_jitter = transforms.ColorJitter(0.3, 0.3, 0.3, 0.3)
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        
        image = PIL.Image.open(image_path)
        x = float(get_x(os.path.basename(image_path)))
        y = float(get_y(os.path.basename(image_path)))
        
        if float(np.random.rand(1)) > 0.5:
            image = transforms.functional.hflip(image)
            x = -x
        
        image = self.color_jitter(image)
        image = transforms.functional.resize(image, (224, 224))
        image = transforms.functional.to_tensor(image)
        image = image.numpy()[::-1].copy()
        image = torch.from_numpy(image)
        image = transforms.functional.normalize(image, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        
        return image, torch.tensor([x, y]).float()

In [3]:
dataset = XYDataset('dataset_xy_circuitlaunch', random_hflips=True)

test_percent = 0.1
num_test = int(test_percent * len(dataset))
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len(dataset) - num_test, num_test])

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True,
    num_workers=4
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=64,
    shuffle=True,
    num_workers=4
)

In [4]:
model = models.resnet18(pretrained=True)
model.fc = torch.nn.Linear(512, 2)
device = torch.device('cuda')
model = model.to(device)

# train regression:

In [5]:
optimizer = optim.Adam(model.parameters())

In [6]:
NUM_EPOCHS = 70
BEST_MODEL_PATH = 'best_steering_model.pth'
best_loss = 1e9


for epoch in range(NUM_EPOCHS):
    
    model.train()
    train_loss = 0.0
    for images, labels in iter(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = F.mse_loss(outputs, labels)
        train_loss += loss
        loss.backward()
        optimizer.step()
    train_loss /= len(train_loader)
    
    model.eval()
    test_loss = 0.0
    for images, labels in iter(test_loader):
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        loss = F.mse_loss(outputs, labels)
        test_loss += loss
    test_loss /= len(test_loader)
    
    print('%f, %f' % (train_loss, test_loss))
    if test_loss < best_loss:
        torch.save(model.state_dict(), BEST_MODEL_PATH)
        best_loss = test_loss

0.823004, 1.204488
0.174736, 2.766335
0.094046, 1.054770
0.046999, 0.311390
0.046917, 0.138634
0.041596, 0.054409
0.045934, 0.032579
0.025213, 0.036667
0.019342, 0.054106
0.020445, 0.053553
0.021594, 0.015138
0.012808, 0.020826
0.010818, 0.010255
0.008438, 0.016040
0.011277, 0.019981
0.007359, 0.012663
0.007887, 0.009259
0.006764, 0.015335
0.008961, 0.012016
0.008861, 0.010902
0.009332, 0.006472
0.004761, 0.007985
0.010770, 0.006781
0.014343, 0.024147
0.021354, 0.005343
0.017153, 0.020461
0.010697, 0.012314
0.007760, 0.009137
0.006347, 0.006487
0.005761, 0.007441
0.005959, 0.009119
0.015088, 0.010405
0.008613, 0.008489
0.010845, 0.007104
0.010172, 0.016222
0.013410, 0.010640
0.006793, 0.008927
0.005220, 0.007308
0.003848, 0.006757
0.004711, 0.004913
0.011253, 0.010160
0.017309, 0.007785
0.012248, 0.014320
0.013776, 0.007978
0.008262, 0.007350
0.013271, 0.009128
0.009480, 0.006401
0.007379, 0.009416
0.006468, 0.006665
0.008706, 0.004083
0.013776, 0.007487
0.009652, 0.006663
0.014193, 0.