---
# **NEW (future train.py)**
---

In [None]:
import torch
from torch import Tensor
from torch import nn
from torch.utils.data import Dataset, DataLoader
import dlc_practical_prologue as prologue
import numpy as np
# Our modules
import models # contains all our torch models classes
import plots # custom plotting functions to produce figures of the report


# Set here which experiment is to be run (running them all will take a long time to compute)

# Experiments for "PairSetup" : In this setup, we consider directly the pairs as input to the 
# network. Thus our inputs are N samples of [2,14,14] made of two 14*14 pictures. Our ouputs are
# the class 0 or 1  indicating whereas if the first digit is lesser or equal to the second.
run_PairSetup_SimpleLinear = True
run_PairSetup_MLP = True
#...

# Experiments for "AuxiliarySetup" : In this setup, we consider N individual 14*14 pictures as 
# input. The network use an auxiliary loss to learn to classify those from 0 to 9. The auxiliary
# outputs are the the class 0 to 9 corresponding to the digit on the picture. We then use this 
# network to predict the number and we can then do the difference to perform our original goal
# which is to predict whereas if the first digit is lesser or equal to the second
run_AuxiliarySetup_SimpleLinear = True
run_AuxiliarySetup_MLP = True
#...

class TrainPairsDataset(Dataset):
    """ 
    PyTorch Dataset for holding MNIST train pairs. 
    Arguments: 
        - train_input: a torch tensor of size [N, 2, 14, 14] containing the N training pairs
        - train_target: hot-encoded target torch tensor of size [N , 2]
        - augment_data: boolean, if True, data will be augmented by inversing pairs
    """
    def __init__(self, train_input, train_target, augment_data=False):
        if augment_data :
            # Create the inversed pairs (data augmentation)
            train_input_rev = train_input[:,[1,0],:,:]
            train_target_rev = train_target[:,[1,0]]
            self.train_input = torch.cat((train_input,train_input_rev))
            self.train_target = torch.cat((train_target,train_target_rev))
        else :
            self.train_input = train_input
            self.train_target = train_target
            
    def __len__(self):
        return len(self.train_input)

    def __getitem__(self, idx):
        return {'pair': self.train_input[idx], 'target': self.train_target[idx]}

    
def hot_encode(data):
    """ 
    2-class hot encoding.
    Arguments:
        - data: torch tensor of size N of 0s and 1s
    Returns:
        - hot-encoded torch tensor of size [N,2]
    """
    col_view = data.view(-1,1)
    return torch.cat((col_view == 0, col_view == 1), dim=1).float()


def hot_decode(data):
    """
    2-class hot decoding. Performs the inverse of the function hot_encode().
    Arguments:
        - data: hot-encoded torch tensor of size [N,2]
    Returns:
        - hot-decoded torch tensor of size N of 0s and 1s
    """
    return torch.argmax(data, dim=1).long()


def compute_errors(output, target):
    """ 
    Computes error percentage given output and target
    Arguments:
        - output: torch tensor of [N,2] of predicted scores for each class
        - target: hot-encoded target torch tensor of size [N,2]
    Returns:
        - error %
    """
    errors_amount = (output.argmax(dim=1) != target.argmax(dim=1)).sum().item()
    return (errors_amount / output.shape[0]) * 100


def train(model, train_input, train_target, test_input, test_target, 
          use_crossentropy = False, lr=1e-3, epochs = 200, verbose=False) :
    """ 
    Trains the given model using the given train and test dataset. Returns the
    train & test error % history.
    Arguments:
        - model: torch model to train
        - train_input: torch tensor of train input data
        - train_target: hot-encoded train target class
        - test_input: torch tensor of test input data
        - test_target: hot-encoded test target class
        - use_crossentropy: boolean, if True, crossentropy loss will be used (train 
        target data will be dencoded in order to use this loss). If False MSE loss 
        will be used.
        - lr: learning rate
        - epochs: number of epochs to train with
        - verbose: if True, a dot '.' will be printed at each new epoch
    Returns:
        - (train_errors, test_errors), the train and test error % histories
    """ 
    batch_size = 100
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr) 
    criterion = nn.CrossEntropyLoss() if use_crossentropy else nn.MSELoss()
    
    trainPairsDataset = TrainPairsDataset(train_input, train_target)
    dataloader = DataLoader(trainPairsDataset, batch_size=batch_size, shuffle=True, 
                            num_workers=0)
    train_errors = []
    test_errors = []
    
    for e in range(epochs):
        if verbose :
            print('.', end='')
        for i_batch, batch in enumerate(dataloader):
            inputPairs = batch['pair']
            # If we use crossentropy, then we don't want hot-encoding of train target class but 
            # directly their class.
            target = hot_decode(batch['target']) if use_crossentropy else batch['target']
            # Forward pass
            output = model(inputPairs)
            # Compute loss
            loss = criterion(output, target)
            # Backprop & update parameters
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        with torch.no_grad():
            # Compute train error
            output_train = model(train_input)
            train_errors.append(compute_errors(output_train, train_target))
            # Compute test error
            output_test = model(test_input)
            test_errors.append(compute_errors(output_test, test_target))
                
    return train_errors, test_errors


