In [1]:
# Import packages and setup gpu configuration.
# This code block shouldnt need to be adjusted!
import os
import sys
import json
import yaml
import numpy as np
import copy
import math
import time
import random
from tqdm.auto import tqdm
import webdataset as wds
import matplotlib.pyplot as plt
import pandas as pd

import torch
import torch.nn as nn
from torchvision import transforms
import utils
from mae_utils.flat_models import *
import h5py
from typing import List, Dict, Any, Tuple
from sklearn.preprocessing import StandardScaler
import argparse

# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True
# following fixes a Conv3D CUDNN_NOT_SUPPORTED error
torch.backends.cudnn.benchmark = True

# ## MODEL TO LOAD ##
# model_name = "HCPflat_large_gsrFalse_"
# parquet_folder = "epoch99"

# # outdir = os.path.abspath(f'checkpoints/{model_name}')
# outdir = os.path.abspath(f'checkpoints/{model_name}')

# print("outdir", outdir)
# # Load previous config.yaml if available
# if os.path.exists(f"{outdir}/config.yaml"):
#     config = yaml.load(open(f"{outdir}/config.yaml", 'r'), Loader=yaml.FullLoader)
#     print(f"Loaded config.yaml from ckpt folder {outdir}")
#     # create global variables from the config
#     print("\n__CONFIG__")
#     for attribute_name in config.keys():
#         print(f"{attribute_name} = {config[attribute_name]}")
#         globals()[attribute_name] = config[f'{attribute_name}']
#     print("\n")

# world_size = os.getenv('WORLD_SIZE')
# if world_size is None: 
#     world_size = 1
# else:
#     world_size = int(world_size)
# print(f"WORLD_SIZE={world_size}")

# if utils.is_interactive():
#     # Following allows you to change functions in models.py or utils.py and 
#     # have this notebook automatically update with your revisions
#     %load_ext autoreload
#     %autoreload 2

# batch_size = probe_batch_size
# num_epochs = probe_num_epochs

# data_type = torch.float32 # change depending on your mixed_precision
# global_batch_size = batch_size * world_size

device = torch.device('cuda')

# hcp_flat_path = "/weka/proj-medarc/shared/HCP-Flat"
# seed = 42
num_frames = 16
gsr = False
# num_workers = 5
batch_size = 16
# target = 'sex' # This can be 'trial_type' 'age' 'sex'

print("PID of this process =",os.getpid())

PID of this process = 877125


In [2]:
# if running this interactively, can specify jupyter_args here for argparser to use
if utils.is_interactive():
    model_name_suffix = "testing"
    print("model_name_suffix:", model_name_suffix)
    
    # global_batch_size and batch_size should already be defined in the 2nd cell block
    jupyter_args = f"--hcp_flat_path=/weka/proj-medarc/shared/HCP-Flat \
                    --target=subject_id \
                    --model_suffix={model_name_suffix} \
                    --batch_size={batch_size} \
                    --max_lr=1e-5 --num_epochs=40 --no-save_ckpt --no-wandb_log --num_workers=10 \
                    --weight_decay=1e-5"
    # --multisubject_ckpt=../train_logs/multisubject_subj01_1024_24bs_nolow

    print(jupyter_args)
    jupyter_args = jupyter_args.split()
    
    from IPython.display import clear_output # function to clear print outputs in cell
    %load_ext autoreload 
    # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions
    %autoreload 2 

model_name_suffix: testing
--hcp_flat_path=/weka/proj-medarc/shared/HCP-Flat                     --target=subject_id                     --model_suffix=testing                     --batch_size=16                     --max_lr=1e-5 --num_epochs=40 --no-save_ckpt --no-wandb_log --num_workers=10                     --weight_decay=1e-5


In [3]:
parser = argparse.ArgumentParser(description="Model Training Configuration")
parser.add_argument(
    "--model_suffix", type=str, default="Testing_flat",
    help="name of model, used for ckpt saving and wandb logging (if enabled)",
)
parser.add_argument(
    "--hcp_flat_path", type=str, default=os.getcwd(),
    help="Path to where NSD data is stored / where to download it to",
)
parser.add_argument(
    "--batch_size", type=int, default=128,
    help="Batch size can be increased by 10x if only training retreival submodule and not diffusion prior",
)
parser.add_argument(
    "--wandb_log",action=argparse.BooleanOptionalAction,default=False,
    help="whether to log to wandb",
)
parser.add_argument(
    "--num_epochs",type=int,default=150,
    help="number of epochs of training",
)
parser.add_argument(
    "--lr_scheduler_type",type=str,default='cycle',choices=['cycle','linear'],
)
parser.add_argument(
    "--save_ckpt",action=argparse.BooleanOptionalAction,default=True,
)
parser.add_argument(
    "--seed",type=int,default=42,
)
parser.add_argument(
    "--max_lr",type=float,default=3e-4,
)
parser.add_argument(
    "--target",type=str,default='trial_type',  # "trial_type" or "subject_id" or see table on HCP_downstream.ipynb
)
parser.add_argument(
    "--num_workers",type=int,default=10,
)
parser.add_argument(
    "--weight_decay",type=float,default=1e-5,
)

if utils.is_interactive():
    args = parser.parse_args(jupyter_args)
else:
    args = parser.parse_args()

print(f"------ ARGS ------- \n {args}")

# create global variables without the args prefix
for attribute_name in vars(args).keys():
    globals()[attribute_name] = getattr(args, attribute_name)
    
# seed all random functions
utils.seed_everything(seed)


------ ARGS ------- 
 Namespace(model_suffix='testing', hcp_flat_path='/weka/proj-medarc/shared/HCP-Flat', batch_size=16, wandb_log=False, num_epochs=40, lr_scheduler_type='cycle', save_ckpt=False, seed=42, max_lr=1e-05, target='subject_id', num_workers=10, weight_decay=1e-05)


