In [1]:
# Import packages and setup gpu configuration.
# This code block shouldnt need to be adjusted!
import os
import sys
import json
import yaml
import numpy as np
import copy
import math
import time
import random
from tqdm import tqdm
import webdataset as wds
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torchvision import transforms
import utils
from mae_utils.flat_models import *
import pandas as pd
from typing import List, Dict, Any, Tuple
from sklearn.preprocessing import StandardScaler

# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True
# following fixes a Conv3D CUDNN_NOT_SUPPORTED error
torch.backends.cudnn.benchmark = True

## MODEL TO LOAD ##
# YOU WILL NEED TO PRECOMPUTE AND SAVE THE FEATURES SEE: prep_HCP_downstream.ipynb
model_name = "HCPflat_large_gsrFalse_"
parquet_folder = "epoch99"

## DEFINE TARGET AND GLOBAL POOLING ##
global_pooling = True
target = "age" # "trial_type" or "sex" or "age"


# outdir = os.path.abspath(f'checkpoints/{model_name}')
outdir = os.path.abspath(f'checkpoints/{model_name}')

print("outdir", outdir)
# Load previous config.yaml if available
if os.path.exists(f"{outdir}/config.yaml"):
    config = yaml.load(open(f"{outdir}/config.yaml", 'r'), Loader=yaml.FullLoader)
    print(f"Loaded config.yaml from ckpt folder {outdir}")
    # create global variables from the config
    print("\n__CONFIG__")
    for attribute_name in config.keys():
        print(f"{attribute_name} = {config[attribute_name]}")
        globals()[attribute_name] = config[f'{attribute_name}']
    print("\n")

world_size = os.getenv('WORLD_SIZE')
if world_size is None: 
    world_size = 1
else:
    world_size = int(world_size)
print(f"WORLD_SIZE={world_size}")

if utils.is_interactive():
    # Following allows you to change functions in models.py or utils.py and 
    # have this notebook automatically update with your revisions
    %load_ext autoreload
    %autoreload 2

batch_size = probe_batch_size
num_epochs = probe_num_epochs

data_type = torch.float32 # change depending on your mixed_precision
global_batch_size = batch_size * world_size

device = torch.device('cuda')

print("PID of this process =",os.getpid())

utils.seed_everything(seed)

outdir /weka/proj-fmri/ckadirt/fMRI-foundation-model/src/checkpoints/HCPflat_large_gsrFalse_
Loaded config.yaml from ckpt folder /weka/proj-fmri/ckadirt/fMRI-foundation-model/src/checkpoints/HCPflat_large_gsrFalse_

__CONFIG__
base_lr = 0.001
batch_size = 32
ckpt_interval = 5
ckpt_saving = True
cls_embed = True
contrastive_loss_weight = 1.0
datasets_to_include = HCP
decoder_embed_dim = 512
grad_accumulation_steps = 1
grad_clip = 1.0
gsr = False
hcp_flat_path = /weka/proj-medarc/shared/HCP-Flat
mask_ratio = 0.75
model_name = HCPflat_large_gsrFalse_
no_qkv_bias = False
norm_pix_loss = False
nsd_flat_path = /weka/proj-medarc/shared/NSD-Flat
num_epochs = 100
num_frames = 16
num_samples_per_epoch = 200000
num_workers = 10
patch_size = 16
pct_masks_to_decode = 1
plotting = True
pred_t_dim = 8
print_interval = 20
probe_base_lr = 0.0003
probe_batch_size = 8
probe_num_epochs = 30
probe_num_samples_per_epoch = 100000
resume_from_ckpt = True
seed = 42
sep_pos_embed = True
t_patch_size = 2
test_nu

### Prepare hcp_flat dataset

In [2]:
from mae_utils.flat import load_hcp_flat_mask
from mae_utils.flat import create_hcp_flat
import mae_utils.visualize as vis

if utils.is_interactive(): # Use less samples per epoch for debugging
    probe_num_samples_per_epoch = 100000
    test_num_samples_per_epoch = 100000
    num_epochs = 10


# Load ckpt
if not os.path.exists(outdir) or not os.path.isdir(outdir):
    assert True, (f"\nCheckpoint folder {outdir} does not exist.\n")
