# AIM: Extract features on labeled data using the pretrained EEGNet

In [1]:
import numpy as np
import pandas as pd
import mne
import lightning.pytorch as pl
import torch
import torch.nn.functional as F
import torch.nn as nn
from torchmetrics import F1Score, Accuracy
import random
import os
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix, f1_score
%matplotlib inline


# prevent extensive logging
mne.set_log_level('WARNING')

## Loading epoch data & participant data of labeled sample
These dataframes have been filtered and stored in a previous project. See https://github.com/TSmolders/Internship_EEG for original code

In [2]:

df_participants = pd.read_pickle(r'D:\Documents\RU\Master_Neurobiology\Internship_jaar_2\Project\TD-BRAIN\TDBRAIN_participants_V2_data\df_participants.pkl')
sample_df = pd.read_pickle(r'D:\Documents\RU\Master_Neurobiology\Internship_jaar_2\Project\TD-BRAIN\TD-BRAIN_extracted_features\df_selected_stat_features.pkl')
sample_ids = sample_df['ID'].unique() # obtain unique IDs from subsampled dataframe containing epoched features
df_sample = df_participants[df_participants['participants_ID'].isin(sample_ids)] # filter participants dataframe to only include subsampled IDs
df_sample = df_sample[df_sample['sessID'] == 1] # filter first session
print(df_sample.shape)
print(df_sample['diagnosis'].value_counts())



(225, 12)
diagnosis
ADHD       45
HEALTHY    45
MDD        45
OCD        45
SMC        45
Name: count, dtype: int64


In [3]:
# functions for loading the epoched EEG data
def get_filepath(epoch_dir, participant_ids):
    """
    Function to get the filepath of the epoched EEG recording
    :param epoch_dir: directory containing the epoched EEG recordings
    :param ID: list of participant IDs to include
    """
    filepaths = []
    for subdir, dirs, files in os.walk(epoch_dir):
        for file in files:
            if any(participant_id in file for participant_id in participant_ids):
                filepaths.append(os.path.join(subdir, file))
    return filepaths

class EpochDataset(torch.utils.data.Dataset):
    def __init__(self, participant_ids, epoch_dir):
        self.filepaths = get_filepath(epoch_dir, participant_ids)
        self.participant_ids = participant_ids
        self.epochs = []
        self.participant_ids = []
        self._load_data()
        print(f"Number of epochs: {self.epochs.shape[0]}")
        print(f"Number of participants: {len(self.participant_ids)}")

    def _load_data(self):
        all_epochs = []
        for filepath in self.filepaths:
            epochs = torch.load(filepath)
            # get participant ID from filepath to make sure the participant ID is correct
            participant_id = filepath.split("\\")[-1].split(".")[0]
            all_epochs.append(epochs)
            self.participant_ids.extend([participant_id]*epochs.shape[0])
        self.epochs = np.concatenate(all_epochs, axis=0)

    def __len__(self):
        return self.epochs.shape[0]
    
    def __getitem__(self, idx):
        epoch = self.epochs[idx]
        participant_id = self.participant_ids[idx]
        return torch.tensor(epoch, dtype=torch.float32), participant_id

In [4]:
# load the epochs into a dataset
participant_ids = df_sample['participants_ID'].tolist()
dataset = EpochDataset(participant_ids, r"D:\Documents\RU\Master_Neurobiology\Internship_jaar_2\Project\TD-BRAIN\TDBRAIN-dataset-derivatives\thesis_epoched_data\EC")
print(len(dataset))
print(dataset[0][0].shape)
print(dataset[0][1])
print(dataset[1][1])

Number of epochs: 2688
Number of participants: 2688
2688
torch.Size([26, 1244])
sub-87964717
sub-87964717


## EEGNet architecture

In [5]:
# create Conv2d with max norm constraint
class Conv2dWithConstraint(nn.Conv2d):
    def __init__(self, *args, max_norm: int = 1, **kwargs):
        self.max_norm = max_norm
        super(Conv2dWithConstraint, self).__init__(*args, **kwargs)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        self.weight.data = torch.renorm(self.weight.data, p=2, dim=0, maxnorm=self.max_norm)
        return super(Conv2dWithConstraint, self).forward(x)
    
