<a href="https://colab.research.google.com/github/GhBlg/Others/blob/main/sweep_encoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install moabb
!pip install braindecode
!pip install wandb
!wandb login 33081462b5f17d9fed5c252c3fcb2071d2250425

Collecting moabb
  Downloading moabb-0.4.4-py3-none-any.whl (130 kB)
[K     |████████████████████████████████| 130 kB 5.5 MB/s 
Collecting pyriemann>=0.2.6
  Downloading pyriemann-0.2.7.tar.gz (42 kB)
[K     |████████████████████████████████| 42 kB 1.1 MB/s 
[?25hCollecting mne>=0.19
  Downloading mne-0.24.1-py3-none-any.whl (7.4 MB)
[K     |████████████████████████████████| 7.4 MB 39.5 MB/s 
[?25hCollecting coverage<6.0,>=5.5
  Downloading coverage-5.5-cp37-cp37m-manylinux2010_x86_64.whl (242 kB)
[K     |████████████████████████████████| 242 kB 48.3 MB/s 
[?25hCollecting PyYAML<6.0,>=5.0
  Downloading PyYAML-5.4.1-cp37-cp37m-manylinux1_x86_64.whl (636 kB)
[K     |████████████████████████████████| 636 kB 40.7 MB/s 
Collecting scipy<2.0,>=1.5
  Downloading scipy-1.7.3-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (38.1 MB)
[K     |████████████████████████████████| 38.1 MB 29.4 MB/s 
Building wheels for collected packages: pyriemann
  Building wheel for pyriemann (se

In [2]:
import numpy as np
from braindecode.datasets.moabb import MOABBDataset
from braindecode.datautil.windowers import create_windows_from_events
from tqdm import tqdm
from torchsummary import summary
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import f1_score
from torch.autograd import Variable
import matplotlib.pyplot as plt
from scipy.linalg import sqrtm, inv 

from torch.utils.data import Dataset

# Apply Euclidean Alignment
def apply_EA(data):
    '''
    Apply Euclidean aligment on array-like objects for 1 subject
    
    PARAMETER:
    data: 
        Data of one subject.
    
    
    OUTPUT:
        Aligned data with Euclidean Alignment
    '''
    
    # So that this function can handles separated or combined left and right trials
    # If they are separated

    # If they are not separated

    print('Found %d trial(s) in which EEG data is stored' %len(data))
    all_trials = data
    
    # Calculate reference matrix
    RefEA = 0
    print('Computing reference matrix RefEA')

    # Iterate over all trials, compute reference EA
    for trial in all_trials:
        cov = np.cov(trial, rowvar=True)
        RefEA += cov

    # Average over all trials
    RefEA = RefEA/all_trials.shape[0]
    
    # Adding reference EA as a new key in data
    data_dict={}
    print('Add RefEA as a new key in data')
    data_dict['RefEA'] = RefEA 
    
    # Compute R^(-0.5)
    R_inv = sqrtm(inv(RefEA))
    data_dict['R_inv'] = R_inv
    
        
    # Perform EA on each trial
    all_trials_EA = []
        
    for t in all_trials:
        all_trials_EA.append(R_inv@t)
        
    # Return all_trials_EA
    return np.array(all_trials_EA)
        

######################## Load all data ################################
class EEGDataset(Dataset):
    def __init__(self, X, labels=None, transforms=None):
        self.X = X
        self.y = labels
        self.transforms = transforms

    def __len__(self):
        return (len(self.X))

    def __getitem__(self, i):
        data = self.X[i,:,:]

        if self.transforms:
            data = self.transforms(data)

        if self.y is not None:
            return (data, self.y[i])
        else:
            return data
    ############################################################################
class TrainObject(object):
    def __init__(self, X, y, euclidean_alignment=True):
        assert len(X) == len(y)
        if euclidean_alignment:
            X=apply_EA(X)
        mean = np.mean(X, axis=2, keepdims=True)
        # Here standardize across the window, when channel size is not large enough
        # In motor imagery kit, we put axis = 1, across channel as an example
        std = np.std(X, axis=2, keepdims=True)
        X = (X - mean) / std
        # we scale it to 1000 as a better training scale of the shallow CNN
        # according to the orignal work of the paper referenced above
        self.X = X*1e3
        self.y = y



  warn('datautil.windowers module is deprecated and is now under '


In [3]:
######################## Load all data ################################

def load_data(loso): 

    ######################## Load all data ################################
    T_x=[]
    T_y=[]
    V_x=[]
    V_y=[]
    Test_x=[]
    Test_y=[]

    for subject_id in [e for e in range(1,10) if e not in (loso, loso-1)]:
        dataset = MOABBDataset(dataset_name="BNCI2014001", subject_ids=[subject_id])


        trial_start_offset_seconds = -0.5
        # Extract sampling frequency, check that they are same in all datasets
        sfreq = dataset.datasets[0].raw.info['sfreq']
        assert all([ds.raw.info['sfreq'] == sfreq for ds in dataset.datasets])
        # Calculate the trial start offset in samples.
        trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)

        # Create windows using braindecode function for this. It needs parameters to define how
        # trials should be used.
        windows_dataset = create_windows_from_events(
            dataset,
            trial_start_offset_samples=trial_start_offset_samples,
            trial_stop_offset_samples=0,
            preload=True,
        )


        splitted = windows_dataset.split('session')
        train_set = splitted['session_T']
        valid_set = splitted['session_E']

        train_x=np.array([ele[0][:-1] for ele in train_set])
        train_y=np.array([ele[1] for ele in train_set])

        valid_x=np.array([ele[0][:-1] for ele in valid_set])
        valid_y=np.array([ele[1] for ele in valid_set])

        train_set = TrainObject(train_x, y=train_y)
        valid_set = TrainObject(valid_x, y=valid_y)

        [T_x.append(el) for el in train_set.X]
        [T_y.append(el) for el in train_set.y]
        [T_x.append(el) for el in valid_set.X]
        [T_y.append(el) for el in valid_set.y]
    T_x=np.array(T_x)
    T_y=np.array(T_y)

    ##################### Validation Set ###################################


    subject_id = loso-1
    dataset = MOABBDataset(dataset_name="BNCI2014001", subject_ids=[subject_id])


    trial_start_offset_seconds = -0.5
    # Extract sampling frequency, check that they are same in all datasets
    sfreq = dataset.datasets[0].raw.info['sfreq']
    assert all([ds.raw.info['sfreq'] == sfreq for ds in dataset.datasets])
    # Calculate the trial start offset in samples.
    trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)

    # Create windows using braindecode function for this. It needs parameters to define how
    # trials should be used.
    windows_dataset = create_windows_from_events(
        dataset,
        trial_start_offset_samples=trial_start_offset_samples,
        trial_stop_offset_samples=0,
        preload=True,
    )


    splitted = windows_dataset.split('session')
    train_set = splitted['session_T']
    valid_set = splitted['session_E']

    train_x=np.array([ele[0][:-1] for ele in train_set])
    train_y=np.array([ele[1] for ele in train_set])

    valid_x=np.array([ele[0][:-1] for ele in valid_set])
    valid_y=np.array([ele[1] for ele in valid_set])
        
    train_set = TrainObject(train_x, y=train_y)
    valid_set = TrainObject(valid_x, y=valid_y)

    [V_x.append(el) for el in train_set.X]
    [V_y.append(el) for el in train_set.y]
    [V_x.append(el) for el in valid_set.X]
    [V_y.append(el) for el in valid_set.y]

    V_x=np.array(V_x)
    V_y=np.array(V_y)


    ############################### Test Set #############################################

    subject_id = loso
    dataset = MOABBDataset(dataset_name="BNCI2014001", subject_ids=[subject_id])


    trial_start_offset_seconds = -0.5
    # Extract sampling frequency, check that they are same in all datasets
    sfreq = dataset.datasets[0].raw.info['sfreq']
    assert all([ds.raw.info['sfreq'] == sfreq for ds in dataset.datasets])
    # Calculate the trial start offset in samples.
    trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)

    # Create windows using braindecode function for this. It needs parameters to define how
    # trials should be used.
    windows_dataset = create_windows_from_events(
        dataset,
        trial_start_offset_samples=trial_start_offset_samples,
        trial_stop_offset_samples=0,
        preload=True,
    )


    splitted = windows_dataset.split('session')
    train_set = splitted['session_T']
    valid_set = splitted['session_E']

    train_x=np.array([ele[0][:-1] for ele in train_set])
    train_y=np.array([ele[1] for ele in train_set])

    valid_x=np.array([ele[0][:-1] for ele in valid_set])
    valid_y=np.array([ele[1] for ele in valid_set])

    train_set = TrainObject(train_x, y=train_y)
    valid_set = TrainObject(valid_x, y=valid_y)

    [Test_x.append(el) for el in train_set.X]
    [Test_y.append(el) for el in train_set.y]
    [Test_x.append(el) for el in valid_set.X]
    [Test_y.append(el) for el in valid_set.y]

    Test_x=np.array(Test_x)
    Test_y=np.array(Test_y)


    ############################################################################
    ##  LOSO  : 9 total subjects = 7 for training - 1 for validation - 1 for testing 
    return(T_x,T_y,V_x,V_y,Test_x,Test_y)


In [4]:


from math import ceil
import torch
from torch import nn
from torch.nn import init
from torch.nn.utils import weight_norm
from braindecode.models.modules import Expression, Ensure4d

from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        param = parameter.numel()
        table.add_row([name, param])
        total_params+=param
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params   

class PrintLayer(nn.Module):
    def __init__(self):
        super(PrintLayer, self).__init__()
    
    def forward(self, x):
        # Do your print / debug stuff here
        print(x.shape)
        return x


class _BatchNormZG(nn.BatchNorm2d):
    def reset_parameters(self):
        if self.track_running_stats:
            self.running_mean.zero_()
            self.running_var.fill_(1)
        if self.affine:
            self.weight.data.zero_()
            self.bias.data.zero_()


class _ConvBlock2D(nn.Module):
    """Implements Convolution block with order:
    Convolution, dropout, activation, batch-norm
    """
    def __init__(self, in_filters, out_filters, kernel, stride=(1, 1), padding=0, dilation=1,
                 groups=1, drop_prob=0.5, batch_norm=True, activation=nn.LeakyReLU, residual=False):
        super().__init__()
        self.kernel = kernel
        self.activation = activation()
        self.residual = residual

        self.conv = nn.Conv2d(in_filters, out_filters, kernel, stride=stride, padding=padding,
                              dilation=dilation, groups=groups, bias=not batch_norm)
        self.dropout = nn.Dropout2d(p=drop_prob)
        self.batch_norm = _BatchNormZG(out_filters) if residual else nn.BatchNorm2d(out_filters) if\
            batch_norm else lambda x: x

    def forward(self, input):
        res = input
        input = self.conv(input,)
        input = self.dropout(input)
        input = self.activation(input)
        input = self.batch_norm(input)
        return input + res if self.residual else input
    



class _DenseFilter(nn.Module):
    def __init__(self, in_features, growth_rate, filter_len=5, drop_prob=0.5, bottleneck=2,
                 activation=nn.LeakyReLU, dim=-2):
        super().__init__()
        dim = dim if dim > 0 else dim + 4
        if dim < 2 or dim > 3:
            raise ValueError('Only last two dimensions supported')
        kernel = (filter_len, 1) if dim == 2 else (1, filter_len)

        self.net = nn.Sequential(
            nn.BatchNorm2d(in_features),
            activation(),
            nn.Conv2d(in_features, bottleneck * growth_rate, 1),
            nn.BatchNorm2d(bottleneck * growth_rate),
            activation(),
            nn.Conv2d(bottleneck * growth_rate, growth_rate, kernel,
                      padding=tuple((k // 2 for k in kernel))),
            nn.Dropout2d(drop_prob)
        )

    def forward(self, x):
        return torch.cat((x, self.net(x)), dim=1)


class _DenseSpatialFilter(nn.Module):
    def __init__(self, in_chans, growth, depth, in_ch=1, bottleneck=4, drop_prob=0.0,
                 activation=nn.LeakyReLU, collapse=True):
        super().__init__()
        self.net = nn.Sequential(*[
            _DenseFilter(in_ch + growth * d, growth, bottleneck=bottleneck, drop_prob=drop_prob,
                         activation=activation) for d in range(depth)
        ])
        n_filters = in_ch + growth * depth
        self.collapse = collapse
        if collapse:
            self.channel_collapse = _ConvBlock2D(n_filters, n_filters, (in_chans, 1), drop_prob=0)

    def forward(self, x):
        if len(x.shape) < 4:
            x = x.unsqueeze(1).permute([0, 1, 3, 2])
        x = self.net(x)
        if self.collapse:
            return self.channel_collapse(x).squeeze(-2)
        return x


class _TemporalFilter(nn.Module):
    def __init__(self, in_chans, filters, depth, temp_len, drop_prob=0., activation=nn.LeakyReLU,
                 residual='netwise'):
        super().__init__()
        temp_len = temp_len + 1 - temp_len % 2
        self.residual_style = str(residual)
        net = list()

        for i in range(depth):
            dil = depth - i
            conv = weight_norm(nn.Conv2d(in_chans if i == 0 else filters, filters,
                                         kernel_size=(1, temp_len), dilation=dil,
                                         padding=(0, dil * (temp_len - 1) // 2)))
            net.append(nn.Sequential(
                conv,
                activation(),
                nn.Dropout2d(drop_prob)
            ))
        if self.residual_style.lower() == 'netwise':
            self.net = nn.Sequential(*net)
            self.residual = nn.Conv2d(in_chans, filters, (1, 1))
        elif residual.lower() == 'dense':
            self.net = net

    def forward(self, x):
        if self.residual_style.lower() == 'netwise':
            return self.net(x) + self.residual(x)
        elif self.residual_style.lower() == 'dense':
            for layer in self.net:
                x = torch.cat((x, layer(x)), dim=1)
            return x


class _TIDNetFeatures(nn.Module):
    def __init__(self,  oo, loop ,s_growth, t_filters, in_chans, input_window_samples, drop_prob, pooling,
                 temp_layers, spat_layers, temp_span, bottleneck, summary):
        super().__init__()
        self.in_chans = in_chans
        self.input_windows_samples = input_window_samples
        self.temp_len = ceil(temp_span * input_window_samples)

        def _permute(x):
            """
            Permutes data:
            from dim:
            batch, chans, time, 1
            to dim:
            batch, 1, chans, time
            """
            return x.permute([0, 3, 1, 2])

        self.temporal = nn.Sequential(
            Ensure4d(),
            Expression(_permute),
            _TemporalFilter(1, t_filters, depth=temp_layers, temp_len=self.temp_len),
            nn.MaxPool2d((1, pooling)),
            nn.Dropout2d(drop_prob),
        )
        summary = input_window_samples // pooling if summary == -1 else summary

        self.spatial = _DenseSpatialFilter(in_chans, s_growth, spat_layers, in_ch=t_filters,
                                           drop_prob=drop_prob, bottleneck=bottleneck)
        self.extract_features = nn.Sequential(
            nn.AdaptiveAvgPool1d(int(summary)),
            nn.Flatten(start_dim=1)
        )

        self._num_features = (t_filters + s_growth * spat_layers) * summary

        
        self.loop=loop
        self.pool = nn.MaxPool2d(2, 2)

        self.enc0 = nn.Conv2d(32, oo, (1,1))
        self.enc1 = nn.Conv2d(oo, oo, (1,1))    

        ### compute input of fc after flatten
        #after enc0
        h= ((25-1)/1)+1  
        w= ((75-1)/1)+1  
        h=int(  (h-2)/2 +1 )
        w=int( (w-2)/2 +1 )

        #after enc1 loop
        for i in range(self.loop):
            h=((h-1)/1)+1  
            w= ((w-1)/1)+1  
            h= int( (h-2)/2 +1 )
            w=int( (w-2)/2 +1 )
        ##################

          
        self.fc_block2 = nn.Linear(oo*h*w, 4)

    @property
    def num_features(self):
        return self._num_features

    def forward(self, x):
        x = self.temporal(x)
        #x = self.spatial(x)
        #return self.extract_features(x)

        x=F.relu(self.pool(self.enc0(x)))

        for i in range(self.loop):
            x=F.relu(self.pool(self.enc1(x)))

        x = torch.flatten(x, 1)
        out = self.fc_block2(x)
        
        return out


########################################### 

class TIDNet_features(nn.Module):

    def __init__(self, oo, loop, in_chans, n_classes, input_window_samples, s_growth=24, t_filters=32,
                 drop_prob=0.4, pooling=15, temp_layers=2, spat_layers=2, temp_span=0.05,
                 bottleneck=3, summary=-1):
        super().__init__()
        self.n_classes = n_classes
        self.in_chans = in_chans
        self.input_window_samples = input_window_samples
        self.temp_len = ceil(temp_span * input_window_samples)

        self.dscnn = _TIDNetFeatures( oo, loop,s_growth=s_growth, t_filters=t_filters, in_chans=in_chans,
                                     input_window_samples=input_window_samples,
                                     drop_prob=drop_prob, pooling=pooling, temp_layers=temp_layers,
                                     spat_layers=spat_layers, temp_span=temp_span,
                                     bottleneck=bottleneck, summary=summary)

        self._num_features = self.dscnn.num_features

        self.classify = self._create_classifier(self.num_features, n_classes)

    def _create_classifier(self, incoming, n_classes):
        classifier = nn.Linear(incoming, n_classes)
        init.xavier_normal_(classifier.weight)
        classifier.bias.data.zero_()
        return nn.Sequential(nn.Flatten(start_dim=1), classifier, nn.LogSoftmax(dim=-1))

    def forward(self, x):
        """Forward pass.
        Parameters
        ----------
        x: torch.Tensor
            Batch of EEG windows of shape (batch_size, n_channels, n_times).
        """

        x = self.dscnn(x)
        return x #self.classify(x)
    
    def get_emb(self, x):
        return self.dscnn(x)

    @property
    def num_features(self):
        return self._num_features
############################################################################


In [5]:
############################################################################

cuda = torch.cuda.is_available()  # check if GPU is available, if True chooses to use it
device = 'cuda' if cuda else 'cpu'

############################################################################


def build_network(oo,loop):

    model=TIDNet_features(oo,loop, n_classes=4, in_chans=25, input_window_samples=1125, s_growth=24, t_filters=32,
                 drop_prob=0.4, pooling=15, temp_layers=2, spat_layers=2, temp_span=0.05,
                 bottleneck=3, summary=-1)
    return model.to(device)


def build_optimizer(network, optimizer, learning_rate):
    if optimizer == "sgd":
        optimizer = optim.SGD(network.parameters(),
                              lr=learning_rate, momentum=0.9, weight_decay=0.5*0.001)
    elif optimizer == "adamw":
        optimizer = optim.AdamW(network.parameters(),
                               lr=learning_rate,  weight_decay=0.01, amsgrad=True)
    return optimizer


def train_epoch(network, loader, optimizer, loss_config, batch_size, tr_iter):
    cumu_loss = 0
    correct = 0.0
    total = 0.0
    
    for i, (data, target) in tqdm(enumerate(loader), ncols = 100, total=68,
               desc ="Training"):
        data, target = data.to(device), target.to(device)
        if data.shape[0]==batch_size:

            data.double()
            target.long()
            optimizer.zero_grad()
            network.double()



            # ➡ Forward pass
            if loss_config == "nll_loss":
                #loss = floss.forward(network(data.double()), target.long())
                loss = F.nll_loss(network(data.double()), target.long())
                cumu_loss += loss.item()
            elif loss_config =='CrossEntropyLoss':
                loss = F.cross_entropy(network(data.double()), target.long())
                cumu_loss += loss.item()
        
            # ⬅ Backward pass + weight update
            loss.backward()
            optimizer.step()
        
             # compute accuracy
            outputs = network(data.double())

            # Get predictions from the maximum value
            _, predicted = torch.max(outputs.data, 1)

            # Total number of labels
            total += target.size(0)
            correct += (predicted == target).sum()


            tr_iter=tr_iter+1

    return cumu_loss / len(loader), correct/total, tr_iter


def validate_epoch(network, loader, optimizer, loss_config, batch_size, v_itr):
    cumu_loss = 0.0
    correct = 0.0
    total = 0.0
   
    for _, (data, target) in enumerate(loader):
        data, target = data.to(device), target.to(device)
        if data.shape[0]==batch_size:
            data.double()
            target.long()
            
        
            optimizer.zero_grad()
            network.double()
            network.eval()  
            torch.no_grad()
            
            # Compute loss
            if loss_config == "nll_loss":
                #loss = F.nll_loss(network(data.double()), target.long())
                cumu_loss += loss.item()
            elif loss_config =='CrossEntropyLoss':
                loss = F.cross_entropy(network(data.double()), target.long())
                cumu_loss += loss.item()      

          # compute accuracy
            outputs = network(data.double())

            # Get predictions from the maximum value
            _, predicted = torch.max(outputs.data, 1)

            # Total number of labels
            total += target.size(0)
            correct += (predicted == target).sum()   

            v_itr=v_itr+1
    return cumu_loss / len(loader), correct/total, v_itr


def test(network, loader, batch_size, n_classes, t_itr):
    # Calculate Accuracy
    correct = 0.0
    correct_arr = [0.0] * n_classes
    total = 0.0
    total_arr = [0.0] * n_classes
    y_true=[]
    y_pred=[]
    # Iterate through test dataset
    for data, target in loader:
        data, target = data.to(device), target.to(device)
        if data.shape[0] == batch_size:  #condition to avoid taking trials length < batch size (problematic for confusion matrix)
            target.long()
            network.double()
            outputs = network(data.double())
                    # Get predictions from the maximum value
            _, predicted = torch.max(outputs.data, 1)
            # Total number of labels
            total += target.size(0)
            correct += (predicted == target).sum()
            y_true.append(target)
            y_pred.append(predicted)
           
            for label in range(n_classes):
                correct_arr[label] += (((predicted == target) & (target==label)).sum())
                total_arr[label] += (target == label).sum()

    accuracy = correct / total
    print('TEST ACCURACY {} '.format(accuracy))
    
    #print('TEST F1-Score {} '.format(f1_score(torch.tensor(np.array(y_true,'int32')).view(-1), torch.tensor(np.array(y_pred,'int32')).view(-1),  average='macro')))
    t_itr=t_itr+1               
    return accuracy, t_itr

import wandb

project_name="sweep_encoder_f_test"


sweep_config = {
    'method': 'grid', #grid
    'metric': {
      'name': 'loss',
      'goal': 'minimize'   
    },
    'parameters': {
        'epochs': {
            'values': [9]
        },
        'batch_size': {
            'values': [60]
        },
     
        'learning_rate': {
            'values': [1e-3]
        },
        'optimizer': {
            'values': ['adamw']
        },
         'loss': {
            'values': ['CrossEntropyLoss'],
        },
        'subject': {
            'values': [3,4,5,6,7,8,9]
        },
        'number_of_layers': {
            'values': [1,2,3,4]
        },
        'number_of_filters': {
            'values': [1,5,32,64]
        },
      
    }
}

sweep_id = wandb.sweep(sweep_config, project=project_name)

def train_wandb():
    # Initialize a new wandb run
    with wandb.init(project=project_name, #entity="brain-imt", 
    config=sweep_config):
        config = wandb.config

        n_classes=4
        batch_size=config.batch_size
        LR=config.learning_rate
        optim= config.optimizer
        loss=config.loss


        network = build_network(1,1)#config.number_of_filters, config.number_of_layers)
        
        pytorch_total_params = sum(p.numel() for p in network.parameters())
        print(pytorch_total_params)
        #wandb.log({"total number of parameters": pytorch_total_params})
        optimizer = build_optimizer(network, optim, LR)

        T_x,T_y, V_x,V_y,Test_x,Test_y=load_data(5)#config.subject)

        tr_x=T_x
        tr_y=T_y


        val_x=V_x
        val_y=V_y


        test_x=Test_x
        test_y=Test_y

        train_data = EEGDataset(tr_x, tr_y, transforms=None)
        valid_data = EEGDataset(val_x, val_y, transforms=None)
        test_data = EEGDataset(test_x, test_y, transforms=None)

        del tr_x, tr_y, val_x, val_y, test_x, test_y



        train_loader = DataLoader(train_data, batch_size=batch_size,shuffle=True)
        valid_loader = DataLoader(valid_data, batch_size=batch_size,shuffle=True)
        test_loader = DataLoader(test_data, batch_size=batch_size,shuffle=True)

        Train_acc=[]
        Val_acc=[]
        Train_loss=[]
        Val_loss=[]
        Test_acc=[]
    ## Tensorboard iterators
        tr_iter=0
        v_itr=0
        maxv=0

        summary(network.cuda(), (25, 1125))


        for epoch in range(config.epochs):

            train_loss,train_acc,tr_iter = train_epoch(network, train_loader, optimizer, loss, batch_size, tr_iter)
            print('train loss {} accuracy {} epoch {} done'.format(train_loss,train_acc,epoch))
            val_loss,val_acc,v_itr = validate_epoch(network, valid_loader, optimizer, loss, batch_size, v_itr)
            print('val loss {} epoch {} done'.format(val_loss,epoch))
            Train_acc.append(train_acc)
            Val_acc.append(val_acc)
            Train_loss.append(train_loss)
            Val_loss.append(val_loss)
            wandb.log({'Training accuracy': train_acc, 'Training loss': train_loss})
            wandb.log({'Validation accuracy': val_acc, 'Validation loss': val_loss})
            if epoch % 3 == 0:
                test_acc,_=test(network, test_loader, batch_size, n_classes, epoch)
                Test_acc.append(test_acc)
                wandb.log({'Test accuracy': test_acc})
            if maxv<val_acc:
                wandb.log({'Maximum validation accuracy': val_acc})
                maxv=val_acc
            


        
        '''x_np=[]
        y_np=[]
        for data, target in train_loader:
            network.to('cpu')
            x_np.append(network.get_emb(data).detach().numpy().tolist())
            y_np.append(target)

        v=[]
        for i in x_np :
            for j in i:
                v.append(np.array(j))  
        vv=[]
        for i in y_np :
            for j in i:
                vv.append(np.array(j))
        print('X shape  ',np.array(v).shape)
        print('Y shape  ',np.array(vv).shape)
        np.savez_compressed('train.npz',x=np.array(v),y=np.array(vv))'''

        return Train_acc, Val_acc, Test_acc

    ############################################################################


Create sweep with ID: twd32prx
Sweep URL: https://wandb.ai/ghblg/sweep_encoder_f_test/sweeps/twd32prx


In [None]:
#import os
#os.environ["WANDB_MODE"]="offline"

wandb.agent(sweep_id, train_wandb)

[34m[1mwandb[0m: Agent Starting Run: 45eevl5w with config:
[34m[1mwandb[0m: 	batch_size: 60
[34m[1mwandb[0m: 	epochs: 9
[34m[1mwandb[0m: 	learning_rate: 0.001
[34m[1mwandb[0m: 	loss: CrossEntropyLoss
[34m[1mwandb[0m: 	number_of_filters: 1
[34m[1mwandb[0m: 	number_of_layers: 1
[34m[1mwandb[0m: 	optimizer: adamw
[34m[1mwandb[0m: 	subject: 3
[34m[1mwandb[0m: Currently logged in as: [33mghblg[0m (use `wandb login --relogin` to force relogin)


  set_config(key, get_config("MNE_DATA"))
Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A01T.mat' to file '/root/mne_data/MNE-bnci-data/database/data-sets/001-2014/A01T.mat'.


269291
MNE_DATA is not already configured. It will be set to default location in the home directory - /root/mne_data
All datasets will be downloaded to this location, if anything is already downloaded, please move manually to this location


SHA256 hash of downloaded file: 054f02e70cf9c4ada1517e9b9864f45407939c1062c6793516585c6f511d0325
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A01E.mat' to file '/root/mne_data/MNE-bnci-data/database/data-sets/001-2014/A01E.mat'.
SHA256 hash of downloaded file: 53d415f39c3d7b0c88b894d7b08d99bcdfe855ede63831d3691af1a45607fb62
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.


48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Adding metadata with 4 columns
Replacing existing metadata with 4 columns
48 matching events found
No baseline correction applied
0 projection items activated
Loading data for 48 events and 1125 original time points ...
0 bad epochs dropped
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Adding metadata with 4 columns
Replacing existing metadata with 4 columns
48 matching events found
No baseline correction applied
0 projection items activated
Lo

Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A02T.mat' to file '/root/mne_data/MNE-bnci-data/database/data-sets/001-2014/A02T.mat'.
SHA256 hash of downloaded file: 5ddd5cb520b1692c3ba1363f48d98f58f0e46f3699ee50d749947950fc39db27
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A02E.mat' to file '/root/mne_data/MNE-bnci-data/database/data-sets/001-2014/A02E.mat'.
SHA256 hash of downloaded file: d63c454005d3a9b41d8440629482e855afc823339bdd0b5721842a7ee9cc7b12
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.


48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Adding metadata with 4 columns
Replacing existing metadata with 4 columns
48 matching events found
No baseline correction applied
0 projection items activated
Loading data for 48 events and 1125 original time points ...
0 bad epochs dropped
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Adding metadata with 4 columns
Replacing existing metadata with 4 columns
48 matching events found
No baseline correction applied
0 projection items activated
Lo

Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A03T.mat' to file '/root/mne_data/MNE-bnci-data/database/data-sets/001-2014/A03T.mat'.
SHA256 hash of downloaded file: 7e731ee8b681d5da6ecb11ae1d4e64b1653c7f15aad5d6b7620b25ce53141e80
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A03E.mat' to file '/root/mne_data/MNE-bnci-data/database/data-sets/001-2014/A03E.mat'.
SHA256 hash of downloaded file: d4229267ec7624fa8bd3af5cbebac17f415f7c722de6cb676748f8cb3b717d97
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.


48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Adding metadata with 4 columns
Replacing existing metadata with 4 columns
48 matching events found
No baseline correction applied
0 projection items activated
Loading data for 48 events and 1125 original time points ...
0 bad epochs dropped
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Adding metadata with 4 columns
Replacing existing metadata with 4 columns
48 matching events found
No baseline correction applied
0 projection items activated
Lo

Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A06T.mat' to file '/root/mne_data/MNE-bnci-data/database/data-sets/001-2014/A06T.mat'.
SHA256 hash of downloaded file: 4dc3be1b0d60279134d1220323c73c68cf73799339a7fb224087a3c560a9a7e2
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A06E.mat' to file '/root/mne_data/MNE-bnci-data/database/data-sets/001-2014/A06E.mat'.
SHA256 hash of downloaded file: bf67a40621b74b6af7a986c2f6edfff7fc2bbbca237aadd07b575893032998d1
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.


48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Adding metadata with 4 columns
Replacing existing metadata with 4 columns
48 matching events found
No baseline correction applied
0 projection items activated
Loading data for 48 events and 1125 original time points ...
0 bad epochs dropped
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Adding metadata with 4 columns
Replacing existing metadata with 4 columns
48 matching events found
No baseline correction applied
0 projection items activated
Lo

Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A07T.mat' to file '/root/mne_data/MNE-bnci-data/database/data-sets/001-2014/A07T.mat'.
SHA256 hash of downloaded file: 43b6bbef0be78f0ac2b66cb2d9679091f1f5b7f0a5d4ebef73d2c7cc8e11aa96
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A07E.mat' to file '/root/mne_data/MNE-bnci-data/database/data-sets/001-2014/A07E.mat'.
SHA256 hash of downloaded file: b9aaec73dcee002fab84ee98e938039a67bf6a3cbf4fc86d5d8df198cfe4c323
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.


48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Adding metadata with 4 columns
Replacing existing metadata with 4 columns
48 matching events found
No baseline correction applied
0 projection items activated
Loading data for 48 events and 1125 original time points ...
0 bad epochs dropped
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Adding metadata with 4 columns
Replacing existing metadata with 4 columns
48 matching events found
No baseline correction applied
0 projection items activated
Lo

Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A08T.mat' to file '/root/mne_data/MNE-bnci-data/database/data-sets/001-2014/A08T.mat'.
SHA256 hash of downloaded file: 7a4b3bd602d5bc307d3f4527fca2cf076659e94aca584dd64f6286fd413a82f2
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A08E.mat' to file '/root/mne_data/MNE-bnci-data/database/data-sets/001-2014/A08E.mat'.
SHA256 hash of downloaded file: 0eedbd89790c7d621c8eef68065ddecf80d437bbbcf60321d9253e2305f294f7
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.


48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Adding metadata with 4 columns
Replacing existing metadata with 4 columns
48 matching events found
No baseline correction applied
0 projection items activated
Loading data for 48 events and 1125 original time points ...
0 bad epochs dropped
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Adding metadata with 4 columns
Replacing existing metadata with 4 columns
48 matching events found
No baseline correction applied
0 projection items activated
Lo

Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A09T.mat' to file '/root/mne_data/MNE-bnci-data/database/data-sets/001-2014/A09T.mat'.
SHA256 hash of downloaded file: b28d8a262c779c8cad9cc80ee6aa9c5691cfa6617c03befe490a090347ebd15c
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A09E.mat' to file '/root/mne_data/MNE-bnci-data/database/data-sets/001-2014/A09E.mat'.
SHA256 hash of downloaded file: 5d79649a42df9d51215def8ffbdaf1c3f76c54b88b9bbaae721e8c6fd972cc36
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.


48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Adding metadata with 4 columns
Replacing existing metadata with 4 columns
48 matching events found
No baseline correction applied
0 projection items activated
Loading data for 48 events and 1125 original time points ...
0 bad epochs dropped
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Adding metadata with 4 columns
Replacing existing metadata with 4 columns
48 matching events found
No baseline correction applied
0 projection items activated
Lo

Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A04T.mat' to file '/root/mne_data/MNE-bnci-data/database/data-sets/001-2014/A04T.mat'.
SHA256 hash of downloaded file: 15850d81b95fc88cc8b9589eb9b713d49fa071e28adaf32d675b3eaa30591d6e
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A04E.mat' to file '/root/mne_data/MNE-bnci-data/database/data-sets/001-2014/A04E.mat'.
SHA256 hash of downloaded file: 81916dff2c12997974ba50ffc311da006ea66e525010d010765f0047e771c86a
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.


48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Adding metadata with 4 columns
Replacing existing metadata with 4 columns
48 matching events found
No baseline correction applied
0 projection items activated
Loading data for 48 events and 1125 original time points ...
0 bad epochs dropped
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Adding metadata with 4 columns
Replacing existing metadata with 4 columns
48 matching events found
No baseline correction applied
0 projection items activated
Lo

Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A05T.mat' to file '/root/mne_data/MNE-bnci-data/database/data-sets/001-2014/A05T.mat'.
SHA256 hash of downloaded file: 77387d3b669f4ed9a7c1dac4dcba4c2c40c8910bae20fb961bb7cf5a94912950
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A05E.mat' to file '/root/mne_data/MNE-bnci-data/database/data-sets/001-2014/A05E.mat'.
SHA256 hash of downloaded file: 8b357470865610c28b2f1d351beac247a56a856f02b2859d650736eb2ef77808
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.


48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Adding metadata with 4 columns
Replacing existing metadata with 4 columns
48 matching events found
No baseline correction applied
0 projection items activated
Loading data for 48 events and 1125 original time points ...
0 bad epochs dropped
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Adding metadata with 4 columns
Replacing existing metadata with 4 columns
48 matching events found
No baseline correction applied
0 projection items activated
Lo

Training: 100%|█████████████████████████████████████████████████████| 68/68 [02:23<00:00,  2.11s/it]


train loss 1.3671447039096092 accuracy 0.2497512549161911 epoch 0 done
val loss 1.248501957565023 epoch 0 done
TEST ACCURACY 0.2518518567085266 


Training: 100%|█████████████████████████████████████████████████████| 68/68 [02:22<00:00,  2.10s/it]


train loss 1.3666763270446074 accuracy 0.2492537498474121 epoch 1 done
val loss 1.247765044170562 epoch 1 done


Training:  71%|█████████████████████████████████████▍               | 48/68 [01:40<00:42,  2.14s/it]