In [None]:
import pickle
import torch
import numpy as np
import sys
import time

from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce


from sklearn.model_selection import train_test_split,cross_val_score, KFold
from sklearn.preprocessing import StandardScaler
import torch.utils.data as Data
import torch.nn.functional as F
from sklearn.metrics import confusion_matrix, accuracy_score,balanced_accuracy_score #roc_auc_score,precision_score,recall_score,f1_score,classification_report

from sklearn.utils import class_weight
import pyriemann
import ot

import matplotlib.pyplot as plt
import mne

import myimporter
from BCI_functions import *  # BCI_functions.ipynb contains some functions we might use multiple times in this tutorial
import warnings
warnings.filterwarnings('ignore')
import os
os.getcwd()

In [None]:
# TODO: This class if has list of subject id can later support combination of sub ids
# TODO: add a function transform to convert dataset to train test, avoiding repetition of same code

class EEGMMIDTrSet(Data.Dataset):
    def __init__(self, subject_id, transform=None):
        root_dir = "../Deep-Learning-for-BCI/dataset/"
        dataset_raw = np.load(root_dir + str(subject_id) + '.npy')
        dataset=[]  # feature after filtering

        # EEG Gamma pattern decomposition
        for i in range(dataset_raw[:,:-1].shape[1]):
            x = dataset_raw[:, i]
            fs = 160.0
            lowcut = 8.0
            highcut = 30.0
            y = butter_bandpass_filter(x, lowcut, highcut, fs, order=3)
            dataset.append(y)
        dataset=np.array(dataset).T
        dataset=np.hstack((dataset,dataset_raw[:,-1:]))
        print(dataset.shape)
        # keep 4,5 which are left and right fist open close imagery classes, remove rest
        # refer 1-Data.ipynb for the details
        removed_label = [0,1,6,7,8,9,10]  # [0,1,2,3,4,5,10] for hf # [0,1,6,7,8,9,10] for lr
        for ll in removed_label:
            id = dataset[:, -1]!=ll
            dataset = dataset[id]

        # Pytorch needs labels to be sequentially ordered starting from 0
        dataset[:, -1][dataset[:, -1] == 2] = 0
        dataset[:, -1][dataset[:, -1] == 4] = 0
        dataset[:, -1][dataset[:, -1] == 3] = 1
        dataset[:, -1][dataset[:, -1] == 5] = 1
#         dataset[:, -1][dataset[:, -1] == 10] = 2
        
        # data segmentation
        n_class = 2 #int(11-len(removed_label))  # 0~9 classes ('10:rest' is not considered)
        no_feature = 64  # the number of the features
        segment_length = 160 #160  # selected time window; 16=160*0.1
        
        #Overlapping is removed to avoid training set overlap with test set
        data_seg = extract(dataset, n_classes=n_class, n_fea=no_feature, 
                           time_window=segment_length, moving=(segment_length))  # /2 for 50% overlapping
        print('After segmentation, the shape of the data:', data_seg.shape)

        # split training and test data
        no_longfeature = no_feature*segment_length
        data_seg_feature = data_seg[:, :no_longfeature]
        self.data_seg_label = data_seg[:, no_longfeature:no_longfeature+1]
        
        # Its important to have random state set equal for Training and test dataset
        train_feature, test_feature, train_label, test_label = train_test_split(
            data_seg_feature, self.data_seg_label,random_state=0, shuffle=True,stratify=self.data_seg_label)

        # Check the class label splits to maintain balance
        unique, counts = np.unique(self.data_seg_label, return_counts=True)
        left_perc = counts[0]/sum(counts)
        if left_perc < 0.4 or left_perc > 0.6:
            print("Imbalanced dataset with split of: ",left_perc,1-left_perc)
        else:
            print("Classes balanced.")
        unique, counts = np.unique(train_label, return_counts=True)
        print("Class label splits in training set \n ",np.asarray((unique, counts)).T)
        unique, counts = np.unique(test_label, return_counts=True)
        print("Class label splits in test set\n ",np.asarray((unique, counts)).T)



        # normalization
        # before normalize reshape data back to raw data shape
        train_feature_2d = train_feature.reshape([-1, no_feature])
        test_feature_2d = test_feature.reshape([-1, no_feature])

        scaler1 = StandardScaler().fit(train_feature_2d)
        train_fea_norm1 = scaler1.transform(train_feature_2d) # normalize the training data
        test_fea_norm1 = scaler1.transform(test_feature_2d) # normalize the test data
        print('After normalization, the shape of training feature:', train_fea_norm1.shape,
              '\nAfter normalization, the shape of test feature:', test_fea_norm1.shape)
        
        # after normalization, reshape data to 3d
        train_fea_norm1 = train_fea_norm1.reshape([-1, segment_length, no_feature])
        test_fea_norm1 = test_fea_norm1.reshape([-1, segment_length, no_feature])
        print('After reshape, the shape of training feature:', train_fea_norm1.shape,
              '\nAfter reshape, the shape of test feature:', test_fea_norm1.shape)
        
        # reshape for data shape: (trial, conv channel, electrode channel, time samples)
        # earlier it was (trial,timesamples,electrode_channel)
        train_fea_reshape1 = np.swapaxes(train_fea_norm1,1,2)
        test_fea_reshape1 = np.swapaxes(test_fea_norm1,1,2)
        print('After expand dims, the shape of training feature:', train_fea_reshape1.shape,
              '\nAfter expand dims, the shape of test feature:', test_fea_reshape1.shape)
        
        self.data = train_fea_reshape1 # torch.tensor(train_fea_reshape1)
        self.targets = train_label.flatten() #torch.tensor(train_label.flatten()).long()
        
        print("data and target type:",type(self.data),type(self.targets))


    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        data, target = self.data[idx], self.targets[idx]
        return data, target
    
    def get_class_weights(self):
        class_weights=class_weight.compute_class_weight('balanced',np.unique(self.data_seg_label),
                                                        self.data_seg_label[:,0])
        return class_weights



