In [None]:
import torch
import torch.nn.functional as F
from torchcontrib.optim import SWA
from torch.autograd import Variable

import numpy as np
import pandas as pd
import mne
import time
import datetime

from matplotlib import pyplot as plt
from numpy.random import RandomState


import braindecode
from braindecode.torch_ext.util import np_to_var, var_to_np
from braindecode.experiments.monitors import compute_preds_per_trial_from_crops
from braindecode.datautil.signal_target import SignalAndTarget
from braindecode.datautil.iterators import CropsFromTrialsIterator
from braindecode.models.util import to_dense_prediction_model
from braindecode.models.eegnet import EEGNetv4
from braindecode.torch_ext.optimizers import AdamW


%load_ext autoreload

%autoreload 1

%aimport braindecode.models.deep4, LoadEEG, mne_interface, OnlineToOffline, torchcontrib.optim

## Load Data

parameters_data contains information for loader LoadEEG. <br>
Used target frequency: 256 Hz <br>
time range: from 0.5 sec before trial onset to 3 sec after trial onset <br>
filters: highpass filter with cutting frequency 1 Hz, lowpass filter with 30 Hz, notch filter 50 Hz <br>
channels: 'eog_eeg' regime returnes eog signals referenced to each other or Fp1/Fp2 channels: <br>
\begin{equation*}
result_channel1 = left EOG - right EOG \\
result_channel2 = Fp1 - down EOG \\
result_channel3 = Fp2 - up EOG \\
\end{equation*}

'eeg' - extracts only eeg files <br>
'MI' - takes only motor-imagery relevant channels, was used for most of the trained models.<br>
LoadEEG.load_subjects() allows to load the data for different subjects, epoch it, preprocess and split them according training mode. <br>
For more details please see the code and comments above and function descriptions in the file LoadEEG.py <br>
__NOTE__ All models were trained on all data from the subjects (including runs that were considrered as bad according to protocols).

In [None]:
parameters_data = {
    'subject_files': ['subjH','subjB','subjD','subjM'], # in case 'leave_one_out' 
    # training design the data will be organized according to an order in this 
    # list: the last subject will be used for testing, other - for training and 
    # validation
    'training_design': 'leave_one_out', # 'leave_one_out', 'mix'
    'target_fs': 256, # target frequency
    'tmin': -0.5, #0.5 seconds before onset of trial
    'tmax': 3, #3 seconds after onset of trial
    'filters': True, #True/False
    'channels': 'MI' #eeg_eog, eeg, MI- motor-imagery relevant channels
}

In [None]:
# depending on parameter return_only_test=False/True function returns either numpy arrays or SignalAndTarget object
train, validation, test, train_labels, validation_labels, test_labels = LoadEEG.load_subjects(parameters_data['subject_files'], training_design=parameters_data['training_design'], target_fs=parameters_data['target_fs'],
                  tmin=parameters_data['tmin'], tmax=parameters_data['tmax'], filters=parameters_data['filters'], channels=parameters_data['channels'])

train_set = SignalAndTarget(train.copy(), train_labels)
validation_set = SignalAndTarget(validation.copy(), validation_labels)
test_set = SignalAndTarget(test.copy(), test_labels)

In [None]:
# Check what is in your data (especially if you are using filters)
i = 0

for channel in range(train_set.X.shape[1]):
    plt.figure(figsize=(20,5))
    plt.plot(validation_set.X[-1,channel,:])
    
    plt.figure(figsize=(10,5))
    plt.magnitude_spectrum(validation_set.X[-1,channel,:], Fs=256, color='m')
    
    i+=1

## Topological plot of variance

In [None]:
from braindecode.datasets.sensor_positions import get_channelpos, CHANNEL_10_20_APPROX

eeg_chan = np.argwhere(ch_types == 'eeg')
ch_names=[]
for chan in eeg_chan:
    ch_names = ch_names + [epochs_left.ch_names[chan[0]]]
ch_names = [s.strip('.') for s in ch_names]

positions = np.array([get_channelpos(name, CHANNEL_10_20_APPROX) for name in ch_names])

In [None]:
from matplotlib import cm
%matplotlib inline
lab = {
    0: 'left hand',
    1: 'right hand'
}

