# AIM: Extract features from the TDBRAIN replication set

For now only extracted for the best models based on downstream performance

In [1]:
import numpy as np
import pandas as pd
import mne
import lightning.pytorch as pl
import torch
import torch.nn.functional as F
import torch.nn as nn
from torchmetrics import F1Score, Accuracy
import random
import os
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split, StratifiedGroupKFold
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix, f1_score
from IPython.display import display
%matplotlib inline


# prevent extensive logging
mne.set_log_level('WARNING')

## Loading epoch data & participant data of replication sample

In [2]:

df_replication = pd.read_csv(r'D:\Documents\RU\Master_Neurobiology\Internship_jaar_2\Project\TD-BRAIN\TDBRAIN_participants_V2_data\TDBRAIN_replication_template_V2.csv')
# set participants_ID column as string
df_replication['participants_ID'] = df_replication['participants_ID'].astype(str)
df_replication

Unnamed: 0,participants_ID,DISC/REP,indication,formal_status,Dataset,Consent,sessSeason,sessTime,Responder,Remitter,...,BDI_pre,BDI_post,rTMS PROTOCOL,ADHD_pre_Hyp_leading,ADHD_pre_Att_leading,ADHD_post_Att_leading,ADHD_post_Hyp_leading,NF Protocol,YBOCS_pre,YBOCS_post
0,19681349,REPLICATION,REPLICATION,REPLICATION,REPLICATION,YES,REPLICATION,REPLICATION,REPLICATION,REPLICATION,...,REPLICATION,REPLICATION,2.0,REPLICATION,REPLICATION,REPLICATION,REPLICATION,REPLICATION,REPLICATION,REPLICATION
1,19681385,REPLICATION,REPLICATION,REPLICATION,REPLICATION,YES,REPLICATION,REPLICATION,REPLICATION,REPLICATION,...,REPLICATION,REPLICATION,1.0,REPLICATION,REPLICATION,REPLICATION,REPLICATION,REPLICATION,REPLICATION,REPLICATION
2,19684666,REPLICATION,REPLICATION,REPLICATION,REPLICATION,YES,REPLICATION,REPLICATION,REPLICATION,REPLICATION,...,REPLICATION,REPLICATION,1.0,REPLICATION,REPLICATION,REPLICATION,REPLICATION,REPLICATION,REPLICATION,REPLICATION
3,19686324,REPLICATION,REPLICATION,REPLICATION,REPLICATION,YES,REPLICATION,REPLICATION,REPLICATION,REPLICATION,...,REPLICATION,REPLICATION,2.0,REPLICATION,REPLICATION,REPLICATION,REPLICATION,REPLICATION,REPLICATION,REPLICATION
4,19687321,REPLICATION,REPLICATION,REPLICATION,REPLICATION,YES,REPLICATION,REPLICATION,REPLICATION,REPLICATION,...,REPLICATION,REPLICATION,3.0,REPLICATION,REPLICATION,REPLICATION,REPLICATION,REPLICATION,REPLICATION,REPLICATION
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
115,19739349,REPLICATION,REPLICATION,REPLICATION,REPLICATION,YES,REPLICATION,REPLICATION,REPLICATION,REPLICATION,...,REPLICATION,REPLICATION,,REPLICATION,REPLICATION,REPLICATION,REPLICATION,REPLICATION,REPLICATION,REPLICATION
116,19739424,REPLICATION,REPLICATION,REPLICATION,REPLICATION,YES,REPLICATION,REPLICATION,REPLICATION,REPLICATION,...,REPLICATION,REPLICATION,,REPLICATION,REPLICATION,REPLICATION,REPLICATION,REPLICATION,REPLICATION,REPLICATION
117,19740051,REPLICATION,REPLICATION,REPLICATION,REPLICATION,YES,REPLICATION,REPLICATION,REPLICATION,REPLICATION,...,REPLICATION,REPLICATION,8.0,REPLICATION,REPLICATION,REPLICATION,REPLICATION,REPLICATION,REPLICATION,REPLICATION
118,19740274,REPLICATION,REPLICATION,REPLICATION,REPLICATION,YES,REPLICATION,REPLICATION,REPLICATION,REPLICATION,...,REPLICATION,REPLICATION,,REPLICATION,REPLICATION,REPLICATION,REPLICATION,REPLICATION,REPLICATION,REPLICATION


