# AIM: Extract features on labeled data using the pretrained EEGNet
for demeaned epochs

In [1]:
import numpy as np
import pandas as pd
import mne
import lightning.pytorch as pl
import torch
import torch.nn.functional as F
import torch.nn as nn
from torchmetrics import F1Score, Accuracy
import random
import os
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix, f1_score
from IPython.display import display
%matplotlib inline


# prevent extensive logging
mne.set_log_level('WARNING')

## Loading epoch data & participant data of labeled sample
These dataframes have been filtered and stored in a previous project. See https://github.com/TSmolders/Internship_EEG for original code

In [2]:

df_participants = pd.read_pickle(r'D:\Documents\RU\Master_Neurobiology\Internship_jaar_2\Project\TD-BRAIN\TDBRAIN_participants_V2_data\df_participants.pkl')
sample_df = pd.read_pickle(r'D:\Documents\RU\Master_Neurobiology\Internship_jaar_2\Project\TD-BRAIN\TD-BRAIN_extracted_features\df_selected_stat_features.pkl')
sample_ids = sample_df['ID'].unique() # obtain unique IDs from subsampled dataframe containing epoched features
df_sample = df_participants[df_participants['participants_ID'].isin(sample_ids)] # filter participants dataframe to only include subsampled IDs
df_sample = df_sample[df_sample['sessID'] == 1] # filter first session
print(df_sample.shape)
print(df_sample['diagnosis'].value_counts())



(225, 12)
diagnosis
ADHD       45
HEALTHY    45
MDD        45
OCD        45
SMC        45
Name: count, dtype: int64


In [3]:
# functions for loading the epoched EEG data
def get_filepath(epoch_dir, participant_ids):
    """
    Function to get the filepath of the epoched EEG recording
    :param epoch_dir: directory containing the epoched EEG recordings
    :param ID: list of participant IDs to include
    """
    filepaths = []
    for subdir, dirs, files in os.walk(epoch_dir):
        for file in files:
            if any(participant_id in file for participant_id in participant_ids):
                filepaths.append(os.path.join(subdir, file))
    return filepaths

class EpochDataset(torch.utils.data.Dataset):
    def __init__(self, participant_ids, epoch_dir):
        self.filepaths = get_filepath(epoch_dir, participant_ids)
        self.participant_ids = participant_ids
        self.epochs = []
        self.participant_ids = []
        self._load_data()
        print(f"Number of epochs: {self.epochs.shape[0]}")
        print(f"Number of participants: {len(self.participant_ids)}")

    def _load_data(self):
        all_epochs = []
        for filepath in self.filepaths:
            epochs = torch.load(filepath)
            # get participant ID from filepath to make sure the participant ID is correct
            participant_id = filepath.split("\\")[-1].split(".")[0]
            all_epochs.append(epochs)
            self.participant_ids.extend([participant_id]*epochs.shape[0])
        self.epochs = np.concatenate(all_epochs, axis=0)

    def __len__(self):
        return self.epochs.shape[0]
    
    def __getitem__(self, idx):
        epoch = self.epochs[idx]
        participant_id = self.participant_ids[idx]
        return torch.tensor(epoch, dtype=torch.float32), participant_id

In [4]:
# load the epochs into a dataset
participant_ids = df_sample['participants_ID'].tolist()
dataset = EpochDataset(participant_ids, r"D:\Documents\RU\Master_Neurobiology\Internship_jaar_2\Project\TD-BRAIN\TDBRAIN-dataset-derivatives\thesis_epoched_data\EC")
print(len(dataset))
print(dataset[0][0].shape)
print(dataset[0][1])
print(dataset[1][1])

Number of epochs: 2688
Number of participants: 2688
2688
torch.Size([26, 1244])
sub-87964717
sub-87964717


## EEGNet architecture

In [5]:
# create Conv2d with max norm constraint
class Conv2dWithConstraint(nn.Conv2d):
    def __init__(self, *args, max_norm: int = 1, **kwargs):
        self.max_norm = max_norm
        super(Conv2dWithConstraint, self).__init__(*args, **kwargs)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        self.weight.data = torch.renorm(self.weight.data, p=2, dim=0, maxnorm=self.max_norm)
        return super(Conv2dWithConstraint, self).forward(x)
    