i=0
k=0
nr_trials = 10
fig, axes = plt.subplots((nr_trials//2),2)
fig.set_figwidth(20)
fig.set_figheight(50)

for trial in range(nr_trials):
    ax = axes[i, k]
    if k==0:
        k+=1
    else:
        k=0
        i+=1
    tr = np.random.choice(epochs.shape[0])
    tr_label = labels[tr]
    trial_var = np.var(epochs[tr,:,:], axis=1)
    max_val = np.max(trial_var)
    min_val = np.min(trial_var)

    mne.viz.plot_topomap(trial_var, positions, vmin=min_val, 
                         vmax=max_val, contours=0, cmap=cm.coolwarm, axes=ax, show=False)
    ax.set_title(lab[tr_label], fontsize=18)
    

# Cropped training

This section performs a cross-subject (or single subject offline training) <br>
*parameters* contains different attributes for training process and also are used to save the model description in .xls file afterwards. <br>
**model** - deep4 or eegnet (for comparison): in general you can achieve the same validation performance with deep4 as with eegnet, however the later one is more stable towards change of optimizer and learning rate. <br>
**NOTE**: if using Cross Entropy Loss one needs to remove the Softmax layer from the model, also in the original file of model from braindecoder Dropuot layer is switched off (was switched on for these trainings) <br> 
**input_time_length** - for now defined depending on the train data shape, should be changed in case of extracting the whole trial during epoching to some fix length of sliding window <br>
If you use scheduling of learning rate (especially cosine annealing), there are 5 main parametres you should always keep in mind due to the tight connection between them:<br>
**n_epoch** - number of epochs for training, **batch_size** - size of minibacth (recommended to take $2^n$ where n - any integer number), **lr** - starting learning rate (recommended to start with >1e-3 for not pretrained models), **eta_min** - min learning rate that scheduler will reach, **optimizer** - three main optimizers were tested. <br>
This parameters should always be balanced wrt to each other: <br>
SGD+momentum - that gave best results in most of the training ia also much more stable than Adam and AdamW (running with the same parameters several times the same result can be reached), but SGD is slower and it makes sense to use more epochs, that indeed stretches the learning rate function. At the same time learning rate is closely connected (proportionally) to the batch size:the bigger batch size works better with bigger lr and is claimed to lead to better generalization. <br>
There are different sete of parameters that were tested, but the most successfull were: <br>
**n_epochs**: 500, <br>
**batch_size**: 64 <br>
**optimizer**: SGD+momentum <br>
**lr**: 0.01 <br>
**eta_min**: 0.0001 <br>
**scheduler**: cosine <br>
Additional technique of scheduling was also implemented: Warm Restart. Did not elicit better results though. <br>
More information on scheduling techniques and how to work with thme can be found here:<br>
https://towardsdatascience.com/https-medium-com-reina-wang-tw-stochastic-gradient-descent-with-restarts-5f511975163 <br>
https://pytorch.org/docs/stable/optim.html <br>

The best model was trained with early stop regularization: while training the validation loss is tracked and if it is not decreasing after 15 epochs from minimum point, the training is stopped. It also allows you to see how good are your training prameters: if your model stops training too quickly, you should consider tunung parameters again.

Also Stochastic Weight Averaging was implemented in a pipeline below. More information on it can be found here: <br>
https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/ <br>
https://arxiv.org/abs/1803.05407 <br>
It is not clear how promissing this technique would be for this model (waas not tuned), but the code below has all necessary steps for it. <br>
**NOTE** after setting weights of the model to SWA weights batch normalization layers also should be updated by seeing all the data again. For this torchcontib.optim.SWA has a function bn_update, but it accepts as an input DataLoader object (not a case for this itterator), to make it still work one needs to edit a function: *b=input.size(0) -> b=input.size* and add a line *input = torch.cuda.FloatTensor(input)*<br>
SWA works together with early stop regime: as soon as the best validation loss is achieved, scheduler is stoped and learning rate stays stable while SWA starts to save gradients further with training and does it cerain amount of time till the end. <br>

Evaluation: ther are plots of loss function (training and validation) and accuracy which allow to see the behaviour of the model, however the most interesting evaluation in this case is how good the model can perform on new unseen data from new subject so it should also be a valuable criterion for choosing a model for further online testing. 
    

In [None]:
# function that defines number of crops per supercrop for a certain model and input signal (model should be in dense regime)
def preds_per_input(model, net_input):
    net_input = Variable(torch.cuda.FloatTensor(net_input[:,:,:,None])) if torch.cuda.is_available() else Variable(torch.FloatTensor(net_input[:,:,:,None]))
    try:
        return model(net_input.cuda()).shape[2]
    except IndexError:
        print('Incorrect shape of output, check if network is in dense mode')

In [None]:
parameters = {
    # network 
    'model': 'eegnet', # 'deep4', 'eegnet'
    'input_time_length': train_set.X.shape[2], #length of trial
    'in_chans': train_set.X.shape[1], #number of channels
    'n_classes': 2, # number of classes
    #training
    'n_epochs': 500, # number of epochs to train
    'batch_size': 64, #size of minibatch
    'criterion': 'cross_entr', # loss function
    'early_stop': 15, # patients for early stop, False - when no early stop is needed
    # optimizer
    'optimizer': 'SGD', # 'Adam', 'AdamW', 'SGD' - SGD+momentum
    'sched': True, # True/False - learning rate scheduling
    'sched_type': 'cosine', # type of scheduling (for the records, does not change scheduling automathically)
    'lr': 0.01, # learning rate (starting learning rate in case of scheduling)
    'b1': 0.9, # b1 parameters for Adam amd AdamW
    'b2': 0.999, # b2 parameter for Adam and AdamW
    'eta_min': 0.0001, #minimum learning rate for scheduling
    'Tmax': None, #number of steps for scheduler (if sched=True this parameter is calculated later)
    'WR': False, # True/False - Warm Restart switch
    'restart_lr': 50, # number of epochs to wait till restarting learning rate
    'SWA': False, # True/False - Stochastic Weight Averaging switch
    'n_swa': 20, # number of savings for SWA (in fact - a bit less)
    'cuda': True # True/False - cuda 
}

# Create model 

cuda = parameters['cuda']

if parameters['model'] == 'deep4': 
    model_cropp = braindecode.models.deep4.Deep4Net(in_chans=parameters['in_chans'], 
                                                    n_classes=parameters['n_classes'],
                                                    input_time_length=parameters['input_time_length'], 
                                                    final_conv_length=1, 
                                                    batch_norm_alpha=0.1).create_network()
    to_dense_prediction_model(model_cropp)
    model_cropp.cuda() if cuda else None

elif parameters['model'] == 'eegnet':
    model_cropp = EEGNetv4(in_chans=parameters['in_chans'], 
                           n_classes=parameters['n_classes'], 
                           final_conv_length=1, 
                           input_time_length=parameters['input_time_length']).create_network()
    to_dense_prediction_model(model_cropp)
    model_cropp.cuda()
    
n_preds_per_input = preds_per_input(model_cropp, train_set.X)


#Create iterator
iterator = CropsFromTrialsIterator(batch_size=parameters['batch_size'],input_time_length=parameters['input_time_length'],
                                  n_preds_per_input=n_preds_per_input)

n_updates_per_epoch = len([None for b in iterator.get_batches(train_set, True)])

# Additional settings for warm restart, annealing and SWA
if parameters['WR']:
    parameters['Tmax'] = parameters['restart_lr'] * n_updates_per_epoch
else:
    parameters['Tmax'] = parameters['n_epochs'] * n_updates_per_epoch



In [None]:
# Create database for today's recording 
try:
    todayRecord
    if input('Are you sure you want to rewrite this database?') == 'yes':
        name_database = datetime.date.today()
        col = list(parameters_data.keys()) + list(parameters.keys()) + ['temporal filter size','min loss', 'epochs final', 'acc train', 'acc validation', 'acc test', 'curve file', 'model file']
        todayRecord = pd.DataFrame(columns = col)
except NameError:
    name_database = datetime.date.today()
    col = list(parameters_data.keys()) + list(parameters.keys()) + ['temporal filter size','min loss', 'epochs final', 'acc train', 'acc validation', 'acc test', 'curve file', 'model file']
    todayRecord = pd.DataFrame(columns = col)

In [None]:
#Set optimizer
if parameters['optimizer']=='AdamW':
    optimizer = AdamW(model_cropp.parameters(), 
                      lr = parameters['lr'], 
                      weight_decay=5*1e-5)
    print('Optimizer: AdamW')
if parameters['optimizer']=='Adam':
    optimizer = torch.optim.Adam(model_cropp.parameters(), 
                                 lr = parameters['lr'],
                                 betas = (parameters['b1'],
                                 parameters['b2']), 
                                 weight_decay = 1e-5)
    print('Optimizer: Adam')
if parameters['optimizer']=='SGD':
    optimizer = torch.optim.SGD(model_cropp.parameters(), 
                                lr=parameters['lr'],
                                momentum=0.9, 
                                weight_decay = 1e-5, 
                                nesterov=True)
    print('Optimizer: SGD+momentum')

if parameters['sched']:
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 
                                                           T_max=parameters['Tmax'],
                                                           eta_min=parameters['eta_min'], 
                                                           last_epoch=-1)
    print('Schedule learning rate')
else:
    print("No scheduling")

    
if parameters['criterion']=='cross_entr':
    criterion=torch.nn.CrossEntropyLoss().cuda()
    print("Loss function: cross entropy")

FloatTensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if torch.cuda.is_available() else torch.LongTensor



In [None]:
## housekeeping ##

train_loss = np.array([])
val_loss = np.array([])

train_acc = np.array([])
val_acc = np.array([])

lr = np.array([])
cross_points = np.array([])

min_loss = np.inf
early_stop = 0

In [None]:
# Training

# when training_mode is 'leave_one_out' network never sees the test data. 
# However, if one wants to cjeck how network would perform on test data having a 
# chunk of it in offline training process parameter check_network_test will add some 
#test trial to training and validation sets
check_network_test = False

if check_network_test:
    train_set.X = np.concatenate((train_set.X, test_set.X[:50,:,:]), axis=0)
    train_set.y = np.concatenate((train_set.y, test_set.y[:50]), axis=0)

    test_set.X = test_set.X[50:,:,:]
    test_set.y = test_set.y[50:]

    validation_set.X = np.concatenate((validation_set.X, test_set.X[:20,:,:]), axis=0)
    validation_set.y = np.concatenate((validation_set.y, test_set.y[:20]), axis=0)

    test_set.X = test_set.X[20:,:,:]
    test_set.y = test_set.y[20:]
else:
    train_set = SignalAndTarget(train.copy(), train_labels)
    validation_set = SignalAndTarget(validation.copy(), validation_labels)
    test_set = SignalAndTarget(test.copy(), test_labels)


start = time.time()
n_step = 0

for i_epoch in range(parameters['n_epochs']):

    print("Epoch {:d}".format(i_epoch))

    model_cropp.train()
    
    for batch_X, batch_y in iterator.get_batches(train_set, shuffle=True):
        
        net_in = Variable(FloatTensor(batch_X))
        net_target = Variable(LongTensor(batch_y))

        optimizer.zero_grad()
        outputs = model_cropp(net_in)
        outputs = torch.mean(outputs, dim=2, keepdim=False)

        loss = criterion(outputs, net_target)        
        loss.backward()
        optimizer.step()  
        
        if parameters['sched']:
            scheduler.step()
        
        lr = np.append(lr, optimizer.state_dict()['param_groups'][0]['lr'])            
        
        n_step += 1

    if (parameters['WR'])&(i_epoch>0)&(i_epoch%parameters['restart_lr']==0): # for warm restart

        optimizer.optimizer.param_groups[0]['lr'] = parameters['lr']
        
        if parameters['sched']:
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=parameters['Tmax'], 
                eta_min=parameters['eta_min'], last_epoch=-1)
        
    outputs = outputs.detach().cpu().numpy()
    train_loss = np.append(train_loss, loss.cpu().detach().numpy())
    accuracy = np.mean(np.argmax(outputs, axis = 1) == batch_y)
    train_acc = np.append(train_acc, accuracy)
    
    print("{:6s} Loss: {:.5f}".format('Train', loss))
    print("{:6s} Accuracy: {:.1f}%".format('Train', accuracy * 100))

    model_cropp.eval()

    net_in = Variable(FloatTensor(validation_set.X[:,:,:,None]), requires_grad=False)
    net_target = Variable(LongTensor(validation_set.y), requires_grad=False)

    outputs = model_cropp(net_in)
    outputs = torch.mean(outputs, dim=2, keepdim=False)

    loss = criterion(outputs, net_target)
    loss = loss.cpu().detach().numpy()
    if parameters['early_stop'] & (loss<min_loss):
        min_loss = loss
        save_loss = train_loss[-1]
        early_stop = 0
        model = model_cropp
    else:
        early_stop+=1

    outputs = outputs.detach().cpu().numpy()
    accuracy = np.mean(np.argmax(outputs, axis = 1) == validation_set.y)

    print("{:6s} Loss: {:.5f}".format('Valid', loss))
    print("{:6s} Accuracy: {:.1f}%".format(
        'Valid', accuracy * 100))

    val_loss = np.append(val_loss, loss)    
    val_acc = np.append(val_acc, accuracy)

    if parameters['early_stop'] & (early_stop == parameters['early_stop']):
        print('Early stop triggered')
        if parameters['SWA']:
            parameters['swa_freq'] = n_updates_per_epoch*(parameters['n_epochs']-i_epoch)//parameters['n_swa']
            parameters['sched'] = False
            optimizer = SWA(optimizer, swa_start=0, 
                            swa_lr=optimizer.param_groups[0]['lr'], 
                            swa_freq=parameters['swa_freq'])
        else:
            model_cropp = model
            break

            
