# AIM: Extract features on labeled data using the pretrained EEGNet
across subject pretext task

In [1]:
import numpy as np
import pandas as pd
import mne
import lightning.pytorch as pl
import torch
import torch.nn.functional as F
import torch.nn as nn
from torchmetrics import F1Score, Accuracy
import random
import os
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split, StratifiedGroupKFold
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix, f1_score
from IPython.display import display
%matplotlib inline


# prevent extensive logging
mne.set_log_level('WARNING')

## Loading epoch data & participant data of labeled sample
These dataframes have been filtered and stored in a previous project. See https://github.com/TSmolders/Internship_EEG for original code

In [2]:

df_participants = pd.read_pickle(r'D:\Documents\RU\Master_Neurobiology\Internship_jaar_2\Project\TD-BRAIN\TDBRAIN_participants_V2_data\df_participants.pkl')
sample_df = pd.read_pickle(r'D:\Documents\RU\Master_Neurobiology\Internship_jaar_2\Project\TD-BRAIN\TD-BRAIN_extracted_features\df_selected_stat_features.pkl')
sample_ids = sample_df['ID'].unique() # obtain unique IDs from subsampled dataframe containing epoched features
df_sample = df_participants[df_participants['participants_ID'].isin(sample_ids)] # filter participants dataframe to only include subsampled IDs
df_sample = df_sample[df_sample['sessID'] == 1] # filter first session
print(df_sample.shape)
print(df_sample['diagnosis'].value_counts())



(225, 12)
diagnosis
ADHD       45
HEALTHY    45
MDD        45
OCD        45
SMC        45
Name: count, dtype: int64


In [3]:
# functions for loading the epoched EEG data
def get_filepath(epoch_dir, participant_ids):
    """
    Function to get the filepath of the epoched EEG recording
    :param epoch_dir: directory containing the epoched EEG recordings
    :param ID: list of participant IDs to include
    """
    filepaths = []
    for subdir, dirs, files in os.walk(epoch_dir):
        for file in files:
            if any(participant_id in file for participant_id in participant_ids):
                filepaths.append(os.path.join(subdir, file))
    return filepaths

class EpochDataset(torch.utils.data.Dataset):
    def __init__(self, participant_ids, epoch_dir):
        self.filepaths = get_filepath(epoch_dir, participant_ids)
        self.participant_ids = participant_ids
        self.epochs = []
        self.participant_ids = []
        self._load_data()
        print(f"Number of epochs: {self.epochs.shape[0]}")
        print(f"Number of participants: {len(self.participant_ids)}")

    def _load_data(self):
        all_epochs = []
        for filepath in self.filepaths:
            epochs = torch.load(filepath)
            # get participant ID from filepath to make sure the participant ID is correct
            participant_id = filepath.split("\\")[-1].split(".")[0]
            all_epochs.append(epochs)
            self.participant_ids.extend([participant_id]*epochs.shape[0])
        self.epochs = np.concatenate(all_epochs, axis=0)

    def __len__(self):
        return self.epochs.shape[0]
    
    def __getitem__(self, idx):
        epoch = self.epochs[idx]
        participant_id = self.participant_ids[idx]
        return torch.tensor(epoch, dtype=torch.float32), participant_id

In [4]:
# load the epochs into a dataset
participant_ids = df_sample['participants_ID'].tolist()
dataset = EpochDataset(participant_ids, r"D:\Documents\RU\Master_Neurobiology\Internship_jaar_2\Project\TD-BRAIN\TDBRAIN-dataset-derivatives\thesis_epoched_data\EC")
print(len(dataset))
print(dataset[0][0].shape)
print(dataset[0][1])
print(dataset[1][1])

Number of epochs: 2688
Number of participants: 2688
2688
torch.Size([26, 1244])
sub-87964717
sub-87964717


## Transfering pretrained weights & extracting features

### Functions:

