In [1]:
# Import packages and setup gpu configuration.
# This code block shouldnt need to be adjusted!
import os
import sys
import json
import yaml
import numpy as np
import copy
import math
import time
import random
from tqdm import tqdm
import webdataset as wds
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torchvision import transforms
import utils
from mae_utils.flat_models import *
import pandas as pd
from typing import List, Dict, Any, Tuple
from sklearn.preprocessing import StandardScaler

# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True
# following fixes a Conv3D CUDNN_NOT_SUPPORTED error
torch.backends.cudnn.benchmark = True

## MODEL TO LOAD ##
# YOU WILL NEED TO PRECOMPUTE AND SAVE THE FEATURES SEE: prep_HCP_downstream.ipynb
model_name = "NSDflat_large_gsrFalse_5sess_57734"
parquet_folder = "epoch99"

## DEFINE TARGET AND GLOBAL POOLING ##
global_pooling = True
target = "subject_id" # "trial_type" or "subject_id" or see table below


# outdir = os.path.abspath(f'checkpoints/{model_name}')
outdir = os.path.abspath(f'checkpoints/{model_name}')

print("outdir", outdir)
# Load previous config.yaml if available
if os.path.exists(f"{outdir}/config.yaml"):
    config = yaml.load(open(f"{outdir}/config.yaml", 'r'), Loader=yaml.FullLoader)
    print(f"Loaded config.yaml from ckpt folder {outdir}")
    # create global variables from the config
    print("\n__CONFIG__")
    for attribute_name in config.keys():
        print(f"{attribute_name} = {config[attribute_name]}")
        globals()[attribute_name] = config[f'{attribute_name}']
    print("\n")

world_size = os.getenv('WORLD_SIZE')
if world_size is None: 
    world_size = 1
else:
    world_size = int(world_size)
print(f"WORLD_SIZE={world_size}")

if utils.is_interactive():
    # Following allows you to change functions in models.py or utils.py and 
    # have this notebook automatically update with your revisions
    %load_ext autoreload
    %autoreload 2

batch_size = probe_batch_size
num_epochs = probe_num_epochs

data_type = torch.float32 # change depending on your mixed_precision
global_batch_size = batch_size * world_size

device = torch.device('cuda')

print("PID of this process =",os.getpid())

utils.seed_everything(seed)

outdir /weka/proj-fmri/ckadirt/fMRI-foundation-model/src/checkpoints/NSDflat_large_gsrFalse_40sess_9908
Loaded config.yaml from ckpt folder /weka/proj-fmri/ckadirt/fMRI-foundation-model/src/checkpoints/NSDflat_large_gsrFalse_40sess_9908

__CONFIG__
base_lr = 0.001
batch_size = 32
ckpt_interval = 25
ckpt_saving = True
cls_embed = False
cls_forward = False
contrastive_loss_weight = 1.0
datasets_to_include = NSD
decoder_cls_embed = False
decoder_embed_dim = 512
global_pool = False
grad_accumulation_steps = 1
grad_clip = 1.0
gsr = False
hcp_flat_path = /weka/proj-medarc/shared/HCP-Flat
mask_ratio = 0.75
model_name = NSDflat_large_gsrFalse_40sess
model_size = large
no_qkv_bias = False
norm_pix_loss = False
nsd_flat_path = /weka/proj-medarc/shared/NSD-Flat
num_epochs = 100
num_frames = 16
num_samples_per_epoch = 200000
num_sessions = 40
num_workers = 8
patch_size = 16
pct_masks_to_decode = 1
plotting = True
pred_t_dim = 8
print_interval = 20
probe_base_lr = 0.0003
probe_batch_size = 8
probe_

# Columns to decode from subjects

## 1. Basic Demographics
| Variable  | Short Description | Type of Value |
|-----------|-------------------|---------------|
| Age_in_Yrs       | Age of the subject (in years) | Numerical |
| Gender       | Biological sex of the subject (e.g., Male/Female/Other) | Categorical |
| Race      | Self-reported racial category (e.g., White, Black, Asian, etc.). | Categorical |
| Ethnicity | Self-reported ethnicity (e.g., Hispanic/Latino, Non-Hispanic, etc.). | Categorical |