end = time.time()
print('%d epochs executed in %f seconds' % (i_epoch+1,end-start))



The script bellow allows to retrain model with all the data after early stopping to a certain training loss that was saved in previous training

In [None]:
break_train = 0

train_set.X = np.concatenate((train_set.X, validation_set.X), axis=0)
train_set.y = np.concatenate((train_set.y, validation_set.y), axis=0)

start = time.time()
i_epoch = 0

for i_epoch in range(parameters['n_epochs']):

    print("Epoch {:d}".format(i_epoch))

    model_cropp.train()
    for batch_X, batch_y in iterator.get_batches(train_set, shuffle=True):

        net_in = Variable(FloatTensor(batch_X))
        net_target = Variable(LongTensor(batch_y))

        optimizer.zero_grad()
        outputs = model_cropp(net_in)
        outputs = torch.mean(outputs, dim=2, keepdim=False)
    
        loss = criterion(outputs, net_target)   
        
        if loss<save_loss:
            print('Reached the loss of training')
            break_train = True
            break
        loss.backward()
        optimizer.step()
        
        
        if parameters['sched']:
            scheduler.step()
    
    if break_train:
        break


    if (parameters['WR'])&(i_epoch%parameters['restart_lr']==0): # for warm restart

        if parameters['optimizer']=='AdamW':
            optimizer = AdamW(model_cropp.parameters(), 
                          lr = parameters['lr'], weight_decay=5*1e-4)

        if parameters['optimizer']=='Adam':
            optimizer = torch.optim.Adam(model_cropp.parameters(), lr = parameters['lr'], 
                                 betas = (parameters['b1'], parameters['b2']), weight_decay = 0)

        if parameters['sched']:
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=parameters['Tmax'], 
                eta_min=parameters['eta_min'], last_epoch=-1)

    if i_epoch%check_epoch == 0:
        
        outputs = outputs.detach().cpu().numpy()
        train_loss = np.append(train_loss, loss.cpu().detach().numpy())
        accuracy = np.mean(np.argmax(outputs, axis = 1) == batch_y)
        train_acc = np.append(train_acc, accuracy)
        print("{:6s} Loss: {:.5f}".format('Train', loss))
        print("{:6s} Accuracy: {:.1f}%".format('Train', accuracy * 100))

        model_cropp.eval()

        net_in = Variable(FloatTensor(validation_set.X[:,:,:,None]), requires_grad=False)
        net_target = Variable(LongTensor(validation_set.y), requires_grad=False)

        outputs = model_cropp(net_in)
        outputs = torch.mean(outputs, dim=2, keepdim=False)

        loss = criterion(outputs, net_target)
        loss = loss.cpu().detach().numpy()
    
        outputs = outputs.detach().cpu().numpy()
        accuracy = np.mean(np.argmax(outputs, axis = 1) == validation_set.y)
        
        lr = np.append(lr, optimizer.state_dict()['param_groups'][0]['lr'])

        print("{:6s} Loss: {:.5f}".format('Valid', loss))
        print("{:6s} Accuracy: {:.1f}%".format(
            'Valid', accuracy * 100))

        val_loss = np.append(val_loss, loss)    
        val_acc = np.append(val_acc, accuracy)
            
            
