<a href="https://colab.research.google.com/github/GhBlg/Others/blob/main/sweeps_with_blocks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

###Imports

In [1]:
!pip install wandb
!pip install scipy -U
!wandb login acf4cf1e91a5b92ebf1613195d6d05c09db63b4e
import math
import numpy as np
from math import ceil
import wandb
from torch import optim
import torch.nn.functional as F
from sklearn.metrics import f1_score
import torch
from torch import nn
from torch.nn import init
from torch.nn.utils import weight_norm
%matplotlib 
import matplotlib.pyplot as plt
import os
from google.colab import drive
drive.mount('/content/drive/')
import sys
sys.path.insert(0,'/content/drive/MyDrive/Colab Notebooks/mne_data')
sys.path.append(os.path.abspath('/content/drive/MyDrive/Colab Notebooks/mne_data'))
!mkdir /content/resultdir
!pip install torchinfo
from torchinfo import summary
resultdir='/content/resultdir'
!pip install braindecode
!pip install tqdm
!pip install mne
!pip install moabb

###########     REPRODUCIBILITY      #######################################
random_seed=42
#torch.use_deterministic_algorithms(True)
torch.manual_seed(random_seed)
############################################################################

############################################################################

cuda = torch.cuda.is_available()  # check if GPU is available, if True chooses to use it (will take the gpu specified by the argument args.device)
device = 'cuda' if cuda else 'cpu'
os.makedirs(resultdir,exist_ok=True)
print(f"Will use device {device}")
print(f"Will save checkpoints to {resultdir}")
################################################################################################################################# 


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting wandb
  Downloading wandb-0.13.1-py2.py3-none-any.whl (1.8 MB)
