In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import xdf
import matplotlib.pyplot as plt
import re
import resampy
import mne
import xdf_interface as xif
from braindecode.datautil.signalproc import exponential_running_standardize
from braindecode.torch_ext.util import np_to_var, var_to_np
import seaborn as sns

## Load Network

In [None]:
from braindecode.models.deep4 import Deep4Net

model = Deep4Net(64, 2, 600, 2)

In [None]:
from braindecode.torch_ext.optimizers import AdamW
import torch.nn.functional as F
optimizer = AdamW(model.parameters(), lr=.8*0.01, weight_decay=0.5*0.001) # these are good values for the deep model
#optimizer = AdamW(model.parameters(), lr=0.0625 * 0.01, weight_decay=0)
model.compile(loss=F.nll_loss, optimizer=optimizer, iterator_seed=1, cropped = True)

In [None]:
model = model.create_network()

In [None]:
path = ""
#folder of the initial model
modelfolder = ''
#folder of the adapted models
gradfolder = ''

In [None]:
import torch
model.load_state_dict(torch.load(path+modelpath+'deep_4_params'))

In [None]:
from braindecode.models.util import to_dense_prediction_model
to_dense_prediction_model(model)
model.cuda()

## Load Test Set

In [None]:
target_framerate = 250;
timeframe_start = 1.5

In [None]:
files = np.array(['data_1.xdf', 'data_2.xdf'])

In [None]:
X,y = xif.bdonline_extract(path, files, timeframe_start,target_framerate)

## Extract Supercrops from Trials

In [None]:
def trials2supercrops(X,y):    
    data = np.empty([64, 600, 0])
    classes = np.array([])
    
    for i in np.arange(len(X)):
        print("Reading trial ", i+1, " of ", len(X))
    
        end = 600
        while end < X[i].shape[1]:
            data = np.append(data, X[i][:, (end-600):end,None ],2)
            classes = np.append(classes, y[i])
            end = end + 125
        end = X[i].shape[1]
        data = np.append(data, X[i][:, (end-600):end,None ],2)
        classes = np.append(classes, y[i])
    return(data,classes)

In [None]:
data = np.empty([64, 600, 0])
classes = np.array([])

d,c = trials2supercrops(X,y)
data = np.append(data, d, 2)
classes = np.append(classes, c)

## Calculate Missclassification for each Net

In [None]:
def test_network(model, data, classes):
    correct = 0
    for i in np.arange(data.shape[2]):
        in_np = data[:,:,i].T
        in_var = np_to_var(in_np.T[None,:,:,None], dtype=np.float32)
        in_var = in_var.cuda()
        pred = var_to_np(model(in_var))
        pred = np.exp(pred)
        if pred.ndim > 2:
            pred = np.mean(pred, axis=2).squeeze()
            
        if np.argmax(pred) == classes[i]:
            correct += 1
            
        
    misclass = 1-(float(correct)/len(classes))
    return misclass

In [None]:
accs = np.empty(len(X))

In [None]:
for i in np.arange(len(X)):
    print("Evaluating Trial ",i+1," of ", len(X))
    
    if i < 10:
        net = path+modelfolder+'deep_4_params'
    else:
        net = path+gradfolder+'state_dict_Trial-'+str(i)+'_Epoch-4'
        
    model.load_state_dict(torch.load(net))
    model.eval()
    accs[i] = test_network(model, data, classes)
print("Done")

## Plot Results

In [None]:
%matplotlib qt
sns.set_palette('colorblind')

In [None]:
plt.figure()

plt.plot(accs[:],'+-', lw = 2.5,color = plt.cm.bone(c/(len(nnets)+2)))


plt.xlim(0, len(X))

plt.vlines(10,0,1, lw=.8)
plt.hlines(.5,0,len(X),alpha=.4, linestyles = 'dashed')

plt.ylim(0,1)

plt.text(11, .8, 'start training', size = 16)

#plt.suptitle("TITLE", size = 20)
plt.ylabel('Misclassification', size = 24)
plt.xlabel('Trial', size = 24)
plt.gca().tick_params(labelsize=20)

plt.show()