end = time.time()
print('%d epochs executed in %f seconds' % (i_epoch+1,end-start))

In [None]:
save_im =False

plt.figure(figsize=(10,5))
plt.plot(train_loss[3:], label='training loss')
plt.plot(val_loss[3:], label='validation loss')
plt.xlabel('epoch')
plt.title('Loss function')
plt.legend()

if input('Do you want to save images for this training?') == 'yes':
    save_im = True
    now = datetime.datetime.now()
    fig_name = str(now.day)+'-'+str(now.month)+'_'+str(now.hour)+'-'+str(now.minute)
    plt.savefig(fig_name+'_loss')


plt.figure(figsize=(10,5))
plt.plot(lr, label='learning_rate')
plt.scatter(cross_points.astype(np.int), lr[cross_points.astype(np.int)], marker='x', color='r', label='SWA')
plt.xlabel('epoch')
plt.ylabel('lr')
plt.title('Learning rate')
plt.legend()
if save_im:
    plt.savefig(fig_name+'_lr')


plt.figure(figsize=(10,5))
plt.plot(train_acc, label='training accuracy')
plt.plot(val_acc, label='validation accuracy')
plt.xlabel('epoch')
plt.title('Accuracy')
plt.legend()
if save_im:
    plt.savefig(fig_name+'_accuracy')