In [4]:
# #### UNCOMMENT THIS TO SAVE THE HCP-FLAT IN HDF5 FORMAT


# from torch.utils.data import default_collate
# from mae_utils.flat import load_hcp_flat_mask
# from mae_utils.flat import create_hcp_flat
# from mae_utils.flat import batch_unmask
# import mae_utils.visualize as vis


# batch_size = 1
# print(f"changed batch_size to {batch_size}")
# load_file_frames = num_frames * 2
# print(f"Calculating with {load_file_frames} frames, doubling to approximate TR")

# ## Test ##
# datasets_to_include = "HCP"
# assert "HCP" in datasets_to_include
# test_dataset = create_hcp_flat(root=hcp_flat_path, 
#                 clip_mode="event", frames=load_file_frames, shuffle=False, gsr=gsr, sub_list = 'test')
# test_dl = wds.WebLoader(
#     test_dataset.batched(batch_size, partial=False, collation_fn=default_collate),
#     batch_size=None,
#     shuffle=False,
#     num_workers=num_workers,
#     pin_memory=True,
# )

# ## Train ##
# assert "HCP" in datasets_to_include
# train_dataset = create_hcp_flat(root=hcp_flat_path, 
#                 clip_mode="event", frames=load_file_frames, shuffle=False, gsr=gsr, sub_list = 'train')
# train_dl = wds.WebLoader(
#     train_dataset.batched(batch_size, partial=True, collation_fn=default_collate),
#     batch_size=None,
#     shuffle=False,
#     num_workers=num_workers,
#     pin_memory=True,
# )

# def flatten_meta(meta_dict):
#     """
#     Flatten the meta dictionary by:
#     - Replacing single-item lists with the item itself.
#     - Converting tensors to scalar numbers.
#     """
#     flattened = {}
#     for key, value in meta_dict.items():
#         if isinstance(value, list):
#             if len(value) == 1:
#                 flattened[key] = value[0]  # Replace list with its single item
#             else:
#                 flattened[key] = value  # Keep as is if multiple items
#         elif isinstance(value, torch.Tensor):
#             # Convert tensor to scalar
#             if value.numel() == 1:
#                 flattened[key] = value.item()
#             else:
#                 flattened[key] = value.tolist()  # Convert multi-element tensor to list
#         else:
#             flattened[key] = value  # Keep the value as is
#     return flattened


# import h5py
# meta_array = np.array([], dtype=object)
# # Open an HDF5 file in write mode
# with h5py.File(f'test_HCP_raw_flatmaps_{load_file_frames}f.hdf5', 'w') as h5f:
#     flatmaps_dset = None
    
#     total_samples = 0

#     for i, batch in tqdm(enumerate(test_dl), total = 12000):
#         images = batch[0]
#         meta = batch[1]
#         batch_size = images.shape[0]
#         meta_serializable = meta.copy()
        
        
#         # Step 2: Serialize the dictionary to a JSON string
#         meta_str = json.dumps(flatten_meta(meta_serializable), indent=4)
#         meta_array = np.append(meta_array, meta_str)
#         if flatmaps_dset is None:
#             # Initialize datasets with unlimited (None) maxshape along the first axis
#             flatmaps_shape = (0,) + images.shape[1:]
#             flatmaps_maxshape = (None,) + images.shape[1:]

#             flatmaps_dset = h5f.create_dataset(
#                 'flatmaps',
#                 shape=flatmaps_shape,
#                 maxshape=flatmaps_maxshape,
#                 dtype=np.float16,
#                 chunks=True  # Enable chunking for efficient resizing
#             )

#         # Resize datasets to accommodate new data
#         flatmaps_dset.resize(total_samples + batch_size, axis=0)

#         # Write data to the datasets
#         flatmaps_dset[total_samples:total_samples + batch_size] = images.numpy().astype(np.float16)

#         total_samples += batch_size
        
#     print(f"Processed {total_samples} samples")
# np.save(f'metadata_test_HCP_raw_flatmaps_{load_file_frames}f.npy', meta_array)


# import h5py
# meta_array = np.array([], dtype=object)
# # Open an HDF5 file in write mode
# with h5py.File(f'train_HCP_raw_flatmaps_{load_file_frames}f.hdf5', 'w') as h5f:
#     flatmaps_dset = None
    
#     total_samples = 0

#     for i, batch in tqdm(enumerate(train_dl), total = 120000):
#         images = batch[0]
#         meta = batch[1]
#         batch_size = images.shape[0]
#         meta_serializable = meta.copy()
        
        
#         # Step 2: Serialize the dictionary to a JSON string
#         meta_str = json.dumps(flatten_meta(meta_serializable), indent=4)
#         meta_array = np.append(meta_array, meta_str)
#         if flatmaps_dset is None:
#             # Initialize datasets with unlimited (None) maxshape along the first axis
#             flatmaps_shape = (0,) + images.shape[1:]
#             flatmaps_maxshape = (None,) + images.shape[1:]

#             flatmaps_dset = h5f.create_dataset(
#                 'flatmaps',
#                 shape=flatmaps_shape,
#                 maxshape=flatmaps_maxshape,
#                 dtype=np.float16,
#                 chunks=True  # Enable chunking for efficient resizing
#             )

#         # Resize datasets to accommodate new data
#         flatmaps_dset.resize(total_samples + batch_size, axis=0)

#         # Write data to the datasets
#         flatmaps_dset[total_samples:total_samples + batch_size] = images.numpy().astype(np.float16)

#         total_samples += batch_size
        
#     print(f"Processed {total_samples} samples")
# np.save(f'metadata_train_HCP_raw_flatmaps_{load_file_frames}f.npy', meta_array)

### Data

In [5]:
hdf5_base_path_raw_file = '.'  # use this /weka/proj-fmri/ckadirt/fMRI-foundation-model/src if you don't want to store them again
load_file_frames = num_frames * 2