In [3]:
# functions for loading the epoched EEG data
def get_filepath(epoch_dir, participant_ids):
    """
    Function to get the filepath of the epoched EEG recording
    :param epoch_dir: directory containing the epoched EEG recordings
    :param participant_ids: list of participant IDs to include
    """
    filepaths = []
    found_ids = set()
    
    for subdir, dirs, files in os.walk(epoch_dir):
        for file in files:
            for participant_id in participant_ids:
                if participant_id in file:
                    filepaths.append(os.path.join(subdir, file))
                    found_ids.add(participant_id)
    
    # Print participant IDs if no files matching those IDs are found
    for participant_id in participant_ids:
        if participant_id not in found_ids:
            print(f"No files found for participant ID {participant_id}")

    return filepaths

class EpochDataset(torch.utils.data.Dataset):
    def __init__(self, participant_ids, epoch_dir):
        self.filepaths = get_filepath(epoch_dir, participant_ids)
        self.participant_ids = participant_ids
        self.epochs = []
        self.participant_ids = []
        self._load_data()
        print(f"Number of epochs: {self.epochs.shape[0]}")
        print(f"Number of participants: {len(self.participant_ids)}")

    def _load_data(self):
        all_epochs = []
        for filepath in self.filepaths:
            epochs = torch.load(filepath)
            # get participant ID from filepath to make sure the participant ID is correct
            participant_id = filepath.split("\\")[-1].split(".")[0]
            all_epochs.append(epochs)
            self.participant_ids.extend([participant_id]*epochs.shape[0])
        self.epochs = np.concatenate(all_epochs, axis=0)

    def __len__(self):
        return self.epochs.shape[0]
    
    def __getitem__(self, idx):
        epoch = self.epochs[idx]
        participant_id = self.participant_ids[idx]
        return torch.tensor(epoch, dtype=torch.float32), participant_id

In [4]:
# load the epochs into a dataset
participant_ids = df_replication['participants_ID'].tolist()
dataset = EpochDataset(participant_ids, r"D:\Documents\RU\Master_Neurobiology\Internship_jaar_2\Project\TD-BRAIN\TDBRAIN-dataset-derivatives\thesis_epoched_data\EC")
print(len(dataset))
print(dataset[0][0].shape)
print(dataset[0][1])
print(dataset[1][1])

Number of epochs: 1440
Number of participants: 1440
1440
torch.Size([26, 1244])
sub-19681349
sub-19681349


## Transfering pretrained weights & extracting features

### Functions:

These functions have been adjusted specifically for the replication set

In [5]:
def transfer_weights(pretrained_weights, pretext_model):
    """
    Function to transfer the pretrained weights to the pretext model
    param: pretrained_weights: the weights to transfer in a dictionary
    param: pretext_model: the model to transfer the weights to
    """
    pretrained_model = pretext_model
    modified_keys = {}
    for k, v in pretrained_weights.items():
        decomposed_key = k.split('.')
        if decomposed_key[0] == 'EEGNet' or decomposed_key[0] == 'ShallowNet': # remove the first part of each key to match the model's keys
            pretrained_key = '.'.join(decomposed_key[1:])
            modified_keys[pretrained_key] = v

            
    pretrained_model.load_state_dict(modified_keys)
        
    return pretrained_model