In [None]:
## Test ###
print('Best validation accuracy:', np.around(np.max(val_acc),2))
model_cropp.eval()

net_in = Variable(FloatTensor(test_set.X[:,:,:,None]), requires_grad=False)
net_target = Variable(LongTensor(test_set.y), requires_grad=False)

output_test = model_cropp(net_in)
output_test = torch.mean(output_test, dim=2, keepdim=False)

loss = F.nll_loss(output_test, net_target)
print('Test loss:', loss.cpu().detach().numpy())

predicted_labels = np.argmax(var_to_np(output_test), axis=1)
accuracy = np.mean(test_set.y  == predicted_labels)
print('Test accuracy:', np.around(accuracy,2))


if parameters['SWA']:
    
    optimizer.swap_swa_sgd()
    optimizer.bn_update(iterator.get_batches(train_set, shuffle=True), model_cropp)
    
    output_test = model_cropp(net_in)
    output_test = torch.mean(output_test, dim=2, keepdim=False)

    loss = F.nll_loss(output_test, net_target)
    print('Test loss with SWA:', loss.cpu().detach().numpy())

    predicted_labels = np.argmax(var_to_np(output_test), axis=1)
    accuracy_swa = np.mean(test_set.y  == predicted_labels)
    print('Test accuracy with SWA:', np.around(accuracy_swa,2))

    
if input('Do you want me to save the model description?') == 'yes':
    fill_line = todayRecord.shape[0]
    for param in parameters_data:
        todayRecord.at[fill_line, param] = parameters_data[param]
    for param in parameters:
        todayRecord.at[fill_line, param] = parameters[param]
    todayRecord.at[fill_line, 'temporal filter size'] = model_cropp[1].weight.shape[2]
    todayRecord.at[fill_line, 'min loss'] = min_loss
    todayRecord.at[fill_line, 'epochs final'] = i_epoch+1
    todayRecord.at[fill_line, 'acc train'] = np.max(train_acc)
    todayRecord.at[fill_line, 'acc validation'] = np.max(val_acc)
    todayRecord.at[fill_line, 'acc test'] = np.around(accuracy,2)
    
    if parameters['SWA']:
        todayRecord.at[fill_line, 'acc test with SWA'] = np.around(accuracy_swa,2)
    
    if save_im:
        todayRecord.at[fill_line, 'curve file'] = fig_name
    else:
        todayRecord.at[fill_line, 'curve file'] = 'no'
    
    if input('Do you want to save the model?') == 'yes':
        name_model = input('Name this model please:')
        torch.save(model_cropp, name_model)
        print('This model is saved')
        todayRecord.at[fill_line, 'model file'] = name_model
    
    print('All saved and you can keep training. You got it!')

        
    todayRecord.to_excel(str(name_database)+'1'+'.xlsx')

In [None]:
time_filters = np.squeeze(model_cropp[1].weight.detach().cpu().numpy())
nr=0
for filt in time_filters:
    nr+=1
    plt.figure()
    plt.plot(filt)
    plt.title('kernel of the filter '+ str(nr))

## Pseudo-online training

This section is designed to perform pseudo-onlne training on unseen data) <br>
Batch size should again correspond to the learing rate, size of 8 with learning rate 1e-3 or 1e-4 seems to gain the best results for tested models.<br>
It is in geenral hard to make the network to learn on new data from first couple of trials, however, if the pretrained model is good enough, your accuracy can vary from 55 to 70 percents during the online run. <br>

In [None]:
parameters_online = {
    'model file': ['model_01-07_10-29', 'model_01-07_10-11', 'model_25-06_19-23', 'model_19-06_11-16'],
    # list of models' files (make sure the order is the same as for the test_subject_data:
    # each subject correspond to the model that has never seen that data)
    'test subject files': ['subjB', 'subjD', 'subjH', 'subjM'], # list of files with subjects' data
    'training layers': [1,2,26], # all other layers except these will be frozen
    # training
    'batch_size': 8, # size of the minibatch
    # optimizer
    'optimizer': 'SGD', # optimizer
    'lr': 1e-3, # learning rate
    'b1': 0.9,
    'b2': 0.999,
    'cuda': True
}

criterion=torch.nn.CrossEntropyLoss().cuda()
print("Loss function: cross entropy")

FloatTensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if torch.cuda.is_available() else torch.LongTensor