try:
    f_train = h5py.File(f'{hdf5_base_path_raw_file}/train_HCP_raw_flatmaps_{load_file_frames}f.hdf5', 'r')
    flatmaps_train = f_train['flatmaps']
    
    f_test = h5py.File(f'{hdf5_base_path_raw_file}/test_HCP_raw_flatmaps_{load_file_frames}f.hdf5', 'r')
    flatmaps_test = f_test['flatmaps']
    
    metadata_train = np.load(f'{hdf5_base_path_raw_file}/metadata_train_HCP_raw_flatmaps_{load_file_frames}f.npy', allow_pickle=True)
    metadata_test = np.load(f'{hdf5_base_path_raw_file}/metadata_test_HCP_raw_flatmaps_{load_file_frames}f.npy', allow_pickle=True)
    print("Loaded flatmaps")
except:
    print(f"Make sure you have the raw flatmaps precomputed for this num frames: {load_file_frames}. You can do it uncommenting the cell above")

Loaded flatmaps


In [6]:
from collections import defaultdict
# creating a new test train data split for target subject_id
split_type = 'random' 
if target == 'subject_id':
    if split_type == 'uniform':
        print("Target is subject_id, so creating a new split on datasets")
        # 1. Combine old metadata
        combined_metadata = []
        # combined_metadata will hold tuples of (json_string, source, index_in_that_source)
    
        # Append all train metadata
        for i, m_str in enumerate(metadata_train):
            combined_metadata.append((m_str, 'train', i))
        
        train_len = len(metadata_train)
        
        # Append all test metadata
        for i, m_str in enumerate(metadata_test):
            combined_metadata.append((m_str, 'test', i))
        
        # Now we have a big list containing all samples from both old train & old test
        print(f"Total combined samples: {len(combined_metadata)}")
    
        subject_to_indices = defaultdict(list)
        
        for global_idx, (m_str, source, idx_in_source) in enumerate(combined_metadata):
            m_dict = json.loads(m_str)
            subj = m_dict["sub"]      # e.g. "285446"
            subject_to_indices[subj].append(global_idx)
    
        train_ratio = 0.9
        new_train_indices = []
        new_test_indices = []
        
        for subj, global_idxs in subject_to_indices.items():
            # Shuffle the subject’s indexes in-place so we can do a random split
            random.shuffle(global_idxs)
            cutoff = max(1, int(len(global_idxs) * train_ratio))        
            
            subj_train = global_idxs[:cutoff]
            subj_test  = global_idxs[cutoff:]
            
            new_train_indices.extend(subj_train)
            new_test_indices.extend(subj_test)

    if split_type == 'random':
        combined_metadata = []
        # Label each sample by whether it came from old train/test
        for i, m_str in enumerate(metadata_train):
            combined_metadata.append((m_str, 'train', i)) 
        for i, m_str in enumerate(metadata_test):
            combined_metadata.append((m_str, 'test', i)) 
    
        print(f"Total combined samples: {len(combined_metadata)}")
        p_train = 0.9  # Probability a given sample goes to train
    
        new_train_indices = []
        new_test_indices = []
    
        # Randomly assign each sample to train or test
        for global_idx, (m_str, source, idx_in_source) in enumerate(combined_metadata):
            if random.random() < p_train:
                new_train_indices.append(global_idx)
            else:
                new_test_indices.append(global_idx)

        # 1) Build a set of subjects already in train
        subject_in_train = set()
        for g_idx in new_train_indices:
            meta_str, source, idx_in_source = combined_metadata[g_idx]
            subj = json.loads(meta_str)["sub"]
            subject_in_train.add(subj)
        
        # 2) Group test indices by subject
        subject_to_test_indices = defaultdict(list)
        for g_idx in new_test_indices:
            meta_str, source, idx_in_source = combined_metadata[g_idx]
            subj = json.loads(meta_str)["sub"]
            subject_to_test_indices[subj].append(g_idx)
        
        # 3) For any subject not in train, move exactly 1 test sample to train
        for subj, test_g_idxs in subject_to_test_indices.items():
            alone_samples = 0
            if subj not in subject_in_train:
                # Move one random test sample for this subject into train
                chosen_idx = random.choice(test_g_idxs)
                new_test_indices.remove(chosen_idx)
                new_train_indices.append(chosen_idx)
                subject_in_train.add(subj)
                alone_samples = alone_samples + 1
        print(f"Number of subjects on train with just 1 sample: {alone_samples}")

Total combined samples: 97246
Number of subjects on train with just 1 sample: 0


### Create the dataloader

In [7]:
from torch.utils.data import Dataset, DataLoader

class HCPSplitDataset(Dataset):
    def __init__(self, combined_metadata, indices, flatmaps_train, flatmaps_test):
        """
        Args:
            combined_metadata: list of (json_string, source, idx_in_source)
            indices: the list of global indices that define this split
            flatmaps_train: h5py dataset for train
            flatmaps_test:  h5py dataset for test
        """
        self.combined_metadata = combined_metadata
        self.indices = indices
        self.flatmaps_train = flatmaps_train
        self.flatmaps_test = flatmaps_test

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        # i here is the i-th sample in our new train/test list
        global_idx = self.indices[i]
        m_str, source, idx_in_source = self.combined_metadata[global_idx]

        # Parse JSON
        m_dict = json.loads(m_str)

        # Retrieve the actual flatmaps from the correct HDF5
        if source == 'train':
            # old train set
            x = self.flatmaps_train[idx_in_source]
        else:
            # old test set
            x = self.flatmaps_test[idx_in_source]

        return x, m_dict

class HCPFlatDataset(Dataset):
    def __init__(self, flatmaps, metadata):
        self.flatmaps = flatmaps
        self.metadata = metadata

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        return self.flatmaps[idx], json.loads(self.metadata[idx])