def rounds_train(model, rounds=10, augment_data=False, use_crossentropy=True, 
                          lr=1e-3, epochs=200, verbose=False, plot_title = None,
                          plot_file_path = None):
    """
    Trains the model multiple times with randomized fresh new data. For each training, we keep
    only the best test error (early stopping) with its corresponding train error. Then we compute
    and return the average and standard deviation of those best test errors and their corresponding 
    train errors.
    Arguments:
        - model: torch model to train
        - rounds: number of experiment (new training) to perform
        - augment_data : boolean, if True, data will be augmented by inversing pairs
        - use_crossentropy: boolean, if True, crossentropy loss will be used. If False MSE loss will 
        be used.
        - lr: learning rate
        - epochs: number of epochs to train with at each round
        - verbose: if True, the current round will be printed at each round
        - plot_title : str, if plot_title and plot_file_path are not None, then a figure with 
        the given titlt will be saved at the given path
        - plot_file_path : str, if plot_title and plot_file_path are not None, then a figure with 
        the given titlt will be saved at the given path
    Returns:
        - (min_test_mean, min_test_std, min_tran_mean, min_train_std), the average and standard deviation
        of the min-train (and corresponding test) errors %.
    """
    min_test_errors = []
    corresponding_train_errors = []
    train_errors_histories = []
    test_errors_histories = []
    
    # Number of samples to generate each round (=1000 as per the project instructions)
    N=1000  
    
    for i in range(0, rounds):
        if verbose :
            print('round n°{}'.format(i+1))
        # Load new data (data is randomized at each round as required in the project instructions)
        train_input, train_target, _ , test_input, test_target, _ = prologue.generate_pair_sets(N)
        # Hot encoding (train targets will be later dencoded if the loss is cross-entropy)
        train_target = hot_encode(train_target)
        test_target = hot_encode(test_target)
        # Create dataset (let use use shuffled batch, but still reproducible thanks to the manual seed)
        pairsDataset = TrainPairsDataset(train_input, train_target, augment_data=augment_data)
        # Train
        train_errors, test_errors = train(model, train_input, train_target, test_input, test_target, 
          use_crossentropy = use_crossentropy, lr=lr, epochs=epochs)
        # Store those histories
        train_errors_histories.append(train_errors)
        test_errors_histories.append(test_errors)
        # Use use early stopping, we take the train & test error where the test error was the smallest
        min_test_errors.append(min(test_errors))
        corresponding_train_errors.append(train_errors[test_errors.index(min(test_errors))])
    
    # Plot figure
    if (plot_title != None) and (plot_file_path != None):
        plots.plot_errors(train_errors_histories, test_errors_histories, plot_title, plot_file_path)
    
    # Compute and return mean/std of min test error and its corresponding train error    
    return (np.mean(min_test_errors), np.std(min_test_errors), 
            np.mean(corresponding_train_errors), np.std(corresponding_train_errors))


# Set a fixed seed for reproducibility
random_seed = 42

if run_PairSetup_SimpleLinear :
    print('******************** Running SimpleLinear model (for PairSetup) ********************')
    torch.manual_seed(random_seed)
    in_dim = 14*14*2
    out_dim = 2
    model = models.SimpleLinear(in_dim, out_dim)
    test_err_mean, test_err_std, _, _ = rounds_train(model,
                                                     plot_title = 'Linear Train & Test errors',
                                                     plot_file_path='./plots/pairSetup_SimpleLinear.eps',
                                                     verbose=True)
    print('done')

if run_PairSetup_MLP :
    print('******************** Running MLP model (for PairSetup) ********************')
    torch.manual_seed(random_seed)
    

---
# **OLD (DRAFT)**
---

goal : predicts if pair's 1st digit <= to the second (=0) or if pair's 1st digit > to the second (=1)

Ideas of architectures to test :
- Simple MLP (fully connected)
- LetNet5
- AlexNet
- VGGNet19
- Residual Net
- Use cross entropy
- Use dropout

General framework to test :
- Network is trained to predict directly lesser or greater
- Network is trained to predict number, then we do the difference