In [5]:
def transfer_weights(pretrained_weights, pretext_model):
    """
    Function to transfer the pretrained weights to the pretext model
    param: pretrained_weights: the weights to transfer in a dictionary
    param: pretext_model: the model to transfer the weights to
    """
    pretrained_model = pretext_model
    modified_keys = {}
    for k, v in pretrained_weights.items():
        decomposed_key = k.split('.')
        if decomposed_key[0] == 'EEGNet' or decomposed_key[0] == 'ShallowNet': # remove the first part of each key to match the model's keys
            pretrained_key = '.'.join(decomposed_key[1:])
            modified_keys[pretrained_key] = v

            
    pretrained_model.load_state_dict(modified_keys)
        
    return pretrained_model

def extract_features(pretrained_model, data, pretext_task, df_sample, to_disk=False):
    """
    Function to extract features from the pretrained model
    param: pretrained_model: the model to extract features from
    param: data: the dataset containing the epochs to extract features from
    param: pretext_task: a string indicating the specific pretext task to save the features as
    param: df_sample: the dataframe containing the sampled participant IDs and their corresponding diagnosis
    param: to_disk: boolean to save the features to disk
    """
    dataloader = torch.utils.data.DataLoader(data, batch_size=1, shuffle=False)
    pretrained_model.eval()
    features_list = []
    participant_ids = []
    with torch.no_grad():  # Disable gradient calculation
        for batch in dataloader:
            epoch, participant_id = batch  # Remove the batch dimension
            epoch = epoch.unsqueeze(0)  # Add dimension
            # print(epoch.shape)
            features = pretrained_model(epoch)  # Extract features
            features = features.squeeze(0)
            features = features.numpy()
            features_list.append(features)
            participant_ids.append(participant_id[0])

    
    features_df = pd.DataFrame(features_list) # store as dataframe
    features_df['ID'] = participant_ids # add participant IDs to the dataframe
    # map the diagnosis values from df_sample to the dataframe based on participant IDs
    features_df['diagnosis'] = features_df['ID'].map(df_sample.set_index('participants_ID')['diagnosis'])
    
    print(f'{features_df.shape = }')
    display(features_df.head(3))


    if to_disk:
        features_df.to_pickle(f'D:/Documents/Master_Data_Science/Thesis/thesis_code/DataScience_Thesis/data/SSL_features/df_{pretext_task}_features.pkl')
    
    return features_df

def evaluate_features(features_df):
    """
    Function to quickly evaluate the extracted features. Doesn't stratify/group data splitting! 
    """
    groups = features_df['ID']
    X = features_df.drop(['ID', 'diagnosis'], axis=1)
    y = features_df['diagnosis']

    sgkf = StratifiedGroupKFold(n_splits=5, shuffle=False)

    # Get the train and test indices
    for train_index, test_index in sgkf.split(X, y, groups):
        X_train, X_test = X.iloc[train_index], X.iloc[test_index]
        y_train, y_test = y.iloc[train_index], y.iloc[test_index]
        break  # Only use the first split

    # quick SVM model
    clf = SVC()
    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)
    print('quick SVM model')
    print(confusion_matrix(y_test, y_pred))
    print(classification_report(y_test, y_pred))
    print(f1_score(y_test, y_pred, average='macro'))
    print()

    # quick random forest model
    clf = RandomForestClassifier()
    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)
    print('quick random forest model')
    print(confusion_matrix(y_test, y_pred))
    print(classification_report(y_test, y_pred))
    print(f1_score(y_test, y_pred, average='macro'))

    return