In [None]:
# load test subjects' data
test_subject_data = {}
for subject in parameters_online['test subject files']:
    print(subject)
    test_subject_data[subject] = LoadEEG.load_subjects([subject], training_design=parameters_data['training_design'], target_fs=parameters_data['target_fs'],
                                                       tmin=parameters_data['tmin'], tmax=parameters_data['tmax'], filters=parameters_data['filters'], channels=parameters_data['channels'], return_only_test=True)

In [None]:
# Create database for today's online recording 
try:
    todayRecord_online
    if input('Are you sure you want to rewrite this database?') == 'yes':
        name_database = datetime.date.today()
        col = list(parameters_online.keys()) + ['curve file']
        todayRecord_online = pd.DataFrame(columns = col)
except NameError:
    name_database = datetime.date.today()
    col = list(parameters_online.keys()) + ['curve files']
    todayRecord_online = pd.DataFrame(columns = col)

In [None]:
# Run training for every model on corresponding test data

figure = plt.figure(figsize=(20,20))
ax_train_loss = figure.add_subplot(411)
ax_loss = figure.add_subplot(412)
ax_success_rate = figure.add_subplot(413)
ax_accuracy = figure.add_subplot(414)

test_input = Variable(FloatTensor(np.ones((2, test_subject_data[subject].X.shape[1], 
                                           test_subject_data[subject].X.shape[2], 1), dtype=np.float32)))
if torch.cuda.is_available():
    test_input = test_input.cuda()
out = model_cropp(test_input)
n_preds_per_input = out.shape[2]


# if True, the average performance is shown (meaning 
# that training will run on the minimum amount of trials among subjects), if False - 
# each subject's data will be used almost completely for training
average_mode = True

if average_mode:
    min_trials = np.min([test_subject_data[subject].X.shape[0] for 
                         subject in parameters_online['test subject files']])
    line_opaque = 0.3
    try:
        del avg_train_loss, avg_train_accum_loss, avg_online_loss, avg_online_acc, avg_online_accum_acc, avg_success_rate
    except NameError:
        None
else:
    min_trials = None
    line_opaque = 1
    
for i_model, model in enumerate(parameters_online['model file']):
    
    # Set model
    model_cropp = torch.load(model)

    for ind, layer in enumerate(model_cropp):
        if ind not in parameters_online['training layers']:
            if hasattr(layer, 'weight'):
                if hasattr(layer.weight, 'requires_grad'):
                    layer.weight.requires_grad = False
            if hasattr(layer, 'bias'):
                if hasattr(layer.bias, 'requires_grad'):
                    layer.bias.requires_grad = False
                    
    #Set optimizer
    if parameters_online['optimizer']=='AdamW':
        optimizer = AdamW(model_cropp.parameters(), 
                          lr = parameters_online['lr'], weight_decay=0)
        print('Optimizer: AdamW')
    if parameters_online['optimizer']=='Adam':
        optimizer = torch.optim.Adam(model_cropp.parameters(), lr = parameters_online['lr'], 
                                 betas = (parameters_online['b1'], parameters_online['b2']), weight_decay = 0)
        print('Optimizer: Adam')
    if parameters_online['optimizer']=='SGD':
        optimizer = torch.optim.SGD(model_cropp.parameters(), lr=parameters_online['lr'], 
                                momentum=0.9, weight_decay = 1e-5, nesterov=True)
        print('Optimizer: SGD+momentum')

    if parameters_online['sched']:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=parameters_online['Tmax'], eta_min=parameters_online['eta_min'], last_epoch=-1)
        print('Schedule learning rate')
    else:
        print("No scheduling")
    
    # Run training
    
    train_loss, train_accum_loss, online_loss, online_acc, online_accum_acc, success_rate = train_pseudo_online(model_cropp, 
                                                                                                                test_subject_data[parameters_online['test subject files'][i_model]],
                                                                                                                average_mode,
                                                                                                                min_trials)
    print(parameters_online['test subject files'][i_model]) 
    
    ax_train_loss.plot(train_loss, color='grey', linewidth=1, alpha=line_opaque)
    ax_loss.plot(train_accum_loss, color='blue', linewidth=1, alpha=line_opaque)
    ax_loss.plot(online_loss, color='red', linewidth=1, alpha=line_opaque)
    ax_success_rate.plot(success_rate, color='grey', linewidth=1, alpha=line_opaque)
    ax_accuracy.plot(online_accum_acc, color='blue', linewidth=1, alpha=line_opaque)
    ax_accuracy.plot(online_acc, color='red', linewidth=1, alpha=line_opaque)
    
    if average_mode:
        try:
            avg_train_loss = np.mean([avg_train_loss, train_loss], axis=0)
            avg_train_accum_loss = np.mean([avg_train_accum_loss, train_accum_loss], axis=0)
            avg_online_loss = np.mean([avg_online_loss, online_loss], axis=0)
            avg_online_acc = np.mean([avg_online_acc, online_acc], axis=0)
            avg_online_accum_acc = np.mean([avg_online_accum_acc, online_accum_acc], axis=0)
            avg_success_rate = np.mean([avg_success_rate, success_rate], axis=0)
        except NameError:
            avg_train_loss = train_loss
            avg_train_accum_loss = train_accum_loss
            avg_online_loss = online_loss
            avg_online_acc = online_acc
            avg_online_accum_acc = online_accum_acc
            avg_success_rate = success_rate
        

save_im = False
if input('Do you want to save images for this training?') == 'yes':
    save_im = True
    now = datetime.datetime.now()
    fig_name = str(now.day)+'-'+str(now.month)+'_'+str(now.hour)+'-'+str(now.minute)
    