In this notebook we explore the first architecture only

# Dataset

In [None]:
import torch
from torch import Tensor
from torch import nn
from torch.utils.data import Dataset, DataLoader
import dlc_practical_prologue as prologue

import numpy as np

import matplotlib.pyplot as plt
%matplotlib inline

# Set a fixed seed for reproducibility
random_seed = 42
torch.manual_seed(random_seed)

In [None]:
class PairsDataset(Dataset):
    
    def __init__(self, train_input, train_target, augment_data=False):
        if augment_data :
            # TODO : Full data augmentation instead (randomized pairs)
            # Create the inversed pairs (data augmentation)
            train_input_rev = train_input[:,[1,0],:,:]
            train_target_rev = train_target[:,[1,0]]
            self.train_input = torch.cat((train_input,train_input_rev))
            self.train_target = torch.cat((train_target,train_target_rev))
        else :
            self.train_input = train_input
            self.train_target = train_target

    def __len__(self):
        return len(self.train_input)

    def __getitem__(self, idx):
        return {'pair': self.train_input[idx], 'target': self.train_target[idx]}

In [None]:
def hot_encode(data):
    """ 2-class hot encoding of target """
    col_view = data.view(-1,1)
    return torch.cat((col_view == 0, col_view == 1), dim=1).float()

In [None]:
def hot_deencode(data):
    """ from hot-encoding back to class labels """
    return torch.argmax(data, dim=1).long()

# Models

# Train & Test

In [None]:
def compute_errors(output, target):
    """ Computes error percentage given output and target"""
    errors_amount = (output.argmax(dim=1) != target.argmax(dim=1)).sum().item()
    return (errors_amount / output.shape[0]) * 100

In [None]:
def train(model, train_input, train_target, test_input, test_target, 
          use_crossentropy = False, lr=1e-3, epochs = 200, verbose=False) :
    
    batch_size = 200
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr) 
    criterion = nn.CrossEntropyLoss() if use_crossentropy else nn.MSELoss()
    
    pairsDataset = PairsDataset(train_input, train_target)
    dataloader = DataLoader(pairsDataset, batch_size=batch_size, shuffle=True, 
                            num_workers=0)
    train_errors = []
    test_errors = []
    
    for e in range(epochs):
        if verbose :
            print('.', end='')
        for i_batch, batch in enumerate(dataloader):
            inputPairs = batch['pair']
            # If we use crossentropy, then we don't want hot-encoding of target class but 
            # directly their class.
            target = hot_deencode(batch['target']) if use_crossentropy else batch['target']
            # Forward pass
            output = model(inputPairs)
            # Compute loss
            loss = criterion(output, target)
            # Backprop & update parameters
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        with torch.no_grad():
            # Compute train error
            output_train = model(train_input)
            train_errors.append(compute_errors(output_train, train_target))
            # Compute test error
            output_test = model(test_input)
            test_errors.append(compute_errors(output_test, test_target))
                
    return train_errors, test_errors

In [None]:
def plot_errors(train_errors, test_errors, title):
    plt.title(title)
    plt.xlabel('epoch')
    plt.ylabel('error %')
    plt.plot(train_errors, label='train error %')
    plt.plot(test_errors, label='test error %')
    plt.ylim(0, 60)
    plt.legend()
    plt.grid()
    plt.show()
    print('Last train error : {}'.format(test_errors[-1]))
    print('Smallest train error : {}'.format(min(test_errors)))

In [None]:
def multiple_rounds_train(model, rounds=10, N=1000, augment_data=False, use_crossentropy=True, 
                          lr=1e-3, epochs=200, verbose=False):
    
    min_test_errors = []
    corresponding_train_errors = []
    
    
    
    for i in range(0, rounds):
        if verbose :
            print('round {}'.format(i+1))
        # Load new data (data is randomized at each round as required in the project instructions)
        train_input, train_target, _ , test_input, test_target, _ = prologue.generate_pair_sets(N)
        # Hot encoding (target will be de-encoded if the loss is cross-entropy)
        train_target = hot_encode(train_target)
        test_target = hot_encode(test_target)
        # Create dataset (let use use shuffled batch, but still reproducible thanks to the manual seed)
        pairsDataset = PairsDataset(train_input, train_target, augment_data=augment_data)
        # Train
        train_errors, test_errors = train(model, train_input, train_target, test_input, test_target, 
          use_crossentropy = use_crossentropy, lr=lr, epochs=epochs)
        # Store those histories
        train_errors_histories.append(train_errors)
        test_errors_histories.append(test_errors)
        # Use use early stopping, we take the train & test error where the test error was the smallest
        min_test_errors.append(min(test_errors))
        corresponding_train_errors.append(train_errors[test_errors.index(min(test_errors))])
    
    # Compute and return mean/std of min test error and its corresponding train error    
    return (np.mean(min_test_errors), np.std(min_test_errors), 
            np.mean(corresponding_train_errors), np.std(corresponding_train_errors))