if target == 'subject_id':
    train_dataset_new = HCPSplitDataset(
        combined_metadata,
        new_train_indices,
        flatmaps_train,
        flatmaps_test
    )

    test_dataset_new = HCPSplitDataset(
        combined_metadata,
        new_test_indices,
        flatmaps_train,
        flatmaps_test
    )

    # And then the dataloaders
    train_dl = DataLoader(train_dataset_new, batch_size=batch_size, shuffle=True, num_workers=num_workers//2)
    test_dl = DataLoader(test_dataset_new, batch_size=batch_size, shuffle=False, num_workers=num_workers//2)

else:
    # original approach
    train_dataset = HCPFlatDataset(flatmaps_train, metadata_train)
    train_dl = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers//2)

    test_dataset = HCPFlatDataset(flatmaps_test, metadata_test)
    test_dl = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers//2)



In [8]:
for batch in train_dl:
    break

In [9]:
batch[0].shape

torch.Size([16, 1, 32, 144, 320])

### Load subject information

In [10]:
from sklearn.preprocessing import LabelEncoder

###### This is for restricted
# Categorical columns (e.g., demographic categories, binary diagnoses)
categorical_columns = [
    "Gender",
    "Race",
    "Ethnicity",
    "SSAGA_PanicDisorder",  # Panic disorder diagnosis (yes/no)
    "SSAGA_Depressive_Ep"   # Depressive episode diagnosis (yes/no)
]

# Numerical columns (continuous, counts, raw scores, standardized scores, etc.)
numerical_columns = [
    # Basic demographics
    "Age_in_Yrs",
    
    # Cognitive / "IQ-like" Measures
    "PMAT24_A_CR",
    "CardSort_Unadj",
    "CardSort_AgeAdj",
    "ListSort_Unadj",
    "ListSort_AgeAdj",
    "PicSeq_Unadj",
    "PicSeq_AgeAdj",
    
    # Personality Traits (Big Five)
    "NEOFAC_A",
    "NEOFAC_O",
    "NEOFAC_C",
    "NEOFAC_N",
    "NEOFAC_E",
    # If you have all 60 NEO item-level responses:
    # "NEORAW_01", "NEORAW_02", ..., "NEORAW_60",
    
    # Psychopathology / Mental Health
    "ASR_Anxd_Raw",
    "ASR_Attn_Raw",
    "ASR_Aggr_Raw",
    "DSM_Depr_Raw",
    "DSM_Anxi_Raw",
    "SSAGA_Depressive_Sx",  # Symptom count or severity
    
    # Substance Use Phenotypes
    "SSAGA_Alc_12_Frq",
    "SSAGA_Alc_12_Max_Drinks",
    "SSAGA_Times_Used_Illicits",
    "SSAGA_Times_Used_Cocaine",
    "Total_Drinks_7days",
    "Total_Any_Tobacco_7days",
    
    # Anthropometric / Basic Health
    "BMI",
    "Height",
    "Weight",
    "BPSystolic",
    "BPDiastolic",
    "HbA1C",
    "ThyroidHormone",
    
    # Sleep / Quality of Life
    "PSQI_Score",
    # If you have separate PSQI component scores, list them here too, e.g.:
    # "PSQI_Component1", "PSQI_Component2", ...
    "PainInterf_Tscore",
    "LifeSatisf_Unadj",
    "MeanPurp_Unadj"
]


if target in ['subject_id', 'trial_type']:
    target_type = 'special'
elif target in categorical_columns:
    target_type = 'categorical'
elif target in numerical_columns:
    target_type = 'numerical'


# open the file containing subject information
if not (target in ['subject_id', 'trial_type']):
    subject_information_HCP_path = os.path.join(hcp_flat_path, "subjects_data_restricted.csv")
    try:
        subject_information_HCP = pd.read_csv(subject_information_HCP_path)
    except:
        try:
            subject_information_HCP = pd.read_csv('./unrestricted_clane9_4_23_2024_13_28_14.csv')   
        except:
            assert False, "Subject information file not found"

    # # show the first few rows of the subject information
    # subject_information_HCP[age_related_columns + sex_related_columns].head()

    # Handle missing values (e.g., impute with mean)
    

    if target in numerical_columns:
        # Count NaNs or missing values
        n_missing = subject_information_HCP[target].isnull().sum()
        print(f"Number of missing values in {target}: {n_missing}. Replacing with mean.")
        mean_ = subject_information_HCP[target].mean()
        # Replace missing values or NaNs with the mean
        subject_information_HCP[target].fillna(mean_, inplace=True)
        # Initialize the scaler
        scaler = StandardScaler()    
        # Perform z-score normalization
        subject_information_HCP[f'{target}_z'] = scaler.fit_transform(subject_information_HCP[[target]])

    if target in categorical_columns:
        # Perform label encoding
        label_enc = LabelEncoder()
        subject_information_HCP[f'{target}_encoded'] = label_enc.fit_transform(subject_information_HCP[target])

def train_test_split_by_subject(df, test_ratio=0.1, random_state=42):
    """
    Split a dataframe into train and test so that
    every subject in test also appears in train at least once.

    Parameters
    ----------
    df : pd.DataFrame
        Your dataset, containing at least the columns:
        ['sub', ...]
    test_ratio : float
        Percentage of each subject's rows to allocate to test.
    random_state : int
        Random seed for reproducibility.

    Returns
    -------
    train_df : pd.DataFrame
    test_df : pd.DataFrame
    """
    np.random.seed(random_state)
    
    train_dfs = []
    test_dfs = []
    
    # Group by subject
    for subject, df_sub in df.groupby('sub'):
        n = len(df_sub)
        
        # If the subject only has 1 row, put it all in train
        if n == 1:
            train_dfs.append(df_sub)
        else:
            # Decide how many rows go to test
            n_test = int(round(test_ratio * n))
            # Ensure at least 1 row ends up in train
            # (i.e. if rounding leads to n_test == n, reduce n_test by 1)
            if n_test >= n:
                n_test = n - 1
            
            # Randomly sample n_test rows for test
            test_rows = df_sub.sample(n_test, random_state=random_state)
            # The remaining go to train
            train_rows = df_sub.drop(test_rows.index)
            
            test_dfs.append(test_rows)
            train_dfs.append(train_rows)
    
    # Combine all splits
    train_df = pd.concat(train_dfs).sample(frac=1, random_state=random_state).reset_index(drop=True)
    test_df = pd.concat(test_dfs).sample(frac=1, random_state=random_state).reset_index(drop=True)
    
    return train_df, test_df

    
def get_label_restricted(subject_id: List[str], target: str, normalized: bool = True) -> List:
    """
    Get the label for the given subject id and target
    """
    subject_id = [int(x) for x in subject_id]

    if target in numerical_columns:
        target_array = []
        for subject in subject_id:
            c_target = subject_information_HCP[subject_information_HCP['Subject'] == subject][f'{target}' if not normalized else f'{target}_z'].values
            # if the subject is not in the subject information file trigger an error
            if len(c_target) == 0:
                assert False, f"Subject {subject} not found in subject information file"
            if len(c_target) > 1:
                print(f"Warning: Multiple entries for subject {subject}")

            target_array.append(np.float32(c_target[0]))

        return np.array(target_array)
    
    elif target in categorical_columns:
        target_array = []
        for subject in subject_id:
            c_target = subject_information_HCP[subject_information_HCP['Subject'] == subject][f'{target}_encoded'].values
            # if the subject is not in the subject information file trigger an error
            if len(c_target) == 0:
                assert False, f"Subject {subject} not found in subject information file"
            if len(c_target) > 1:
                print(f"Warning: Multiple entries for subject {subject}")

            target_array.append(np.int8(c_target[0]))

        return np.array(target_array)
    

In [11]:
# # open the file containing subject information
# if target == "age" or target == "sex":
#     subject_information_HCP_path = os.path.join(hcp_flat_path, "subjects_data_restricted.csv")
#     try:
#         subject_information_HCP = pd.read_csv(subject_information_HCP_path)
#     except:
#         try:
#             subject_information_HCP = pd.read_csv('./unrestricted_clane9_4_23_2024_13_28_14.csv')   
#         except:
#             assert False, "Subject information file not found"

#     ###### This is for unrestricted
#     # age_related_columns = [
#     #     'Age', 'PicSeq_AgeAdj', 'CardSort_AgeAdj', 'Flanker_AgeAdj',
#     #     'ReadEng_AgeAdj', 'PicVocab_AgeAdj', 'ProcSpeed_AgeAdj',
#     #     'CogFluidComp_AgeAdj', 'CogEarlyComp_AgeAdj', 'CogTotalComp_AgeAdj',
#     #     'CogCrystalComp_AgeAdj', 'Endurance_AgeAdj', 'Dexterity_AgeAdj',
#     #     'Strength_AgeAdj', 'Odor_AgeAdj', 'Taste_AgeAdj'
#     # ]
    
#     # sex_related_columns = [
#     #     'Gender'
#     # ]

#     ###### This is for restricted
#     gender_related_columns = [
#         'Gender'
#     ]

#     age_related_columns = [
#         'Age_in_Yrs',
#         'Menstrual_AgeBegan',
#         'Menstrual_AgeIrreg',
#         'Menstrual_AgeStop',
#         'SSAGA_Alc_Age_1st_Use',
#         'SSAGA_TB_Age_1st_Cig',
#         'SSAGA_Mj_Age_1st_Use',
#         'Endurance_AgeAdj',
#         'Dexterity_AgeAdj',
#         'Strength_AgeAdj',
#         'PicSeq_AgeAdj',
#         'CardSort_AgeAdj',
#         'Flanker_AgeAdj',
#         'ReadEng_AgeAdj',
#         'PicVocab_AgeAdj',
#         'ProcSpeed_AgeAdj',
#         'Odor_AgeAdj',
#         'Taste_AgeAdj'
#     ]

#     # # show the first few rows of the subject information
#     # subject_information_HCP[age_related_columns + sex_related_columns].head()

#     # Handle missing values (e.g., impute with mean)
#     mean_age = subject_information_HCP['Age_in_Yrs'].mean()
    
#     # Initialize the scaler
#     scaler = StandardScaler()
    
#     # Perform z-score normalization
#     subject_information_HCP['Age_in_Yrs_z'] = scaler.fit_transform(subject_information_HCP[['Age_in_Yrs']])


    
# def get_label_unrestricted(subject_id: List[str], target: str, method_for_age: str = 'mean') -> List:
#     """
#     Get the label for the given subject id and target.

#     For sex 0 is F and 1 is M
#     """

#     # convert to list of ints
#     subject_id = [int(x) for x in subject_id]

#     if target == "age":
#         age_array = []
#         for subject in subject_id:
#             c_age = subject_information_HCP[subject_information_HCP['Subject'] == subject]['Age'].values
#             # if the subject is not in the subject information file trigger an error
#             if len(c_age) == 0:
#                 assert False, f"Subject {subject} not found in subject information file"
#             if len(c_age) > 1:
#                 print(f"Warning: Multiple entries for subject {subject}")

#             c_age = c_age[0].split('-')
#             if len(c_age) < 2:
#                 c_age = c_age[0].split('+')
#                 age_array.append(int(c_age[0]))
#             else:
#                 if method_for_age == 'mean':
#                     age_array.append(np.mean([int(x) for x in c_age]))
#                 elif method_for_age == 'min':
#                     age_array.append(np.min([int(x) for x in c_age]))
#                 elif method_for_age == 'max':
#                     age_array.append(np.max([int(x) for x in c_age]))
#                 else:
#                     assert False, f"Method {method_for_age} not recognized"

#         return np.array(age_array)  
    
#     elif target == 'sex':
#         sex_array = []
#         for subject in subject_id:
#             c_sex = subject_information_HCP[subject_information_HCP['Subject'] == subject]['Gender'].values
#             # if the subject is not in the subject information file trigger an error
#             if len(c_sex) == 0:
#                 assert False, f"Subject {subject} not found in subject information file"
#             if len(c_sex) > 1:
#                 print(f"Warning: Multiple entries for subject {subject}")
#             sex_array.append(int(c_sex[0] == 'M'))
#         return sex_array

# def get_label_restricted(subject_id: List[str], target: str, normalized: bool = True) -> List:
#     """
#     Get the label for the given subject id and target.

#     For sex 0 is F and 1 is M
#     """

#     # convert to list of ints
#     subject_id = [int(x) for x in subject_id]

#     if target == "age":
#         age_array = []
#         for subject in subject_id:
#             c_age = subject_information_HCP[subject_information_HCP['Subject'] == subject]['Age_in_Yrs' if not normalized else 'Age_in_Yrs_z'].values
#             # if the subject is not in the subject information file trigger an error
#             if len(c_age) == 0:
#                 assert False, f"Subject {subject} not found in subject information file"
#             if len(c_age) > 1:
#                 print(f"Warning: Multiple entries for subject {subject}")

#             age_array.append(np.int8(c_age[0]))

#         return np.array(age_array)  
    
#     elif target == 'sex':
#         sex_array = []
#         for subject in subject_id:
#             c_sex = subject_information_HCP[subject_information_HCP['Subject'] == subject]['Gender'].values
#             # if the subject is not in the subject information file trigger an error
#             if len(c_sex) == 0:
#                 assert False, f"Subject {subject} not found in subject information file"
#             if len(c_sex) > 1:
#                 print(f"Warning: Multiple entries for subject {subject}")
#             sex_array.append(int(c_sex[0] == 'M'))
#         return sex_array

In [12]:
from sklearn.preprocessing import LabelEncoder

if target == "trial_type":
    INCLUDE_CONDS = {
        "fear",
        "neut",
        "math",
        "story",
        "lf",
        "lh",
        "rf",
        "rh",
        "t",
        "match",
        "relation",
        "mental",
        "rnd",
        "0bk_body",
        "2bk_body",
        "0bk_faces",
        "2bk_faces",
        "0bk_places",
        "2bk_places",
        "0bk_tools",
        "2bk_tools",
    }
    # Initialize the label encoder
    label_enc = LabelEncoder()
    label_enc.fit(sorted(INCLUDE_CONDS))  # Ensure consistent ordering
elif target == 'subject_id':
    all_subs = [json.loads(i_data[0])['sub'] for i_data in combined_metadata]
    label_enc = LabelEncoder()
    label_enc.fit(all_subs)
    
if target_type in ['categorical', 'special']:
    num_classes = len(label_enc.classes_)
    print(f"Number of classes: {num_classes}")

Number of classes: 1093


In [13]:
# for sample in tqdm(train_dl):
#     x = sample[0]
#     subject_id = sample[1]['sub']

#     # benchmark time
#     start = time.time()
#     y = get_label(subject_id, 'age')
#     end = time.time()
#     print(f"Time taken: {end - start}")
#     print(x.shape, y, subject_id, torch.Tensor(y).shape)
#     break

### Create pytorch model

In [14]:
class LinearClassifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(LinearClassifier, self).__init__()
        self.linear = nn.Linear(input_dim, num_classes)
        
    def forward(self, x):
        # Flatten the input except for the batch dimension
        x = x.view(x.size(0), -1)
        out = self.linear(x)
        return out  # Raw logits

# Determine the input dimension from a single sample
# Assuming images are of shape [1, 16, 144, 320]
sample_batch = batch
sample_image = sample_batch[0][0]  # Shape: [16, 144, 320]
input_dim = sample_image.view(-1).size(0)
print(f"Input dimension: {input_dim}")


Input dimension: 1474560


In [15]:
# Initialize the model

if (target in ["trial_type", "subject_id"]) or (target in categorical_columns):
    model = LinearClassifier(input_dim=input_dim, num_classes=num_classes)
    criterion = nn.CrossEntropyLoss()

elif target in numerical_columns:
    model = LinearClassifier(input_dim=input_dim, num_classes=1)
    criterion = nn.MSELoss()

# elif target == "sex":
#     model = LinearClassifier(input_dim=input_dim, num_classes=1)
#     criterion = nn.BCEWithLogitsLoss()

# Move the model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# import schedulefree
# optimizer = schedulefree.AdamWScheduleFree(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr, weight_decay=weight_decay)

num_iterations_per_epoch = math.ceil(flatmaps_train.shape[0]/batch_size)

if lr_scheduler_type == 'linear':
    lr_scheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer,
        total_iters=int(np.floor(num_epochs*num_iterations_per_epoch)),
        last_epoch=-1
    )
elif lr_scheduler_type == 'cycle':
    total_steps=int(np.floor(num_epochs*num_iterations_per_epoch))
    print("total_steps", total_steps)
    lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, 
        max_lr=max_lr,
        total_steps=total_steps,
        final_div_factor=1000,
        last_epoch=-1, pct_start=2/num_epochs
    )


