# AIM: Extract features on labeled data using the pretrained EEGNet
across subject pretext task

In [1]:
import numpy as np
import pandas as pd
import mne
import lightning.pytorch as pl
import torch
import torch.nn.functional as F
import torch.nn as nn
from torchmetrics import F1Score, Accuracy
import random
import os
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split, StratifiedGroupKFold
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix, f1_score
from IPython.display import display
%matplotlib inline


# prevent extensive logging
mne.set_log_level('WARNING')

## Loading epoch data & participant data of labeled sample
These dataframes have been filtered and stored in a previous project. See https://github.com/TSmolders/Internship_EEG for original code

In [2]:

df_participants = pd.read_pickle(r'D:\Documents\RU\Master_Neurobiology\Internship_jaar_2\Project\TD-BRAIN\TDBRAIN_participants_V2_data\df_participants.pkl')
sample_df = pd.read_pickle(r'D:\Documents\RU\Master_Neurobiology\Internship_jaar_2\Project\TD-BRAIN\TD-BRAIN_extracted_features\df_selected_stat_features.pkl')
sample_ids = sample_df['ID'].unique() # obtain unique IDs from subsampled dataframe containing epoched features
df_sample = df_participants[df_participants['participants_ID'].isin(sample_ids)] # filter participants dataframe to only include subsampled IDs
df_sample = df_sample[df_sample['sessID'] == 1] # filter first session
print(df_sample.shape)
print(df_sample['diagnosis'].value_counts())



(225, 12)
diagnosis
ADHD       45
HEALTHY    45
MDD        45
OCD        45
SMC        45
Name: count, dtype: int64


In [3]:
# functions for loading the epoched EEG data
def get_filepath(epoch_dir, participant_ids):
    """
    Function to get the filepath of the epoched EEG recording
    :param epoch_dir: directory containing the epoched EEG recordings
    :param ID: list of participant IDs to include
    """
    filepaths = []
    for subdir, dirs, files in os.walk(epoch_dir):
        for file in files:
            if any(participant_id in file for participant_id in participant_ids):
                filepaths.append(os.path.join(subdir, file))
    return filepaths

class EpochDataset(torch.utils.data.Dataset):
    def __init__(self, participant_ids, epoch_dir):
        self.filepaths = get_filepath(epoch_dir, participant_ids)
        self.participant_ids = participant_ids
        self.epochs = []
        self.participant_ids = []
        self._load_data()
        print(f"Number of epochs: {self.epochs.shape[0]}")
        print(f"Number of participants: {len(self.participant_ids)}")

    def _load_data(self):
        all_epochs = []
        for filepath in self.filepaths:
            epochs = torch.load(filepath)
            # get participant ID from filepath to make sure the participant ID is correct
            participant_id = filepath.split("\\")[-1].split(".")[0]
            all_epochs.append(epochs)
            self.participant_ids.extend([participant_id]*epochs.shape[0])
        self.epochs = np.concatenate(all_epochs, axis=0)

    def __len__(self):
        return self.epochs.shape[0]
    
    def __getitem__(self, idx):
        epoch = self.epochs[idx]
        participant_id = self.participant_ids[idx]
        return torch.tensor(epoch, dtype=torch.float32), participant_id

In [4]:
# load the epochs into a dataset
participant_ids = df_sample['participants_ID'].tolist()
dataset = EpochDataset(participant_ids, r"D:\Documents\RU\Master_Neurobiology\Internship_jaar_2\Project\TD-BRAIN\TDBRAIN-dataset-derivatives\thesis_epoched_data\EC")
print(len(dataset))
print(dataset[0][0].shape)
print(dataset[0][1])
print(dataset[1][1])

Number of epochs: 2688
Number of participants: 2688
2688
torch.Size([26, 1244])
sub-87964717
sub-87964717


## Transfering pretrained weights & extracting features

### Functions:

In [5]:
def transfer_weights(pretrained_weights, pretext_model):
    """
    Function to transfer the pretrained weights to the pretext model
    param: pretrained_weights: the weights to transfer in a dictionary
    param: pretext_model: the model to transfer the weights to
    """
    pretrained_model = pretext_model
    modified_keys = {}
    for k, v in pretrained_weights.items():
        decomposed_key = k.split('.')
        if decomposed_key[0] == 'EEGNet': # remove the first part of each key to match the model's keys
            pretrained_key = '.'.join(decomposed_key[1:])
            modified_keys[pretrained_key] = v

            
    pretrained_model.load_state_dict(modified_keys)
        
    return pretrained_model

def extract_features(pretrained_model, data, pretext_task, df_sample, to_disk=False):
    """
    Function to extract features from the pretrained model
    param: pretrained_model: the model to extract features from
    param: data: the dataset containing the epochs to extract features from
    param: pretext_task: a string indicating the specific pretext task to save the features as
    param: df_sample: the dataframe containing the sampled participant IDs and their corresponding diagnosis
    param: to_disk: boolean to save the features to disk
    """
    dataloader = torch.utils.data.DataLoader(data, batch_size=1, shuffle=False)
    pretrained_model.eval()
    features_list = []
    participant_ids = []
    with torch.no_grad():  # Disable gradient calculation
        for batch in dataloader:
            epoch, participant_id = batch  # Remove the batch dimension
            epoch = epoch.unsqueeze(0)  # Add dimension
            # print(epoch.shape)
            features = pretrained_model(epoch)  # Extract features
            features = features.squeeze(0)
            features = features.numpy()
            features_list.append(features)
            participant_ids.append(participant_id[0])

    
    features_df = pd.DataFrame(features_list) # store as dataframe
    features_df['ID'] = participant_ids # add participant IDs to the dataframe
    # map the diagnosis values from df_sample to the dataframe based on participant IDs
    features_df['diagnosis'] = features_df['ID'].map(df_sample.set_index('participants_ID')['diagnosis'])
    
    print(f'{features_df.shape = }')
    display(features_df.head(3))


    if to_disk:
        features_df.to_pickle(f'D:/Documents/Master_Data_Science/Thesis/thesis_code/DataScience_Thesis/data/SSL_features/df_{pretext_task}_features.pkl')
    
    return features_df

