<a href="https://colab.research.google.com/github/GhBlg/Others/blob/main/sweeps_with_blocks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

###Imports

In [None]:
!pip install wandb
!pip install scipy -U
!wandb login acf4cf1e91a5b92ebf1613195d6d05c09db63b4e
import math
import numpy as np
from math import ceil
import wandb
from torch import optim
import torch.nn.functional as F
from sklearn.metrics import f1_score
import torch
from torch import nn
from torch.nn import init
from torch.nn.utils import weight_norm
%matplotlib 
import matplotlib.pyplot as plt
import os
from google.colab import drive
drive.mount('/content/drive/')
import sys
sys.path.insert(0,'/content/drive/MyDrive/Colab Notebooks/mne_data')
sys.path.append(os.path.abspath('/content/drive/MyDrive/Colab Notebooks/mne_data'))
!mkdir /content/resultdir
!pip install torchinfo
from torchinfo import summary
resultdir='/content/resultdir'
!pip install braindecode
!pip install tqdm
!pip install mne
!pip install moabb

###########     REPRODUCIBILITY      #######################################
random_seed=42
#torch.use_deterministic_algorithms(True)
torch.manual_seed(random_seed)
############################################################################

############################################################################

cuda = torch.cuda.is_available()  # check if GPU is available, if True chooses to use it (will take the gpu specified by the argument args.device)
device = 'cuda' if cuda else 'cpu'
os.makedirs(resultdir,exist_ok=True)
print(f"Will use device {device}")
print(f"Will save checkpoints to {resultdir}")
################################################################################################################################# 


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting wandb
  Downloading wandb-0.13.1-py2.py3-none-any.whl (1.8 MB)
[K     |████████████████████████████████| 1.8 MB 34.0 MB/s 
[?25hCollecting sentry-sdk>=1.0.0
  Downloading sentry_sdk-1.9.5-py2.py3-none-any.whl (157 kB)
[K     |████████████████████████████████| 157 kB 68.4 MB/s 
[?25hCollecting shortuuid>=0.5.0
  Downloading shortuuid-1.0.9-py3-none-any.whl (9.4 kB)
Collecting docker-pycreds>=0.4.0
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting pathtools
  Downloading pathtools-0.1.2.tar.gz (11 kB)
Collecting GitPython>=1.0.0
  Downloading GitPython-3.1.27-py3-none-any.whl (181 kB)
[K     |████████████████████████████████| 181 kB 80.2 MB/s 
Collecting setproctitle
  Downloading setproctitle-1.3.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30 kB)
Collecting gitdb<5,>=4.0.1
  Downloading gitdb-

Will use device cuda
Will save checkpoints to /content/resultdir


###Utils

In [None]:
import numpy as np
import torch
# Import necessary library
from scipy.linalg import sqrtm, inv 
import numpy as np
from braindecode.datasets.moabb import MOABBDataset
from braindecode.datautil.windowers import create_windows_from_events


def resample(eeg):
    secs = eeg.shape[-1]/250.0 # from 250 Hz
    samps = int(secs*125)     # to 125 Hz
    eeg2 = torch.nn.functional.interpolate(eeg, size=samps)
    return eeg2


## Dataloader ###
def loaders(removed_subject,T_x,T_y,V_x,V_y,Test_x,Test_y,batch_size):
    train_data = list(zip(T_x, T_y))
    train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size, shuffle = True)
    valid_data = list(zip(V_x, V_y))
    valid_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size, shuffle = True)
    test_loader = torch.utils.data.DataLoader(list(zip(Test_x, Test_y)), batch_size = 10000)
    return train_loader, valid_loader, test_loader


# Apply Euclidean Alignment
def apply_EA(data, device='cuda'):
    '''
    Apply Euclidean aligment on array-like objects for 1 subject
    
    PARAMETER:
    data: 
        Data of one subject.
    
    
    OUTPUT:
        Aligned data with Euclidean Alignment
    '''
    
    # So that this function can handles separated or combined left and right trials
    # If they are separated
    # If they are not separated

    #if input is a torch tensor
    if torch.is_tensor(data):
        data=data.cpu().detach().numpy() 

    print('Found %d trial(s) in which EEG data is stored' %len(data))
    all_trials = data
    
    # Calculate reference matrix
    RefEA = 0
    print('Computing reference matrix RefEA')

    # Iterate over all trials, compute reference EA
    for trial in all_trials:
        cov = np.cov(trial, rowvar=True)
        RefEA += cov

    # Average over all trials
    RefEA = RefEA/all_trials.shape[0]
    
    # Adding reference EA as a new key in data
    data_dict={}
    print('Add RefEA as a new key in data')
    data_dict['RefEA'] = RefEA 
    
    # Compute R^(-0.5)
    R_inv = sqrtm(inv(RefEA))
    data_dict['R_inv'] = R_inv
    
        
    # Perform EA on each trial
    all_trials_EA = []
        
    for t in all_trials:
        all_trials_EA.append(R_inv@t)
        
    # Return all_trials_EA
    return torch.tensor(np.array(all_trials_EA)).float()
        

def standardize(X, mean=None, std=None):
  mean = X.mean(dim=-1, keepdim=True)
  std = X.std(dim=-1, keepdim=True)
  return (X - mean) / std , mean, std



################### Load Physionet data as done in the TIDNet paper ###################################################################
################## calling load_eeg_bci will return an object having data[subjects][records][0/1] 0 for Xs and 1 for Ys ###############
from torch.utils.data import Dataset
from mne.datasets import eegbci
import mne 
import tqdm
from collections import OrderedDict

BAD_SUBJECTS_EEGBCI = [87, 89, 91, 99]
SUBJECTS_EEGBCI = list(i for i in range(109) if i not in BAD_SUBJECTS_EEGBCI)
EVENTS_EEGBCI = dict(hands=2, feet=3)
BASELINE_EYES_OPEN = [1]
BASELINE_EYES_CLOSED = [2]

MOTOR_FISTS = (3, 7, 11)
IMAGERY_FISTS = (4, 8, 12)
MOTOR_FEET = (5, 9, 13)
IMAGERY_FEET_V_FISTS = (6, 10, 14)

def zscore(data: np.ndarray, axis=-1):
    return (data - data.mean(axis, keepdims=True)) / (data.std(axis, keepdims=True) + 1e-12)


def one_hot(y: torch.Tensor, num_classes):
    """ 1-hot encodes a tensor to another similarly stored tensor"""
    if len(y.shape) > 0 and y.shape[-1] == 1:
        y = y.squeeze(-1)
    out = torch.zeros(y.size()+torch.Size([num_classes]), device=y.device)
    return out.scatter_(-1, y.view((*y.size(), 1)), 1)

class EpochsDataset(Dataset):

    def __init__(self, epochs: mne.Epochs, force_label=None, picks=None, preproccesors=None, normalizer=zscore,
                 runs=None, train_mode=False):
        self.mode = train_mode
        self.epochs = epochs
        self._t_len = epochs.tmax - epochs.tmin
        self.loaded_x = [None for _ in range(len(epochs.events))]
        self.runs = runs
        self.picks = picks
        self.force_label = force_label if force_label is None else torch.tensor(force_label)
        self.normalizer = normalizer
        self.preprocessors = preproccesors if isinstance(preproccesors, (list, tuple)) else [preproccesors]
        for i, p in enumerate(self.preprocessors):
            self.preprocessors[i] = p(self.epochs)

    @property
    def channels(self):
        if self.picks is None:
            return len(self.epochs.ch_names)
        else:
            return len(self.picks)

    @property
    def sfreq(self):
        return self.epochs.info['sfreq']

    def train_mode(self, mode=False):
        self.mode = mode

    def __getitem__(self, index):
        ep = self.epochs[index]
        if self.loaded_x[index] is None:
            x = ep.get_data()
            if len(x.shape) != 3 or 0 in x.shape:
                print("I don't know why: {} index{}/{}".format(self.epochs, index, len(self)))
                print(self.epochs.info['description'])
                # raise AttributeError()
                return self.__getitem__(index - 1)
            x = x[0, self.picks, :]
            for p in self.preprocessors:
                x = p(x)
            x = torch.from_numpy(self.normalizer(x).astype('float32')).squeeze(0)
            self.loaded_x[index] = x
        else:
            x = self.loaded_x[index]

        y = torch.from_numpy(ep.events[..., -1]).long() if self.force_label is None else self.force_label

        if self.runs is not None:
            return x, y, one_hot(torch.tensor(self.runs * index / len(self)).long(), self.runs)

        return x, y

    def __len__(self):
        events = self.epochs.events[:, 0].tolist()
        return len(events)

def same(x):
    return x

def load_eeg_bci(targets=4, tmin=0, tlen=3, t_ev=0, t_sub=None, normalizer=same, low_f=None, high_f=None, #zscore
                 alignment=False, path_mne=None):

    paths = [eegbci.load_data(s+1, IMAGERY_FISTS, path=path_mne, update_path=False) for s in SUBJECTS_EEGBCI]
    raws = [mne.io.concatenate_raws([mne.io.read_raw_edf(p, preload=True) for p in path])
            for path in tqdm.tqdm(paths, unit='subj', desc='Loading')]
    datasets = OrderedDict()
    for i, raw in tqdm.tqdm(list(zip(SUBJECTS_EEGBCI, raws)), desc='Preprocessing'):
        if raw.info['sfreq'] != 160:
            tqdm.tqdm.write('Skipping..., sampling frequency: {}'.format(raw.info['sfreq']))
            continue
        raw.rename_channels(lambda x: x.strip('.'))
        if low_f or high_f:
            raw.filter(low_f, high_f, fir_design='firwin', skip_by_annotation='edge')
        events, _ = mne.events_from_annotations(raw, event_id=dict(T1=0, T2=1))
        picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, exclude='bads')
        epochs = mne.Epochs(raw, events[:41, ...], tmin=tmin, tmax=tmin + tlen - 1 / raw.info['sfreq'], picks=picks,
                            baseline=None, reject_by_annotation=False)#.drop_bad()
        if targets > 2:
            paths = eegbci.load_data(i + 1, BASELINE_EYES_OPEN, path=path_mne, update_path=False)
            raw = mne.io.concatenate_raws([mne.io.read_raw_edf(p, preload=True) for p in paths])
            raw.rename_channels(lambda x: x.strip('.'))
            if low_f or high_f:
                raw.filter(low_f, high_f, fir_design='firwin', skip_by_annotation='edge')
            events = np.zeros((events.shape[0] // 2, 3)).astype('int')
            events[:, -1] = 2
            events[:, 0] = np.linspace(0, raw.info['sfreq'] * (60 - 2 * tlen), num=events.shape[0]).astype(np.int)
            picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, exclude='bads')
            eyes_epochs = mne.Epochs(raw, events, tmin=tmin, tmax=tmin + tlen - 1 / raw.info['sfreq'], picks=picks,
                                     baseline=None, reject_by_annotation=False)#.drop_bad()
            epochs = mne.concatenate_epochs([eyes_epochs, epochs])
        if targets > 3:
            paths = eegbci.load_data(i+1, IMAGERY_FEET_V_FISTS, path=path_mne, update_path=False)
            raw = mne.io.concatenate_raws([mne.io.read_raw_edf(p, preload=True) for p in paths])
            raw.rename_channels(lambda x: x.strip('.'))
            if low_f or high_f:
                raw.filter(low_f, high_f, fir_design='firwin', skip_by_annotation='edge')
            events, _ = mne.events_from_annotations(raw, event_id=dict(T2=3))
            picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, exclude='bads')
            feet_epochs = mne.Epochs(raw, events[:20, ...], tmin=tmin, tmax=tmin + tlen - 1 / raw.info['sfreq'],
                                     picks=picks, baseline=None, reject_by_annotation=False)#.drop_bad()
            epochs = mne.concatenate_epochs([epochs, feet_epochs])

        datasets[i] = EpochsDataset(epochs, preproccesors=EuclideanAlignment if alignment else [],
                                    normalizer=normalizer, runs=3)

    return datasets


####################################################################################################################



def load_subjects(path, number_of_subjects, device, bad_subjects=[], apply_euclidean=True, with_eog=True):
    l= number_of_subjects+1
    sbj_x=[]
    sbj_y=[]
    if number_of_subjects==9:
        for subject_id in [e for e in range(1,l) if e not in bad_subjects]:
            dataset = MOABBDataset(dataset_name="BNCI2014001", subject_ids=[subject_id])


            trial_start_offset_seconds = 0#-0.5  #####################
            # Extract sampling frequency, check that they are same in all datasets
            sfreq = dataset.datasets[0].raw.info['sfreq']            
            trial_stop_offset_samples= int(-1 * sfreq)  #######################

            assert all([ds.raw.info['sfreq'] == sfreq for ds in dataset.datasets])
            # Calculate the trial start offset in samples.
            trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)

            # Create windows using braindecode function for this. It needs parameters to define how
            # trials should be used.
            windows_dataset = create_windows_from_events(
                dataset,
                trial_start_offset_samples=trial_start_offset_samples,
                trial_stop_offset_samples=trial_stop_offset_samples, 
                preload=True,
            )


            splitted = windows_dataset.split('session')
            train_set = splitted['session_T']
            valid_set = splitted['session_E']

            # delete stim channel and eog channels if wanted
            if with_eog==False: 
              dlt=-4
            else:
              dlt=-1

            train_x=np.array([ele[0][:dlt] for ele in train_set])
            train_y=np.array([ele[1] for ele in train_set])

            valid_x=np.array([ele[0][:dlt] for ele in valid_set])
            valid_y=np.array([ele[1] for ele in valid_set])

            T_x = torch.tensor( np.append(np.array(train_x), np.array(valid_x), axis=0) )
            T_y = torch.tensor( np.append(np.array(train_y), np.array(valid_y), axis=0) )

            T_x=resample(T_x)############################# added to do downsampling

            if apply_euclidean==True:
                x=apply_EA(T_x)
            else : 
                x=T_x

            sbj_x.append(x)
            sbj_y.append(T_y)
            del T_x, T_y


    if number_of_subjects==52:
        for subject_id in [e for e in range(1,l) if e not in bad_subjects]:
            dataset = MOABBDataset(dataset_name="Cho2017", subject_ids=[subject_id])


            trial_start_offset_seconds = 0#-0.5  #####################
            # Extract sampling frequency, check that they are same in all datasets
            sfreq = dataset.datasets[0].raw.info['sfreq']            
            trial_stop_offset_samples= int(-1 * sfreq)  #######################

            assert all([ds.raw.info['sfreq'] == sfreq for ds in dataset.datasets])
            # Calculate the trial start offset in samples.
            trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)

            # Create windows using braindecode function for this. It needs parameters to define how
            # trials should be used.
            windows_dataset = create_windows_from_events(
                dataset,
                trial_start_offset_samples=trial_start_offset_samples,
                trial_stop_offset_samples=trial_stop_offset_samples, 
                preload=True,
            )


            splitted = windows_dataset.split('session')
            splits = splitted['session_0']

            # delete stim channel and EMG channels and EOG if wanted 
            if with_eog==False: 
              ch=dataset.datasets[0].raw.ch_names
              no_eog=[e for e in range(len(ch)) if e not in [ch.index('Fp1'),ch.index('Fp2'),ch.index('Fpz'),64, 65, 66, 67, 68]]
              T_x=np.array([ele[0][no_eog] for ele in splits])
              T_y=np.array([ele[1] for ele in splits])
            else:
              dlt=-5
              T_x=np.array([ele[0][:dlt] for ele in splits])
              T_y=np.array([ele[1] for ele in splits])

            T_x = torch.tensor(T_x)
            T_y = torch.tensor(T_y)

            if apply_euclidean==True:
                x=apply_EA(T_x)
            else : 
                x=T_x

            sbj_x.append(x)
            sbj_y.append(T_y)
            del T_x, T_y


    elif number_of_subjects==109:
        bad_subjects=[s-1 for s in bad_subjects]
        data=load_eeg_bci(path_mne=path)
        for subject_id in [e for e in range(number_of_subjects) if e not in bad_subjects]:   
            xx=[]
            yy=[]
            for records in range(len(data[subject_id])):
                xx.append(data[subject_id][records][0])
                yy.append(data[subject_id][records][1])
            xx=torch.stack(xx)
            yy=torch.tensor(yy)
            if apply_euclidean==True:
                sbj_x.append(apply_EA(xx))
            else:
                sbj_x.append(xx)
            sbj_y.append(yy)

    return sbj_x, sbj_y