In [None]:
class EEGMMIDTsSet(Data.Dataset):
    def __init__(self, subject_id, transform=None):
        root_dir = "../Deep-Learning-for-BCI/dataset/"
#         dataset = np.load(root_dir + str(subject_id) + '.npy')
        dataset_raw = np.load(root_dir + str(subject_id) + '.npy')
        dataset=[]  # feature after filtering

        # EEG Gamma pattern decomposition
        for i in range(dataset_raw[:,:-1].shape[1]):
            x = dataset_raw[:, i]
            fs = 160.0
            lowcut = 8.0
            highcut = 30.0
            y = butter_bandpass_filter(x, lowcut, highcut, fs, order=3)
            dataset.append(y)
        dataset=np.array(dataset).T
        dataset=np.hstack((dataset,dataset_raw[:,-1:]))
        # keep 4,5 which are left and right fist open close imagery classes, remove rest
        # refer 1-Data.ipynb for the details
        removed_label = [0,1,6,7,8,9,10]  # [0,1,2,3,4,5,10] for hf # [0,1,6,7,8,9,10] for lr
        for ll in removed_label:
            id = dataset[:, -1]!=ll
            dataset = dataset[id]

        # Pytorch needs labels to be sequentially ordered starting from 0
        dataset[:, -1][dataset[:, -1] == 2] = 0
        dataset[:, -1][dataset[:, -1] == 4] = 0
        dataset[:, -1][dataset[:, -1] == 3] = 1
        dataset[:, -1][dataset[:, -1] == 5] = 1