[K     |████████████████████████████████| 1.8 MB 5.1 MB/s 
[?25hCollecting GitPython>=1.0.0
  Downloading GitPython-3.1.27-py3-none-any.whl (181 kB)
[K     |████████████████████████████████| 181 kB 71.1 MB/s 
Collecting pathtools
  Downloading pathtools-0.1.2.tar.gz (11 kB)
Collecting shortuuid>=0.5.0
  Downloading shortuuid-1.0.9-py3-none-any.whl (9.4 kB)
Collecting docker-pycreds>=0.4.0
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting sentry-sdk>=1.0.0
  Downloading sentry_sdk-1.9.3-py2.py3-none-any.whl (157 kB)
[K     |████████████████████████████████| 157 kB 69.8 MB/s 
Collecting setproctitle
  Downloading setproctitle-1.3.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30 kB)
Collecting gitdb<5,>=4.0.1
  Downloading gitdb-4.0.9-p

Will use device cuda
Will save checkpoints to /content/resultdir


###Utils

In [2]:
import numpy as np
import torch
# Import necessary library
from scipy.linalg import sqrtm, inv 
import numpy as np
from braindecode.datasets.moabb import MOABBDataset
from braindecode.datautil.windowers import create_windows_from_events


# Apply Euclidean Alignment
def apply_EA(data, device='cuda'):
    '''
    Apply Euclidean aligment on array-like objects for 1 subject
    
    PARAMETER:
    data: 
        Data of one subject.
    
    
    OUTPUT:
        Aligned data with Euclidean Alignment
    '''
    
    # So that this function can handles separated or combined left and right trials
    # If they are separated
    # If they are not separated

    #if input is a torch tensor
    if torch.is_tensor(data):
        data=data.cpu().detach().numpy() 

    print('Found %d trial(s) in which EEG data is stored' %len(data))
    all_trials = data
    
    # Calculate reference matrix
    RefEA = 0
    print('Computing reference matrix RefEA')

    # Iterate over all trials, compute reference EA
    for trial in all_trials:
        cov = np.cov(trial, rowvar=True)
        RefEA += cov

    # Average over all trials
    RefEA = RefEA/all_trials.shape[0]
    
    # Adding reference EA as a new key in data
    data_dict={}
    print('Add RefEA as a new key in data')
    data_dict['RefEA'] = RefEA 
    
    # Compute R^(-0.5)
    R_inv = sqrtm(inv(RefEA))
    data_dict['R_inv'] = R_inv
    
        
    # Perform EA on each trial
    all_trials_EA = []
        
    for t in all_trials:
        all_trials_EA.append(R_inv@t)
        
    # Return all_trials_EA
    return torch.tensor(np.array(all_trials_EA)).float()
        

def standardize(X, mean=None, std=None):
  mean = X.mean(dim=-1, keepdim=True)
  std = X.std(dim=-1, keepdim=True)
  return (X - mean) / std , mean, std



################### Load Physionet data as done in the TIDNet paper ###################################################################
################## calling load_eeg_bci will return an object having data[subjects][records][0/1] 0 for Xs and 1 for Ys ###############
from torch.utils.data import Dataset
from mne.datasets import eegbci
import mne 
import tqdm
from collections import OrderedDict

BAD_SUBJECTS_EEGBCI = [87, 89, 91, 99]
SUBJECTS_EEGBCI = list(i for i in range(109) if i not in BAD_SUBJECTS_EEGBCI)
EVENTS_EEGBCI = dict(hands=2, feet=3)
BASELINE_EYES_OPEN = [1]
BASELINE_EYES_CLOSED = [2]

MOTOR_FISTS = (3, 7, 11)
IMAGERY_FISTS = (4, 8, 12)
MOTOR_FEET = (5, 9, 13)
IMAGERY_FEET_V_FISTS = (6, 10, 14)

def zscore(data: np.ndarray, axis=-1):
    return (data - data.mean(axis, keepdims=True)) / (data.std(axis, keepdims=True) + 1e-12)


def one_hot(y: torch.Tensor, num_classes):
    """ 1-hot encodes a tensor to another similarly stored tensor"""
    if len(y.shape) > 0 and y.shape[-1] == 1:
        y = y.squeeze(-1)
    out = torch.zeros(y.size()+torch.Size([num_classes]), device=y.device)
    return out.scatter_(-1, y.view((*y.size(), 1)), 1)

class EpochsDataset(Dataset):

    def __init__(self, epochs: mne.Epochs, force_label=None, picks=None, preproccesors=None, normalizer=zscore,
                 runs=None, train_mode=False):
        self.mode = train_mode
        self.epochs = epochs
        self._t_len = epochs.tmax - epochs.tmin
        self.loaded_x = [None for _ in range(len(epochs.events))]
        self.runs = runs
        self.picks = picks
        self.force_label = force_label if force_label is None else torch.tensor(force_label)
        self.normalizer = normalizer
        self.preprocessors = preproccesors if isinstance(preproccesors, (list, tuple)) else [preproccesors]
        for i, p in enumerate(self.preprocessors):
            self.preprocessors[i] = p(self.epochs)

    @property
    def channels(self):
        if self.picks is None:
            return len(self.epochs.ch_names)
        else:
            return len(self.picks)

    @property
    def sfreq(self):
        return self.epochs.info['sfreq']

    def train_mode(self, mode=False):
        self.mode = mode

    def __getitem__(self, index):
        ep = self.epochs[index]
        if self.loaded_x[index] is None:
            x = ep.get_data()
            if len(x.shape) != 3 or 0 in x.shape:
                print("I don't know why: {} index{}/{}".format(self.epochs, index, len(self)))
                print(self.epochs.info['description'])
                # raise AttributeError()
                return self.__getitem__(index - 1)
            x = x[0, self.picks, :]
            for p in self.preprocessors:
                x = p(x)
            x = torch.from_numpy(self.normalizer(x).astype('float32')).squeeze(0)
            self.loaded_x[index] = x
        else:
            x = self.loaded_x[index]

        y = torch.from_numpy(ep.events[..., -1]).long() if self.force_label is None else self.force_label

        if self.runs is not None:
            return x, y, one_hot(torch.tensor(self.runs * index / len(self)).long(), self.runs)

        return x, y

    def __len__(self):
        events = self.epochs.events[:, 0].tolist()
        return len(events)

def same(x):
    return x

def load_eeg_bci(targets=4, tmin=0, tlen=3, t_ev=0, t_sub=None, normalizer=same, low_f=None, high_f=None, #zscore
                 alignment=False, path_mne=None):

    paths = [eegbci.load_data(s+1, IMAGERY_FISTS, path=path_mne, update_path=False) for s in SUBJECTS_EEGBCI]
    raws = [mne.io.concatenate_raws([mne.io.read_raw_edf(p, preload=True) for p in path])
            for path in tqdm.tqdm(paths, unit='subj', desc='Loading')]
    datasets = OrderedDict()
    for i, raw in tqdm.tqdm(list(zip(SUBJECTS_EEGBCI, raws)), desc='Preprocessing'):
        if raw.info['sfreq'] != 160:
            tqdm.tqdm.write('Skipping..., sampling frequency: {}'.format(raw.info['sfreq']))
            continue
        raw.rename_channels(lambda x: x.strip('.'))
        if low_f or high_f:
            raw.filter(low_f, high_f, fir_design='firwin', skip_by_annotation='edge')
        events, _ = mne.events_from_annotations(raw, event_id=dict(T1=0, T2=1))
        picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, exclude='bads')
        epochs = mne.Epochs(raw, events[:41, ...], tmin=tmin, tmax=tmin + tlen - 1 / raw.info['sfreq'], picks=picks,
                            baseline=None, reject_by_annotation=False)#.drop_bad()
        if targets > 2:
            paths = eegbci.load_data(i + 1, BASELINE_EYES_OPEN, path=path_mne, update_path=False)
            raw = mne.io.concatenate_raws([mne.io.read_raw_edf(p, preload=True) for p in paths])
            raw.rename_channels(lambda x: x.strip('.'))
            if low_f or high_f:
                raw.filter(low_f, high_f, fir_design='firwin', skip_by_annotation='edge')
            events = np.zeros((events.shape[0] // 2, 3)).astype('int')
            events[:, -1] = 2
            events[:, 0] = np.linspace(0, raw.info['sfreq'] * (60 - 2 * tlen), num=events.shape[0]).astype(np.int)
            picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, exclude='bads')
            eyes_epochs = mne.Epochs(raw, events, tmin=tmin, tmax=tmin + tlen - 1 / raw.info['sfreq'], picks=picks,
                                     baseline=None, reject_by_annotation=False)#.drop_bad()
            epochs = mne.concatenate_epochs([eyes_epochs, epochs])
        if targets > 3:
            paths = eegbci.load_data(i+1, IMAGERY_FEET_V_FISTS, path=path_mne, update_path=False)
            raw = mne.io.concatenate_raws([mne.io.read_raw_edf(p, preload=True) for p in paths])
            raw.rename_channels(lambda x: x.strip('.'))
            if low_f or high_f:
                raw.filter(low_f, high_f, fir_design='firwin', skip_by_annotation='edge')
            events, _ = mne.events_from_annotations(raw, event_id=dict(T2=3))
            picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, exclude='bads')
            feet_epochs = mne.Epochs(raw, events[:20, ...], tmin=tmin, tmax=tmin + tlen - 1 / raw.info['sfreq'],
                                     picks=picks, baseline=None, reject_by_annotation=False)#.drop_bad()
            epochs = mne.concatenate_epochs([epochs, feet_epochs])

        datasets[i] = EpochsDataset(epochs, preproccesors=EuclideanAlignment if alignment else [],
                                    normalizer=normalizer, runs=3)

    return datasets


####################################################################################################################



def load_subjects(path, number_of_subjects, device, bad_subjects=[], apply_euclidean=True, with_eog=True):
    l= number_of_subjects+1
    sbj_x=[]
    sbj_y=[]
    if number_of_subjects==9:
        for subject_id in [e for e in range(1,l) if e not in bad_subjects]:
            dataset = MOABBDataset(dataset_name="BNCI2014001", subject_ids=[subject_id])


            trial_start_offset_seconds = -0.5
            # Extract sampling frequency, check that they are same in all datasets
            sfreq = dataset.datasets[0].raw.info['sfreq']
            assert all([ds.raw.info['sfreq'] == sfreq for ds in dataset.datasets])
            # Calculate the trial start offset in samples.
            trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)

            # Create windows using braindecode function for this. It needs parameters to define how
            # trials should be used.
            windows_dataset = create_windows_from_events(
                dataset,
                trial_start_offset_samples=trial_start_offset_samples,
                trial_stop_offset_samples=0,
                preload=True,
            )


            splitted = windows_dataset.split('session')
            train_set = splitted['session_T']
            valid_set = splitted['session_E']

            # delete stim channel and eog channels if wanted
            if with_eog==False: 
              dlt=-4
            else:
              dlt=-1

            train_x=np.array([ele[0][:dlt] for ele in train_set])
            train_y=np.array([ele[1] for ele in train_set])

            valid_x=np.array([ele[0][:dlt] for ele in valid_set])
            valid_y=np.array([ele[1] for ele in valid_set])

            T_x = torch.tensor( np.append(np.array(train_x), np.array(valid_x), axis=0) )
            T_y = torch.tensor( np.append(np.array(train_y), np.array(valid_y), axis=0) )

            if apply_euclidean==True:
                x=apply_EA(T_x)
            else : 
                x=T_x

            sbj_x.append(x)
            sbj_y.append(T_y)
            del T_x, T_y

    elif number_of_subjects==109:
        bad_subjects=[s-1 for s in bad_subjects]
        data=load_eeg_bci(path_mne=path)
        for subject_id in [e for e in range(number_of_subjects) if e not in bad_subjects]:   
            xx=[]
            yy=[]
            for records in range(len(data[subject_id])):
                xx.append(data[subject_id][records][0])
                yy.append(data[subject_id][records][1])
            xx=torch.stack(xx)
            yy=torch.tensor(yy)
            if apply_euclidean==True:
                sbj_x.append(apply_EA(xx))
            else:
                sbj_x.append(xx)
            sbj_y.append(yy)

    return sbj_x, sbj_y



#################### Leave One Subject Out ############################
def loso(x, y, number_of_subjects, loso, device, bad_subjects=[], with_validation=True): 
    l= number_of_subjects
    [s-1 for s in bad_subjects]
    T_x=torch.tensor([])
    T_y=torch.tensor([])

    loso=loso-1 # because lists start from 0

    for subject_id in [e for e in range(l) if e not in [loso]+bad_subjects]:
        T_x=torch.cat((T_x,x[subject_id]), 0)
        T_y=torch.cat((T_y,y[subject_id]), 0)

    T_x=torch.tensor(T_x).reshape([-1, T_x.shape[1], T_x.shape[2]])
    T_y=torch.tensor(T_y).reshape([-1])

    k_f=int(0.9*T_x.shape[0])

    if with_validation==True:
        data_perm = torch.randperm(T_x.shape[0])
        temp_x, temp_y = T_x[data_perm[:]], T_y[data_perm[:]]
        T_x, T_y = temp_x[:k_f], temp_y[:k_f]
        V_x, V_y = temp_x[k_f:], temp_y[k_f:]
    else :
        V_x, V_y = torch.tensor([]), torch.tensor([])


    Test_x=x[loso]
    Test_y=y[loso]

    
    T_x, mean, std=standardize(T_x)
    if with_validation==True:
        V_x, _ , _=standardize(V_x,mean, std)
    Test_x, _ , _=standardize(Test_x, mean, std)

    ############################################################################
    return(T_x.to(device),T_y.to(device).to(torch.int64),V_x.to(device),V_y.to(device).to(torch.int64),Test_x.to(device),Test_y.to(device).to(torch.int64))


#################### Leave Multiple Subjects Out ############################
import random
from sklearn.model_selection import KFold

def lmso(x, y, number_of_subjects, kfold, device, bad_subjects=[], with_validation=True, random_seed=42): 
    c = list(zip(x, y))
    random.Random(random_seed).shuffle(c)
    x, y = zip(*c)
    
    # 10-fold crossvalidation
    kf = KFold(n_splits=10)

    x=np.array([t.numpy() for t in x])
    y=np.array([t.numpy() for t in y])


    i=1
    for train_index, test_index in kf.split(x):
        X_train, X_test = x[train_index], x[test_index]
        y_train, y_test = y[train_index], y[test_index]
        if i==kfold:
            break
        else:
            i+=1
    

    del c, kf

    l= number_of_subjects
    [s-1 for s in bad_subjects]

    T_x=torch.tensor([])
    T_y=torch.tensor([])
    Test_x=torch.tensor([])
    Test_y=torch.tensor([])

    for subject_id in [e for e in range(X_train.shape[0])]:
        T_x=torch.cat((T_x,torch.tensor(X_train[subject_id])), 0)
        T_y=torch.cat((T_y,torch.tensor(y_train[subject_id])), 0)

    T_x=torch.tensor(T_x).reshape([-1, T_x.shape[1], T_x.shape[2]])
    T_y=torch.tensor(T_y).reshape([-1])

    k_f=int(0.9*T_x.shape[0])

    if with_validation==True:
        data_perm = torch.randperm(T_x.shape[0])
        temp_x, temp_y = T_x[data_perm[:]], T_y[data_perm[:]]
        T_x, T_y = temp_x[:k_f], temp_y[:k_f]
        V_x, V_y = temp_x[k_f:], temp_y[k_f:]
    else :
        V_x, V_y = torch.tensor([]), torch.tensor([])

    for lms in range(len(X_test)):
        Test_x=torch.cat((Test_x,torch.tensor(X_test[lms])), 0)
        Test_y=torch.cat((Test_y,torch.tensor(y_test[lms])), 0)

    Test_x=torch.tensor(Test_x).reshape([-1, Test_x.shape[1], Test_x.shape[2]])
    Test_y=torch.tensor(Test_y).reshape([-1])
    
    T_x, mean, std=standardize(T_x)
    if with_validation==True:
        V_x, _ , _=standardize(V_x,mean, std)
    Test_x, _ , _=standardize(Test_x, mean, std)
    
    ############################################################################
    return(T_x.to(device),T_y.to(device).to(torch.int64),V_x.to(device),V_y.to(device).to(torch.int64),Test_x.to(device),Test_y.to(device).to(torch.int64))


#############Data Augmentation ###########################################################
from scipy.signal import butter, lfilter
    
#flip channels
def flip_channels(data):
  r=torch.randperm(data.shape[1])
  data=data[:,r,:]
  return data

#time inverse
def time_inverse(data):
  return torch.flip(data, [2])

#Masking in %
def masking(data,p):
  x=data
  mask=torch.FloatTensor(x.shape).uniform_() > p
  masked_output = data * mask.int().float().to(device)
  return masked_output


def butter_bandpass(lowcut, highcut, fs, order=5):
    return butter(order, [lowcut, highcut], fs=fs, btype='band')
def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    y = lfilter(b, a, data)
    return y

selection=['fc','ti','mask','filter']
def apply_da(data,selection):
  s=random.choice(selection)
  if s=='fc':
    data=flip_channels(data)
  elif s=='ti':
    data=time_inverse(data)
  elif s=='mask':
    data=masking(data,0.2)
  elif s=='filter':
    data=torch.tensor(butter_bandpass_filter(data.detach().cpu().numpy(),0.1, 40, 250)).float().to(device)
  return data

#####################################################################################

def weight_attack(model, p=0.1):
  ii=0
  for i in model:
    try:
      r1=torch.max(model[ii].weight)
      r2=torch.min(model[ii].weight)
      M1=(r1 - r2) * torch.rand(model[ii].weight.shape)+ r2
      M2=(torch.abs(M1)< p*(r1 - r2)).float()
      with torch.no_grad():
        M1=torch.flatten(M1)
        M2=torch.flatten(M2)
        mw=torch.flatten(model[2].weight)
        shp=model[ii].weight.shape
        for i in range(len(M1)):
          if M2[i]==1:
            mw[i]=M1[i]
        mw=mw.double()

        model[ii].weight=torch.nn.Parameter(mw.reshape(shp))
    except:
      pass
    ii+=1
  return model

  warn('datautil.windowers module is deprecated and is now under '


###coatnets

In [None]:
from math import ceil
import torch
from torch import nn
from torch.nn.utils import weight_norm
import torch.nn.functional as F
import math

##############  EEG_CoatNet models #############################################################

class PrintLayer(nn.Module):
    def __init__(self):
        super(PrintLayer, self).__init__()
    
    def forward(self, x):
        # Do your print / debug stuff here
        print(x.shape)
        return x

class Ensure4d(nn.Module):
    def forward(self, x):
        while(len(x.shape) < 4):
            x = x.unsqueeze(-1)
        return x


class Expression(nn.Module):
    """Compute given expression on forward pass.
    Parameters
    ----------
    expression_fn : callable
        Should accept variable number of objects of type
        `torch.autograd.Variable` to compute its output.
    """

    def __init__(self, expression_fn):
        super(Expression, self).__init__()
        self.expression_fn = expression_fn

    def forward(self, *x):
        return self.expression_fn(*x)

    def __repr__(self):
        if hasattr(self.expression_fn, "func") and hasattr(
            self.expression_fn, "kwargs"
        ):
            expression_str = "{:s} {:s}".format(
                self.expression_fn.func.__name__, str(self.expression_fn.kwargs)
            )
        elif hasattr(self.expression_fn, "__name__"):
            expression_str = self.expression_fn.__name__
        else:
            expression_str = repr(self.expression_fn)
        return (
            self.__class__.__name__ +
            "(expression=%s) " % expression_str
        )

############## Multi-head Attention Mechanism #################################################

def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    attn_logits = torch.matmul(q, k.transpose(-2, -1))
    attn_logits = attn_logits / math.sqrt(d_k)
    if mask is not None:
        attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
    attention = F.softmax(attn_logits, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention

class MultiheadAttention(nn.Module):

    def __init__(self, input_dim, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be 0 modulo number of heads."

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Stack all weight matrices 1...h together for efficiency
        # Note that in many implementations you see "bias=False" which is optional
        self.qkv_proj = nn.Linear(input_dim, 3*embed_dim)
        self.o_proj = nn.Linear(embed_dim, embed_dim)

        self._reset_parameters()

    def _reset_parameters(self):
        # Original Transformer initialization, see PyTorch documentation
        nn.init.xavier_uniform_(self.qkv_proj.weight)
        self.qkv_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.o_proj.weight)
        self.o_proj.bias.data.fill_(0)


    def forward(self, x, mask=None, return_attention=False):
        batch_size, seq_length, embed_dim = x.size()
        qkv = self.qkv_proj(x)

        # Separate Q, K, V from linear output
        qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3*self.head_dim)
        qkv = qkv.permute(0, 2, 1, 3) # [Batch, Head, SeqLen, Dims]
        q, k, v = qkv.chunk(3, dim=-1)

        # Determine value outputs
        values, attention = scaled_dot_product(q, k, v, mask=mask)
        values = values.permute(0, 2, 1, 3) # [Batch, SeqLen, Head, Dims]
        values = values.reshape(batch_size, seq_length, self.embed_dim)
        o = self.o_proj(values)

        if return_attention:
            return o, attention
        else:
            return o



############## stem_stage (reducing temporal dimension) #################################################
class stem_stage(nn.Module):
    def __init__(self, in_channels, out_channels, k ):
        super().__init__()
        def _permute(x):
            '''
            Permutes data:
            from dim:
            batch, chans, time, 1
            to dim:
            batch, chans, 1, time'''
            return x.permute([0, 1, 3, 2])

        prnt=PrintLayer()
      
        layers = [Ensure4d(),
                Expression(_permute),
                nn.Conv2d(in_channels, out_channels, kernel_size=(1,k)),
                nn.BatchNorm2d(out_channels),
                nn.GELU()
                ]

        self.model= nn.Sequential(*layers)
      
    def forward(self, x):
      return self.model(x)

############## conv_stage #################################################
class conv_stage(nn.Module):
    def __init__(self, in_channels, out_channels, n_ch, conv_kernel, pool_kernel):
        super().__init__()

        prnt=PrintLayer()
        self.pool=nn.AvgPool2d(kernel_size=(1, pool_kernel))
        self.res=nn.Conv2d(in_channels, out_channels, kernel_size=(1,1), padding='same')
        layers = [nn.BatchNorm2d(in_channels),
                  nn.Conv2d(in_channels, out_channels, kernel_size=(1,1), padding='same'),
                  nn.Conv2d(in_channels=n_ch, out_channels=n_ch, kernel_size=(1,conv_kernel),groups=n_ch, padding='same'),
                nn.Conv2d(out_channels, out_channels, kernel_size=(1,1), padding='same'),
                nn.ReLU()
                ]

        self.model= nn.Sequential(*layers)
      
    def forward(self, x):
      return self.pool(self.res(x))+self.pool(self.model(x))

############## InceptionBlock #################################################
class InceptionBlock(nn.Module):
    def __init__(self, in_channels, out_channels, n_ch, conv_kernel, pool_kernel):        
        super().__init__()

        self.branch1 = conv_stage(in_channels, out_channels, n_ch, conv_kernel=100, pool_kernel=2)
        self.branch2 = conv_stage(in_channels, out_channels, n_ch, conv_kernel=5, pool_kernel=2)
        self.branch3 = conv_stage(in_channels, out_channels, n_ch, conv_kernel=3, pool_kernel=2)
        self.branch4 = conv_stage(in_channels, out_channels, n_ch, conv_kernel=2, pool_kernel=2)

    def forward(self, x):
        branches = (self.branch1, self.branch2, self.branch3, self.branch4)
        return torch.cat([branch(x) for branch in branches], 1)


############## att_stage #################################################

class FeedForward(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, output_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x): 
        return self.net(x)
        
class att_stage(nn.Module):
    def __init__(self, embed_dim, num_heads, in_channels, temp_dim, pool_kernel, incep=1, masking_rate=0.4):
        super().__init__()

        prnt=PrintLayer()
        self.mr=masking_rate
        self.incep=incep
        self.ed=embed_dim
        self.ic= in_channels
        self.ln=nn.LayerNorm(temp_dim)
        self.mha=MultiheadAttention(temp_dim ,embed_dim, num_heads)
        self.ffn=FeedForward(temp_dim, temp_dim, embed_dim)
      
    def forward(self, x):
      x=torch.squeeze(x)
      #x1=self.pool(x)
      x1=self.ffn(x)
      x=self.ln(x)
      #x=self.pool(x)
      mask=(torch.cuda.FloatTensor(self.ic*self.incep, self.ic*self.incep).uniform_() > self.mr).type(torch.uint8)#.to('cpu')
      x,s=self.mha(x, mask, return_attention=True)
      return x+x1.reshape([-1,self.ic*self.incep,self.ed]),  s
      






#residual block
class residual_block(nn.Module):
    def __init__(self, in_channels, out_channels, k, p,stride=1):
        """
        Args:
          in_channels (int):  Number of input channels.
          out_channels (int): Number of output channels.
          stride (int):       Controls the stride.
        """
        super(residual_block, self).__init__()

        self.skip = nn.Sequential()

        if stride != 1 or in_channels != out_channels:
          self.skip = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, padding='same', bias=False),
            nn.AvgPool2d(kernel_size=(1,p)),
            nn.BatchNorm2d(out_channels))
        else:
          self.skip = None

        self.block = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1,k), padding='same', bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=(1,p)),
            nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(1,k), padding='same', bias=False),
            nn.BatchNorm2d(out_channels))

    def forward(self, x):
      out = self.block(x)
      out += (x if self.skip is None else self.skip(x))
      out = F.relu(out)
      return out