## 2. Cognitive / “IQ-like” Measures
| Variable          | Short Description | Type of Value |
|--------------------|-------------------|---------------|
| PMAT24_A_CR       | Fluid intelligence / abstract reasoning. Assesses pattern recognition, logical reasoning, and problem-solving using 24 matrix-style puzzles. | Numerical (count score) |
| CardSort_Unadj    | Executive function (cognitive flexibility). Assesses ability to switch attention between rules or dimensions (raw/unadjusted score). | Numerical |
| CardSort_AgeAdj   | Same test as CardSort_Unadj but age-adjusted score. | Numerical (standardized) |
| ListSort_Unadj    | Working memory. Involves remembering/sorting lists (e.g., animals, foods) in size order (raw/unadjusted score). | Numerical |
| ListSort_AgeAdj   | Same test as ListSort_Unadj but age-adjusted score. | Numerical (standardized) |
| PicSeq_Unadj      | Episodic memory. Assesses ability to recall sequences of illustrated objects/activities (raw/unadjusted score). | Numerical |
| PicSeq_AgeAdj     | Same test as PicSeq_Unadj but age-adjusted score. | Numerical (standardized) |

## 3. Personality Traits (Big Five)
| Variable          | Short Description | Type of Value |
|--------------------|-------------------|---------------|
| NEOFAC_A          | Agreeableness (factor score from NEO inventory) | Numerical |
| NEOFAC_O          | Openness to Experience (factor score) | Numerical |
| NEOFAC_C          | Conscientiousness (factor score) | Numerical |
| NEOFAC_N          | Neuroticism (factor score) | Numerical |
| NEOFAC_E          | Extraversion (factor score) | Numerical |
| NEORAW_01 to 60   | Raw item-level responses (60 items) from NEO-based inventory, each reflecting specific behavior/preference statements. | Numerical (item-level) |

## 4. Psychopathology / Mental Health
| Variable          | Short Description | Type of Value |
|--------------------|-------------------|---------------|
| ASR_Anxd_Raw      | Adult Self-Report anxiety problems (raw score). | Numerical (sum) |
| ASR_Attn_Raw      | Adult Self-Report attention problems (raw score). | Numerical (sum) |
| ASR_Aggr_Raw      | Adult Self-Report aggressive behavior (raw score). | Numerical (sum) |
| DSM_Depr_Raw      | DSM-oriented depression subscale (raw score). | Numerical (sum) |
| DSM_Anxi_Raw      | DSM-oriented anxiety subscale (raw score). | Numerical (sum) |
| SSAGA_PanicDisorder | Semi-Structured Assessment for the Genetics of Alcoholism: Panic disorder diagnosis (yes/no). | Categorical (binary) |
| SSAGA_Depressive_Ep | SSAGA depression episode diagnosis (yes/no). | Categorical (binary) |
| SSAGA_Depressive_Sx | SSAGA depression symptom count or severity. | Numerical (count/scale) |

## 5. Substance Use Phenotypes
| Variable          | Short Description | Type of Value |
|--------------------|-------------------|---------------|
| SSAGA_Alc_12_Frq  | Frequency of alcohol use in last 12 months (from SSAGA). | Numerical (count/frequency) |
| SSAGA_Alc_12_Max_Drinks | Max number of drinks in one sitting, last 12 months. | Numerical (count) |
| SSAGA_Times_Used_Illicits | Times used illicit substances (lifetime or specific period). | Numerical (count) |
| SSAGA_Times_Used_Cocaine | Times used cocaine. | Numerical (count) |
| Total_Drinks_7days | Total number of drinks in the past 7 days (short-term measure). | Numerical (count) |
| Total_Any_Tobacco_7days | Total tobacco use in the past 7 days (cigarettes, e-cig, etc.). | Numerical (count) |

## 6. Anthropometric / Basic Health
| Variable          | Short Description | Type of Value |
|--------------------|-------------------|---------------|
| BMI               | Body Mass Index (weight/height²). | Numerical |
| Height            | Height of the subject (e.g., in cm). | Numerical |
| Weight            | Weight of the subject (e.g., in kg). | Numerical |
| BPSystolic        | Systolic blood pressure (mmHg). | Numerical |
| BPDiastolic       | Diastolic blood pressure (mmHg). | Numerical |
| HbA1C             | Glycated hemoglobin level (indicates average blood glucose over ~3 months). | Numerical (continuous) |
| ThyroidHormone    | Thyroid hormone levels (e.g., TSH, T4). | Numerical (continuous) |