total_steps 219400


### Wandb logging

In [16]:
import wandb
import uuid

myuuid = uuid.uuid4()
str(myuuid)
if utils.is_interactive():
    print("Running in interactive notebook. Disabling W&B and ckpt saving.")
    wandb_log = False
    save_ckpt = False

if wandb_log:
    wandb_project = 'fMRI-foundation-model'
    wandb_config = {
        "model_name": f"HCPflat_raw_{target}",
        "batch_size": batch_size,
        "weight_decay": weight_decay,
        "num_epochs": num_epochs,
        "seed": seed,
        "lr_scheduler_type": lr_scheduler_type,
        "save_ckpt": save_ckpt,
        "seed": seed,
        "max_lr": max_lr,
        "target": target,
        "num_workers": num_workers,
        "weight_decay": weight_decay
    }
    print("wandb_config:\n", wandb_config)
    random_id = random.randint(0, 100000)
    wandb_id = "HCPflat_raw" + f"_{model_suffix}_{target}_{myuuid}"
    print("wandb_id:", wandb_id)
    wandb.init(
        id=wandb_id,
        project=wandb_project,
        name="HCPflat_raw"+ f"_{model_suffix}_{target}",
        config=wandb_config,
        resume="allow",
    )

Running in interactive notebook. Disabling W&B and ckpt saving.