############## CoatNet #################################################
class coatnet(nn.Module):
    def __init__(self, n_classes, n_channels, embed_dim, num_heads, conv_kernel, pool_kernel, incep=4):
        super().__init__()

        n_ch=n_channels
        self.s0=stem_stage(n_ch, n_ch, pool_kernel)
        self.res1=residual_block(n_ch, n_ch, 64, 2,stride=2)
        self.res2=residual_block(n_ch, n_ch, 32, 2,stride=2)
        self.res3=residual_block(n_ch, n_ch, 16, 2,stride=2)
        self.s11= InceptionBlock(n_ch, n_ch, n_ch, conv_kernel, pool_kernel)
        temp_dim=70*2
        self.s2=att_stage(embed_dim, num_heads, n_ch, temp_dim, pool_kernel, incep)
        temp_dim=256
        self.s22=att_stage(embed_dim, num_heads, n_ch, temp_dim, pool_kernel, incep)

        self.classifier=nn.Linear(5632, n_classes)
      
    def forward(self, x):
      x=self.s0(x)
      x=self.res1(x)
      x=self.res2(x)
      x=self.res3(x)
      #x=self.s11(x)
      x,s= self.s2(x)
      #x,s= self.s22(x)

      x=torch.flatten(x, 1)
      x=self.classifier(x)
      return x

