- QUESTION : NORMALIZE AT K FOLD OR NORMALIZE WHOLE ?

In [1]:
import numpy as np

import torch
from torch import nn
from torch.autograd import Variable

# customized libraries
import dlc_bci as bci
import plot_lib as plib
import preprocess as prep
from nn_models import ConvNet2, LSTM
from dlc_practical_prologue import *

if __name__ == "__main__":

    # load dataset
    tr_input_org, tr_target_org = bci.load("bci", train=True, one_khz=False)
    te_input_org, te_target_org = bci.load("bci", train=False, one_khz=False)

    # create outputs with one hot encoding
    tr_target_onehot = convert_to_one_hot_labels(tr_input_org, tr_target_org)
    te_target_onehot = convert_to_one_hot_labels(te_input_org, te_target_org)
    
    # convert output to variable
    tr_target_onehot = Variable(tr_target_onehot)
    te_target_onehot = Variable(te_target_onehot)
    tr_target = Variable(tr_target_org)
    te_target = Variable(te_target_org)

    # normalization
    tr_input_org = torch.nn.functional.normalize(tr_input_org, p=2, dim=0) 
    te_input_org = torch.nn.functional.normalize(te_input_org, p=2, dim=0) 

    # Convert to 4D tensor for CNN [dataset size, number of channels, rows, cols]
    tr_input = tr_input_org[:, np.newaxis, :, :]
    te_input = te_input_org[:, np.newaxis, :, :]

    # convert input to pytorch variable
    tr_input = Variable(tr_input)
    te_input = Variable(te_input)

    # load best convolutional neural network
    # ---------------------------------------
    model = ConvNet2()
    model.load_state_dict(torch.load('models/best_conv2.pth'))

    # compute accuracy
    nb_errors = bci.compute_nb_errors(model, te_input, te_target_onehot, 10)
    print('accuracy = {:0.2f}'.format((te_input.shape[0]-nb_errors)/te_input.shape[0]))
    
    # compute accuracy of best performing LSTM
    # ---------------------------------------
    
    # rearrange data for LSTM [dataset size, timesteps, channels]
    te_input = np.transpose(te_input_org, (0,2,1))
    te_input = Variable(te_input)
    
    model = LSTM(feature_dim = 28, hidden_size=25, batch_size=10)
    model.load_state_dict(torch.load("models/best_lstm.pth"))
    
    # compute accuracy
    nb_errors = bci.compute_nb_errors(model, te_input, te_target_onehot, 10)
    print('accuracy = {:0.2f}'.format((te_input.shape[0]-nb_errors)/te_input.shape[0]))
    
    # --------------------------------------------------------------------------
    # We will train a Long-Short-Term-Memory (LSTM) architecture on the data
    # --------------------------------------------------------------------------
    
    # 1. Prepare data
    # ------------
    
    # Append training set by replicating 34 random samples since the batch size the PyTorch LSTM cannot vary
    tr_input, tr_target = prep.replicate_samples(tr_input_org, tr_target_org, 34)
        
    # Rearrange data for the LSTM to be (dataset size, timesteps, channels)
    tr_input = np.transpose(tr_input, (0,2,1))
    te_input = np.transpose(te_input_org, (0,2,1))
    
    # Prepare one-hot encoding for PyTorch computation of loss
    tr_target_onehot = convert_to_one_hot_labels(tr_input_org, tr_target)
    te_target_onehot = convert_to_one_hot_labels(te_input_org, te_target_org)
    
    # Convert to pytorch variable
    tr_input, tr_target, tr_target_onehot = Variable(tr_input), Variable(tr_target), Variable(tr_target_onehot)
    te_input, te_target, te_target_onehot = Variable(te_input), Variable(te_target_org), Variable(te_target_onehot)
    
    # 2. Begin the first phase of training
    #    where we collect the epochs that correspond to the max validation accuracy in each iteration
    # --------------------------------------------------------------------------------------------
    
    # Get the indices for 5-fold cross validation
    tr_indices, val_indices = prep.cross_validation_batch(tr_input.shape[0], 5) 
    
    # collect the epoch that corresponds to the best validation accuracy at each iterationb
    epochs = []
    
    # iterate through the disjoint sets
    for tr_ind, val_ind in zip(tr_indices, val_indices):
        
        # create LSTM
        model = LSTM(feature_dim = 28, hidden_size=25, batch_size=10)

        # train LSTM
        tr_acc, te_acc = bci.train_model_haziq(model, 
                                           tr_input[tr_ind.numpy()[None]],  tr_target[tr_ind.numpy()[None]],  tr_target_onehot[tr_ind.numpy()[None]],  10, 
                                           tr_input[val_ind.numpy()[None]], tr_target[val_ind.numpy()[None]], tr_target_onehot[val_ind.numpy()[None]], 10, 
                                           200)
    
        # get epoch corresponding to the best validation accuracy
        epoch_max_val_acc = np.argmax(te_acc)

        # show best performing model
        print('tr acc {:.2f}% val acc {:.2f}% at epoch {:d}'.format(tr_acc[epoch_max_val_acc]*100, 
                                                                    te_acc[epoch_max_val_acc]*100,
                                                                    epoch_max_val_acc))

        epochs.append(epoch_max_val_acc)

    # 3. Begin the second and final phase of training
    #   where we train the model with the average of the epochs collected beforehand
    # --------------------------------------------------------------------------------------------
    
    # create LSTM
    model = LSTM(feature_dim = 28, hidden_size=25, batch_size=10)

    # train LSTM
    tr_acc, te_acc = bci.train_model_haziq(model, 
                                      tr_input,  tr_target,  tr_target_onehot,  10, 
                                      te_input,  te_target,  te_target_onehot,  10, 
                                      int(np.average(epochs)))
    
    # 4. Compute accuracy of final model
    # -------------------------------
    nb_errors = bci.compute_nb_errors(model, te_input, te_target_onehot, 10)
    print('accuracy = {:0.2f}'.format((te_input.shape[0]-nb_errors)/te_input.shape[0]))
        

