# AIM: Extract features on labeled data using the pretrained EEGNet
for demeaned epochs

In [1]:
import numpy as np
import pandas as pd
import mne
import lightning.pytorch as pl
import torch
import torch.nn.functional as F
import torch.nn as nn
from torchmetrics import F1Score, Accuracy
import random
import os
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix, f1_score
from IPython.display import display
%matplotlib inline


# prevent extensive logging
mne.set_log_level('WARNING')

## Loading epoch data & participant data of labeled sample
These dataframes have been filtered and stored in a previous project. See https://github.com/TSmolders/Internship_EEG for original code

In [2]:

df_participants = pd.read_pickle(r'D:\Documents\RU\Master_Neurobiology\Internship_jaar_2\Project\TD-BRAIN\TDBRAIN_participants_V2_data\df_participants.pkl')
sample_df = pd.read_pickle(r'D:\Documents\RU\Master_Neurobiology\Internship_jaar_2\Project\TD-BRAIN\TD-BRAIN_extracted_features\df_selected_stat_features.pkl')
sample_ids = sample_df['ID'].unique() # obtain unique IDs from subsampled dataframe containing epoched features
df_sample = df_participants[df_participants['participants_ID'].isin(sample_ids)] # filter participants dataframe to only include subsampled IDs
df_sample = df_sample[df_sample['sessID'] == 1] # filter first session
print(df_sample.shape)
print(df_sample['diagnosis'].value_counts())



(225, 12)
diagnosis
ADHD       45
HEALTHY    45
MDD        45
OCD        45
SMC        45
Name: count, dtype: int64


In [3]:
# functions for loading the epoched EEG data
def get_filepath(epoch_dir, participant_ids):
    """
    Function to get the filepath of the epoched EEG recording
    :param epoch_dir: directory containing the epoched EEG recordings
    :param ID: list of participant IDs to include
    """
    filepaths = []
    for subdir, dirs, files in os.walk(epoch_dir):
        for file in files:
            if any(participant_id in file for participant_id in participant_ids):
                filepaths.append(os.path.join(subdir, file))
    return filepaths

class EpochDataset(torch.utils.data.Dataset):
    def __init__(self, participant_ids, epoch_dir):
        self.filepaths = get_filepath(epoch_dir, participant_ids)
        self.participant_ids = participant_ids
        self.epochs = []
        self.participant_ids = []
        self._load_data()
        print(f"Number of epochs: {self.epochs.shape[0]}")
        print(f"Number of participants: {len(self.participant_ids)}")

    def _load_data(self):
        all_epochs = []
        for filepath in self.filepaths:
            epochs = torch.load(filepath)
            # get participant ID from filepath to make sure the participant ID is correct
            participant_id = filepath.split("\\")[-1].split(".")[0]
            all_epochs.append(epochs)
            self.participant_ids.extend([participant_id]*epochs.shape[0])
        self.epochs = np.concatenate(all_epochs, axis=0)

    def __len__(self):
        return self.epochs.shape[0]
    
    def __getitem__(self, idx):
        epoch = self.epochs[idx]
        participant_id = self.participant_ids[idx]
        return torch.tensor(epoch, dtype=torch.float32), participant_id

In [4]:
# load the epochs into a dataset
participant_ids = df_sample['participants_ID'].tolist()
dataset = EpochDataset(participant_ids, r"D:\Documents\RU\Master_Neurobiology\Internship_jaar_2\Project\TD-BRAIN\TDBRAIN-dataset-derivatives\thesis_epoched_data\EC")
print(len(dataset))
print(dataset[0][0].shape)
print(dataset[0][1])
print(dataset[1][1])

Number of epochs: 2688
Number of participants: 2688
2688
torch.Size([26, 1244])
sub-87964717
sub-87964717


## EEGNet architecture

In [5]:
# create Conv2d with max norm constraint
class Conv2dWithConstraint(nn.Conv2d):
    def __init__(self, *args, max_norm: int = 1, **kwargs):
        self.max_norm = max_norm
        super(Conv2dWithConstraint, self).__init__(*args, **kwargs)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        self.weight.data = torch.renorm(self.weight.data, p=2, dim=0, maxnorm=self.max_norm)
        return super(Conv2dWithConstraint, self).forward(x)
    
