---
# **NEW (future train.py)**
---

In [None]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import dlc_practical_prologue as prologue
import numpy as np
# Our modules
import models # contains all our torch models classes
import plots # custom ploting functions to produce figures of the report
import training_functions # all our functions and classes for training
%matplotlib inline


# Set here which experiment is to be run (running them all will take a long time to compute)
run_LossCompare = False
# Experiments for "PairSetup" : In this setup, we consider directly the pairs as input to the 
# network. Thus our inputs are N samples of [2,14,14] made of two 14*14 pictures. Our ouputs are
# the class 0 or 1  indicating whereas if the first digit is lesser or equal to the second.
run_PairSetup_SimpleLinear = True
run_PairSetup_MLP = False
run_PairSetup_LeNetLike5 = False
run_PairSetup_LeNetLike3 = False
# Experiments for "AuxiliarySetup" : In this setup, we consider N individual 14*14 pictures as 
# input. The network use an auxiliary loss to learn to classify those from 0 to 9. The auxiliary
# outputs are the the class 0 to 9 corresponding to the digit on the picture. We then use this 
# network to predict the number and we can then do the difference to perform our original goal
# which is to predict whereas if the first digit is lesser or equal to the second
run_AuxiliarySetup_SimpleLinear = False
run_AuxiliarySetup_MLP = False

# Set a fixed seed for reproducibility
random_seed = 42

def count_parameters(model):
    """ Returns the number of trainable parameters of the model."""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


if run_LossCompare :
    print('******************** Running Loss comparaison ********************')
    
    print('--- PairSetup, MSE loss : ---')
    torch.manual_seed(random_seed)
    in_dim, out_dim = 14*14*2, 2
    model = models.SimpleLinear(in_dim, out_dim)
    test_err_mean, test_err_std, _, _ = training_functions.rounds_train(model, 
                                                                        'PairSetup',
                                                                        plot_title = 'Test',
                                                                        plot_file_path='./plots/test1.svg',
                                                                        lr = 0.00004,
                                                                        epochs = 300,
                                                                        use_crossentropy=False)
    print('mean minimum test error : {0:.{1}f} %'.format(test_err_mean,1))
    print('std minimum test error : {0:.{1}f} %'.format(test_err_std,1))

    print('--- PairSetup, cross-entropy loss : ---')
    torch.manual_seed(random_seed)
    model = models.SimpleLinear(in_dim, out_dim)
    test_err_mean, test_err_std, _, _ = training_functions.rounds_train(model, 
                                                                        'PairSetup',
                                                                        plot_title = 'Test',
                                                                        plot_file_path='./plots/test2.svg',
                                                                        lr = 0.00004,
                                                                        epochs = 300,
                                                                        use_crossentropy=True)
    print('mean minimum test error : {0:.{1}f} %'.format(test_err_mean,1))
    print('std minimum test error : {0:.{1}f} %'.format(test_err_std,1))
    
    print('--- AuxiliarySetup, MSE loss : ---')
    torch.manual_seed(random_seed)
    in_dim, out_dim = 14*14, 10
    model = models.SimpleLinear(in_dim, out_dim)
    test_err_mean, test_err_std, _, _ = training_functions.rounds_train(model, 
                                                                        'AuxiliarySetup',
                                                                        plot_title = 'Test',
                                                                        plot_file_path='./plots/test1Aux.svg',
                                                                        lr = 0.00004,
                                                                        epochs = 300,
                                                                        use_crossentropy=False)
    print('mean minimum test error : {0:.{1}f} %'.format(test_err_mean,1))
    print('std minimum test error : {0:.{1}f} %'.format(test_err_std,1))

    print('--- AuxiliarySetup, cross-entropy loss : ---')
    torch.manual_seed(random_seed)
    model = models.SimpleLinear(in_dim, out_dim)
    test_err_mean, test_err_std, _, _ = training_functions.rounds_train(model,
                                                                        'AuxiliarySetup',
                                                                        plot_title = 'Test',
                                                                        plot_file_path='./plots/test2Aux.svg',
                                                                        lr = 0.00004,
                                                                        epochs = 300,
                                                                        use_crossentropy=True)
    print('mean minimum test error : {0:.{1}f} %'.format(test_err_mean,1))
    print('std minimum test error : {0:.{1}f} %'.format(test_err_std,1))
    