In [None]:
eeg = torch.randn(60,22,1125)
network=coatnet(n_classes=4, n_channels=22, embed_dim=256, num_heads=4, conv_kernel=10, pool_kernel=4, incep=1)
out=network(eeg)

torch.Size([60, 22, 1, 140])
torch.Size([60, 22, 256])


### Variational Auto-Encoder test

In [None]:
class VariationalEncoder(nn.Module):
    def __init__(self, n_ch, latent_dims):    
        super(VariationalEncoder, self).__init__()

        # Layer 1
        self.conv1 = nn.Conv2d(1, 16, (1, 25), padding = 0)
        self.batchnorm1 = nn.BatchNorm2d(16, False)
        
        # Layer 2
        self.padding1 = nn.ZeroPad2d((16, 17, 0, 1))
        self.conv2 = nn.Conv2d(1, 4, (2, 13))
        self.batchnorm2 = nn.BatchNorm2d(4, False)
        self.pooling2 = nn.MaxPool2d(2, 4)
        
        # Layer 3
        self.padding2 = nn.ZeroPad2d((2, 1, 4, 3))
        self.conv3 = nn.Conv2d(4, 4, (8, 4))
        self.batchnorm3 = nn.BatchNorm2d(4, False)
        self.pooling3 = nn.MaxPool2d((2, 4))
      

        self.linear1 = nn.Linear(568, 128)
        self.linear2 = nn.Linear(128, latent_dims)
        self.linear3 = nn.Linear(128, latent_dims)

        self.N = torch.distributions.Normal(0, 1)
        self.N.loc = self.N.loc.cuda() # hack to get sampling on the GPU
        self.N.scale = self.N.scale.cuda()
        self.kl = 0

    def forward(self, x):
        # Layer 1
        x=x.reshape([-1,1,1125,25])
        x = F.elu(self.conv1(x))
        x = self.batchnorm1(x)
        x = F.dropout(x, 0.25)
        x = x.permute(0, 3, 1, 2)
        
        # Layer 2
        x = self.padding1(x)
        x = F.elu(self.conv2(x))
        x = self.batchnorm2(x)
        x = F.dropout(x, 0.25)
        x = self.pooling2(x)
        
        # Layer 3
        x = self.padding2(x)
        x = F.elu(self.conv3(x))
        x = self.batchnorm3(x)
        x = F.dropout(x, 0.25)
        x = self.pooling3(x)
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.linear1(x))
        mu =  self.linear2(x)
        sigma = torch.exp(self.linear3(x))
        z = mu + sigma*self.N.sample(mu.shape)
        self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum()
        return z      