class EEGNet(nn.Module):
    """
    Code taken and adjusted from pytorch implementation of EEGNet
    url: https://github.com/torcheeg/torcheeg/blob/v1.1.0/torcheeg/models/cnn/eegnet.py#L5
    """
    def __init__(self,
                 chunk_size: int = 1244, # number of data points in each EEG chunk
                 num_electrodes: int = 26, # number of EEG electrodes
                 F1: int = 8, # number of filters in first convolutional layer
                 F2: int = 16, # number of filters in second convolutional layer
                 D: int = 2, # depth multiplier
                 num_extracted_features: int = 100, # number of features to extract
                 kernel_1: int = 64, # the filter size of block 1 (half of sfreq (125 Hz))
                 kernel_2: int = 16, # the filter size of block 2 (one eight of sfreq (500 Hz))
                 dropout: float = 0.25): # dropout rate
        super(EEGNet, self).__init__()
        self.F1 = F1
        self.F2 = F2
        self.D = D
        self.chunk_size = chunk_size
        self.num_extracted_features = num_extracted_features
        self.num_electrodes = num_electrodes
        self.kernel_1 = kernel_1
        self.kernel_2 = kernel_2
        self.dropout = dropout

        self.block1 = nn.Sequential(
            nn.Conv2d(1, self.F1, (1, self.kernel_1), stride=1, padding=(0, self.kernel_1 // 2), bias=False),
            nn.BatchNorm2d(self.F1, momentum=0.01, affine=True, eps=1e-3),
            Conv2dWithConstraint(self.F1,
                                 self.F1 * self.D, (self.num_electrodes, 1),
                                 max_norm=1,
                                 stride=1,
                                 padding=(0, 0),
                                 groups=self.F1,
                                 bias=False), nn.BatchNorm2d(self.F1 * self.D, momentum=0.01, affine=True, eps=1e-3),
            nn.ELU(), nn.AvgPool2d((1, 4), stride=4), nn.Dropout(p=dropout))

        self.block2 = nn.Sequential(
            nn.Conv2d(self.F1 * self.D,
                      self.F1 * self.D, (1, self.kernel_2),
                      stride=1,
                      padding=(0, self.kernel_2 // 2),
                      bias=False,
                      groups=self.F1 * self.D),
            nn.Conv2d(self.F1 * self.D, self.F2, 1, padding=(0, 0), groups=1, bias=False, stride=1),
            nn.BatchNorm2d(self.F2, momentum=0.01, affine=True, eps=1e-3), nn.ELU(), nn.AvgPool2d((1, 8), stride=8),
            nn.Dropout(p=dropout))

        self.lin = nn.Linear(self.feature_dim(), num_extracted_features, bias=False)


    def feature_dim(self):
        # function to calculate the number of features after the convolutional blocks
        with torch.no_grad():
            mock_eeg = torch.zeros(1, 1, self.num_electrodes, self.chunk_size)

            mock_eeg = self.block1(mock_eeg)
            mock_eeg = self.block2(mock_eeg)

        return self.F2 * mock_eeg.shape[3]

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.block1(x)
        x = self.block2(x)
        x = x.flatten(start_dim=1)
        x = self.lin(x)
        return x

## Transfering pretrained weights & extracting features

### Functions:

In [6]:
def transfer_weights(pretrained_weights, pretext_model=EEGNet()):
    """
    Function to transfer the pretrained weights to the pretext model
    param: pretrained_weights: the weights to transfer in a dictionary
    param: pretext_model: the model to transfer the weights to
    """
    pretrained_model = pretext_model
    modified_keys = {}
    for k, v in pretrained_weights.items():
        decomposed_key = k.split('.')
        if decomposed_key[0] == 'EEGNet': # remove the first part of each key to match the model's keys
            pretrained_key = '.'.join(decomposed_key[1:])
            modified_keys[pretrained_key] = v

            
    pretrained_model.load_state_dict(modified_keys)
        
    return pretrained_model

def extract_features(pretrained_model, data, pretext_task, df_sample, to_disk=False):
    """
    Function to extract features from the pretrained model
    param: pretrained_model: the model to extract features from
    param: data: the dataset containing the epochs to extract features from
    param: pretext_task: a string indicating the specific pretext task to save the features as
    param: df_sample: the dataframe containing the sampled participant IDs and their corresponding diagnosis
    param: to_disk: boolean to save the features to disk
    """
    dataloader = torch.utils.data.DataLoader(data, batch_size=1, shuffle=False)
    pretrained_model.eval()
    features_list = []
    participant_ids = []
    with torch.no_grad():  # Disable gradient calculation
        for batch in dataloader:
            epoch, participant_id = batch  # Remove the batch dimension
            epoch = epoch.unsqueeze(0)  # Add dimension
            # print(epoch.shape)
            features = pretrained_model(epoch)  # Extract features
            features = features.squeeze(0)
            features = features.numpy()
            features_list.append(features)
            participant_ids.append(participant_id[0])

    
    features_df = pd.DataFrame(features_list) # store as dataframe
    features_df['ID'] = participant_ids # add participant IDs to the dataframe
    # map the diagnosis values from df_sample to the dataframe based on participant IDs
    features_df['diagnosis'] = features_df['ID'].map(df_sample.set_index('participants_ID')['diagnosis'])
    
    print(f'{features_df.shape = }')
    display(features_df.head(3))


    if to_disk:
        features_df.to_pickle(f'D:/Documents/Master_Data_Science/Thesis/thesis_code/DataScience_Thesis/data/SSL_features/df_{pretext_task}_features.pkl')
    
    return features_df

def evaluate_features(features_df):
    """
    Function to quickly evaluate the extracted features. Doesn't stratify/group data splitting! 
    """
    X = features_df.drop(['ID', 'diagnosis'], axis=1)
    y = features_df['diagnosis']

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

    # quick SVM model
    clf = SVC()
    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)
    print('quick SVM model')
    print(confusion_matrix(y_test, y_pred))
    print(classification_report(y_test, y_pred))
    print(f1_score(y_test, y_pred, average='macro'))
    print()

    # quick random forest model
    clf = RandomForestClassifier()
    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)
    print('quick random forest model')
    print(confusion_matrix(y_test, y_pred))
    print(classification_report(y_test, y_pred))
    print(f1_score(y_test, y_pred, average='macro'))

    return

def get_ssl_features(
        pretext_task,
        data,
        df_sample,
        pretext_model=EEGNet(),
        eval=True,
        to_disk=True,
        weights_dir=r'D:\Documents\Master_Data_Science\Thesis\thesis_code\DataScience_Thesis\data\pretext_model_weights'
                     ):
    """
    Obtains SSL features from the weights trained by the pretext model.
    param: pretext_task: a string indicating the specific pretext task to load the weights from and save the features as
    param: data: the dataset containing the epochs to extract features from
    param: df_sample: the dataframe containing the sampled participant IDs and their corresponding diagnosis
    param: pretext_model: the pretext model to transfer the weights to
    param: eval: boolean to evaluate the features
    param: to_disk: boolean to save the features to disk
    """
    pretrained_weights = torch.load(f'{weights_dir}\{pretext_task}_weights.pt')
    pretrained_model = transfer_weights(pretrained_weights, pretext_model)
    features_df = extract_features(pretrained_model, data, pretext_task, df_sample=df_sample, to_disk=to_disk)
    if eval:
        evaluate_features(features_df)
    
    return

### randomly initialized model

In [7]:
features_df = extract_features(EEGNet(), dataset, df_sample=df_sample, pretext_task='random', to_disk=False)
evaluate_features(features_df)

features_df.shape = (2688, 102)
features_df.head(3) =           0         1         2         3         4         5         6  \
0 -0.004536  0.320280  0.405812  0.039802 -0.308949 -0.257266  0.252949   
1 -0.168716  0.159646  0.369122  0.241846 -0.294710 -0.187903  0.017133   
2 -0.021149  0.256763  0.265521  0.361246 -0.253374 -0.083487 -0.051087   

          7         8         9  ...        92        93        94        95  \
0  0.060935 -0.188668 -0.008633  ...  0.060558 -0.385546 -0.374811  0.350624   
1 -0.083962  0.277266 -0.150524  ... -0.394782 -0.591801 -0.386088  0.295578   
2  0.141549 -0.148216 -0.296685  ... -0.344894 -0.313566 -0.437223  0.387534   

         96        97        98        99            ID  diagnosis  
0 -0.133003 -0.174662 -0.396049  0.490404  sub-87964717        SMC  
1 -0.153409 -0.321951  0.115503  0.097863  sub-87964717        SMC  
2 -0.250683 -0.171399 -0.170894  0.026129  sub-87964717        SMC  

[3 rows x 102 columns]
quick SVM model
[[44 15 

### pretext model with default parameters (0.25 dropout, 0 weight decay, binary cross entropy loss)

In [11]:
# best model checkpoint
get_ssl_features('acrossRP_default_pretext_model', dataset, df_sample, eval=True, to_disk=True)

features_df.shape = (2688, 102)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,92,93,94,95,96,97,98,99,ID,diagnosis
0,0.121063,-2.412203,-0.631755,0.131576,0.145455,-1.03796,0.999677,-0.510573,0.570545,-0.796339,...,-0.204599,-0.080658,-0.796637,1.746416,0.229279,1.084207,0.164622,0.126958,sub-87964717,SMC
1,-0.689976,1.406009,-0.107374,-0.123132,-1.165247,0.258329,0.123246,0.44922,-0.620026,0.380591,...,0.010307,-0.489512,-0.351726,1.238766,-0.106024,-0.274885,-0.23567,0.027215,sub-87964717,SMC
2,-0.524933,-0.436216,0.495016,0.564053,1.216422,-0.075507,-0.430498,-0.623051,-0.569874,-1.174702,...,0.231564,-1.004634,-0.554257,-0.129656,-0.976007,0.629619,-1.015885,1.263301,sub-87964717,SMC


quick SVM model
[[54 14 10  8 20]
 [ 9 29 23 20 27]
 [13 21 18 24 32]
 [ 8 24 14 26 36]
 [ 6 15 14 16 57]]
              precision    recall  f1-score   support

        ADHD       0.60      0.51      0.55       106
     HEALTHY       0.28      0.27      0.27       108
         MDD       0.23      0.17      0.19       108
         OCD       0.28      0.24      0.26       108
         SMC       0.33      0.53      0.41       108

    accuracy                           0.34       538
   macro avg       0.34      0.34      0.34       538
weighted avg       0.34      0.34      0.34       538

0.33659677869040294

quick random forest model
[[51 20 11 13 11]
 [20 32 20 15 21]
 [16 24 19 22 27]
 [16 26 15 23 28]
 [ 5 16 30 17 40]]
              precision    recall  f1-score   support

        ADHD       0.47      0.48      0.48       106
     HEALTHY       0.27      0.30      0.28       108
         MDD       0.20      0.18      0.19       108
         OCD       0.26      0.21      0.23      

In [12]:
# overtrained model
get_ssl_features('overtrained_acrossRP_default_pretext_model', dataset, df_sample, eval=True, to_disk=True)

features_df.shape = (2688, 102)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,92,93,94,95,96,97,98,99,ID,diagnosis
0,-0.200007,-2.224988,-0.764491,0.012763,-0.061264,-0.762751,0.821297,-0.421128,0.748596,-0.467404,...,0.036675,-0.098485,-0.552036,1.752755,0.252039,0.89179,-0.069077,0.01054,sub-87964717,SMC
1,-0.507994,1.191808,-0.125921,-0.18772,-1.447353,0.048781,0.017578,0.436701,-0.42518,0.148685,...,-0.054039,-0.702327,0.19098,1.35384,-0.531177,-0.043534,-0.046193,-0.341877,sub-87964717,SMC
2,-0.675697,-0.502504,0.239002,0.477379,1.267334,0.210188,-0.723624,-0.302985,-0.696387,-0.656995,...,0.20822,-0.60582,-0.668775,0.11345,-0.751433,0.713029,-1.132169,1.226705,sub-87964717,SMC


quick SVM model
[[52 16  9 12 17]
 [11 34 19 17 27]
 [14 21 17 25 31]
 [ 9 23 15 26 35]
 [ 8 16 11 14 59]]
              precision    recall  f1-score   support

        ADHD       0.55      0.49      0.52       106
     HEALTHY       0.31      0.31      0.31       108
         MDD       0.24      0.16      0.19       108
         OCD       0.28      0.24      0.26       108
         SMC       0.35      0.55      0.43       108

    accuracy                           0.35       538
   macro avg       0.35      0.35      0.34       538
weighted avg       0.34      0.35      0.34       538

0.3410578523880901

quick random forest model
[[47 23 16 10 10]
 [20 25 26 20 17]
 [17 25 17 21 28]
 [13 21 23 25 26]
 [ 9 16 23 21 39]]
              precision    recall  f1-score   support

        ADHD       0.44      0.44      0.44       106
     HEALTHY       0.23      0.23      0.23       108
         MDD       0.16      0.16      0.16       108
         OCD       0.26      0.23      0.24       

### pretext model with (0.5 dropout, 0.01 weight decay, binary cross entropy loss) trained for 300 epochs

In [7]:
# best model checkpoint
get_ssl_features('acrossRP_0.5dropout_0.01wd_pretext_model', dataset, df_sample, eval=True, to_disk=True)

features_df.shape = (2688, 102)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,92,93,94,95,96,97,98,99,ID,diagnosis
0,-7.461343999999999e-20,-1.116169e-24,-4.65838e-23,-8.359894e-25,-2.340574e-24,2.629246e-21,1.418604e-21,2.707291e-23,1.916131e-24,-1.039791e-24,...,-5.486037e-20,9.684058e-24,-6.967804e-25,-1.199987e-21,-1.35138e-24,1.955244e-24,1.6401140000000002e-25,-5.0752850000000007e-23,sub-87964717,SMC
1,-7.461353e-20,-1.116169e-24,-4.658385e-23,-8.359892e-25,-2.340575e-24,2.6292489999999998e-21,1.418605e-21,2.7072950000000003e-23,1.916129e-24,-1.0397979999999999e-24,...,-5.486043e-20,9.684098999999999e-24,-6.967818e-25,-1.1999879999999999e-21,-1.35138e-24,1.955244e-24,1.64011e-25,-5.0752860000000003e-23,sub-87964717,SMC
2,-7.461351e-20,-1.116174e-24,-4.6583880000000003e-23,-8.359868e-25,-2.340578e-24,2.629248e-21,1.418605e-21,2.7072780000000005e-23,1.916124e-24,-1.039802e-24,...,-5.486041999999999e-20,9.684058e-24,-6.967835000000001e-25,-1.1999879999999999e-21,-1.351378e-24,1.955241e-24,1.6401110000000001e-25,-5.075278e-23,sub-87964717,SMC


quick SVM model
[[ 0  1  0 32 73]
 [ 0  0  0 28 80]
 [ 0  0  0 28 80]
 [ 0  0  0 28 80]
 [ 0  0  0 25 83]]
              precision    recall  f1-score   support

        ADHD       0.00      0.00      0.00       106
     HEALTHY       0.00      0.00      0.00       108
         MDD       0.00      0.00      0.00       108
         OCD       0.20      0.26      0.22       108
         SMC       0.21      0.77      0.33       108

    accuracy                           0.21       538
   macro avg       0.08      0.21      0.11       538
weighted avg       0.08      0.21      0.11       538

0.11085293555173073



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


quick random forest model
[[  0   0 106   0   0]
 [  0   0 108   0   0]
 [  0   0 108   0   0]
 [  0   0 108   0   0]
 [  0   0 108   0   0]]
              precision    recall  f1-score   support

        ADHD       0.00      0.00      0.00       106
     HEALTHY       0.00      0.00      0.00       108
         MDD       0.20      1.00      0.33       108
         OCD       0.00      0.00      0.00       108
         SMC       0.00      0.00      0.00       108

    accuracy                           0.20       538
   macro avg       0.04      0.20      0.07       538
weighted avg       0.04      0.20      0.07       538

0.06687306501547988


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [8]:
# overtrained model
get_ssl_features('overtrained_acrossRP_0.5dropout_0.01wd_pretext_model', dataset, df_sample, eval=True, to_disk=True)

features_df.shape = (2688, 102)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,92,93,94,95,96,97,98,99,ID,diagnosis
0,1.760137e-29,-7.355686e-35,-3.2918880000000004e-33,-3.866369e-35,1.79805e-34,2.489792e-31,2.230983e-31,-4.361544e-33,-8.405089999999999e-34,1.173681e-33,...,1.4946410000000002e-29,1.441292e-33,2.56255e-35,-1.265104e-31,6.304309e-35,-4.616614e-35,1.028053e-35,5.29803e-33,sub-87964717,SMC
1,1.760137e-29,-7.355686e-35,-3.2918880000000004e-33,-3.866369e-35,1.79805e-34,2.489792e-31,2.230983e-31,-4.361544e-33,-8.405089999999999e-34,1.173681e-33,...,1.4946410000000002e-29,1.441292e-33,2.56255e-35,-1.265104e-31,6.304309e-35,-4.616614e-35,1.028053e-35,5.29803e-33,sub-87964717,SMC
2,1.760137e-29,-7.355686e-35,-3.2918880000000004e-33,-3.866369e-35,1.79805e-34,2.489792e-31,2.230983e-31,-4.361544e-33,-8.405089999999999e-34,1.173681e-33,...,1.4946410000000002e-29,1.441292e-33,2.56255e-35,-1.265104e-31,6.304309e-35,-4.616614e-35,1.028053e-35,5.29803e-33,sub-87964717,SMC


quick SVM model
[[  0   0   0   0 106]
 [  0   0   0   0 108]
 [  0   0   0   0 108]
 [  0   0   0   0 108]
 [  0   0   0   0 108]]
              precision    recall  f1-score   support

        ADHD       0.00      0.00      0.00       106
     HEALTHY       0.00      0.00      0.00       108
         MDD       0.00      0.00      0.00       108
         OCD       0.00      0.00      0.00       108
         SMC       0.20      1.00      0.33       108

    accuracy                           0.20       538
   macro avg       0.04      0.20      0.07       538
weighted avg       0.04      0.20      0.07       538

0.06687306501547988



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


quick random forest model
[[  0   0   0   0 106]
 [  0   0   0   0 108]
 [  0   0   0   0 108]
 [  0   0   0   0 108]
 [  0   0   0   0 108]]
              precision    recall  f1-score   support

        ADHD       0.00      0.00      0.00       106
     HEALTHY       0.00      0.00      0.00       108
         MDD       0.00      0.00      0.00       108
         OCD       0.00      0.00      0.00       108
         SMC       0.20      1.00      0.33       108

    accuracy                           0.20       538
   macro avg       0.04      0.20      0.07       538
weighted avg       0.04      0.20      0.07       538

0.06687306501547988


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


### pretext model with tpos=2 (0.25)

In [9]:
# best model checkpoint
get_ssl_features('acrossRP_tpos2_pretext_model', dataset, df_sample, eval=True, to_disk=True)

features_df.shape = (2688, 102)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,92,93,94,95,96,97,98,99,ID,diagnosis
0,0.244671,0.785922,0.24382,0.071878,0.22932,-1.181531,0.79153,0.621222,1.883599,0.07241,...,0.30511,-0.852624,2.626107,-1.362296,1.051692,0.067965,0.142839,0.593177,sub-87964717,SMC
1,-0.866886,-0.009283,0.107281,-0.960699,-0.513096,-0.627648,-1.952809,0.5555,0.240736,0.506544,...,0.089955,-1.535046,-0.032849,-0.690623,1.025768,-0.561509,0.964592,1.262896,sub-87964717,SMC
2,0.085198,0.510124,-0.556894,0.513048,0.194047,0.597693,0.020339,1.538817,-0.674133,0.583213,...,1.215022,0.519565,1.288856,0.170691,-0.963983,0.235204,-0.356454,-0.388962,sub-87964717,SMC


quick SVM model
[[52 13 13 15 13]
 [13 32 19 20 24]
 [19 10 24 21 34]
 [14 21 19 20 34]
 [ 3 17 20 25 43]]
              precision    recall  f1-score   support

        ADHD       0.51      0.49      0.50       106
     HEALTHY       0.34      0.30      0.32       108
         MDD       0.25      0.22      0.24       108
         OCD       0.20      0.19      0.19       108
         SMC       0.29      0.40      0.34       108

    accuracy                           0.32       538
   macro avg       0.32      0.32      0.32       538
weighted avg       0.32      0.32      0.32       538

0.3169203361830518

quick random forest model
[[52 22 14  9  9]
 [15 26 20 23 24]
 [19 20 22 23 24]
 [11 23 20 24 30]
 [ 9 18 24 24 33]]
              precision    recall  f1-score   support

        ADHD       0.49      0.49      0.49       106
     HEALTHY       0.24      0.24      0.24       108
         MDD       0.22      0.20      0.21       108
         OCD       0.23      0.22      0.23       

In [11]:
# overtrained model
get_ssl_features('overtrained_acrossRP_tpos2_pretext_model', dataset, df_sample, eval=True, to_disk=True)

features_df.shape = (2688, 102)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,92,93,94,95,96,97,98,99,ID,diagnosis
0,0.258715,0.919693,0.405137,-0.342024,0.540094,-0.559412,0.457357,0.733314,1.610013,-0.037336,...,0.207041,-0.422054,2.468134,-1.928622,1.343169,0.196285,0.277656,0.392743,sub-87964717,SMC
1,-1.356684,-0.19582,0.527072,-0.849958,-0.138088,-1.295241,-2.01638,1.184622,-0.049766,0.466154,...,0.217456,-1.426862,0.155748,-0.632025,0.832715,-0.514698,0.894622,1.221896,sub-87964717,SMC
2,-0.180131,0.633555,-0.703599,0.379585,0.542985,0.528939,-0.33577,1.789057,-0.356298,0.45699,...,1.203158,0.45434,1.396475,0.147737,-0.983699,0.806666,-0.501196,-0.548008,sub-87964717,SMC


quick SVM model
[[50 14 12 16 14]
 [13 31 18 23 23]
 [14 14 25 27 28]
 [12 19 19 22 36]
 [ 4 18 18 21 47]]
              precision    recall  f1-score   support

        ADHD       0.54      0.47      0.50       106
     HEALTHY       0.32      0.29      0.30       108
         MDD       0.27      0.23      0.25       108
         OCD       0.20      0.20      0.20       108
         SMC       0.32      0.44      0.37       108

    accuracy                           0.33       538
   macro avg       0.33      0.33      0.33       538
weighted avg       0.33      0.33      0.32       538

0.32527732168000933

quick random forest model
[[51 20  9 15 11]
 [20 28 26 16 18]
 [16 17 27 28 20]
 [14 22 17 26 29]
 [ 8 20 21 25 34]]
              precision    recall  f1-score   support

        ADHD       0.47      0.48      0.47       106
     HEALTHY       0.26      0.26      0.26       108
         MDD       0.27      0.25      0.26       108
         OCD       0.24      0.24      0.24      

### pretext model with soft margin loss (0.25 dropout)

In [7]:
# best model checkpoint
get_ssl_features('acrossRP_soft_margin_loss_pretext_model', dataset, df_sample, eval=True, to_disk=True)

features_df.shape = (2688, 102)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,92,93,94,95,96,97,98,99,ID,diagnosis
0,0.571377,1.348506,1.306309,-0.039131,-0.679116,0.430812,0.206737,-0.845433,0.675838,0.154717,...,-0.068997,0.613388,1.290965,-0.000601,0.974056,-1.492894,-0.053755,-0.535912,sub-87964717,SMC
1,-0.132364,0.804168,1.05046,-0.905266,0.033737,0.219633,0.700095,-0.56442,-0.493905,0.182286,...,1.280335,1.079072,-0.893294,0.264488,-0.088887,-0.256551,-0.268115,1.45727,sub-87964717,SMC
2,0.25646,-0.461006,0.935647,0.111506,0.422974,-0.991261,0.233257,-0.667897,0.110081,-0.05823,...,-0.722836,-0.620188,0.466873,-1.125999,-0.319528,-0.976971,0.373419,-0.176336,sub-87964717,SMC


quick SVM model
[[52 16 10  6 22]
 [21 26 15 16 30]
 [14 10 20 24 40]
 [11 18 20 21 38]
 [ 4 13 20 20 51]]
              precision    recall  f1-score   support

        ADHD       0.51      0.49      0.50       106
     HEALTHY       0.31      0.24      0.27       108
         MDD       0.24      0.19      0.21       108
         OCD       0.24      0.19      0.22       108
         SMC       0.28      0.47      0.35       108

    accuracy                           0.32       538
   macro avg       0.32      0.32      0.31       538
weighted avg       0.32      0.32      0.31       538

0.309566197353218

quick random forest model
[[54 10 15 12 15]
 [17 33 16 24 18]
 [16 18 28 26 20]
 [17 26 22 21 22]
 [ 9 18 22 24 35]]
              precision    recall  f1-score   support

        ADHD       0.48      0.51      0.49       106
     HEALTHY       0.31      0.31      0.31       108
         MDD       0.27      0.26      0.27       108
         OCD       0.20      0.19      0.20       1

In [9]:
# overtrained model
get_ssl_features('overtrained_acrossRP_soft_margin_loss_pretext_model', dataset, df_sample, eval=True, to_disk=True)

features_df.shape = (2688, 102)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,92,93,94,95,96,97,98,99,ID,diagnosis
0,0.571377,1.348506,1.306309,-0.039131,-0.679116,0.430812,0.206737,-0.845433,0.675838,0.154717,...,-0.068997,0.613388,1.290965,-0.000601,0.974056,-1.492894,-0.053755,-0.535912,sub-87964717,SMC
1,-0.132364,0.804168,1.05046,-0.905266,0.033737,0.219633,0.700095,-0.56442,-0.493905,0.182286,...,1.280335,1.079072,-0.893294,0.264488,-0.088887,-0.256551,-0.268115,1.45727,sub-87964717,SMC
2,0.25646,-0.461006,0.935647,0.111506,0.422974,-0.991261,0.233257,-0.667897,0.110081,-0.05823,...,-0.722836,-0.620188,0.466873,-1.125999,-0.319528,-0.976971,0.373419,-0.176336,sub-87964717,SMC


quick SVM model
[[52 16 10  6 22]
 [21 26 15 16 30]
 [14 10 20 24 40]
 [11 18 20 21 38]
 [ 4 13 20 20 51]]
              precision    recall  f1-score   support

        ADHD       0.51      0.49      0.50       106
     HEALTHY       0.31      0.24      0.27       108
         MDD       0.24      0.19      0.21       108
         OCD       0.24      0.19      0.22       108
         SMC       0.28      0.47      0.35       108

    accuracy                           0.32       538
   macro avg       0.32      0.32      0.31       538
weighted avg       0.32      0.32      0.31       538

0.309566197353218

quick random forest model
[[52 11  9 18 16]
 [18 25 24 18 23]
 [12 21 20 27 28]
 [13 19 26 15 35]
 [ 5 21 22 22 38]]
              precision    recall  f1-score   support

        ADHD       0.52      0.49      0.50       106
     HEALTHY       0.26      0.23      0.24       108
         MDD       0.20      0.19      0.19       108
         OCD       0.15      0.14      0.14       1