else:
    checkpoint_files = [f for f in os.listdir(outdir) if f.endswith('.pth')]

    # Find the latest ckpt to load
    epoch_numbers = []
    for file in checkpoint_files:
        try:
            epoch_number = int(file.split('epoch')[-1].split('.')[0])
            epoch_numbers.append(epoch_number)
        except ValueError:
            continue
    latest_epoch = max(epoch_numbers)
    checkpoint_name = f"epoch{latest_epoch}.pth"
    
    ### Or provide the specific checkpoint you want to load
    # checkpoint_name = "epoch10.pth" #"epoch15.pth"

    # Load the checkpoint
#     checkpoint_path = os.path.join(outdir, checkpoint_name)
#     state = torch.load(checkpoint_path)

# model = mae_vit_large_fmri(
#     patch_size=16,
#     decoder_embed_dim=decoder_embed_dim,
#     t_patch_size=t_patch_size,
#     pred_t_dim=pred_t_dim,
#     decoder_depth=4,
#     cls_embed=cls_embed,
#     norm_pix_loss=norm_pix_loss,
#     no_qkv_bias=no_qkv_bias,
#     sep_pos_embed=sep_pos_embed,
#     trunc_init=trunc_init,
#     img_mask=state["model_state_dict"]['img_mask']
# )

# model.load_state_dict(state["model_state_dict"], strict=True) #model_state_dict
# print(f"\nLoaded checkpoint {checkpoint_name} from {outdir}\n")

# model.eval()
# model.requires_grad_(False)
# model.to(device)

### Prepare subjects information

In [3]:
# open the file containing subject information
if target == "age" or target == "sex":
    subject_information_HCP_path = os.path.join(hcp_flat_path, "subjects_data_restricted.csv")
    try:
        subject_information_HCP = pd.read_csv(subject_information_HCP_path)
    except:
        try:
            subject_information_HCP = pd.read_csv('./unrestricted_clane9_4_23_2024_13_28_14.csv')   
        except:
            assert False, "Subject information file not found"

    ###### This is for unrestricted
    # age_related_columns = [
    #     'Age', 'PicSeq_AgeAdj', 'CardSort_AgeAdj', 'Flanker_AgeAdj',
    #     'ReadEng_AgeAdj', 'PicVocab_AgeAdj', 'ProcSpeed_AgeAdj',
    #     'CogFluidComp_AgeAdj', 'CogEarlyComp_AgeAdj', 'CogTotalComp_AgeAdj',
    #     'CogCrystalComp_AgeAdj', 'Endurance_AgeAdj', 'Dexterity_AgeAdj',
    #     'Strength_AgeAdj', 'Odor_AgeAdj', 'Taste_AgeAdj'
    # ]
    
    # sex_related_columns = [
    #     'Gender'
    # ]

    ###### This is for restricted
    gender_related_columns = [
        'Gender'
    ]

    age_related_columns = [
        'Age_in_Yrs',
        'Menstrual_AgeBegan',
        'Menstrual_AgeIrreg',
        'Menstrual_AgeStop',
        'SSAGA_Alc_Age_1st_Use',
        'SSAGA_TB_Age_1st_Cig',
        'SSAGA_Mj_Age_1st_Use',
        'Endurance_AgeAdj',
        'Dexterity_AgeAdj',
        'Strength_AgeAdj',
        'PicSeq_AgeAdj',
        'CardSort_AgeAdj',
        'Flanker_AgeAdj',
        'ReadEng_AgeAdj',
        'PicVocab_AgeAdj',
        'ProcSpeed_AgeAdj',
        'Odor_AgeAdj',
        'Taste_AgeAdj'
    ]

    # # show the first few rows of the subject information
    # subject_information_HCP[age_related_columns + sex_related_columns].head()

    # Handle missing values (e.g., impute with mean)
    mean_age = subject_information_HCP['Age_in_Yrs'].mean()
    
    # Initialize the scaler
    scaler = StandardScaler()
    
    # Perform z-score normalization
    subject_information_HCP['Age_in_Yrs_z'] = scaler.fit_transform(subject_information_HCP[['Age_in_Yrs']])


    