if run_PairSetup_SimpleLinear :
    torch.manual_seed(random_seed)
    print('******************** Running SimpleLinear model (for PairSetup) ********************')
    in_dim, out_dim = 14*14*2, 2
    model = models.SimpleLinear(in_dim, out_dim)
    test_err_mean, test_err_std, _, _ = training_functions.rounds_train(model,
                                                                        'PairSetup',
                                                                        plot_title = 'Pair Setup Linear Classifier error history',
                                                                        plot_file_path='./plots/pairSetup_SimpleLinear.svg',
                                                                        lr = 0.00004,
                                                                        epochs = 300,
                                                                        use_crossentropy=True)
    print('mean minimum test error : {0:.{1}f} %'.format(test_err_mean,1))
    print('std minimum test error : {0:.{1}f} %'.format(test_err_std,1))
    print('number of trainable parameters : {}'.format(count_parameters(model)))


if run_PairSetup_MLP :
    torch.manual_seed(random_seed)
    print('******************** Running MLP model (for PairSetup) ********************')
    in_dim, out_dim = 14*14*2, 2
    # We test for different hidden layer count vs. neurons per layer count
    L_range = list(range(2, 18, 2))
    h_range = list(range(4, 65, 8))
    test_error_means = np.zeros((len(h_range), len(L_range)))
    test_error_std = np.zeros((len(h_range), len(L_range)))
    for L_idx in range(0, len(L_range)) :
        for h_idx in range(0, len(h_range)) :
    #for L_idx in [7] :
    #    for h_idx in [0] :
            torch.manual_seed(random_seed)
            L = L_range[L_idx]
            h = h_range[h_idx]
            print('testing with L = {} and h = {}'.format(L, h))
            model = models.MLP(L, h, in_dim, out_dim)
            test_err_mean, test_err_std, _, _ = training_functions.rounds_train(model,
                                                                                'PairSetup',
                                                                                lr = 0.0001,
                                                                                plot_title = 'test',
                                                                                plot_file_path='./plots/test.svg',
                                                                                epochs = 200,
                                                                                use_crossentropy=True,
                                                                                rounds=4)
            test_error_means[h_idx, L_idx] = test_err_mean
            test_error_std[h_idx, L_idx] = test_err_std
    # Plot heat table
    plots.plot_error_table(h_range, L_range, test_error_means, 
                 test_error_std, 'Pair Setup MLP models mean/std minimum test error',
                          './plots/pairSetup_MLP.svg')
    
if run_PairSetup_LeNetLike5 :
    torch.manual_seed(random_seed)
    print('******************** Running LetNetLike5 model (for PairSetup) ********************')
    in_depth, out_dim = 2, 2
    model = models.LeNetLike5(in_depth, out_dim)
    test_err_mean, test_err_std, _, _ = training_functions.rounds_train(model, 
                                                                        'PairSetup',
                                                                        plot_title = 'Pair Setup LeNetLike5 error history',
                                                                        plot_file_path='./plots/pairSetup_LeNetLike5.svg',
                                                                        lr = 0.0001,
                                                                        epochs = 300,
                                                                        use_crossentropy=True,
                                                                        verbose=True)
    print('mean minimum test error : {0:.{1}f} %'.format(test_err_mean,1))
    print('std minimum test error : {0:.{1}f} %'.format(test_err_std,1))
    print('number of trainable parameters : {}'.format(count_parameters(model)))
    

if run_PairSetup_LeNetLike3 :
    torch.manual_seed(random_seed)
    print('******************** Running LetNetLike3 model (for PairSetup) ********************')
    in_depth, out_dim = 2, 2
    model = models.LeNetLike3(in_depth, out_dim)
    test_err_mean, test_err_std, _, _ = training_functions.rounds_train(model, 
                                                                        'PairSetup',
                                                                        plot_title = 'Pair Setup LeNetLike3 error history',
                                                                        plot_file_path='./plots/pairSetup_LeNetLike3.svg',
                                                                        lr = 0.00015,
                                                                        epochs = 300,
                                                                        use_crossentropy=True,
                                                                        verbose=True)
    print('mean minimum test error : {0:.{1}f} %'.format(test_err_mean,1))
    print('std minimum test error : {0:.{1}f} %'.format(test_err_std,1))
    print('number of trainable parameters : {}'.format(count_parameters(model)))

---
# **OLD (DRAFT)**
---

goal : predicts if pair's 1st digit <= to the second (=0) or if pair's 1st digit > to the second (=1)

Ideas of architectures to test :
- Simple MLP (fully connected)
- LetNet5
- AlexNet
- VGGNet19
- Residual Net
- Use cross entropy
- Use dropout

General framework to test :
- Network is trained to predict directly lesser or greater
- Network is trained to predict number, then we do the difference

In this notebook we explore the first architecture only