if average_mode:
    ax_train_loss.plot(avg_train_loss, color='orange', linewidth=2)
    ax_loss.plot(avg_online_loss, label='loss on remaining data', color='r', linewidth=4)
    ax_loss.plot(avg_train_accum_loss, label='accumulative training loss', color='b', linewidth=4)
    ax_success_rate.plot(avg_success_rate, color='orange', linewidth=2)
    ax_accuracy.plot(avg_online_acc, color='r', label='accuracy on remaining data', linewidth=4)
    ax_accuracy.plot(avg_online_accum_acc, color='b', label ='accumulative accuracy', linewidth=4)
    ax_loss.legend()
    ax_accuracy.legend()

ax_train_loss.set_xlabel('training step')
ax_train_loss.set_title('Loss during online training')

    
ax_loss.set_xlabel('training step')
ax_loss.set_title('Loss on training and remaining data')

ax_success_rate.set_xlabel('training step')
ax_success_rate.set_title('Success rate')


ax_accuracy.set_xlabel('training step')
ax_accuracy.set_title('Accuracy on training and remaining data')


if save_im:
    plt.savefig(fig_name+'_training curves')



if input('Do you want me to save the training description?') == 'yes':
    fill_line = todayRecord_online.shape[0]
    for param in parameters_online:
        todayRecord_online.at[fill_line, param] = parameters_online[param]    
    if save_im:
        todayRecord_online.at[fill_line, 'curve file'] = fig_name
    else:
        todayRecord_online.at[fill_line, 'curve file'] = 'no'
    
    print('All saved and you can keep training. You got it!')

        
    todayRecord_online.to_excel(str(name_database)+'_online'+'.xlsx')


In [None]:
# function for training

def train_pseudo_online(model_cropp, test_set, averaged_mode, min_trial):
    train_loss = np.array([])
    train_accum_loss = np.array([])
    online_loss = np.array([])
    online_acc = np.array([])
    online_accum_acc = np.array([])
    success_rate = np.array([])

    iterator_online = CropsFromTrialsIterator(batch_size=parameters_online['batch_size'],input_time_length=test_set.X.shape[2],
                                  n_preds_per_input=n_preds_per_input)
    
    if averaged_mode:
        skip_trials = test_set.X.shape[0] - min_trial
    else:
        skip_trials = 20
    
    print('Total test trials for subject:', test_set.X.shape[0])
    print('Left unseen trials:', skip_trials)
    for trial in range(parameters_online['batch_size'],test_set.X.shape[0]-skip_trials,parameters_online['batch_size']):

        test_data_input = test_set.X[:trial,:,:]
        test_label_input = test_set.y[:trial]
        test_set_input = SignalAndTarget(test_data_input, test_label_input)

        model_cropp.train()

        net_in = Variable(FloatTensor(test_set_input.X[:, :, :, None]))
        net_target = Variable(LongTensor(test_set_input.y))
        outputs = model_cropp(net_in)
        outputs = torch.mean(outputs, dim=2, keepdim=False)
        loss = criterion(outputs, net_target) 
        train_accum_loss = np.append(train_accum_loss, loss.detach().cpu().numpy())
        online_accum_acc = np.append(online_accum_acc, 
                                     np.mean(np.argmax(outputs.detach().cpu().numpy(), axis = 1) == test_label_input))

        for batch_X, batch_y in iterator_online.get_batches(test_set_input, shuffle=True):
    #         lr = np.append(lr, optimizer.state_dict()['param_groups'][0]['lr'])

            net_in = Variable(FloatTensor(batch_X))
            net_target = Variable(LongTensor(batch_y))

            optimizer.zero_grad()

            outputs = model_cropp(net_in)
            outputs = torch.mean(outputs, dim=2, keepdim=False)

            loss = criterion(outputs, net_target) 
            loss.backward()
            optimizer.step()      
            if parameters_online['sched']:
                scheduler.step()
            outputs = outputs.detach().cpu().numpy()
            success_rate = np.append(success_rate, 
                                     np.mean(np.argmax(outputs, axis = 1) == batch_y ))

            train_loss = np.append(train_loss, loss.detach().cpu().numpy())

        #check on remaining data
        model_cropp.eval()

        leftovers_X, leftovers_y = test_set.X[trial:,:,:,None], test_set.y[trial:]

        net_in = Variable(FloatTensor(leftovers_X), requires_grad=False)
        net_target = Variable(LongTensor(leftovers_y), requires_grad=False)

        outputs = model_cropp(net_in)
        outputs = torch.mean(outputs, dim=2, keepdim=False)

        loss = criterion(outputs, net_target)

        loss = loss.cpu().detach().numpy()
        outputs = outputs.detach().cpu().numpy()

        accuracy = np.mean(np.argmax(outputs, axis = 1) == leftovers_y)

        online_loss = np.append(online_loss, loss)    
        online_acc = np.append(online_acc, accuracy)
        
    return train_loss, train_accum_loss, online_loss, online_acc, online_accum_acc, success_rate
    

## Test  savings

In [None]:
import sys
from braindecode.torch_ext.constraints import MaxNormDefaultConstraint
sys.path.insert(0, '/home/tanja/braindecode-online/bdonline')


%load_ext autoreload

%autoreload 1

%aimport trainers, braindecode.models.deep4

