### Import library

In [1]:
%matplotlib notebook
# %load_ext autoreload
# %autoreload 2
#%matplotlib qt 


# Setup

In [2]:
# !pip install moabb[full]
# !pip install braindecode

### Select Dataset 


In [3]:
from nu_smrutils import loaddat
import pandas as pd
import pickle
import os
from sklearn.model_selection import train_test_split
import torch
import mne



In [4]:
# !pip install moabb[full]
# !pip install braindecode
# !pip install matplotlib==3.7.1
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117

Looking in indexes: https://download.pytorch.org/whl/cu117




In [5]:
# # itemname is one of : ['BNCI2014004', 'BNCI2014001', 'Weibo2014', 'Physionet']
# itemname = 'BNCI2014004'
# filename = dname[itemname]
# iname = itemname + '__'    

# Load datasets

In [46]:
from moabb.datasets import (
    BNCI2014001, 
    BNCI2014004, 
    PhysionetMI, 
    Weibo2014
)
from moabb.paradigms import LeftRightImagery

# Preprocesses datasets and saves the results locally.
# If local copyis available, returns
class EEGDatasets:
    # 
    def __init__(self, local_dir=None):
        self.local_dir = local_dir or os.path.join(os.getcwd(), 'eeg-data')
    
    def moabb_dataset(self, name, subjects):
        if  name == 'BNCI2014001':
            raw_dataset = BNCI2014001()
            paradigm = LeftRightImagery()
            X, labels, meta = paradigm.get_data(dataset=raw_dataset, subjects=subjects)
            return X, labels
        elif  name == 'BNCI2014004':
            raw_dataset = BNCI2014004()
            paradigm = LeftRightImagery()
            X, labels, meta = paradigm.get_data(dataset=raw_dataset, subjects=subjects)
            return X, labels
        else:
            raise ValueError(f'unknown dataset name {name}')
                    
datasets = EEGDatasets()        
itemname = 'BNCI2014001'
iname = itemname + '__'   
data, labels = datasets.moabb_dataset(itemname, [1])

print(data.shape)
print(labels.shape)

(288, 22, 1001)
(288,)


### Load pooled data


In [48]:
from nu_smrutils import load_pooled, augment_dataset, crop_data

In [49]:
def load_pooled1(data, labels, test_size=0.15):
    """
    Creates pooled data from all subject specific EEG dataset.          

    Parameters:
    -------------------------
    Input: a python list containing MNE EEG data objects. 

    For instance, a list with the following elements:    
    [<Epochs  |   720 events, 'left_hand': 360  'right_hand': 360>,
     <Epochs   |  680 events, 'left_hand': 340, 'right_hand': 340>]

    Returns:
    -------------------------
    A dictionary :
        X_train, X_valid, X_test: 
        np.array of shape >>>  (samples, channel, times), 

        Data labels: 
        y_train, y_valid, y_test
    -------------------------
    output = dict(xtrain = X_train, xvalid = X_valid, xtest = X_test,
                  ytrain = y_train, yvalid = y_valid, ytest = y_test)
    -------------------------    
    """

    X = data
    Y = (labels == 'right_hand').astype(int)

    # split the data using sklearn split function
    x_rest, x_test, y_rest, y_test =\
        train_test_split(X, Y, test_size=test_size, random_state=42,
                         stratify=Y)

    x_train, x_valid, y_train, y_valid =\
        train_test_split(x_rest, y_rest, test_size=0.2, random_state=42,
                         stratify=y_rest)

    # Convert to Pytorch tensors
    X_train, X_valid, X_test = map(torch.FloatTensor,(x_train, x_valid, x_test))
    y_train, y_valid, y_test = map(torch.FloatTensor,(y_train, y_valid, y_test))

    return dict(xtrain=X_train, xvalid=X_valid, xtest=X_test,
                ytrain=y_train, yvalid=y_valid, ytest=y_test)

dat = load_pooled1(data, labels, test_size = 0.15)


### Data augmentation 

In [50]:
print(dat.keys())
dat['xtrain'].shape

dict_keys(['xtrain', 'xvalid', 'xtest', 'ytrain', 'yvalid', 'ytest'])


torch.Size([195, 22, 1001])

In [51]:
augment_dataset?

In [52]:
augdata = dict(std_dev = 0.01, multiple = 2)

In [53]:
xtrain, ytrain = augment_dataset(dat['xtrain'], dat['ytrain'], 
                                 augdata['std_dev'], augdata['multiple'])

print("Shape after data augmentation :", xtrain.shape)
dat['xtrain'], dat['ytrain'] = xtrain, ytrain

Shape after data augmentation : torch.Size([390, 22, 1001])


### Data Cropping

In [54]:
fs = 80 # sampling frequency 
crop_len = 1.5 #or None
crop = dict(fs = fs, crop_len = crop_len)

#if crop['crop_len']:
X_train,y_train = crop_data(crop['fs'],crop['crop_len'], 
                            dat['xtrain'], dat['ytrain'], 
                            xpercent = 50)

X_valid,y_valid = crop_data(crop['fs'],crop['crop_len'], 
                            dat['xvalid'], dat['yvalid'], 
                            xpercent = 50)

X_test, y_test  = crop_data(crop['fs'],crop['crop_len'], 
                            dat['xtest'], dat['ytest'], 
                            xpercent = 50)

dat = dict(xtrain = X_train, xvalid = X_valid, xtest = X_test,
           ytrain = y_train, yvalid = y_valid, ytest = y_test)

In [55]:
print('data shape after cropping :',dat['xtrain'].shape)

data shape after cropping : torch.Size([2730, 22, 180])


### Pytorch dataloaders 

In [56]:
import torch 
from torch.utils.data import TensorDataset, DataLoader  

def get_data_loaders(dat, batch_size, EEGNET = None):    
    # convert data dimensions to into to gray scale image format
    if EEGNET: ### EEGNet model requires the last dimension to be 1 
        ff = lambda dat: torch.unsqueeze(dat, dim = -1)    
    else:
        ff = lambda dat: torch.unsqueeze(dat, dim = 1)    
    
    x_train, x_valid, x_test = map(ff,(dat['xtrain'], dat['xvalid'],dat['xtest']))    
    y_train, y_valid, y_test = dat['ytrain'], dat['yvalid'], dat['ytest']
    print('Input data shape', x_train.shape)       
    
    # TensorDataset & Dataloader    
    train_dat    = TensorDataset(x_train, y_train) 
    val_dat      = TensorDataset(x_valid, y_valid) 
    
    train_loader = DataLoader(train_dat, batch_size = batch_size, shuffle = True, generator=torch.Generator(device='cuda'))
    val_loader   = DataLoader(val_dat,   batch_size = batch_size, shuffle = False, generator=torch.Generator(device='cuda'))

    output = dict(dset_loaders = {'train': train_loader, 'val': val_loader}, 
                  dset_sizes  =  {'train': len(x_train), 'val': len(x_valid)},
                  test_data   =  {'x_test' : x_test, 'y_test' : y_test})          
    return output 