class Decoder(nn.Module): 
    def __init__(self, n_ch, latent_dims):
        super().__init__()

        self.decoder_lin = nn.Sequential(
            nn.Linear(latent_dims, 128),
            nn.ReLU(True),
            nn.Linear(128, 568),
            nn.ReLU(True)
        )

        self.unflatten = nn.Unflatten(dim=1, unflattened_size=(1, 568))

        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose1d(1, 16, 4, stride=2, output_padding=0),
            nn.BatchNorm1d(16),
            nn.ReLU(True),
            nn.ConvTranspose1d(16, 8, 3, stride=1, padding=1, output_padding=1),
            nn.BatchNorm1d(8),
            nn.ReLU(True),
            nn.ConvTranspose1d(8, n_ch, 4, stride=1, padding=1, output_padding=1)
        )
        
    def forward(self, x):
        x = self.decoder_lin(x)
        x = self.unflatten(x)
        x = self.decoder_conv(x)
        print(x.shape)
        return x
        
class VariationalAutoencoder(nn.Module):
    def __init__(self, n_ch, latent_dims):
        super(VariationalAutoencoder, self).__init__()
        self.encoder = VariationalEncoder(n_ch, latent_dims)
        self.decoder = Decoder(n_ch, latent_dims)

    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)
      

In [None]:
eeg = torch.randn(10,25,1125).to('cuda')
model=VariationalAutoencoder(25,1024).to('cuda')
out=model(eeg)

In [None]:
import pandas as pd
from sklearn.manifold import TSNE
import plotly.express as px

vae = VariationalAutoencoder(n_ch=25, latent_dims=1024).to(device)
lr = 1e-3
optim_vae = torch.optim.Adam(vae.parameters(), lr=lr, weight_decay=1e-5)

def train_vae(vae, optimizer, T_x, T_y, batch_size):
    # Set train mode for both the encoder and the decoder
    vae.train()
    train_loss = 0.0
    data_perm = torch.randperm(T_x.shape[0])

    for i in range(T_x.shape[0] // batch_size):# + (1 if T_x.shape[0] % batch_size != 0 else 0)):
        data, target = T_x[data_perm[batch_size*i: batch_size*(i+1)]], T_y[data_perm[batch_size*i: batch_size*(i+1)]]        


        x=data

        # Forward pass 
        x_hat = vae(x)
        # Evaluate loss
        loss = ((x - x_hat)**2).sum() + vae.encoder.kl

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # Print batch loss
        train_loss+=loss.item()

    return train_loss / T_x.shape[0]

def test_vae(vae,T_x, T_y, batch_size):
    # Set evaluation mode for encoder and decoder
    vae.eval()
    val_loss = 0.0
    data_perm = torch.randperm(T_x.shape[0])
    with torch.no_grad(): # No need to track the gradients
        for i in range(T_x.shape[0] // batch_size):# + (1 if T_x.shape[0] % batch_size != 0 else 0)):
            data, target = T_x[data_perm[batch_size*i: batch_size*(i+1)]], T_y[data_perm[batch_size*i: batch_size*(i+1)]]  
            x=data
            # Encode data
            encoded_data = vae.encoder(x)
            # Decode data
            x_hat = vae(x)
            loss = ((x - x_hat)**2).sum() + vae.encoder.kl
            val_loss += loss.item()

    return val_loss / T_x.shape[0]

def vae_training(vae, optimizer, T_x, T_y, V_x, V_y,batch_size):
    num_epochs = 2
    for epoch in range(num_epochs):
      train_loss = train_vae(vae, optimizer, T_x, T_y, batch_size)
      val_loss = test_vae(vae,V_x, V_y, batch_size)
      print('\n EPOCH {}/{} \t train loss {:.3f} \t val loss {:.3f}'.format(epoch + 1, num_epochs,train_loss,val_loss))

    #plot latent space      
    encoded_samples = []
    for i in range(V_x.shape[0]):
        img = V_x[i].unsqueeze(0)
        label = V_y[i]
        # Encode image
        vae.eval()
        with torch.no_grad():
            encoded_img  = vae.encoder(img)
        # Append to list
        encoded_img = encoded_img.flatten().cpu().numpy()
        encoded_sample = {f"Enc. Variable {i}": enc for i, enc in enumerate(encoded_img)}
        encoded_sample['label'] = label
        encoded_samples.append(encoded_sample)
        
    encoded_samples = pd.DataFrame(encoded_samples)

    tsne = TSNE(n_components=2)
    tsne_results = tsne.fit_transform(encoded_samples.drop(['label'],axis=1))

    fig = px.scatter(tsne_results, x=0, y=1, color=encoded_samples.label.astype(str),labels={'0': 'tsne-2d-one', '1': 'tsne-2d-two'})
    fig.show()   

    data1=V_x[0]
    print(data1.shape)
        
    sampling_rate = 128
    info = mne.create_info(25, sampling_rate)
    data2  = vae(data1.reshape([1,25,1125]))
    raw1=mne.io.RawArray(data1.reshape([25,1125]).detach().cpu().numpy(), info)
    raw2=mne.io.RawArray(data2.reshape([25,1125]).detach().cpu().numpy(), info)
    print('plot 1')
    raw1.plot()
    print('plot 2')
    raw2.plot()

### Quick only-MLP test

In [None]:
############## MLP blocks #################################################
class MLP(nn.Module):
    def __init__(self, arch, vae):
        super().__init__()
        layers = []
        self.va=vae.encoder
        #self.p1 = nn.AvgPool2d(kernel_size=(1, 20))
        self.p2 = nn.MaxPool2d(kernel_size=(1, 2))
        prnt=PrintLayer()
        in_channels=1024
        for x in arch:
          layers += [nn.Linear(in_channels, x),
                      nn.BatchNorm1d(x),
                      nn.ReLU() ]
          in_channels = x

        self.model= nn.Sequential(*layers)
      
    def forward(self, x):
      #x=x.reshape([-1,64,1,480])
      #x = torch.flatten(x, 1)
      x=self.va(x)
      return self.model(x)


In [None]:
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNN, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        #self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        # -> x needs to be: (batch_size, seq, input_size)
        
        # or:
        #self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)
        
    def forward(self, x):
        # Set initial hidden states (and cell states for LSTM)
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) 
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) 
        
        # x: (n, 28, 28), h0: (2, n, 128)
        
        # Forward propagate RNN
        #out, _ = self.rnn(x, h0)  
        # or:
        out, _ = self.lstm(x, (h0,c0))  
        
        # out: tensor of shape (batch_size, seq_length, hidden_size)
        # out: (n, 28, 128)
        
        # Decode the hidden state of the last time step
        out = out[:, -1, :]
        # out: (n, 128)
         
        out = self.fc(out)
        # out: (n, 10)
        return out

In [None]:
eeg = torch.randn(10,25,1125).to('cuda')
model=RNN(input_size=1125, hidden_size=128 , num_layers=4, num_classes=4).to('cuda')
out=model(eeg)
print(out.shape)

In [None]:
eeg = torch.randn(64,480)
arch=[512,4]
#model=MLP(arch)
#out=model(eeg)

###Models

In [4]:
from math import ceil
import torch
from torch import nn
from torch.nn.utils import weight_norm
import torch.nn.functional as F
import math

##############  TIDNet modules #############################################################

class PrintLayer(nn.Module):
    def __init__(self):
        super(PrintLayer, self).__init__()
    
    def forward(self, x):
        # Do your print / debug stuff here
        print(x.shape)
        return x

class Ensure4d(nn.Module):
    def forward(self, x):
        while(len(x.shape) < 4):
            x = x.unsqueeze(-1)
        return x