### Training loop

In [None]:
for epoch in range(num_epochs):
    running_train_loss = 0.0
    correct_train = 0
    mse_age_train = 0.0
    total_train = 0
    step = 0

    # Training Phase
    model.train()
    optimizer.zero_grad()  # Reset gradients before starting training

    for batch in tqdm(train_dl, desc=f"Epoch {epoch+1}/{num_epochs} - Training"):
        optimizer.zero_grad()
        images = batch[0].to(device).float()  # Shape: [batch_size, num_frames, 144, 320]

        # Prepare labels based on target type
        if target == "trial_type":
            labels = batch[1]['trial_type']  # List of labels
            labels = label_enc.transform(labels)
            labels = torch.tensor(labels, dtype=torch.long).to(device)  # Shape: [batch_size]
        if target == 'subject_id':
            labels = label_enc.transform(batch[1]['sub'])
            labels = torch.tensor(labels, dtype=torch.long).to(device)  # Shape: [batch_size]
        elif target in numerical_columns:
            labels = get_label_restricted(batch[1]['sub'], target)
            labels = torch.tensor(labels, dtype=torch.float).to(device)  # Shape: [batch_size]
        elif target in categorical_columns:
            labels = get_label_restricted(batch[1]['sub'], target)
            labels = torch.tensor(labels, dtype=torch.long).to(device)  # Shape: [batch_size]

        # labels = labels.unsqueeze(1)
        # Forward pass
        outputs = model(images)  # Output shape depends on the target

        # # Compute loss
        # if target in ["trial_type", "sex"]:
        #     # For classification, ensure outputs are logits
        #     loss = criterion(outputs.squeeze(), labels.squeeze())
        # elif target == "age":
        #     # For regression, ensure outputs are single values
        loss = criterion(outputs.squeeze(), labels.squeeze())

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Accumulate loss
        running_train_loss += loss.item()

        # Calculate and accumulate metrics
        if target_type in ["categorical", "special"]:
            _, predicted = torch.max(outputs, 1)
            correct_train += (predicted == labels).sum().item()
        elif target_type == "numerical":
            mse_age_train += (torch.sum((outputs.squeeze() - labels) ** 2).item())
        # elif target == "sex":
        #     threshold = 0.5
        #     predicted = (torch.sigmoid(outputs) > threshold).float().squeeze()
        #     correct_train += (predicted == labels).sum().item()

        total_train += labels.size(0)
        step += 1

        # Print intermediate metrics every 100 steps
        if step % 100 == 0:
            if target_type in ['categorical', 'special']:
                current_accuracy = 100 * correct_train / total_train if total_train > 0 else 0.0
                print(f"Step [{step}/{len(train_dl)}] - Training Loss: {loss.item():.4f} - Training Accuracy: {current_accuracy:.2f}%")
            elif target_type == 'numerical':
                current_mse = mse_age_train / total_train if total_train > 0 else 0.0
                print(f"Step [{step}/{len(train_dl)}] - Training Loss: {loss.item():.4f} - Training MSE: {current_mse:.4f}")

        if lr_scheduler_type is not None:
                lr_scheduler.step()

    # Calculate epoch-level metrics
    epoch_train_loss = running_train_loss / step if step > 0 else 0.0

    if target_type in ['categorical', 'special']:
        train_accuracy = 100 * correct_train / total_train if total_train > 0 else 0.0
    elif target_type == 'numerical':
        train_mse = mse_age_train / total_train if total_train > 0 else 0.0

    # Validation Phase
    model.eval()
    running_val_loss = 0.0
    correct_val = 0
    mse_age_val = 0.0
    total_val = 0
    step_val = 0
    
    with torch.no_grad():
        for batch in tqdm(test_dl, desc=f"Epoch {epoch+1}/{num_epochs} - Validation"):
            images = batch[0].to(device).float()  # Removed unsqueeze(1) unless specifically needed

            # Prepare labels based on target type
            if target == "trial_type":
                labels = batch[1]['trial_type']  # List of labels
                labels = label_enc.transform(labels)
                labels = torch.tensor(labels, dtype=torch.long).to(device)  # Shape: [batch_size]
            if target == 'subject_id':
                labels = label_enc.transform(batch[1]['sub'])
                labels = torch.tensor(labels, dtype=torch.long).to(device)  # Shape: [batch_size]
            elif target in numerical_columns:
                labels = get_label_restricted(batch[1]['sub'], target)
                labels = torch.tensor(labels, dtype=torch.float).to(device)  # Shape: [batch_size]
            elif target in categorical_columns:
                labels = get_label_restricted(batch[1]['sub'], target)
                labels = torch.tensor(labels, dtype=torch.long).to(device)  # Shape: [batch_size]

            # labels = labels.unsqueeze(1)
            
            # Forward pass
            outputs = model(images)

            
            loss = criterion(outputs.squeeze(), labels.squeeze())

            # Accumulate loss
            running_val_loss += loss.item()

            # Calculate and accumulate metrics
            if target_type in ["categorical", "special"]:
                _, predicted = torch.max(outputs, 1)
                correct_val += (predicted == labels).sum().item()
            elif target_type == "numerical":
                mse_age_val += (torch.sum((outputs.squeeze() - labels) ** 2).item())
            # elif target == "sex":
            #     threshold = 0.5
            #     predicted = (torch.sigmoid(outputs) > threshold).float().squeeze()
            #     correct_train += (predicted == labels).sum().item()

            total_val += labels.size(0)
            step_val += 1

    # Calculate epoch-level validation metrics
    epoch_val_loss = running_val_loss / step_val if step_val > 0 else 0.0

    if target_type in ['categorical', 'special']:
        val_accuracy = 100 * correct_val / total_val if total_val > 0 else 0.0
    elif target_type == 'numerical':
        val_mse = mse_age_val / total_val if total_val > 0 else 0.0

    # Print epoch-level metrics
    if target_type in ['categorical', 'special']:
        print(f"Epoch [{epoch+1}/{num_epochs}] "
              f"- Training Loss: {epoch_train_loss:.4f}, Training Accuracy: {train_accuracy:.2f}% "
              f"- Validation Loss: {epoch_val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")
    elif target_type == 'numerical':
        print(f"Epoch [{epoch+1}/{num_epochs}] "
              f"- Training Loss: {epoch_train_loss:.4f}, Training MSE: {train_mse:.4f} "
              f"- Validation Loss: {epoch_val_loss:.4f}, Validation MSE: {val_mse:.4f}")

    # Log metrics with wandb
    if wandb_log:
        log_dict = {
            "epoch_train_loss": epoch_train_loss,
            "epoch_val_loss": epoch_val_loss,
        }
        if target_type in ['categorical', 'special']:
            log_dict.update({
                f"train_accuracy_{target}": train_accuracy,
                f"val_accuracy_{target}": val_accuracy,
            })
        elif target_type == 'numerical':
            log_dict.update({
                f"train_mse_{target}": train_mse,
                f"val_mse_{target}": val_mse,
            })
        wandb.log(log_dict)