class EEGNet(nn.Module):
    """
    Code taken and adjusted from pytorch implementation of EEGNet
    url: https://github.com/torcheeg/torcheeg/blob/v1.1.0/torcheeg/models/cnn/eegnet.py#L5
    """
    def __init__(self,
                 chunk_size: int = 1244, # number of data points in each EEG chunk
                 num_electrodes: int = 26, # number of EEG electrodes
                 F1: int = 8, # number of filters in first convolutional layer
                 F2: int = 16, # number of filters in second convolutional layer
                 D: int = 2, # depth multiplier
                 num_extracted_features: int = 100, # number of features to extract
                 kernel_1: int = 64, # the filter size of block 1 (half of sfreq (125 Hz))
                 kernel_2: int = 16, # the filter size of block 2 (one eight of sfreq (500 Hz))
                 dropout: float = 0.25): # dropout rate
        super(EEGNet, self).__init__()
        self.F1 = F1
        self.F2 = F2
        self.D = D
        self.chunk_size = chunk_size
        self.num_extracted_features = num_extracted_features
        self.num_electrodes = num_electrodes
        self.kernel_1 = kernel_1
        self.kernel_2 = kernel_2
        self.dropout = dropout

        self.block1 = nn.Sequential(
            nn.Conv2d(1, self.F1, (1, self.kernel_1), stride=1, padding=(0, self.kernel_1 // 2), bias=False),
            nn.BatchNorm2d(self.F1, momentum=0.01, affine=True, eps=1e-3),
            Conv2dWithConstraint(self.F1,
                                 self.F1 * self.D, (self.num_electrodes, 1),
                                 max_norm=1,
                                 stride=1,
                                 padding=(0, 0),
                                 groups=self.F1,
                                 bias=False), nn.BatchNorm2d(self.F1 * self.D, momentum=0.01, affine=True, eps=1e-3),
            nn.ELU(), nn.AvgPool2d((1, 4), stride=4), nn.Dropout(p=dropout))

        self.block2 = nn.Sequential(
            nn.Conv2d(self.F1 * self.D,
                      self.F1 * self.D, (1, self.kernel_2),
                      stride=1,
                      padding=(0, self.kernel_2 // 2),
                      bias=False,
                      groups=self.F1 * self.D),
            nn.Conv2d(self.F1 * self.D, self.F2, 1, padding=(0, 0), groups=1, bias=False, stride=1),
            nn.BatchNorm2d(self.F2, momentum=0.01, affine=True, eps=1e-3), nn.ELU(), nn.AvgPool2d((1, 8), stride=8),
            nn.Dropout(p=dropout))

        self.lin = nn.Linear(self.feature_dim(), num_extracted_features, bias=False)


    def feature_dim(self):
        # function to calculate the number of features after the convolutional blocks
        with torch.no_grad():
            mock_eeg = torch.zeros(1, 1, self.num_electrodes, self.chunk_size)

            mock_eeg = self.block1(mock_eeg)
            mock_eeg = self.block2(mock_eeg)

        return self.F2 * mock_eeg.shape[3]

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.block1(x)
        x = self.block2(x)
        x = x.flatten(start_dim=1)
        x = self.lin(x)
        return x

## Transfering pretrained weights & extracting features

### Functions:

In [31]:
def transfer_weights(pretrained_weights, pretext_model=EEGNet()):
    """
    Function to transfer the pretrained weights to the pretext model
    param: pretrained_weights: the weights to transfer in a dictionary
    param: pretext_model: the model to transfer the weights to
    """
    pretrained_model = pretext_model
    modified_keys = {}
    for k, v in pretrained_weights.items():
        decomposed_key = k.split('.')
        if decomposed_key[0] == 'EEGNet': # remove the first part of each key to match the model's keys
            pretrained_key = '.'.join(decomposed_key[1:])
            modified_keys[pretrained_key] = v

            
    pretrained_model.load_state_dict(modified_keys)
        
    return pretrained_model

def extract_features(pretrained_model, data, pretext_task, df_sample, to_disk=False):
    """
    Function to extract features from the pretrained model
    param: pretrained_model: the model to extract features from
    param: data: the dataset containing the epochs to extract features from
    param: pretext_task: a string indicating the specific pretext task to save the features as
    param: df_sample: the dataframe containing the sampled participant IDs and their corresponding diagnosis
    param: to_disk: boolean to save the features to disk
    """
    dataloader = torch.utils.data.DataLoader(data, batch_size=1, shuffle=False)
    pretrained_model.eval()
    features_list = []
    participant_ids = []
    with torch.no_grad():  # Disable gradient calculation
        for batch in dataloader:
            epoch, participant_id = batch  # Remove the batch dimension
            epoch = epoch.unsqueeze(0)  # Add dimension
            # print(epoch.shape)
            features = pretrained_model(epoch)  # Extract features
            features = features.squeeze(0)
            features = features.numpy()
            features_list.append(features)
            participant_ids.append(participant_id[0])

    
    features_df = pd.DataFrame(features_list) # store as dataframe
    features_df['ID'] = participant_ids # add participant IDs to the dataframe
    # map the diagnosis values from df_sample to the dataframe based on participant IDs
    features_df['diagnosis'] = features_df['ID'].map(df_sample.set_index('participants_ID')['diagnosis'])
    
    print(f'{features_df.shape = }')
    display(features_df.head(3))


    if to_disk:
        features_df.to_pickle(f'D:/Documents/Master_Data_Science/Thesis/thesis_code/DataScience_Thesis/data/SSL_features/df_{pretext_task}_features.pkl')
    
    return features_df

def evaluate_features(features_df):
    """
    Function to quickly evaluate the extracted features. Doesn't stratify/group data splitting! 
    """
    X = features_df.drop(['ID', 'diagnosis'], axis=1)
    y = features_df['diagnosis']

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

    # quick SVM model
    clf = SVC()
    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)
    print('quick SVM model')
    print(confusion_matrix(y_test, y_pred))
    print(classification_report(y_test, y_pred))
    print(f1_score(y_test, y_pred, average='macro'))
    print()

    # quick random forest model
    clf = RandomForestClassifier()
    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)
    print('quick random forest model')
    print(confusion_matrix(y_test, y_pred))
    print(classification_report(y_test, y_pred))
    print(f1_score(y_test, y_pred, average='macro'))

    return

def get_ssl_features(
        pretext_task,
        data,
        df_sample,
        pretext_model=EEGNet(),
        eval=True,
        to_disk=True,
        weights_dir=r'D:\Documents\Master_Data_Science\Thesis\thesis_code\DataScience_Thesis\data\pretext_model_weights'
                     ):
    """
    Obtains SSL features from the weights trained by the pretext model.
    param: pretext_task: a string indicating the specific pretext task to load the weights from and save the features as
    param: data: the dataset containing the epochs to extract features from
    param: df_sample: the dataframe containing the sampled participant IDs and their corresponding diagnosis
    param: pretext_model: the pretext model to transfer the weights to
    param: eval: boolean to evaluate the features
    param: to_disk: boolean to save the features to disk
    """
    pretrained_weights = torch.load(f'{weights_dir}\{pretext_task}_weights.pt')
    pretrained_model = transfer_weights(pretrained_weights, pretext_model)
    features_df = extract_features(pretrained_model, data, pretext_task, df_sample=df_sample, to_disk=to_disk)
    if eval:
        evaluate_features(features_df)
    
    return

### randomly initialized model

In [7]:
features_df = extract_features(EEGNet(), dataset, df_sample=df_sample, pretext_task='random', to_disk=False)
evaluate_features(features_df)

features_df.shape = (2688, 102)
features_df.head(3) =           0         1         2         3         4         5         6  \
0 -0.004536  0.320280  0.405812  0.039802 -0.308949 -0.257266  0.252949   
1 -0.168716  0.159646  0.369122  0.241846 -0.294710 -0.187903  0.017133   
2 -0.021149  0.256763  0.265521  0.361246 -0.253374 -0.083487 -0.051087   

          7         8         9  ...        92        93        94        95  \
0  0.060935 -0.188668 -0.008633  ...  0.060558 -0.385546 -0.374811  0.350624   
1 -0.083962  0.277266 -0.150524  ... -0.394782 -0.591801 -0.386088  0.295578   
2  0.141549 -0.148216 -0.296685  ... -0.344894 -0.313566 -0.437223  0.387534   

         96        97        98        99            ID  diagnosis  
0 -0.133003 -0.174662 -0.396049  0.490404  sub-87964717        SMC  
1 -0.153409 -0.321951  0.115503  0.097863  sub-87964717        SMC  
2 -0.250683 -0.171399 -0.170894  0.026129  sub-87964717        SMC  

[3 rows x 102 columns]
quick SVM model
[[44 15 

### pretext model with default parameters (0.25 dropout, 0 weight decay, binary cross entropy loss)

In [32]:
# best model checkpoint
get_ssl_features('default_pretext_model', dataset, df_sample, eval=True, to_disk=True)

features_df.shape = (2688, 102)
features_df.head(3) =           0         1         2         3         4         5         6  \
0 -0.511930  0.052950  0.629106 -0.078158 -1.339599 -0.417798  0.066230   
1  1.353647  0.468311 -1.459589  0.212438 -0.850818 -0.448878 -0.012591   
2 -0.763003 -0.355235  0.214412  0.381083 -0.696883  0.254853  0.900915   

          7         8         9  ...        92        93        94        95  \
0  0.000036 -0.960075 -0.984144  ...  2.038200  0.822667 -0.136034 -0.178997   
1  1.331975 -0.114144 -1.265044  ...  0.591929 -0.040963  1.020034  1.659824   
2  0.199232 -0.951991 -0.998366  ...  1.196933  0.748006  1.391716 -0.761784   

         96        97        98        99            ID  diagnosis  
0 -0.077823  0.701746  0.307834 -1.124583  sub-87964717        SMC  
1 -0.552552  1.591685 -0.362145 -1.257051  sub-87964717        SMC  
2  0.502140 -0.699653  1.631659 -0.162776  sub-87964717        SMC  

[3 rows x 102 columns]
quick SVM model
[[52  6 

In [33]:
# overtrained model
get_ssl_features('overtrained_default_pretext_model', dataset, df_sample, eval=True, to_disk=True)

features_df.shape = (2688, 102)
features_df.head(3) =           0         1         2         3         4         5         6  \
0 -1.499251 -1.839241  1.299461 -0.906688 -1.627193 -2.607241  1.848223   
1  2.603693  2.615295 -2.526045  3.560906 -2.999318 -0.375544  1.581350   
2 -1.489633  1.421686  0.774805  1.866495 -0.733223  4.067636  1.103560   

          7         8         9  ...        92        93        94        95  \
0 -0.372570 -1.733146 -2.572570  ...  3.662326  2.495516  1.045900 -1.393403   
1  2.049844  0.889608 -0.389436  ... -1.864482  0.224161  2.841857  2.038397   
2 -1.521009 -1.515620 -0.420507  ...  1.860082 -0.845679  3.349092  0.084994   

         96        97        98        99            ID  diagnosis  
0  0.830747  0.784356  1.230647 -2.264533  sub-87964717        SMC  
1  0.840320  1.497790 -4.382267 -1.196007  sub-87964717        SMC  
2  1.980972  0.776525  6.022857  0.976138  sub-87964717        SMC  

[3 rows x 102 columns]
quick SVM model
[[56  9 

### pretext model with (0.5 dropout, 0.01 weight decay, binary cross entropy loss) trained for 300 epochs

In [14]:
# best model checkpoint
get_ssl_features('0.5dropout_0.01wd_pretext_model', dataset, df_sample, eval=True, to_disk=True)

features_df.shape = (2688, 102)
features_df.head(3) =               0             1             2             3             4  \
0  9.528830e-43  4.428103e-43  6.207752e-43  2.410233e-43  2.012265e-42   
1 -4.932571e-43  1.050974e-43 -1.897358e-42  1.844109e-42  6.003163e-42   
2  1.324227e-42 -2.490107e-42  2.368194e-42 -7.791219e-43 -2.260294e-42   

              5             6             7             8             9  ...  \
0  1.296201e-42  1.454548e-42  3.161329e-42 -1.370470e-42  3.222986e-43  ...   
1  1.233143e-43  2.586797e-42  2.846037e-42 -3.972681e-42 -2.134178e-42  ...   
2 -3.461207e-43 -6.011570e-43  2.100546e-42  3.150119e-42 -1.401298e-44  ...   

             92            93            94            95            96  \
0 -3.138909e-43  4.091792e-43  3.983892e-42  9.335736e-08 -6.053609e-43   
1  1.258366e-42 -2.045896e-43  4.240329e-42  8.344613e-07  1.458752e-42   
2 -4.764415e-44 -1.636717e-42 -2.754953e-42 -3.120408e-07 -2.310741e-42   

             97        

In [15]:
# overtrained model
get_ssl_features('overtrained_0.5dropout_0.01wd_pretext_model', dataset, df_sample, eval=True, to_disk=True)

features_df.shape = (2688, 102)
features_df.head(3) =               0             1             2             3             4  \
0  1.063586e-42  7.006492e-45  2.928714e-43  9.528830e-44  1.490982e-42   
1 -1.177091e-43 -1.541428e-43 -9.486791e-43  1.388687e-42  4.074976e-42   
2  1.053776e-42 -1.974430e-42  1.226136e-42 -2.410233e-43 -1.167282e-42   

              5             6             7             8             9  ...  \
0  1.210722e-42  9.290609e-43  2.158000e-42 -1.094414e-42  4.287973e-43  ...   
1  5.857428e-43  2.300932e-42  1.716591e-42 -2.421444e-42 -1.027152e-42  ...   
2  1.008935e-43 -2.424246e-43  1.430726e-42  1.987041e-42  2.087935e-43  ...   

             92            93            94        95            96        97  \
0  2.522337e-44  6.473999e-43  2.511127e-42  0.033212 -1.245754e-42  0.000101   
1  1.223334e-42  2.396220e-43  2.729729e-42  0.066676  9.570869e-43  0.000163   
2  5.885454e-44 -7.272739e-43 -1.880543e-42 -0.078971 -1.842707e-42 -0.000132   


### pretext model with tpos=2 (0.5 dropout, 0.01 weight decay, binary cross entropy loss, 1e-5 lr) trained for 300 epochs

In [16]:
# best model checkpoint
get_ssl_features('tpos2_pretext_model', dataset, df_sample, eval=True, to_disk=True)

features_df.shape = (2688, 102)
features_df.head(3) =           0         1         2         3             4         5         6  \
0  0.061802  0.005085  0.007856 -0.035400  2.385010e-42 -0.008866 -0.056749   
1 -0.029408  0.061680  0.071100 -0.022571  3.335090e-42 -0.076831  0.077834   
2 -0.025841 -0.010355 -0.011240 -0.053490  1.715189e-42  0.035290  0.000481   

          7         8         9  ...        92        93        94        95  \
0 -0.065079 -0.065252 -0.002033  ... -0.014515 -0.007414  0.136899 -0.032047   
1 -0.027313 -0.020615 -0.008531  ... -0.026250 -0.011071 -0.027499 -0.127595   
2 -0.017168 -0.007870  0.079678  ... -0.044766  0.064771  0.049487  0.048421   

         96        97        98        99            ID  diagnosis  
0 -0.040107  0.011386  0.046973 -0.040559  sub-87964717        SMC  
1 -0.109569  0.031522  0.031266 -0.068513  sub-87964717        SMC  
2  0.054905  0.005165  0.030797 -0.034964  sub-87964717        SMC  

[3 rows x 102 columns]
quick SV

In [17]:
# overtrained model
get_ssl_features('overtrained_tpos2_pretext_model', dataset, df_sample, eval=True, to_disk=True)

features_df.shape = (2688, 102)
features_df.head(3) =           0         1         2         3             4         5         6  \
0  0.061802  0.005085  0.007856 -0.035400  2.385010e-42 -0.008866 -0.056749   
1 -0.029408  0.061680  0.071100 -0.022571  3.335090e-42 -0.076831  0.077834   
2 -0.025841 -0.010355 -0.011240 -0.053490  1.715189e-42  0.035290  0.000481   

          7         8         9  ...        92        93        94        95  \
0 -0.065079 -0.065252 -0.002033  ... -0.014515 -0.007414  0.136899 -0.032047   
1 -0.027313 -0.020615 -0.008531  ... -0.026250 -0.011071 -0.027499 -0.127595   
2 -0.017168 -0.007870  0.079678  ... -0.044766  0.064771  0.049487  0.048421   

         96        97        98        99            ID  diagnosis  
0 -0.040107  0.011386  0.046973 -0.040559  sub-87964717        SMC  
1 -0.109569  0.031522  0.031266 -0.068513  sub-87964717        SMC  
2  0.054905  0.005165  0.030797 -0.034964  sub-87964717        SMC  

[3 rows x 102 columns]
quick SV

### pretext model with soft margin loss (0.5 dropout, 0.01 weight decay, 1e-5 lr) trained for 300 epochs

In [18]:
# best model checkpoint
get_ssl_features('soft_margin_loss_pretext_model', dataset, df_sample, eval=True, to_disk=True)

features_df.shape = (2688, 102)
features_df.head(3) =           0         1         2         3         4         5         6  \
0  0.219832  0.098533  0.005199  0.195303  0.086944 -0.068312  0.143486   
1 -0.120667 -0.014623 -0.183000 -0.099864 -0.015822  0.276396 -0.171983   
2 -0.036115 -0.154194  0.118493 -0.061462 -0.167263  0.196827  0.051444   

          7         8             9  ...        92        93        94  \
0  0.324379  0.001090 -6.033991e-42  ... -0.038351 -0.131607 -0.238432   
1  0.062228  0.104401 -7.340001e-42  ...  0.121118 -0.130495  0.147437   
2 -0.082809 -0.054026  2.910497e-42  ...  0.148639  0.086051 -0.018675   

         95        96        97        98        99            ID  diagnosis  
0 -0.041366  0.061754  0.003293  0.288630  0.028199  sub-87964717        SMC  
1 -0.115310 -0.065496  0.120059  0.031802 -0.103211  sub-87964717        SMC  
2  0.124328 -0.143368 -0.065945 -0.056548 -0.054269  sub-87964717        SMC  

[3 rows x 102 columns]
quick SV

In [19]:
# overtrained model
get_ssl_features('overtrained_soft_margin_loss_pretext_model', dataset, df_sample, eval=True, to_disk=True)

features_df.shape = (2688, 102)
features_df.head(3) =           0         1         2         3         4         5         6  \
0  0.226481  0.088861  0.005907  0.204030  0.081828 -0.069658  0.137385   
1 -0.109406 -0.035708 -0.181463 -0.088779 -0.029132  0.255942 -0.170817   
2 -0.038523 -0.145451  0.118563 -0.055741 -0.160860  0.196051  0.050994   

          7         8             9  ...        92        93        94  \
0  0.318937 -0.004623 -5.871441e-42  ... -0.037064 -0.121207 -0.240783   
1  0.082091  0.100829 -7.055538e-42  ...  0.117451 -0.140545  0.121451   
2 -0.101076 -0.050460  2.813807e-42  ...  0.154400  0.077547 -0.002461   

         95        96        97        98        99            ID  diagnosis  
0 -0.037577  0.074509 -0.005701  0.278743  0.055143  sub-87964717        SMC  
1 -0.118065 -0.071716  0.122560  0.047558 -0.077472  sub-87964717        SMC  
2  0.111397 -0.135598 -0.061939 -0.067626 -0.054492  sub-87964717        SMC  

[3 rows x 102 columns]
quick SV