## 7. Sleep / Quality of Life
| Variable          | Short Description | Type of Value |
|--------------------|-------------------|---------------|
| PSQI_Score        | Pittsburgh Sleep Quality Index global score (higher = poorer sleep quality). | Numerical (sum/index) |
| PSQI Components   | Sub-scores for sleep latency, duration, disturbances, etc. | Numerical |
| PainInterf_Tscore | PROMIS measure of pain interference with daily activities (T-score standardized). | Numerical (T-score) |
| LifeSatisf_Unadj  | Self-reported life satisfaction (raw score). | Numerical (score) |
| MeanPurp_Unadj    | Self-reported sense of meaning/purpose in life (raw score). | Numerical (score) |


### Prepare hcp_flat dataset

In [2]:
from mae_utils.flat import load_hcp_flat_mask
from mae_utils.flat import create_hcp_flat
import mae_utils.visualize as vis

if utils.is_interactive(): # Use less samples per epoch for debugging
    probe_num_samples_per_epoch = 100000
    test_num_samples_per_epoch = 100000
    num_epochs = 10


# Load ckpt
if not os.path.exists(outdir) or not os.path.isdir(outdir):
    assert True, (f"\nCheckpoint folder {outdir} does not exist.\n")
else:
    checkpoint_files = [f for f in os.listdir(outdir) if f.endswith('.pth')]

    # Find the latest ckpt to load
    epoch_numbers = []
    for file in checkpoint_files:
        try:
            epoch_number = int(file.split('epoch')[-1].split('.')[0])
            epoch_numbers.append(epoch_number)
        except ValueError:
            continue
    latest_epoch = max(epoch_numbers)
    checkpoint_name = f"epoch{latest_epoch}.pth"
    
    ### Or provide the specific checkpoint you want to load
    # checkpoint_name = "epoch10.pth" #"epoch15.pth"

    # Load the checkpoint
#     checkpoint_path = os.path.join(outdir, checkpoint_name)
#     state = torch.load(checkpoint_path)

# model = mae_vit_large_fmri(
#     patch_size=16,
#     decoder_embed_dim=decoder_embed_dim,
#     t_patch_size=t_patch_size,
#     pred_t_dim=pred_t_dim,
#     decoder_depth=4,
#     cls_embed=cls_embed,
#     norm_pix_loss=norm_pix_loss,
#     no_qkv_bias=no_qkv_bias,
#     sep_pos_embed=sep_pos_embed,
#     trunc_init=trunc_init,
#     img_mask=state["model_state_dict"]['img_mask']
# )

# model.load_state_dict(state["model_state_dict"], strict=True) #model_state_dict
# print(f"\nLoaded checkpoint {checkpoint_name} from {outdir}\n")

# model.eval()
# model.requires_grad_(False)
# model.to(device)

### Prepare subjects information

In [3]:
from sklearn.preprocessing import LabelEncoder

###### This is for restricted
# Categorical columns (e.g., demographic categories, binary diagnoses)
categorical_columns = [
    "Gender",
    "Race",
    "Ethnicity",
    "SSAGA_PanicDisorder",  # Panic disorder diagnosis (yes/no)
    "SSAGA_Depressive_Ep"   # Depressive episode diagnosis (yes/no)
]