In [None]:
N=1000
# Load new data (data is randomized at each round as required in the project instructions)
train_input, train_target, _ , test_input, test_target, _ = prologue.generate_pair_sets(N)
# Hot encoding (target will be de-encoded if the loss is cross-entropy)
train_target = hot_encode(train_target)
test_target = hot_encode(test_target)
# Create dataset (let use use shuffled batch, but still reproducible thanks to the manual seed)
pairsDataset = PairsDataset(train_input, train_target, augment_data=False)

In [None]:
model = MLP(L, h)
train_errors, test_errors = train(model, train_input, train_target, test_input, test_target, 
                                  use_crossentropy=False)
plot_errors(train_errors, test_errors, 'title')

In [None]:
# MLP Models
L_range = list(range(2, 17, 2))
h_range = list(range(4, 34, 4))
test_error_means = np.zeros((len(h_range), len(L_range)))
test_error_std = np.zeros((len(h_range), len(L_range)))

print('********** MLP models **********')
for L_idx in range(0, len(L_range)) :
    for h_idx in range(0, len(h_range)) :
        L = L_range[L_idx]
        h = h_range[h_idx]
        print('testing with L = {} and h = {}'.format(L, h))
        model = MLP(L, h)
        test_err_mean, test_err_std, _, _ = multiple_rounds_train(model, use_crossentropy=False)
        test_error_means[h_idx, L_idx] = test_err_mean
        test_error_std[h_idx, L_idx] = test_err_std

In [None]:
def plot_error_table(h_range, L_range, table_mean, table_std, title):
    
    h_labels = ['h = {}'.format(h) for h in h_range]
    L_labels = ['L = {}'.format(L) for L in L_range]
    
    fig, ax = plt.subplots(figsize=(6,6))
    im = ax.imshow(table_mean)
    
    ax.set_xticks(np.arange(len(L_labels)))
    ax.set_yticks(np.arange(len(h_labels)))
    ax.set_xticklabels(L_labels)
    ax.set_yticklabels(h_labels)
    
    # Rotate the tick labels and set their alignment
    plt.setp(ax.get_xticklabels(), rotation=45, ha='right',
             rotation_mode='anchor')
    
    # Loop over data dimensions and create text annotations.
    for i in range(len(h_labels)):
        for j in range(len(L_labels)):
            text = ax.text(j, i, '{0:.{1}f}'.format(table_mean[i, j], 1),
                           ha='center', va='bottom', color='w')
            text = ax.text(j, i, '({0:.{1}f})'.format(table_std[i, j], 1),
                           ha='center', va='top', color='#FFFFFF80')
    
    ax.set_title(title)
    fig.tight_layout()
    plt.show()
    
plot_error_table(h_range, L_range, test_error_means, 
                 test_error_std, 'MLP Min-Error % Mean & std for L hidden layers with h neurons each')

In [None]:
# Reset seed between each experiment so that we can run experiments in any orders
# without impacting reporducibility
torch.manual_seed(random_seed)

In [None]:
# LeNet5Like Model
for dropout in [False, True] :
    model = LeNet5Like(dropout=dropout)
    train_errors, test_errors = train(model, train_input, train_target, test_input, test_target,
                                     use_crossentropy = True)
    plot_errors(train_errors, test_errors, 'LeNet5Like, DropOut = {}'.format(dropout))
    del model

In [None]:
# Custom "LeNet3" Model
model = LeNet3()
train_errors, test_errors = train(model, train_input, train_target, test_input, test_target,
                                  use_crossentropy = True)
plot_errors(train_errors, test_errors, 'LeNet3, {} hidden layers, {} neurons per layer'.format(L,h))
del model

In [None]:
# VGGNetLike Model
for dropout in [False, True] :
    model = VGGNetLike(dropout=dropout)
    train_errors, test_errors = train(model, train_input, train_target, test_input, test_target, 
                                     use_crossentropy = True)
    plot_errors(train_errors, test_errors, 'VGGNetLike, DropOut = {}'.format(dropout))
    del model

In [None]:
# Custom "ConvResNet" Model
model = ConvResNet()
train_errors, test_errors = train(model, train_input, train_target, test_input, test_target,
                                  use_crossentropy = True)
plot_errors(train_errors, test_errors, 'ConvResNet, {} hidden layers, {} neurons per layer'.format(L,h))
del model