class Expression(nn.Module):
    """Compute given expression on forward pass.
    Parameters
    ----------
    expression_fn : callable
        Should accept variable number of objects of type
        `torch.autograd.Variable` to compute its output.
    """

    def __init__(self, expression_fn):
        super(Expression, self).__init__()
        self.expression_fn = expression_fn

    def forward(self, *x):
        return self.expression_fn(*x)

    def __repr__(self):
        if hasattr(self.expression_fn, "func") and hasattr(
            self.expression_fn, "kwargs"
        ):
            expression_str = "{:s} {:s}".format(
                self.expression_fn.func.__name__, str(self.expression_fn.kwargs)
            )
        elif hasattr(self.expression_fn, "__name__"):
            expression_str = self.expression_fn.__name__
        else:
            expression_str = repr(self.expression_fn)
        return (
            self.__class__.__name__ +
            "(expression=%s) " % expression_str
        )


class _TemporalFilter(nn.Module):
    def __init__(self, in_chans, filters, depth, temp_len, drop_prob=0., activation=nn.LeakyReLU,
                 residual='netwise'):
        super().__init__()
        temp_len = temp_len + 1 - temp_len % 2
        self.residual_style = str(residual)
        net = list()

        for i in range(depth):
            dil = depth - i
            conv = weight_norm(nn.Conv2d(in_chans if i == 0 else filters, filters,
                                         kernel_size=(1, temp_len), dilation=dil,
                                         padding=(0, dil * (temp_len - 1) // 2)))
            net.append(nn.Sequential(
                conv,
                activation(),
                nn.Dropout2d(drop_prob)
            ))
        if self.residual_style.lower() == 'netwise':
            self.net = nn.Sequential(*net)
            self.residual = nn.Conv2d(in_chans, filters, (1, 1))
        elif residual.lower() == 'dense':
            self.net = net

    def forward(self, x):
        if self.residual_style.lower() == 'netwise':
            return self.net(x) + self.residual(x)
        elif self.residual_style.lower() == 'dense':
            for layer in self.net:
                x = torch.cat((x, layer(x)), dim=1)
            return x


############## MLP blocks #################################################
class MLP_blocks(nn.Module):
    def __init__(self, arch, in_channels):
        super().__init__()
        layers = []
        prnt=PrintLayer()
        for x in arch:
            layers += [nn.Linear(in_channels, x),
                        nn.BatchNorm1d(x),
                        nn.ReLU()]
            in_channels = x

        self.model= nn.Sequential(*layers)
      
    def forward(self, x):
      return self.model(x)

############## spatial convs(electrodes,1) (electrode aggregator)  (still on testing phase) #################################################
class spatial_aggregator(nn.Module):
    def __init__(self, arch, in_channels, k ):
        super().__init__()
        def _permute(x):
            """
            Permutes data:
            from dim:
            batch, chans, time, 1
            to dim:
            batch, 1, chans, time
            """
            #return x.permute([0, 3, 1, 2])
            return x.permute([0, 1, 3, 2])

        prnt=PrintLayer()
      
        if in_channels==1:
            layers = [Ensure4d(),
                Expression(_permute)]
        else:
            layers = []
        #prnt=PrintLayer()
            
        for x in arch:
            s=str(x)
            if s[0] == 'A':
                layers += [nn.AvgPool2d(kernel_size=(1, int(s[1:])), stride=(1,2))]
            else:
                layers += [nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(1,10),groups=64),
                           #nn.Conv2d(in_channels, x, kernel_size=(1,k)), 
                           nn.BatchNorm2d(x),
                           nn.ReLU()]
                in_channels = x

        self.model= nn.Sequential(*layers)
      
    def forward(self, x):
      return self.model(x)

############## electrode weight-sharing convs(1,1) (a.k.a 1x1encoders) #################################################
class SharedSpaceTimeConv1x1(nn.Module):
    def __init__(self, mode, arch, in_channels):
        super().__init__()
        def _permute(x):
            """
            Permutes data:
            from dim:
            batch, chans, time, 1
            to dim:
            batch, 1, chans, time
            """
            return x.permute([0, 3, 1, 2])
        prnt=PrintLayer()
        if mode=='1x1':
            layers = [Ensure4d(),
                Expression(_permute),]
            in_channels = 1
        else:
            layers = []
        for x in arch:
            s=str(x)
            if s[0] == 'A':
                layers += [nn.AvgPool2d(kernel_size=(1, int(s[1:])), stride=(1,2))]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=1),
                           nn.BatchNorm2d(x),
                           nn.ReLU()]
                in_channels = x

        self.model= nn.Sequential(*layers)
      
    def forward(self, x):
      return self.model(x)


############## TIDNet temporal block #################################################
class TemporalTIDNet(nn.Module):
    def __init__(self, t_filters, input_window_samples, drop_prob, pooling,
                 temp_layers,  temp_span):
        super().__init__()
        self.temp_len = ceil(temp_span * input_window_samples)

        def _permute(x):
            """
            Permutes data:
            from dim:
            batch, chans, time, 1
            to dim:
            batch, 1, chans, time
            """
            return x.permute([0, 3, 1, 2])

        self.temporal = nn.Sequential(
            Ensure4d(),
            Expression(_permute),
            _TemporalFilter(1, t_filters, depth=temp_layers, temp_len=self.temp_len),
            nn.MaxPool2d((1, pooling)),
            nn.Dropout2d(drop_prob),
        )

    def forward(self, x):
        x = self.temporal(x)
        return x


###################################################################################### 

class TemporalShareSpaceTime(nn.Module):

    def __init__(self, mode, arch, n_classes, in_chans, input_window_samples, t_filters,
                 drop_prob, pooling, temp_layers, temp_span):
        super().__init__()
        def compute_params(arch,h,w):
            ### height and width of the input convolutions of each feature map
            #### compute the number of inputs after flatten ######
            for i in arch:
                s=str(i)
                if s[0] == 'A':
                    w=int((w-int(s[1:]))/2 +1 )     #for average pooling we apply:  [(Output width + padding width right + padding width left - kernel width) / (stride width)] + 1
            s=str(i)
            if s[0] == 'A':
                last_filter=arch[-2]
            else:
                last_filter=arch[-1]
            return int(last_filter*h*w)

        self.mode=mode
        self.n_classes = n_classes
        self.in_chans = in_chans
        self.input_window_samples = input_window_samples
        self.temp_len = ceil(temp_span * input_window_samples)
        self.params=compute_params(arch,in_chans,ceil((input_window_samples/pooling)-1))  

        #self.TrnsEnc= TransformerEncoder(num_layers=8 ,
                      #                        input_dim=input_window_samples,
                      #                        dim_feedforward=56,
                       #                       num_heads=8,
                         #                     dropout=0.2)# CoAtNet(2,480)


        self.tidnet_temp = TemporalTIDNet(t_filters=t_filters,
                                     input_window_samples=input_window_samples,
                                     drop_prob=drop_prob, pooling=pooling, temp_layers=temp_layers,
                                     temp_span=temp_span)
        
        
        self.MHAtn=MultiheadAttention(input_dim=input_window_samples, embed_dim=512, num_heads=8)
        self.MHAtn2=MultiheadAttention(input_dim=64, embed_dim=128, num_heads=4)

        self.MHAtn2=nn.MultiheadAttention(embed_dim=107, num_heads=1)

        mode=[64,'A10',64,'A10']
        self.rd= spatial_aggregator(mode, 1, 10)
        #self.rd2= spatial_aggregator([16,1], 1, 5)
        


        
        ######################################################
        
        if self.mode!='1x1':
            self.tidnet_temp = TemporalTIDNet(t_filters=t_filters,
                                     input_window_samples=input_window_samples,
                                     drop_prob=drop_prob, pooling=pooling, temp_layers=temp_layers,
                                     temp_span=temp_span)     

        if self.mode=='1x1' or self.mode=='t+1x1':
            self.model = SharedSpaceTimeConv1x1(self.mode, arch, t_filters) 
            if self.mode=='1x1':
                self.params=compute_params(arch,in_chans,input_window_samples)
            self.fc_block = nn.Linear(self.params, self.n_classes)
        elif self.mode=='mlp':
            self.model = MLP_blocks(arch, in_chans*ceil((input_window_samples/pooling))*t_filters)
            self.fc_block = nn.Linear(arch[-1], self.n_classes)

        self.fc_block2 = nn.Linear(3200, self.n_classes)

        #for transformer test
    def forward(self, x):

      x=self.rd(x)
      print(x.shape)
      x=torch.squeeze(x)
      x,_ = self.MHAtn2(x,x,x)
      #print(x.shape)
      #x=x.reshape([-1,25,238])
      #x= self.MHAtn2(x)
      #x=self.rd2(x)
      x = torch.flatten(x, 1)
      x = self.fc_block2(x)
      return x
    

    def get_emb(self, x):
        return self.tidnet_temp(x)

    
