# AIM: Extract features on labeled data using the pretrained EEGNet
across subject pretext task

In [1]:
import numpy as np
import pandas as pd
import mne
import lightning.pytorch as pl
import torch
import torch.nn.functional as F
import torch.nn as nn
from torchmetrics import F1Score, Accuracy
import random
import os
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split, StratifiedGroupKFold
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix, f1_score
from IPython.display import display
%matplotlib inline


# prevent extensive logging
mne.set_log_level('WARNING')

## Loading epoch data & participant data of labeled sample
These dataframes have been filtered and stored in a previous project. See https://github.com/TSmolders/Internship_EEG for original code

In [2]:

df_participants = pd.read_pickle(r'D:\Documents\RU\Master_Neurobiology\Internship_jaar_2\Project\TD-BRAIN\TDBRAIN_participants_V2_data\df_participants.pkl')
sample_df = pd.read_pickle(r'D:\Documents\RU\Master_Neurobiology\Internship_jaar_2\Project\TD-BRAIN\TD-BRAIN_extracted_features\df_selected_stat_features.pkl')
sample_ids = sample_df['ID'].unique() # obtain unique IDs from subsampled dataframe containing epoched features
df_sample = df_participants[df_participants['participants_ID'].isin(sample_ids)] # filter participants dataframe to only include subsampled IDs
df_sample = df_sample[df_sample['sessID'] == 1] # filter first session
print(df_sample.shape)
print(df_sample['diagnosis'].value_counts())



(225, 12)
diagnosis
ADHD       45
HEALTHY    45
MDD        45
OCD        45
SMC        45
Name: count, dtype: int64


In [3]:
# functions for loading the epoched EEG data
def get_filepath(epoch_dir, participant_ids):
    """
    Function to get the filepath of the epoched EEG recording
    :param epoch_dir: directory containing the epoched EEG recordings
    :param participant_ids: list of participant IDs to include
    """
    filepaths = []
    found_ids = set()
    
    for subdir, dirs, files in os.walk(epoch_dir):
        for file in files:
            for participant_id in participant_ids:
                if participant_id in file:
                    filepaths.append(os.path.join(subdir, file))
                    found_ids.add(participant_id)
    
    # Print participant IDs if no files matching those IDs are found
    for participant_id in participant_ids:
        if participant_id not in found_ids:
            print(f"No files found for participant ID {participant_id}")

    return filepaths

class EpochDataset(torch.utils.data.Dataset):
    def __init__(self, participant_ids, epoch_dir):
        self.filepaths = get_filepath(epoch_dir, participant_ids)
        self.participant_ids = participant_ids
        self.epochs = []
        self.participant_ids = []
        self._load_data()
        print(f"Number of epochs: {self.epochs.shape[0]}")
        print(f"Number of participants: {len(self.participant_ids)}")

    def _load_data(self):
        all_epochs = []
        for filepath in self.filepaths:
            epochs = torch.load(filepath)
            # get participant ID from filepath to make sure the participant ID is correct
            participant_id = filepath.split("\\")[-1].split(".")[0]
            all_epochs.append(epochs)
            self.participant_ids.extend([participant_id]*epochs.shape[0])
        self.epochs = np.concatenate(all_epochs, axis=0)

    def __len__(self):
        return self.epochs.shape[0]
    
    def __getitem__(self, idx):
        epoch = self.epochs[idx]
        participant_id = self.participant_ids[idx]
        return torch.tensor(epoch, dtype=torch.float32), participant_id

In [4]:
# load the epochs into a dataset
participant_ids = df_sample['participants_ID'].tolist()
dataset = EpochDataset(participant_ids, r"D:\Documents\RU\Master_Neurobiology\Internship_jaar_2\Project\TD-BRAIN\TDBRAIN-dataset-derivatives\thesis_epoched_data\EC")
print(len(dataset))
print(dataset[0][0].shape)
print(dataset[0][1])
print(dataset[1][1])

No files found for participant ID sub-88073521
Number of epochs: 2688
Number of participants: 2688
2688
torch.Size([26, 1244])
sub-87964717
sub-87964717


participant sub-88073521 EC's recording is marked as BAD by the preprocessing pipeline of TDBRAIN

## Transfering pretrained weights & extracting features

### Functions:

In [5]:
def transfer_weights(pretrained_weights, pretext_model):
    """
    Function to transfer the pretrained weights to the pretext model
    param: pretrained_weights: the weights to transfer in a dictionary
    param: pretext_model: the model to transfer the weights to
    """
    pretrained_model = pretext_model
    modified_keys = {}
    for k, v in pretrained_weights.items():
        decomposed_key = k.split('.')
        if decomposed_key[0] == 'EEGNet' or decomposed_key[0] == 'ShallowNet': # remove the first part of each key to match the model's keys
            pretrained_key = '.'.join(decomposed_key[1:])
            modified_keys[pretrained_key] = v

            
    pretrained_model.load_state_dict(modified_keys)
        
    return pretrained_model