accuracy = 0.85
accuracy = 0.77
epoch 0 tr loss 19.41 val loss 4.86 tr acc 0.521429 val acc 0.485714
epoch 1 tr loss 19.36 val loss 4.85 tr acc 0.521429 val acc 0.528571
epoch 2 tr loss 19.34 val loss 4.85 tr acc 0.578571 val acc 0.485714
epoch 3 tr loss 19.31 val loss 4.85 tr acc 0.571429 val acc 0.528571
epoch 4 tr loss 19.29 val loss 4.85 tr acc 0.553571 val acc 0.542857
epoch 5 tr loss 19.26 val loss 4.85 tr acc 0.557143 val acc 0.542857
epoch 6 tr loss 19.21 val loss 4.85 tr acc 0.557143 val acc 0.514286
epoch 7 tr loss 19.13 val loss 4.90 tr acc 0.560714 val acc 0.528571
epoch 8 tr loss 19.11 val loss 4.89 tr acc 0.575000 val acc 0.571429
epoch 9 tr loss 18.98 val loss 4.92 tr acc 0.578571 val acc 0.571429
epoch 10 tr loss 18.92 val loss 4.94 tr acc 0.589286 val acc 0.557143
epoch 11 tr loss 18.88 val loss 4.99 tr acc 0.582143 val acc 0.571429
epoch 12 tr loss 19.20 val loss 4.83 tr acc 0.571429 val acc 0.485714
epoch 13 tr loss 19.03 val loss 4.81 tr acc 0.596429 val acc 0.55714

epoch 117 tr loss 9.18 val loss 3.53 tr acc 0.882143 val acc 0.757143
epoch 118 tr loss 10.16 val loss 3.30 tr acc 0.860714 val acc 0.757143
epoch 119 tr loss 9.25 val loss 3.58 tr acc 0.857143 val acc 0.742857
epoch 120 tr loss 8.97 val loss 3.81 tr acc 0.875000 val acc 0.757143
epoch 121 tr loss 9.16 val loss 3.93 tr acc 0.889286 val acc 0.757143
epoch 122 tr loss 8.61 val loss 4.05 tr acc 0.882143 val acc 0.728571
epoch 123 tr loss 8.36 val loss 4.11 tr acc 0.885714 val acc 0.728571
epoch 124 tr loss 8.44 val loss 4.06 tr acc 0.889286 val acc 0.728571
epoch 125 tr loss 8.50 val loss 3.90 tr acc 0.892857 val acc 0.742857
epoch 126 tr loss 8.32 val loss 3.64 tr acc 0.882143 val acc 0.728571
epoch 127 tr loss 8.44 val loss 3.40 tr acc 0.871429 val acc 0.771429
epoch 128 tr loss 8.90 val loss 3.55 tr acc 0.860714 val acc 0.771429
epoch 129 tr loss 8.99 val loss 3.83 tr acc 0.875000 val acc 0.757143
epoch 130 tr loss 8.56 val loss 3.71 tr acc 0.871429 val acc 0.742857
epoch 131 tr loss 8