'''
    def forward(self, x):
        """Forward pass.
        Parameters
        ----------
        x: torch.Tensor
            Batch of EEG windows of shape (batch_size, n_channels, n_times).
        """
        if self.mode!='1x1':
            x = self.tidnet_temp(x)
        
        if self.mode!='mlp':
            x = self.model(x)
            x = torch.flatten(x, 1)
            out = self.fc_block(x)

        else :
            x = torch.flatten(x, 1)
            x = self.model(x)
            out = self.fc_block(x)
        return out
    
    def get_emb(self, x):
        return self.tidnet_temp(x)'''

############################################################################


'\n    def forward(self, x):\n        """Forward pass.\n        Parameters\n        ----------\n        x: torch.Tensor\n            Batch of EEG windows of shape (batch_size, n_channels, n_times).\n        """\n        if self.mode!=\'1x1\':\n            x = self.tidnet_temp(x)\n        \n        if self.mode!=\'mlp\':\n            x = self.model(x)\n            x = torch.flatten(x, 1)\n            out = self.fc_block(x)\n\n        else :\n            x = torch.flatten(x, 1)\n            x = self.model(x)\n            out = self.fc_block(x)\n        return out\n    \n    def get_emb(self, x):\n        return self.tidnet_temp(x)'

###Train

In [5]:
import wandb
import os
import torch 
from torch import optim
from sklearn.metrics import f1_score
import torch.nn.functional as F
from torch.autograd import Variable
from torchinfo import summary
import gc



from braindecode.models import EEGNetv4,TIDNet, EEGResNet
def build_network(mode, arch, n_classes=4, in_chans=25, input_window_samples=1125, t_filters=32,
                 drop_prob=0.4, pooling=15, temp_layers=2, temp_span=0.05):
    #model=coatnet(n_classes=4, n_channels=22, embed_dim=256, num_heads=4, conv_kernel=10, pool_kernel=4, incep=1)
    #model=RNN(input_size=1125, hidden_size=128 , num_layers=4, num_classes=4).to('cuda')
    #model=TemporalShareSpaceTime(mode, arch, n_classes, in_chans, input_window_samples, t_filters,
     #            drop_prob, pooling, temp_layers, temp_span)
    #model=MLP(arch, vae)
    model=EEGNetv4(in_chans, n_classes, input_window_samples)
  
    return model


def build_optimizer(network, optimizer, learning_rate):
    if optimizer == "sgd":
        optimizer = optim.SGD(network.parameters(),
                              lr=learning_rate, momentum=0.9, weight_decay=0.5*0.001)
    elif optimizer == "adamw":
        optimizer = optim.AdamW(network.parameters(),
                               lr=learning_rate,  weight_decay=0.01, amsgrad=True)
    return optimizer



####### Training using mixup #####################################################################
def mixup_data(x, y, alpha=8.0, use_cuda=True):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