# Numerical columns (continuous, counts, raw scores, standardized scores, etc.)
numerical_columns = [
    # Basic demographics
    "Age_in_Yrs",
    
    # Cognitive / "IQ-like" Measures
    "PMAT24_A_CR",
    "CardSort_Unadj",
    "CardSort_AgeAdj",
    "ListSort_Unadj",
    "ListSort_AgeAdj",
    "PicSeq_Unadj",
    "PicSeq_AgeAdj",
    
    # Personality Traits (Big Five)
    "NEOFAC_A",
    "NEOFAC_O",
    "NEOFAC_C",
    "NEOFAC_N",
    "NEOFAC_E",
    # If you have all 60 NEO item-level responses:
    # "NEORAW_01", "NEORAW_02", ..., "NEORAW_60",
    
    # Psychopathology / Mental Health
    "ASR_Anxd_Raw",
    "ASR_Attn_Raw",
    "ASR_Aggr_Raw",
    "DSM_Depr_Raw",
    "DSM_Anxi_Raw",
    "SSAGA_Depressive_Sx",  # Symptom count or severity
    
    # Substance Use Phenotypes
    "SSAGA_Alc_12_Frq",
    "SSAGA_Alc_12_Max_Drinks",
    "SSAGA_Times_Used_Illicits",
    "SSAGA_Times_Used_Cocaine",
    "Total_Drinks_7days",
    "Total_Any_Tobacco_7days",
    
    # Anthropometric / Basic Health
    "BMI",
    "Height",
    "Weight",
    "BPSystolic",
    "BPDiastolic",
    "HbA1C",
    "ThyroidHormone",
    
    # Sleep / Quality of Life
    "PSQI_Score",
    # If you have separate PSQI component scores, list them here too, e.g.:
    # "PSQI_Component1", "PSQI_Component2", ...
    "PainInterf_Tscore",
    "LifeSatisf_Unadj",
    "MeanPurp_Unadj"
]


if target in ['subject_id', 'trial_type']:
    target_type = 'special'
elif target in categorical_columns:
    target_type = 'categorical'
elif target in numerical_columns:
    target_type = 'numerical'


# open the file containing subject information
if not (target in ['subject_id', 'trial_type']):
    subject_information_HCP_path = os.path.join(hcp_flat_path, "subjects_data_restricted.csv")
    try:
        subject_information_HCP = pd.read_csv(subject_information_HCP_path)
    except:
        try:
            subject_information_HCP = pd.read_csv('./unrestricted_clane9_4_23_2024_13_28_14.csv')   
        except:
            assert False, "Subject information file not found"

    # # show the first few rows of the subject information
    # subject_information_HCP[age_related_columns + sex_related_columns].head()

    # Handle missing values (e.g., impute with mean)
    mean_age = subject_information_HCP['Age_in_Yrs'].mean()
    

    if target in numerical_columns:
        # Count NaNs or missing values
        n_missing = subject_information_HCP[target].isnull().sum()
        print(f"Number of missing values in {target}: {n_missing}. Replacing with mean.")
        mean_ = subject_information_HCP[target].mean()
        # Replace missing values or NaNs with the mean
        subject_information_HCP[target].fillna(mean_, inplace=True)
        # Initialize the scaler
        scaler = StandardScaler()    
        # Perform z-score normalization
        subject_information_HCP[f'{target}_z'] = scaler.fit_transform(subject_information_HCP[[target]])

    if target in categorical_columns:
        # Perform label encoding
        label_enc = LabelEncoder()
        subject_information_HCP[f'{target}_encoded'] = label_enc.fit_transform(subject_information_HCP[target])

def train_test_split_by_subject(df, test_ratio=0.1, random_state=42):
    """
    Split a dataframe into train and test so that
    every subject in test also appears in train at least once.

    Parameters
    ----------
    df : pd.DataFrame
        Your dataset, containing at least the columns:
        ['sub', ...]
    test_ratio : float
        Percentage of each subject's rows to allocate to test.
    random_state : int
        Random seed for reproducibility.

    Returns
    -------
    train_df : pd.DataFrame
    test_df : pd.DataFrame
    """
    np.random.seed(random_state)
    
    train_dfs = []
    test_dfs = []
    
    # Group by subject
    for subject, df_sub in df.groupby('sub'):
        n = len(df_sub)
        
        # If the subject only has 1 row, put it all in train
        if n == 1:
            train_dfs.append(df_sub)
        else:
            # Decide how many rows go to test
            n_test = int(round(test_ratio * n))
            # Ensure at least 1 row ends up in train
            # (i.e. if rounding leads to n_test == n, reduce n_test by 1)
            if n_test >= n:
                n_test = n - 1
            
            # Randomly sample n_test rows for test
            test_rows = df_sub.sample(n_test, random_state=random_state)
            # The remaining go to train
            train_rows = df_sub.drop(test_rows.index)
            
            test_dfs.append(test_rows)
            train_dfs.append(train_rows)
    
    # Combine all splits
    train_df = pd.concat(train_dfs).sample(frac=1, random_state=random_state).reset_index(drop=True)
    test_df = pd.concat(test_dfs).sample(frac=1, random_state=random_state).reset_index(drop=True)
    
    return train_df, test_df

    