def get_ssl_features(
        pretext_task,
        data,
        df_sample,
        num_extracted_features=100,
        eval=True,
        to_disk=True,
        pretext_model='EEGNet',
        weights_dir=r'D:\Documents\Master_Data_Science\Thesis\thesis_code\DataScience_Thesis\data\pretext_model_weights'
                     ):
    """
    Obtains SSL features from the weights trained by the pretext model.
    param: pretext_task: a string indicating the specific pretext task to load the weights from and save the features as
    param: data: the dataset containing the epochs to extract features from
    param: df_sample: the dataframe containing the sampled participant IDs and their corresponding diagnosis
    param: num_extracted_features: the number of features to extract
    param: eval: boolean to evaluate the features
    param: to_disk: boolean to save the features to disk
    param: pretext_model: the model to extract features from
    param: weights_dir: the directory containing the weights of the pretext model
    """
    # Need to define model class here to avoid issues with different number of extracted features
    # create Conv2d with max norm constraint
    class Conv2dWithConstraint(nn.Conv2d):
        def __init__(self, *args, max_norm: int = 1, **kwargs):
            self.max_norm = max_norm
            super(Conv2dWithConstraint, self).__init__(*args, **kwargs)

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            self.weight.data = torch.renorm(self.weight.data, p=2, dim=0, maxnorm=self.max_norm)
            return super(Conv2dWithConstraint, self).forward(x)
        
    class EEGNet(nn.Module):
        """
        Code taken and adjusted from pytorch implementation of EEGNet
        url: https://github.com/torcheeg/torcheeg/blob/v1.1.0/torcheeg/models/cnn/eegnet.py#L5
        """
        def __init__(self,
                    chunk_size: int = 1244, # number of data points in each EEG chunk
                    num_electrodes: int = 26, # number of EEG electrodes
                    F1: int = 8, # number of filters in first convolutional layer
                    F2: int = 16, # number of filters in second convolutional layer
                    D: int = 2, # depth multiplier
                    num_extracted_features: int = num_extracted_features, # number of features to extract
                    kernel_1: int = 64, # the filter size of block 1 (half of sfreq (125 Hz))
                    kernel_2: int = 16, # the filter size of block 2 (one eight of sfreq (500 Hz))
                    dropout: float = 0.25): # dropout rate
            super(EEGNet, self).__init__()
            self.F1 = F1
            self.F2 = F2
            self.D = D
            self.chunk_size = chunk_size
            self.num_extracted_features = num_extracted_features
            self.num_electrodes = num_electrodes
            self.kernel_1 = kernel_1
            self.kernel_2 = kernel_2
            self.dropout = dropout

            self.block1 = nn.Sequential(
                nn.Conv2d(1, self.F1, (1, self.kernel_1), stride=1, padding=(0, self.kernel_1 // 2), bias=False),
                nn.BatchNorm2d(self.F1, momentum=0.01, affine=True, eps=1e-3),
                Conv2dWithConstraint(self.F1,
                                    self.F1 * self.D, (self.num_electrodes, 1),
                                    max_norm=1,
                                    stride=1,
                                    padding=(0, 0),
                                    groups=self.F1,
                                    bias=False), nn.BatchNorm2d(self.F1 * self.D, momentum=0.01, affine=True, eps=1e-3),
                nn.ELU(), nn.AvgPool2d((1, 4), stride=4), nn.Dropout(p=dropout))

            self.block2 = nn.Sequential(
                nn.Conv2d(self.F1 * self.D,
                        self.F1 * self.D, (1, self.kernel_2),
                        stride=1,
                        padding=(0, self.kernel_2 // 2),
                        bias=False,
                        groups=self.F1 * self.D),
                nn.Conv2d(self.F1 * self.D, self.F2, 1, padding=(0, 0), groups=1, bias=False, stride=1),
                nn.BatchNorm2d(self.F2, momentum=0.01, affine=True, eps=1e-3), nn.ELU(), nn.AvgPool2d((1, 8), stride=8),
                nn.Dropout(p=dropout))

            self.lin = nn.Linear(self.feature_dim(), num_extracted_features, bias=False)


        def feature_dim(self):
            # function to calculate the number of features after the convolutional blocks
            with torch.no_grad():
                mock_eeg = torch.zeros(1, 1, self.num_electrodes, self.chunk_size)

                mock_eeg = self.block1(mock_eeg)
                mock_eeg = self.block2(mock_eeg)

            return self.F2 * mock_eeg.shape[3]

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            x = self.block1(x)
            x = self.block2(x)
            x = x.flatten(start_dim=1)
            x = self.lin(x)
            return x
    
    class ShallowNet(nn.Module):
        """
        Pytorch implementation of the ShallowNet Encoder.
        Code taken and adjusted from:
        https://github.com/MedMaxLab/selfEEG/blob/024402ba4bde95051d86ab2524cc71105bfd5c25/selfeeg/models/zoo.py#L693
        """

        def __init__(self,
                    samples=1244,
                    chans=26, # number of EEG channels
                    F=40, # number of output filters in the temporal convolution layer
                    K1=25, # length of the temporal convolutional layer
                    pool=75, # temporal pooling kernel size
                    dropout=0.2, # dropout probability
                    num_extracted_features=num_extracted_features # number of features to extract
                    ):

            super(ShallowNet, self).__init__()
            self.conv1 = nn.Conv2d(1, F, (1, K1), stride=(1, 1))
            self.conv2 = nn.Conv2d(F, F, (chans, 1), stride=(1, 1))
            self.batch1 = nn.BatchNorm2d(F)
            self.pool2 = nn.AvgPool2d((1, pool), stride=(1, 15))
            self.flatten2 = nn.Flatten()
            self.drop1 = nn.Dropout(dropout)
            self.lin = nn.Linear(
                F * ((samples - K1 + 1 - pool) // 15 + 1), num_extracted_features
            )

        def forward(self, x):
            x = self.conv1(x)
            x = self.conv2(x)
            x = self.batch1(x)
            x = torch.square(x)
            x = self.pool2(x)
            x = torch.log(torch.clamp(x, 1e-7, 10000))
            x = self.flatten2(x)
            x = self.drop1(x)
            x = self.lin(x)

            return x
    
    if pretext_model == 'EEGNet':
        pretext_model = EEGNet()
    if pretext_model == 'ShallowNet':
        pretext_model = ShallowNet()
        
    pretrained_weights = torch.load(f'{weights_dir}\{pretext_task}_weights.pt')
    pretrained_model = transfer_weights(pretrained_weights, pretext_model)
    features_df = extract_features(pretrained_model, data, pretext_task, df_sample=df_sample, to_disk=to_disk)
    if eval:
        evaluate_features(features_df)
    
    return

### randomly initialized model

In [7]:
features_df = extract_features(EEGNet(), dataset, df_sample=df_sample, pretext_task='random', to_disk=False)
evaluate_features(features_df)

features_df.shape = (2688, 102)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,92,93,94,95,96,97,98,99,ID,diagnosis
0,-0.025231,0.558265,0.38643,-0.427448,-0.195549,-0.385354,0.035971,0.385304,0.106288,0.573986,...,0.261263,0.668248,-0.518208,0.242272,0.186938,0.214983,0.118263,-0.321973,sub-87964717,SMC
1,-0.111984,0.977839,0.131731,-0.051678,-0.290045,-0.434035,0.237411,-0.101681,-0.275212,0.762277,...,0.520117,0.624576,-0.485054,0.251544,-0.03708,-0.125105,-0.325259,-0.672581,sub-87964717,SMC
2,-0.033515,0.690026,-0.037421,-0.427968,-0.044509,-0.273484,0.223849,0.226518,0.001657,0.772421,...,0.178688,0.269243,-0.326347,-0.047144,-0.056313,0.331418,0.099657,-0.459115,sub-87964717,SMC


quick SVM model
[[58 17  5 16 12]
 [15 36  9 32 16]
 [13 29 24 21 21]
 [ 0 27 21 33 27]
 [ 3  4 15 26 60]]
              precision    recall  f1-score   support

        ADHD       0.65      0.54      0.59       108
     HEALTHY       0.32      0.33      0.33       108
         MDD       0.32      0.22      0.26       108
         OCD       0.26      0.31      0.28       108
         SMC       0.44      0.56      0.49       108

    accuracy                           0.39       540
   macro avg       0.40      0.39      0.39       540
weighted avg       0.40      0.39      0.39       540

0.3899649803774411

quick random forest model
[[62 14 14 10  8]
 [18 34 20 22 14]
 [22 28 26 15 17]
 [10 24 28 18 28]
 [ 3 23 15 19 48]]
              precision    recall  f1-score   support

        ADHD       0.54      0.57      0.56       108
     HEALTHY       0.28      0.31      0.29       108
         MDD       0.25      0.24      0.25       108
         OCD       0.21      0.17      0.19       

### pretext model with default parameters (0.25 dropout, 0 weight decay, binary cross entropy loss)

In [9]:
# best model checkpoint
get_ssl_features('contrastive_loss_default_pretext_model', dataset, df_sample, eval=True, to_disk=True)

features_df.shape = (2688, 102)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,92,93,94,95,96,97,98,99,ID,diagnosis
0,-0.42721,-0.32027,-0.282973,-0.141028,0.262942,0.065654,0.073298,0.137498,0.29613,-0.075196,...,-0.029153,0.157873,0.154806,0.03309,0.054093,-0.021883,-0.077166,0.171117,sub-87964717,SMC
1,-0.317875,-0.391166,-0.346829,-0.101455,0.218276,-0.002775,0.184389,0.275933,0.219873,-0.199641,...,0.079862,-0.126056,0.120228,0.241789,0.017197,-0.116866,0.137638,0.37838,sub-87964717,SMC
2,-0.189565,-0.228894,-0.134339,-0.100182,0.126142,0.122312,0.15434,0.137634,0.436666,0.000482,...,0.053833,0.14299,0.144225,-0.100617,0.098143,-0.059571,-0.064255,0.40905,sub-87964717,SMC


quick SVM model
[[56  9 14 14 15]
 [ 4 31 17 36 20]
 [16 12 24 38 18]
 [ 5 23 21 35 24]
 [ 0  5 29 20 54]]
              precision    recall  f1-score   support

        ADHD       0.69      0.52      0.59       108
     HEALTHY       0.39      0.29      0.33       108
         MDD       0.23      0.22      0.23       108
         OCD       0.24      0.32      0.28       108
         SMC       0.41      0.50      0.45       108

    accuracy                           0.37       540
   macro avg       0.39      0.37      0.38       540
weighted avg       0.39      0.37      0.38       540

0.3756998493301762

quick random forest model
[[63  9 13 13 10]
 [18 31 21 22 16]
 [28 25 14 23 18]
 [11 28 20 23 26]
 [ 4 21 23 18 42]]
              precision    recall  f1-score   support

        ADHD       0.51      0.58      0.54       108
     HEALTHY       0.27      0.29      0.28       108
         MDD       0.15      0.13      0.14       108
         OCD       0.23      0.21      0.22       

In [8]:
# overtrained model
get_ssl_features('overtrained_contrastive_loss_default_pretext_model', dataset, df_sample, eval=True, to_disk=True)

features_df.shape = (2688, 102)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,92,93,94,95,96,97,98,99,ID,diagnosis
0,0.128607,0.138197,-2.235987,-1.701625,3.545807,-1.874852,2.633586,-1.813814,-0.959881,-2.024573,...,0.010959,2.923007,3.074398,0.924193,-0.835261,-0.142691,-1.170481,-1.667304,sub-87964717,SMC
1,-1.118486,2.918775,-3.907192,-0.223409,4.939961,-0.68072,2.935375,-1.185846,1.424726,-0.685542,...,0.466219,-0.417178,-0.076883,1.810046,-1.779487,-0.976218,-1.046167,2.468539,sub-87964717,SMC
2,0.380765,-0.459981,-0.608639,-0.373267,2.232232,1.592137,0.924958,-3.015853,0.509243,-1.72856,...,0.33991,3.535902,2.091701,0.75523,-1.79828,-1.933364,-0.464456,-0.927624,sub-87964717,SMC


quick SVM model
[[58 13  3  5 29]
 [ 5 30 26 11 36]
 [ 6 24 14  7 57]
 [ 2 22 14 15 55]
 [ 1 14  7  8 78]]
              precision    recall  f1-score   support

        ADHD       0.81      0.54      0.64       108
     HEALTHY       0.29      0.28      0.28       108
         MDD       0.22      0.13      0.16       108
         OCD       0.33      0.14      0.19       108
         SMC       0.31      0.72      0.43       108

    accuracy                           0.36       540
   macro avg       0.39      0.36      0.34       540
weighted avg       0.39      0.36      0.34       540

0.34323051852264397

quick random forest model
[[59 15 13 11 10]
 [11 36 24 17 20]
 [10 30 23 17 28]
 [ 9 22 23 23 31]
 [ 4 20 21 24 39]]
              precision    recall  f1-score   support

        ADHD       0.63      0.55      0.59       108
     HEALTHY       0.29      0.33      0.31       108
         MDD       0.22      0.21      0.22       108
         OCD       0.25      0.21      0.23      