In [57]:
dat = get_data_loaders(dat, batch_size = 64)
dat.keys()

Input data shape torch.Size([2730, 1, 22, 180])


dict_keys(['dset_loaders', 'dset_sizes', 'test_data'])

In [58]:
# Sanity check begin 
dset_loaders = dat['dset_loaders']
dset_sizes = dat['dset_sizes']
dset_sizes

dtrain = dset_loaders['train']
dval   = dset_loaders['val']

dtr = iter(dtrain)
dv  = iter(dval)

In [59]:
inputs, labels = next(dtr)
print(inputs.shape, labels.shape)
# Sanity check end 

torch.Size([64, 1, 22, 180]) torch.Size([64])


## CNN model

In [60]:
import torch.nn as nn
import numpy as np

class CNN2D(torch.nn.Module):  
    def __init__(self, input_size, kernel_size, conv_channels, 
                 dense_size, dropout):         
        super(CNN2D, self).__init__()                  
        self.cconv   = []  
        self.MaxPool = nn.MaxPool2d((1, 2), (1, 2))  
        self.ReLU    = nn.ReLU()
        self.Dropout = nn.Dropout(dropout)        
        self.batchnorm = []                
        # ############ batchnorm ###########
        for jj in conv_channels:
            self.batchnorm.append(nn.BatchNorm2d(jj, eps=0.001, momentum=0.01,
                                                 affine=True, track_running_stats=True).cuda())     
        ii = 0 ##### define CONV layer architecture: #####
        for in_channels, out_channels in zip(conv_channels, conv_channels[1:]):                           
            conv_i = torch.nn.Conv2d(in_channels = in_channels, out_channels = out_channels,
                                     kernel_size = kernel_size[ii], #stride = (1, 2),
                                     padding     = (kernel_size[ii][0]//2, kernel_size[ii][1]//2))            
            self.cconv.append(conv_i)                
            self.add_module('CNN_K{}_O{}'.format(kernel_size[ii], out_channels), conv_i)
            ii += 1                            
        self.flat_dim = self.get_output_dim(input_size, self.cconv)    
        self.fc1 = torch.nn.Linear(self.flat_dim, dense_size)
        self.fc2 = torch.nn.Linear(dense_size, 2)                

    def get_output_dim(self, input_size, cconv):        
        with torch.no_grad():
            input = torch.ones(1,*input_size)              
            for conv_i in cconv:                
                input = self.MaxPool(conv_i(input))        
                flatout = int(np.prod(input.size()[1:]))
                print("Input shape : {} and flattened : {}".format(input.shape, flatout))
        return flatout 
        
    def forward(self, input):        
        for jj, conv_i in enumerate(self.cconv):
            input = conv_i(input)
            input = self.batchnorm[jj+1](input)
            input = self.ReLU(input)        
            input = self.MaxPool(input)                   
        # flatten the CNN output     
        out = input.view(-1, self.flat_dim) 
        out = self.fc1(out)                       
        out = self.Dropout(out)        
        out = self.fc2(out)      
        return out        

### Hyperparameter settings

In [61]:
import torch 
torch.manual_seed(0)

from nu_smrutils import train_model  

dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if dev.type == 'cuda':
   print('Your GPU device name :', torch.cuda.get_device_name())

Your GPU device name : NVIDIA GeForce RTX 2080


In [62]:
num_epochs = 150 
learning_rate = 1e-3
weight_decay = 1e-4  
batch_size = 64
verbose = 1

#% used to save the results table 
results = {}        
table = pd.DataFrame(columns = ['Train_Acc', 'Val_Acc', 'Test_Acc', 'Epoch'])   

#% get input size (channel x timepoints)
input_size = (1, dat['test_data']['x_test'].shape[-2], 
                 dat['test_data']['x_test'].shape[-1])
print(input_size)

(1, 22, 180)


Relate the **kernel width** hyperparameter to a temporal window in milliseconds    

- If we want to convolve 100 ms >>> set time_window = 100 #ms
- width = (time_window_of_interest * sampling_frequency)/one_second 

In [63]:
# define kernel size in terms of ms length 
fs = 80 #Hz
time_window = 100 #ms
width = time_window*fs//1000  

# width = 8 #timelength//chans         
# convolution parameters 
h1, w1 = 3, 1
h2, w2 = 3, 3
h3, w3 = 3, 5       

In [64]:
# one should run this script twice with ConvDown = True or False 
# to have different convolutional layer patterns  
ConvDOWN = True  

if ConvDOWN:            
    params = {'conv_channels': [
                                [1, 16, 8],                                               
                                [1, 32, 16, 8],
                                [1, 64, 32, 16, 8],
                                [1, 128, 64, 32, 16, 8],
                                [1, 256, 128, 64, 32, 16, 8]                                     
                                ],                         

              'kernel_size':    [[(h1, w1*width), (h1, w1*width), (h1, w1*width),
                                  (h1, w1*width),(h1, w1*width),(h1, w1*width)],
                                 
                                 [(h2, w2*width), (h2, w2*width), (h2, w2*width),
                                  (h2, w2*width),(h2, w2*width),(h2, w2*width)],
                                 
                                 [(h3, w3*width), (h3, w3*width), (h3, w3*width),
                                  (h3, w3*width),(h3, w3*width),(h3, w3*width)]]                                                                      
              }                      
else:                      
    params = {'conv_channels': [
                                [1, 8, 16],                                                  
                                [1, 8, 16, 32],
                                [1, 8, 16, 32, 64],
                                [1, 8, 16, 32, 64, 128],
                                [1, 8, 16, 32, 64, 128, 256]
                                ],      		

              'kernel_size':    [[(h1, w1*width), (h1, w1*width), (h1, w1*width),
                                  (h1, w1*width),(h1, w1*width),(h1, w1*width)],
                                 
                                 [(h2, w2*width), (h2, w2*width), (h2, w2*width),
                                  (h2, w2*width),(h2, w2*width),(h2, w2*width)],
                                 
                                 [(h3, w3*width), (h3, w3*width), (h3, w3*width),
                                  (h3, w3*width),(h3, w3*width),(h3, w3*width)]]                     
              }    
keys = list(params)

------------
## Training loop 

In [65]:
import itertools

In [None]:
verbose=2

for values in itertools.product(*map(params.get, keys)):
    d = dict(zip(keys, values))    
    description = 'C{}_K{}'.format(d['conv_channels'], d['kernel_size'][0])    
    print('\n\n##### ' + description + ' #####')

    # Define the architecture
    model = CNN2D(input_size    = input_size,
                  kernel_size   = d['kernel_size'], 
                  conv_channels = d['conv_channels'],
                  dense_size    = 256,
                  dropout       = 0.5)               
    print("Model architecture >>>", model)

    # optimizer and the loss function definition 
    optimizer = torch.optim.Adam(model.parameters(), 
                                 lr = learning_rate,
                                 weight_decay = weight_decay)
    criterion = torch.nn.CrossEntropyLoss()

    # move the model to GPU/CPU
    model.to(dev)  
    criterion.to(dev)       

    #******** Training loop *********    
    best_model, train_losses, val_losses, train_accs, val_accs, info =\
        train_model(model, dat['dset_loaders'], dat['dset_sizes'], 
                    criterion, optimizer, dev, lr_scheduler = None, 
                    num_epochs = num_epochs, verbose = verbose)    

    test_samples = 100
    x_test = dat['test_data']['x_test'][:test_samples,:,:,:] 
    y_test = dat['test_data']['y_test'][:test_samples] 
    
    # predict test data 
    preds = best_model(x_test.to(dev)) 
    preds_class = preds.data.max(1)[1]

    # get the accuracy 
    corrects = torch.sum(preds_class == y_test.data.to(dev))     
    test_acc = corrects.cpu().numpy()/x_test.shape[0]
    print("Test Accuracy :", test_acc) 

    # save results       
    tab = dict(Train_Acc= train_accs[info['best_epoch']],
               Val_Acc  = val_accs[info['best_epoch']],   
               Test_Acc = test_acc, Epoch = info['best_epoch'] + 1)         

    table.loc[description] = tab  
    val_acc = np.max(val_accs)

    print(table)
    results[description] = dict(train_accs = train_accs, val_accs =  val_accs,                                
                                ytrain = info['ytrain'], yval= info['yval'])      
    
    # save the best model weights
    fname = 'm_' + iname + 'CNN_POOLED' + description + '_' + str(val_acc)[:4]
    torch.save(best_model.state_dict(), fname) 



##### C[1, 16, 8]_K(3, 8) #####
Input shape : torch.Size([1, 16, 22, 90]) and flattened : 31680
Input shape : torch.Size([1, 8, 22, 45]) and flattened : 7920
Model architecture >>> CNN2D(
  (MaxPool): MaxPool2d(kernel_size=(1, 2), stride=(1, 2), padding=0, dilation=1, ceil_mode=False)
  (ReLU): ReLU()
  (Dropout): Dropout(p=0.5, inplace=False)
  (CNN_K(3, 8)_O16): Conv2d(1, 16, kernel_size=(3, 8), stride=(1, 1), padding=(1, 4))
  (CNN_K(3, 8)_O8): Conv2d(16, 8, kernel_size=(3, 8), stride=(1, 1), padding=(1, 4))
  (fc1): Linear(in_features=7920, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=2, bias=True)
)
Epoch 1/150
train loss: 0.0209, acc: 0.5626
val loss: 0.0133, acc: 0.5656
Epoch 2/150
train loss: 0.0104, acc: 0.6875
val loss: 0.0206, acc: 0.5423
Epoch 3/150
train loss: 0.0082, acc: 0.7454
val loss: 0.0173, acc: 0.5510
Epoch 4/150
train loss: 0.0070, acc: 0.7846
val loss: 0.0152, acc: 0.6093
Epoch 5/150
train loss: 0.0055, acc: 0.8399
val loss: 0.0154

train loss: 0.0003, acc: 0.9974
val loss: 0.0332, acc: 0.6443
Epoch 103/150
train loss: 0.0002, acc: 0.9996
val loss: 0.0355, acc: 0.6443
Epoch 104/150
train loss: 0.0001, acc: 1.0000
val loss: 0.0365, acc: 0.6560
Epoch 105/150
train loss: 0.0001, acc: 1.0000
val loss: 0.0368, acc: 0.6501
Epoch 106/150
train loss: 0.0001, acc: 1.0000
val loss: 0.0384, acc: 0.6385
Epoch 107/150
train loss: 0.0001, acc: 1.0000
val loss: 0.0377, acc: 0.6531
Epoch 108/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0385, acc: 0.6531
Epoch 109/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0388, acc: 0.6501
Epoch 110/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0392, acc: 0.6531
Epoch 111/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0399, acc: 0.6531
Epoch 112/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0400, acc: 0.6501
Epoch 113/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0405, acc: 0.6501
Epoch 114/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0406, acc: 0.6385
Epoch 115/150
train loss: 

train loss: 0.0000, acc: 1.0000
val loss: 0.0403, acc: 0.6297
Epoch 44/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0403, acc: 0.6239
Epoch 45/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0402, acc: 0.6268
Epoch 46/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0403, acc: 0.6297
Epoch 47/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0405, acc: 0.6297
Epoch 48/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0410, acc: 0.6239
Epoch 49/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0409, acc: 0.6210
Epoch 50/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0407, acc: 0.6239
Epoch 51/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0407, acc: 0.6239
Epoch 52/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0407, acc: 0.6239
Epoch 53/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0407, acc: 0.6239
Epoch 54/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0405, acc: 0.6210
Epoch 55/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0401, acc: 0.6268
Epoch 56/150
train loss: 0.0000, acc: 

train loss: 0.0453, acc: 0.5077
val loss: 0.0169, acc: 0.4869
Epoch 2/150
train loss: 0.0122, acc: 0.5755
val loss: 0.0137, acc: 0.5219
Epoch 3/150
train loss: 0.0103, acc: 0.6253
val loss: 0.0152, acc: 0.5248
Epoch 4/150
train loss: 0.0099, acc: 0.6524
val loss: 0.0123, acc: 0.5685
Epoch 5/150
train loss: 0.0081, acc: 0.7425
val loss: 0.0129, acc: 0.6152
Epoch 6/150
train loss: 0.0076, acc: 0.7608
val loss: 0.0133, acc: 0.5802
Epoch 7/150
train loss: 0.0062, acc: 0.8253
val loss: 0.0166, acc: 0.5948
Epoch 8/150
train loss: 0.0063, acc: 0.8165
val loss: 0.0144, acc: 0.6472
Epoch 9/150
train loss: 0.0054, acc: 0.8407
val loss: 0.0181, acc: 0.6356
Epoch 10/150
train loss: 0.0051, acc: 0.8535
val loss: 0.0172, acc: 0.6239
Epoch 11/150
train loss: 0.0031, acc: 0.9249
val loss: 0.0207, acc: 0.5948
Epoch 12/150
train loss: 0.0025, acc: 0.9333
val loss: 0.0251, acc: 0.6064
Epoch 13/150
train loss: 0.0031, acc: 0.9161
val loss: 0.0254, acc: 0.5860
Epoch 14/150
train loss: 0.0028, acc: 0.9253
v

train loss: 0.0000, acc: 1.0000
val loss: 0.0454, acc: 0.6297
Epoch 112/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0453, acc: 0.6472
Epoch 113/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0459, acc: 0.6327
Epoch 114/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0459, acc: 0.6472
Epoch 115/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0453, acc: 0.6472
Epoch 116/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0454, acc: 0.6443
Epoch 117/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0454, acc: 0.6385
Epoch 118/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0460, acc: 0.6414
Epoch 119/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0459, acc: 0.6268
Epoch 120/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0452, acc: 0.6385
Epoch 121/150
train loss: 0.0761, acc: 0.7502
val loss: 0.0583, acc: 0.4927
Epoch 122/150
train loss: 0.0177, acc: 0.6359
val loss: 0.0124, acc: 0.6327
Epoch 123/150
train loss: 0.0083, acc: 0.7403
val loss: 0.0132, acc: 0.6443
Epoch 124/150
train loss: 

train loss: 0.0000, acc: 1.0000
val loss: 0.0326, acc: 0.6589
Epoch 51/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0328, acc: 0.6735
Epoch 52/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0326, acc: 0.6706
Epoch 53/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0329, acc: 0.6764
Epoch 54/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0326, acc: 0.6706
Epoch 55/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0326, acc: 0.6706
Epoch 56/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0326, acc: 0.6706
Epoch 57/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0325, acc: 0.6618
Epoch 58/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0326, acc: 0.6647
Epoch 59/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0325, acc: 0.6706
Epoch 60/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0327, acc: 0.6764
Epoch 61/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0326, acc: 0.6647
Epoch 62/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0324, acc: 0.6706
Epoch 63/150
train loss: 0.0000, acc: 

train loss: 0.0303, acc: 0.5143
val loss: 0.0146, acc: 0.5160
Epoch 2/150
train loss: 0.0125, acc: 0.5209
val loss: 0.0130, acc: 0.5452
Epoch 3/150
train loss: 0.0112, acc: 0.5637
val loss: 0.0128, acc: 0.5306
Epoch 4/150
train loss: 0.0103, acc: 0.6172
val loss: 0.0115, acc: 0.6268
Epoch 5/150
train loss: 0.0094, acc: 0.6714
val loss: 0.0117, acc: 0.6327
Epoch 6/150
train loss: 0.0087, acc: 0.7172
val loss: 0.0131, acc: 0.6035
Epoch 7/150
train loss: 0.0085, acc: 0.7201
val loss: 0.0126, acc: 0.6239
Epoch 8/150
train loss: 0.0075, acc: 0.7663
val loss: 0.0127, acc: 0.6210
Epoch 9/150
train loss: 0.0072, acc: 0.7864
val loss: 0.0123, acc: 0.6210
Epoch 10/150
train loss: 0.0058, acc: 0.8385
val loss: 0.0130, acc: 0.6385
Epoch 11/150
train loss: 0.0048, acc: 0.8777
val loss: 0.0155, acc: 0.6268
Epoch 12/150
train loss: 0.0051, acc: 0.8564
val loss: 0.0179, acc: 0.6268
Epoch 13/150
train loss: 0.0034, acc: 0.9154
val loss: 0.0185, acc: 0.6647
Epoch 14/150
train loss: 0.0015, acc: 0.9696
v

train loss: 0.0000, acc: 1.0000
val loss: 0.0366, acc: 0.6531
Epoch 112/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0367, acc: 0.6560
Epoch 113/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0367, acc: 0.6501
Epoch 114/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0367, acc: 0.6531
Epoch 115/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0369, acc: 0.6356
Epoch 116/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0371, acc: 0.6443
Epoch 117/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0369, acc: 0.6385
Epoch 118/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0366, acc: 0.6501
Epoch 119/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0372, acc: 0.6531
Epoch 120/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0368, acc: 0.6618
Epoch 121/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0369, acc: 0.6385
Epoch 122/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0367, acc: 0.6618
Epoch 123/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0370, acc: 0.6560
Epoch 124/150
train loss: 

train loss: 0.0000, acc: 1.0000
val loss: 0.0388, acc: 0.6268
Epoch 51/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0387, acc: 0.6297
Epoch 52/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0386, acc: 0.6327
Epoch 53/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0385, acc: 0.6356
Epoch 54/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0386, acc: 0.6356
Epoch 55/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0386, acc: 0.6327
Epoch 56/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0385, acc: 0.6356
Epoch 57/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0385, acc: 0.6356
Epoch 58/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0384, acc: 0.6385
Epoch 59/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0383, acc: 0.6356
Epoch 60/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0383, acc: 0.6356
Epoch 61/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0383, acc: 0.6356
Epoch 62/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0382, acc: 0.6356
Epoch 63/150
train loss: 0.0000, acc: 

train loss: 0.0144, acc: 0.5363
val loss: 0.0126, acc: 0.5335
Epoch 2/150
train loss: 0.0108, acc: 0.5799
val loss: 0.0115, acc: 0.5948
Epoch 3/150
train loss: 0.0095, acc: 0.6773
val loss: 0.0120, acc: 0.6210
Epoch 4/150
train loss: 0.0079, acc: 0.7440
val loss: 0.0121, acc: 0.6210
Epoch 5/150
train loss: 0.0070, acc: 0.7857
val loss: 0.0113, acc: 0.6472
Epoch 6/150
train loss: 0.0069, acc: 0.8044
val loss: 0.0134, acc: 0.6181
Epoch 7/150
train loss: 0.0049, acc: 0.8670
val loss: 0.0150, acc: 0.6064
Epoch 8/150
train loss: 0.0044, acc: 0.8802
val loss: 0.0187, acc: 0.6268
Epoch 9/150
train loss: 0.0023, acc: 0.9513
val loss: 0.0192, acc: 0.6385
Epoch 10/150
train loss: 0.0014, acc: 0.9692
val loss: 0.0216, acc: 0.6327
Epoch 11/150
train loss: 0.0007, acc: 0.9868
val loss: 0.0236, acc: 0.6589
Epoch 12/150
train loss: 0.0005, acc: 0.9894
val loss: 0.0263, acc: 0.6210
Epoch 13/150
train loss: 0.0004, acc: 0.9905
val loss: 0.0303, acc: 0.6356
Epoch 14/150
train loss: 0.0004, acc: 0.9927
v

train loss: 0.0000, acc: 1.0000
val loss: 0.0326, acc: 0.6356
Epoch 112/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0328, acc: 0.6327
Epoch 113/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0322, acc: 0.6443
Epoch 114/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0321, acc: 0.6443
Epoch 115/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0324, acc: 0.6443
Epoch 116/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0324, acc: 0.6414
Epoch 117/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0328, acc: 0.6356
Epoch 118/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0325, acc: 0.6268
Epoch 119/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0332, acc: 0.6327
Epoch 120/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0328, acc: 0.6501
Epoch 121/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0329, acc: 0.6443
Epoch 122/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0325, acc: 0.6356
Epoch 123/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0324, acc: 0.6356
Epoch 124/150
train loss: 

train loss: 0.0000, acc: 1.0000
val loss: 0.0354, acc: 0.7114
Epoch 49/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0354, acc: 0.7143
Epoch 50/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0355, acc: 0.7143
Epoch 51/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0354, acc: 0.7114
Epoch 52/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0353, acc: 0.7114
Epoch 53/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0353, acc: 0.7201
Epoch 54/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0352, acc: 0.7143
Epoch 55/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0351, acc: 0.7172
Epoch 56/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0351, acc: 0.7172
Epoch 57/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0350, acc: 0.7172
Epoch 58/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0349, acc: 0.7201
Epoch 59/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0348, acc: 0.7114
Epoch 60/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0347, acc: 0.7172
Epoch 61/150
train loss: 0.0000, acc: 

train loss: 0.0138, acc: 0.5117
val loss: 0.0126, acc: 0.4927
Epoch 2/150
train loss: 0.0118, acc: 0.5216
val loss: 0.0132, acc: 0.5131
Epoch 3/150
train loss: 0.0109, acc: 0.5557
val loss: 0.0118, acc: 0.5714
Epoch 4/150
train loss: 0.0100, acc: 0.6275
val loss: 0.0111, acc: 0.6385
Epoch 5/150
train loss: 0.0092, acc: 0.6714
val loss: 0.0112, acc: 0.6880
Epoch 6/150
train loss: 0.0084, acc: 0.7190
val loss: 0.0108, acc: 0.6706
Epoch 7/150
train loss: 0.0079, acc: 0.7557
val loss: 0.0111, acc: 0.6647
Epoch 8/150
train loss: 0.0068, acc: 0.8114
val loss: 0.0125, acc: 0.6647
Epoch 9/150
train loss: 0.0057, acc: 0.8447
val loss: 0.0120, acc: 0.6793
Epoch 10/150
train loss: 0.0044, acc: 0.8875
val loss: 0.0163, acc: 0.6618
Epoch 11/150
train loss: 0.0033, acc: 0.9117
val loss: 0.0248, acc: 0.6385
Epoch 12/150
train loss: 0.0018, acc: 0.9549
val loss: 0.0250, acc: 0.6822
Epoch 13/150
train loss: 0.0008, acc: 0.9832
val loss: 0.0294, acc: 0.6501
Epoch 14/150
train loss: 0.0005, acc: 0.9890
v

val loss: 0.0421, acc: 0.6268
Epoch 111/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0418, acc: 0.6297
Epoch 112/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0415, acc: 0.6268
Epoch 113/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0411, acc: 0.6297
Epoch 114/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0409, acc: 0.6239
Epoch 115/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0411, acc: 0.6356
Epoch 116/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0407, acc: 0.6239
Epoch 117/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0426, acc: 0.6327
Epoch 118/150
train loss: 0.0167, acc: 0.6791
val loss: 0.0137, acc: 0.5743
Epoch 119/150
train loss: 0.0106, acc: 0.6403
val loss: 0.0118, acc: 0.6764
Epoch 120/150
train loss: 0.0091, acc: 0.6927
val loss: 0.0111, acc: 0.6618
Epoch 121/150
train loss: 0.0089, acc: 0.7106
val loss: 0.0107, acc: 0.6589
Epoch 122/150
train loss: 0.0088, acc: 0.7066
val loss: 0.0107, acc: 0.6997
Epoch 123/150
train loss: 0.0083, acc: 0.7264
val loss: 0.

train loss: 0.0000, acc: 1.0000
val loss: 0.0364, acc: 0.6035
Epoch 46/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0364, acc: 0.6064
Epoch 47/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0364, acc: 0.6006
Epoch 48/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0364, acc: 0.5977
Epoch 49/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0362, acc: 0.6035
Epoch 50/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0362, acc: 0.6035
Epoch 51/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0361, acc: 0.5977
Epoch 52/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0360, acc: 0.5948
Epoch 53/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0359, acc: 0.5918
Epoch 54/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0358, acc: 0.5889
Epoch 55/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0358, acc: 0.5889
Epoch 56/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0358, acc: 0.5889
Epoch 57/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0359, acc: 0.5918
Epoch 58/150
train loss: 0.0000, acc: 

train loss: 0.0125, acc: 0.4927
val loss: 0.0124, acc: 0.4956
Epoch 2/150
train loss: 0.0112, acc: 0.5172
val loss: 0.0122, acc: 0.4985
Epoch 3/150
train loss: 0.0110, acc: 0.5260
val loss: 0.0119, acc: 0.5481
Epoch 4/150
train loss: 0.0106, acc: 0.5703
val loss: 0.0117, acc: 0.5773
Epoch 5/150
train loss: 0.0092, acc: 0.6857
val loss: 0.0112, acc: 0.6531
Epoch 6/150
train loss: 0.0079, acc: 0.7601
val loss: 0.0110, acc: 0.6851
Epoch 7/150
train loss: 0.0069, acc: 0.7993
val loss: 0.0107, acc: 0.6822
Epoch 8/150
train loss: 0.0056, acc: 0.8451
val loss: 0.0129, acc: 0.6968
Epoch 9/150
train loss: 0.0036, acc: 0.9106
val loss: 0.0161, acc: 0.6880
Epoch 10/150
train loss: 0.0019, acc: 0.9557
val loss: 0.0222, acc: 0.6851
Epoch 11/150
train loss: 0.0010, acc: 0.9755
val loss: 0.0257, acc: 0.6910
Epoch 12/150
train loss: 0.0013, acc: 0.9707
val loss: 0.0272, acc: 0.6589
Epoch 13/150
train loss: 0.0005, acc: 0.9886
val loss: 0.0304, acc: 0.6706
Epoch 14/150
train loss: 0.0002, acc: 0.9967
v

train loss: 0.0007, acc: 0.9850
val loss: 0.0328, acc: 0.6880
Epoch 112/150
train loss: 0.0005, acc: 0.9901
val loss: 0.0363, acc: 0.7026
Epoch 113/150
train loss: 0.0004, acc: 0.9905
val loss: 0.0374, acc: 0.6735
Epoch 114/150
train loss: 0.0002, acc: 0.9967
val loss: 0.0474, acc: 0.6385
Epoch 115/150
train loss: 0.0001, acc: 0.9993
val loss: 0.0448, acc: 0.6822
Epoch 116/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0479, acc: 0.6589
Epoch 117/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0480, acc: 0.6676
Epoch 118/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0484, acc: 0.6618
Epoch 119/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0489, acc: 0.6647
Epoch 120/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0487, acc: 0.6647
Epoch 121/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0490, acc: 0.6647
Epoch 122/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0489, acc: 0.6647
Epoch 123/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0487, acc: 0.6618
Epoch 124/150
train loss: 

val loss: 0.0462, acc: 0.6647
Epoch 43/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0462, acc: 0.6589
Epoch 44/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0462, acc: 0.6589
Epoch 45/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0462, acc: 0.6647
Epoch 46/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0462, acc: 0.6647
Epoch 47/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0461, acc: 0.6647
Epoch 48/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0461, acc: 0.6618
Epoch 49/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0460, acc: 0.6589
Epoch 50/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0459, acc: 0.6589
Epoch 51/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0459, acc: 0.6589
Epoch 52/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0458, acc: 0.6589
Epoch 53/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0458, acc: 0.6618
Epoch 54/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0457, acc: 0.6618
Epoch 55/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0456, acc: 0.

train loss: 0.0115, acc: 0.5209
val loss: 0.0127, acc: 0.4898
Epoch 2/150
train loss: 0.0106, acc: 0.5952
val loss: 0.0125, acc: 0.5219
Epoch 3/150
train loss: 0.0095, acc: 0.6788
val loss: 0.0128, acc: 0.5685
Epoch 4/150
train loss: 0.0067, acc: 0.8110
val loss: 0.0168, acc: 0.5773
Epoch 5/150
train loss: 0.0044, acc: 0.8813
val loss: 0.0225, acc: 0.5598
Epoch 6/150
train loss: 0.0029, acc: 0.9264
val loss: 0.0243, acc: 0.5918
Epoch 7/150
train loss: 0.0011, acc: 0.9751
val loss: 0.0316, acc: 0.5918
Epoch 8/150
train loss: 0.0007, acc: 0.9817
val loss: 0.0312, acc: 0.5831
Epoch 9/150
train loss: 0.0007, acc: 0.9824
val loss: 0.0393, acc: 0.5743
Epoch 10/150
train loss: 0.0012, acc: 0.9736
val loss: 0.0400, acc: 0.5743
Epoch 11/150
train loss: 0.0012, acc: 0.9703
val loss: 0.0305, acc: 0.6064
Epoch 12/150
train loss: 0.0003, acc: 0.9941
val loss: 0.0374, acc: 0.5918
Epoch 13/150
train loss: 0.0001, acc: 0.9985
val loss: 0.0390, acc: 0.5831
Epoch 14/150
train loss: 0.0000, acc: 1.0000
v

train loss: 0.0000, acc: 1.0000
val loss: 0.0565, acc: 0.5190
Epoch 112/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0564, acc: 0.5248
Epoch 113/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0566, acc: 0.5190
Epoch 114/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0566, acc: 0.5190
Epoch 115/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0567, acc: 0.5248
Epoch 116/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0568, acc: 0.5248
Epoch 117/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0568, acc: 0.5219
Epoch 118/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0569, acc: 0.5277
Epoch 119/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0569, acc: 0.5277
Epoch 120/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0570, acc: 0.5248
Epoch 121/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0571, acc: 0.5219
Epoch 122/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0570, acc: 0.5190
Epoch 123/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0569, acc: 0.5277
Epoch 124/150
train loss: 

val loss: 0.0439, acc: 0.6501
Epoch 38/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0438, acc: 0.6531
Epoch 39/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0437, acc: 0.6501
Epoch 40/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0436, acc: 0.6472
Epoch 41/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0435, acc: 0.6472
Epoch 42/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0434, acc: 0.6472
Epoch 43/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0433, acc: 0.6472
Epoch 44/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0431, acc: 0.6472
Epoch 45/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0430, acc: 0.6472
Epoch 46/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0430, acc: 0.6443
Epoch 47/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0429, acc: 0.6443
Epoch 48/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0428, acc: 0.6385
Epoch 49/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0427, acc: 0.6356
Epoch 50/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0425, acc: 0.

val loss: 0.0488, acc: 0.6764
Epoch 147/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0486, acc: 0.6793
Epoch 148/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0492, acc: 0.6706
Epoch 149/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0492, acc: 0.6735
Epoch 150/150
train loss: 0.0034, acc: 0.9542
val loss: 0.0191, acc: 0.5802
Training complete in 98m 45s
Best val Acc: 0.693878
Best Epoch : 88
Test Accuracy : 0.78
                                        Train_Acc   Val_Acc  Test_Acc  Epoch
C[1, 16, 8]_K(3, 8)                      1.000000  0.661808      0.74    131
C[1, 16, 8]_K(3, 24)                     0.993773  0.667638      0.68    125
C[1, 16, 8]_K(3, 40)                     0.889744  0.667638      0.70    128
C[1, 32, 16, 8]_K(3, 8)                  0.955311  0.688047      0.78     14
C[1, 32, 16, 8]_K(3, 24)                 0.915385  0.664723      0.76     13
C[1, 32, 16, 8]_K(3, 40)                 0.952015  0.661808      0.66     19
C[1, 64, 32, 16, 8]_K(3, 8)           

train loss: 0.0000, acc: 1.0000
val loss: 0.0492, acc: 0.6181
Epoch 73/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0498, acc: 0.6122
Epoch 74/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0507, acc: 0.6035
Epoch 75/150
train loss: 0.0000, acc: 1.0000
val loss: 0.0498, acc: 0.6152
Epoch 76/150
train loss: 0.0103, acc: 0.7143
val loss: 0.0135, acc: 0.4927
Epoch 77/150
train loss: 0.0108, acc: 0.5575
val loss: 0.0118, acc: 0.5627
Epoch 78/150
train loss: 0.0105, acc: 0.5850
val loss: 0.0119, acc: 0.5948
Epoch 79/150
train loss: 0.0100, acc: 0.6227
val loss: 0.0122, acc: 0.5656
Epoch 80/150
train loss: 0.0097, acc: 0.6516
val loss: 0.0123, acc: 0.6268
Epoch 81/150
train loss: 0.0090, acc: 0.7040
val loss: 0.0124, acc: 0.6035
Epoch 82/150
train loss: 0.0087, acc: 0.7121
val loss: 0.0114, acc: 0.6706
Epoch 83/150
train loss: 0.0084, acc: 0.7253
val loss: 0.0111, acc: 0.6414
Epoch 84/150
train loss: 0.0081, acc: 0.7484
val loss: 0.0120, acc: 0.6501
Epoch 85/150
train loss: 0.0075, acc: 

### Save results

In [None]:
import pickle

In [None]:
rTable = dict(table = table)
ij = str(np.random.randint(101))
filename = iname + "_CNNPOOLEDRES_"+description +ij+ 'ConvDown' + str(ConvDOWN)         

with open(filename, 'wb') as ffile:
    pickle.dump(rTable, ffile)   

### Results obtained

In [None]:
rTable['table']

In [None]:
"""
Demo single-file script to train a ConvNet on CIFAR10 using SoftHebb, an unsupervised, efficient and bio-plausible
learning algorithm
"""
import math
import warnings

import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.nn.modules.utils import _pair
from torch.optim.lr_scheduler import StepLR
import torchvision


class SoftHebbConv2d(nn.Module):
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size: int,
            stride: int = 1,
            padding: int = 0,
            dilation: int = 1,
            groups: int = 1,
            t_invert: float = 12,
    ) -> None:
        """
        Simplified implementation of Conv2d learnt with SoftHebb; an unsupervised, efficient and bio-plausible
        learning algorithm.
        This simplified implementation omits certain configurable aspects, like using a bias, groups>1, etc. which can
        be found in the full implementation in hebbconv.py
        """
        super(SoftHebbConv2d, self).__init__()
        assert groups == 1, "Simple implementation does not support groups > 1."
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = _pair(kernel_size)
        self.stride = _pair(stride)
        self.dilation = _pair(dilation)
        self.groups = groups
        self.padding_mode = 'reflect'
        self.F_padding = (padding, padding, padding, padding)
        weight_range = 25 / math.sqrt((in_channels / groups) * kernel_size * kernel_size)
        self.weight = nn.Parameter(weight_range * torch.randn((out_channels, in_channels // groups, *self.kernel_size)))
        self.t_invert = torch.tensor(t_invert)

    def forward(self, x):
        x = F.pad(x, self.F_padding, self.padding_mode)  # pad input
        # perform conv, obtain weighted input u \in [B, OC, OH, OW]
        weighted_input = F.conv2d(x, self.weight, None, self.stride, 0, self.dilation, self.groups)

        if self.training:
            # ===== find post-synaptic activations y = sign(u)*softmax(u, dim=C), s(u)=1 - 2*I[u==max(u,dim=C)] =====
            # Post-synaptic activation, for plastic update, is weighted input passed through a softmax.
            # Non-winning neurons (those not with the highest activation) receive the negated post-synaptic activation.
            batch_size, out_channels, height_out, width_out = weighted_input.shape
            # Flatten non-competing dimensions (B, OC, OH, OW) -> (OC, B*OH*OW)
            flat_weighted_inputs = weighted_input.transpose(0, 1).reshape(out_channels, -1)
            # Compute the winner neuron for each batch element and pixel
            flat_softwta_activs = torch.softmax(self.t_invert * flat_weighted_inputs, dim=0)
            flat_softwta_activs = - flat_softwta_activs  # Turn all postsynaptic activations into anti-Hebbian
            win_neurons = torch.argmax(flat_weighted_inputs, dim=0)  # winning neuron for each pixel in each input
            competing_idx = torch.arange(flat_weighted_inputs.size(1))  # indeces of all pixel-input elements
            # Turn winner neurons' activations back to hebbian
            flat_softwta_activs[win_neurons, competing_idx] = - flat_softwta_activs[win_neurons, competing_idx]
            softwta_activs = flat_softwta_activs.view(out_channels, batch_size, height_out, width_out).transpose(0, 1)
            # ===== compute plastic update Δw = y*(x - u*w) = y*x - (y*u)*w =======================================
            # Use Convolutions to apply the plastic update. Sweep over inputs with postynaptic activations.
            # Each weighting of an input pixel & an activation pixel updates the kernel element that connected them in
            # the forward pass.
            yx = F.conv2d(
                x.transpose(0, 1),  # (B, IC, IH, IW) -> (IC, B, IH, IW)
                softwta_activs.transpose(0, 1),  # (B, OC, OH, OW) -> (OC, B, OH, OW)
                padding=0,
                stride=self.dilation,
                dilation=self.stride,
                groups=1
            ).transpose(0, 1)  # (IC, OC, KH, KW) -> (OC, IC, KH, KW)

            # sum over batch, output pixels: each kernel element will influence all batches and output pixels.
            yu = torch.sum(torch.mul(softwta_activs, weighted_input), dim=(0, 2, 3))
            delta_weight = yx - yu.view(-1, 1, 1, 1) * self.weight
            delta_weight.div_(torch.abs(delta_weight).amax() + 1e-30)  # Scale [min/max , 1]
            self.weight.grad = delta_weight  # store in grad to be used with common optimizers

        return weighted_input


class DeepSoftHebb(nn.Module):
    def __init__(self):
        super(DeepSoftHebb, self).__init__()
        # block 1
        self.bn1 = nn.BatchNorm2d(3, affine=False)
        self.conv1 = SoftHebbConv2d(in_channels=3, out_channels=96, kernel_size=5, padding=2, t_invert=1,)
        self.activ1 = Triangle(power=0.7)
        self.pool1 = nn.MaxPool2d(kernel_size=4, stride=2, padding=1)
        # block 2
        self.bn2 = nn.BatchNorm2d(96, affine=False)
        self.conv2 = SoftHebbConv2d(in_channels=96, out_channels=384, kernel_size=3, padding=1, t_invert=0.65,)
        self.activ2 = Triangle(power=1.4)
        self.pool2 = nn.MaxPool2d(kernel_size=4, stride=2, padding=1)
        # block 3
        self.bn3 = nn.BatchNorm2d(384, affine=False)
        self.conv3 = SoftHebbConv2d(in_channels=384, out_channels=1536, kernel_size=3, padding=1, t_invert=0.25,)
        self.activ3 = Triangle(power=1.)
        self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
        # block 4
        self.flatten = nn.Flatten()
        self.classifier = nn.Linear(24576, 10)
        self.classifier.weight.data = 0.11048543456039805 * torch.rand(10, 24576)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        # block 1
        out = self.pool1(self.activ1(self.conv1(self.bn1(x))))
        # block 2
        out = self.pool2(self.activ2(self.conv2(self.bn2(out))))
        # block 3
        out = self.pool3(self.activ3(self.conv3(self.bn3(out))))
        # block 4
        return self.classifier(self.dropout(self.flatten(out)))


class Triangle(nn.Module):
    def __init__(self, power: float = 1, inplace: bool = True):
        super(Triangle, self).__init__()
        self.inplace = inplace
        self.power = power

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        input = input - torch.mean(input.data, axis=1, keepdims=True)
        return F.relu(input, inplace=self.inplace) ** self.power


class WeightNormDependentLR(optim.lr_scheduler._LRScheduler):
    """
    Custom Learning Rate Scheduler for unsupervised training of SoftHebb Convolutional blocks.
    Difference between current neuron norm and theoretical converged norm (=1) scales the initial lr.
    """

    def __init__(self, optimizer, power_lr, last_epoch=-1, verbose=False):
        self.optimizer = optimizer
        self.initial_lr_groups = [group['lr'] for group in self.optimizer.param_groups]  # store initial lrs
        self.power_lr = power_lr
        super().__init__(optimizer, last_epoch, verbose)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)
        new_lr = []
        for i, group in enumerate(self.optimizer.param_groups):
            for param in group['params']:
                # difference between current neuron norm and theoretical converged norm (=1) scales the initial lr
                # initial_lr * |neuron_norm - 1| ** 0.5
                norm_diff = torch.abs(torch.linalg.norm(param.view(param.shape[0], -1), dim=1, ord=2) - 1) + 1e-10
                new_lr.append(self.initial_lr_groups[i] * (norm_diff ** self.power_lr)[:, None, None, None])
        return new_lr


class TensorLRSGD(optim.SGD):
    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step, using a non-scalar (tensor) learning rate.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']

            for p in group['params']:
                if p.grad is None:
                    continue
                d_p = p.grad
                if weight_decay != 0:
                    d_p = d_p.add(p, alpha=weight_decay)
                if momentum != 0:
                    param_state = self.state[p]
                    if 'momentum_buffer' not in param_state:
                        buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
                    else:
                        buf = param_state['momentum_buffer']
                        buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
                    if nesterov:
                        d_p = d_p.add(buf, alpha=momentum)
                    else:
                        d_p = buf

                p.add_(-group['lr'] * d_p)
        return loss


class CustomStepLR(StepLR):
    """
    Custom Learning Rate schedule with step functions for supervised training of linear readout (classifier)
    """

    def __init__(self, optimizer, nb_epochs):
        threshold_ratios = [0.2, 0.35, 0.5, 0.6, 0.7, 0.8, 0.9]
        self.step_thresold = [int(nb_epochs * r) for r in threshold_ratios]
        super().__init__(optimizer, -1, False)

    def get_lr(self):
        if self.last_epoch in self.step_thresold:
            return [group['lr'] * 0.5
                    for group in self.optimizer.param_groups]
        return [group['lr'] for group in self.optimizer.param_groups]


class FastCIFAR10(torchvision.datasets.CIFAR10):
    """
    Improves performance of training on CIFAR10 by removing the PIL interface and pre-loading on the GPU (2-3x speedup).

    Taken from https://github.com/y0ast/pytorch-snippets/tree/main/fast_mnist
    """

    def __init__(self, *args, **kwargs):
        device = kwargs.pop('device', "cpu")
        super().__init__(*args, **kwargs)

        self.data = torch.tensor(self.data, dtype=torch.float, device=device).div_(255)
        self.data = torch.movedim(self.data, -1, 1)  # -> set dim to: (batch, channels, height, width)
        self.targets = torch.tensor(self.targets, device=device)

    def __getitem__(self, index: int):
        """
        Parameters
        ----------
        index : int
            Index of the element to be returned

        Returns
        -------
            tuple: (image, target) where target is the index of the target class
        """
        img = self.data[index]
        target = self.targets[index]

        return img, target


# Main training loop CIFAR10
if __name__ == "__main__":
    device = torch.device('cuda:0')
    model = DeepSoftHebb()
    model.to(device)

    unsup_optimizer = TensorLRSGD([
        {"params": model.conv1.parameters(), "lr": -0.08, },  # SGD does descent, so set lr to negative
        {"params": model.conv2.parameters(), "lr": -0.005, },
        {"params": model.conv3.parameters(), "lr": -0.01, },
    ], lr=0)
    unsup_lr_scheduler = WeightNormDependentLR(unsup_optimizer, power_lr=0.5)

    sup_optimizer = optim.Adam(model.classifier.parameters(), lr=0.001)
    sup_lr_scheduler = CustomStepLR(sup_optimizer, nb_epochs=50)
    criterion = nn.CrossEntropyLoss()

    trainset = FastCIFAR10('./data', train=True, download=True)
    unsup_trainloader = torch.utils.data.DataLoader(trainset, batch_size=10, shuffle=True, )
    sup_trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, )

    testset = FastCIFAR10('./data', train=False)
    testloader = torch.utils.data.DataLoader(testset, batch_size=1000, shuffle=False)

    # Unsupervised training with SoftHebb
    running_loss = 0.0
    for i, data in enumerate(unsup_trainloader, 0):
        inputs, _ = data
        inputs = inputs.to(device)

        # zero the parameter gradients
        unsup_optimizer.zero_grad()

        # forward + update computation
        with torch.no_grad():
            outputs = model(inputs)

        # optimize
        unsup_optimizer.step()
        unsup_lr_scheduler.step()

    # Supervised training of classifier
    # set requires grad false and eval mode for all modules but classifier
    unsup_optimizer.zero_grad()
    model.conv1.requires_grad = False
    model.conv2.requires_grad = False
    model.conv3.requires_grad = False
    model.conv1.eval()
    model.conv2.eval()
    model.conv3.eval()
    model.bn1.eval()
    model.bn2.eval()
    model.bn3.eval()
    for epoch in range(50):
        model.classifier.train()
        model.dropout.train()
        running_loss = 0.0
        correct = 0
        total = 0
        for i, data in enumerate(sup_trainloader, 0):
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)

            # zero the parameter gradients
            sup_optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            sup_optimizer.step()

            # compute training statistics
            running_loss += loss.item()
            if epoch % 10 == 0 or epoch == 49:
                total += labels.size(0)
                _, predicted = torch.max(outputs.data, 1)
                correct += (predicted == labels).sum().item()
        sup_lr_scheduler.step()
        # Evaluation on test set
        if epoch % 10 == 0 or epoch == 49:
            print(f'Accuracy of the network on the train images: {100 * correct // total} %')
            print(f'[{epoch + 1}] loss: {running_loss / total:.3f}')

            # on the test set
            model.eval()
            running_loss = 0.
            correct = 0
            total = 0
            # since we're not training, we don't need to calculate the gradients for our outputs
            with torch.no_grad():
                for data in testloader:
                    images, labels = data
                    images = images.to(device)
                    labels = labels.to(device)
                    # calculate outputs by running images through the network
                    outputs = model(images)
                    # the class with the highest energy is what we choose as prediction
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
                    loss = criterion(outputs, labels)
                    running_loss += loss.item()

            print(f'Accuracy of the network on the 10000 test images: {100 * correct / total} %')
            print(f'test loss: {running_loss / total:.3f}')