def get_label_unrestricted(subject_id: List[str], target: str, method_for_age: str = 'mean') -> List:
    """
    Get the label for the given subject id and target.

    For sex 0 is F and 1 is M
    """

    # convert to list of ints
    subject_id = [int(x) for x in subject_id]

    if target == "age":
        age_array = []
        for subject in subject_id:
            c_age = subject_information_HCP[subject_information_HCP['Subject'] == subject]['Age'].values
            # if the subject is not in the subject information file trigger an error
            if len(c_age) == 0:
                assert False, f"Subject {subject} not found in subject information file"
            if len(c_age) > 1:
                print(f"Warning: Multiple entries for subject {subject}")

            c_age = c_age[0].split('-')
            if len(c_age) < 2:
                c_age = c_age[0].split('+')
                age_array.append(int(c_age[0]))
            else:
                if method_for_age == 'mean':
                    age_array.append(np.mean([int(x) for x in c_age]))
                elif method_for_age == 'min':
                    age_array.append(np.min([int(x) for x in c_age]))
                elif method_for_age == 'max':
                    age_array.append(np.max([int(x) for x in c_age]))
                else:
                    assert False, f"Method {method_for_age} not recognized"

        return np.array(age_array)  
    
    elif target == 'sex':
        sex_array = []
        for subject in subject_id:
            c_sex = subject_information_HCP[subject_information_HCP['Subject'] == subject]['Gender'].values
            # if the subject is not in the subject information file trigger an error
            if len(c_sex) == 0:
                assert False, f"Subject {subject} not found in subject information file"
            if len(c_sex) > 1:
                print(f"Warning: Multiple entries for subject {subject}")
            sex_array.append(int(c_sex[0] == 'M'))
        return sex_array

def get_label_restricted(subject_id: List[str], target: str, normalized: bool = True) -> List:
    """
    Get the label for the given subject id and target.

    For sex 0 is F and 1 is M
    """

    # convert to list of ints
    subject_id = [int(x) for x in subject_id]

    if target == "age":
        age_array = []
        for subject in subject_id:
            c_age = subject_information_HCP[subject_information_HCP['Subject'] == subject]['Age_in_Yrs' if not normalized else 'Age_in_Yrs_z'].values
            # if the subject is not in the subject information file trigger an error
            if len(c_age) == 0:
                assert False, f"Subject {subject} not found in subject information file"
            if len(c_age) > 1:
                print(f"Warning: Multiple entries for subject {subject}")

            age_array.append(np.int8(c_age[0]))

        return np.array(age_array)  
    
    elif target == 'sex':
        sex_array = []
        for subject in subject_id:
            c_sex = subject_information_HCP[subject_information_HCP['Subject'] == subject]['Gender'].values
            # if the subject is not in the subject information file trigger an error
            if len(c_sex) == 0:
                assert False, f"Subject {subject} not found in subject information file"
            if len(c_sex) > 1:
                print(f"Warning: Multiple entries for subject {subject}")
            sex_array.append(int(c_sex[0] == 'M'))
        return sex_array

### Fit sklearn model

In [4]:
import argparse
import json
import os
import pickle
from pathlib import Path

import pandas as pd
import numpy as np
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegressionCV, Ridge
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder


print(f"Target: {target}")

train_features = pd.read_parquet(f"{outdir}_gp{str(global_pooling)}/{parquet_folder}/HCP/train.parquet")
test_features = pd.read_parquet(f"{outdir}_gp{str(global_pooling)}/{parquet_folder}/HCP/test.parquet")

print(f"train: {train_features.shape}, test: {test_features.shape}")
print(f"test: {test_features.shape}")

X_train = np.stack(train_features["feature"])
X_test = np.stack(test_features["feature"])
print(f"X_train: {X_train.shape}, X_test: {X_test.shape}")
print(f"X_test: {X_test.shape}")


if target == "task":
    labels_train = train_features["task"].str.rstrip("1234").values
    labels_test = test_features["task"].str.rstrip("1234").values
elif target == "trial_type":
    labels_train = train_features["trial_type"].values
    labels_test = test_features["trial_type"].values
elif target == "sex":
    labels_train = [str(sample[0]) for sample in train_features['sub'].values.tolist()]
    labels_test = [str(sample[0]) for sample in test_features['sub'].values.tolist()]
    labels_train = get_label_restricted(labels_train, target="sex")
    labels_test = get_label_restricted(labels_test, target="sex")
