In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision import datasets, transforms
from torch.utils.data import random_split
from collections import Counter
from matplotlib import pyplot as plt
from datetime import datetime

torch.manual_seed(123)
torch.set_default_dtype(torch.double)

device = (torch.device('cuda') if torch.cuda.is_available()
          else torch.device('cpu'))
print(f"Training on device {device}.")

In [None]:
data_path = '../data/'

# We download the training/validation data to investigate so that we can make educated 
# decisions about our preprocessing.
cifar10_train_val = datasets.CIFAR10(data_path, train=True, download=True)

In [None]:
n_train = int(len(cifar10_train_val)*0.90)
n_val =  len(cifar10_train_val) - n_train

cifar10_train, cifar10_val = random_split(cifar10_train_val, 
                                          [n_train, n_val],
                                          generator=torch.Generator().manual_seed(123))

print("Size of the train dataset:        ", len(cifar10_train))
print("Size of the validation dataset:   ", len(cifar10_val))

In [None]:
sample = cifar10_train[50]

print("Type of the ith element in the dataset  :  ", type(sample))
print("Tuple length of this element            :  ", len(sample))
print("Type of the 1st element of the tuple    :  ", type(sample[0]))
print("Type of the 2nd element of the tuple    :  ", type(sample[1]))

Counter(sorted([int(label) for _, label in cifar10_train]))

In [None]:
# One fig with 2*5 subplots since our labels go from 0-9
fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(8,3))
for i, ax in enumerate(axes.flat):
    img = next(img for img, label in cifar10_train if label == i)
    ax.imshow(img)
    ax.set_title(i)
    ax.axis('off')
    
plt.show()

In [None]:
# We normalize only on the values we are keeping for the assignment.
label_map = {0: 0, 2: 1}
class_map = ["planes", "birds"]
cifar2_train = [(img, label_map[label]) for (img, label) in cifar10_train if label in [0, 2]]

print("Size of the train dataset:        ", len(cifar2_train))

# Compose a transform to use. 
imgs = torch.stack([transforms.functional.to_tensor(img) for img, _ in cifar2_train])
print(imgs.shape) # [9017, 3, 32, 32]
imgs_mean = imgs.mean(dim=(0, 2, 3))
imgs_std  = imgs.std(dim=(0, 2, 3))

preprocessor = transforms.Compose([transforms.ToTensor(),
                                   transforms.Normalize(imgs_mean, imgs_std)])

cifar10_train_val = datasets.CIFAR10(data_path, train=True, download=True, transform=preprocessor)
cifar10_test = datasets.CIFAR10(data_path, train=False, download=True, transform=preprocessor) 

cifar10_train, cifar10_val = random_split(cifar10_train_val, 
                                          [n_train, n_val],
                                          generator=torch.Generator().manual_seed(123))

In [None]:
# We are only keeping the birds and the planes for this assignment.
label_map = {0: 0, 2: 1}
class_map = ["planes", "birds"]
cifar2_train = [(img, label_map[label]) for (img, label) in cifar10_train if label in [0, 2]]
cifar2_val   = [(img, label_map[label]) for (img, label) in cifar10_val   if label in [0, 2]]
cifar2_test  = [(img, label_map[label]) for (img, label) in cifar10_test  if label in [0, 2]]

print("Size of the train dataset:        ", len(cifar2_train))
print("Size of the validation dataset:   ", len(cifar2_val))
print("Size of the test dataset:         ", len(cifar2_test))