def extract_features(pretrained_model, data, pretext_task, df_sample, to_disk=False, num_extracted_features=100, add_missing_ids=False):
    """
    Function to extract features from the pretrained model
    param: pretrained_model: the model to extract features from
    param: data: the dataset containing the epochs to extract features from
    param: pretext_task: a string indicating the specific pretext task to save the features as
    param: df_sample: the dataframe containing the sampled participant IDs and their corresponding diagnosis
    param: to_disk: boolean to save the features to disk
    param: num_extracted_features: the number of features to extract
    param: add_missing_ids: boolean to add missing participant IDs to the dataframe
    """
    dataloader = torch.utils.data.DataLoader(data, batch_size=1, shuffle=False)
    pretrained_model.eval()
    features_list = []
    participant_ids = []
    with torch.no_grad():  # Disable gradient calculation
        for batch in dataloader:
            epoch, participant_id = batch  # Remove the batch dimension
            epoch = epoch.unsqueeze(0)  # Add dimension
            # print(epoch.shape)
            features = pretrained_model(epoch)  # Extract features
            features = features.squeeze(0)
            features = features.numpy()
            features_list.append(features)
            participant_ids.append(participant_id[0])

    if add_missing_ids:
        # Probably not the best idea to add missing IDs with NaN values (in this
        # case only one ID), but the baselines from the previous internetship were 
        # trained with the same approach (using mean impute for the 12 missing epochs of one participant). 
        # Considering only 1 participant is missing (due to 1 BAD preprocessing file), 
        # it shouldn't have a large impact on the results. But it's something to keep in mind for future work.

        # check if any IDs are missing compared to the sampled IDs
        missing_ids = set(df_sample['participants_ID'].tolist()) - set(participant_ids)
        # if they are missing, add them with NaN values
        if missing_ids:
            for missing_id in missing_ids:
                num_missing_epochs = 12  # Assuming 12 missing epochs for each missing participant
                for _ in range(num_missing_epochs):
                    features_list.append([np.nan] * num_extracted_features)
                    participant_ids.append(missing_id)

    features_df = pd.DataFrame(features_list) # store as dataframe
    features_df['ID'] = participant_ids # add participant IDs to the dataframe

    
    print(f'{features_df.shape = }')
    # print(f'{features_df["diagnosis"].value_counts()}')
    # print(f'{features_df.isnull().sum()}')
    display(features_df.head(3))


    if to_disk:
        features_df.to_pickle(f'D:/Documents/Master_Data_Science/Thesis/thesis_code/DataScience_Thesis/data/SSL_features/df_replication_set_{pretext_task}_features.pkl')
    
    return features_df

def evaluate_features(features_df):
    """
    Function to quickly evaluate the extracted features. Doesn't stratify/group data splitting! 
    """
    groups = features_df['ID']
    X = features_df.drop(['ID', 'diagnosis'], axis=1)
    y = features_df['diagnosis']

    # if X contains NaN values, impute them with the feature-wise mean
    if X.isnull().values.any():
        X = X.fillna(X.mean())

    sgkf = StratifiedGroupKFold(n_splits=5, shuffle=False)

    # Get the train and test indices
    for train_index, test_index in sgkf.split(X, y, groups):
        X_train, X_test = X.iloc[train_index], X.iloc[test_index]
        y_train, y_test = y.iloc[train_index], y.iloc[test_index]
        break  # Only use the first split
    
    # quick SVM model
    clf = SVC()
    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)
    print('quick SVM model')
    print(confusion_matrix(y_test, y_pred))
    print(classification_report(y_test, y_pred))
    print(f1_score(y_test, y_pred, average='macro'))
    print()

    # quick random forest model
    clf = RandomForestClassifier()
    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)
    print('quick random forest model')
    print(confusion_matrix(y_test, y_pred))
    print(classification_report(y_test, y_pred))
    print(f1_score(y_test, y_pred, average='macro'))

    return