epoch 34 tr loss 15.69 val loss 4.17 tr acc 0.728571 val acc 0.642857
epoch 35 tr loss 15.40 val loss 4.02 tr acc 0.742857 val acc 0.657143
epoch 36 tr loss 15.09 val loss 4.03 tr acc 0.739286 val acc 0.657143
epoch 37 tr loss 14.88 val loss 4.45 tr acc 0.692857 val acc 0.628571
epoch 38 tr loss 16.39 val loss 4.05 tr acc 0.728571 val acc 0.728571
epoch 39 tr loss 15.10 val loss 4.10 tr acc 0.721429 val acc 0.657143
epoch 40 tr loss 14.73 val loss 3.66 tr acc 0.750000 val acc 0.742857
epoch 41 tr loss 14.54 val loss 3.88 tr acc 0.735714 val acc 0.700000
epoch 42 tr loss 14.05 val loss 3.61 tr acc 0.753571 val acc 0.757143
epoch 43 tr loss 14.76 val loss 3.70 tr acc 0.782143 val acc 0.742857
epoch 44 tr loss 14.23 val loss 3.71 tr acc 0.792857 val acc 0.742857
epoch 45 tr loss 13.21 val loss 3.78 tr acc 0.760714 val acc 0.771429
epoch 46 tr loss 13.12 val loss 3.83 tr acc 0.750000 val acc 0.728571
epoch 47 tr loss 14.01 val loss 4.08 tr acc 0.739286 val acc 0.657143
epoch 48 tr loss 15.

epoch 152 tr loss 8.16 val loss 4.08 tr acc 0.892857 val acc 0.814286
epoch 153 tr loss 10.61 val loss 4.64 tr acc 0.807143 val acc 0.728571
epoch 154 tr loss 11.40 val loss 3.86 tr acc 0.867857 val acc 0.771429
epoch 155 tr loss 8.35 val loss 3.62 tr acc 0.900000 val acc 0.814286
epoch 156 tr loss 8.10 val loss 4.06 tr acc 0.867857 val acc 0.757143
epoch 157 tr loss 8.27 val loss 4.39 tr acc 0.878571 val acc 0.771429
epoch 158 tr loss 7.82 val loss 4.06 tr acc 0.892857 val acc 0.771429
epoch 159 tr loss 7.25 val loss 4.49 tr acc 0.885714 val acc 0.728571
epoch 160 tr loss 7.37 val loss 4.12 tr acc 0.885714 val acc 0.757143
epoch 161 tr loss 6.96 val loss 3.95 tr acc 0.889286 val acc 0.757143
epoch 162 tr loss 7.90 val loss 3.98 tr acc 0.875000 val acc 0.742857
epoch 163 tr loss 8.24 val loss 3.94 tr acc 0.882143 val acc 0.742857
epoch 164 tr loss 7.32 val loss 3.71 tr acc 0.896429 val acc 0.785714
epoch 165 tr loss 7.27 val loss 3.77 tr acc 0.910714 val acc 0.800000
epoch 166 tr loss 