#         dataset[:, -1][dataset[:, -1] == 10] = 2
        
        # data segmentation
        n_class = 2 #int(11-len(removed_label))  # 0~9 classes ('10:rest' is not considered)
        no_feature = 64  # the number of the features
        segment_length = 160 #160  # selected time window; 16=160*0.1
        
        #Overlapping is removed to avoid training set overlap with test set
        data_seg = extract(dataset, n_classes=n_class, n_fea=no_feature, 
                           time_window=segment_length, moving=(segment_length))  # /2 for 50% overlapping
        print('After segmentation, the shape of the data:', data_seg.shape)

        # split training and test data
        no_longfeature = no_feature*segment_length
        data_seg_feature = data_seg[:, :no_longfeature]
        data_seg_label = data_seg[:, no_longfeature:no_longfeature+1]
        # Its important to have random state set equal for Training and test dataset
        train_feature, test_feature, train_label, test_label = train_test_split(
            data_seg_feature, data_seg_label,random_state=0, shuffle=True,stratify=data_seg_label)

        # Check the class label splits to maintain balance
        unique, counts = np.unique(data_seg_label, return_counts=True)
        left_perc = counts[0]/sum(counts)
        if left_perc < 0.4 or left_perc > 0.6:
            print("Imbalanced dataset with split of: ",left_perc,1-left_perc)
        else:
            print("Classes balanced.")
        unique, counts = np.unique(train_label, return_counts=True)
        print("Class label splits in training set \n ",np.asarray((unique, counts)).T)
        unique, counts = np.unique(test_label, return_counts=True)
        print("Class label splits in test set\n ",np.asarray((unique, counts)).T)



        # normalization
        # before normalize reshape data back to raw data shape
        train_feature_2d = train_feature.reshape([-1, no_feature])
        test_feature_2d = test_feature.reshape([-1, no_feature])

        scaler1 = StandardScaler().fit(train_feature_2d)
        train_fea_norm1 = scaler1.transform(train_feature_2d) # normalize the training data
        test_fea_norm1 = scaler1.transform(test_feature_2d) # normalize the test data
        print('After normalization, the shape of training feature:', train_fea_norm1.shape,
              '\nAfter normalization, the shape of test feature:', test_fea_norm1.shape)
        

        # after normalization, reshape data to 3d
        train_fea_norm1 = train_fea_norm1.reshape([-1, segment_length, no_feature])
        test_fea_norm1 = test_fea_norm1.reshape([-1, segment_length, no_feature])
        print('After reshape, the shape of training feature:', train_fea_norm1.shape,
              '\nAfter reshape, the shape of test feature:', test_fea_norm1.shape)
        
        
        
        # reshape for data shape: (trial, conv channel, electrode channel, time samples)
        # earlier it was (trial,timesamples,electrode_channel)
        train_fea_reshape1 = np.swapaxes(train_fea_norm1,1,2)
        test_fea_reshape1 = np.swapaxes(test_fea_norm1,1,2)
        print('After expand dims, the shape of training feature:', train_fea_reshape1.shape,
              '\nAfter expand dims, the shape of test feature:', test_fea_reshape1.shape)
        
        self.data =  test_fea_reshape1#torch.tensor(test_fea_reshape1)
        self.targets = test_label.flatten() #torch.tensor(test_label.flatten()).long()
        
        print("data and target type:",type(self.data),type(self.targets))

    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        data, target = self.data[idx], self.targets[idx]
        return data, target


In [None]:
def find_topFR_channels(sub_id,top_channel_dict):
    
    start = time.time()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    train_ds = EEGMMIDTrSet(subject_id=sub_id)
#     test_ds = EEGMMIDTsSet(subject_id=sub_id)
    
    # compute covariance matrices on training data
    cov_64 = pyriemann.estimation.Covariances('oas').transform(train_ds[:][0])
    labels = train_ds[:][1]
    
    ES_instance = pyriemann.channelselection.ElectrodeSelection(nelec=21, metric='riemann', n_jobs=1)
    ES_instance = ES_instance.fit(cov_64, y=labels, sample_weight=None)
    # target category is -1 since no class specific feature relevance
    top_channel_dict[(sub_id,-1)] = ES_instance.subelec_
    top_channel_dict[(sub_id,'riem_distance')] = ES_instance.di_
    print("Selected channels:", ES_instance.subelec_)
    print(f"DI list of len {len(ES_instance.di_)}:", ES_instance.di_)

In [None]:
train_ds = EEGMMIDTrSet(subject_id=1)
cov_64 = pyriemann.estimation.Covariances('oas').transform(train_ds[:][0])

In [None]:
cov_64.shape
labels = train_ds[:][1]
labels.shape

In [None]:
top_channel_dict = {}
for s_id in range(1,110):#[7,15,29,32,35,42,43,46,48,49,54,56,62,93,94,108]: #range(7,18):
    print("\n --------------------------------------------------- \n")
    print("Starting for subject id:",s_id)
    for itr in range(1):
        find_topFR_channels(sub_id=s_id,top_channel_dict=top_channel_dict)



In [None]:
with open("./results/riem/eegmmid_ws_results_topchannels.pkl", "wb") as outfile:
        pickle.dump(top_channel_dict, outfile)

## Read top channel dict and regenerate importance matrix

In [None]:
with open("./results/riem/eegmmid_ws_riem_topchannels.pkl", "rb") as outfile:
    top_channel_dict = pickle.load(outfile)

In [None]:
left_list = []
right_list = []
full_list=[]
for key in top_channel_dict:
    if key[0] in [7, 12, 22, 42, 43, 48, 49, 53, 70, 80, 82, 85, 94, 102]:
        if key[1] == 0:
            left_list.extend(top_channel_dict[key])
        elif key[1] == 1:
            right_list.extend(top_channel_dict[key])
        elif key[1] == -1:
            full_list.extend(top_channel_dict[key])


from collections import Counter,OrderedDict
# freq_left = Counter(left_list)
# freq_right = Counter(right_list)

freq_full = Counter(np.asarray(full_list))
print(freq_full)