def extract_features(pretrained_model, data, pretext_task, df_sample, to_disk=False, num_extracted_features=100, add_missing_ids=True):
    """
    Function to extract features from the pretrained model
    param: pretrained_model: the model to extract features from
    param: data: the dataset containing the epochs to extract features from
    param: pretext_task: a string indicating the specific pretext task to save the features as
    param: df_sample: the dataframe containing the sampled participant IDs and their corresponding diagnosis
    param: to_disk: boolean to save the features to disk
    param: num_extracted_features: the number of features to extract
    param: add_missing_ids: boolean to add missing participant IDs to the dataframe
    """
    dataloader = torch.utils.data.DataLoader(data, batch_size=1, shuffle=False)
    pretrained_model.eval()
    features_list = []
    participant_ids = []
    with torch.no_grad():  # Disable gradient calculation
        for batch in dataloader:
            epoch, participant_id = batch  # Remove the batch dimension
            epoch = epoch.unsqueeze(0)  # Add dimension
            # print(epoch.shape)
            features = pretrained_model(epoch)  # Extract features
            features = features.squeeze(0)
            features = features.numpy()
            features_list.append(features)
            participant_ids.append(participant_id[0])

    if add_missing_ids:
        # Probably not the best idea to add missing IDs with NaN values (in this
        # case only one ID), but the baselines from the previous internetship were 
        # trained with the same approach (using mean impute for the 12 missing epochs of one participant). 
        # Considering only 1 participant is missing (due to 1 BAD preprocessing file), 
        # it shouldn't have a large impact on the results. But it's something to keep in mind for future work.

        # check if any IDs are missing compared to the sampled IDs
        missing_ids = set(df_sample['participants_ID'].tolist()) - set(participant_ids)
        # if they are missing, add them with NaN values
        if missing_ids:
            for missing_id in missing_ids:
                num_missing_epochs = 12  # Assuming 12 missing epochs for each missing participant
                for _ in range(num_missing_epochs):
                    features_list.append([np.nan] * num_extracted_features)
                    participant_ids.append(missing_id)

    features_df = pd.DataFrame(features_list) # store as dataframe
    features_df['ID'] = participant_ids # add participant IDs to the dataframe

    # map the diagnosis values from df_sample to the dataframe based on participant IDs
    features_df['diagnosis'] = features_df['ID'].map(df_sample.set_index('participants_ID')['diagnosis'])

    
    print(f'{features_df.shape = }')
    # print(f'{features_df["diagnosis"].value_counts()}')
    # print(f'{features_df.isnull().sum()}')
    display(features_df.head(3))


    if to_disk:
        features_df.to_pickle(f'D:/Documents/Master_Data_Science/Thesis/thesis_code/DataScience_Thesis/data/SSL_features/df_{pretext_task}_features.pkl')
    
    return features_df

def evaluate_features(features_df):
    """
    Function to quickly evaluate the extracted features. Doesn't stratify/group data splitting! 
    """
    groups = features_df['ID']
    X = features_df.drop(['ID', 'diagnosis'], axis=1)
    y = features_df['diagnosis']

    # if X contains NaN values, impute them with the feature-wise mean
    if X.isnull().values.any():
        X = X.fillna(X.mean())

    sgkf = StratifiedGroupKFold(n_splits=5, shuffle=False)

    # Get the train and test indices
    for train_index, test_index in sgkf.split(X, y, groups):
        X_train, X_test = X.iloc[train_index], X.iloc[test_index]
        y_train, y_test = y.iloc[train_index], y.iloc[test_index]
        break  # Only use the first split
    
    # quick SVM model
    clf = SVC()
    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)
    print('quick SVM model')
    print(confusion_matrix(y_test, y_pred))
    print(classification_report(y_test, y_pred))
    print(f1_score(y_test, y_pred, average='macro'))
    print()

    # quick random forest model
    clf = RandomForestClassifier()
    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)
    print('quick random forest model')
    print(confusion_matrix(y_test, y_pred))
    print(classification_report(y_test, y_pred))
    print(f1_score(y_test, y_pred, average='macro'))

    return