In [None]:
class MyMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(32*32*3, 512)
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, 32)
        self.fc4 = nn.Linear(32, 2)
        
    def forward(self, x):
        out = torch.flatten(x, 1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = F.relu(self.fc3(out))
        out = self.fc4(out)
        return out

In [None]:
def train(n_epochs, optimizer, model, loss_fn, train_loader):
    
    n_batch = len(train_loader)
    losses_train = []
    model.train()
    optimizer.zero_grad(set_to_none=True)
    
    
    for epoch in range(1, n_epochs + 1):
        
        loss_train = 0.0
        for imgs, labels in train_loader:
            
            imgs = imgs.to(device=device, dtype=torch.double)
            labels = labels.to(device=device)

            outputs = model(imgs)
            
            loss = loss_fn(outputs, labels)
            loss.backward()
            
            optimizer.step()
            optimizer.zero_grad()

            loss_train += loss.item()
            
        losses_train.append(loss_train / n_batch)

        if epoch == 1 or epoch % 10 == 0:
            print('{}  |  Epoch {}  |  Training loss {:.3f}'.format(
                datetime.now().time(), epoch, loss_train / n_batch))
    return losses_train

In [132]:
def train_manual_update(n_epochs, lr, model, loss_fn, train_loader, weight_decay = 0, momentum = 0):
    
    n_batch = len(train_loader)
    losses_train = []
    model.train()

    with torch.no_grad():
        for epoch in range(1, n_epochs + 1):
            
            loss_train = 0.0
            change = {}
            for name, p in model.named_parameters():
                change[name] = 0.0
            
            for imgs, labels in train_loader:
                
                imgs = imgs.to(device=device, dtype=torch.double) 
                labels = labels.to(device=device)

                outputs = model(imgs)
                
                loss = loss_fn(outputs, labels)
                for p in model.parameters():
                    if p.grad is not None:
                        p.grad = p.grad.add(p, alpha=weight_decay)
                loss.backward()
                
                for name, p in model.named_parameters():
                    new_change = momentum * change[name] + lr * p.grad
                    p.data = p.data - new_change
                    p.grad = torch.zeros(p.grad.shape)
                    change[name] = new_change

                loss_train += loss.item()
                
            losses_train.append(loss_train / n_batch)

            if epoch == 1 or epoch % 10 == 0:
                print('{}  |  Epoch {}  |  Training loss {:.3f}'.format(
                    datetime.now().time(), epoch, loss_train / n_batch))
    return losses_train

In [133]:
train_loader = torch.utils.data.DataLoader(cifar2_train, batch_size=64, shuffle=False)
loss_fn = nn.CrossEntropyLoss()

lrs = [0.01]
decays = [0.001]
momentums = [0, 0.9, 0.99]

hparams = [
    {
        "lr" : lr,
        "weight_decay": w,
        "momentum": m,
    } for lr in lrs for w in decays for m in momentums
]

for param in hparams:
    print(param)
    torch.manual_seed(123)
    model01 = MyMLP()
    optimizer = optim.SGD(model01.parameters(), **param)
    print("======== Regular Train ========")
    train_loss = train(n_epochs = 10, 
                       optimizer = optimizer, 
                       model = model01, 
                       loss_fn = loss_fn, 
                       train_loader = train_loader)
    print()
    print(f"======== Manual update Train ========")
    torch.manual_seed(123)
    model02 = MyMLP()
    train_manual_loss = train_manual_update(n_epochs = 10,  
                                            model = model02, 
                                            loss_fn = loss_fn, 
                                            train_loader = train_loader, 
                                            **param)
    print()

{'lr': 0.01, 'weight_decay': 0.001, 'momentum': 0}
11:57:34.668182  |  Epoch 1  |  Training loss 0.640
11:57:48.254967  |  Epoch 10  |  Training loss 0.322

11:57:50.155520  |  Epoch 1  |  Training loss 0.640
11:58:06.881745  |  Epoch 10  |  Training loss 0.322

{'lr': 0.01, 'weight_decay': 0.001, 'momentum': 0.9}
11:58:08.527557  |  Epoch 1  |  Training loss 0.506
11:58:23.068624  |  Epoch 10  |  Training loss 0.183

11:58:24.947555  |  Epoch 1  |  Training loss 0.505
11:58:41.467659  |  Epoch 10  |  Training loss 0.182

{'lr': 0.01, 'weight_decay': 0.001, 'momentum': 0.99}
11:58:43.130030  |  Epoch 1  |  Training loss 0.519
11:58:58.028237  |  Epoch 10  |  Training loss 0.575

11:58:59.962976  |  Epoch 1  |  Training loss 0.521
11:59:17.120754  |  Epoch 10  |  Training loss 0.356



In [None]:
train_loader = torch.utils.data.DataLoader(cifar2_train, batch_size=64, shuffle=False)
torch.manual_seed(123)
model01 = MyMLP()
optimizer = optim.SGD(model01.parameters(), lr=1e-2)
loss_fn = nn.CrossEntropyLoss()

print("======== Regular Train ========")
train_loss = train(n_epochs = 10, 
                   optimizer = optimizer, 
                   model = model01, 
                   loss_fn = loss_fn, 
                   train_loader = train_loader)

weight_decays = [0.01, 0.001]
print()
for decay in weight_decays:
    print(f"======== Manual update Train with weight decay: {decay} ========")
    torch.manual_seed(123)
    model02 = MyMLP()
    train_manual_loss = train_manual_update(n_epochs = 10, 
                                            lr = 1e-2, 
                                            model = model02, 
                                            loss_fn = loss_fn, 
                                            train_loader = train_loader, 
                                            weight_decay = decay)

momentums = [0.5, 0.8, 0.9, 0.99]
print()
for momentum in momentums:
    print(f"======== Manual update Train with momentum: {momentum} ========")
    torch.manual_seed(123)
    model02 = MyMLP()
    train_manual_loss = train_manual_update(n_epochs = 10, 
                                            lr = 1e-2, 
                                            model = model02, 
                                            loss_fn = loss_fn, 
                                            train_loader = train_loader, 
                                            momentum = momentum)

In [131]:
train_loader = torch.utils.data.DataLoader(cifar2_train, batch_size=64, shuffle=False)
loss_fn = nn.CrossEntropyLoss()

lrs = [0.01, 0.005]
decays = [0, 0.001]
momentums = [0]#, 0.9]

hparams = [
    {
        "lr" : lr,
        "weight_decay": w,
        "momentum": m,
    } for lr in lrs for w in decays for m in momentums
]

for param in hparams:
    print(param)
    torch.manual_seed(123)
    model01 = MyMLP()
    optimizer = optim.SGD(model01.parameters(), **param)
    print("======== Regular Train ========")
    train_loss = train(n_epochs = 10, 
                       optimizer = optimizer, 
                       model = model01, 
                       loss_fn = loss_fn, 
                       train_loader = train_loader)
    print()
    print(f"======== Manual update Train ========")
    torch.manual_seed(123)
    model02 = MyMLP()
    train_manual_loss = train_manual_update(n_epochs = 10,  
                                            model = model02, 
                                            loss_fn = loss_fn, 
                                            train_loader = train_loader, 
                                            **param)
    print()


{'lr': 0.01, 'weight_decay': 0, 'momentum': 0}
11:53:55.815221  |  Epoch 1  |  Training loss 0.640
11:54:08.240339  |  Epoch 10  |  Training loss 0.320

11:54:10.080135  |  Epoch 1  |  Training loss 0.640
11:54:26.642473  |  Epoch 10  |  Training loss 0.320

{'lr': 0.01, 'weight_decay': 0.001, 'momentum': 0}
11:54:28.160908  |  Epoch 1  |  Training loss 0.640
11:54:41.635843  |  Epoch 10  |  Training loss 0.322

11:54:43.519438  |  Epoch 1  |  Training loss 0.640
11:55:00.281533  |  Epoch 10  |  Training loss 0.322

{'lr': 0.005, 'weight_decay': 0, 'momentum': 0}
11:55:01.698132  |  Epoch 1  |  Training loss 0.670
11:55:14.525148  |  Epoch 10  |  Training loss 0.406

11:55:16.424130  |  Epoch 1  |  Training loss 0.670
11:55:33.634166  |  Epoch 10  |  Training loss 0.406

{'lr': 0.005, 'weight_decay': 0.001, 'momentum': 0}
11:55:35.299620  |  Epoch 1  |  Training loss 0.670
11:55:50.025009  |  Epoch 10  |  Training loss 0.407

11:55:52.065648  |  Epoch 1  |  Training loss 0.670
11:56:11