# Save checkpoint if required
if save_ckpt:
    outdir = os.path.abspath(f'checkpoints/{"HCPflat_raw"+ f"_{model_suffix}_{target}"}_{random_id}')
    os.makedirs(outdir, exist_ok=True)
    print("Saving checkpoint to:", outdir)
    # Save model state
    torch.save(model.state_dict(), os.path.join(outdir, "model.pth"))
    # Save configuration
    with open(os.path.join(outdir, "config.yaml"), 'w') as f:
        yaml.dump(wandb_config, f)
    print(f"Model and config saved to {outdir}")


Epoch 1/40 - Training:   0%|          | 0/5465 [00:00<?, ?it/s]

Step [100/5465] - Training Loss: 7.2323 - Training Accuracy: 0.06%
Step [200/5465] - Training Loss: 7.0261 - Training Accuracy: 0.03%
Step [300/5465] - Training Loss: 7.1709 - Training Accuracy: 0.06%
Step [400/5465] - Training Loss: 7.0584 - Training Accuracy: 0.08%
Step [500/5465] - Training Loss: 7.2548 - Training Accuracy: 0.10%
Step [600/5465] - Training Loss: 7.2418 - Training Accuracy: 0.17%
Step [700/5465] - Training Loss: 7.3006 - Training Accuracy: 0.17%
Step [800/5465] - Training Loss: 7.0092 - Training Accuracy: 0.17%
Step [900/5465] - Training Loss: 7.0600 - Training Accuracy: 0.22%
Step [1000/5465] - Training Loss: 7.0577 - Training Accuracy: 0.26%
Step [1100/5465] - Training Loss: 6.8615 - Training Accuracy: 0.29%
Step [1200/5465] - Training Loss: 7.2380 - Training Accuracy: 0.32%
Step [1300/5465] - Training Loss: 7.2289 - Training Accuracy: 0.39%
Step [1400/5465] - Training Loss: 7.0936 - Training Accuracy: 0.47%
Step [1500/5465] - Training Loss: 7.4408 - Training Accur