def train_mixup(network, optimizer, T_x, T_y, batch_size, tr_iter):
    cumu_loss = 0
    correct = 0.0
    total = 0.0
    data_perm = torch.randperm(T_x.shape[0])
    criterion = torch.nn.CrossEntropyLoss()   
    
    network.train()

    for i in range(T_x.shape[0] // batch_size + (1 if T_x.shape[0] % batch_size != 0 else 0)):
        data, target = T_x[data_perm[batch_size*i: batch_size*(i+1)]], T_y[data_perm[batch_size*i: batch_size*(i+1)]] 

        # Clear the gradients
        optimizer.zero_grad()

        # Forward pass :
        ## implement mixup with alpha preset to 2
        inputs, targets_a, targets_b, lam = mixup_data(data, target, use_cuda=True)
        inputs, targets_a, targets_b = map(Variable, (inputs,
                                                    targets_a, targets_b))


        outputs = network(inputs)

        loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
        
        cumu_loss += loss.item()
    
        # ⬅ Backward pass + weight update
        loss.backward()
        optimizer.step()

        # compute accuracy
        # Get predictions from the maximum value
        _, predicted = torch.max(outputs.data, 1)

        # Total number of labels
        total += target.size(0)
        correct += (lam * predicted.eq(targets_a.data).cpu().sum().float()
                    + (1 - lam) * predicted.eq(targets_b.data).cpu().sum().float())

    
        tr_iter=tr_iter+1

    return cumu_loss / T_x.shape[0], correct/total, tr_iter
############################################################################

def train_epoch(network, optimizer, T_x, T_y, batch_size, tr_iter):
    cumu_loss = 0
    correct = 0.0
    total = 0.0
    data_perm = torch.randperm(T_x.shape[0])
    
    network.train()

    for i in range(T_x.shape[0] // batch_size):# + (1 if T_x.shape[0] % batch_size != 0 else 0)):
        data, target = T_x[data_perm[batch_size*i: batch_size*(i+1)]], T_y[data_perm[batch_size*i: batch_size*(i+1)]]        

        r=random.uniform(0, 1)
        #if r<0.4:
          #data=apply_da(data,selection)
        # Clear the gradients
        optimizer.zero_grad()

        # Forward pass 

        outputs = network(data)
        loss = F.cross_entropy(outputs, target)
        cumu_loss += loss.item()
    
        # ⬅ Backward pass + weight update
        loss.backward()
        optimizer.step()
    
        # compute accuracy
        # Get predictions from the maximum value
        _, predicted = torch.max(outputs.data, 1)

        # Total number of labels
        total += target.size(0)
        correct += (predicted == target).sum()


        tr_iter=tr_iter+1

    return cumu_loss / T_x.shape[0], correct/total, tr_iter


def validate_epoch(network, T_x, T_y, batch_size, v_itr):
    cumu_loss = 0
    correct = 0.0
    total = 0.0
    data_perm = torch.randperm(T_x.shape[0])
    
    network.eval()
    
    with torch.no_grad():
    
        for i in range(T_x.shape[0] // batch_size + (1 if T_x.shape[0] % batch_size != 0 else 0)):
            data, target = T_x[data_perm[batch_size*i: batch_size*(i+1)]], T_y[data_perm[batch_size*i: batch_size*(i+1)]]   

            loss = F.cross_entropy(network(data), target)
            cumu_loss += loss.item()     

            # compute accuracy
            outputs = network(data)

            # Get predictions from the maximum value
            _, predicted = torch.max(outputs.data, 1)

            # Total number of labels
            total += target.size(0)
            correct += (predicted == target).sum()   

            v_itr=v_itr+1
    
    return cumu_loss / T_x.shape[0], correct/total, v_itr


def test(network, T_x, T_y, batch_size, classes, t_itr):
    n_classes=len(classes)
    # Calculate Accuracy
    correct = 0.0
    correct_arr = [0.0] * n_classes
    total = 0.0
    total_arr = [0.0] * n_classes
    y_true=[]
    y_pred=[]
    pred_probs=[]
    # Iterate through test dataset
    data_perm = torch.randperm(T_x.shape[0])
    
    network.eval()
    
    with torch.no_grad():
        for i in range(T_x.shape[0] // batch_size + (1 if T_x.shape[0] % batch_size != 0 else 0)):
            data, target = T_x[data_perm[batch_size*i: batch_size*(i+1)]], T_y[data_perm[batch_size*i: batch_size*(i+1)]]   

            outputs = network(data)
            
            # Get predictions from the maximum value
            _, predicted = torch.max(outputs.data, 1)
            # Total number of labels
            total += target.size(0)
            correct += (predicted == target).sum()
            y_true.append(target.cpu().detach().numpy())
            y_pred.append(predicted.cpu().detach().numpy())
            pred_probs.append(outputs.data.cpu().detach().numpy())
            
            for label in range(n_classes):
                correct_arr[label] += (((predicted == target) & (target==label)).sum())
                total_arr[label] += (target == label).sum()
    
    
    y_true=np.array(y_true[:-1]).reshape([-1])
    y_pred=np.array(y_pred[:-1]).reshape([-1])
    pred_probs=np.array(pred_probs[:-1]).reshape([-1,n_classes])


    # Confusion Matrices
    wandb.log({"conf_mat" : wandb.plot.confusion_matrix(probs=None,
                        y_true=y_true , preds=y_pred ,
                        class_names=classes)})


    # ROC
    wandb.log({"roc" : wandb.plot.roc_curve(  y_true , pred_probs ,
                            labels=classes)})


    # Precision Recall Curve
    wandb.log({"pr" : wandb.plot.pr_curve( y_true , pred_probs ,
                        labels=classes, classes_to_plot=None)})

    
    f1=f1_score( y_true, y_pred, average='macro')
    wandb.log({'Test macro F1-Score': f1})
    print(f1)
    accuracy = correct / total
    print('TEST ACCURACY {} '.format(accuracy))
            
    t_itr=t_itr+1               
    return accuracy, f1



def train_sweep(x, y, number_of_subjects, device, bad_subjects, config, run, resultdir):
    
    gc.collect()
    torch.cuda.empty_cache()

    run_name = run.name
    batch_size= 64 #config.batch_size
    patience = 20 #config.patience
    LR=1e-3 #config.learning_rate
    optim='adamw' #config.optimizer
    
    '''mode=config.mode
    design=config.design
    p= config.power_2_of_filters
    v=1
    if mode!='mlp':
        v=config.power_2_of_Avg_pooling


    
    if design=='inverse_bottleneck':
        p=p-2
        v=1
    elif design=='bottleneck':
        p=p+3
    elif design=='fixed':
        v=1
    arch=[]

    # Building the design
    for i in range(config.number_of_layers):
        arch.append(2**p)
        if mode!='mlp':
            arch.append('A'+str(2**v))
        if design=='bottleneck':
            p-=1
            v-=1
            if v<1:
                v=1
        elif design=='inverse_bottleneck':
            p+=1
            if p>6:
                p=6
    print(arch)'''


    mode='1x1'
    arch=[1024,512,128,64,32,4]
    network = build_network(mode, arch, n_classes=4, in_chans=22, input_window_samples=1125)
 
    
    network.to(device)
    pytorch_total_params = sum(p.numel() for p in network.parameters() if p.requires_grad)
    wandb.log({"total number of parameters": pytorch_total_params})
    optimizer = build_optimizer(network, optim, LR)
    
    classes=['left fist', 'right fist', 'eyes_open', 'MI O/C feet']

    kfold= config.subject

    T_x,T_y,V_x,V_y,Test_x,Test_y=loso(x, y, number_of_subjects, kfold, device, bad_subjects, with_validation=True)

    #vae_training(vae, optim_vae, T_x, T_y, V_x, V_y,batch_size)

    Train_acc=[]
    Val_acc=[]
    Train_loss=[]
    Val_loss=[]
    Test_acc=[]
    ## Tensorboard iterators
    tr_iter=0
    v_itr=0
    maxv=0
    cpt_early = 0

    for epoch in range(config.epochs):
        try :
            train_loss,train_acc,tr_iter = train_epoch(network, optimizer, T_x, T_y, batch_size, tr_iter)
        except:
            #exit()
            pass
        #print('train loss {} accuracy {} epoch {} done'.format(train_loss,train_acc,epoch))

        val_loss,val_acc,v_itr = validate_epoch(network, V_x, V_y, batch_size, v_itr)
        #print('val loss {} epoch {} done'.format(val_loss,epoch))

        Train_acc.append(train_acc)
        Val_acc.append(val_acc)
        Train_loss.append(train_loss)
        Val_loss.append(val_loss)
        wandb.log({'Training accuracy': train_acc,'Training loss': train_loss,'Validation accuracy': val_acc,'Validation loss': val_loss});

        
        if maxv<val_acc:
            print(f"Epoch {epoch}, new best val accuracy {val_acc} and loss {val_loss}")
            maxv=val_acc

            ckpt_dict = {
            'weights': network.state_dict(),
            'train_acc': Train_acc,
            'val_acc': Val_acc,
            'train_loss': Train_loss,
            'val_acc': Val_loss,
            'epoch': epoch
            }
            torch.save(ckpt_dict,os.path.join(resultdir,f"{run_name}_bestval.pth") )
            cpt_early = 0
        else:
            cpt_early +=1
        
        if cpt_early == patience:
            print("Early Stopping")
            wandb.log({'Maximum validation accuracy': maxv})
            break
    
    print("Reloading best validation model")
    ckpt_dict = torch.load(os.path.join(resultdir,f"{run_name}_bestval.pth") )

    print(f"Reloading best model at epoch {ckpt_dict['epoch']}")
    network.load_state_dict(ckpt_dict['weights'])

    test_acc,f1=test(network, Test_x, Test_y, batch_size, classes, epoch)
    wandb.log({'Test accuracy': test_acc})

    return test_acc, f1

###Main

In [9]:
project_name="sandbox"

sweep_config = {
    'method': 'grid', #grid
    'metric': {
      'name': 'loss',
      'goal': 'minimize'   
    },
    'parameters': {
        'epochs': {
            'values': [3000]
        },
        'batch_size': {
            'values': [64]
        },
        'patience': {
            'values': [60]
        },
     
        'learning_rate': {
            'values': [1e-3]
        },
        'optimizer': {
            'values': ['adamw']
        },
         'loss': {
            'values': ['CrossEntropyLoss'],
        },
        'subject': {
            'values': [1,2,3,4,5,6,7,8,9]
        },
        'runs': {
            'values': [1,2,3]
        },
      
    }
}

sweep_id = wandb.sweep(sweep_config, project=project_name)

number_of_subjects=9
bad_subjects=[] #bad subject for Physionet MI
x, y= load_subjects('/content/drive/MyDrive/Colab Notebooks/bci4_subjects/', number_of_subjects, device, bad_subjects, apply_euclidean=False, with_eog=False)

#number_of_subjects=109
#bad_subjects=[88,90,92,100] #bad subject for Physionet MI
#x, y= load_subjects('/content/drive/MyDrive/Colab Notebooks/mne_data/', number_of_subjects, device, bad_subjects, apply_euclidean=False)

def train_wandb():
    # Initialize a new wandb run
    run = wandb.init(project=project_name, entity="brain-imt" , config=sweep_config)
    assert run is wandb.run
    with run:
        config = wandb.config
        #run=12
        test_acc, f1=train_sweep(x, y, number_of_subjects, device, bad_subjects, config, run, resultdir)
    ############################################################################

#import os
#os.environ["WANDB_MODE"]="offline"

#Test_acc= train_wandb()
wandb.agent(sweep_id, train_wandb)

KeyboardInterrupt: ignored

In [None]:
import numpy as np
accs_no_ea=[0.7024, 0.6289, 0.6145, 0.6795, 0.6892, 0.621, 0.6451, 0.6594, 0.6407, 0.5998]
accs_ea=[0.7378, 0.7378, 0.8438, 0.6563, 0.6788, 0.599, 0.6059, 0.6233, 0.651]
accs_bci=[0.73,0.72,0.84,0.65,0.67,0.599,0.6,0.62,0.65]
mlp_bci_accs=[0.57,0.58,0.55,0.6,0.74,0.66,0.81,0.67,0.58]
accs=mlp_bci_accs
accs=np.array(accs)
print('mean ',accs.mean())
print('std ',accs.std())

In [None]:
!nvidia-smi