In [1]:
import numpy as np
def get_dataset(names,samples,classes=2,cross=False):
    data = []
    labels = []
    for name, sample in zip(names,samples):
        for s in sample:
            datum = np.load(f"data/preprocessed/data/{name}_{s}.npy")
            if classes == 2:
                datum = np.concatenate((datum[:10],datum[40:]))
                label = np.concatenate(([0 for _ in range(10)],[1 for _ in range(10)]))
            elif classes == 3:
                datum = datum = np.concatenate((datum[:10], datum[10:20], datum[-10:]))
                label = np.concatenate(([0 for _ in range(10)],[1 for _ in range(10)], [2 for _ in range(10)]))
            else:
                label = np.concatenate(([0 for _ in range(10)],[1 for _ in range(10)], [2 for _ in range(10)], [3 for _ in range(10)], [4 for _ in range(10)]))
            data.append(datum)
            labels.append(label)            
    groups = []
    if cross:
        for i in range(len(names)):
            for _ in range(data[i].shape[0] * len(samples[i])):
                groups.append(i)
    else:
        if classes == 2:
            num_per = 20
        elif classes == 3:
            num_per = 30 
        else:
            num_per = 50
        for i in range(len(data)):
            for _ in range(num_per):
                groups.append(i)
    groups = np.array(groups)
    data = np.concatenate(data)
    labels = np.concatenate(labels)
    return data, labels, groups
    

In [12]:
names = ["onno", "yoyo", "emma"]
samples = [[1,2,3,4,5,7,8],[1,2,3,4,5,6,7,8,9],[1,2]]
data, labels, groups = get_dataset(names,samples,classes=2,cross=False)
print(data.shape)
print(labels.shape)
print(groups.shape)

(360, 16, 1250)
(360,)
(360,)


In [13]:
unique_labels, counts = np.unique(labels,return_counts=True)
print(np.asarray((unique_labels, counts)).T)

[[  0 180]
 [  1 180]]


In [14]:
from classify import classify_torch
from classify import EEGNet
from sklearn.model_selection import LeaveOneGroupOut
kwargs = {'num_electrodes':data.shape[1], 'chunk_size': data.shape[-1],'num_classes':len(unique_labels)}
metrics_dict = classify_torch(data,labels,EEGNet,kwargs=kwargs,cv_splitter=LeaveOneGroupOut(),groups=groups,bayesian=True,num_epochs=30)
    

Device: cuda:0
Epoch [1/30], Loss: 1.9175
Epoch [2/30], Loss: 2.0146
Epoch [3/30], Loss: 1.5916
Epoch [4/30], Loss: 1.5323
Epoch [5/30], Loss: 1.5099
Epoch [6/30], Loss: 1.1850
Epoch [7/30], Loss: 1.2228
Epoch [8/30], Loss: 1.1341
Epoch [9/30], Loss: 1.1231
Epoch [10/30], Loss: 1.0679
Epoch [11/30], Loss: 1.0184
Epoch [12/30], Loss: 1.0273
Epoch [13/30], Loss: 0.8715
Epoch [14/30], Loss: 0.9444
Epoch [15/30], Loss: 0.8040
Epoch [16/30], Loss: 0.7738
Epoch [17/30], Loss: 0.7366
Epoch [18/30], Loss: 0.6637
Epoch [19/30], Loss: 0.6533
Epoch [20/30], Loss: 0.6652
Epoch [21/30], Loss: 0.6098
Epoch [22/30], Loss: 0.5625
Epoch [23/30], Loss: 0.6265
Epoch [24/30], Loss: 0.5918
Epoch [25/30], Loss: 0.5543
Epoch [26/30], Loss: 0.5591
Epoch [27/30], Loss: 0.5246
Epoch [28/30], Loss: 0.6375
Epoch [29/30], Loss: 0.5481
Epoch [30/30], Loss: 0.5184
Mean accuracy for current fold: 1.0
Epoch [1/30], Loss: 2.1748
Epoch [2/30], Loss: 1.5716
Epoch [3/30], Loss: 1.5314
Epoch [4/30], Loss: 1.2328
Epoch [5/3

In [15]:
print(metrics_dict)

{'mean_accuracy': 0.9194444444444443, 'best_accuracy': 1.0, 'worst_accuracy': 0.45, 'mean_precision': 0.91455116377717, 'best_precision': 1.0, 'worst_precision': 0.23684210526315788, 'mean_recall': 0.9194444444444443, 'best_recall': 1.0, 'worst_recall': 0.45, 'mean_f1': 0.9078778222642823, 'best_f1': 1.0, 'worst_f1': 0.3103448275862069, 'mean_difference': 0.08055555555555556, 'median_difference': 0.0}