In [None]:
# model_cropp = torch.load('model_19-06_11-16')
optimizer = AdamW(model_cropp.parameters(), 
                  lr = parameters['lr'], weight_decay=0)
trainer = trainers.BatchCntTrainer(model=model_cropp, 
                                   loss_function=criterion, 
                                   model_loss_function=None, 
                                   model_constraint = MaxNormDefaultConstraint(), 
                                   optimizer=optimizer, 
                                   input_time_length=parameters['input_time_length'], 
                                   n_preds_per_input=n_preds_per_input, 
                                   n_classes=parameters['n_classes'], 
                                   n_updates_per_break=5, 
                                   batch_size=30,
                                   n_min_trials=10, 
                                   trial_start_offset=500, 
                                   break_start_offset=1000, 
                                   break_stop_offset=-1000,
                                   savegrad=True)

In [None]:
loss_list = np.array([])

for batch_X, batch_y in iterator.get_batches(train_set, shuffle=True):
    loss, outputs = trainer.train_on_batch(batch_X, batch_y)
    loss_list = np.append(loss_list, loss.detach().cpu().numpy())
#     try:
#         outputs_list
#         outputs_list = np.dstack((outputs_list, outputs.detach().cpu().numpy()))
#     except NameError:
#         outputs_list = outputs.detach().cpu().numpy()


In [None]:
import glob
path = '/home/tanja/DLVR/savedInfo/'
data_list = np.sort(glob.glob(path + 'data_*'))
weights_list = np.sort(glob.glob(path + 'weights_*'))
grad_list = np.sort(glob.glob(path + 'grad_*'))
optimizer_list = np.sort(glob.glob(path + 'optimizer_*'))
 

In [None]:
# model = braindecode.models.deep4.Deep4Net(in_chans=parameters['in_chans'],
#                                           n_classes=parameters['n_classes'],
#                                           input_time_length=parameters['input_time_length'], 
#                                           final_conv_length=1, 
#                                           batch_norm_alpha=0.1).create_network()
# to_dense_prediction_model(model)
# model.cuda()
model = torch.load('model_19-06_11-16')

layer_names = list(dict(model.named_children()).keys())

optimizer = AdamW(model_cropp.parameters(), 
                  lr = parameters['lr'], weight_decay=0)

In [None]:
loss_check = np.array([])
for item in range(len(weights_list)):
    
    batch = torch.load(data_list[item])
    inputs = Variable(FloatTensor(batch['inputs']))
    targets = Variable(LongTensor(batch['targets']))
    
    model.train()
    weight_step = torch.load(weights_list[item])
    grad_step = torch.load(grad_list[item])    
    optimizer_step = torch.load(optimizer_list[item])
    
    for l_nr, layer in enumerate(model):
        if hasattr(layer, 'weight'):
            layer.weight = weight_step[layer_names[l_nr] + '_weight']
            layer.weight.grad = grad_step[layer_names[l_nr] + '_weight_grad']
        if hasattr(layer, 'bias'):
            layer.bias = weight_step[layer_names[l_nr] + '_bias']
            if hasattr(layer.bias, 'grad'):
                layer.bias.grad = grad_step[layer_names[l_nr] + '_bias_grad'] 
    
    outputs = model(inputs)
    outputs = torch.mean(outputs, dim=2, keepdim=False)
    loss = criterion(outputs, targets)
#     optimizer = AdamW(model.parameters(), 
#                   lr = parameters_online['lr'], weight_decay=0)
#     optimizer.load_state_dict(optimizer_step)
#     optimizer.step()
    
    loss_check = np.append(loss_check, loss.cpu().detach().numpy())
    


In [None]:
process = OnlineToOffline.OfflineProcess(model_cropp, data=data_list, weights=weights_list, gradients=grad_list, optimizer_state=optimizer_list)


## Autotraining


In [None]:
# Create model

from braindecode.torch_ext.util import set_random_seeds
from braindecode.models.util import to_dense_prediction_model

cuda = True
set_random_seeds(seed=20170629, cuda=cuda)

input_time_length = train.shape[2]
n_classes = 2
in_chans = train.shape[1]
model_cropp = braindecode.models.deep4.Deep4Net(in_chans=in_chans, n_classes=n_classes, 
                 input_time_length=input_time_length, final_conv_length=1, 
                       batch_norm_alpha=0.1)

if cuda:
    model_cropp.cuda()

In [None]:
optimizer = AdamW(model_cropp.parameters(), lr=0.001, weight_decay=1e-5)
model_cropp.compile(loss=F.nll_loss, optimizer=optimizer, iterator_seed=1, cropped=True)

In [None]:
input_time_length = train.shape[2]
model_cropp.fit(train_set.X, train_set.y, epochs=100, batch_size=60, scheduler='cosine', input_time_length=input_time_length, validation_data=(valid_set.X, valid_set.y),)

In [None]:
model_cropp.epochs_df

In [None]:
model_cropp.evaluate(test_set.X, test_set.y)

In [None]:
plt.figure(figsize=(10,5))
plt.plot(model_cropp.epochs_df['train_loss'], label='training loss')
plt.plot(model_cropp.epochs_df['valid_loss'], label='validation loss')
plt.xlabel('epoch')
plt.legend()


plt.figure(figsize=(10,5))
plt.plot(model_cropp.epochs_df['train_misclass'], label='training misclass')
plt.plot(model_cropp.epochs_df['valid_misclass'], label='validation misclass')
plt.xlabel('epoch')
plt.legend()