def get_label_restricted(subject_id: List[str], target: str, normalized: bool = True) -> List:
    """
    Get the label for the given subject id and target
    """
    subject_id = [int(x) for x in subject_id]

    if target in numerical_columns:
        target_array = []
        for subject in subject_id:
            c_target = subject_information_HCP[subject_information_HCP['Subject'] == subject][f'{target}' if not normalized else f'{target}_z'].values
            # if the subject is not in the subject information file trigger an error
            if len(c_target) == 0:
                assert False, f"Subject {subject} not found in subject information file"
            if len(c_target) > 1:
                print(f"Warning: Multiple entries for subject {subject}")

            target_array.append(np.float32(c_target[0]))

        return np.array(target_array)
    
    elif target in categorical_columns:
        target_array = []
        for subject in subject_id:
            c_target = subject_information_HCP[subject_information_HCP['Subject'] == subject][f'{target}_encoded'].values
            # if the subject is not in the subject information file trigger an error
            if len(c_target) == 0:
                assert False, f"Subject {subject} not found in subject information file"
            if len(c_target) > 1:
                print(f"Warning: Multiple entries for subject {subject}")

            target_array.append(np.int8(c_target[0]))

        return np.array(target_array)
    

### Fit sklearn model

In [7]:
import argparse
import json
import os
import pickle
from pathlib import Path

import pandas as pd
import numpy as np
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegressionCV, Ridge
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder


print(f"Target: {target}")

train_features = pd.read_parquet(f"{outdir}_gp{str(global_pooling)}/{parquet_folder}/HCP/train.parquet")
test_features = pd.read_parquet(f"{outdir}_gp{str(global_pooling)}/{parquet_folder}/HCP/test.parquet")

print(f"train: {train_features.shape}, test: {test_features.shape}")
print(f"test: {test_features.shape}")

X_train = np.stack(train_features["feature"])
X_test = np.stack(test_features["feature"])
print(f"X_train: {X_train.shape}, X_test: {X_test.shape}")

if target == "subject_id":
    # We have to redo the train test split including every possible subject from test on train.
    all_features = pd.concat([train_features, test_features])
    all_features["sub"] = all_features["sub"].apply(lambda x: x[0])
    train_features, test_features = train_test_split_by_subject(all_features)
    labels_train = train_features['sub'].values
    labels_test = test_features['sub'].values
    X_train = np.stack(train_features["feature"])
    X_test = np.stack(test_features["feature"])
    print(f"New shapes bcz of Subjec id prediction.- X_train: {X_train.shape}, X_test: {X_test.shape}")
elif target == "trial_type":
    labels_train = train_features["trial_type"].values
    labels_test = test_features["trial_type"].values
elif target_type == "numerical":
    labels_train = [str(sample[0]) for sample in train_features['sub'].values.tolist()]
    labels_test = [str(sample[0]) for sample in test_features['sub'].values.tolist()]
    y_train = get_label_restricted(labels_train, target=target)
    y_test = get_label_restricted(labels_test, target=target)
elif target_type == "categorical":
    labels_train = [str(sample[0]) for sample in train_features['sub'].values.tolist()]
    labels_test = [str(sample[0]) for sample in test_features['sub'].values.tolist()]
    y_train = get_label_restricted(labels_train, target=target)
    y_test = get_label_restricted(labels_test, target=target)


# elif target == "sex":
#     labels_train = [str(sample[0]) for sample in train_features['sub'].values.tolist()]
#     labels_test = [str(sample[0]) for sample in test_features['sub'].values.tolist()]
#     labels_train = get_label_restricted(labels_train, target="sex")
#     labels_test = get_label_restricted(labels_test, target="sex")
# elif target == "age":
#     labels_train = [str(sample[0]) for sample in train_features['sub'].values.tolist()]
#     labels_test = [str(sample[0]) for sample in test_features['sub'].values.tolist()]
#     labels_train = get_label_restricted(labels_train, target="age")
#     labels_test = get_label_restricted(labels_test, target="age")

# labels_train = np.array([label[0] if isinstance(label, np.ndarray) else label for label in labels_train])
# labels_test = np.array([label[0] if isinstance(label, np.ndarray) else label for label in labels_test])