In [None]:
# Top 21 Feature relevant based on Riem distance back elimination
print(sorted(list(OrderedDict(freq_full.most_common()).keys())[:21]))

In [None]:
most_common21 =freq_full.most_common()[:21]
for ind in most_common21:
    print(ind)

In [None]:
biosemi_montage = mne.channels.make_standard_montage('biosemi64')
index = [8, 9, 10, 46, 45, 44, 43, 13, 12, 11, 47, 48, 49, 50, 16, 17, 18, 
         31, 55, 54, 53, 0, 32, 33, 1, 2, 36, 35, 34, 6, 5, 4, 3, 37, 38, 
         39, 40, 41, 7, 42, 14, 51, 23, 60, 15, 52, 22, 21, 20, 19, 30, 56, 
         57, 58, 59, 24, 25, 29, 62, 61, 26, 28, 63, 27]#range(64)#[37, 9, 10, 46, 45, 44, 13, 12, 11, 47, 48, 49, 50, 17, 18, 31, 55, 54, 19, 30, 56, 29]  # for bci competition iv 2a
biosemi_montage.ch_names = [biosemi_montage.ch_names[i] for i in index]



topfr_list = {}
for ind in most_common21:
    topfr_list[biosemi_montage.ch_names[ind[0]]] = ind[1]

In [None]:
topfr_list
assert len(topfr_list) ==21

In [None]:
unit_weight = 21/sum(topfr_list.values())
print(unit_weight,sum(topfr_list.values()))

for key in topfr_list:
    topfr_list[key] = topfr_list[key] * unit_weight
    
print(sum(topfr_list.values()))
topfr_list['T9'] = topfr_list.pop('P9')
topfr_list['T10'] = topfr_list.pop('P10')

In [None]:
import pandas as pd
import numpy as np

# Replace 'your_file.csv' with the actual file path
file_path = 'channel_loc_mat.csv'

# Read CSV into a pandas DataFrame
data_frame = pd.read_csv(file_path,header=None)

# Extract values as a NumPy array
topo_ref = data_frame.values

# Display the NumPy array
print(topo_ref)

In [None]:
# Define the dimensions of the scalp grid (adjust as needed)
rows, cols = 11,11

# Create a 2D matrix to represent the location map
topo_pred = np.zeros((rows, cols), dtype=float)
for ch in topfr_list:
    topo_pred = np.where(topo_ref == ch,topfr_list[ch], topo_pred) #topfr_list[ch]
print(sum(sum(topo_pred)))

In [None]:
plt.imshow(topo_pred)
plt.colorbar()