def evaluate_features(features_df):
    """
    Function to quickly evaluate the extracted features. Doesn't stratify/group data splitting! 
    """
    groups = features_df['ID']
    X = features_df.drop(['ID', 'diagnosis'], axis=1)
    y = features_df['diagnosis']

    sgkf = StratifiedGroupKFold(n_splits=5, shuffle=False)

    # Get the train and test indices
    for train_index, test_index in sgkf.split(X, y, groups):
        X_train, X_test = X.iloc[train_index], X.iloc[test_index]
        y_train, y_test = y.iloc[train_index], y.iloc[test_index]
        break  # Only use the first split

    # quick SVM model
    clf = SVC()
    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)
    print('quick SVM model')
    print(confusion_matrix(y_test, y_pred))
    print(classification_report(y_test, y_pred))
    print(f1_score(y_test, y_pred, average='macro'))
    print()

    # quick random forest model
    clf = RandomForestClassifier()
    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)
    print('quick random forest model')
    print(confusion_matrix(y_test, y_pred))
    print(classification_report(y_test, y_pred))
    print(f1_score(y_test, y_pred, average='macro'))

    return

def get_ssl_features(
        pretext_task,
        data,
        df_sample,
        num_extracted_features=100,
        eval=True,
        to_disk=True,
        weights_dir=r'D:\Documents\Master_Data_Science\Thesis\thesis_code\DataScience_Thesis\data\pretext_model_weights'
                     ):
    """
    Obtains SSL features from the weights trained by the pretext model.
    param: pretext_task: a string indicating the specific pretext task to load the weights from and save the features as
    param: data: the dataset containing the epochs to extract features from
    param: df_sample: the dataframe containing the sampled participant IDs and their corresponding diagnosis
    param: num_extracted_features: the number of features to extract
    param: eval: boolean to evaluate the features
    param: to_disk: boolean to save the features to disk
    """
    # Need to define model class here to avoid issues with different number of extracted features
    # create Conv2d with max norm constraint
    class Conv2dWithConstraint(nn.Conv2d):
        def __init__(self, *args, max_norm: int = 1, **kwargs):
            self.max_norm = max_norm
            super(Conv2dWithConstraint, self).__init__(*args, **kwargs)

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            self.weight.data = torch.renorm(self.weight.data, p=2, dim=0, maxnorm=self.max_norm)
            return super(Conv2dWithConstraint, self).forward(x)
        
    class EEGNet(nn.Module):
        """
        Code taken and adjusted from pytorch implementation of EEGNet
        url: https://github.com/torcheeg/torcheeg/blob/v1.1.0/torcheeg/models/cnn/eegnet.py#L5
        """
        def __init__(self,
                    chunk_size: int = 1244, # number of data points in each EEG chunk
                    num_electrodes: int = 26, # number of EEG electrodes
                    F1: int = 8, # number of filters in first convolutional layer
                    F2: int = 16, # number of filters in second convolutional layer
                    D: int = 2, # depth multiplier
                    num_extracted_features: int = num_extracted_features, # number of features to extract
                    kernel_1: int = 64, # the filter size of block 1 (half of sfreq (125 Hz))
                    kernel_2: int = 16, # the filter size of block 2 (one eight of sfreq (500 Hz))
                    dropout: float = 0.25): # dropout rate
            super(EEGNet, self).__init__()
            self.F1 = F1
            self.F2 = F2
            self.D = D
            self.chunk_size = chunk_size
            self.num_extracted_features = num_extracted_features
            self.num_electrodes = num_electrodes
            self.kernel_1 = kernel_1
            self.kernel_2 = kernel_2
            self.dropout = dropout

            self.block1 = nn.Sequential(
                nn.Conv2d(1, self.F1, (1, self.kernel_1), stride=1, padding=(0, self.kernel_1 // 2), bias=False),
                nn.BatchNorm2d(self.F1, momentum=0.01, affine=True, eps=1e-3),
                Conv2dWithConstraint(self.F1,
                                    self.F1 * self.D, (self.num_electrodes, 1),
                                    max_norm=1,
                                    stride=1,
                                    padding=(0, 0),
                                    groups=self.F1,
                                    bias=False), nn.BatchNorm2d(self.F1 * self.D, momentum=0.01, affine=True, eps=1e-3),
                nn.ELU(), nn.AvgPool2d((1, 4), stride=4), nn.Dropout(p=dropout))

            self.block2 = nn.Sequential(
                nn.Conv2d(self.F1 * self.D,
                        self.F1 * self.D, (1, self.kernel_2),
                        stride=1,
                        padding=(0, self.kernel_2 // 2),
                        bias=False,
                        groups=self.F1 * self.D),
                nn.Conv2d(self.F1 * self.D, self.F2, 1, padding=(0, 0), groups=1, bias=False, stride=1),
                nn.BatchNorm2d(self.F2, momentum=0.01, affine=True, eps=1e-3), nn.ELU(), nn.AvgPool2d((1, 8), stride=8),
                nn.Dropout(p=dropout))

            self.lin = nn.Linear(self.feature_dim(), num_extracted_features, bias=False)


        def feature_dim(self):
            # function to calculate the number of features after the convolutional blocks
            with torch.no_grad():
                mock_eeg = torch.zeros(1, 1, self.num_electrodes, self.chunk_size)

                mock_eeg = self.block1(mock_eeg)
                mock_eeg = self.block2(mock_eeg)

            return self.F2 * mock_eeg.shape[3]

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            x = self.block1(x)
            x = self.block2(x)
            x = x.flatten(start_dim=1)
            x = self.lin(x)
            return x
        
    pretext_model = EEGNet()
    pretrained_weights = torch.load(f'{weights_dir}\{pretext_task}_weights.pt')
    pretrained_model = transfer_weights(pretrained_weights, pretext_model)
    features_df = extract_features(pretrained_model, data, pretext_task, df_sample=df_sample, to_disk=to_disk)
    if eval:
        evaluate_features(features_df)
    
    return

### randomly initialized model

In [7]:
features_df = extract_features(EEGNet(), dataset, df_sample=df_sample, pretext_task='random', to_disk=False)
evaluate_features(features_df)

features_df.shape = (2688, 102)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,92,93,94,95,96,97,98,99,ID,diagnosis
0,-0.025231,0.558265,0.38643,-0.427448,-0.195549,-0.385354,0.035971,0.385304,0.106288,0.573986,...,0.261263,0.668248,-0.518208,0.242272,0.186938,0.214983,0.118263,-0.321973,sub-87964717,SMC
1,-0.111984,0.977839,0.131731,-0.051678,-0.290045,-0.434035,0.237411,-0.101681,-0.275212,0.762277,...,0.520117,0.624576,-0.485054,0.251544,-0.03708,-0.125105,-0.325259,-0.672581,sub-87964717,SMC
2,-0.033515,0.690026,-0.037421,-0.427968,-0.044509,-0.273484,0.223849,0.226518,0.001657,0.772421,...,0.178688,0.269243,-0.326347,-0.047144,-0.056313,0.331418,0.099657,-0.459115,sub-87964717,SMC


quick SVM model
[[58 17  5 16 12]
 [15 36  9 32 16]
 [13 29 24 21 21]
 [ 0 27 21 33 27]
 [ 3  4 15 26 60]]
              precision    recall  f1-score   support

        ADHD       0.65      0.54      0.59       108
     HEALTHY       0.32      0.33      0.33       108
         MDD       0.32      0.22      0.26       108
         OCD       0.26      0.31      0.28       108
         SMC       0.44      0.56      0.49       108

    accuracy                           0.39       540
   macro avg       0.40      0.39      0.39       540
weighted avg       0.40      0.39      0.39       540

0.3899649803774411

quick random forest model
[[62 14 14 10  8]
 [18 34 20 22 14]
 [22 28 26 15 17]
 [10 24 28 18 28]
 [ 3 23 15 19 48]]
              precision    recall  f1-score   support

        ADHD       0.54      0.57      0.56       108
     HEALTHY       0.28      0.31      0.29       108
         MDD       0.25      0.24      0.25       108
         OCD       0.21      0.17      0.19       

### pretext model with default parameters (0.25 dropout, 0 weight decay, binary cross entropy loss)

In [7]:
# best model checkpoint
get_ssl_features('acrosssub_default_pretext_model', dataset, df_sample, eval=True, to_disk=True)

features_df.shape = (2688, 102)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,92,93,94,95,96,97,98,99,ID,diagnosis
0,3.623887,-1.228523,-1.482568,1.651297,-17.790987,5.814167,-1.637995,7.890889,0.264064,-1.745725,...,1.176101,19.422047,1.221881,9.082914,10.107931,1.013603,-0.151076,-0.010935,sub-87964717,SMC
1,4.06391,-0.249095,-1.333639,0.85253,-16.195347,5.602299,-0.927352,6.921437,-0.603323,-3.448514,...,0.694274,17.92716,2.676285,9.526098,10.338356,1.623796,0.002702,0.939451,sub-87964717,SMC
2,3.290289,0.072605,-0.576546,1.302655,-16.728363,5.784608,-1.060832,6.392365,1.129892,-3.624918,...,1.223866,19.308262,-0.15916,8.258363,11.810682,3.635775,0.800181,1.143464,sub-87964717,SMC


quick SVM model
[[44  9 11 30 14]
 [12 40 11 33 12]
 [16  6 22 24 40]
 [14 33 14 27 20]
 [18  4 31  3 52]]
              precision    recall  f1-score   support

        ADHD       0.42      0.41      0.42       108
     HEALTHY       0.43      0.37      0.40       108
         MDD       0.25      0.20      0.22       108
         OCD       0.23      0.25      0.24       108
         SMC       0.38      0.48      0.42       108

    accuracy                           0.34       540
   macro avg       0.34      0.34      0.34       540
weighted avg       0.34      0.34      0.34       540

0.34024176421440494

quick random forest model
[[63  0 23 12 10]
 [ 2 52 15 29 10]
 [10 12 24 35 27]
 [ 6 30 20 23 29]
 [15  7 16  7 63]]
              precision    recall  f1-score   support

        ADHD       0.66      0.58      0.62       108
     HEALTHY       0.51      0.48      0.50       108
         MDD       0.24      0.22      0.23       108
         OCD       0.22      0.21      0.21      

In [8]:
# overtrained model
get_ssl_features('overtrained_acrosssub_default_pretext_model', dataset, df_sample, eval=True, to_disk=True)

features_df.shape = (2688, 102)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,92,93,94,95,96,97,98,99,ID,diagnosis
0,4.888031,-1.091719,-1.84299,1.659128,-18.793324,7.022872,-1.732366,9.151362,0.017521,-0.97539,...,0.738329,20.309301,1.319816,10.909989,9.26828,-0.507095,-0.809679,0.530718,sub-87964717,SMC
1,5.39201,-0.476683,-1.306184,-0.017972,-17.125385,7.14606,-0.269107,7.905194,-0.789516,-3.939141,...,-0.030765,18.537172,3.2459,11.278497,10.519344,0.110141,-0.592927,1.838493,sub-87964717,SMC
2,3.965634,0.235488,-0.46326,0.626689,-17.413103,6.74487,-1.186025,7.509469,0.89055,-3.323798,...,0.611838,19.963472,0.414842,10.219461,11.495079,3.069146,-0.335685,2.482077,sub-87964717,SMC


quick SVM model
[[44  8  6 34 16]
 [19 36  9 28 16]
 [15  5 22 21 45]
 [ 6 28 17 37 20]
 [16  1 27  9 55]]
              precision    recall  f1-score   support

        ADHD       0.44      0.41      0.42       108
     HEALTHY       0.46      0.33      0.39       108
         MDD       0.27      0.20      0.23       108
         OCD       0.29      0.34      0.31       108
         SMC       0.36      0.51      0.42       108

    accuracy                           0.36       540
   macro avg       0.36      0.36      0.36       540
weighted avg       0.36      0.36      0.36       540

0.3556582280142917

quick random forest model
[[63  0 16 17 12]
 [ 3 61 17 18  9]
 [ 9  2 31 30 36]
 [ 3 25 18 29 33]
 [18  4 18  5 63]]
              precision    recall  f1-score   support

        ADHD       0.66      0.58      0.62       108
     HEALTHY       0.66      0.56      0.61       108
         MDD       0.31      0.29      0.30       108
         OCD       0.29      0.27      0.28       

### pretext model with (0.5 dropout, 0.01 weight decay, binary cross entropy loss) trained for 50 epochs

In [8]:
# best model checkpoint
get_ssl_features('acrosssub_0.5dropout_0.01wd_pretext_model', dataset, df_sample, eval=True, to_disk=True)

features_df.shape = (2688, 102)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,92,93,94,95,96,97,98,99,ID,diagnosis
0,0.110355,0.811385,0.972408,1.268003,0.849162,-1.444711,-0.000281,-0.144839,-0.278361,0.77008,...,1.697409,-0.454907,0.853184,1.153516,-0.163238,1.132012,-1.542204,-0.077861,sub-87964717,SMC
1,-0.032289,0.799907,0.899556,1.262478,0.876555,-1.422669,-0.066222,0.037951,-0.280847,0.757096,...,1.69331,-0.431586,0.814568,1.041446,0.013759,1.129061,-1.500744,0.015565,sub-87964717,SMC
2,0.321387,0.773045,0.829811,1.196555,0.767482,-1.333442,-0.035082,-0.10344,-0.247684,0.693844,...,1.487311,-0.39089,0.752712,0.992581,-0.097055,0.994175,-1.416777,0.161192,sub-87964717,SMC


quick SVM model
[[55 12  0 31 10]
 [11 16  1 42 38]
 [16 21  0 28 43]
 [ 0 34  3 32 39]
 [ 0 12  2 49 45]]
              precision    recall  f1-score   support

        ADHD       0.67      0.51      0.58       108
     HEALTHY       0.17      0.15      0.16       108
         MDD       0.00      0.00      0.00       108
         OCD       0.18      0.30      0.22       108
         SMC       0.26      0.42      0.32       108

    accuracy                           0.27       540
   macro avg       0.25      0.27      0.26       540
weighted avg       0.25      0.27      0.26       540

0.25505873859743794

quick random forest model
[[54 17  8 24  5]
 [15 35 17 21 20]
 [24 26 25 20 13]
 [ 2 37 26 23 20]
 [ 3 20 20 47 18]]
              precision    recall  f1-score   support

        ADHD       0.55      0.50      0.52       108
     HEALTHY       0.26      0.32      0.29       108
         MDD       0.26      0.23      0.25       108
         OCD       0.17      0.21      0.19      

In [10]:
# overtrained model
get_ssl_features('overtrained_acrosssub_0.5dropout_0.01wd_pretext_model', dataset, df_sample, eval=True, to_disk=True)

features_df.shape = (2688, 102)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,92,93,94,95,96,97,98,99,ID,diagnosis
0,0.110355,0.811385,0.972408,1.268003,0.849162,-1.444711,-0.000281,-0.144839,-0.278361,0.77008,...,1.697409,-0.454907,0.853184,1.153516,-0.163238,1.132012,-1.542204,-0.077861,sub-87964717,SMC
1,-0.032289,0.799907,0.899556,1.262478,0.876555,-1.422669,-0.066222,0.037951,-0.280847,0.757096,...,1.69331,-0.431586,0.814568,1.041446,0.013759,1.129061,-1.500744,0.015565,sub-87964717,SMC
2,0.321387,0.773045,0.829811,1.196555,0.767482,-1.333442,-0.035082,-0.10344,-0.247684,0.693844,...,1.487311,-0.39089,0.752712,0.992581,-0.097055,0.994175,-1.416777,0.161192,sub-87964717,SMC


quick SVM model
[[55 12  0 31 10]
 [11 16  1 42 38]
 [16 21  0 28 43]
 [ 0 34  3 32 39]
 [ 0 12  2 49 45]]
              precision    recall  f1-score   support

        ADHD       0.67      0.51      0.58       108
     HEALTHY       0.17      0.15      0.16       108
         MDD       0.00      0.00      0.00       108
         OCD       0.18      0.30      0.22       108
         SMC       0.26      0.42      0.32       108

    accuracy                           0.27       540
   macro avg       0.25      0.27      0.26       540
weighted avg       0.25      0.27      0.26       540

0.25505873859743794

quick random forest model
[[57 15  7 25  4]
 [17 32 16 21 22]
 [25 20 25 17 21]
 [ 3 39 26 24 16]
 [ 6 23 15 41 23]]
              precision    recall  f1-score   support

        ADHD       0.53      0.53      0.53       108
     HEALTHY       0.25      0.30      0.27       108
         MDD       0.28      0.23      0.25       108
         OCD       0.19      0.22      0.20      

### pretext model with trained for 10 epochs

In [6]:
# best model checkpoint
get_ssl_features('acrosssub_10epochs_pretext_model', dataset, df_sample, eval=True, to_disk=True)

features_df.shape = (2688, 102)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,92,93,94,95,96,97,98,99,ID,diagnosis
0,0.48144,0.633112,-1.726813,9.642133,0.69213,-10.080816,1.418434,1.354628,-3.802871,4.347657,...,2.109843,2.293432,4.100114,-5.626955,0.154212,1.305998,-1.614531,-7.114364,sub-87964717,SMC
1,-0.288733,0.207911,0.257113,6.740704,0.79273,-8.717131,0.794609,-1.65943,-2.662497,0.801926,...,-0.072879,2.290982,3.123032,-3.737323,0.073285,1.190753,-0.780681,-5.290376,sub-87964717,SMC
2,0.185403,0.507175,-0.284592,7.127849,0.205872,-10.026775,0.720594,-0.380045,-3.082051,1.919059,...,1.387761,2.440483,1.969689,-4.929652,0.935145,1.276471,-1.320437,-5.335428,sub-87964717,SMC


quick SVM model
[[56  3 29 12  8]
 [12 30 26 23 17]
 [15 21 27 16 29]
 [13 13 37 31 14]
 [15 14 13 17 49]]
              precision    recall  f1-score   support

        ADHD       0.50      0.52      0.51       108
     HEALTHY       0.37      0.28      0.32       108
         MDD       0.20      0.25      0.23       108
         OCD       0.31      0.29      0.30       108
         SMC       0.42      0.45      0.44       108

    accuracy                           0.36       540
   macro avg       0.36      0.36      0.36       540
weighted avg       0.36      0.36      0.36       540

0.35778966126851774

quick random forest model
[[64  1 24 10  9]
 [13 50  8 26 11]
 [20 20 26 22 20]
 [10 26 26 34 12]
 [16 19 19 18 36]]
              precision    recall  f1-score   support

        ADHD       0.52      0.59      0.55       108
     HEALTHY       0.43      0.46      0.45       108
         MDD       0.25      0.24      0.25       108
         OCD       0.31      0.31      0.31      

In [7]:
# overtrained model
get_ssl_features('overtrained_acrosssub_10epochs_pretext_model', dataset, df_sample, eval=True, to_disk=True)

features_df.shape = (2688, 102)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,92,93,94,95,96,97,98,99,ID,diagnosis
0,0.193253,1.199918,0.319841,10.120821,0.304049,-8.346296,1.340301,0.298271,-3.376927,3.718822,...,1.891683,3.434391,2.526464,-3.987168,-0.804768,1.362534,-1.847329,-6.440199,sub-87964717,SMC
1,-0.553735,1.069052,2.289602,8.52388,0.626479,-8.482388,0.357972,-2.737444,-2.118034,0.706097,...,-1.176191,2.789662,2.244703,-3.112911,-1.079146,1.799031,-1.622424,-4.895182,sub-87964717,SMC
2,-0.051257,1.426289,1.955187,8.644217,0.076998,-9.104626,0.291343,-0.411044,-1.611388,1.95562,...,0.741185,2.922426,1.33208,-3.222737,-0.557499,1.172376,-1.824687,-4.568328,sub-87964717,SMC


quick SVM model
[[57  2 29 14  6]
 [13 33 24 25 13]
 [21 17 19 22 29]
 [ 7 16 38 33 14]
 [12 11 16 22 47]]
              precision    recall  f1-score   support

        ADHD       0.52      0.53      0.52       108
     HEALTHY       0.42      0.31      0.35       108
         MDD       0.15      0.18      0.16       108
         OCD       0.28      0.31      0.29       108
         SMC       0.43      0.44      0.43       108

    accuracy                           0.35       540
   macro avg       0.36      0.35      0.35       540
weighted avg       0.36      0.35      0.35       540

0.3532185398650852

quick random forest model
[[62  2 29 11  4]
 [12 45 11 29 11]
 [16 20 29 18 25]
 [14 28 20 32 14]
 [14 20 22 12 40]]
              precision    recall  f1-score   support

        ADHD       0.53      0.57      0.55       108
     HEALTHY       0.39      0.42      0.40       108
         MDD       0.26      0.27      0.26       108
         OCD       0.31      0.30      0.30       

### pretext model 150 features

In [8]:
# best model checkpoint
get_ssl_features('acrosssub_150feat_pretext_model', dataset, df_sample, eval=True, to_disk=True)

FileNotFoundError: [Errno 2] No such file or directory: 'D:\\Documents\\Master_Data_Science\\Thesis\\thesis_code\\DataScience_Thesis\\data\\pretext_model_weights\\acrosssub_150feat_pretext_model_weights.pt'

In [10]:
# overtrained model
get_ssl_features('overtrained_acrosssub_150feat_pretext_model', dataset, df_sample, eval=True, to_disk=True, num_extracted_features=150)

features_df.shape = (2688, 152)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,142,143,144,145,146,147,148,149,ID,diagnosis
0,-2.559961,-3.054284,-2.644992,-6.277209,-0.028564,-0.906803,-21.77343,-0.134492,-1.485628,0.396045,...,1.410881,-1.88887,1.88595,-3.703908,-12.834751,2.617877,13.235068,-1.618843,sub-87964717,SMC
1,-0.068584,-0.779696,-0.959284,-9.155769,0.970993,-0.524335,-19.471977,2.17656,-1.005807,1.004661,...,4.83721,-2.676227,0.850829,-7.057852,-10.173094,1.023762,10.949565,-0.758843,sub-87964717,SMC
2,-0.333431,-1.506322,-3.900646,-7.29097,-0.30578,-2.440059,-18.541176,0.154866,-0.744295,1.786914,...,5.353173,-4.43046,1.096191,-5.224116,-8.479725,2.299004,10.07663,-1.119349,sub-87964717,SMC


quick SVM model
[[53  0 13 31 11]
 [ 7 59 19 15  8]
 [26  9 38 19 16]
 [ 9 35 38 21  5]
 [17 34 13 18 26]]
              precision    recall  f1-score   support

        ADHD       0.47      0.49      0.48       108
     HEALTHY       0.43      0.55      0.48       108
         MDD       0.31      0.35      0.33       108
         OCD       0.20      0.19      0.20       108
         SMC       0.39      0.24      0.30       108

    accuracy                           0.36       540
   macro avg       0.36      0.36      0.36       540
weighted avg       0.36      0.36      0.36       540

0.35845846927937236

quick random forest model
[[59  0 19 24  6]
 [ 7 64  3 29  5]
 [14 16 38 18 22]
 [12 33 33 28  2]
 [17 30  9 18 34]]
              precision    recall  f1-score   support

        ADHD       0.54      0.55      0.54       108
     HEALTHY       0.45      0.59      0.51       108
         MDD       0.37      0.35      0.36       108
         OCD       0.24      0.26      0.25      

### pretext model 50 features

In [16]:
# best model checkpoint
get_ssl_features('acrosssub_50feat_pretext_model', dataset, df_sample, eval=True, to_disk=True, num_extracted_features=50)

features_df.shape = (2688, 52)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,42,43,44,45,46,47,48,49,ID,diagnosis
0,0.532304,3.829793,-0.430381,4.252481,-1.384971,-2.284594,-19.166306,-4.244578,1.721899,-2.508856,...,-1.258278,-1.058604,0.927077,0.233525,0.605244,-0.611597,-1.793311,-0.969478,sub-87964717,SMC
1,1.282263,2.542199,2.460679,4.44435,-0.421117,0.519686,-20.942211,-4.378036,1.536193,-2.673515,...,-3.298463,-1.547187,0.078922,0.734727,-0.862515,-0.289685,-2.435214,-1.241847,sub-87964717,SMC
2,1.07103,0.693853,1.960208,5.386222,0.254885,-2.035523,-23.922855,-3.678423,2.244879,-3.970068,...,-0.40765,-0.171258,0.938032,3.459286,3.073155,-0.338672,-1.132302,0.352085,sub-87964717,SMC


quick SVM model
[[57  0 13 23 15]
 [25 41  3 24 15]
 [22 35 19 24  8]
 [ 7 23 28 42  8]
 [12 28  3 16 49]]
              precision    recall  f1-score   support

        ADHD       0.46      0.53      0.49       108
     HEALTHY       0.32      0.38      0.35       108
         MDD       0.29      0.18      0.22       108
         OCD       0.33      0.39      0.35       108
         SMC       0.52      0.45      0.48       108

    accuracy                           0.39       540
   macro avg       0.38      0.39      0.38       540
weighted avg       0.38      0.39      0.38       540

0.3796044937506903

quick random forest model
[[60  3 10 19 16]
 [ 4 55  0 32 17]
 [20 30 21 18 19]
 [ 9 20 25 32 22]
 [13 24  6 23 42]]
              precision    recall  f1-score   support

        ADHD       0.57      0.56      0.56       108
     HEALTHY       0.42      0.51      0.46       108
         MDD       0.34      0.19      0.25       108
         OCD       0.26      0.30      0.28       

In [15]:
# overtrained model
get_ssl_features('overtrained_acrosssub_50feat_pretext_model', dataset, df_sample, eval=True, to_disk=True, num_extracted_features=50)

features_df.shape = (2688, 52)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,42,43,44,45,46,47,48,49,ID,diagnosis
0,0.985169,3.891889,-1.663651,3.693137,-1.054052,-2.293855,-16.974684,-4.104893,1.525412,-2.65457,...,-0.131403,-0.581822,-0.064482,-0.245294,0.507439,-0.253126,-0.792511,-1.782067,sub-87964717,SMC
1,1.417029,4.115095,1.167675,4.628066,-0.84252,0.439773,-19.010715,-4.591773,2.259824,-2.43431,...,-2.745334,-2.295626,0.041776,0.951865,-1.172435,1.165308,-1.550761,-1.830126,sub-87964717,SMC
2,1.30008,2.458149,-0.347941,4.680026,1.513699,-2.501022,-22.114927,-3.981549,2.064567,-3.587168,...,-0.174688,-0.942412,0.703044,2.498853,2.830074,0.088156,0.008734,-0.577674,sub-87964717,SMC


quick SVM model
[[57  0  9 27 15]
 [24 40  5 23 16]
 [20 38 22 18 10]
 [ 7 14 27 43 17]
 [14 27  2 21 44]]
              precision    recall  f1-score   support

        ADHD       0.47      0.53      0.50       108
     HEALTHY       0.34      0.37      0.35       108
         MDD       0.34      0.20      0.25       108
         OCD       0.33      0.40      0.36       108
         SMC       0.43      0.41      0.42       108

    accuracy                           0.38       540
   macro avg       0.38      0.38      0.38       540
weighted avg       0.38      0.38      0.38       540

0.37595825877971795

quick random forest model
[[60  3  8 21 16]
 [ 3 53  2 30 20]
 [19 24 33 17 15]
 [ 8 18 22 37 23]
 [18 31  4 12 43]]
              precision    recall  f1-score   support

        ADHD       0.56      0.56      0.56       108
     HEALTHY       0.41      0.49      0.45       108
         MDD       0.48      0.31      0.37       108
         OCD       0.32      0.34      0.33      

### pretext model with soft margin loss (0.25 dropout)

In [6]:
# best model checkpoint
get_ssl_features('acrosssub_soft_margin_loss_pretext_model', dataset, df_sample, eval=True, to_disk=True)

features_df.shape = (2688, 102)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,92,93,94,95,96,97,98,99,ID,diagnosis
0,-7.529795,-4.775673,12.061558,10.33105,0.569838,2.952614,0.439917,0.16706,5.21309,-1.864353,...,3.128101,8.402712,-2.372483,1.281823,-8.790417,-0.179843,2.697247,-4.472449,sub-87964717,SMC
1,-6.932564,-3.487274,9.676962,9.289244,0.574574,2.927734,-0.353255,-0.186739,5.203309,-2.24973,...,2.614557,5.515316,-3.588432,2.151644,-6.597898,-0.385852,2.548435,-3.753773,sub-87964717,SMC
2,-7.055388,-3.230823,9.919414,11.759319,-0.670908,3.54558,1.525463,-0.568282,5.000744,-2.417435,...,2.675336,6.487484,-3.388251,1.319001,-6.86921,0.183566,2.301235,-1.50848,sub-87964717,SMC


quick SVM model
[[59  5 13 26  5]
 [21 56 10  9 12]
 [26 16 37 14 15]
 [ 1 49 30 24  4]
 [ 8  6 63 17 14]]
              precision    recall  f1-score   support

        ADHD       0.51      0.55      0.53       108
     HEALTHY       0.42      0.52      0.47       108
         MDD       0.24      0.34      0.28       108
         OCD       0.27      0.22      0.24       108
         SMC       0.28      0.13      0.18       108

    accuracy                           0.35       540
   macro avg       0.35      0.35      0.34       540
weighted avg       0.35      0.35      0.34       540

0.3397957970483333

quick random forest model
[[62  4  2 30 10]
 [19 50 10 17 12]
 [19 14 30 19 26]
 [ 0 53 26 26  3]
 [12  4 52 14 26]]
              precision    recall  f1-score   support

        ADHD       0.55      0.57      0.56       108
     HEALTHY       0.40      0.46      0.43       108
         MDD       0.25      0.28      0.26       108
         OCD       0.25      0.24      0.24       

In [7]:
# overtrained model
get_ssl_features('overtrained_acrosssub_soft_margin_loss_pretext_model', dataset, df_sample, eval=True, to_disk=True)

features_df.shape = (2688, 102)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,92,93,94,95,96,97,98,99,ID,diagnosis
0,-5.855083,-7.727406,12.22132,14.037726,0.102029,2.820677,-0.149468,0.302036,3.921639,-1.482393,...,3.979167,10.372499,-2.736665,2.408197,-8.681751,0.079202,2.404514,-5.526348,sub-87964717,SMC
1,-5.438762,-5.820913,9.628739,11.997428,-0.018043,1.948442,-0.738979,-0.414632,4.591846,-1.989791,...,3.682869,6.663427,-3.57217,2.588218,-5.981269,0.278089,2.353071,-3.891693,sub-87964717,SMC
2,-5.352138,-5.79145,9.795975,15.059916,-1.232315,2.726134,0.533534,-0.766387,4.567248,-1.457844,...,4.47485,8.624177,-3.688293,1.963357,-6.804841,1.191428,1.76773,-2.902045,sub-87964717,SMC


quick SVM model
[[59  7  9 28  5]
 [23 47 12 15 11]
 [25 17 41 12 13]
 [ 1 50 29 24  4]
 [ 8  5 60 18 17]]
              precision    recall  f1-score   support

        ADHD       0.51      0.55      0.53       108
     HEALTHY       0.37      0.44      0.40       108
         MDD       0.27      0.38      0.32       108
         OCD       0.25      0.22      0.23       108
         SMC       0.34      0.16      0.22       108

    accuracy                           0.35       540
   macro avg       0.35      0.35      0.34       540
weighted avg       0.35      0.35      0.34       540

0.33888672949571375

quick random forest model
[[60  3  8 26 11]
 [21 38 14 23 12]
 [19 17 38 15 19]
 [ 2 52 31 22  1]
 [11  0 55 20 22]]
              precision    recall  f1-score   support

        ADHD       0.53      0.56      0.54       108
     HEALTHY       0.35      0.35      0.35       108
         MDD       0.26      0.35      0.30       108
         OCD       0.21      0.20      0.21      

### pretext model with soft margin loss (10 epochs)

In [8]:
# best model checkpoint
get_ssl_features('acrosssub_10ep_soft_margin_loss_pretext_model', dataset, df_sample, eval=True, to_disk=True)

features_df.shape = (2688, 102)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,92,93,94,95,96,97,98,99,ID,diagnosis
0,-1.185726,-1.812652,5.068885,0.24698,-2.601562,1.16609,-2.460453,1.01484,4.117381,0.738197,...,-2.89908,-7.798763,-1.278432,-1.515464,-3.018791,15.429169,3.20452,0.910752,sub-87964717,SMC
1,0.033688,-0.387711,4.267061,1.124661,-1.836862,0.044589,-0.043063,0.80314,2.214555,1.553338,...,-3.567886,-7.010677,-2.68365,-0.171279,-2.612827,17.178402,2.143798,1.743032,sub-87964717,SMC
2,-0.483011,0.152234,4.092942,0.670455,-1.64723,0.816747,0.704018,0.984571,3.125106,0.942811,...,-3.160301,-7.528048,-2.740133,0.843616,-1.651058,16.694237,1.527334,2.125,sub-87964717,SMC


quick SVM model
[[59 18  7 13 11]
 [ 1 57  6 27 17]
 [ 9 18 23 34 24]
 [ 0 51 30 15 12]
 [11  5  5  3 84]]
              precision    recall  f1-score   support

        ADHD       0.74      0.55      0.63       108
     HEALTHY       0.38      0.53      0.44       108
         MDD       0.32      0.21      0.26       108
         OCD       0.16      0.14      0.15       108
         SMC       0.57      0.78      0.66       108

    accuracy                           0.44       540
   macro avg       0.43      0.44      0.43       540
weighted avg       0.43      0.44      0.43       540

0.4268945162457028

quick random forest model
[[62  6 18 14  8]
 [ 2 45  8 38 15]
 [ 7 18 31 32 20]
 [ 0 36 38 22 12]
 [ 5  6 43 12 42]]
              precision    recall  f1-score   support

        ADHD       0.82      0.57      0.67       108
     HEALTHY       0.41      0.42      0.41       108
         MDD       0.22      0.29      0.25       108
         OCD       0.19      0.20      0.19       

In [9]:
# overtrained model
get_ssl_features('overtrained_acrosssub_10ep_soft_margin_loss_pretext_model', dataset, df_sample, eval=True, to_disk=True)

features_df.shape = (2688, 102)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,92,93,94,95,96,97,98,99,ID,diagnosis
0,-1.330847,-3.088368,6.051607,-0.310206,-2.076479,0.169632,-2.489274,1.178272,4.181692,0.570766,...,-2.514752,-8.293884,-0.184375,-1.302711,-2.548679,15.433952,2.838049,0.641423,sub-87964717,SMC
1,0.127345,-1.546696,4.803869,0.438345,-1.315916,-1.010393,0.150936,1.616889,2.439004,1.251327,...,-3.269745,-7.268375,-1.956525,-0.088381,-2.118301,17.572662,1.327049,1.884496,sub-87964717,SMC
2,-0.189558,-1.203649,4.795016,-0.336626,-0.929377,-0.361057,1.00055,1.881567,3.439887,0.08926,...,-2.845836,-7.916263,-1.761527,1.068087,-1.135421,17.265411,0.589541,1.947669,sub-87964717,SMC


quick SVM model
[[60 18  7 12 11]
 [ 1 59  2 29 17]
 [14 15 18 33 28]
 [ 1 58 24 11 14]
 [13  8  3  3 81]]
              precision    recall  f1-score   support

        ADHD       0.67      0.56      0.61       108
     HEALTHY       0.37      0.55      0.44       108
         MDD       0.33      0.17      0.22       108
         OCD       0.12      0.10      0.11       108
         SMC       0.54      0.75      0.63       108

    accuracy                           0.42       540
   macro avg       0.41      0.42      0.40       540
weighted avg       0.41      0.42      0.40       540

0.4025391648115971

quick random forest model
[[61 11 13 14  9]
 [ 1 41  6 42 18]
 [ 7 17 32 32 20]
 [ 2 42 32 19 13]
 [ 7  7 33 15 46]]
              precision    recall  f1-score   support

        ADHD       0.78      0.56      0.66       108
     HEALTHY       0.35      0.38      0.36       108
         MDD       0.28      0.30      0.29       108
         OCD       0.16      0.18      0.17       

### pretext model with soft margin loss (3 epochs)

In [10]:
# best model checkpoint
get_ssl_features('acrosssub_3ep_soft_margin_loss_pretext_model', dataset, df_sample, eval=True, to_disk=True)

features_df.shape = (2688, 102)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,92,93,94,95,96,97,98,99,ID,diagnosis
0,22.240875,-2.774086,10.166421,-2.734844,-20.334148,4.683058,-1.84208,-0.716484,3.789842,-0.530341,...,-2.135378,1.71317,2.878568,-0.334627,-1.462004,-1.275851,1.567474,-0.333517,sub-87964717,SMC
1,19.033426,-2.217509,10.272071,-3.890897,-16.428486,5.538733,-0.886464,-1.512452,3.135611,-1.985455,...,-1.226656,1.212124,3.571429,-0.577601,-1.085243,0.270354,0.655387,-1.460791,sub-87964717,SMC
2,20.169655,-3.53853,8.025307,-2.360705,-18.272587,4.484413,-0.634366,-0.157521,3.838518,-1.168988,...,-1.228914,1.31487,2.982244,-0.001873,-1.395818,-1.726364,0.948865,-0.374595,sub-87964717,SMC


quick SVM model
[[53 11 30  4 10]
 [ 3 33  5 44 23]
 [ 9 11 28 30 30]
 [ 0 18 27 18 45]
 [12 22  3 15 56]]
              precision    recall  f1-score   support

        ADHD       0.69      0.49      0.57       108
     HEALTHY       0.35      0.31      0.33       108
         MDD       0.30      0.26      0.28       108
         OCD       0.16      0.17      0.16       108
         SMC       0.34      0.52      0.41       108

    accuracy                           0.35       540
   macro avg       0.37      0.35      0.35       540
weighted avg       0.37      0.35      0.35       540

0.3505702716765301

quick random forest model
[[59 14 26  5  4]
 [ 5 26  4 52 21]
 [ 9 23 24 34 18]
 [ 6 16 26 24 36]
 [ 8  9  8 46 37]]
              precision    recall  f1-score   support

        ADHD       0.68      0.55      0.61       108
     HEALTHY       0.30      0.24      0.27       108
         MDD       0.27      0.22      0.24       108
         OCD       0.15      0.22      0.18       

In [11]:
# overtrained model
get_ssl_features('overtrained_acrosssub_3ep_soft_margin_loss_pretext_model', dataset, df_sample, eval=True, to_disk=True)

features_df.shape = (2688, 102)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,92,93,94,95,96,97,98,99,ID,diagnosis
0,22.240875,-2.774086,10.166421,-2.734844,-20.334148,4.683058,-1.84208,-0.716484,3.789842,-0.530341,...,-2.135378,1.71317,2.878568,-0.334627,-1.462004,-1.275851,1.567474,-0.333517,sub-87964717,SMC
1,19.033426,-2.217509,10.272071,-3.890897,-16.428486,5.538733,-0.886464,-1.512452,3.135611,-1.985455,...,-1.226656,1.212124,3.571429,-0.577601,-1.085243,0.270354,0.655387,-1.460791,sub-87964717,SMC
2,20.169655,-3.53853,8.025307,-2.360705,-18.272587,4.484413,-0.634366,-0.157521,3.838518,-1.168988,...,-1.228914,1.31487,2.982244,-0.001873,-1.395818,-1.726364,0.948865,-0.374595,sub-87964717,SMC


quick SVM model
[[53 11 30  4 10]
 [ 3 33  5 44 23]
 [ 9 11 28 30 30]
 [ 0 18 27 18 45]
 [12 22  3 15 56]]
              precision    recall  f1-score   support

        ADHD       0.69      0.49      0.57       108
     HEALTHY       0.35      0.31      0.33       108
         MDD       0.30      0.26      0.28       108
         OCD       0.16      0.17      0.16       108
         SMC       0.34      0.52      0.41       108

    accuracy                           0.35       540
   macro avg       0.37      0.35      0.35       540
weighted avg       0.37      0.35      0.35       540

0.3505702716765301

quick random forest model
[[60 14 28  4  2]
 [ 7 21 10 48 22]
 [ 9 22 24 28 25]
 [ 2 11 28 26 41]
 [ 8  7  8 46 39]]
              precision    recall  f1-score   support

        ADHD       0.70      0.56      0.62       108
     HEALTHY       0.28      0.19      0.23       108
         MDD       0.24      0.22      0.23       108
         OCD       0.17      0.24      0.20       

### pretext model with soft margin loss (150 features)

In [12]:
# best model checkpoint
get_ssl_features('acrosssub_150feat_soft_margin_loss_pretext_model', dataset, df_sample, eval=True, to_disk=True, num_extracted_features=150)

features_df.shape = (2688, 152)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,142,143,144,145,146,147,148,149,ID,diagnosis
0,-11.013163,-2.091952,-0.754555,5.430296,1.437447,-3.404385,0.15488,1.820009,1.522059,1.170023,...,-2.105203,-0.543057,-2.778214,0.24518,-2.521306,-0.064656,1.595937,2.878217,sub-87964717,SMC
1,-9.135051,-3.131505,-0.448994,3.926103,2.140814,-1.957871,-1.274686,1.618623,1.282301,1.610223,...,-1.305158,-0.875738,-2.207716,-0.567321,-2.166052,-0.568555,0.867209,1.666896,sub-87964717,SMC
2,-6.380183,-2.706604,0.422875,1.756941,1.687053,-3.623837,-1.334724,3.064804,0.841953,1.148472,...,-1.216941,-0.543375,-3.546489,-3.584837,-4.308726,-0.400823,1.41538,0.387051,sub-87964717,SMC


quick SVM model
[[62  6 10 18 12]
 [21 52  3 12 20]
 [14 27 17 14 36]
 [ 4 47 29 17 11]
 [ 4 30 52  6 16]]
              precision    recall  f1-score   support

        ADHD       0.59      0.57      0.58       108
     HEALTHY       0.32      0.48      0.39       108
         MDD       0.15      0.16      0.16       108
         OCD       0.25      0.16      0.19       108
         SMC       0.17      0.15      0.16       108

    accuracy                           0.30       540
   macro avg       0.30      0.30      0.29       540
weighted avg       0.30      0.30      0.29       540

0.29490342668337044

quick random forest model
[[69 13  7 10  9]
 [16 66  4  5 17]
 [18 18 22 18 32]
 [ 8 38 23 14 25]
 [ 3 30 55  6 14]]
              precision    recall  f1-score   support

        ADHD       0.61      0.64      0.62       108
     HEALTHY       0.40      0.61      0.48       108
         MDD       0.20      0.20      0.20       108
         OCD       0.26      0.13      0.17      

In [13]:
# overtrained model
get_ssl_features('overtrained_acrosssub_150feat_soft_margin_loss_pretext_model', dataset, df_sample, eval=True, to_disk=True, num_extracted_features=150)

features_df.shape = (2688, 152)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,142,143,144,145,146,147,148,149,ID,diagnosis
0,-10.468291,-1.732355,0.341479,4.108348,3.223325,-3.020484,-0.500779,2.004912,1.712551,1.110271,...,-2.859026,-0.844783,-2.657928,0.793354,-3.167115,0.016587,1.367866,3.30115,sub-87964717,SMC
1,-9.754599,-1.935057,0.716301,2.785853,4.041155,-2.372832,-1.186751,2.527348,0.768679,1.590828,...,-1.867083,-1.739313,-1.871639,-0.142649,-3.156472,-0.317873,1.252581,2.470677,sub-87964717,SMC
2,-6.853885,-2.359985,1.545487,0.807781,2.882714,-2.942299,-0.875146,3.333212,1.048014,0.91397,...,-2.409781,-1.628225,-2.899913,-2.839725,-4.747772,-0.190376,1.226963,0.614625,sub-87964717,SMC


quick SVM model
[[58 16  5 21  8]
 [18 54  4 14 18]
 [20 22 20 12 34]
 [ 4 49 27 16 12]
 [ 7 30 52  6 13]]
              precision    recall  f1-score   support

        ADHD       0.54      0.54      0.54       108
     HEALTHY       0.32      0.50      0.39       108
         MDD       0.19      0.19      0.19       108
         OCD       0.23      0.15      0.18       108
         SMC       0.15      0.12      0.13       108

    accuracy                           0.30       540
   macro avg       0.29      0.30      0.29       540
weighted avg       0.29      0.30      0.29       540

0.28546456589167535

quick random forest model
[[61 16  8 14  9]
 [19 57 13  7 12]
 [19 14 32 18 25]
 [ 8 33 28 25 14]
 [ 4 25 62  5 12]]
              precision    recall  f1-score   support

        ADHD       0.55      0.56      0.56       108
     HEALTHY       0.39      0.53      0.45       108
         MDD       0.22      0.30      0.25       108
         OCD       0.36      0.23      0.28      