elif target == "age":
    labels_train = [str(sample[0]) for sample in train_features['sub'].values.tolist()]
    labels_test = [str(sample[0]) for sample in test_features['sub'].values.tolist()]
    labels_train = get_label_restricted(labels_train, target="age")
    labels_test = get_label_restricted(labels_test, target="age")

labels_train = np.array([label[0] if isinstance(label, np.ndarray) else label for label in labels_train])
labels_test = np.array([label[0] if isinstance(label, np.ndarray) else label for label in labels_test])

if target == "trial_type" or target == "sex":
    label_enc = LabelEncoder()
    y_train = label_enc.fit_transform(labels_train)
    y_test = label_enc.transform(labels_test)
    print(f"classes ({len(label_enc.classes_)}): {label_enc.classes_}")
elif target == "age":
    y_train = np.array(labels_train)
    y_test = np.array(labels_test)


print(
    f"\ny_train: {y_train.shape} {y_train[:20]}\n"
    f"y_test: {y_test.shape} {y_test[:20]}"
)
del train_features, test_features

train_ind, val_ind = train_test_split(
    np.arange(len(X_train)), train_size=0.9, random_state=42
)
print(
    f"\ntrain_ind: {len(train_ind)} {train_ind[:10]}\n"
    f"val_ind: {len(val_ind)} {val_ind[:10]}"
)
X_train, X_val = X_train[train_ind], X_train[val_ind]
y_train, y_val = y_train[train_ind], y_train[val_ind]

print("Fitting PCA projection")
pca = PCA(n_components=384, whiten=True, svd_solver="randomized")
pca.fit(X_train)

X_train = pca.transform(X_train)
X_val = pca.transform(X_val)
X_test = pca.transform(X_test)

if target == "trial_type" or target == "sex":
    print("Fitting logistic regression")
    clf = LogisticRegressionCV()
    clf.fit(X_train, y_train)
    
    train_acc = clf.score(X_train, y_train)
    val_acc = clf.score(X_val, y_val)
    test_acc = clf.score(X_test, y_test)
    
    result = {
        "target": target,
        "train_acc": train_acc,
        "val_acc": val_acc,
        "test_acc": test_acc,
    }

elif target == "age":
    alpha = 10000
    print("Fitting ridge regression")
    clf = Ridge()
    clf.fit(X_train, y_train)
    
    # Calculate RÂ² scores
    train_r2 = clf.score(X_train, y_train)
    val_r2 = clf.score(X_val, y_val)
    test_r2 = clf.score(X_test, y_test)
    
    # Make predictions
    train_pred = clf.predict(X_train)
    val_pred = clf.predict(X_val)
    test_pred = clf.predict(X_test)
    
    # Calculate MSE
    train_mse = mean_squared_error(y_train, train_pred)
    val_mse = mean_squared_error(y_val, val_pred)
    test_mse = mean_squared_error(y_test, test_pred)
    
    # Compile results
    result = {
        "target": target,
        "train_r2": train_r2,
        "val_r2": val_r2,
        "test_r2": test_r2,
        "train_mse": float(train_mse),
        "val_mse": float(val_mse),
        "test_mse": float(test_mse),
    }

with open(f"{outdir}_gp{str(global_pooling)}/{parquet_folder}/HCP/downstream_{target}.json", 'w') as out_json:
    json.dump(result, out_json)

print(f"Done:\n{json.dumps(result)}")

Target: age
train: (111302, 9), test: (12082, 9)
test: (12082, 9)
X_train: (111302, 1024), X_test: (12082, 1024)
X_test: (12082, 1024)

y_train: (111302,) [ 0 -1 -1 -1  0  1  0  0 -1  0  0 -1 -1 -1  0  1  0  0 -1  0]
y_test: (12082,) [ 0  0 -1  0  0 -1 -1  0  0  0  0  0 -1  0  0 -1 -1  0  0  0]

train_ind: 100171 [73280 30515 54498  3785 49428 41996 11968 37870 49360 80177]
val_ind: 11131 [ 52061  23724 105636  84928  69162  94892   2339  55435 106582  63659]
Fitting PCA projection
Fitting ridge regression
Done:
{"target": "age", "train_r2": 0.15634626150131226, "val_r2": 0.15349626541137695, "test_r2": 0.12421047687530518, "train_mse": 0.33954328298568726, "val_mse": 0.3488796055316925, "test_mse": 0.3238401412963867}