In [None]:
topo_true = np.array([ [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                       [0., 0., 1., 1., 1., 1., 1., 1., 1., 0., 0.],
                       [0., 0., 1., 1., 1., 1., 1., 1., 1., 0., 0.],
                       [0., 0., 1., 1., 1., 1., 1., 1., 1., 0., 0.],
                       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

def compare_topomaps(topo_pred,topo_true):
    """
    Show optimal transport on a moving disc in a 50x50 grid
    """
    ## Step 1: Setup problem
    pix = np.linspace(-1, 1, 11) # max channels are 13
    # Setup grid
    X, Y = np.meshgrid(pix, pix)
    # Compute pariwise distances between points on 2D grid so we know
    # how to score the Wasserstein distance
    coords = np.array([X.flatten(), Y.flatten()]).T
    coordsSqr = np.sum(coords**2, 1)
    M = coordsSqr[:, None] + coordsSqr[None, :] - 2*coords.dot(coords.T)
    M[M < 0] = 0
    M = np.sqrt(M)
    wass = ot.emd2(1e-5 +topo_pred.flatten(), 1e-5 +topo_true.flatten(), M, 1.0)
    return wass

compare_topomaps(topo_pred,topo_true)

## Code for training on selected feature relevant channels from selected subjects

In [None]:
# TODO: This class if has list of subject id can later support combination of sub ids
# TODO: add a function transform to convert dataset to train test, avoiding repetition of same code

class EEGMMIDTrSet(Data.Dataset):
    def __init__(self, subject_id, transform=None):
        root_dir = "../Deep-Learning-for-BCI/dataset/"
        dataset_raw = np.load(root_dir + str(subject_id) + '.npy')
        dataset=[]  # feature after filtering

        # EEG Gamma pattern decomposition
        for i in range(dataset_raw[:,:-1].shape[1]):
            x = dataset_raw[:, i]
            fs = 160.0
            lowcut = 8.0
            highcut = 30.0
            y = butter_bandpass_filter(x, lowcut, highcut, fs, order=3)
            dataset.append(y)
        dataset=np.array(dataset).T
        dataset=np.hstack((dataset,dataset_raw[:,-1:]))
        print(dataset.shape)
        # keep 4,5 which are left and right fist open close imagery classes, remove rest
        # refer 1-Data.ipynb for the details
        removed_label = [0,1,6,7,8,9,10]  # [0,1,2,3,4,5,10] for hf # [0,1,6,7,8,9,10] for lr
        for ll in removed_label:
            id = dataset[:, -1]!=ll
            dataset = dataset[id]

        # Pytorch needs labels to be sequentially ordered starting from 0
        dataset[:, -1][dataset[:, -1] == 2] = 0
        dataset[:, -1][dataset[:, -1] == 4] = 0
        dataset[:, -1][dataset[:, -1] == 3] = 1
        dataset[:, -1][dataset[:, -1] == 5] = 1
#         dataset[:, -1][dataset[:, -1] == 10] = 2
        
        # data segmentation
        n_class = 2 #int(11-len(removed_label))  # 0~9 classes ('10:rest' is not considered)
        no_feature = 64  # the number of the features
        segment_length = 160 #160  # selected time window; 16=160*0.1
        
        #Overlapping is removed to avoid training set overlap with test set
        data_seg = extract(dataset, n_classes=n_class, n_fea=no_feature, 
                           time_window=segment_length, moving=(segment_length))  # /2 for 50% overlapping
        print('After segmentation, the shape of the data:', data_seg.shape)

        # split training and test data
        no_longfeature = no_feature*segment_length
        data_seg_feature = data_seg[:, :no_longfeature]
        self.data_seg_label = data_seg[:, no_longfeature:no_longfeature+1]
        
        # Its important to have random state set equal for Training and test dataset
        train_feature, test_feature, train_label, test_label = train_test_split(
            data_seg_feature, self.data_seg_label,random_state=0, shuffle=True,stratify=self.data_seg_label)

        # Check the class label splits to maintain balance
        unique, counts = np.unique(self.data_seg_label, return_counts=True)
        left_perc = counts[0]/sum(counts)
        if left_perc < 0.4 or left_perc > 0.6:
            print("Imbalanced dataset with split of: ",left_perc,1-left_perc)
        else:
            print("Classes balanced.")
        unique, counts = np.unique(train_label, return_counts=True)
        print("Class label splits in training set \n ",np.asarray((unique, counts)).T)
        unique, counts = np.unique(test_label, return_counts=True)
        print("Class label splits in test set\n ",np.asarray((unique, counts)).T)



        # normalization
        # before normalize reshape data back to raw data shape
        train_feature_2d = train_feature.reshape([-1, no_feature])
        test_feature_2d = test_feature.reshape([-1, no_feature])

        scaler1 = StandardScaler().fit(train_feature_2d)
        train_fea_norm1 = scaler1.transform(train_feature_2d) # normalize the training data
        test_fea_norm1 = scaler1.transform(test_feature_2d) # normalize the test data
        print('After normalization, the shape of training feature:', train_fea_norm1.shape,
              '\nAfter normalization, the shape of test feature:', test_fea_norm1.shape)
        
        #list is to select topFR channels
        train_fea_norm1 = train_fea_norm1[:,[0, 3, 6, 8, 12, 15, 23, 24, 25, 28, 30, 37, 39, 42, 43, 45, 52, 55, 59, 60, 63]]
        test_fea_norm1 = test_fea_norm1[:,[0, 3, 6, 8, 12, 15, 23, 24, 25, 28, 30, 37, 39, 42, 43, 45, 52, 55, 59, 60, 63]]
        no_feature = 21
        
        # after normalization, reshape data to 3d
        train_fea_norm1 = train_fea_norm1.reshape([-1, segment_length, no_feature])
        test_fea_norm1 = test_fea_norm1.reshape([-1, segment_length, no_feature])
        print('After reshape, the shape of training feature:', train_fea_norm1.shape,
              '\nAfter reshape, the shape of test feature:', test_fea_norm1.shape)
        
        # reshape for data shape: (trial, conv channel, electrode channel, time samples)
        # earlier it was (trial,timesamples,electrode_channel)
        train_fea_reshape1 = np.swapaxes(train_fea_norm1,1,2)
        test_fea_reshape1 = np.swapaxes(test_fea_norm1,1,2)
        print('After expand dims, the shape of training feature:', train_fea_reshape1.shape,
              '\nAfter expand dims, the shape of test feature:', test_fea_reshape1.shape)
        
        self.data = train_fea_reshape1 # torch.tensor(train_fea_reshape1)
        self.targets = train_label.flatten() #torch.tensor(train_label.flatten()).long()
        
        print("data and target type:",type(self.data),type(self.targets))


    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        data, target = self.data[idx], self.targets[idx]
        return data, target
    
    def get_class_weights(self):
        class_weights=class_weight.compute_class_weight('balanced',np.unique(self.data_seg_label),
                                                        self.data_seg_label[:,0])
        return class_weights



In [None]:
class EEGMMIDTsSet(Data.Dataset):
    def __init__(self, subject_id, transform=None):
        root_dir = "../Deep-Learning-for-BCI/dataset/"
#         dataset = np.load(root_dir + str(subject_id) + '.npy')
        dataset_raw = np.load(root_dir + str(subject_id) + '.npy')
        dataset=[]  # feature after filtering

        # EEG Gamma pattern decomposition
        for i in range(dataset_raw[:,:-1].shape[1]):
            x = dataset_raw[:, i]
            fs = 160.0
            lowcut = 8.0
            highcut = 30.0
            y = butter_bandpass_filter(x, lowcut, highcut, fs, order=3)
            dataset.append(y)
        dataset=np.array(dataset).T
        dataset=np.hstack((dataset,dataset_raw[:,-1:]))
        # keep 4,5 which are left and right fist open close imagery classes, remove rest
        # refer 1-Data.ipynb for the details
        removed_label = [0,1,6,7,8,9,10]  # [0,1,2,3,4,5,10] for hf # [0,1,6,7,8,9,10] for lr
        for ll in removed_label:
            id = dataset[:, -1]!=ll
            dataset = dataset[id]

        # Pytorch needs labels to be sequentially ordered starting from 0
        dataset[:, -1][dataset[:, -1] == 2] = 0
        dataset[:, -1][dataset[:, -1] == 4] = 0
        dataset[:, -1][dataset[:, -1] == 3] = 1
        dataset[:, -1][dataset[:, -1] == 5] = 1
        
        # data segmentation
        n_class = 2 #int(11-len(removed_label))  # 0~9 classes ('10:rest' is not considered)
        no_feature = 64  # the number of the features
        segment_length = 160 #160  # selected time window; 16=160*0.1
        
        #Overlapping is removed to avoid training set overlap with test set
        data_seg = extract(dataset, n_classes=n_class, n_fea=no_feature, 
                           time_window=segment_length, moving=(segment_length))  # /2 for 50% overlapping
        print('After segmentation, the shape of the data:', data_seg.shape)

        # split training and test data
        no_longfeature = no_feature*segment_length
        data_seg_feature = data_seg[:, :no_longfeature]
        data_seg_label = data_seg[:, no_longfeature:no_longfeature+1]
        # Its important to have random state set equal for Training and test dataset
        train_feature, test_feature, train_label, test_label = train_test_split(
            data_seg_feature, data_seg_label,random_state=0, shuffle=True,stratify=data_seg_label)

        # Check the class label splits to maintain balance
        unique, counts = np.unique(data_seg_label, return_counts=True)
        left_perc = counts[0]/sum(counts)
        if left_perc < 0.4 or left_perc > 0.6:
            print("Imbalanced dataset with split of: ",left_perc,1-left_perc)
        else:
            print("Classes balanced.")
        unique, counts = np.unique(train_label, return_counts=True)
        print("Class label splits in training set \n ",np.asarray((unique, counts)).T)
        unique, counts = np.unique(test_label, return_counts=True)
        print("Class label splits in test set\n ",np.asarray((unique, counts)).T)



        # normalization
        # before normalize reshape data back to raw data shape
        train_feature_2d = train_feature.reshape([-1, no_feature])
        test_feature_2d = test_feature.reshape([-1, no_feature])

        scaler1 = StandardScaler().fit(train_feature_2d)
        train_fea_norm1 = scaler1.transform(train_feature_2d) # normalize the training data
        test_fea_norm1 = scaler1.transform(test_feature_2d) # normalize the test data
        print('After normalization, the shape of training feature:', train_fea_norm1.shape,
              '\nAfter normalization, the shape of test feature:', test_fea_norm1.shape)
        

#         # list is to select topFR channels
        train_fea_norm1 = train_fea_norm1[:,[0, 3, 6, 8, 12, 15, 23, 24, 25, 28, 30, 37, 39, 42, 43, 45, 52, 55, 59, 60, 63]]
        test_fea_norm1 = test_fea_norm1[:,[0, 3, 6, 8, 12, 15, 23, 24, 25, 28, 30, 37, 39, 42, 43, 45, 52, 55, 59, 60, 63]]
        no_feature = 21

        # after normalization, reshape data to 3d
        train_fea_norm1 = train_fea_norm1.reshape([-1, segment_length, no_feature])
        test_fea_norm1 = test_fea_norm1.reshape([-1, segment_length, no_feature])
        print('After reshape, the shape of training feature:', train_fea_norm1.shape,
              '\nAfter reshape, the shape of test feature:', test_fea_norm1.shape)
        
        
        
        # reshape for data shape: (trial, conv channel, electrode channel, time samples)
        # earlier it was (trial,timesamples,electrode_channel)
        train_fea_reshape1 = np.swapaxes(train_fea_norm1,1,2)
        test_fea_reshape1 = np.swapaxes(test_fea_norm1,1,2)
        print('After expand dims, the shape of training feature:', train_fea_reshape1.shape,
              '\nAfter expand dims, the shape of test feature:', test_fea_reshape1.shape)
        
        self.data =  test_fea_reshape1#torch.tensor(test_fea_reshape1)
        self.targets = test_label.flatten() #torch.tensor(test_label.flatten()).long()
        
        print("data and target type:",type(self.data),type(self.targets))

    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        data, target = self.data[idx], self.targets[idx]
        return data, target


In [None]:
def get_class_acc(confusion_matrix, class_id):
    """
    confusion matrix of multi-class classification
    
    class_id: id of a particular class 
    
    """
    confusion_matrix = np.float64(confusion_matrix)
    TP = confusion_matrix[class_id,class_id]
    FN = np.sum(confusion_matrix[class_id]) - TP
    FP = np.sum(confusion_matrix[:,class_id]) - TP
    TN = np.sum(confusion_matrix) - TP - FN - FP
    print("for class id: ",class_id)
    print(f"TP: {TP}, FN: {FN}, FP: {FP}, TN: {TN} ")
    
    # sensitivity = 0 if TP == 0
    if TP != 0:
        sensitivity = TP/(TP+FN)
    else:
        sensitivity = 0.
    
    specificity = TN/(TN+FP)
    accuracy = (TP+TN)/(TP+FP+FN+TN)
    
    return sensitivity, specificity, accuracy

In [None]:
def train_eegmmid(task_type, strat, sub_id, i=""):
    
    start = time.time()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    train_ds = EEGMMIDTrSet(subject_id=sub_id)
    test_ds = EEGMMIDTsSet(subject_id=sub_id)
    
    # compute covariance matrices on training data
    cov_topfr_21 = pyriemann.estimation.Covariances('oas').transform(train_ds[:][0])
    labels = train_ds[:][1]
    
    mdm = pyriemann.classification.MDM(metric=dict(mean='riemann', distance='riemann'))
    
    # cross validation
    cv = KFold(n_splits=10, shuffle=True, random_state=42)
    # Use scikit-learn Pipeline with cross_val_score function
    scores = cross_val_score(mdm, cov_topfr_21, labels, cv=cv, n_jobs=1)
    
    
    cov_topfr_21_test = pyriemann.estimation.Covariances('oas').transform(test_ds[:][0])
    mdm = mdm.fit(cov_topfr_21, labels)
    y_pred = mdm.predict(cov_topfr_21_test)
    cm = confusion_matrix(test_ds[:][1], y_pred)
    
    acc_0 = get_class_acc(cm,0)
    acc_1 = get_class_acc(cm,1)
    acc = accuracy_score(test_ds[:][1], y_pred)

    # Printing the results
    class_balance = np.mean(labels == labels[0])
    class_balance = max(class_balance, 1. - class_balance)
    print("train acc: ",scores)
    print("test acc: ",acc, acc_0, acc_1)
    print("chance level acc: ",class_balance)
    
    results = []
    thisresults = []

    print(i + ".")
    
    
    thisresults.append({"task_type":task_type,
                                "strategy":strat,
                                "sub_id":sub_id,
                                "iteration":i,
                                "chance level acc":class_balance,
                                "acc":acc,
                                "acc0":acc_0[0],
                                "acc1":acc_1[0] })
    results.append({"task_type":task_type,
                    "strategy":strat,
                    "sub_id":sub_id,
                    "iteration":i,
                    "results":thisresults})
    elapsed = time.time() - start
#     results.append({"time":elapsed})
    with open("./results/riem_topfr/eegmmid_ws_" + strat +"_"+ str(sub_id)+ "_results" + i + ".pkl", "wb") as outfile:
        pickle.dump(results, outfile)
    print("\t" + str(elapsed) + " seconds")



In [None]:
for s_id in range(1,110):#[7,15,29,32,35,42,43,46,48,49,54,56,62,93,94,108]: #range(7,18):
    print("\n --------------------------------------------------- \n")
    print("Starting for subject id:",s_id)
    for itr in range(1):
        train_eegmmid(task_type="within_sub",strat="riem_topfr", sub_id=s_id, i=str(itr))

In [None]:
# top 21 channels
top_fr_ind = [0, 3, 6, 8, 12, 15, 23, 24, 25, 28, 30, 37, 39, 42, 43, 45, 52, 55, 59, 60, 63]
import mne
# load channel names as per sequence of data
biosemi_montage = mne.channels.make_standard_montage('biosemi64')
index = [8, 9, 10, 46, 45, 44, 43, 13, 12, 11, 47, 48, 49, 50, 16, 17, 18, 
         31, 55, 54, 53, 0, 32, 33, 1, 2, 36, 35, 34, 6, 5, 4, 3, 37, 38, 
         39, 40, 41, 7, 42, 14, 51, 23, 60, 15, 52, 22, 21, 20, 19, 30, 56, 
         57, 58, 59, 24, 25, 29, 62, 61, 26, 28, 63, 27]
biosemi_montage.ch_names = [biosemi_montage.ch_names[i] for i in index]

In [None]:
top_fr_ch_names = [biosemi_montage.ch_names[i] for i in top_fr_ind]

In [None]:
print(top_fr_ch_names)

In [None]:
import numpy as np

# Define the dimensions of the scalp grid (adjust as needed)
rows, cols = 11,11

# Create a 2D matrix to represent the location map
location_map = np.zeros((rows, cols), dtype=float)

# Define the positions of EEG sensors as 2D coordinates
sensor_positions = biosemi_layout.pos[:,:2]*11  # Replace with your actual sensor positions

# Assign values in the matrix based on sensor positions
for position in sensor_positions:
    row, col = map(int, position)  # Convert coordinates to integers
    location_map[row, col] += 1.0  # You can use different values for different sensors if needed

# Now, location_map represents the 2D matrix of EEG sensor locations with floating-point coordinates
location_map

In [None]:
sensor_positions

In [None]:
coord = []
for position in sensor_positions:
    row, col = map(int, position)  # Convert coordinates to integers
    coord.append((row, col))

In [None]:
# biosemi_layout.names[[0, 3, 6, 8, 12, 15, 23, 24, 25, 28, 30, 37, 39, 42, 43, 45, 52, 55, 59, 60, 63]]
for ind in [0, 3, 6, 8, 12, 15, 23, 24, 25, 28, 30, 37, 39, 42, 43, 45, 52, 55, 59, 60, 63]:
    print(biosemi_layout.names[ind])

In [None]:
biosemi_layout.pos = np.asarray([biosemi_layout.pos[i] for i in index])
biosemi_layout.names = [biosemi_layout.names[i] for i in index]

In [None]:
# Define the dimensions of the scalp grid (adjust as needed)
rows, cols = 11,11

# Create a 2D matrix to represent the location map
location_map = np.zeros((rows, cols), dtype=float)

# Define the positions of EEG sensors as 2D coordinates
sensor_positions = biosemi_layout.pos[[0, 3, 6, 8, 12, 15, 23, 24, 25, 28, 30, 37, 39, 42, 43, 45, 52, 55, 59, 60, 63],:2]*11  # Replace with your actual sensor positions

# Assign values in the matrix based on sensor positions
for position in sensor_positions:
    row, col = map(int, position)  # Convert coordinates to integers
    location_map[row, col] = 1.0  # You can use different values for different sensors if needed

# Now, location_map represents the 2D matrix of EEG sensor locations with floating-point coordinates
location_map.T


In [None]:
# Define the dimensions of the scalp grid (adjust as needed)
rows, cols = 11,11

# Create a 2D matrix to represent the location map
location_map = np.zeros((rows, cols), dtype=float)

# Define the positions of EEG sensors as 2D coordinates
sensor_positions = biosemi_layout.pos[:21,:2]*11  # Replace with your actual sensor positions

# Assign values in the matrix based on sensor positions
for position in sensor_positions:
    row, col = map(int, position)  # Convert coordinates to integers
    location_map[row, col] = 1.0  # You can use different values for different sensors if needed

# Now, location_map represents the 2D matrix of EEG sensor locations with floating-point coordinates
location_map.T

In [None]:
sensor_positions.round()