#################### Leave One Subject Out ############################
def loso(x, y, number_of_subjects, loso, device, bad_subjects=[], with_validation=True): 
    l= number_of_subjects
    [s-1 for s in bad_subjects]
    T_x=torch.tensor([])
    T_y=torch.tensor([])

    loso=loso-1 # because lists start from 0

    for subject_id in [e for e in range(l) if e not in [loso]+bad_subjects]:
        T_x=torch.cat((T_x,x[subject_id]), 0)
        T_y=torch.cat((T_y,y[subject_id]), 0)

    T_x=torch.tensor(T_x).reshape([-1, T_x.shape[1], T_x.shape[2]])
    T_y=torch.tensor(T_y).reshape([-1])

    k_f=int(0.9*T_x.shape[0])

    if with_validation==True:
        data_perm = torch.randperm(T_x.shape[0])
        temp_x, temp_y = T_x[data_perm[:]], T_y[data_perm[:]]
        T_x, T_y = temp_x[:k_f], temp_y[:k_f]
        V_x, V_y = temp_x[k_f:], temp_y[k_f:]
    else :
        V_x, V_y = torch.tensor([]), torch.tensor([])


    Test_x=x[loso]
    Test_y=y[loso]

    
    T_x, mean, std=standardize(T_x)
    if with_validation==True:
        V_x, _ , _=standardize(V_x,mean, std)
    Test_x, _ , _=standardize(Test_x, mean, std)

    ############################################################################
    return(T_x.to(device),T_y.to(device).to(torch.int64),V_x.to(device),V_y.to(device).to(torch.int64),Test_x.to(device),Test_y.to(device).to(torch.int64))


#################### Leave Multiple Subjects Out ############################
import random
from sklearn.model_selection import KFold

def lmso(x, y, number_of_subjects, kfold, device, bad_subjects=[], with_validation=True, random_seed=42): 
    c = list(zip(x, y))
    random.Random(random_seed).shuffle(c)
    x, y = zip(*c)
    
    # 10-fold crossvalidation
    kf = KFold(n_splits=10)

    x=np.array([t.numpy() for t in x])
    y=np.array([t.numpy() for t in y])


    i=1
    for train_index, test_index in kf.split(x):
        X_train, X_test = x[train_index], x[test_index]
        y_train, y_test = y[train_index], y[test_index]
        if i==kfold:
            break
        else:
            i+=1
    

    del c, kf

    l= number_of_subjects
    [s-1 for s in bad_subjects]

    T_x=torch.tensor([])
    T_y=torch.tensor([])
    Test_x=torch.tensor([])
    Test_y=torch.tensor([])

    for subject_id in [e for e in range(X_train.shape[0])]:
        T_x=torch.cat((T_x,torch.tensor(X_train[subject_id])), 0)
        T_y=torch.cat((T_y,torch.tensor(y_train[subject_id])), 0)

    T_x=torch.tensor(T_x).reshape([-1, T_x.shape[1], T_x.shape[2]])
    T_y=torch.tensor(T_y).reshape([-1])

    k_f=int(0.9*T_x.shape[0])

    if with_validation==True:
        data_perm = torch.randperm(T_x.shape[0])
        temp_x, temp_y = T_x[data_perm[:]], T_y[data_perm[:]]
        T_x, T_y = temp_x[:k_f], temp_y[:k_f]
        V_x, V_y = temp_x[k_f:], temp_y[k_f:]
    else :
        V_x, V_y = torch.tensor([]), torch.tensor([])

    for lms in range(len(X_test)):
        Test_x=torch.cat((Test_x,torch.tensor(X_test[lms])), 0)
        Test_y=torch.cat((Test_y,torch.tensor(y_test[lms])), 0)

    Test_x=torch.tensor(Test_x).reshape([-1, Test_x.shape[1], Test_x.shape[2]])
    Test_y=torch.tensor(Test_y).reshape([-1])
    
    T_x, mean, std=standardize(T_x)
    if with_validation==True:
        V_x, _ , _=standardize(V_x,mean, std)
    Test_x, _ , _=standardize(Test_x, mean, std)
    
    ############################################################################
    return(T_x.to(device),T_y.to(device).to(torch.int64),V_x.to(device),V_y.to(device).to(torch.int64),Test_x.to(device),Test_y.to(device).to(torch.int64))


#############Data Augmentation ###########################################################
from scipy.signal import butter, lfilter
from random import gauss
import numpy as np
import math
from matplotlib import pyplot

#add noise
def chua(n):
    alpha  = 15.6
    beta   = 31
    m0     = -1.143
    m1     = -0.714

    x=0.7
    y=0
    z=0

    dt=0.01

    sig=[]
    sig1=[]

    for i in range(n):
        
        phi=m1*x+0.5*(m0-m1)*(abs(x+1)-abs(x-1))
        x1=alpha*(y-x-phi)
        y1=x-y+z
        z1=-beta*y

        
        x=x1*dt+x
        y=y1*dt+y
        z=z1*dt+z

        sig.append(x)
        sig1.append(y)

    return sig

def SNR_Set(Signal, Desired_SNR_dB):
    Npts = len(Signal)
    #Gaussian Noise
##    Noise = [gauss(0.0, 1.0) for i in range(Npts)] # Generate initial noise;
##                                                #mean zero, variance one
    #Poisson noise
    Noise = np.random.poisson(5, Npts)


    #Chaotic noise (chua model)
    #Noise = chua(Npts)


    
    Signal_Power = sum(abs(Signal)*abs(Signal))/Npts
    absN=[abs(i) for i in Noise]
    absnsqrd=[i*i for i in absN]
    Noise_Power = sum(absnsqrd)/Npts
            

    K = (Signal_Power/Noise_Power)*10**(-Desired_SNR_dB/10)  

    New_Noise = [math.sqrt(K)*i for i in Noise]

    Noisy_Signal = Signal + New_Noise
    return Noisy_Signal

#flip channels
def flip_channels(data):
  r=torch.randperm(data.shape[1])
  data=data[:,r,:]
  return data

#time inverse
def time_inverse(data):
  return torch.flip(data, [2])

#Masking in %
def masking(data,p):
  x=data
  mask=torch.FloatTensor(x.shape).uniform_() > p
  masked_output = data * mask.int().float().to(device)
  return masked_output


def butter_bandpass(lowcut, highcut, fs, order=5):
    return butter(order, [lowcut, highcut], fs=fs, btype='band')
def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    y = lfilter(b, a, data)
    return y

selection=['fc','ti','mask','filter']
def apply_da(data,selection):
  s=random.choice(selection)
  if s=='fc':
    data=flip_channels(data)
  elif s=='ti':
    data=time_inverse(data)
  elif s=='mask':
    data=masking(data,0.2)
  elif s=='filter':
    data=torch.tensor(butter_bandpass_filter(data.detach().cpu().numpy(),0.1, 40, 250)).float().to(device)
  return data

#####################################################################################

def weight_attack(model, p=0.1):
  ii=0
  for i in model:
    try:
      r1=torch.max(model[ii].weight)
      r2=torch.min(model[ii].weight)
      M1=(r1 - r2) * torch.rand(model[ii].weight.shape)+ r2
      M2=(torch.abs(M1)< p*(r1 - r2)).float()
      with torch.no_grad():
        M1=torch.flatten(M1)
        M2=torch.flatten(M2)
        mw=torch.flatten(model[2].weight)
        shp=model[ii].weight.shape
        for i in range(len(M1)):
          if M2[i]==1:
            mw[i]=M1[i]
        mw=mw.double()

        model[ii].weight=torch.nn.Parameter(mw.reshape(shp))
    except:
      pass
    ii+=1
  return model

###coatnets

In [None]:
from math import ceil
import torch
from torch import nn
from torch.nn.utils import weight_norm
import torch.nn.functional as F
import math

##############  EEG_CoatNet models #############################################################

class PrintLayer(nn.Module):
    def __init__(self):
        super(PrintLayer, self).__init__()
    
    def forward(self, x):
        # Do your print / debug stuff here
        print(x.shape)
        return x

class Ensure4d(nn.Module):
    def forward(self, x):
        while(len(x.shape) < 4):
            x = x.unsqueeze(-1)
        return x