if target == "trial_type":
    labels_train = np.array([label[0] if isinstance(label, np.ndarray) else label for label in labels_train])
    labels_test = np.array([label[0] if isinstance(label, np.ndarray) else label for label in labels_test])
    label_enc = LabelEncoder()
    y_train = label_enc.fit_transform(labels_train)
    y_test = label_enc.transform(labels_test)
    print(f"classes ({len(label_enc.classes_)}): {label_enc.classes_}")
    
elif target == "subject_id":
    label_enc = LabelEncoder()
    y_train = label_enc.fit_transform(labels_train)
    y_test = label_enc.transform(labels_test)
    print(f"classes ({len(label_enc.classes_)}): {label_enc.classes_}")


print(
    f"\ny_train: {y_train.shape} {y_train[:20]}\n"
    f"y_test: {y_test.shape} {y_test[:20]}"
)
del train_features, test_features

train_ind, val_ind = train_test_split(
    np.arange(len(X_train)), train_size=0.9, random_state=42
)
print(
    f"\ntrain_ind: {len(train_ind)} {train_ind[:10]}\n"
    f"val_ind: {len(val_ind)} {val_ind[:10]}"
)
X_train, X_val = X_train[train_ind], X_train[val_ind]
y_train, y_val = y_train[train_ind], y_train[val_ind]

print("Fitting PCA projection")
pca = PCA(n_components=384, whiten=True, svd_solver="randomized")
pca.fit(X_train)

X_train = pca.transform(X_train)
X_val = pca.transform(X_val)
X_test = pca.transform(X_test)

if target_type == "special" or target_type == "categorical":
    print("Fitting logistic regression")
    clf = LogisticRegressionCV()
    clf.fit(X_train, y_train)
    
    train_acc = clf.score(X_train, y_train)
    val_acc = clf.score(X_val, y_val)
    test_acc = clf.score(X_test, y_test)
    
    result = {
        "target": target,
        "train_acc": train_acc,
        "val_acc": val_acc,
        "test_acc": test_acc,
    }

elif target_type == "numerical":
    alpha = 10000
    print("Fitting ridge regression")
    clf = Ridge()
    clf.fit(X_train, y_train)
    
    # Calculate R² scores
    train_r2 = clf.score(X_train, y_train)
    val_r2 = clf.score(X_val, y_val)
    test_r2 = clf.score(X_test, y_test)
    
    # Make predictions
    train_pred = clf.predict(X_train)
    val_pred = clf.predict(X_val)
    test_pred = clf.predict(X_test)
    
    # Calculate MSE
    train_mse = mean_squared_error(y_train, train_pred)
    val_mse = mean_squared_error(y_val, val_pred)
    test_mse = mean_squared_error(y_test, test_pred)
    
    # Compile results
    result = {
        "target": target,
        "train_r2": train_r2,
        "val_r2": val_r2,
        "test_r2": test_r2,
        "train_mse": float(train_mse),
        "val_mse": float(val_mse),
        "test_mse": float(test_mse),
    }

with open(f"{outdir}_gp{str(global_pooling)}/{parquet_folder}/HCP/downstream_{target}.json", 'w') as out_json:
    json.dump(result, out_json)

print(f"Done:\n{json.dumps(result)}")

Target: subject_id
train: (87748, 9), test: (9498, 9)
test: (9498, 9)
X_train: (87748, 1024), X_test: (9498, 1024)
New shapes bcz of Subjec id prediction.- X_train: (87662, 1024), X_test: (9584, 1024)
classes (1093): ['100206' '100307' '100408' ... '994273' '995174' '996782']

y_train: (87662,) [1061  370  234  885  817  741  398  107 1057  332   80 1056  185  777
  414  657  489  273 1081 1017]
y_test: (9584,) [421 803 279 528 678 879 562 338 871 436 221 289 266 307  84 490 178 493
 442 701]

train_ind: 78895 [70938 24482 28961 85974  2298 42843 60708 83536 24913 60912]
val_ind: 8767 [85066 29468 18593 71058 65582 59468 31687  8525 84802 26565]
Fitting PCA projection
Fitting logistic regression
Done:
{"target": "subject_id", "train_acc": 0.5002344888776221, "val_acc": 0.14086916847268166, "test_acc": 0.12635642737896494}