Epoch 1/40 - Validation:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch [1/40] - Training Loss: 7.1431, Training Accuracy: 5.68% - Validation Loss: 7.3734, Validation Accuracy: 15.28%


Epoch 2/40 - Training:   0%|          | 0/5465 [00:00<?, ?it/s]

Step [100/5465] - Training Loss: 1.0922 - Training Accuracy: 81.06%
Step [200/5465] - Training Loss: 1.4752 - Training Accuracy: 81.16%
Step [300/5465] - Training Loss: 3.1480 - Training Accuracy: 81.25%
Step [400/5465] - Training Loss: 1.9623 - Training Accuracy: 81.25%
Step [500/5465] - Training Loss: 1.3727 - Training Accuracy: 81.44%
Step [600/5465] - Training Loss: 1.3185 - Training Accuracy: 81.59%
Step [700/5465] - Training Loss: 1.5319 - Training Accuracy: 81.60%
Step [800/5465] - Training Loss: 1.0762 - Training Accuracy: 81.27%
Step [900/5465] - Training Loss: 1.9257 - Training Accuracy: 81.10%
Step [1000/5465] - Training Loss: 1.9981 - Training Accuracy: 80.86%
Step [1100/5465] - Training Loss: 2.4925 - Training Accuracy: 80.57%
Step [1200/5465] - Training Loss: 1.8952 - Training Accuracy: 80.53%
Step [1300/5465] - Training Loss: 1.7435 - Training Accuracy: 80.40%
Step [1400/5465] - Training Loss: 3.3180 - Training Accuracy: 80.15%
Step [1500/5465] - Training Loss: 3.0942 - 