In [23]:
import torch
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
import glob
import PIL.Image
import os
import numpy as np
from collections import namedtuple

In [2]:
BATCH_SIZE = 64
Task = namedtuple('Task', ['dataset', 'loss_function', 'loss_multiplier', 'output_dim'])

# Datasets

## Blocks Dataset

In [3]:
dataset_blocks = datasets.ImageFolder(
    'dataset_blocks',
    transforms.Compose([
        transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
)

## Transition Dataset

In [4]:
dataset_transition = datasets.ImageFolder(
    'dataset_transition',
    transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
)

## Steering Datasets

In [24]:
def get_steering(path):
    return (float(int(path[9:12])) - 50.0) / 50.0

class MergedSteeringDataset(torch.utils.data.Dataset):
    
    def __init__(self, all_dir, task_dir, flipped=False):
        self.flipped = flipped
        self.all_dir = all_dir
        self.task_dir = task_dir
        self.all_image_paths = glob.glob(os.path.join(self.all_dir, '*.jpg'))
        self.task_image_paths = glob.glob(os.path.join(self.task_dir, '*.jpg'))
        self.color_jitter = transforms.ColorJitter(0.1, 0.1, 0.1, 0.1)
    
    def __len__(self):
        return len(self.all_image_paths) + len(self.task_image_paths)
    
    def __getitem__(self, idx):
        flip_all = False
        if idx < len(self.all_image_paths):
            image_path = self.all_image_paths[idx]
            
            if float(np.random.rand(1)) > 0.5:
                flip_all = True
        else:
            image_path = self.task_image_paths[idx - len(self.all_image_paths)]
        
        image = PIL.Image.open(image_path)
        steering = float(get_steering(os.path.basename(image_path)))
        
        if self.flipped:
            image = transforms.functional.hflip(image)
            steering = -steering
        
        if flip_all:
            image = transforms.functional.hflip(image)
            steering = -steering
            
        image = self.color_jitter(image)
        image = transforms.functional.resize(image, (224, 224))
        image = transforms.functional.to_tensor(image)
        image = transforms.functional.normalize(image, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        
        #target = get_gaussian(steering, self.stdev, self.nbins)
        #target = get_bin(steering, self.nbins)
        
        return image, torch.tensor([steering]).float()

dataset_forward = MergedSteeringDataset('dataset_all', 'dataset_forward', flipped=False)
dataset_left = MergedSteeringDataset('dataset_all', 'dataset_left', flipped=False)
dataset_right = MergedSteeringDataset('dataset_all', 'dataset_left', flipped=True)

## Combine

In [48]:
tasks = [
    Task(dataset_blocks, torch.nn.functional.cross_entropy, 1.0, 7),
    Task(dataset_transition, torch.nn.functional.cross_entropy, 1.0, 2),
    #Task(dataset_forward, torch.nn.functional.mse_loss, 1.0, 1),
    #Task(dataset_left, torch.nn.functional.mse_loss, 1.0, 1),
    #Task(dataset_right, torch.nn.functional.mse_loss, 1.0, 1)
]

loaders = [torch.utils.data.DataLoader(t.dataset, batch_size=BATCH_SIZE, shuffle=True) for t in tasks]

# Model

In [49]:
output_dim = 0
for t in tasks:
    output_dim += t.output_dim
    
model = models.resnet18(pretrained=True)
model.fc = torch.nn.Linear(512, output_dim)
device = torch.device('cuda')
model = model.to(device)

# Optimizer

In [50]:
optimizer = optim.Adam(model.parameters())

# Train

In [53]:
NUM_BATCHES = 200

iters = [iter(l) for l in loaders]

model.train()

for batch in range(NUM_BATCHES):
    
#     net_loss = 0.0
    
    start_idx = 0
    
    losses = []
    # compute loss for each tasks next batch
    for i, task in enumerate(tasks):
        optimizer.zero_grad()
        
        try:
            image, label = next(iters[i])
        except:
            iters[i] = iter(loaders[i]) # create new iterator, end reached of previous
            image, label = next(iters[i])
        
        image = image.to(device)
        label = label.to(device)
        
        output = model(image)[:, start_idx:start_idx + task.output_dim]
        
        loss = task.loss_function(output, label)
        
#         net_loss += task.loss_multiplier * loss
        
        start_idx += task.output_dim
        losses += [float(loss)]
    
#         net_loss.backward()
        loss.backward()
        optimizer.step()
    
    print('%s' % losses)
    print('%d, %f' % (batch, float(net_loss)))

[0.005575467366725206, 0.010622795671224594, 0.008507998660206795]
0, 0.024871
[0.005525545217096806, 0.004213564097881317, 0.011882449500262737]
1, 0.024871
[0.004807657562196255, 0.009023392572999, 0.004948530346155167]
2, 0.024871
[0.023161552846431732, 0.011391853913664818, 0.013667738065123558]
3, 0.024871
[0.020563384518027306, 0.033288754522800446, 0.013314269483089447]
4, 0.024871
[0.01103946752846241, 0.016444789245724678, 0.009933341294527054]
5, 0.024871
[0.008462561294436455, 0.0077315135858953, 0.031208762899041176]
6, 0.024871
[0.025809641927480698, 0.013420622795820236, 0.01141023263335228]
7, 0.024871
[0.005238179583102465, 0.029687602072954178, 0.005655947141349316]
8, 0.024871
[0.027094285935163498, 0.015508221462368965, 0.004364926367998123]
9, 0.024871
[0.004546883516013622, 0.007655059918761253, 0.003508918220177293]
10, 0.024871
[0.008974513970315456, 0.08767539262771606, 0.0064724222756922245]
11, 0.024871
[0.021273519843816757, 0.015175367705523968, 0.0117744347

In [54]:
torch.save(model.state_dict(), 'best_model_multitask_steering_400.pth')