class Expression(nn.Module):
    """Compute given expression on forward pass.
    Parameters
    ----------
    expression_fn : callable
        Should accept variable number of objects of type
        `torch.autograd.Variable` to compute its output.
    """

    def __init__(self, expression_fn):
        super(Expression, self).__init__()
        self.expression_fn = expression_fn

    def forward(self, *x):
        return self.expression_fn(*x)

    def __repr__(self):
        if hasattr(self.expression_fn, "func") and hasattr(
            self.expression_fn, "kwargs"
        ):
            expression_str = "{:s} {:s}".format(
                self.expression_fn.func.__name__, str(self.expression_fn.kwargs)
            )
        elif hasattr(self.expression_fn, "__name__"):
            expression_str = self.expression_fn.__name__
        else:
            expression_str = repr(self.expression_fn)
        return (
            self.__class__.__name__ +
            "(expression=%s) " % expression_str
        )

############## Multi-head Attention Mechanism #################################################

def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    attn_logits = torch.matmul(q, k.transpose(-2, -1))
    attn_logits = attn_logits / math.sqrt(d_k)
    if mask is not None:
        attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
    attention = F.softmax(attn_logits, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention

class MultiheadAttention(nn.Module):

    def __init__(self, input_dim, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be 0 modulo number of heads."

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Stack all weight matrices 1...h together for efficiency
        # Note that in many implementations you see "bias=False" which is optional
        self.qkv_proj = nn.Linear(input_dim, 3*embed_dim)
        self.o_proj = nn.Linear(embed_dim, embed_dim)

        self._reset_parameters()

    def _reset_parameters(self):
        # Original Transformer initialization, see PyTorch documentation
        nn.init.xavier_uniform_(self.qkv_proj.weight)
        self.qkv_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.o_proj.weight)
        self.o_proj.bias.data.fill_(0)


    def forward(self, x, mask=None, return_attention=False):
        batch_size, seq_length, embed_dim = x.size()
        qkv = self.qkv_proj(x)

        # Separate Q, K, V from linear output
        qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3*self.head_dim)
        qkv = qkv.permute(0, 2, 1, 3) # [Batch, Head, SeqLen, Dims]
        q, k, v = qkv.chunk(3, dim=-1)

        # Determine value outputs
        values, attention = scaled_dot_product(q, k, v, mask=mask)
        values = values.permute(0, 2, 1, 3) # [Batch, SeqLen, Head, Dims]
        values = values.reshape(batch_size, seq_length, self.embed_dim)
        o = self.o_proj(values)

        if return_attention:
            return o, attention
        else:
            return o



############## stem_stage (reducing temporal dimension) #################################################
class stem_stage(nn.Module):
    def __init__(self, in_channels, out_channels, k ):
        super().__init__()
        def _permute(x):
            '''
            Permutes data:
            from dim:
            batch, chans, time, 1
            to dim:
            batch, chans, 1, time'''
            return x.permute([0, 1, 3, 2])

        prnt=PrintLayer()
      
        layers = [Ensure4d(),
                Expression(_permute),
                nn.Conv2d(in_channels, out_channels, kernel_size=(1,k)),
                nn.BatchNorm2d(out_channels),
                nn.GELU()
                ]

        self.model= nn.Sequential(*layers)
      
    def forward(self, x):
      return self.model(x)

############## conv_stage #################################################
class conv_stage(nn.Module):
    def __init__(self, in_channels, out_channels, n_ch, conv_kernel, pool_kernel):
        super().__init__()

        prnt=PrintLayer()
        self.pool=nn.AvgPool2d(kernel_size=(1, pool_kernel))
        self.res=nn.Conv2d(in_channels, out_channels, kernel_size=(1,1), padding='same')
        layers = [nn.BatchNorm2d(in_channels),
                  nn.Conv2d(in_channels, out_channels, kernel_size=(1,1), padding='same'),
                  nn.Conv2d(in_channels=n_ch, out_channels=n_ch, kernel_size=(1,conv_kernel),groups=n_ch, padding='same'),
                nn.Conv2d(out_channels, out_channels, kernel_size=(1,1), padding='same'),
                nn.ReLU()
                ]

        self.model= nn.Sequential(*layers)
      
    def forward(self, x):
      return self.pool(self.res(x))+self.pool(self.model(x))

############## InceptionBlock #################################################
class InceptionBlock(nn.Module):
    def __init__(self, in_channels, out_channels, n_ch, kernel_list, pool_kernel):        
        super().__init__()

        self.branch1 = conv_stage(in_channels, out_channels, n_ch, conv_kernel=kernel_list[0], pool_kernel=2)
        self.branch2 = conv_stage(in_channels, out_channels, n_ch, conv_kernel=kernel_list[1], pool_kernel=2)
        self.branch3 = conv_stage(in_channels, out_channels, n_ch, conv_kernel=kernel_list[2], pool_kernel=2)
        self.branch4 = conv_stage(in_channels, out_channels, n_ch, conv_kernel=kernel_list[3], pool_kernel=2)

    def forward(self, x):
        branches = (self.branch1, self.branch2, self.branch3, self.branch4)
        return torch.cat([branch(x) for branch in branches], 1)


############## att_stage #################################################

class FeedForward(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, output_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x): 
        return self.net(x)
        
class att_stage(nn.Module):
    def __init__(self, embed_dim, num_heads, in_channels, temp_dim, pool_kernel, incep=1, masking_rate=0.1):
        super().__init__()

        prnt=PrintLayer()
        self.mr=masking_rate
        self.incep=incep
        self.ed=embed_dim
        self.ic= in_channels
        self.ln=nn.LayerNorm(temp_dim)
        self.mha=MultiheadAttention(temp_dim ,embed_dim, num_heads)
        self.ffn=FeedForward(temp_dim, temp_dim, embed_dim)
      
    def forward(self, x):
      x=torch.squeeze(x)
      #x1=self.pool(x)
      x1=self.ffn(x)
      x=self.ln(x)
      #x=self.pool(x)
      mask=(torch.cuda.FloatTensor(self.ic*self.incep, self.ic*self.incep).uniform_() > self.mr).type(torch.uint8)
      x,s=self.mha(x, mask, return_attention=True)
      return x+x1.reshape([-1,self.ic*self.incep,self.ed]),  s
      






#residual block
class residual_block(nn.Module):
    def __init__(self, in_channels, out_channels, k, p,stride=1):
        """
        Args:
          in_channels (int):  Number of input channels.
          out_channels (int): Number of output channels.
          stride (int):       Controls the stride.
        """
        super(residual_block, self).__init__()

        self.skip = nn.Sequential()

        if stride != 1 or in_channels != out_channels:
          self.skip = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, padding='same', bias=False),
            nn.AvgPool2d(kernel_size=(1,p)),
            nn.BatchNorm2d(out_channels))
        else:
          self.skip = None

        self.block = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1,k), padding='same', bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=(1,p)),
            nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(1,k), padding='same', bias=False),
            nn.BatchNorm2d(out_channels))

    def forward(self, x):
      out = self.block(x)
      out += (x if self.skip is None else self.skip(x))
      out = F.relu(out)
      return out


############## CoatNet #################################################
class coatnet(nn.Module):
    def __init__(self, n_classes, n_channels, embed_dim, num_heads, conv_kernel, pool_kernel, incep=4):
        super().__init__()

        n_ch=n_channels
        self.s0=stem_stage(n_ch, n_ch, pool_kernel)
        self.s11= InceptionBlock(n_ch, n_ch, n_ch, [8, 12 ,16 ,24], pool_kernel)
        self.s12= InceptionBlock(n_ch*4, n_ch*4, n_ch*4, [4, 8 ,16 ,32], pool_kernel)
        temp_dim=281
        self.s2=att_stage(128, num_heads, 14, temp_dim, pool_kernel, incep)
        self.res1=residual_block(4, 1, 3, 1,stride=2)

        self.classifier=nn.Linear(n_ch*embed_dim*incep, n_classes)
        self.classifier=nn.Linear(3934, n_classes)



        #================================
        num_heads=1
        in_channels=n_channels
        drop_rate=0.4
        self.block11= nn.Sequential(* [nn.Conv2d(num_heads, 2, kernel_size = (1, 16), padding = 'same', bias=False),
                        nn.BatchNorm2d(2),
                        DepthwiseConv2d(2,kernel_size=(in_channels, 1),depth_multiplier=1, bias=False, padding='valid'),
                        nn.BatchNorm2d(2),
                        nn.ELU()] )
        self.block12= nn.Sequential(* [torch.nn.Conv2d(num_heads, 4, kernel_size = (1, 32), padding = 'same', bias=False),
                        nn.BatchNorm2d(4),
                        DepthwiseConv2d(4,kernel_size=(in_channels, 1),depth_multiplier=1, bias=False, padding='valid'),
                        nn.BatchNorm2d(4),
                        nn.ELU()] )
        self.block13= nn.Sequential(* [torch.nn.Conv2d(num_heads, 8, kernel_size = (1, 64), padding = 'same', bias=False),
                        nn.BatchNorm2d(8),
                        DepthwiseConv2d(8,kernel_size=(in_channels, 1),depth_multiplier=1, bias=False, padding='valid'),
                        nn.BatchNorm2d(8),
                        nn.ELU()] )

        self.pool1 = nn.Sequential(* [nn.AvgPool2d(kernel_size=(1,4)), 
                      nn.Dropout(drop_rate)] )


        #================================



      
    def forward(self, x):
      x=x.reshape([-1,1,25,1125])
      branches = (self.block11, self.block12, self.block13)
      x = torch.cat([branch(x) for branch in branches], 1)
      x = self.pool1(x)
      #print(x.shape)
      x=torch.flatten(x, 1)
      x=self.classifier(x)
      return x

In [None]:
eeg = torch.randn(60,25,1125).to(device)
network=coatnet(n_classes=4, n_channels=25, embed_dim=64, num_heads=8, conv_kernel=10, pool_kernel=2, incep=1).to(device)
print(sum(p.numel() for p in network.parameters() if p.requires_grad))
out=network(eeg)

396868


### Yassine's cnn