epoch 69 tr loss 15.44 val loss 3.83 tr acc 0.735714 val acc 0.700000
epoch 70 tr loss 14.91 val loss 3.61 tr acc 0.814286 val acc 0.800000
epoch 71 tr loss 13.16 val loss 3.21 tr acc 0.803571 val acc 0.828571
epoch 72 tr loss 13.72 val loss 3.57 tr acc 0.796429 val acc 0.742857
epoch 73 tr loss 13.26 val loss 3.26 tr acc 0.810714 val acc 0.785714
epoch 74 tr loss 12.58 val loss 3.37 tr acc 0.796429 val acc 0.785714
epoch 75 tr loss 14.05 val loss 4.06 tr acc 0.771429 val acc 0.700000
epoch 76 tr loss 13.23 val loss 3.21 tr acc 0.810714 val acc 0.828571
epoch 77 tr loss 12.39 val loss 3.31 tr acc 0.810714 val acc 0.814286
epoch 78 tr loss 13.37 val loss 3.32 tr acc 0.792857 val acc 0.742857
epoch 79 tr loss 12.69 val loss 3.00 tr acc 0.817857 val acc 0.828571
epoch 80 tr loss 12.45 val loss 3.52 tr acc 0.803571 val acc 0.828571
epoch 81 tr loss 12.98 val loss 3.29 tr acc 0.807143 val acc 0.828571
epoch 82 tr loss 12.69 val loss 3.21 tr acc 0.800000 val acc 0.842857
epoch 83 tr loss 11.

epoch 186 tr loss 8.32 val loss 2.96 tr acc 0.896429 val acc 0.814286
epoch 187 tr loss 7.83 val loss 3.25 tr acc 0.896429 val acc 0.828571
epoch 188 tr loss 7.15 val loss 3.58 tr acc 0.885714 val acc 0.814286
epoch 189 tr loss 7.07 val loss 4.56 tr acc 0.900000 val acc 0.785714
epoch 190 tr loss 6.77 val loss 4.73 tr acc 0.900000 val acc 0.785714
epoch 191 tr loss 6.97 val loss 5.18 tr acc 0.900000 val acc 0.785714
epoch 192 tr loss 7.13 val loss 3.96 tr acc 0.892857 val acc 0.828571
epoch 193 tr loss 7.87 val loss 4.87 tr acc 0.903571 val acc 0.842857
epoch 194 tr loss 7.07 val loss 4.17 tr acc 0.914286 val acc 0.785714
epoch 195 tr loss 7.37 val loss 4.78 tr acc 0.903571 val acc 0.828571
epoch 196 tr loss 6.61 val loss 4.37 tr acc 0.892857 val acc 0.828571
epoch 197 tr loss 7.50 val loss 5.10 tr acc 0.892857 val acc 0.814286
epoch 198 tr loss 6.87 val loss 4.26 tr acc 0.875000 val acc 0.771429
epoch 199 tr loss 7.37 val loss 5.12 tr acc 0.885714 val acc 0.800000
tr acc 81.79% val ac

epoch 103 tr loss 13.17 val loss 3.54 tr acc 0.810714 val acc 0.714286
epoch 104 tr loss 10.83 val loss 3.46 tr acc 0.839286 val acc 0.785714
epoch 105 tr loss 11.06 val loss 3.35 tr acc 0.810714 val acc 0.757143
epoch 106 tr loss 12.83 val loss 3.29 tr acc 0.828571 val acc 0.728571
epoch 107 tr loss 11.65 val loss 3.38 tr acc 0.821429 val acc 0.742857
epoch 108 tr loss 11.97 val loss 3.45 tr acc 0.817857 val acc 0.757143
epoch 109 tr loss 10.63 val loss 3.32 tr acc 0.810714 val acc 0.771429
epoch 110 tr loss 11.16 val loss 3.49 tr acc 0.835714 val acc 0.742857
epoch 111 tr loss 10.31 val loss 3.19 tr acc 0.782143 val acc 0.728571
epoch 112 tr loss 11.94 val loss 3.43 tr acc 0.760714 val acc 0.785714
epoch 113 tr loss 11.95 val loss 3.52 tr acc 0.835714 val acc 0.742857
epoch 114 tr loss 11.66 val loss 3.61 tr acc 0.817857 val acc 0.771429
epoch 115 tr loss 11.76 val loss 3.48 tr acc 0.842857 val acc 0.757143
epoch 116 tr loss 10.48 val loss 3.21 tr acc 0.796429 val acc 0.757143
epoch 