def get_ssl_features(
        pretext_task,
        data,
        df_sample,
        num_extracted_features=100,
        conv_length=25,
        eval=True,
        to_disk=True,
        pretext_model='EEGNet',
        weights_dir=r'D:\Documents\Master_Data_Science\Thesis\thesis_code\DataScience_Thesis\data\pretext_model_weights'
                     ):
    """
    Obtains SSL features from the weights trained by the pretext model.
    param: pretext_task: a string indicating the specific pretext task to load the weights from and save the features as
    param: data: the dataset containing the epochs to extract features from
    param: df_sample: the dataframe containing the sampled participant IDs and their corresponding diagnosis
    param: num_extracted_features: the number of features to extract
    param: conv_length: the length of the temporal convolutional filter of ShallowNet
    param: eval: boolean to evaluate the features
    param: to_disk: boolean to save the features to disk
    param: pretext_model: the model to extract features from
    param: weights_dir: the directory containing the weights of the pretext model
    """
    # Need to define model class here to avoid issues with different number of extracted features
    # create Conv2d with max norm constraint
    class Conv2dWithConstraint(nn.Conv2d):
        def __init__(self, *args, max_norm: int = 1, **kwargs):
            self.max_norm = max_norm
            super(Conv2dWithConstraint, self).__init__(*args, **kwargs)

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            self.weight.data = torch.renorm(self.weight.data, p=2, dim=0, maxnorm=self.max_norm)
            return super(Conv2dWithConstraint, self).forward(x)
        
    class EEGNet(nn.Module):
        """
        Code taken and adjusted from pytorch implementation of EEGNet
        url: https://github.com/torcheeg/torcheeg/blob/v1.1.0/torcheeg/models/cnn/eegnet.py#L5
        """
        def __init__(self,
                    chunk_size: int = 1244, # number of data points in each EEG chunk
                    num_electrodes: int = 26, # number of EEG electrodes
                    F1: int = 8, # number of filters in first convolutional layer
                    F2: int = 16, # number of filters in second convolutional layer
                    D: int = 2, # depth multiplier
                    num_extracted_features: int = num_extracted_features, # number of features to extract
                    kernel_1: int = 64, # the filter size of block 1 (half of sfreq (125 Hz))
                    kernel_2: int = 16, # the filter size of block 2 (one eight of sfreq (500 Hz))
                    dropout: float = 0.25): # dropout rate
            super(EEGNet, self).__init__()
            self.F1 = F1
            self.F2 = F2
            self.D = D
            self.chunk_size = chunk_size
            self.num_extracted_features = num_extracted_features
            self.num_electrodes = num_electrodes
            self.kernel_1 = kernel_1
            self.kernel_2 = kernel_2
            self.dropout = dropout

            self.block1 = nn.Sequential(
                nn.Conv2d(1, self.F1, (1, self.kernel_1), stride=1, padding=(0, self.kernel_1 // 2), bias=False),
                nn.BatchNorm2d(self.F1, momentum=0.01, affine=True, eps=1e-3),
                Conv2dWithConstraint(self.F1,
                                    self.F1 * self.D, (self.num_electrodes, 1),
                                    max_norm=1,
                                    stride=1,
                                    padding=(0, 0),
                                    groups=self.F1,
                                    bias=False), nn.BatchNorm2d(self.F1 * self.D, momentum=0.01, affine=True, eps=1e-3),
                nn.ELU(), nn.AvgPool2d((1, 4), stride=4), nn.Dropout(p=dropout))

            self.block2 = nn.Sequential(
                nn.Conv2d(self.F1 * self.D,
                        self.F1 * self.D, (1, self.kernel_2),
                        stride=1,
                        padding=(0, self.kernel_2 // 2),
                        bias=False,
                        groups=self.F1 * self.D),
                nn.Conv2d(self.F1 * self.D, self.F2, 1, padding=(0, 0), groups=1, bias=False, stride=1),
                nn.BatchNorm2d(self.F2, momentum=0.01, affine=True, eps=1e-3), nn.ELU(), nn.AvgPool2d((1, 8), stride=8),
                nn.Dropout(p=dropout))

            self.lin = nn.Linear(self.feature_dim(), num_extracted_features, bias=False)


        def feature_dim(self):
            # function to calculate the number of features after the convolutional blocks
            with torch.no_grad():
                mock_eeg = torch.zeros(1, 1, self.num_electrodes, self.chunk_size)

                mock_eeg = self.block1(mock_eeg)
                mock_eeg = self.block2(mock_eeg)

            return self.F2 * mock_eeg.shape[3]

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            x = self.block1(x)
            x = self.block2(x)
            x = x.flatten(start_dim=1)
            x = self.lin(x)
            return x
    
    class ShallowNet(nn.Module):
        """
        Pytorch implementation of the ShallowNet Encoder.
        Code taken and adjusted from:
        https://github.com/MedMaxLab/selfEEG/blob/024402ba4bde95051d86ab2524cc71105bfd5c25/selfeeg/models/zoo.py#L693
        """

        def __init__(self,
                    samples=1244,
                    chans=26, # number of EEG channels
                    F=40, # number of output filters in the temporal convolution layer
                    K1=conv_length, # length of the temporal convolutional layer
                    pool=75, # temporal pooling kernel size
                    dropout=0.2, # dropout probability
                    num_extracted_features=num_extracted_features # number of features to extract
                    ):

            super(ShallowNet, self).__init__()
            self.conv1 = nn.Conv2d(1, F, (1, K1), stride=(1, 1))
            self.conv2 = nn.Conv2d(F, F, (chans, 1), stride=(1, 1))
            self.batch1 = nn.BatchNorm2d(F)
            self.pool2 = nn.AvgPool2d((1, pool), stride=(1, 15))
            self.flatten2 = nn.Flatten()
            self.drop1 = nn.Dropout(dropout)
            self.lin = nn.Linear(
                F * ((samples - K1 + 1 - pool) // 15 + 1), num_extracted_features
            )

        def forward(self, x):
            x = self.conv1(x)
            x = self.conv2(x)
            x = self.batch1(x)
            x = torch.square(x)
            x = self.pool2(x)
            x = torch.log(torch.clamp(x, 1e-7, 10000))
            x = self.flatten2(x)
            x = self.drop1(x)
            x = self.lin(x)

            return x
    
    if pretext_model == 'EEGNet':
        pretext_model = EEGNet()
    if pretext_model == 'ShallowNet':
        pretext_model = ShallowNet()
        
    pretrained_weights = torch.load(f'{weights_dir}\{pretext_task}_weights.pt')
    pretrained_model = transfer_weights(pretrained_weights, pretext_model)
    features_df = extract_features(pretrained_model, data, pretext_task, df_sample=df_sample, to_disk=to_disk, num_extracted_features=num_extracted_features)
    if eval:
        evaluate_features(features_df)
    
    return

### 0. randomly initialized model

In [7]:
features_df = extract_features(EEGNet(), dataset, df_sample=df_sample, pretext_task='random', to_disk=False)
evaluate_features(features_df)

features_df.shape = (2688, 102)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,92,93,94,95,96,97,98,99,ID,diagnosis
0,-0.025231,0.558265,0.38643,-0.427448,-0.195549,-0.385354,0.035971,0.385304,0.106288,0.573986,...,0.261263,0.668248,-0.518208,0.242272,0.186938,0.214983,0.118263,-0.321973,sub-87964717,SMC
1,-0.111984,0.977839,0.131731,-0.051678,-0.290045,-0.434035,0.237411,-0.101681,-0.275212,0.762277,...,0.520117,0.624576,-0.485054,0.251544,-0.03708,-0.125105,-0.325259,-0.672581,sub-87964717,SMC
2,-0.033515,0.690026,-0.037421,-0.427968,-0.044509,-0.273484,0.223849,0.226518,0.001657,0.772421,...,0.178688,0.269243,-0.326347,-0.047144,-0.056313,0.331418,0.099657,-0.459115,sub-87964717,SMC


quick SVM model
[[58 17  5 16 12]
 [15 36  9 32 16]
 [13 29 24 21 21]
 [ 0 27 21 33 27]
 [ 3  4 15 26 60]]
              precision    recall  f1-score   support

        ADHD       0.65      0.54      0.59       108
     HEALTHY       0.32      0.33      0.33       108
         MDD       0.32      0.22      0.26       108
         OCD       0.26      0.31      0.28       108
         SMC       0.44      0.56      0.49       108

    accuracy                           0.39       540
   macro avg       0.40      0.39      0.39       540
weighted avg       0.40      0.39      0.39       540

0.3899649803774411

quick random forest model
[[62 14 14 10  8]
 [18 34 20 22 14]
 [22 28 26 15 17]
 [10 24 28 18 28]
 [ 3 23 15 19 48]]
              precision    recall  f1-score   support

        ADHD       0.54      0.57      0.56       108
     HEALTHY       0.28      0.31      0.29       108
         MDD       0.25      0.24      0.25       108
         OCD       0.21      0.17      0.19       

### 4. Contrastive Loss pretext task

#### 4a. Default (less noise and smaller notch filter range)

In [25]:
# best model checkpoint
get_ssl_features('contrastive_loss_ShallowNet_pretext_model', dataset, df_sample, eval=True, to_disk=True, pretext_model='ShallowNet')

features_df.shape = (2700, 102)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,92,93,94,95,96,97,98,99,ID,diagnosis
0,-2.369301,-4.337305,-4.308994,3.001537,-1.646338,-6.464166,3.485768,6.120035,-5.26361,-3.524369,...,4.930452,-5.001675,1.037297,-5.185491,-5.979504,-0.461733,-1.838399,-6.754786,sub-87964717,SMC
1,-2.853487,-4.522071,-3.694927,3.755439,-1.284667,-6.863893,3.434332,6.634506,-5.45096,-3.620916,...,4.37609,-4.527611,0.599956,-4.963607,-5.508793,-0.068413,-1.895499,-7.050792,sub-87964717,SMC
2,-2.440639,-4.519271,-3.691686,2.828543,-1.111394,-6.356917,3.354033,6.216813,-4.895897,-3.326767,...,4.624696,-4.263685,1.499501,-4.416896,-5.274657,0.247817,-2.149446,-6.030265,sub-87964717,SMC


quick SVM model
[[60  1 15 20 12]
 [ 0  1 25 49 33]
 [ 0  0  0 60 48]
 [ 2  1 23 35 47]
 [ 9  3 31 35 30]]
              precision    recall  f1-score   support

        ADHD       0.85      0.56      0.67       108
     HEALTHY       0.17      0.01      0.02       108
         MDD       0.00      0.00      0.00       108
         OCD       0.18      0.32      0.23       108
         SMC       0.18      0.28      0.22       108

    accuracy                           0.23       540
   macro avg       0.27      0.23      0.23       540
weighted avg       0.27      0.23      0.23       540

0.2263550577094188

quick random forest model
[[68  8 11 15  6]
 [ 8 32 23 22 23]
 [12 16 13 29 38]
 [12 12 27 19 38]
 [12 23 22 23 28]]
              precision    recall  f1-score   support

        ADHD       0.61      0.63      0.62       108
     HEALTHY       0.35      0.30      0.32       108
         MDD       0.14      0.12      0.13       108
         OCD       0.18      0.18      0.18       

In [26]:
# overtrained model
get_ssl_features('fullytrained_contrastive_loss_ShallowNet_pretext_model', dataset, df_sample, eval=True, to_disk=True, pretext_model='ShallowNet')

features_df.shape = (2700, 102)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,92,93,94,95,96,97,98,99,ID,diagnosis
0,10.451659,30.639395,11.807755,6.348997,17.861507,-31.225868,7.20275,8.405975,7.505048,18.203638,...,41.697327,-0.088799,13.625303,35.362148,20.136534,13.248628,27.597157,24.472734,sub-87964717,SMC
1,4.270686,31.222342,11.271786,6.745347,21.709585,-31.926849,8.319916,9.044333,7.242358,16.904369,...,40.668343,4.712557,12.581395,36.615234,24.236464,14.825472,25.072674,22.385237,sub-87964717,SMC
2,8.565495,23.934246,13.511322,5.521217,27.356424,-27.029285,8.649724,7.774141,6.803407,13.373815,...,43.222195,7.897653,17.840275,29.360674,16.3503,21.295179,20.767286,21.513336,sub-87964717,SMC


quick SVM model
[[52  3 24 19 10]
 [ 0  2 38 52 16]
 [ 0  2  5 63 38]
 [ 6  1 21 39 41]
 [15  4 28 37 24]]
              precision    recall  f1-score   support

        ADHD       0.71      0.48      0.57       108
     HEALTHY       0.17      0.02      0.03       108
         MDD       0.04      0.05      0.04       108
         OCD       0.19      0.36      0.25       108
         SMC       0.19      0.22      0.20       108

    accuracy                           0.23       540
   macro avg       0.26      0.23      0.22       540
weighted avg       0.26      0.23      0.22       540

0.22007529805457024

quick random forest model
[[65 13 21  6  3]
 [ 2 32 20 27 27]
 [ 8 11 17 40 32]
 [ 8 21 27 20 32]
 [20 24 25 11 28]]
              precision    recall  f1-score   support

        ADHD       0.63      0.60      0.62       108
     HEALTHY       0.32      0.30      0.31       108
         MDD       0.15      0.16      0.16       108
         OCD       0.19      0.19      0.19      

#### 4b. more noise

In [6]:
# best model checkpoint
get_ssl_features('morenoise_contrastive_loss_ShallowNet_pretext_model', dataset, df_sample, eval=True, to_disk=True, pretext_model='ShallowNet')

features_df.shape = (2700, 102)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,92,93,94,95,96,97,98,99,ID,diagnosis
0,-4.375022,5.794556,6.069884,6.09371,1.83473,5.716855,-3.119078,0.474394,6.809621,-1.210216,...,-4.96541,2.371307,-5.030252,5.662517,7.00875,5.72688,-6.449692,4.66626,sub-87964717,SMC
1,-4.861523,5.858618,6.520693,6.273579,1.634162,6.12232,-2.914032,0.609031,7.766216,-1.138097,...,-5.335489,2.482917,-4.927655,6.497468,7.524563,6.575546,-6.70115,4.664319,sub-87964717,SMC
2,-3.831183,5.160967,5.403428,5.270938,1.59719,5.078155,-3.127128,0.630827,6.514303,-1.870754,...,-4.722949,2.200771,-4.730864,5.650596,6.58805,5.984232,-5.903697,3.915682,sub-87964717,SMC


quick SVM model
[[60  2 15 20 11]
 [ 1  0 23 53 31]
 [ 0  0  0 60 48]
 [ 2  2 23 35 46]
 [ 5  0 34 44 25]]
              precision    recall  f1-score   support

        ADHD       0.88      0.56      0.68       108
     HEALTHY       0.00      0.00      0.00       108
         MDD       0.00      0.00      0.00       108
         OCD       0.17      0.32      0.22       108
         SMC       0.16      0.23      0.19       108

    accuracy                           0.22       540
   macro avg       0.24      0.22      0.22       540
weighted avg       0.24      0.22      0.22       540

0.21728835755322745

quick random forest model
[[71 10  9 12  6]
 [ 8 25 26 28 21]
 [12 15  9 32 40]
 [15 19 27 21 26]
 [10 23 29 19 27]]
              precision    recall  f1-score   support

        ADHD       0.61      0.66      0.63       108
     HEALTHY       0.27      0.23      0.25       108
         MDD       0.09      0.08      0.09       108
         OCD       0.19      0.19      0.19      

In [7]:
# overtrained model
get_ssl_features('fullytrained_morenoise_contrastive_loss_ShallowNet_pretext_model', dataset, df_sample, eval=True, to_disk=True, pretext_model='ShallowNet')

features_df.shape = (2700, 102)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,92,93,94,95,96,97,98,99,ID,diagnosis
0,5.062675,14.4969,12.388856,14.589525,13.754212,16.743914,1.877846,12.042021,17.341433,7.048843,...,12.060273,5.8686,12.307624,13.730348,5.139685,6.761933,-20.693783,17.270153,sub-87964717,SMC
1,4.442379,14.378494,11.455668,14.687763,19.609699,17.661081,2.484261,10.325764,18.975815,1.307699,...,15.322884,3.847873,17.482456,13.7056,6.976749,8.85433,-23.702114,16.688944,sub-87964717,SMC
2,-0.776249,13.710887,12.173506,13.937624,20.609173,15.687565,2.53058,6.007638,18.193617,0.455962,...,10.397358,2.586116,7.428412,11.327138,6.06275,9.511157,-12.400146,19.127943,sub-87964717,SMC


quick SVM model
[[60  2 15 22  9]
 [ 6  1 24 55 22]
 [ 0  1  7 63 37]
 [13  0 21 32 42]
 [17  5 20 35 31]]
              precision    recall  f1-score   support

        ADHD       0.62      0.56      0.59       108
     HEALTHY       0.11      0.01      0.02       108
         MDD       0.08      0.06      0.07       108
         OCD       0.15      0.30      0.20       108
         SMC       0.22      0.29      0.25       108

    accuracy                           0.24       540
   macro avg       0.24      0.24      0.23       540
weighted avg       0.24      0.24      0.23       540

0.2258589540233764

quick random forest model
[[58 13 25  6  6]
 [ 6 28 21 33 20]
 [ 9 24  8 37 30]
 [10 25 20 26 27]
 [17 24 19 18 30]]
              precision    recall  f1-score   support

        ADHD       0.58      0.54      0.56       108
     HEALTHY       0.25      0.26      0.25       108
         MDD       0.09      0.07      0.08       108
         OCD       0.22      0.24      0.23       

#### 4c. larger filter range

In [8]:
# best model checkpoint
get_ssl_features('largerfilter_contrastive_loss_ShallowNet_pretext_model', dataset, df_sample, eval=True, to_disk=True, pretext_model='ShallowNet')

features_df.shape = (2700, 102)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,92,93,94,95,96,97,98,99,ID,diagnosis
0,-0.025737,-6.11589,5.020125,2.999495,-3.26637,-4.776773,-4.618756,-4.514643,4.348269,0.317911,...,-6.907289,4.401042,5.669729,-3.566941,-4.846494,-6.191189,-1.29423,-1.792427,sub-87964717,SMC
1,-0.97631,-6.085506,5.247127,2.422552,-3.270205,-5.534034,-3.866799,-4.48331,4.437867,0.64971,...,-6.701769,3.485051,5.769794,-4.290616,-4.510294,-6.559725,-1.004886,-1.825006,sub-87964717,SMC
2,-0.312646,-6.262755,4.751021,2.101804,-3.160367,-4.992777,-3.877764,-4.183433,3.646285,1.012359,...,-6.092437,4.04623,4.78868,-3.850194,-4.16473,-5.837598,-1.537514,-1.275081,sub-87964717,SMC


quick SVM model
[[61  0 13 22 12]
 [ 0  0 19 57 32]
 [ 0  0  0 60 48]
 [ 2  0 22 37 47]
 [10  1 27 41 29]]
              precision    recall  f1-score   support

        ADHD       0.84      0.56      0.67       108
     HEALTHY       0.00      0.00      0.00       108
         MDD       0.00      0.00      0.00       108
         OCD       0.17      0.34      0.23       108
         SMC       0.17      0.27      0.21       108

    accuracy                           0.24       540
   macro avg       0.24      0.24      0.22       540
weighted avg       0.24      0.24      0.22       540

0.22237407687996208

quick random forest model
[[72 11  5 16  4]
 [ 9 30 20 23 26]
 [ 9 17 10 31 41]
 [16 23 19 16 34]
 [ 9 24 21 26 28]]
              precision    recall  f1-score   support

        ADHD       0.63      0.67      0.65       108
     HEALTHY       0.29      0.28      0.28       108
         MDD       0.13      0.09      0.11       108
         OCD       0.14      0.15      0.15      

In [9]:
# overtrained model
get_ssl_features('fullytrained_largerfilter_contrastive_loss_ShallowNet_pretext_model', dataset, df_sample, eval=True, to_disk=True, pretext_model='ShallowNet')

features_df.shape = (2700, 102)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,92,93,94,95,96,97,98,99,ID,diagnosis
0,71.065979,36.948509,66.55809,56.225983,39.355999,-14.149976,43.426926,14.194438,87.023689,75.097778,...,40.441429,82.254112,66.507523,50.361832,-34.479275,26.571655,86.194702,108.976639,sub-87964717,SMC
1,68.458626,35.800407,70.292191,71.781693,34.119583,-13.852758,52.386978,10.035754,84.584549,80.632263,...,47.788017,85.976685,69.383492,48.467571,-45.635483,33.320778,97.97406,111.45816,sub-87964717,SMC
2,62.577991,37.145359,67.81678,54.796619,41.906292,-12.918624,47.667534,13.643794,80.488182,85.534607,...,37.664093,82.052422,65.906166,46.924179,-30.951738,26.604286,89.317467,109.873535,sub-87964717,SMC


quick SVM model
[[60  1 14 18 15]
 [ 8  9 25 15 51]
 [ 0  7  9 29 63]
 [ 4 11 36 12 45]
 [ 9 16 32 22 29]]
              precision    recall  f1-score   support

        ADHD       0.74      0.56      0.63       108
     HEALTHY       0.20      0.08      0.12       108
         MDD       0.08      0.08      0.08       108
         OCD       0.12      0.11      0.12       108
         SMC       0.14      0.27      0.19       108

    accuracy                           0.22       540
   macro avg       0.26      0.22      0.23       540
weighted avg       0.26      0.22      0.23       540

0.22756821321635218

quick random forest model
[[61  9  8 20 10]
 [16 22 19 15 36]
 [ 2 22 11 30 43]
 [ 5 19 31 24 29]
 [20 17 22 29 20]]
              precision    recall  f1-score   support

        ADHD       0.59      0.56      0.58       108
     HEALTHY       0.25      0.20      0.22       108
         MDD       0.12      0.10      0.11       108
         OCD       0.20      0.22      0.21      

#### 4d. subsequent augmentations

In [6]:
# best model checkpoint
get_ssl_features('subseqaug_contrastive_loss_ShallowNet_pretext_model', dataset, df_sample, eval=True, to_disk=True, pretext_model='ShallowNet')

features_df.shape = (2700, 102)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,92,93,94,95,96,97,98,99,ID,diagnosis
0,-0.96368,-3.591621,3.490758,-3.251682,-1.148758,3.85553,-2.337285,2.084611,2.285586,0.963054,...,-0.818257,3.968619,5.330093,-3.217261,-3.148831,-2.451833,-2.596372,-2.33103,sub-87964717,SMC
1,-1.363783,-4.350845,3.072384,-3.601652,-1.908192,4.745199,-2.351085,3.175907,3.03411,0.599782,...,-1.719849,4.780525,5.73867,-3.02659,-4.182706,-2.612415,-2.90841,-2.30594,sub-87964717,SMC
2,-0.607921,-3.72336,2.286056,-2.737145,-1.18909,3.720479,-1.992744,1.723267,3.045962,0.827706,...,-1.267317,3.959916,4.22477,-3.070462,-3.573229,-2.659045,-2.417223,-2.448407,sub-87964717,SMC


quick SVM model
[[61  4 12 20 11]
 [ 1 21 23 33 30]
 [ 0  8  7 46 47]
 [ 2  7 29 26 44]
 [ 8  7 34 27 32]]
              precision    recall  f1-score   support

        ADHD       0.85      0.56      0.68       108
     HEALTHY       0.45      0.19      0.27       108
         MDD       0.07      0.06      0.07       108
         OCD       0.17      0.24      0.20       108
         SMC       0.20      0.30      0.24       108

    accuracy                           0.27       540
   macro avg       0.35      0.27      0.29       540
weighted avg       0.35      0.27      0.29       540

0.2899534673781674

quick random forest model
[[72 11  8 10  7]
 [ 8 42 18 24 16]
 [ 9 16 13 34 36]
 [15 18 24 23 28]
 [11 22 27 19 29]]
              precision    recall  f1-score   support

        ADHD       0.63      0.67      0.65       108
     HEALTHY       0.39      0.39      0.39       108
         MDD       0.14      0.12      0.13       108
         OCD       0.21      0.21      0.21       

In [7]:
# overtrained model
get_ssl_features('fullytrained_subseqaug_contrastive_loss_ShallowNet_pretext_model', dataset, df_sample, eval=True, to_disk=True, pretext_model='ShallowNet')

features_df.shape = (2700, 102)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,92,93,94,95,96,97,98,99,ID,diagnosis
0,-2.365527,2.541262,1.100096,0.263822,9.412162,5.971048,-2.964276,4.720715,-0.906365,9.890745,...,-10.604981,-4.588036,-3.630062,-2.656839,8.993075,-1.147043,1.348928,-2.472632,sub-87964717,SMC
1,12.57119,3.077733,0.109194,-3.75491,-6.433097,3.554824,-4.701464,0.027851,1.998845,10.217842,...,6.874039,7.297047,-2.101861,-5.019956,11.835664,6.138152,-4.321651,-7.959854,sub-87964717,SMC
2,5.848882,-6.175635,-0.739543,-6.396237,6.124206,2.883111,-5.656449,3.49595,-4.675406,11.54311,...,-8.946095,-0.620383,-0.355117,-5.221846,8.734726,1.875726,-4.062213,-1.752447,sub-87964717,SMC


quick SVM model
[[44 24 10 19 11]
 [14 46 16 17 15]
 [ 2 33 24 22 27]
 [ 9 37 16 26 20]
 [ 9 14 29 15 41]]
              precision    recall  f1-score   support

        ADHD       0.56      0.41      0.47       108
     HEALTHY       0.30      0.43      0.35       108
         MDD       0.25      0.22      0.24       108
         OCD       0.26      0.24      0.25       108
         SMC       0.36      0.38      0.37       108

    accuracy                           0.34       540
   macro avg       0.35      0.34      0.34       540
weighted avg       0.35      0.34      0.34       540

0.3362587237092486

quick random forest model
[[46 19 10 19 14]
 [20 37 17 14 20]
 [ 8 32 23 20 25]
 [20 34 15 19 20]
 [14 15 19 24 36]]
              precision    recall  f1-score   support

        ADHD       0.43      0.43      0.43       108
     HEALTHY       0.27      0.34      0.30       108
         MDD       0.27      0.21      0.24       108
         OCD       0.20      0.18      0.19       

#### 4e. subsequent augmentations trained for 3000 epochs and 64 temp filter length

In [6]:
# best model checkpoint
get_ssl_features('subsaug64f_contrastive_loss_ShallowNet_pretext_model', dataset, df_sample, eval=True, to_disk=True, pretext_model='ShallowNet', conv_length=64)

features_df.shape = (2700, 102)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,92,93,94,95,96,97,98,99,ID,diagnosis
0,4.417768,-1.559999,4.175131,2.281379,5.381208,4.258815,-1.948829,-3.461449,-0.987703,3.185023,...,-4.469476,-1.576374,3.763077,2.700604,2.574667,3.813061,4.969516,-4.791446,sub-87964717,SMC
1,4.50046,-1.291694,4.185689,2.55144,6.034654,5.136297,-2.689325,-3.706076,-1.666248,4.16777,...,-5.280787,-1.949486,3.782542,4.056895,2.920238,4.448786,4.992928,-5.752219,sub-87964717,SMC
2,3.453121,-0.827561,3.077811,2.095023,4.728253,3.954333,-1.842008,-2.558147,-1.265961,3.330319,...,-4.563795,-1.284568,3.653576,3.010451,2.049146,3.317192,4.804771,-4.828982,sub-87964717,SMC


quick SVM model
[[60 17  2 19 10]
 [ 1 27  9 47 24]
 [ 0  9  0 58 41]
 [ 1 10 22 32 43]
 [ 7 29 11 40 21]]
              precision    recall  f1-score   support

        ADHD       0.87      0.56      0.68       108
     HEALTHY       0.29      0.25      0.27       108
         MDD       0.00      0.00      0.00       108
         OCD       0.16      0.30      0.21       108
         SMC       0.15      0.19      0.17       108

    accuracy                           0.26       540
   macro avg       0.30      0.26      0.27       540
weighted avg       0.30      0.26      0.27       540

0.2657065806628697

quick random forest model
[[63 10 19  8  8]
 [ 9 34 19 23 23]
 [15 16 11 29 37]
 [ 7 18 29 25 29]
 [17 24 22 18 27]]
              precision    recall  f1-score   support

        ADHD       0.57      0.58      0.58       108
     HEALTHY       0.33      0.31      0.32       108
         MDD       0.11      0.10      0.11       108
         OCD       0.24      0.23      0.24       

In [7]:
# overtrained model
get_ssl_features('fullytrained_subsaug64f_contrastive_loss_ShallowNet_pretext_model', dataset, df_sample, eval=True, to_disk=True, pretext_model='ShallowNet', conv_length=64)

features_df.shape = (2700, 102)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,92,93,94,95,96,97,98,99,ID,diagnosis
0,-2.702296,4.044283,-0.034612,2.119333,3.693329,5.193596,-5.600686,4.961525,1.553331,-0.481846,...,1.713349,0.660513,3.290288,0.753465,-1.346806,-0.859452,-4.114105,6.026663,sub-87964717,SMC
1,0.858901,-1.671891,1.751413,5.232744,3.083068,0.842232,1.67039,8.728765,7.049444,5.304541,...,-4.857956,2.994891,-1.128406,-2.183594,-6.029429,-4.887656,2.997276,4.020938,sub-87964717,SMC
2,-4.860819,4.865071,6.519952,3.495035,2.258393,-0.242716,1.899868,4.326125,0.873515,-2.444288,...,-2.769558,-5.850525,5.333314,8.29144,-5.350672,-1.408727,-0.852987,-1.8446,sub-87964717,SMC


quick SVM model
[[46 27 16 19  0]
 [15 54 14 21  4]
 [ 2 41 10 18 37]
 [15 36 15 27 15]
 [ 4 11 21 17 55]]
              precision    recall  f1-score   support

        ADHD       0.56      0.43      0.48       108
     HEALTHY       0.32      0.50      0.39       108
         MDD       0.13      0.09      0.11       108
         OCD       0.26      0.25      0.26       108
         SMC       0.50      0.51      0.50       108

    accuracy                           0.36       540
   macro avg       0.35      0.36      0.35       540
weighted avg       0.35      0.36      0.35       540

0.34844476748125863

quick random forest model
[[48 27 14 17  2]
 [20 46 10 24  8]
 [ 3 42 12 15 36]
 [ 6 43 19 18 22]
 [ 2 13 18 11 64]]
              precision    recall  f1-score   support

        ADHD       0.61      0.44      0.51       108
     HEALTHY       0.27      0.43      0.33       108
         MDD       0.16      0.11      0.13       108
         OCD       0.21      0.17      0.19      