In [None]:
class ConvNet(torch.nn.Module):
    def __init__(self, n_chan, fm, n_convs, init_pool, kernel_size):
        super(ConvNet, self).__init__()
        self.pool = torch.nn.AvgPool1d(init_pool)
        self.conv = torch.nn.Conv1d(n_chan, fm, kernel_size = kernel_size, padding = kernel_size // 2, bias = False)
        self.bn = torch.nn.BatchNorm1d(fm)
        self.blocks = []
        newfm = fm
        oldfm = fm
        for i in range(n_convs):
            if i > 0:
                newfm = int(1.414 * newfm)
            self.blocks.append(torch.nn.Sequential(
                (torch.nn.Conv1d(oldfm, newfm, kernel_size = kernel_size, padding = kernel_size // 2, bias = False)),
                (torch.nn.BatchNorm1d(newfm)),
                (torch.nn.MaxPool1d(2) if i > 0 - 1 else torch.nn.MaxPool1d(1)),
                (torch.nn.ReLU()),
                (torch.nn.Conv1d(newfm, newfm, kernel_size = kernel_size, padding = kernel_size // 2, bias = False)),
                (torch.nn.BatchNorm1d(newfm)),
                (torch.nn.ReLU())
            ))
            oldfm = newfm
        self.blocks = torch.nn.ModuleList(self.blocks)
        self.fc = torch.nn.Linear(oldfm, 4)
    

    def forward(self, x):
        y = torch.relu(self.bn(self.conv(self.pool(x))))
        for seq in self.blocks:
            y = seq(y)
        y = y.mean(dim = 2)
        return self.fc(y)
    

In [None]:
eeg = torch.randn(60,25,1125).to(device)
params = [64, 4, 4, 7]
network=ConvNet(eeg.shape[1], params[0], params[1], params[2], params[3]).to(device)
print(sum(p.numel() for p in network.parameters() if p.requires_grad))
out=network(eeg)

744583


### EEG-ITNET

In [None]:
class DepthwiseConv2d(torch.nn.Conv2d):
    def __init__(self,
                 in_channels,
                 depth_multiplier=2,
                 kernel_size=3,
                 stride=1,
                 padding=0,
                 dilation=1,
                 bias=True,
                 padding_mode='zeros'
                 ):
        out_channels = in_channels * depth_multiplier
        super().__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=in_channels,
            bias=bias,
            padding_mode=padding_mode
        )


#===================     EEG-ITNet       =====================================================================   
   
class EEG_ITNET(nn.Module):
    def __init__(self, n_classes, in_channels):
        super().__init__()
        drop_rate=0.4
        self.block11= nn.Sequential(* [nn.Conv2d(1, 2, kernel_size = (1, 16), padding = 'same', bias=False),
                        nn.BatchNorm2d(2),
                        DepthwiseConv2d(2,kernel_size=(in_channels, 1),depth_multiplier=1, bias=False, padding='valid'),
                        nn.BatchNorm2d(2),
                        nn.ELU()] )
        self.block12= nn.Sequential(* [torch.nn.Conv2d(1, 4, kernel_size = (1, 32), padding = 'same', bias=False),
                        nn.BatchNorm2d(4),
                        DepthwiseConv2d(4,kernel_size=(in_channels, 1),depth_multiplier=1, bias=False, padding='valid'),
                        nn.BatchNorm2d(4),
                        nn.ELU()] )
        self.block13= nn.Sequential(* [torch.nn.Conv2d(1, 8, kernel_size = (1, 64), padding = 'same', bias=False),
                        nn.BatchNorm2d(8),
                        DepthwiseConv2d(8,kernel_size=(in_channels, 1),depth_multiplier=1, bias=False, padding='valid'),
                        nn.BatchNorm2d(8),
                        nn.ELU()] )

        self.pool1 = nn.Sequential(* [nn.AvgPool2d(kernel_size=(1,4)), 
                      nn.Dropout(drop_rate)] )

        #================================

        self.block2= nn.Sequential(* [DepthwiseConv2d(14,kernel_size=(1,4),depth_multiplier=1, dilation=(1, 1), bias=False, padding='valid'),
                        nn.BatchNorm2d(14),
                        nn.ELU(),
                        nn.Dropout(drop_rate)] )


        self.block3= nn.Sequential(* [DepthwiseConv2d(14,kernel_size=(1,4),depth_multiplier=1, dilation=(1, 1), bias=False, padding='valid'),
                        nn.BatchNorm2d(14),
                        nn.ELU(),
                        nn.Dropout(drop_rate) ] )

        #================================


        self.block4= nn.Sequential(* [DepthwiseConv2d(14,kernel_size=(1,4),depth_multiplier=1, dilation=(1, 2), bias=False, padding='valid'),
                        nn.BatchNorm2d(14),
                        nn.ELU(),
                        nn.Dropout(drop_rate)] )

        self.block5= nn.Sequential(* [DepthwiseConv2d(14,kernel_size=(1,4),depth_multiplier=1, dilation=(1, 2), bias=False, padding='valid'),
                        nn.BatchNorm2d(14),
                        nn.ELU(),
                        nn.Dropout(drop_rate)] )

        #================================

        self.block6= nn.Sequential(* [DepthwiseConv2d(14,kernel_size=(1,4),depth_multiplier=1, dilation=(1, 4), bias=False, padding='valid'),
                        nn.BatchNorm2d(14),
                        nn.ELU(),
                        nn.Dropout(drop_rate)] )

        self.block7= nn.Sequential(* [DepthwiseConv2d(14,kernel_size=(1,4),depth_multiplier=1, dilation=(1, 4), bias=False, padding='valid'),
                        nn.BatchNorm2d(14),
                        nn.ELU(),
                        nn.Dropout(drop_rate) ] )
                        
        #================================

        self.block8= nn.Sequential(* [DepthwiseConv2d(14,kernel_size=(1,4),depth_multiplier=1, dilation=(1, 8), bias=False, padding='valid'),
                        nn.BatchNorm2d(14),
                        nn.ELU(),
                        nn.Dropout(drop_rate)] )

        self.block9= nn.Sequential(* [DepthwiseConv2d(14,kernel_size=(1,4),depth_multiplier=1, dilation=(1, 8), bias=False, padding='valid'),
                        nn.BatchNorm2d(14),
                        nn.ELU(),
                        nn.Dropout(drop_rate)] )

        #================================

        self.block_reduce= nn.Sequential(* [nn.Conv2d(14, 28, kernel_size = (1, 1)),
                        nn.BatchNorm2d(28),
                        nn.ELU(),
                        nn.AvgPool2d((1,4)),
                        nn.Dropout(drop_rate)] )


      
        self.classifier=nn.Linear(644, n_classes)
        self.m = nn.Softmax(dim=1)
          

    def forward(self, x):
      x=x.reshape([-1,1,x.shape[1],x.shape[2]])
      branches = (self.block11, self.block12, self.block13)
      x = torch.cat([branch(x) for branch in branches], 1)
      x = self.pool1(x)
      x1=x
      #================================
      paddings = (3,0, 0,0, 0,0, 0,0)
      x = nn.functional.pad(x, paddings)
      x=self.block2(x)
      x = nn.functional.pad(x, paddings)
      x=self.block3(x)+x1
      x1=x
      #================================
      paddings = (6,0, 0,0, 0,0, 0,0)
      x = nn.functional.pad(x, paddings)
      x=self.block4(x)
      x = nn.functional.pad(x, paddings)
      x=self.block5(x)+x1
      x1=x
      #================================
      paddings = (12,0, 0,0, 0,0, 0,0)
      x = nn.functional.pad(x, paddings)
      x=self.block6(x)
      x = nn.functional.pad(x, paddings)
      x=self.block7(x)+x1
      x1=x
      #================================
      paddings = (24,0, 0,0, 0,0, 0,0)
      x = nn.functional.pad(x, paddings)
      x=self.block8(x)
      x = nn.functional.pad(x, paddings)
      x=self.block9(x)+x1
      #================================
      x=self.block_reduce(x)
      #================================
      x=torch.flatten(x, 1)
      x=self.classifier(x)

      return self.m(x)


In [None]:
eeg = torch.randn(60,22,375).to(device)
network=EEG_ITNET(n_classes=4, in_channels=22).to(device)
print(sum(p.numel() for p in network.parameters() if p.requires_grad))
out=network(eeg)

4764


###Models

In [None]:
from math import ceil
import torch
from torch import nn
from torch.nn.utils import weight_norm
import torch.nn.functional as F
import math
##############  Attention mechanism (still on testing phase) #############################################################

def attention(q, k, v, d_k, mask=None, dropout=None):
    
    scores = torch.matmul(q, k.transpose(-2, -1)) /  math.sqrt(d_k)
    if mask is not None:
            mask = mask.unsqueeze(1)
            scores = scores.masked_fill(mask == 0, -1e9)
    scores = F.softmax(scores, dim=-1)
        
    if dropout is not None:
        scores = dropout(scores)
        
    output = torch.matmul(scores, v)
    return output
###############################################################
class MultiHeadAttention(nn.Module):
    def __init__(self, heads, d_model, dropout = 0.1):
        super().__init__()
        
        self.d_model = d_model
        self.d_k = d_model // heads
        self.h = heads
        
        self.q_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(d_model, d_model)
    
    def forward(self, q, k, v, mask=None):
        
        bs = q.size(0)
        
        # perform linear operation and split into h heads
        
        k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
        q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
        v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
        
        # transpose to get dimensions bs * h * sl * d_model
       
        k = k.transpose(1,2)
        q = q.transpose(1,2)
        v = v.transpose(1,2)
        # calculate attention using function we will define next
        scores = attention(q, k, v, self.d_k, mask, self.dropout)
        
        # concatenate heads and put through final linear layer
        concat = scores.transpose(1,2).contiguous()\
        .view(bs, -1, self.d_model)
        
        output = self.out(concat)
    
        return output


##############  TIDNet modules #############################################################

class PrintLayer(nn.Module):
    def __init__(self):
        super(PrintLayer, self).__init__()
    
    def forward(self, x):
        # Do your print / debug stuff here
        print(x.shape)
        return x

class Ensure4d(nn.Module):
    def forward(self, x):
        while(len(x.shape) < 4):
            x = x.unsqueeze(-1)
        return x


class Expression(nn.Module):
    """Compute given expression on forward pass.
    Parameters
    ----------
    expression_fn : callable
        Should accept variable number of objects of type
        `torch.autograd.Variable` to compute its output.
    """

    def __init__(self, expression_fn):
        super(Expression, self).__init__()
        self.expression_fn = expression_fn

    def forward(self, *x):
        return self.expression_fn(*x)

    def __repr__(self):
        if hasattr(self.expression_fn, "func") and hasattr(
            self.expression_fn, "kwargs"
        ):
            expression_str = "{:s} {:s}".format(
                self.expression_fn.func.__name__, str(self.expression_fn.kwargs)
            )
        elif hasattr(self.expression_fn, "__name__"):
            expression_str = self.expression_fn.__name__
        else:
            expression_str = repr(self.expression_fn)
        return (
            self.__class__.__name__ +
            "(expression=%s) " % expression_str
        )


class _TemporalFilter(nn.Module):
    def __init__(self, in_chans, filters, depth, temp_len, drop_prob=0., activation=nn.LeakyReLU,
                 residual='netwise'):
        super().__init__()
        temp_len = temp_len + 1 - temp_len % 2
        self.residual_style = str(residual)
        net = list()

        for i in range(depth):
            dil = depth - i
            conv = weight_norm(nn.Conv2d(in_chans if i == 0 else filters, filters,
                                         kernel_size=(1, temp_len), dilation=dil,
                                         padding=(0, dil * (temp_len - 1) // 2)))
            net.append(nn.Sequential(
                conv,
                activation(),
                nn.Dropout2d(drop_prob)
            ))
        if self.residual_style.lower() == 'netwise':
            self.net = nn.Sequential(*net)
            self.residual = nn.Conv2d(in_chans, filters, (1, 1))
        elif residual.lower() == 'dense':
            self.net = net

    def forward(self, x):
        if self.residual_style.lower() == 'netwise':
            return self.net(x) + self.residual(x)
        elif self.residual_style.lower() == 'dense':
            for layer in self.net:
                x = torch.cat((x, layer(x)), dim=1)
            return x


############## MLP blocks #################################################
class MLP_blocks(nn.Module):
    def __init__(self, arch, in_channels):
        super().__init__()
        layers = []
        prnt=PrintLayer()
        for x in arch:
            layers += [nn.Linear(in_channels, x),
                        nn.BatchNorm1d(x),
                        nn.ReLU()]
            in_channels = x

        self.model= nn.Sequential(*layers)
      
    def forward(self, x):
      return self.model(x)

############## spatial convs(electrodes,1) (electrode aggregator)  (still on testing phase) #################################################
class spatial_aggregator(nn.Module):
    def __init__(self, out, in_channels, k ):
        super().__init__()
        layers = []
            
        layers += [nn.Conv2d(in_channels, out, (k,1)),
                    nn.BatchNorm2d(out),
                    nn.ReLU(),
                    nn.AvgPool2d(kernel_size=(7,1), stride=(1,2))]

        self.model= nn.Sequential(*layers)
      
    def forward(self, x):
      return self.model(x)

############## electrode weight-sharing convs(1,1) (a.k.a 1x1encoders) #################################################
class SharedSpaceTimeConv1x1(nn.Module):
    def __init__(self, mode, arch, in_channels):
        super().__init__()
        def _permute(x):
            """
            Permutes data:
            from dim:
            batch, chans, time, 1
            to dim:
            batch, 1, chans, time
            """
            return x.permute([0, 3, 1, 2])
        prnt=PrintLayer()
        if mode=='1x1':
            layers = [Ensure4d(),
                Expression(_permute),]
            in_channels = 1
        else:
            layers = []
        for x in arch:
            s=str(x)
            if s[0] == 'A':
                layers += [nn.AvgPool2d(kernel_size=(1, int(s[1:])), stride=(1,2))]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=1),
                           nn.BatchNorm2d(x),
                           nn.ReLU()]
                in_channels = x

        self.model= nn.Sequential(*layers)
      
    def forward(self, x):
      return self.model(x)


############## TIDNet temporal block #################################################
class TemporalTIDNet(nn.Module):
    def __init__(self, t_filters, input_window_samples, drop_prob, pooling,
                 temp_layers,  temp_span):
        super().__init__()
        self.temp_len = ceil(temp_span * input_window_samples)

        def _permute(x):
            """
            Permutes data:
            from dim:
            batch, chans, time, 1
            to dim:
            batch, 1, chans, time
            """
            return x.permute([0, 3, 1, 2])

        self.temporal = nn.Sequential(
            Ensure4d(),
            Expression(_permute),
            _TemporalFilter(1, t_filters, depth=temp_layers, temp_len=self.temp_len),
            nn.MaxPool2d((1, pooling)),
            nn.Dropout2d(drop_prob),
        )

    def forward(self, x):
        x = self.temporal(x)
        return x


###################################################################################### 

class TemporalShareSpaceTime(nn.Module):

    def __init__(self, mode, arch, n_classes, in_chans, input_window_samples, t_filters,
                 drop_prob, pooling, temp_layers, temp_span):
        super().__init__()
        def compute_params(arch,h,w):
            ### height and width of the input convolutions of each feature map
            #### compute the number of inputs after flatten ######
            for i in arch:
                s=str(i)
                if s[0] == 'A':
                    w=int((w-int(s[1:]))/2 +1 )     #for average pooling we apply:  [(Output width + padding width right + padding width left - kernel width) / (stride width)] + 1
            s=str(i)
            if s[0] == 'A':
                last_filter=arch[-2]
            else:
                last_filter=arch[-1]
            return int(last_filter*h*w)

        self.mode=mode
        self.n_classes = n_classes
        self.in_chans = in_chans
        self.input_window_samples = input_window_samples
        self.temp_len = ceil(temp_span * input_window_samples)
        self.params=compute_params(arch,in_chans,ceil((input_window_samples/pooling)-1))   
        
        ######################################################
        
        if self.mode!='1x1':
            self.tidnet_temp = TemporalTIDNet(t_filters=t_filters,
                                     input_window_samples=input_window_samples,
                                     drop_prob=drop_prob, pooling=pooling, temp_layers=temp_layers,
                                     temp_span=temp_span)     

        if self.mode=='1x1' or self.mode=='t+1x1':
            self.model = SharedSpaceTimeConv1x1(self.mode, arch, t_filters) 
            if self.mode=='1x1':
                self.params=compute_params(arch,in_chans,input_window_samples)
            self.fc_block = nn.Linear(self.params, self.n_classes)
        elif self.mode=='mlp':
            self.model = MLP_blocks(arch, in_chans*ceil((input_window_samples/pooling))*t_filters)
            self.fc_block = nn.Linear(arch[-1], self.n_classes)

    

    def forward(self, x):
        """Forward pass.
        Parameters
        ----------
        x: torch.Tensor
            Batch of EEG windows of shape (batch_size, n_channels, n_times).
        """
        if self.mode!='1x1':
            x = self.tidnet_temp(x)
        
        if self.mode!='mlp':
            x = self.model(x)
            x = torch.flatten(x, 1)
            out = self.fc_block(x)

        else :
            x = torch.flatten(x, 1)
            x = self.model(x)
            out = self.fc_block(x)
        return out
    
    def get_emb(self, x):
        return self.tidnet_temp(x)

############################################################################


###Train

In [None]:
import wandb
import os
import torch 
from torch import optim
from sklearn.metrics import f1_score
import torch.nn.functional as F
from torch.autograd import Variable
from torchinfo import summary
import gc



from braindecode.models import EEGNetv4,TIDNet, EEGResNet
def build_network(mode, arch, n_classes=4, in_chans=25, input_window_samples=1125, t_filters=32,
                 drop_prob=0.4, pooling=15, temp_layers=2, temp_span=0.05):
    #model=coatnet(n_classes=4, n_channels=25, embed_dim=64, num_heads=8, conv_kernel=10, pool_kernel=2, incep=1).to(device)
    #model=RNN(input_size=1125, hidden_size=128 , num_layers=4, num_classes=4).to('cuda')
    #model=TemporalShareSpaceTime(mode, arch, n_classes, in_chans, input_window_samples, t_filters,
    #             drop_prob, pooling, temp_layers, temp_span)
    model=EEGNetv4(in_chans,n_classes,input_window_samples)

    #params = [64, 4, 4, 7]
    #model=ConvNet(in_chans, params[0], params[1], params[2], params[3])
    #model=EEG_ITNET(n_classes=4, in_channels=22).to(device)
  
    return model


def build_optimizer(network, optimizer, learning_rate):
    if optimizer == "sgd":
        optimizer = optim.SGD(network.parameters(),
                              lr=learning_rate, momentum=0.9, weight_decay=0.5*0.001)
    elif optimizer == "adamw":
        optimizer = optim.AdamW(network.parameters(),
                               lr=learning_rate,  weight_decay=0.01, amsgrad=True)
    return optimizer



####### Training using mixup #####################################################################
def mixup_data(x, y, alpha=5.0, beta=5.0, use_cuda=True):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, beta)
    else:
        lam = 1

    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


def train_mixup(network, optimizer, train_loader):
    cumu_loss = 0
    correct = 0.0
    total = 0.0
    criterion = torch.nn.CrossEntropyLoss()   
    
    network.train()
    for batch_idx, (data, target) in enumerate(train_loader):    

        # Clear the gradients
        optimizer.zero_grad()

        # Forward pass :
        ## implement mixup with alpha preset to 2
        inputs, targets_a, targets_b, lam = mixup_data(data, target, use_cuda=True)
        inputs, targets_a, targets_b = map(Variable, (inputs,
                                                    targets_a, targets_b))


        outputs = network(inputs)

        loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
        
        cumu_loss += loss.item()
    
        # ⬅ Backward pass + weight update
        loss.backward()
        optimizer.step()

        # compute accuracy
        # Get predictions from the maximum value
        _, predicted = torch.max(outputs.data, 1)

        # Total number of labels
        total += target.size(0)
        correct += (lam * predicted.eq(targets_a.data).cpu().sum().float()
                    + (1 - lam) * predicted.eq(targets_b.data).cpu().sum().float())



    return cumu_loss / batch_idx*data.shape[0], correct/total
############################################################################

def train_epoch(network, optimizer, train_loader):
    cumu_loss = 0
    correct = 0.0
    total = 0.0
    
    network.train()
    for batch_idx, (data, target) in enumerate(train_loader):        

        r=random.uniform(0, 1)
        #if r<0.52:
          #data=apply_da(data,selection)
        # Clear the gradients
        optimizer.zero_grad()

        # Forward pass 

        outputs = network(data)
        loss = F.cross_entropy(outputs, target)
        cumu_loss += loss.item()
    
        # ⬅ Backward pass + weight update
        loss.backward()
        optimizer.step()
    
        # compute accuracy
        # Get predictions from the maximum value
        _, predicted = torch.max(outputs.data, 1)

        # Total number of labels
        total += target.size(0)
        correct += (predicted == target).sum()


    return cumu_loss / batch_idx*data.shape[0], correct/total


def validate_epoch(network, valid_loader):
    cumu_loss = 0
    correct = 0.0
    total = 0.0
    
    network.eval()
    
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(valid_loader):  

            loss = F.cross_entropy(network(data), target)
            cumu_loss += loss.item()     

            # compute accuracy
            outputs = network(data)

            # Get predictions from the maximum value
            _, predicted = torch.max(outputs.data, 1)

            # Total number of labels
            total += target.size(0)
            correct += (predicted == target).sum()   

    
    return cumu_loss / batch_idx*data.shape[0], correct/total


def test(network, test_loader , classes):
    n_classes=len(classes)
    # Calculate Accuracy
    correct = 0.0
    correct_arr = [0.0] * n_classes
    total = 0.0
    total_arr = [0.0] * n_classes
    y_true=[]
    y_pred=[]
    pred_probs=[]
    # Iterate through test dataset
    network.eval()   #network.train()
    
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(test_loader):  

            outputs = network(data)
            
            # Get predictions from the maximum value
            _, predicted = torch.max(outputs.data, 1)
            # Total number of labels
            total += target.size(0)
            correct += (predicted == target).sum()
            y_true.append(target.cpu().detach().numpy())
            y_pred.append(predicted.cpu().detach().numpy())
            pred_probs.append(outputs.data.cpu().detach().numpy())
            
            for label in range(n_classes):
                correct_arr[label] += (((predicted == target) & (target==label)).sum())
                total_arr[label] += (target == label).sum()
    
    
    y_true=np.array(y_true[:-1]).reshape([-1])
    y_pred=np.array(y_pred[:-1]).reshape([-1])
    pred_probs=np.array(pred_probs[:-1]).reshape([-1,n_classes])


    '''# Confusion Matrices
    wandb.log({"conf_mat" : wandb.plot.confusion_matrix(probs=None,
                        y_true=y_true , preds=y_pred ,
                        class_names=classes)})


    # ROC
    wandb.log({"roc" : wandb.plot.roc_curve(  y_true , pred_probs ,
                            labels=classes)})


    # Precision Recall Curve
    wandb.log({"pr" : wandb.plot.pr_curve( y_true , pred_probs ,
                        labels=classes, classes_to_plot=None)})'''

    
    f1=f1_score( y_true, y_pred, average='macro')
    wandb.log({'Test macro F1-Score': f1})
    print(f1)
    accuracy = correct / total
    print('TEST ACCURACY {} '.format(accuracy))
                          
    return accuracy, f1



def train_sweep(x, y, number_of_subjects, device, bad_subjects, config, run, resultdir):
    
    gc.collect()
    torch.cuda.empty_cache()

    run_name = run.name
    batch_size= 128 #config.batch_size
    patience = 15 #config.patience
    LR=1e-3 #config.learning_rate
    optim='adamw' #config.optimizer
    
    '''mode=config.mode
    design=config.design
    p= config.power_2_of_filters
    v=1
    if mode!='mlp':
        v=config.power_2_of_Avg_pooling


    
    if design=='inverse_bottleneck':
        p=p-2
        v=1
    elif design=='bottleneck':
        p=p+3
    elif design=='fixed':
        v=1
    arch=[]

    # Building the design
    for i in range(config.number_of_layers):
        arch.append(2**p)
        if mode!='mlp':
            arch.append('A'+str(2**v))
        if design=='bottleneck':
            p-=1
            v-=1
            if v<1:
                v=1
        elif design=='inverse_bottleneck':
            p+=1
            if p>6:
                p=6
    print(arch)'''


    mode='mlp'
    arch=[16,16,16,16,16]
    network = build_network(mode, arch, n_classes=2, in_chans=61, input_window_samples=1024).to(device)
    pytorch_total_params = sum(p.numel() for p in network.parameters() if p.requires_grad)
    wandb.log({"total number of parameters": pytorch_total_params})
    optimizer=build_optimizer(network, optim, LR)
    #scheduler = torch.optim.lr_scheduler.StepLR(optimizer = optimizer, step_size = 40, gamma = 0.1)
    
    classes=['left fist', 'right fist']#, 'eyes_open', 'MI O/C feet']

    kfold= config.subject

    T_x,T_y,V_x,V_y,Test_x,Test_y=lmso(x, y, number_of_subjects, kfold, device, bad_subjects, with_validation=False)
    train_loader, valid_loader, test_loader=loaders(kfold, T_x, T_y, V_x, V_y, Test_x, Test_y, batch_size)

    Train_acc=[]
    Val_acc=[]
    Train_loss=[]
    Val_loss=[]
    Test_acc=[]
    ## Tensorboard iterators
    tr_iter=0
    v_itr=0
    maxv=0
    cpt_early = 0

    for epoch in range(100): #config.epochs):
        train_loss,train_acc = train_epoch(network, optimizer, train_loader)
        print('train loss {} accuracy {} epoch {} done'.format(train_loss,train_acc,epoch))
        #scheduler.step()
      
        wandb.log({'Training accuracy': train_acc,'Training loss': train_loss});   ## use this when no validation set is used

        '''val_loss,val_acc = validate_epoch(network, valid_loader)
        #print('val loss {} epoch {} done'.format(val_loss,epoch))

        Train_acc.append(train_acc)
        Val_acc.append(val_acc)
        Train_loss.append(train_loss)
        #Val_loss.append(val_loss)
        #wandb.log({'Training accuracy': train_acc,'Training loss': train_loss,'Validation accuracy': val_acc,'Validation loss': val_loss});

        
        if maxv<val_acc:
            print(f"Epoch {epoch}, new best val accuracy {val_acc} and loss {val_loss}")
            maxv=val_acc

            ckpt_dict = {
            'weights': network.state_dict(),
            'train_acc': Train_acc,
            'val_acc': Val_acc,
            'train_loss': Train_loss,
            'val_acc': Val_loss,
            'epoch': epoch
            }
            torch.save(ckpt_dict,os.path.join(resultdir,f"{run_name}_bestval.pth") )
            cpt_early = 0
        else:
            cpt_early +=1
        
        if cpt_early == patience:
            print("Early Stopping")
            wandb.log({'Maximum validation accuracy': maxv})
            break
    
    print("Reloading best validation model")
    ckpt_dict = torch.load(os.path.join(resultdir,f"{run_name}_bestval.pth") )

    print(f"Reloading best model at epoch {ckpt_dict['epoch']}")
    network.load_state_dict(ckpt_dict['weights'])'''

    test_acc,f1=test(network, test_loader , classes)
    wandb.log({'Test accuracy': test_acc})

    return test_acc, f1

###Main

In [None]:
project_name="cho17_tests" #PhysioNet_tests   #bci4_tests   #cho17_tests   #sandbox

sweep_config = {
    'method': 'grid', #grid
    'metric': {
      'name': 'loss',
      'goal': 'minimize'   
    },
    'parameters': {
        'epochs': {
            'values': [3000]
        },
        'batch_size': {
            'values': [64]
        },
        'patience': {
            'values': [60]
        },
     
        'learning_rate': {
            'values': [1e-3]
        },
        'optimizer': {
            'values': ['adamw']
        },
         'loss': {
            'values': ['CrossEntropyLoss'],
        },
        'subject': {
            'values': [1,2,3,4,5,6,7,8,9,10]
        },
        'runs': {
            'values': [1,2,3,4,5,6,7,8,9,10]
        },
      
    }
}

sweep_id = wandb.sweep(sweep_config, project=project_name)

#number_of_subjects=9
#bad_subjects=[] #bad subject for BCI-VI-2a
#x, y= load_subjects('/content/drive/MyDrive/Colab Notebooks/bci4_subjects/', number_of_subjects, device, bad_subjects, apply_euclidean=False, with_eog=False)

number_of_subjects=52
bad_subjects=[32,46,49] #bad subject for Cho2017
x, y= load_subjects('/content/drive/MyDrive/Colab Notebooks/mne_data/', number_of_subjects, device, bad_subjects, apply_euclidean=False, with_eog=False)

#number_of_subjects=109
#bad_subjects=[88,90,92,100] #bad subject for Physionet MI
#x, y= load_subjects('/content/drive/MyDrive/Colab Notebooks/mne_data/', number_of_subjects, device, bad_subjects, apply_euclidean=True)

def train_wandb():
    # Initialize a new wandb run
    run = wandb.init(project=project_name, entity="brain-imt" , config=sweep_config)
    assert run is wandb.run
    with run:
        config =wandb.config
        #run=12
        test_acc, f1=train_sweep(x, y, number_of_subjects, device, bad_subjects, config, run, resultdir)
    ############################################################################

#import os
#os.environ["WANDB_MODE"]="offline"

#Test_acc= train_wandb()
wandb.agent(sweep_id, train_wandb)

Create sweep with ID: j09ylk92
Sweep URL: https://wandb.ai/brain-imt/cho17_tests/sweeps/j09ylk92




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




240 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
240 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 240 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




240 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
240 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 240 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped




200 events found
Event IDs: [1 2]
Used Annotations descriptions: ['left_hand', 'right_hand']
Adding metadata with 4 columns
200 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 200 events and 1024 original time points ...
0 bad epochs dropped


[34m[1mwandb[0m: Agent Starting Run: 4aehu1j0 with config:
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	epochs: 3000
[34m[1mwandb[0m: 	learning_rate: 0.001
[34m[1mwandb[0m: 	loss: CrossEntropyLoss
[34m[1mwandb[0m: 	optimizer: adamw
[34m[1mwandb[0m: 	patience: 60
[34m[1mwandb[0m: 	runs: 1
[34m[1mwandb[0m: 	subject: 1




train loss 31.13577813687532 accuracy 0.6161035895347595 epoch 0 done
train loss 28.02446417186571 accuracy 0.6731982231140137 epoch 1 done
train loss 26.70374354072239 accuracy 0.6972973346710205 epoch 2 done
train loss 26.026610872019894 accuracy 0.7182432413101196 epoch 3 done
train loss 25.724472605663795 accuracy 0.7159910202026367 epoch 4 done
train loss 25.48301107987114 accuracy 0.718693733215332 epoch 5 done
train loss 25.30910989512568 accuracy 0.728378415107727 epoch 6 done
train loss 24.8648795666902 accuracy 0.7331081032752991 epoch 7 done
train loss 24.773213199947193 accuracy 0.7319819927215576 epoch 8 done
train loss 24.430958001509957 accuracy 0.7391892075538635 epoch 9 done
train loss 24.312580108642578 accuracy 0.7425675988197327 epoch 10 done
train loss 24.012443583944574 accuracy 0.748085618019104 epoch 11 done
train loss 23.783791790837828 accuracy 0.7495495676994324 epoch 12 done
train loss 23.591963166775912 accuracy 0.7524774670600891 epoch 13 done
train loss 2

  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Test accuracy,▁
Training accuracy,▁▄▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇█████████████████
Training loss,█▅▅▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total number of parameters,▁

0,1
Test accuracy,0.793
Test macro F1-Score,
Training accuracy,0.81171
Training loss,19.28686
total number of parameters,3106.0


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: ljh55he5 with config:
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	epochs: 3000
[34m[1mwandb[0m: 	learning_rate: 0.001
[34m[1mwandb[0m: 	loss: CrossEntropyLoss
[34m[1mwandb[0m: 	optimizer: adamw
[34m[1mwandb[0m: 	patience: 60
[34m[1mwandb[0m: 	runs: 1
[34m[1mwandb[0m: 	subject: 2




train loss 30.35101376409116 accuracy 0.6298423409461975 epoch 0 done
train loss 27.048768209374472 accuracy 0.68840092420578 epoch 1 done
train loss 25.825302559396494 accuracy 0.713738739490509 epoch 2 done
train loss 25.376026319420852 accuracy 0.7198198437690735 epoch 3 done
train loss 24.97093032753986 accuracy 0.7277027368545532 epoch 4 done
train loss 24.722265989884086 accuracy 0.735247790813446 epoch 5 done
train loss 24.402558160864785 accuracy 0.7365990877151489 epoch 6 done
train loss 24.38740632845008 accuracy 0.7393018007278442 epoch 7 done
train loss 24.159968127375066 accuracy 0.7413288354873657 epoch 8 done
train loss 23.845868753350302 accuracy 0.7445946335792542 epoch 9 done
train loss 23.569176901941717 accuracy 0.7461711764335632 epoch 10 done
train loss 23.332626093988836 accuracy 0.7510135173797607 epoch 11 done
train loss 23.118610796721086 accuracy 0.7514640092849731 epoch 12 done
train loss 22.93440700613934 accuracy 0.7594594955444336 epoch 13 done
train loss

  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Test accuracy,▁
Training accuracy,▁▄▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇████████████
Training loss,█▅▅▅▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▂▁▁▁▁▁▁▁
total number of parameters,▁

0,1
Test accuracy,0.717
Test macro F1-Score,
Training accuracy,0.82128
Training loss,18.19552
total number of parameters,3106.0


[34m[1mwandb[0m: Agent Starting Run: mitm7mtd with config:
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	epochs: 3000
[34m[1mwandb[0m: 	learning_rate: 0.001
[34m[1mwandb[0m: 	loss: CrossEntropyLoss
[34m[1mwandb[0m: 	optimizer: adamw
[34m[1mwandb[0m: 	patience: 60
[34m[1mwandb[0m: 	runs: 1
[34m[1mwandb[0m: 	subject: 3




train loss 30.412210277889088 accuracy 0.6373873949050903 epoch 0 done
train loss 27.475532718326736 accuracy 0.6849099397659302 epoch 1 done
train loss 26.19307644470878 accuracy 0.7087838053703308 epoch 2 done
train loss 25.54390235569166 accuracy 0.7211712002754211 epoch 3 done
train loss 25.13020152631013 accuracy 0.7257882952690125 epoch 4 done
train loss 24.76834139616593 accuracy 0.7301802039146423 epoch 5 done
train loss 24.661256292591926 accuracy 0.7381756901741028 epoch 6 done
train loss 24.024023698723834 accuracy 0.7438063025474548 epoch 7 done
train loss 23.7020931865858 accuracy 0.7497748136520386 epoch 8 done
train loss 23.480392310930334 accuracy 0.7522522807121277 epoch 9 done
train loss 23.192387601603635 accuracy 0.756869375705719 epoch 10 done
train loss 22.99893337747325 accuracy 0.757770299911499 epoch 11 done
train loss 22.807035777879797 accuracy 0.7621622085571289 epoch 12 done
train loss 22.81820023578146 accuracy 0.759684681892395 epoch 13 done
train loss 22

  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Test accuracy,▁
Training accuracy,▁▄▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇██▇▇█████████████
Training loss,█▅▅▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁
total number of parameters,▁

0,1
Test accuracy,0.764
Test macro F1-Score,
Training accuracy,0.8089
Training loss,19.30872
total number of parameters,3106.0


[34m[1mwandb[0m: Agent Starting Run: waf90r9b with config:
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	epochs: 3000
[34m[1mwandb[0m: 	learning_rate: 0.001
[34m[1mwandb[0m: 	loss: CrossEntropyLoss
[34m[1mwandb[0m: 	optimizer: adamw
[34m[1mwandb[0m: 	patience: 60
[34m[1mwandb[0m: 	runs: 1
[34m[1mwandb[0m: 	subject: 4




train loss 30.446094554403558 accuracy 0.6305180191993713 epoch 0 done
train loss 27.00018555185069 accuracy 0.680630624294281 epoch 1 done
train loss 25.82811817915543 accuracy 0.7111486792564392 epoch 2 done
train loss 24.968871241030488 accuracy 0.7266892194747925 epoch 3 done
train loss 24.75003555546636 accuracy 0.7322072386741638 epoch 4 done
train loss 24.251204158948816 accuracy 0.7405405640602112 epoch 5 done
train loss 24.114101326983906 accuracy 0.7397522926330566 epoch 6 done
train loss 23.928686826125436 accuracy 0.7442567944526672 epoch 7 done
train loss 23.929958613022514 accuracy 0.7438063025474548 epoch 8 done
train loss 23.544482625049092 accuracy 0.7548423409461975 epoch 9 done
train loss 23.65666325195976 accuracy 0.7525901198387146 epoch 10 done
train loss 23.417159785395082 accuracy 0.7511261701583862 epoch 11 done
train loss 23.422689707382865 accuracy 0.7548423409461975 epoch 12 done
train loss 23.153299829234246 accuracy 0.753716230392456 epoch 13 done
train lo

  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Test accuracy,▁
Training accuracy,▁▄▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇██████████████
Training loss,█▅▅▄▄▄▄▄▃▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁
total number of parameters,▁

0,1
Test accuracy,0.692
Test macro F1-Score,
Training accuracy,0.8232
Training loss,18.12622
total number of parameters,3106.0


[34m[1mwandb[0m: Agent Starting Run: 9frnsa9x with config:
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	epochs: 3000
[34m[1mwandb[0m: 	learning_rate: 0.001
[34m[1mwandb[0m: 	loss: CrossEntropyLoss
[34m[1mwandb[0m: 	optimizer: adamw
[34m[1mwandb[0m: 	patience: 60
[34m[1mwandb[0m: 	runs: 1
[34m[1mwandb[0m: 	subject: 5




train loss 30.09120816769807 accuracy 0.6414414644241333 epoch 0 done
train loss 27.125402761542283 accuracy 0.6842342615127563 epoch 1 done
train loss 25.91217808101488 accuracy 0.7091216444969177 epoch 2 done
train loss 25.344102382659912 accuracy 0.7191441655158997 epoch 3 done
train loss 24.727963654891305 accuracy 0.7341216206550598 epoch 4 done
train loss 24.70589380678923 accuracy 0.7314189076423645 epoch 5 done
train loss 24.46412345637446 accuracy 0.7370495796203613 epoch 6 done
train loss 24.189800697824225 accuracy 0.7451576590538025 epoch 7 done
train loss 23.816924364670463 accuracy 0.7463964223861694 epoch 8 done
train loss 23.56481572856074 accuracy 0.7525901198387146 epoch 9 done
train loss 23.43820694218511 accuracy 0.754954993724823 epoch 10 done
train loss 23.06514066198598 accuracy 0.7560811042785645 epoch 11 done
train loss 22.89889018431954 accuracy 0.7617117166519165 epoch 12 done
train loss 22.668330545010775 accuracy 0.7654279470443726 epoch 13 done
train loss 

  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Test accuracy,▁
Training accuracy,▁▄▅▅▅▆▆▆▆▆▇▇▇▇▇▇█▇▇▇▇▇▇█████████████████
Training loss,█▅▅▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total number of parameters,▁

0,1
Test accuracy,0.694
Test macro F1-Score,
Training accuracy,0.81858
Training loss,18.437
total number of parameters,3106.0


[34m[1mwandb[0m: Agent Starting Run: lr5mt1dq with config:
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	epochs: 3000
[34m[1mwandb[0m: 	learning_rate: 0.001
[34m[1mwandb[0m: 	loss: CrossEntropyLoss
[34m[1mwandb[0m: 	optimizer: adamw
[34m[1mwandb[0m: 	patience: 60
[34m[1mwandb[0m: 	runs: 1
[34m[1mwandb[0m: 	subject: 6




train loss 30.26246916729471 accuracy 0.6353603601455688 epoch 0 done
train loss 27.371312473131262 accuracy 0.6922297477722168 epoch 1 done
train loss 26.132864413054094 accuracy 0.7100225687026978 epoch 2 done
train loss 25.42656647640726 accuracy 0.7248874306678772 epoch 3 done
train loss 24.930038203363836 accuracy 0.7262387275695801 epoch 4 done
train loss 24.783239447552226 accuracy 0.7386261224746704 epoch 5 done
train loss 24.551229207412057 accuracy 0.7372747659683228 epoch 6 done
train loss 24.09126445521479 accuracy 0.7454954981803894 epoch 7 done
train loss 23.92397167371667 accuracy 0.7467342615127563 epoch 8 done
train loss 23.720080251279086 accuracy 0.7550675868988037 epoch 9 done
train loss 23.42531861429629 accuracy 0.753716230392456 epoch 10 done
train loss 23.23101124556168 accuracy 0.7574324607849121 epoch 11 done
train loss 23.20965511902519 accuracy 0.7620495557785034 epoch 12 done
train loss 22.95969766119252 accuracy 0.7599099278450012 epoch 13 done
train loss 

  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Test accuracy,▁
Training accuracy,▁▄▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇▇█▇████████████
Training loss,█▆▅▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁
total number of parameters,▁

0,1
Test accuracy,0.718
Test macro F1-Score,
Training accuracy,0.81543
Training loss,18.7717
total number of parameters,3106.0


[34m[1mwandb[0m: Agent Starting Run: urx0zdda with config:
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	epochs: 3000
[34m[1mwandb[0m: 	learning_rate: 0.001
[34m[1mwandb[0m: 	loss: CrossEntropyLoss
[34m[1mwandb[0m: 	optimizer: adamw
[34m[1mwandb[0m: 	patience: 60
[34m[1mwandb[0m: 	runs: 1
[34m[1mwandb[0m: 	subject: 7




train loss 30.402361413706906 accuracy 0.6323198080062866 epoch 0 done
train loss 27.339606430219565 accuracy 0.6751126050949097 epoch 1 done
train loss 26.067598736804467 accuracy 0.7027027010917664 epoch 2 done
train loss 25.391042730082635 accuracy 0.7153153419494629 epoch 3 done
train loss 24.94503914791605 accuracy 0.7317567467689514 epoch 4 done
train loss 24.747097284897514 accuracy 0.7307432889938354 epoch 5 done
train loss 24.45906236897344 accuracy 0.7382882833480835 epoch 6 done
train loss 24.288767441459328 accuracy 0.7438063025474548 epoch 7 done
train loss 24.225655555725098 accuracy 0.7441441416740417 epoch 8 done
train loss 23.86943767381751 accuracy 0.7462838292121887 epoch 9 done
train loss 23.957150023916494 accuracy 0.7468468546867371 epoch 10 done
train loss 24.115855569424834 accuracy 0.7438063025474548 epoch 11 done
train loss 23.71028056352035 accuracy 0.7495495676994324 epoch 12 done
train loss 23.611215819483217 accuracy 0.7493243217468262 epoch 13 done
train 

  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Test accuracy,▁
Training accuracy,▁▄▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇████████████
Training loss,█▅▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁
total number of parameters,▁

0,1
Test accuracy,0.729
Test macro F1-Score,
Training accuracy,0.81227
Training loss,19.34583
total number of parameters,3106.0


[34m[1mwandb[0m: Agent Starting Run: emgq1aht with config:
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	epochs: 3000
[34m[1mwandb[0m: 	learning_rate: 0.001
[34m[1mwandb[0m: 	loss: CrossEntropyLoss
[34m[1mwandb[0m: 	optimizer: adamw
[34m[1mwandb[0m: 	patience: 60
[34m[1mwandb[0m: 	runs: 1
[34m[1mwandb[0m: 	subject: 8




train loss 31.113447438115656 accuracy 0.610585629940033 epoch 0 done
train loss 28.26682296006576 accuracy 0.6666666865348816 epoch 1 done
train loss 27.092388982358187 accuracy 0.6931306719779968 epoch 2 done
train loss 26.455907717995018 accuracy 0.7084459662437439 epoch 3 done
train loss 25.7611925498299 accuracy 0.7248874306678772 epoch 4 done
train loss 25.47875995221345 accuracy 0.723761260509491 epoch 5 done
train loss 24.85911265663479 accuracy 0.7327702641487122 epoch 6 done
train loss 24.422638976055644 accuracy 0.7434684634208679 epoch 7 done
train loss 24.395288571067475 accuracy 0.7434684634208679 epoch 8 done
train loss 23.95740094392196 accuracy 0.7462838292121887 epoch 9 done
train loss 23.634822161301322 accuracy 0.7496621608734131 epoch 10 done
train loss 23.409320789834727 accuracy 0.7551801800727844 epoch 11 done
train loss 23.06557626309602 accuracy 0.761824369430542 epoch 12 done
train loss 22.72225732388704 accuracy 0.7649775147438049 epoch 13 done
train loss 22

  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Test accuracy,▁
Training accuracy,▁▄▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇██████████▇████████
Training loss,█▆▅▄▄▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total number of parameters,▁

0,1
Test accuracy,0.808
Test macro F1-Score,
Training accuracy,0.80766
Training loss,19.47546
total number of parameters,3106.0


[34m[1mwandb[0m: Agent Starting Run: n43irxoa with config:
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	epochs: 3000
[34m[1mwandb[0m: 	learning_rate: 0.001
[34m[1mwandb[0m: 	loss: CrossEntropyLoss
[34m[1mwandb[0m: 	optimizer: adamw
[34m[1mwandb[0m: 	patience: 60
[34m[1mwandb[0m: 	runs: 1
[34m[1mwandb[0m: 	subject: 9




train loss 61.47571328107048 accuracy 0.6273863911628723 epoch 0 done
train loss 53.80547109772178 accuracy 0.6871591210365295 epoch 1 done
train loss 51.02374417641583 accuracy 0.7182954549789429 epoch 2 done
train loss 49.70652900022619 accuracy 0.729204535484314 epoch 3 done
train loss 48.65539316570057 accuracy 0.7368181943893433 epoch 4 done
train loss 48.00342758964089 accuracy 0.7419317960739136 epoch 5 done
train loss 47.78616352642284 accuracy 0.7438636422157288 epoch 6 done
train loss 46.86183887369492 accuracy 0.753636360168457 epoch 7 done
train loss 46.71508164966808 accuracy 0.7555682063102722 epoch 8 done
train loss 46.52703182837543 accuracy 0.7561363577842712 epoch 9 done
train loss 45.62708628878874 accuracy 0.7621591091156006 epoch 10 done
train loss 45.431169874527875 accuracy 0.7627272605895996 epoch 11 done
train loss 45.00708965694203 accuracy 0.7662500143051147 epoch 12 done
train loss 44.684686702840466 accuracy 0.7681818008422852 epoch 13 done
train loss 44.05

  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Test accuracy,▁
Training accuracy,▁▄▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇██▇█████████████████
Training loss,█▅▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total number of parameters,▁

0,1
Test accuracy,0.63241
Test macro F1-Score,
Training accuracy,0.83102
Training loss,34.90602
total number of parameters,3106.0


[34m[1mwandb[0m: Agent Starting Run: hayr8ea9 with config:
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	epochs: 3000
[34m[1mwandb[0m: 	learning_rate: 0.001
[34m[1mwandb[0m: 	loss: CrossEntropyLoss
[34m[1mwandb[0m: 	optimizer: adamw
[34m[1mwandb[0m: 	patience: 60
[34m[1mwandb[0m: 	runs: 1
[34m[1mwandb[0m: 	subject: 10




train loss 74.37151735169547 accuracy 0.6484581828117371 epoch 0 done
train loss 66.68425282410213 accuracy 0.6959251165390015 epoch 1 done
train loss 63.71790759904044 accuracy 0.7194933891296387 epoch 2 done
train loss 62.63754573890141 accuracy 0.7246696352958679 epoch 3 done
train loss 61.86646342277527 accuracy 0.7289648056030273 epoch 4 done
train loss 61.0097497190748 accuracy 0.7379956245422363 epoch 5 done
train loss 60.04591604641506 accuracy 0.7448238134384155 epoch 6 done
train loss 59.87995444025312 accuracy 0.7422907948493958 epoch 7 done
train loss 59.243744015693665 accuracy 0.7509912252426147 epoch 8 done
train loss 58.416324734687805 accuracy 0.7546255588531494 epoch 9 done
train loss 58.214766076632905 accuracy 0.7540749311447144 epoch 10 done
train loss 57.12016148226601 accuracy 0.7599118947982788 epoch 11 done
train loss 57.03033261639731 accuracy 0.7620044350624084 epoch 12 done
train loss 56.058409452438354 accuracy 0.766740083694458 epoch 13 done
train loss 55.

  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Test accuracy,▁
Training accuracy,▁▄▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇▇▇█████████
Training loss,█▅▅▅▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▂▂▁▁▁▁▁▁▁▁▁
total number of parameters,▁

0,1
Test accuracy,0.68
Test macro F1-Score,
Training accuracy,0.82478
Training loss,45.37
total number of parameters,3106.0


[34m[1mwandb[0m: Agent Starting Run: u9iuhjac with config:
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	epochs: 3000
[34m[1mwandb[0m: 	learning_rate: 0.001
[34m[1mwandb[0m: 	loss: CrossEntropyLoss
[34m[1mwandb[0m: 	optimizer: adamw
[34m[1mwandb[0m: 	patience: 60
[34m[1mwandb[0m: 	runs: 2
[34m[1mwandb[0m: 	subject: 1




train loss 32.05361163097879 accuracy 0.6217342615127563 epoch 0 done
train loss 28.163392605988875 accuracy 0.678716242313385 epoch 1 done
train loss 26.76347259853197 accuracy 0.7031531929969788 epoch 2 done
train loss 26.205646805141285 accuracy 0.7131757140159607 epoch 3 done
train loss 25.576910433561906 accuracy 0.7202702760696411 epoch 4 done
train loss 25.36480314835258 accuracy 0.7245495915412903 epoch 5 done
train loss 25.10478108862172 accuracy 0.7304054498672485 epoch 6 done
train loss 24.783949209296186 accuracy 0.7382882833480835 epoch 7 done
train loss 24.34609201680059 accuracy 0.7476351261138916 epoch 8 done
train loss 24.329773902893066 accuracy 0.7424549460411072 epoch 9 done
train loss 24.00716470635456 accuracy 0.7459459900856018 epoch 10 done
train loss 23.94376777565998 accuracy 0.7497748136520386 epoch 11 done
train loss 23.58054360099461 accuracy 0.7556306719779968 epoch 12 done
train loss 23.391121905782946 accuracy 0.7586711645126343 epoch 13 done
train loss 

[34m[1mwandb[0m: Ctrl + C detected. Stopping sweep.


###Yassine's code

In [None]:

#hyperparameters
batch_size = 144

number_of_subjects=9
bad_subjects=[] #bad subject for Physionet MI
x, y= load_subjects('/content/drive/MyDrive/Colab Notebooks/bci4_subjects/', number_of_subjects, device, bad_subjects, apply_euclidean=True, with_eog=True)

n_chan = x[0].shape[1]

def loaders(removed_subject):
    T_x,T_y,V_x,V_y,Test_x,Test_y=loso(x, y, number_of_subjects, removed_subject, device, bad_subjects, with_validation=False)
    print(T_x.shape)
    train_data = list(zip(T_x, T_y))
    train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size, shuffle = True, drop_last = True)
    test_loader = torch.utils.data.DataLoader(list(zip(Test_x, Test_y)), batch_size = 1000)
    return train_loader, test_loader





class ConvNet(torch.nn.Module):
    def __init__(self, fm, n_convs, init_pool, kernel_size):
        super(ConvNet, self).__init__()
        self.pool = torch.nn.AvgPool1d(init_pool)
        self.conv = torch.nn.Conv1d(n_chan, fm, kernel_size = kernel_size, padding = kernel_size // 2, bias = False)
        self.bn = torch.nn.BatchNorm1d(fm)
        self.blocks = []
        newfm = fm
        oldfm = fm
        for i in range(n_convs):
            if i > 0:
                newfm = int(1.414 * newfm)
            self.blocks.append(torch.nn.Sequential(
                (torch.nn.Conv1d(oldfm, newfm, kernel_size = kernel_size, padding = kernel_size // 2, bias = False)),
                (torch.nn.BatchNorm1d(newfm)),
                (torch.nn.MaxPool1d(2) if i > 0 - 1 else torch.nn.MaxPool1d(1)),
                (torch.nn.ReLU()),
                (torch.nn.Conv1d(newfm, newfm, kernel_size = kernel_size, padding = kernel_size // 2, bias = False)),
                (torch.nn.BatchNorm1d(newfm)),
                (torch.nn.ReLU())
            ))
            oldfm = newfm
        self.blocks = torch.nn.ModuleList(self.blocks)
        self.fc = torch.nn.Linear(oldfm, 4)

    def forward(self, x):
        y = torch.relu(self.bn(self.conv(self.pool(x))))
        for seq in self.blocks:
            y = seq(y)
        y = y.mean(dim = 2)
        return self.fc(y)
    
def train(epoch, model, criterion, optimizer, train_loader, mixup = False):
    losses, scores = [], []
    cont = True
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        if mixup:
            mm = random.random()
            perm = torch.randperm(data.shape[0])
            output = model(mm * data + (1 - mm) * data[perm])
        else:
            output = model(data)
        decisions = torch.argmax(output, dim = 1)
        scores.append((decisions == target).float().mean().item())
        if mixup:
            loss = mm * criterion(output, target) + (1 - mm) * criterion(output, target[perm])
        else:
            loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    print("\r{:3d} {:3.3f} {:3.3f} ".format(epoch + 1, np.mean(losses), np.mean(scores)), end='')
    return np.mean(scores)

def test(epoch, model, test_loader, confusions = False):
    if confusions:
        confs = torch.zeros((4,4))
    score, count = 0, 0
    model.train()
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(test_loader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            decisions = torch.argmax(output, dim = 1)
            if confusions:
                for j in range(4):
                    for k in range(4):
                        confs[j][k] += (decisions[torch.where(target == j)[0]] == k).int().sum().item()
            score += (decisions == target).int().sum().item()
            count += target.shape[0]
    print("\r{:3d} test: {:.3f} ".format(epoch, score / count), end = '')
    if confusions:
        print(confs)
    return (score / count)

def train_test(params, runs = 3):
    model = ConvNet(params[0], params[1], params[2], params[3]).to(device)
    print(params, np.sum([m.numel() for m in model.parameters()]), "params")
    scores = []
    for removed_subject in range(1,10):
        print("Removed subject:", removed_subject)
        train_loader, test_loader = loaders(removed_subject)
        criterion = torch.nn.CrossEntropyLoss()
        for n_run in range(runs):
            print("number of the run ",n_run)
            model = ConvNet(params[0], params[1], params[2], params[3]).to(device)
            optimizer = torch.optim.Adam(model.parameters())
            scheduler = torch.optim.lr_scheduler.StepLR(optimizer = optimizer, step_size = 40, gamma = 0.1)
            
            for epoch in range(50):
                train_acc = train(epoch, model, criterion, optimizer, train_loader, mixup = True)
                score = test(epoch, model, test_loader)
                scheduler.step()
            scores.append(score)
            if n_run == runs - 1:
                print()
                print(" average: {:.3f}".format(np.mean(scores[-runs:])))
            else:
                print()
    print("score is {:.3f}".format(np.mean(scores)))
    return np.mean(scores)


variations = [16, 1, 1, 1, 2]
best_params = [448, 5, 1, 5, 9]
best_params = [64, 4, 4, 7]
#best_params = [32, 1, 2, 2, 3]
#ending fm, n_blocks, depth_block, pool, kernel_size
best_score = train_test(best_params)


48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Adding metadata with 4 columns
48 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 48 events and 1125 original time points ...
0 bad epochs dropped
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Adding metadata with 4 columns
48 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 48 events and 1125 original time poin



torch.Size([4608, 25, 1125])
number of the run  0
 49 test: 0.814 
number of the run  1
 49 test: 0.776 
number of the run  2
 49 test: 0.800 
 average: 0.797
Removed subject: 2
torch.Size([4608, 25, 1125])
number of the run  0
 49 test: 0.637 
number of the run  1
 49 test: 0.634 
number of the run  2
 49 test: 0.630 
 average: 0.634
Removed subject: 3
torch.Size([4608, 25, 1125])
number of the run  0
 49 test: 0.766 
number of the run  1
 49 test: 0.783 
number of the run  2
 49 test: 0.780 
 average: 0.776
Removed subject: 4
torch.Size([4608, 25, 1125])
number of the run  0
 49 test: 0.715 
number of the run  1
 49 test: 0.703 
number of the run  2
 49 test: 0.682 
 average: 0.700
Removed subject: 5
torch.Size([4608, 25, 1125])
number of the run  0
 49 test: 0.760 
number of the run  1
 49 test: 0.759 
number of the run  2
 49 test: 0.786 
 average: 0.769
Removed subject: 6
torch.Size([4608, 25, 1125])
number of the run  0
 49 test: 0.689 
number of the run  1
 49 test: 0.696 
numbe

In [None]:
import torch
import torchvision
import numpy as np
import random
from  braindecode.models.eegnet import EEGNetv4
import sys
if len(sys.argv) > 1:
    device = ("cuda:" + str(sys.argv[1])) if torch.cuda.is_available() else "cpu"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

#hyperparameters
batch_size = 144

    
X = torch.load("/users/local/eeg_data/bnci_X_offset_eog_EA.pt")
Y = torch.load("/users/local/eeg_data/bnci_Y_offset_eog_EA.pt")

n_chan = X[0].shape[1]

mean = torch.cat(X).transpose(1,2).reshape(-1, n_chan).mean(dim = 0)
std = torch.cat(X).transpose(1,2).reshape(-1, n_chan).std(dim = 0)

def loaders(removed_subject):
    train_X = torch.cat(X[:removed_subject] + X[removed_subject+1:])
    train_Y = torch.cat(Y[:removed_subject] + Y[removed_subject+1:])
    #train_X = (train_X - mean.unsqueeze(0).unsqueeze(2)) / std.unsqueeze(0).unsqueeze(2)
    train_X = (train_X - mean.unsqueeze(0).unsqueeze(2)) / std.unsqueeze(0).unsqueeze(2)
    train_X = torch.unbind(train_X)
    train_Y = torch.unbind(train_Y)
    train_data = list(zip(train_X, train_Y))
    train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size, shuffle = True, drop_last = True, num_workers = 8)
    test_X = X[removed_subject]
    test_X = (test_X - mean.unsqueeze(0).unsqueeze(2)) / std.unsqueeze(0).unsqueeze(2)
    test_X = torch.unbind(test_X)
    test_Y = torch.unbind(Y[removed_subject])
    test_loader = torch.utils.data.DataLoader(list(zip(test_X, test_Y)), batch_size = 10000)
    return train_loader, test_loader


class ConvNet(torch.nn.Module):
    def __init__(self, fm, n_convs, init_pool, kernel_size):
        super(ConvNet, self).__init__()
        self.pool = torch.nn.AvgPool1d(init_pool)
        self.conv = torch.nn.Conv1d(n_chan, fm, kernel_size = kernel_size, padding = kernel_size // 2, bias = False)
        self.bn = torch.nn.BatchNorm1d(fm)
        self.blocks = []
        newfm = fm
        oldfm = fm
        for i in range(n_convs):
            if i > 0:
                newfm = int(1.414 * newfm)
            self.blocks.append(torch.nn.Sequential(
                (torch.nn.Conv1d(oldfm, newfm, kernel_size = kernel_size, padding = kernel_size // 2, bias = False)),
                (torch.nn.BatchNorm1d(newfm)),
                (torch.nn.MaxPool1d(2) if i > 0 - 1 else torch.nn.MaxPool1d(1)),
                (torch.nn.ReLU()),
                (torch.nn.Conv1d(newfm, newfm, kernel_size = kernel_size, padding = kernel_size // 2, bias = False)),
                (torch.nn.BatchNorm1d(newfm)),
                (torch.nn.ReLU())
            ))
            oldfm = newfm
        self.blocks = torch.nn.ModuleList(self.blocks)
        self.fc = torch.nn.Linear(oldfm, 4)

    def forward(self, x):
        y = torch.relu(self.bn(self.conv(self.pool(x))))
        for seq in self.blocks:
            y = seq(y)
        y = y.mean(dim = 2)
        return self.fc(y)
    
def train(epoch, model, criterion, optimizer, train_loader, mixup = False):
    losses, scores = [], []
    cont = True
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        # #random crop
        # data = data[:,:,random.randint(0,300):-1-random.randint(0,300)]
        # #resize (700 is a target to provide both dilations and contractions)
        # data = torchvision.transforms.Resize((data.shape[1],700))(data)
        # # drop out sensors
        # data = data * (torch.rand(1,data.shape[1],1) > 0.1).to(device).float()
        # add random noise
        #data = data + 2 * torch.randn_like(data)
        optimizer.zero_grad()
        if mixup:
            mm = random.random()
            perm = torch.randperm(data.shape[0])
            output = model(mm * data + (1 - mm) * data[perm])
        else:
            output = model(data)
        decisions = torch.argmax(output, dim = 1)
        scores.append((decisions == target).float().mean().item())
        if mixup:
            loss = mm * criterion(output, target) + (1 - mm) * criterion(output, target[perm])
        else:
            loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    print("\r{:3d} {:3.3f} {:3.3f} ".format(epoch + 1, np.mean(losses), np.mean(scores)), end='')
    return np.mean(scores)

def test(epoch, model, test_loader, confusions = False):
    if confusions:
        confs = torch.zeros((4,4))
    score, count = 0, 0
    model.train()
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(test_loader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            decisions = torch.argmax(output, dim = 1)
            if confusions:
                for j in range(4):
                    for k in range(4):
                        confs[j][k] += (decisions[torch.where(target == j)[0]] == k).int().sum().item()
            score += (decisions == target).int().sum().item()
            count += target.shape[0]
    print("\r{:3d} test: {:.3f} ".format(epoch, score / count), end = '')
    if confusions:
        print(confs)
    return (score / count)

def train_test(params, runs = 3):
    model = ConvNet(params[0], params[1], params[2], params[3]).to(device)
    print(params, np.sum([m.numel() for m in model.parameters()]), "params")
    scores = []
    for removed_subject in range(len(X)):
        print("Removed subject:", removed_subject)
        train_loader, test_loader = loaders(removed_subject)
        criterion = torch.nn.CrossEntropyLoss()
        for n_run in range(runs):
            model = ConvNet(params[0], params[1], params[2], params[3]).to(device)
            optimizer = torch.optim.Adam(model.parameters())
            scheduler = torch.optim.lr_scheduler.StepLR(optimizer = optimizer, step_size = 40, gamma = 0.1)
            #optimizer = torch.optim.SGD(model.parameters(), lr = 0.1, momentum = 0.9, nesterov = True, weight_decay = 1e-4)
            #scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0 = 10)
            for epoch in range(50):
                train_acc = train(epoch, model, criterion, optimizer, train_loader, mixup = True)
                scheduler.step()
                
            score = test(epoch, model, test_loader)
            scores.append(score)
            if n_run == runs - 1:
                print(" average: {:.3f}".format(np.mean(scores[-runs:])))
            else:
                print()
    print("{:.3f}".format(np.mean(scores)))
    return np.mean(scores)


variations = [16, 1, 1, 1, 2]
best_params = [448, 5, 1, 5, 9]
best_params = [64, 4, 4, 7]
#best_params = [32, 1, 2, 2, 3]
#ending fm, n_blocks, depth_block, pool, kernel_size
best_score = train_test(best_params)


In [None]:
import numpy as np
accs_no_ea=[0.7024, 0.6289, 0.6145, 0.6795, 0.6892, 0.621, 0.6451, 0.6594, 0.6407, 0.5998]
accs_ea=[0.7378, 0.7378, 0.8438, 0.6563, 0.6788, 0.599, 0.6059, 0.6233, 0.651]
accs_bci=[0.73,0.72,0.84,0.65,0.67,0.599,0.6,0.62,0.65]
mlp_bci_accs=[0.57,0.58,0.55,0.6,0.74,0.66,0.81,0.67,0.58]
accs=mlp_bci_accs
accs=np.array(accs)
print('mean ',accs.mean())
print('std ',accs.std())

In [None]:
!nvidia-smi