epoch 20 tr loss 17.83 val loss 4.56 tr acc 0.657143 val acc 0.671429
epoch 21 tr loss 18.04 val loss 4.49 tr acc 0.739286 val acc 0.671429
epoch 22 tr loss 17.46 val loss 4.48 tr acc 0.657143 val acc 0.657143
epoch 23 tr loss 17.84 val loss 4.43 tr acc 0.739286 val acc 0.671429
epoch 24 tr loss 17.15 val loss 4.40 tr acc 0.675000 val acc 0.657143
epoch 25 tr loss 17.46 val loss 4.33 tr acc 0.689286 val acc 0.671429
epoch 26 tr loss 17.06 val loss 4.30 tr acc 0.707143 val acc 0.700000
epoch 27 tr loss 16.87 val loss 4.33 tr acc 0.735714 val acc 0.714286
epoch 28 tr loss 16.70 val loss 4.30 tr acc 0.728571 val acc 0.685714
epoch 29 tr loss 16.78 val loss 4.28 tr acc 0.707143 val acc 0.700000
epoch 30 tr loss 16.45 val loss 4.36 tr acc 0.714286 val acc 0.671429
epoch 31 tr loss 16.68 val loss 4.62 tr acc 0.660714 val acc 0.585714
epoch 32 tr loss 16.75 val loss 4.21 tr acc 0.710714 val acc 0.700000
epoch 33 tr loss 16.17 val loss 4.29 tr acc 0.721429 val acc 0.714286
epoch 34 tr loss 16.

epoch 137 tr loss 9.45 val loss 3.08 tr acc 0.889286 val acc 0.828571
epoch 138 tr loss 8.89 val loss 3.17 tr acc 0.846429 val acc 0.771429
epoch 139 tr loss 9.30 val loss 4.28 tr acc 0.846429 val acc 0.714286
epoch 140 tr loss 9.26 val loss 3.17 tr acc 0.882143 val acc 0.814286
epoch 141 tr loss 8.83 val loss 3.50 tr acc 0.896429 val acc 0.714286
epoch 142 tr loss 8.29 val loss 2.99 tr acc 0.871429 val acc 0.828571
epoch 143 tr loss 8.44 val loss 3.02 tr acc 0.850000 val acc 0.757143
epoch 144 tr loss 9.05 val loss 2.80 tr acc 0.871429 val acc 0.814286
epoch 145 tr loss 8.38 val loss 2.93 tr acc 0.857143 val acc 0.785714
epoch 146 tr loss 8.92 val loss 3.03 tr acc 0.842857 val acc 0.771429
epoch 147 tr loss 8.73 val loss 3.23 tr acc 0.867857 val acc 0.714286
epoch 148 tr loss 8.70 val loss 4.20 tr acc 0.875000 val acc 0.728571
epoch 149 tr loss 9.10 val loss 3.23 tr acc 0.896429 val acc 0.728571
epoch 150 tr loss 8.21 val loss 2.85 tr acc 0.882143 val acc 0.785714
epoch 151 tr loss 8.

epoch 54 tr loss 13.21 val loss 7.54 tr acc 0.840000 val acc 0.760000
epoch 55 tr loss 14.31 val loss 7.52 tr acc 0.822857 val acc 0.650000
epoch 56 tr loss 13.91 val loss 8.37 tr acc 0.825714 val acc 0.690000
epoch 57 tr loss 13.21 val loss 8.36 tr acc 0.845714 val acc 0.740000
epoch 58 tr loss 13.23 val loss 8.04 tr acc 0.851429 val acc 0.690000
epoch 59 tr loss 12.48 val loss 8.01 tr acc 0.845714 val acc 0.680000
epoch 60 tr loss 12.52 val loss 8.09 tr acc 0.857143 val acc 0.720000
epoch 61 tr loss 12.91 val loss 8.39 tr acc 0.865714 val acc 0.660000
epoch 62 tr loss 12.47 val loss 8.13 tr acc 0.845714 val acc 0.710000
epoch 63 tr loss 12.17 val loss 8.54 tr acc 0.854286 val acc 0.660000
epoch 64 tr loss 12.51 val loss 9.05 tr acc 0.860000 val acc 0.630000
epoch 65 tr loss 12.33 val loss 8.65 tr acc 0.860000 val acc 0.680000
epoch 66 tr loss 12.05 val loss 9.12 tr acc 0.862857 val acc 0.650000
epoch 67 tr loss 12.08 val loss 9.35 tr acc 0.862857 val acc 0.640000
epoch 68 tr loss 11.