class EEGNet(nn.Module):
    """
    Code taken and adjusted from pytorch implementation of EEGNet
    url: https://github.com/torcheeg/torcheeg/blob/v1.1.0/torcheeg/models/cnn/eegnet.py#L5
    """
    def __init__(self,
                 chunk_size: int = 1244, # number of data points in each EEG chunk
                 num_electrodes: int = 26, # number of EEG electrodes
                 F1: int = 8, # number of filters in first convolutional layer
                 F2: int = 16, # number of filters in second convolutional layer
                 D: int = 2, # depth multiplier
                 num_extracted_features: int = 100, # number of features to extract
                 kernel_1: int = 64, # the filter size of block 1 (half of sfreq (125 Hz))
                 kernel_2: int = 16, # the filter size of block 2 (one eight of sfreq (500 Hz))
                 dropout: float = 0.25): # dropout rate
        super(EEGNet, self).__init__()
        self.F1 = F1
        self.F2 = F2
        self.D = D
        self.chunk_size = chunk_size
        self.num_extracted_features = num_extracted_features
        self.num_electrodes = num_electrodes
        self.kernel_1 = kernel_1
        self.kernel_2 = kernel_2
        self.dropout = dropout

        self.block1 = nn.Sequential(
            nn.Conv2d(1, self.F1, (1, self.kernel_1), stride=1, padding=(0, self.kernel_1 // 2), bias=False),
            nn.BatchNorm2d(self.F1, momentum=0.01, affine=True, eps=1e-3),
            Conv2dWithConstraint(self.F1,
                                 self.F1 * self.D, (self.num_electrodes, 1),
                                 max_norm=1,
                                 stride=1,
                                 padding=(0, 0),
                                 groups=self.F1,
                                 bias=False), nn.BatchNorm2d(self.F1 * self.D, momentum=0.01, affine=True, eps=1e-3),
            nn.ELU(), nn.AvgPool2d((1, 4), stride=4), nn.Dropout(p=dropout))

        self.block2 = nn.Sequential(
            nn.Conv2d(self.F1 * self.D,
                      self.F1 * self.D, (1, self.kernel_2),
                      stride=1,
                      padding=(0, self.kernel_2 // 2),
                      bias=False,
                      groups=self.F1 * self.D),
            nn.Conv2d(self.F1 * self.D, self.F2, 1, padding=(0, 0), groups=1, bias=False, stride=1),
            nn.BatchNorm2d(self.F2, momentum=0.01, affine=True, eps=1e-3), nn.ELU(), nn.AvgPool2d((1, 8), stride=8),
            nn.Dropout(p=dropout))

        self.lin = nn.Linear(self.feature_dim(), num_extracted_features, bias=False)


    def feature_dim(self):
        # function to calculate the number of features after the convolutional blocks
        with torch.no_grad():
            mock_eeg = torch.zeros(1, 1, self.num_electrodes, self.chunk_size)

            mock_eeg = self.block1(mock_eeg)
            mock_eeg = self.block2(mock_eeg)

        return self.F2 * mock_eeg.shape[3]

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.block1(x)
        x = self.block2(x)
        x = x.flatten(start_dim=1)
        x = self.lin(x)
        return x

## Transfering pretrained weights & extracting features

### Functions:

In [26]:
def transfer_weights(pretrained_weights, pretext_model=EEGNet()):
    """
    Function to transfer the pretrained weights to the pretext model
    param: pretrained_weights: the weights to transfer in a dictionary
    param: pretext_model: the model to transfer the weights to
    """
    pretrained_model = pretext_model
    modified_keys = {}
    for k, v in pretrained_weights.items():
        decomposed_key = k.split('.')
        if decomposed_key[0] == 'EEGNet': # remove the first part of each key to match the model's keys
            pretrained_key = '.'.join(decomposed_key[1:])
            modified_keys[pretrained_key] = v

            
    pretrained_model.load_state_dict(modified_keys)

    # assert that the weights are transferred correctly
    assert torch.equal(pretrained_model.block1[0].weight, pretrained_weights['EEGNet.block1.0.weight'])
    
    return pretrained_model

def extract_features(pretrained_model, data, pretext_task, df_sample, to_disk=False):
    """
    Function to extract features from the pretrained model
    param: pretrained_model: the model to extract features from
    param: data: the dataset containing the epochs to extract features from
    param: pretext_task: a string indicating the specific pretext task to save the features as
    param: df_sample: the dataframe containing the sampled participant IDs and their corresponding diagnosis
    param: to_disk: boolean to save the features to disk
    """
    dataloader = torch.utils.data.DataLoader(data, batch_size=1, shuffle=False)
    pretrained_model.eval()
    features_list = []
    participant_ids = []
    with torch.no_grad():  # Disable gradient calculation
        for batch in dataloader:
            epoch, participant_id = batch  # Remove the batch dimension
            epoch = epoch.unsqueeze(0)  # Add dimension
            # print(epoch.shape)
            features = pretrained_model(epoch)  # Extract features
            features = features.squeeze(0)
            features = features.numpy()
            features_list.append(features)
            participant_ids.append(participant_id[0])

    
    features_df = pd.DataFrame(features_list) # store as dataframe
    features_df['ID'] = participant_ids # add participant IDs to the dataframe
    # map the diagnosis values from df_sample to the dataframe based on participant IDs
    features_df['diagnosis'] = features_df['ID'].map(df_sample.set_index('participants_ID')['diagnosis'])
    
    print(f'{features_df.shape = }')
    print(f'{features_df.head(3) = }')

    if to_disk:
        features_df.to_pickle(f'D:/Documents/Master_Data_Science/Thesis/thesis_code/DataScience_Thesis/data/SSL_features/df_{pretext_task}_features.pkl')
    
    return features_df

def evaluate_features(features_df):
    """
    Function to quickly evaluate the extracted features. Doesn't stratify/group data splitting! 
    """
    X = features_df.drop(['ID', 'diagnosis'], axis=1)
    y = features_df['diagnosis']

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

    # quick SVM model
    clf = SVC()
    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)
    print('quick SVM model')
    print(confusion_matrix(y_test, y_pred))
    print(classification_report(y_test, y_pred))
    print(f1_score(y_test, y_pred, average='macro'))
    print()

    # quick random forest model
    clf = RandomForestClassifier()
    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)
    print('quick random forest model')
    print(confusion_matrix(y_test, y_pred))
    print(classification_report(y_test, y_pred))
    print(f1_score(y_test, y_pred, average='macro'))

    return

def get_ssl_features(
        pretext_task,
        data,
        df_sample,
        pretext_model=EEGNet(),
        eval=True,
        to_disk=True,
        weights_dir=r'D:\Documents\Master_Data_Science\Thesis\thesis_code\DataScience_Thesis\data\pretext_model_weights'
                     ):
    """
    Obtains SSL features from the weights trained by the pretext model.
    param: pretext_task: a string indicating the specific pretext task to load the weights from and save the features as
    param: data: the dataset containing the epochs to extract features from
    param: df_sample: the dataframe containing the sampled participant IDs and their corresponding diagnosis
    param: pretext_model: the pretext model to transfer the weights to
    param: eval: boolean to evaluate the features
    param: to_disk: boolean to save the features to disk
    """
    pretrained_weights = torch.load(f'{weights_dir}\{pretext_task}_weights.pt')
    pretrained_model = transfer_weights(pretrained_weights, pretext_model)
    features_df = extract_features(pretrained_model, data, pretext_task, df_sample=df_sample, to_disk=to_disk)
    if eval:
        evaluate_features(features_df)
    
    return

### randomly initialized model

In [28]:
features_df = extract_features(EEGNet(), dataset, df_sample=df_sample, pretext_task='random', to_disk=False)
evaluate_features(features_df)

features_df.shape = (2688, 102)
features_df.head(3) =           0         1         2         3         4         5         6  \
0 -0.402753 -0.018044  0.215058 -0.077406 -0.205178  0.171975 -0.167967   
1 -0.229062  0.036969  0.568885 -0.129756 -0.021982  0.176611 -0.151093   
2 -0.266612 -0.267169  0.270881  0.135174 -0.010860  0.258197 -0.105652   

          7         8         9  ...        92        93        94        95  \
0 -0.120625  0.146552 -0.416287  ...  0.170753  0.015566 -0.161969  0.324737   
1  0.002237 -0.212080 -0.567819  ...  0.202544  0.197015 -0.137225  0.456216   
2  0.010338 -0.002674 -0.073011  ...  0.279790  0.164753 -0.070322 -0.039753   

        96        97        98        99            ID  diagnosis  
0  0.22090 -0.109117  0.144687 -0.142254  sub-87964717        SMC  
1  0.20367 -0.121167 -0.081884  0.050865  sub-87964717        SMC  
2  0.13927 -0.221297  0.002260 -0.182580  sub-87964717        SMC  

[3 rows x 102 columns]
quick SVM model
[[47 11  9 2

### pretext model with default parameters (0.25 dropout, 0 weight decay, binary cross entropy loss)

In [18]:
# best model checkpoint
get_ssl_features('standard_pretext_model', dataset, df_sample, eval=True, to_disk=True)

features_df.shape = (2688, 102)
features_df.head(3) =           0         1         2         3         4         5         6  \
0 -1.893295 -2.149688 -0.291691  0.780488 -1.615547  1.665748  0.461150   
1 -2.980256 -0.838663  0.650124 -0.476067 -1.397269 -0.044820  1.434734   
2  0.775971  0.236548  1.128384  0.707310  0.376685 -1.288042 -0.227033   

          7         8         9  ...        92        93        94        95  \
0 -0.837854  0.562078  1.077490  ... -0.341244  0.210595  1.314447  2.179167   
1 -0.210545  1.121377 -1.097450  ... -2.101032 -1.259048  1.215847 -1.073825   
2  0.380525 -0.230556  1.754263  ...  0.305352  0.125076  1.234537 -1.021486   

         96        97        98        99            ID  diagnosis  
0 -2.016901  0.171110 -0.861355 -0.053142  sub-87964717        SMC  
1 -0.081516  0.933414 -1.175488  1.451200  sub-87964717        SMC  
2 -0.300199 -0.030796  1.932677  1.266169  sub-87964717        SMC  

[3 rows x 102 columns]
quick SVM model
[[56  8 

In [19]:
# overtrained model
get_ssl_features('overtrained_standard_pretext_model', dataset, df_sample, eval=True, to_disk=True)

features_df.shape = (2688, 102)
features_df.head(3) =           0         1         2         3         4         5         6  \
0 -7.071573 -5.241930 -5.575650 -0.064017 -5.057350  2.548192 -1.623831   
1 -3.345401 -1.141400  1.936891  0.652492 -5.435171 -0.186471  2.487033   
2 -1.183420 -0.480357  3.380833  1.488391  2.152638 -4.007096  1.255239   

          7         8         9  ...        92        93        94        95  \
0  1.098861 -0.824495 -3.680593  ... -5.903049 -3.027779  2.351406  3.764353   
1 -1.902161  2.025343 -1.199311  ... -4.263688 -1.011250  7.327817 -4.517756   
2  1.056725 -2.077426  4.147937  ... -1.247893  4.364254  0.495155  2.827570   

         96        97        98        99            ID  diagnosis  
0 -4.891006  0.856988  0.988189 -0.587294  sub-87964717        SMC  
1  2.113866  2.032950 -0.433161  4.070028  sub-87964717        SMC  
2 -0.980230  0.179054  2.691066  3.877378  sub-87964717        SMC  

[3 rows x 102 columns]
quick SVM model
[[57  8 

### pretext model with (0.5 dropout, 0.01 weight decay, binary cross entropy loss) trained for 300 epochs

In [20]:
# best model checkpoint
get_ssl_features('pretext_0.5dropout_0.01wd_model', dataset, df_sample, eval=True, to_disk=True)

features_df.shape = (2688, 102)
features_df.head(3) =           0             1             2             3             4  \
0 -0.027774  2.082330e-42  1.740413e-42  3.405155e-43 -1.926785e-42   
1  0.027152  1.566652e-42  1.395693e-42 -1.680157e-42 -6.558077e-43   
2  0.167967  1.168683e-42  1.795063e-42  9.528830e-44  6.011570e-43   

              5             6             7             8             9  ...  \
0 -2.895083e-42 -1.757228e-42 -2.490107e-42 -2.310741e-42  7.707142e-44  ...   
1 -4.049753e-43 -4.694350e-43 -9.696985e-43  1.443337e-42  2.564376e-43  ...   
2 -7.160635e-43  1.290596e-42 -8.379765e-43 -1.631111e-42  1.936594e-42  ...   

             92        93            94            95            96        97  \
0 -1.894556e-42 -0.130888  7.959375e-43  1.589072e-42 -1.877740e-43  0.009687   
1 -9.935206e-43  0.480544 -1.681558e-44  1.492383e-42 -5.871441e-43 -0.266607   
2  3.068844e-43 -0.167957  4.049753e-43 -8.071479e-43  1.233143e-42  0.005749   

             98

In [21]:
# overtrained model
get_ssl_features('overtrained_pretext_0.5dropout_0.01wd_model', dataset, df_sample, eval=True, to_disk=True)

features_df.shape = (2688, 102)
features_df.head(3) =               0             1             2             3             4  \
0  2.306537e-42  1.893154e-42  1.460153e-42  2.101948e-43 -1.455949e-42   
1  4.960597e-42  1.160275e-42  7.496947e-43 -1.364865e-42 -4.694350e-43   
2 -2.186026e-43  1.090210e-42  1.122440e-42  1.079000e-43  2.522337e-43   

              5             6             7             8             9  ...  \
0 -1.771241e-42 -1.150466e-42 -1.677354e-42 -1.862326e-42 -1.093013e-43  ...   
1 -3.124896e-43 -2.928714e-43 -4.273960e-43  1.193906e-42  2.522337e-43  ...   
2 -3.615350e-43  8.646012e-43 -4.722376e-43 -1.221932e-42  1.279385e-42  ...   

             92        93            94            95            96  \
0 -1.554040e-42  0.059893  5.254869e-43  7.314778e-43  1.219130e-43   
1 -7.931349e-43  0.519542  4.484155e-44  8.267661e-43 -1.008935e-42   
2 -1.149065e-43 -0.100525  2.143987e-43 -8.870219e-43  6.221765e-43   

             97            98          

### pretext model with tpos=2 (0.5 dropout, 0.01 weight decay, binary cross entropy loss, 1e-5 lr) trained for 300 epochs

In [22]:
# best model checkpoint
get_ssl_features('pretext_tpos2_model', dataset, df_sample, eval=True, to_disk=True)

features_df.shape = (2688, 102)
features_df.head(3) =           0         1         2         3         4         5         6  \
0  0.044288  0.077095  0.041208  0.069365  0.005585  0.086525 -0.039408   
1  0.038830 -0.073380 -0.002620  0.041997  0.026605  0.224603  0.007724   
2  0.007056 -0.006016  0.013439 -0.068288  0.018402  0.035598  0.056252   

          7             8         9  ...        92        93        94  \
0 -0.099715 -2.858649e-43  0.031967  ...  0.011712 -0.130427  0.111504   
1 -0.003489 -8.421804e-43  0.136855  ... -0.012818 -0.036208  0.050513   
2 -0.010388  3.744269e-42 -0.003925  ... -0.053303  0.017775  0.037817   

         95        96        97        98        99            ID  diagnosis  
0  0.042759  0.034478  0.010688 -0.070165  0.023789  sub-87964717        SMC  
1  0.145209  0.123046 -0.009172  0.140408 -0.051043  sub-87964717        SMC  
2  0.030690  0.103314  0.061628 -0.052314 -0.033971  sub-87964717        SMC  

[3 rows x 102 columns]
quick SV

In [23]:
# overtrained model
get_ssl_features('overtrained_pretext_tpos2_model', dataset, df_sample, eval=True, to_disk=True)

features_df.shape = (2688, 102)
features_df.head(3) =           0         1         2         3         4         5         6  \
0  0.046656  0.084502  0.040037  0.069342  0.004450  0.089280 -0.038479   
1  0.044932 -0.074159 -0.003874  0.043264  0.026199  0.228104  0.010726   
2  0.008904 -0.003818  0.019000 -0.069606  0.023631  0.033476  0.056637   

          7             8         9  ...        92        93        94  \
0 -0.099987 -2.073922e-43  0.033823  ...  0.009429 -0.126348  0.113319   
1 -0.005795 -8.716076e-43  0.138301  ... -0.017095 -0.035397  0.051016   
2 -0.007491  3.905419e-42 -0.002699  ... -0.059963  0.021154  0.036736   

         95        96        97        98        99            ID  diagnosis  
0  0.044803  0.046632  0.007344 -0.075668  0.017424  sub-87964717        SMC  
1  0.144371  0.127535 -0.006805  0.139712 -0.054080  sub-87964717        SMC  
2  0.038853  0.110679  0.063725 -0.055768 -0.033633  sub-87964717        SMC  

[3 rows x 102 columns]
quick SV

### pretext model with soft margin loss (0.5 dropout, 0.01 weight decay, 1e-5 lr) trained for 300 epochs

In [24]:
# best model checkpoint
get_ssl_features('pretext_soft_margin_loss_model', dataset, df_sample, eval=True, to_disk=True)

features_df.shape = (2688, 102)
features_df.head(3) =           0         1         2         3         4         5         6  \
0 -0.232964  0.055680  0.150263  0.086725 -0.071166 -0.041175 -0.086144   
1 -0.001101  0.135006 -0.079469  0.213518  0.022893 -0.341625 -0.085815   
2  0.086274 -0.131936 -0.181186 -0.051774  0.172969 -0.060890  0.193879   

          7         8         9  ...        92        93        94        95  \
0  0.057028 -0.049891 -0.041486  ... -0.023152 -0.026366  0.016942 -0.235924   
1 -0.101287  0.041595 -0.148309  ... -0.175478  0.248597 -0.118389 -0.129514   
2 -0.129158  0.087837 -0.108324  ...  0.121672  0.093388 -0.112233  0.133843   

         96        97        98        99            ID  diagnosis  
0 -0.006213 -0.035274 -0.011762  0.162296  sub-87964717        SMC  
1  0.289626 -0.236137 -0.200228  0.025403  sub-87964717        SMC  
2 -0.015941  0.035154  0.137667 -0.192194  sub-87964717        SMC  

[3 rows x 102 columns]
quick SVM model
[[46 11 

In [25]:
# overtrained model
get_ssl_features('overtrained_pretext_soft_margin_loss_model', dataset, df_sample, eval=True, to_disk=True)

features_df.shape = (2688, 102)
features_df.head(3) =           0         1         2         3         4         5         6  \
0 -0.226694  0.028478  0.151133  0.126223 -0.046721 -0.056416 -0.037307   
1 -0.016170  0.127135 -0.084554  0.239610  0.023522 -0.379305 -0.063542   
2  0.097132 -0.129453 -0.149846 -0.015699  0.162194 -0.032629  0.234710   

          7         8         9  ...        92        93        94        95  \
0  0.051133 -0.033549 -0.039247  ... -0.017119 -0.038116  0.014537 -0.229170   
1 -0.048511  0.026792 -0.103350  ... -0.175404  0.206910 -0.124889 -0.155380   
2 -0.124972  0.091944 -0.110038  ...  0.096640  0.126587 -0.122890  0.141575   

         96        97        98        99            ID  diagnosis  
0  0.026684 -0.021335 -0.008719  0.132885  sub-87964717        SMC  
1  0.289495 -0.240159 -0.205877  0.003637  sub-87964717        SMC  
2 -0.012638  0.010316  0.109036 -0.197871  sub-87964717        SMC  

[3 rows x 102 columns]
quick SVM model
[[43 14 