def get_ssl_features(
        pretext_task,
        data,
        df_sample,
        num_extracted_features=100,
        eval=True,
        to_disk=True,
        pretext_model='EEGNet',
        weights_dir=r'D:\Documents\Master_Data_Science\Thesis\thesis_code\DataScience_Thesis\data\pretext_model_weights'
                     ):
    """
    Obtains SSL features from the weights trained by the pretext model.
    param: pretext_task: a string indicating the specific pretext task to load the weights from and save the features as
    param: data: the dataset containing the epochs to extract features from
    param: df_sample: the dataframe containing the sampled participant IDs and their corresponding diagnosis
    param: num_extracted_features: the number of features to extract
    param: eval: boolean to evaluate the features
    param: to_disk: boolean to save the features to disk
    param: pretext_model: the model to extract features from
    param: weights_dir: the directory containing the weights of the pretext model
    """
    # Need to define model class here to avoid issues with different number of extracted features
    # create Conv2d with max norm constraint
    class Conv2dWithConstraint(nn.Conv2d):
        def __init__(self, *args, max_norm: int = 1, **kwargs):
            self.max_norm = max_norm
            super(Conv2dWithConstraint, self).__init__(*args, **kwargs)

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            self.weight.data = torch.renorm(self.weight.data, p=2, dim=0, maxnorm=self.max_norm)
            return super(Conv2dWithConstraint, self).forward(x)
        
    class EEGNet(nn.Module):
        """
        Code taken and adjusted from pytorch implementation of EEGNet
        url: https://github.com/torcheeg/torcheeg/blob/v1.1.0/torcheeg/models/cnn/eegnet.py#L5
        """
        def __init__(self,
                    chunk_size: int = 1244, # number of data points in each EEG chunk
                    num_electrodes: int = 26, # number of EEG electrodes
                    F1: int = 8, # number of filters in first convolutional layer
                    F2: int = 16, # number of filters in second convolutional layer
                    D: int = 2, # depth multiplier
                    num_extracted_features: int = num_extracted_features, # number of features to extract
                    kernel_1: int = 64, # the filter size of block 1 (half of sfreq (125 Hz))
                    kernel_2: int = 16, # the filter size of block 2 (one eight of sfreq (500 Hz))
                    dropout: float = 0.25): # dropout rate
            super(EEGNet, self).__init__()
            self.F1 = F1
            self.F2 = F2
            self.D = D
            self.chunk_size = chunk_size
            self.num_extracted_features = num_extracted_features
            self.num_electrodes = num_electrodes
            self.kernel_1 = kernel_1
            self.kernel_2 = kernel_2
            self.dropout = dropout

            self.block1 = nn.Sequential(
                nn.Conv2d(1, self.F1, (1, self.kernel_1), stride=1, padding=(0, self.kernel_1 // 2), bias=False),
                nn.BatchNorm2d(self.F1, momentum=0.01, affine=True, eps=1e-3),
                Conv2dWithConstraint(self.F1,
                                    self.F1 * self.D, (self.num_electrodes, 1),
                                    max_norm=1,
                                    stride=1,
                                    padding=(0, 0),
                                    groups=self.F1,
                                    bias=False), nn.BatchNorm2d(self.F1 * self.D, momentum=0.01, affine=True, eps=1e-3),
                nn.ELU(), nn.AvgPool2d((1, 4), stride=4), nn.Dropout(p=dropout))

            self.block2 = nn.Sequential(
                nn.Conv2d(self.F1 * self.D,
                        self.F1 * self.D, (1, self.kernel_2),
                        stride=1,
                        padding=(0, self.kernel_2 // 2),
                        bias=False,
                        groups=self.F1 * self.D),
                nn.Conv2d(self.F1 * self.D, self.F2, 1, padding=(0, 0), groups=1, bias=False, stride=1),
                nn.BatchNorm2d(self.F2, momentum=0.01, affine=True, eps=1e-3), nn.ELU(), nn.AvgPool2d((1, 8), stride=8),
                nn.Dropout(p=dropout))

            self.lin = nn.Linear(self.feature_dim(), num_extracted_features, bias=False)


        def feature_dim(self):
            # function to calculate the number of features after the convolutional blocks
            with torch.no_grad():
                mock_eeg = torch.zeros(1, 1, self.num_electrodes, self.chunk_size)

                mock_eeg = self.block1(mock_eeg)
                mock_eeg = self.block2(mock_eeg)

            return self.F2 * mock_eeg.shape[3]

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            x = self.block1(x)
            x = self.block2(x)
            x = x.flatten(start_dim=1)
            x = self.lin(x)
            return x
    
    class ShallowNet(nn.Module):
        """
        Pytorch implementation of the ShallowNet Encoder.
        Code taken and adjusted from:
        https://github.com/MedMaxLab/selfEEG/blob/024402ba4bde95051d86ab2524cc71105bfd5c25/selfeeg/models/zoo.py#L693
        """

        def __init__(self,
                    samples=1244,
                    chans=26, # number of EEG channels
                    F=40, # number of output filters in the temporal convolution layer
                    K1=25, # length of the temporal convolutional layer
                    pool=75, # temporal pooling kernel size
                    dropout=0.2, # dropout probability
                    num_extracted_features=num_extracted_features # number of features to extract
                    ):

            super(ShallowNet, self).__init__()
            self.conv1 = nn.Conv2d(1, F, (1, K1), stride=(1, 1))
            self.conv2 = nn.Conv2d(F, F, (chans, 1), stride=(1, 1))
            self.batch1 = nn.BatchNorm2d(F)
            self.pool2 = nn.AvgPool2d((1, pool), stride=(1, 15))
            self.flatten2 = nn.Flatten()
            self.drop1 = nn.Dropout(dropout)
            self.lin = nn.Linear(
                F * ((samples - K1 + 1 - pool) // 15 + 1), num_extracted_features
            )

        def forward(self, x):
            x = self.conv1(x)
            x = self.conv2(x)
            x = self.batch1(x)
            x = torch.square(x)
            x = self.pool2(x)
            x = torch.log(torch.clamp(x, 1e-7, 10000))
            x = self.flatten2(x)
            x = self.drop1(x)
            x = self.lin(x)

            return x
    
    if pretext_model == 'EEGNet':
        pretext_model = EEGNet()
    if pretext_model == 'ShallowNet':
        pretext_model = ShallowNet()
        
    pretrained_weights = torch.load(f'{weights_dir}\{pretext_task}_weights.pt')
    pretrained_model = transfer_weights(pretrained_weights, pretext_model)
    features_df = extract_features(pretrained_model, data, pretext_task, df_sample=df_sample, to_disk=to_disk, num_extracted_features=num_extracted_features)
    if eval:
        evaluate_features(features_df)
    
    return

### 3. Across Subject pretext task

In [6]:
# overtrained model
get_ssl_features('fullytrained_acrossSub_ShallowNet_pretext_model', dataset, df_replication, eval=False, to_disk=True, pretext_model='ShallowNet')

features_df.shape = (1440, 101)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,91,92,93,94,95,96,97,98,99,ID
0,-7.190171,-19.356483,7.035101,-5.932686,1.865819,0.201499,22.717876,-1.551631,6.496997,-29.09519,...,5.371091,6.682823,14.467761,-4.801239,10.828082,4.138763,-21.877794,-3.519001,-18.737434,sub-19681349
1,-13.781171,-5.390237,-1.49506,-17.824326,-2.177072,4.026847,23.834311,-5.461724,-3.255945,-23.361437,...,10.863097,-2.749647,7.217631,-6.415405,5.745875,0.000694,-17.799702,-6.54289,-14.986298,sub-19681349
2,-18.367344,0.138474,-2.984107,-7.175438,-0.735494,3.747396,13.888578,-4.137311,-4.859614,-28.061474,...,12.408052,-5.423985,11.853927,2.480663,7.429001,-6.48644,-18.050285,-3.872681,-8.611533,sub-19681349
