In [1]:
# Import packages and setup gpu configuration.
# This code block shouldnt need to be adjusted!
import os
import sys
import json
import yaml
import numpy as np
import copy
import math
import time
import random
from tqdm.auto import tqdm
import webdataset as wds
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torchvision import transforms
import utils
from mae_utils.flat_models import *
import h5py
from mae_utils import flat_models
import pandas as pd
from sklearn.preprocessing import StandardScaler
from typing import List, Dict, Any, Tuple
import argparse

# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True
# following fixes a Conv3D CUDNN_NOT_SUPPORTED error
torch.backends.cudnn.benchmark = True




In [2]:
# if running this interactively, can specify jupyter_args here for argparser to use
if utils.is_interactive():
    model_name_suffix = "testing"
    print("model_name_suffix:", model_name_suffix)
    batch_size = 16
    # global_batch_size and batch_size should already be defined in the 2nd cell block
    jupyter_args = f"--found_model_name=NSDflat_large_gsrFalse_5sess_57734 --epoch_checkpoint epoch99.pth \
                    --hcp_flat_path=/weka/proj-medarc/shared/HCP-Flat \
                    --target=subject_id \
                    --model_suffix={model_name_suffix} \
                    --batch_size={batch_size} \
                    --max_lr=1e-5 --num_epochs=20 --no-save_ckpt --no-wandb_log --num_workers=10 \
                    --weight_decay=1e-5 \
                    --global_pool"
    # --multisubject_ckpt=../train_logs/multisubject_subj01_1024_24bs_nolow
    # suggested hyperparameters for trial_type: wd = 1e-5, max_lr = 3e-4
    print(jupyter_args)
    jupyter_args = jupyter_args.split()
    
    from IPython.display import clear_output # function to clear print outputs in cell
    %load_ext autoreload 
    # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions
    %autoreload 2 

model_name_suffix: testing
--found_model_name=NSDflat_large_gsrFalse_5sess_57734 --epoch_checkpoint epoch99.pth                     --hcp_flat_path=/weka/proj-medarc/shared/HCP-Flat                     --target=subject_id                     --model_suffix=testing                     --batch_size=16                     --max_lr=1e-5 --num_epochs=20 --no-save_ckpt --no-wandb_log --num_workers=10                     --weight_decay=1e-5                     --global_pool


In [3]:
parser = argparse.ArgumentParser(description="Model Training Configuration")
parser.add_argument(
    "--found_model_name", type=str, default="Testing_flat",
    help="name of model, used for ckpt saving and wandb logging (if enabled)",
)
parser.add_argument(
    "--epoch_checkpoint", type=str, default="epoch99.pth",
    help="the epoch number of the found_model_name checkpoint",
)
parser.add_argument(
    "--model_suffix", type=str, default="Testing_flat",
    help="name of model, used for ckpt saving and wandb logging (if enabled)",
)
parser.add_argument(
    "--hcp_flat_path", type=str, default=os.getcwd(),
    help="Path to where NSD data is stored / where to download it to",
)
parser.add_argument(
    "--nsd_flat_path", type=str, default='/weka/proj-medarc/shared/NSD-Flat',
    help="Path to where NSD data is stored / where to download it to",
)
parser.add_argument(
    "--batch_size", type=int, default=128,
    help="Batch size can be increased by 10x if only training retreival submodule and not diffusion prior",
)
parser.add_argument(
    "--wandb_log",action=argparse.BooleanOptionalAction,default=False,
    help="whether to log to wandb",
)
parser.add_argument(
    "--num_epochs",type=int,default=150,
    help="number of epochs of training",
)
parser.add_argument(
    "--lr_scheduler_type",type=str,default='cycle',choices=['cycle','linear'],
)
parser.add_argument(
    "--save_ckpt",action=argparse.BooleanOptionalAction,default=True,
)
parser.add_argument(
    "--seed",type=int,default=42,
)
parser.add_argument(
    "--max_lr",type=float,default=3e-4,
)
parser.add_argument(
    "--target",type=str,default='trial_type',  # "trial_type" or "subject_id" or see table on HCP_downstream.ipynb
)
parser.add_argument(
    "--num_workers",type=int,default=10,
)
parser.add_argument(
    "--weight_decay",type=float,default=1e-5,
)
parser.add_argument(
    "--global_pool",action=argparse.BooleanOptionalAction,default=False,
    help="not implemented yet",
)

if utils.is_interactive():
    args = parser.parse_args(jupyter_args)
else:
    args = parser.parse_args()

print(f"------ ARGS ------- \n {args}")

# create global variables without the args prefix
for attribute_name in vars(args).keys():
    globals()[attribute_name] = getattr(args, attribute_name)
    
# seed all random functions
utils.seed_everything(seed)


------ ARGS ------- 
 Namespace(found_model_name='NSDflat_large_gsrFalse_5sess_57734', epoch_checkpoint='epoch99.pth', model_suffix='testing', hcp_flat_path='/weka/proj-medarc/shared/HCP-Flat', nsd_flat_path='/weka/proj-medarc/shared/NSD-Flat', batch_size=16, wandb_log=False, num_epochs=20, lr_scheduler_type='cycle', save_ckpt=False, seed=42, max_lr=1e-05, target='subject_id', num_workers=10, weight_decay=1e-05, global_pool=True)


In [4]:
# ## MODEL TO LOAD ##
# if utils.is_interactive():
#     model_name = "HCPflat_large_gsrFalse_"
# else:
#     model_name = sys.argv[1]
    
# target = 'sex' # This can be 'trial_type' 'age' 'sex'


In [5]:

# outdir = os.path.abspath(f'checkpoints/{model_name}')
outdir = os.path.abspath(f'checkpoints/{found_model_name}')

print("outdir", outdir)
# Load previous config.yaml if available
if os.path.exists(f"{outdir}/config.yaml"):
    config = yaml.load(open(f"{outdir}/config.yaml", 'r'), Loader=yaml.FullLoader)
    print(f"Loaded config.yaml from ckpt folder {outdir}")
    # create global variables from the config
    print("\n__CONFIG__")
    for attribute_name in config.keys():
        print(f"{attribute_name} = {config[attribute_name]}")
        globals()[attribute_name] = config[f'{attribute_name}']
    print("\n")

world_size = os.getenv('WORLD_SIZE')
if world_size is None: 
    world_size = 1
else:
    world_size = int(world_size)
print(f"WORLD_SIZE={world_size}")

if utils.is_interactive():
    # Following allows you to change functions in models.py or utils.py and 
    # have this notebook automatically update with your revisions
    %load_ext autoreload
    %autoreload 2

# batch_size = probe_batch_size
# num_epochs = probe_num_epochs

data_type = torch.float32 # change depending on your mixed_precision
global_batch_size = batch_size * world_size

device = torch.device('cuda')

# hcp_flat_path = "/weka/proj-medarc/shared/HCP-Flat"
# seed = 42
# num_frames = 16
# gsr = False
# num_workers = 10
# batch_size = 128
# save_ckpt = True
# wandb_log = True
print("PID of this process =",os.getpid())
utils.seed_everything(seed)

outdir /weka/proj-fmri/ckadirt/fMRI-foundation-model/src/checkpoints/NSDflat_large_gsrFalse_5sess_57734
Loaded config.yaml from ckpt folder /weka/proj-fmri/ckadirt/fMRI-foundation-model/src/checkpoints/NSDflat_large_gsrFalse_5sess_57734

__CONFIG__
base_lr = 0.001
batch_size = 32
ckpt_interval = 25
ckpt_saving = True
cls_embed = False
cls_forward = False
contrastive_loss_weight = 1.0
datasets_to_include = NSD
decoder_cls_embed = False
decoder_embed_dim = 512
global_pool = False
grad_accumulation_steps = 1
grad_clip = 1.0
gsr = False
hcp_flat_path = /weka/proj-medarc/shared/HCP-Flat
mask_ratio = 0.75
model_name = NSDflat_large_gsrFalse_5sess
model_size = large
no_qkv_bias = False
norm_pix_loss = False
nsd_flat_path = /weka/proj-medarc/shared/NSD-Flat
num_epochs = 100
num_frames = 16
num_samples_per_epoch = 200000
num_sessions = 5
num_workers = 8
patch_size = 16
pct_masks_to_decode = 1
plotting = True
pred_t_dim = 8
print_interval = 20
probe_base_lr = 0.0003
probe_batch_size = 8
probe_nu

In [6]:
# if os.getenv('global_pool') == "False":
#     global_pool = False
# else:
#     global_pool = True
print(f"global_pool = {global_pool}")

try:
    gsr
except:
    gsr = True
    print("set gsr to True")
print(f"gsr = {gsr}")

for attribute_name in vars(args).keys():
    globals()[attribute_name] = getattr(args, attribute_name)

global_pool = False
gsr = False


In [7]:
# #### UNCOMMENT THIS TO SAVE THE HCP-FLAT IN HDF5 FORMAT


# from torch.utils.data import default_collate
# from mae_utils.flat import load_hcp_flat_mask
# from mae_utils.flat import create_hcp_flat
# from mae_utils.flat import batch_unmask
# import mae_utils.visualize as vis


# batch_size = 1
# print(f"changed batch_size to {batch_size}")
# load_file_frames = num_frames * 2
# print(f"Calculating with {load_file_frames} frames, doubling to approximate TR")

# ## Test ##
# datasets_to_include = "HCP"
# assert "HCP" in datasets_to_include
# test_dataset = create_hcp_flat(root=hcp_flat_path, 
#                 clip_mode="event", frames=load_file_frames, shuffle=False, gsr=gsr, sub_list = 'test')
# test_dl = wds.WebLoader(
#     test_dataset.batched(batch_size, partial=False, collation_fn=default_collate),
#     batch_size=None,
#     shuffle=False,
#     num_workers=num_workers,
#     pin_memory=True,
# )

# ## Train ##
# assert "HCP" in datasets_to_include
# train_dataset = create_hcp_flat(root=hcp_flat_path, 
#                 clip_mode="event", frames=load_file_frames, shuffle=False, gsr=gsr, sub_list = 'train')
# train_dl = wds.WebLoader(
#     train_dataset.batched(batch_size, partial=True, collation_fn=default_collate),
#     batch_size=None,
#     shuffle=False,
#     num_workers=num_workers,
#     pin_memory=True,
# )

# def flatten_meta(meta_dict):
#     """
#     Flatten the meta dictionary by:
#     - Replacing single-item lists with the item itself.
#     - Converting tensors to scalar numbers.
#     """
#     flattened = {}
#     for key, value in meta_dict.items():
#         if isinstance(value, list):
#             if len(value) == 1:
#                 flattened[key] = value[0]  # Replace list with its single item
#             else:
#                 flattened[key] = value  # Keep as is if multiple items
#         elif isinstance(value, torch.Tensor):
#             # Convert tensor to scalar
#             if value.numel() == 1:
#                 flattened[key] = value.item()
#             else:
#                 flattened[key] = value.tolist()  # Convert multi-element tensor to list
#         else:
#             flattened[key] = value  # Keep the value as is
#     return flattened


# import h5py
# meta_array = np.array([], dtype=object)
# # Open an HDF5 file in write mode
# with h5py.File(f'test_HCP_raw_flatmaps_{load_file_frames}f.hdf5', 'w') as h5f:
#     flatmaps_dset = None
    
#     total_samples = 0

#     for i, batch in tqdm(enumerate(test_dl), total = 12000):
#         images = batch[0]
#         meta = batch[1]
#         batch_size = images.shape[0]
#         meta_serializable = meta.copy()
        
        
#         # Step 2: Serialize the dictionary to a JSON string
#         meta_str = json.dumps(flatten_meta(meta_serializable), indent=4)
#         meta_array = np.append(meta_array, meta_str)
#         if flatmaps_dset is None:
#             # Initialize datasets with unlimited (None) maxshape along the first axis
#             flatmaps_shape = (0,) + images.shape[1:]
#             flatmaps_maxshape = (None,) + images.shape[1:]

#             flatmaps_dset = h5f.create_dataset(
#                 'flatmaps',
#                 shape=flatmaps_shape,
#                 maxshape=flatmaps_maxshape,
#                 dtype=np.float16,
#                 chunks=True  # Enable chunking for efficient resizing
#             )

#         # Resize datasets to accommodate new data
#         flatmaps_dset.resize(total_samples + batch_size, axis=0)

#         # Write data to the datasets
#         flatmaps_dset[total_samples:total_samples + batch_size] = images.numpy().astype(np.float16)

#         total_samples += batch_size
        
#     print(f"Processed {total_samples} samples")
# np.save(f'metadata_test_HCP_raw_flatmaps_{load_file_frames}f.npy', meta_array)


# import h5py
# meta_array = np.array([], dtype=object)
# # Open an HDF5 file in write mode
# with h5py.File(f'train_HCP_raw_flatmaps_{load_file_frames}f.hdf5', 'w') as h5f:
#     flatmaps_dset = None
    
#     total_samples = 0

#     for i, batch in tqdm(enumerate(train_dl), total = 120000):
#         images = batch[0]
#         meta = batch[1]
#         batch_size = images.shape[0]
#         meta_serializable = meta.copy()
        
        
#         # Step 2: Serialize the dictionary to a JSON string
#         meta_str = json.dumps(flatten_meta(meta_serializable), indent=4)
#         meta_array = np.append(meta_array, meta_str)
#         if flatmaps_dset is None:
#             # Initialize datasets with unlimited (None) maxshape along the first axis
#             flatmaps_shape = (0,) + images.shape[1:]
#             flatmaps_maxshape = (None,) + images.shape[1:]

#             flatmaps_dset = h5f.create_dataset(
#                 'flatmaps',
#                 shape=flatmaps_shape,
#                 maxshape=flatmaps_maxshape,
#                 dtype=np.float16,
#                 chunks=True  # Enable chunking for efficient resizing
#             )

#         # Resize datasets to accommodate new data
#         flatmaps_dset.resize(total_samples + batch_size, axis=0)

#         # Write data to the datasets
#         flatmaps_dset[total_samples:total_samples + batch_size] = images.numpy().astype(np.float16)

#         total_samples += batch_size
        
#     print(f"Processed {total_samples} samples")
# np.save(f'metadata_train_HCP_raw_flatmaps_{load_file_frames}f.npy', meta_array)

### Preparing data

In [8]:
hdf5_base_path_raw_file = '.'  # use this /weka/proj-fmri/ckadirt/fMRI-foundation-model/src if you don't want to store them again
load_file_frames = num_frames * 2

try:
    f_train = h5py.File(f'{hdf5_base_path_raw_file}/train_HCP_raw_flatmaps_{load_file_frames}f.hdf5', 'r')
    flatmaps_train = f_train['flatmaps']
    
    f_test = h5py.File(f'{hdf5_base_path_raw_file}/test_HCP_raw_flatmaps_{load_file_frames}f.hdf5', 'r')
    flatmaps_test = f_test['flatmaps']
    
    metadata_train = np.load(f'{hdf5_base_path_raw_file}/metadata_train_HCP_raw_flatmaps_{load_file_frames}f.npy', allow_pickle=True)
    metadata_test = np.load(f'{hdf5_base_path_raw_file}/metadata_test_HCP_raw_flatmaps_{load_file_frames}f.npy', allow_pickle=True)
    print("Loaded flatmaps")
except:
    print(f"Make sure you have the raw flatmaps precomputed for this num frames: {load_file_frames}. You can do it uncommenting the cell above")

Loaded flatmaps


In [9]:
from collections import defaultdict
# creating a new test train data split for target subject_id
split_type = 'random' 
if target == 'subject_id':
    if split_type == 'uniform':
        print("Target is subject_id, so creating a new split on datasets")
        # 1. Combine old metadata
        combined_metadata = []
        # combined_metadata will hold tuples of (json_string, source, index_in_that_source)
    
        # Append all train metadata
        for i, m_str in enumerate(metadata_train):
            combined_metadata.append((m_str, 'train', i))
        
        train_len = len(metadata_train)
        
        # Append all test metadata
        for i, m_str in enumerate(metadata_test):
            combined_metadata.append((m_str, 'test', i))
        
        # Now we have a big list containing all samples from both old train & old test
        print(f"Total combined samples: {len(combined_metadata)}")
    
        subject_to_indices = defaultdict(list)
        
        for global_idx, (m_str, source, idx_in_source) in enumerate(combined_metadata):
            m_dict = json.loads(m_str)
            subj = m_dict["sub"]      # e.g. "285446"
            subject_to_indices[subj].append(global_idx)
    
        train_ratio = 0.9
        new_train_indices = []
        new_test_indices = []
        
        for subj, global_idxs in subject_to_indices.items():
            # Shuffle the subject’s indexes in-place so we can do a random split
            random.shuffle(global_idxs)
            cutoff = max(1, int(len(global_idxs) * train_ratio))        
            
            subj_train = global_idxs[:cutoff]
            subj_test  = global_idxs[cutoff:]
            
            new_train_indices.extend(subj_train)
            new_test_indices.extend(subj_test)

    if split_type == 'random':
        combined_metadata = []
        # Label each sample by whether it came from old train/test
        for i, m_str in enumerate(metadata_train):
            combined_metadata.append((m_str, 'train', i)) 
        for i, m_str in enumerate(metadata_test):
            combined_metadata.append((m_str, 'test', i)) 
    
        print(f"Total combined samples: {len(combined_metadata)}")
        p_train = 0.9  # Probability a given sample goes to train
    
        new_train_indices = []
        new_test_indices = []
    
        # Randomly assign each sample to train or test
        for global_idx, (m_str, source, idx_in_source) in enumerate(combined_metadata):
            if random.random() < p_train:
                new_train_indices.append(global_idx)
            else:
                new_test_indices.append(global_idx)

        # 1) Build a set of subjects already in train
        subject_in_train = set()
        for g_idx in new_train_indices:
            meta_str, source, idx_in_source = combined_metadata[g_idx]
            subj = json.loads(meta_str)["sub"]
            subject_in_train.add(subj)
        
        # 2) Group test indices by subject
        subject_to_test_indices = defaultdict(list)
        for g_idx in new_test_indices:
            meta_str, source, idx_in_source = combined_metadata[g_idx]
            subj = json.loads(meta_str)["sub"]
            subject_to_test_indices[subj].append(g_idx)
        
        # 3) For any subject not in train, move exactly 1 test sample to train
        for subj, test_g_idxs in subject_to_test_indices.items():
            alone_samples = 0
            if subj not in subject_in_train:
                # Move one random test sample for this subject into train
                chosen_idx = random.choice(test_g_idxs)
                new_test_indices.remove(chosen_idx)
                new_train_indices.append(chosen_idx)
                subject_in_train.add(subj)
                alone_samples = alone_samples + 1
        print(f"Number of subjects on train with just 1 sample: {alone_samples}")

Target is subject_id, so creating a new split on datasets
Total combined samples: 3300


In [10]:
from torch.utils.data import Dataset, DataLoader

class HCPSplitDataset(Dataset):
    def __init__(self, combined_metadata, indices, flatmaps_train, flatmaps_test):
        """
        Args:
            combined_metadata: list of (json_string, source, idx_in_source)
            indices: the list of global indices that define this split
            flatmaps_train: h5py dataset for train
            flatmaps_test:  h5py dataset for test
        """
        self.combined_metadata = combined_metadata
        self.indices = indices
        self.flatmaps_train = flatmaps_train
        self.flatmaps_test = flatmaps_test

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        # i here is the i-th sample in our new train/test list
        global_idx = self.indices[i]
        m_str, source, idx_in_source = self.combined_metadata[global_idx]

        # Parse JSON
        m_dict = json.loads(m_str)

        # Retrieve the actual flatmaps from the correct HDF5
        if source == 'train':
            # old train set
            x = self.flatmaps_train[idx_in_source]
        else:
            # old test set
            x = self.flatmaps_test[idx_in_source]

        return x, m_dict

class HCPFlatDataset(Dataset):
    def __init__(self, flatmaps, metadata):
        self.flatmaps = flatmaps
        self.metadata = metadata

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        return self.flatmaps[idx], json.loads(self.metadata[idx])

if target == 'subject_id':
    train_dataset_new = HCPSplitDataset(
        combined_metadata,
        new_train_indices,
        flatmaps_train,
        flatmaps_test
    )

    test_dataset_new = HCPSplitDataset(
        combined_metadata,
        new_test_indices,
        flatmaps_train,
        flatmaps_test
    )

    # And then the dataloaders
    train_dl = DataLoader(train_dataset_new, batch_size=batch_size, shuffle=True, num_workers=num_workers//2)
    test_dl = DataLoader(test_dataset_new, batch_size=batch_size, shuffle=False, num_workers=num_workers//2)

else:
    # original approach
    train_dataset = HCPFlatDataset(flatmaps_train, metadata_train)
    train_dl = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers//2)

    test_dataset = HCPFlatDataset(flatmaps_test, metadata_test)
    test_dl = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers//2)

In [11]:
for batch in train_dl:
    break

In [12]:
batch[0].shape

torch.Size([16, 1, 32, 144, 320])

In [13]:
from sklearn.preprocessing import LabelEncoder

###### This is for restricted
# Categorical columns (e.g., demographic categories, binary diagnoses)
categorical_columns = [
    "Gender",
    "Race",
    "Ethnicity",
    "SSAGA_PanicDisorder",  # Panic disorder diagnosis (yes/no)
    "SSAGA_Depressive_Ep"   # Depressive episode diagnosis (yes/no)
]

# Numerical columns (continuous, counts, raw scores, standardized scores, etc.)
numerical_columns = [
    # Basic demographics
    "Age_in_Yrs",
    
    # Cognitive / "IQ-like" Measures
    "PMAT24_A_CR",
    "CardSort_Unadj",
    "CardSort_AgeAdj",
    "ListSort_Unadj",
    "ListSort_AgeAdj",
    "PicSeq_Unadj",
    "PicSeq_AgeAdj",
    
    # Personality Traits (Big Five)
    "NEOFAC_A",
    "NEOFAC_O",
    "NEOFAC_C",
    "NEOFAC_N",
    "NEOFAC_E",
    # If you have all 60 NEO item-level responses:
    # "NEORAW_01", "NEORAW_02", ..., "NEORAW_60",
    
    # Psychopathology / Mental Health
    "ASR_Anxd_Raw",
    "ASR_Attn_Raw",
    "ASR_Aggr_Raw",
    "DSM_Depr_Raw",
    "DSM_Anxi_Raw",
    "SSAGA_Depressive_Sx",  # Symptom count or severity
    
    # Substance Use Phenotypes
    "SSAGA_Alc_12_Frq",
    "SSAGA_Alc_12_Max_Drinks",
    "SSAGA_Times_Used_Illicits",
    "SSAGA_Times_Used_Cocaine",
    "Total_Drinks_7days",
    "Total_Any_Tobacco_7days",
    
    # Anthropometric / Basic Health
    "BMI",
    "Height",
    "Weight",
    "BPSystolic",
    "BPDiastolic",
    "HbA1C",
    "ThyroidHormone",
    
    # Sleep / Quality of Life
    "PSQI_Score",
    # If you have separate PSQI component scores, list them here too, e.g.:
    # "PSQI_Component1", "PSQI_Component2", ...
    "PainInterf_Tscore",
    "LifeSatisf_Unadj",
    "MeanPurp_Unadj"
]


if target in ['subject_id', 'trial_type']:
    target_type = 'special'
elif target in categorical_columns:
    target_type = 'categorical'
elif target in numerical_columns:
    target_type = 'numerical'


# open the file containing subject information
if not (target in ['subject_id', 'trial_type']):
    subject_information_HCP_path = os.path.join(hcp_flat_path, "subjects_data_restricted.csv")
    try:
        subject_information_HCP = pd.read_csv(subject_information_HCP_path)
    except:
        try:
            subject_information_HCP = pd.read_csv('./unrestricted_clane9_4_23_2024_13_28_14.csv')   
        except:
            assert False, "Subject information file not found"

    # # show the first few rows of the subject information
    # subject_information_HCP[age_related_columns + sex_related_columns].head()

    # Handle missing values (e.g., impute with mean)
    

    if target in numerical_columns:
        # Count NaNs or missing values
        n_missing = subject_information_HCP[target].isnull().sum()
        print(f"Number of missing values in {target}: {n_missing}. Replacing with mean.")
        mean_ = subject_information_HCP[target].mean()
        # Replace missing values or NaNs with the mean
        subject_information_HCP[target].fillna(mean_, inplace=True)
        # Initialize the scaler
        scaler = StandardScaler()    
        # Perform z-score normalization
        subject_information_HCP[f'{target}_z'] = scaler.fit_transform(subject_information_HCP[[target]])

    if target in categorical_columns:
        # Perform label encoding
        label_enc = LabelEncoder()
        subject_information_HCP[f'{target}_encoded'] = label_enc.fit_transform(subject_information_HCP[target])

def train_test_split_by_subject(df, test_ratio=0.1, random_state=42):
    """
    Split a dataframe into train and test so that
    every subject in test also appears in train at least once.

    Parameters
    ----------
    df : pd.DataFrame
        Your dataset, containing at least the columns:
        ['sub', ...]
    test_ratio : float
        Percentage of each subject's rows to allocate to test.
    random_state : int
        Random seed for reproducibility.

    Returns
    -------
    train_df : pd.DataFrame
    test_df : pd.DataFrame
    """
    np.random.seed(random_state)
    
    train_dfs = []
    test_dfs = []
    
    # Group by subject
    for subject, df_sub in df.groupby('sub'):
        n = len(df_sub)
        
        # If the subject only has 1 row, put it all in train
        if n == 1:
            train_dfs.append(df_sub)
        else:
            # Decide how many rows go to test
            n_test = int(round(test_ratio * n))
            # Ensure at least 1 row ends up in train
            # (i.e. if rounding leads to n_test == n, reduce n_test by 1)
            if n_test >= n:
                n_test = n - 1
            
            # Randomly sample n_test rows for test
            test_rows = df_sub.sample(n_test, random_state=random_state)
            # The remaining go to train
            train_rows = df_sub.drop(test_rows.index)
            
            test_dfs.append(test_rows)
            train_dfs.append(train_rows)
    
    # Combine all splits
    train_df = pd.concat(train_dfs).sample(frac=1, random_state=random_state).reset_index(drop=True)
    test_df = pd.concat(test_dfs).sample(frac=1, random_state=random_state).reset_index(drop=True)
    
    return train_df, test_df

    
def get_label_restricted(subject_id: List[str], target: str, normalized: bool = True) -> List:
    """
    Get the label for the given subject id and target
    """
    subject_id = [int(x) for x in subject_id]

    if target in numerical_columns:
        target_array = []
        for subject in subject_id:
            c_target = subject_information_HCP[subject_information_HCP['Subject'] == subject][f'{target}' if not normalized else f'{target}_z'].values
            # if the subject is not in the subject information file trigger an error
            if len(c_target) == 0:
                assert False, f"Subject {subject} not found in subject information file"
            if len(c_target) > 1:
                print(f"Warning: Multiple entries for subject {subject}")

            target_array.append(np.float32(c_target[0]))

        return np.array(target_array)
    
    elif target in categorical_columns:
        target_array = []
        for subject in subject_id:
            c_target = subject_information_HCP[subject_information_HCP['Subject'] == subject][f'{target}_encoded'].values
            # if the subject is not in the subject information file trigger an error
            if len(c_target) == 0:
                assert False, f"Subject {subject} not found in subject information file"
            if len(c_target) > 1:
                print(f"Warning: Multiple entries for subject {subject}")

            target_array.append(np.int8(c_target[0]))

        return np.array(target_array)

In [14]:
from sklearn.preprocessing import LabelEncoder

if target == "trial_type":
    INCLUDE_CONDS = {
        "fear",
        "neut",
        "math",
        "story",
        "lf",
        "lh",
        "rf",
        "rh",
        "t",
        "match",
        "relation",
        "mental",
        "rnd",
        "0bk_body",
        "2bk_body",
        "0bk_faces",
        "2bk_faces",
        "0bk_places",
        "2bk_places",
        "0bk_tools",
        "2bk_tools",
    }
    # Initialize the label encoder
    label_enc = LabelEncoder()
    label_enc.fit(sorted(INCLUDE_CONDS))  # Ensure consistent ordering
elif target == 'subject_id':
    all_subs = [json.loads(i_data[0])['sub'] for i_data in combined_metadata]
    label_enc = LabelEncoder()
    label_enc.fit(all_subs)
    
if target_type in ['categorical', 'special']:
    num_classes = len(label_enc.classes_)
    print(f"Number of classes: {num_classes}")

Number of classes: 376


### Creating and loading Model

In [15]:
from mae_utils.flat import load_hcp_flat_mask, load_nsd_flat_mask
from mae_utils.flat import create_hcp_flat
from mae_utils.flat import batch_unmask
import mae_utils.visualize as vis


if "HCP" in datasets_to_include:
    flat_mask = load_hcp_flat_mask()
    nsd_mask = None
    hcp_mask = None
elif "NSD" in datasets_to_include:
    flat_mask = load_nsd_flat_mask()
    nsd_mask = None
    hcp_mask = None
elif "BOTH" in datasets_to_include:
    flat_mask = None
    nsd_mask = load_nsd_flat_mask()
    hcp_mask = load_hcp_flat_mask()

assert model_size in {"huge", "large", "small"}, "undefined model_size"

if model_size=="huge":
    mae_model = flat_models.mae_vit_huge_fmri(
        patch_size=patch_size,
        decoder_embed_dim=decoder_embed_dim,
        t_patch_size=t_patch_size,
        pred_t_dim=pred_t_dim,
        decoder_depth=4,
        cls_embed=cls_embed,
        norm_pix_loss=norm_pix_loss,
        no_qkv_bias=no_qkv_bias,
        sep_pos_embed=sep_pos_embed,
        trunc_init=trunc_init,
        pct_masks_to_decode=pct_masks_to_decode,
        img_mask=flat_mask,
        nsd_mask=nsd_mask,
        hcp_mask=hcp_mask,
        use_source_embeds=use_source_embeds,
        use_decoder_contrastive_loss=use_decoder_contrastive_loss,
        source_embed_train_mode=source_embed_train_mode,
        source_embed_mode=source_embed_mode,
        use_contrastive_loss=use_contrastive_loss
    )
elif model_size=="large":
    mae_model = flat_models.mae_vit_large_fmri(
        patch_size=patch_size,
        decoder_embed_dim=decoder_embed_dim,
        t_patch_size=t_patch_size,
        pred_t_dim=pred_t_dim,
        decoder_depth=4,
        cls_embed=cls_embed,
        norm_pix_loss=norm_pix_loss,
        no_qkv_bias=no_qkv_bias,
        sep_pos_embed=sep_pos_embed,
        trunc_init=trunc_init,
        pct_masks_to_decode=pct_masks_to_decode,
        img_mask=flat_mask,
        nsd_mask=nsd_mask,
        hcp_mask=hcp_mask,
        use_source_embeds=use_source_embeds,
        use_decoder_contrastive_loss=use_decoder_contrastive_loss,
        source_embed_train_mode=source_embed_train_mode,
        source_embed_mode=source_embed_mode,
        use_contrastive_loss=use_contrastive_loss
    )
elif model_size=="small":
    mae_model = flat_models.mae_vit_small_fmri(
        patch_size=patch_size,
        decoder_embed_dim=decoder_embed_dim,
        t_patch_size=t_patch_size,
        pred_t_dim=pred_t_dim,
        decoder_depth=4,
        cls_embed=cls_embed,
        norm_pix_loss=norm_pix_loss,
        no_qkv_bias=no_qkv_bias,
        sep_pos_embed=sep_pos_embed,
        trunc_init=trunc_init,
        pct_masks_to_decode=pct_masks_to_decode,
        img_mask=flat_mask,
        nsd_mask=nsd_mask,
        hcp_mask=hcp_mask,
        use_source_embeds=use_source_embeds,
        use_decoder_contrastive_loss=use_decoder_contrastive_loss,
        source_embed_train_mode=source_embed_train_mode,
        source_embed_mode=source_embed_mode,
        use_contrastive_loss=use_contrastive_loss
    )

img_size (144, 320) patch_size (16, 16) frames 16 t_patch_size 2
model initialized


In [16]:
checkpoint_files = [f for f in os.listdir(outdir) if f.endswith('.pth')]

latest_checkpoint = epoch_checkpoint
    
print(f"latest_checkpoint: {latest_checkpoint}")

# Load the checkpoint
checkpoint_path = os.path.join(outdir, latest_checkpoint)

state = torch.load(checkpoint_path)
mae_model.load_state_dict(state["model_state_dict"], strict=False)
mae_model.to(device)

print(f"\nLoaded checkpoint {latest_checkpoint} from {outdir}\n")

latest_checkpoint: epoch99.pth


  state = torch.load(checkpoint_path)



Loaded checkpoint epoch99.pth from /weka/proj-fmri/ckadirt/fMRI-foundation-model/src/checkpoints/NSDflat_large_gsrFalse_5sess_57734



In [17]:
class LinearClassifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(LinearClassifier, self).__init__()
        self.linear = nn.Linear(input_dim, num_classes)
        
    def forward(self, x):
        # Flatten the input except for the batch dimension
        x = x.view(x.size(0), -1)
        out = self.linear(x)
        return out  # Raw logits

# Determine the input dimension from a single sample
# Assuming images are of shape [1, 16, 144, 320]
with torch.no_grad():
    images = batch[0]
    images_shape = images.shape
    images_reshaped = images.view(len(images), 2, images_shape[2]//2, images_shape[3], images_shape[4])
    images = images_reshaped.mean(dim=1).unsqueeze(1).to(torch.float)
    print(images.shape)
    input_dim = np.prod(mae_model(images.to(device),forward_features=True, global_pool=global_pool, cls_forward=cls_forward).shape[1:])
    print(f"Input dimension: {input_dim}")


torch.Size([16, 1, 16, 144, 320])
Input dimension: 1024


In [18]:
# mae_model.n_mask_patches, cls_forward, global_pool

In [19]:
class FullModel(nn.Module):
    def __init__(self, lc_model, mae_model):
        super(FullModel, self).__init__()
        self.lc_model = lc_model
        self.mae_model = mae_model
        
        
    def forward(self, x, gsr):
        x = self.mae_model(x, forward_features=True, global_pool=global_pool, cls_forward=cls_forward)
        x = self.lc_model(x)
        return x


In [20]:
# Initialize the model

if (target in ["trial_type", "subject_id"]) or (target in categorical_columns):
    lc_model = LinearClassifier(input_dim=input_dim, num_classes=num_classes)
    criterion = nn.CrossEntropyLoss()

elif target in numerical_columns:
    lc_model = LinearClassifier(input_dim=input_dim, num_classes=1)
    criterion = nn.MSELoss()

# elif target == "sex":
#     lc_model = LinearClassifier(input_dim=input_dim, num_classes=1)
#     criterion = nn.BCEWithLogitsLoss()


# lc_model = LinearClassifier(input_dim=input_dim, num_classes=num_classes)

model = FullModel(lc_model, mae_model)

# Move the model to the GPU
model.to(device)

# Define optimizer with L2 regularization (weight_decay)
# learning_rate = 1e-4
# weight_decay = 1e-5  # Adjust based on your needs

optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr, weight_decay=weight_decay)

num_iterations_per_epoch = math.ceil(flatmaps_train.shape[0]/batch_size)

if lr_scheduler_type == 'linear':
    lr_scheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer,
        total_iters=int(np.floor(num_epochs*num_iterations_per_epoch)),
        last_epoch=-1
    )
elif lr_scheduler_type == 'cycle':
    total_steps=int(np.floor(num_epochs*num_iterations_per_epoch))
    print("total_steps", total_steps)
    lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, 
        max_lr=max_lr,
        total_steps=total_steps,
        final_div_factor=1000,
        last_epoch=-1, pct_start=2/num_epochs
    )



# num_epochs = 20  # Adjust as needed


total_steps 3760


### Data

In [21]:
import uuid

myuuid = uuid.uuid4()
str(myuuid)

'b544ad7c-f39d-4c1c-a38a-2d0f9dc863ac'

In [22]:
import wandb

if utils.is_interactive():
    print("Running in interactive notebook. Disabling W&B and ckpt saving.")
    wandb_log = False
    save_ckpt = False

if wandb_log:
    wandb_project = 'fMRI-foundation-model'
    wandb_config = {
        "model_name": f'{found_model_name}_HCP_FT_{target}',
        "batch_size": batch_size,
        "weight_decay": weight_decay,
        "num_epochs": num_epochs,
        "seed": seed,
        "lr_scheduler_type": lr_scheduler_type,
        "save_ckpt": save_ckpt,
        "seed": seed,
        "max_lr": max_lr,
        "target": target,
        "num_workers": num_workers,
        "weight_decay": weight_decay
    }
    print("wandb_config:\n", wandb_config)
    random_id = random.randint(0, 100000)
    wandb_id = f"{found_model_name}_{model_suffix}_{target}_HCPFT_{myuuid}"
    print("wandb_id:", wandb_id)
    wandb.init(
        id=wandb_id,
        project=wandb_project,
        name=f"{found_model_name}_{model_suffix}_{target}_HCPFT",
        config=wandb_config,
        resume="allow",
    )

Running in interactive notebook. Disabling W&B and ckpt saving.


In [23]:
for epoch in range(num_epochs):
    running_train_loss = 0.0
    correct_train = 0
    mse_age_train = 0.0
    total_train = 0
    step = 0

    # with torch.amp.autocast(device_type='cuda'):
    # Training Phase
    model.train()
    for batch in tqdm(train_dl, desc=f"Epoch {epoch+1}/{num_epochs} - Training"):
        optimizer.zero_grad()
        images = batch[0].to(device).float()  # Shape: [batch_size, 1, 16, 144, 320]

        images_shape = images.shape
        images_reshaped = images.view(len(images), 2, images_shape[2]//2, images_shape[3], images_shape[4])
        images = images_reshaped.mean(dim=1).unsqueeze(1) 
        
        # Prepare labels based on target type
        if target == "trial_type":
            labels = batch[1]['trial_type']  # List of labels
            labels = label_enc.transform(labels)
            labels = torch.tensor(labels, dtype=torch.long).to(device)  # Shape: [batch_size]
        if target == 'subject_id':
            labels = label_enc.transform(batch[1]['sub'])
            labels = torch.tensor(labels, dtype=torch.long).to(device)  # Shape: [batch_size]
        elif target in numerical_columns:
            labels = get_label_restricted(batch[1]['sub'], target)
            labels = torch.tensor(labels, dtype=torch.float).to(device)  # Shape: [batch_size]
        elif target in categorical_columns:
            labels = get_label_restricted(batch[1]['sub'], target)
            labels = torch.tensor(labels, dtype=torch.long).to(device)  # Shape: [batch_size]
        
        
        # Forward pass
        outputs = model(images, gsr=gsr)  # Shape: [num_train_samples, num_classes]
        
        # Compute loss
        loss = criterion(outputs.squeeze(), labels.squeeze())
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
        # Accumulate loss
        running_train_loss += loss.item()

        
        # Calculate and accumulate metrics
        if target_type in ["categorical", "special"]:
            _, predicted = torch.max(outputs, 1)
            correct_train += (predicted == labels).sum().item()
        elif target_type == "numerical":
            mse_age_train += (torch.sum((outputs.squeeze() - labels) ** 2).item())

        total_train += labels.size(0)
        step += 1

        # Print intermediate metrics every 100 steps
        if step % 100 == 0:
            if target_type in ['categorical', 'special']:
                current_accuracy = 100 * correct_train / total_train if total_train > 0 else 0.0
                print(f"Step [{step}/{len(train_dl)}] - Training Loss: {loss.item():.4f} - Training Accuracy: {current_accuracy:.2f}%")
            elif target_type == 'numerical':
                current_mse = mse_age_train / total_train if total_train > 0 else 0.0
                print(f"Step [{step}/{len(train_dl)}] - Training Loss: {loss.item():.4f} - Training MSE: {current_mse:.4f}")

        if lr_scheduler_type is not None:
                lr_scheduler.step()


    # Calculate epoch-level metrics
    epoch_train_loss = running_train_loss / step if step > 0 else 0.0

    if target_type in ['categorical', 'special']:
        train_accuracy = 100 * correct_train / total_train if total_train > 0 else 0.0
    elif target_type == 'numerical':
        train_mse = mse_age_train / total_train if total_train > 0 else 0.0
    
    # Validation Phase
    model.eval()
    running_val_loss = 0.0
    correct_val = 0
    mse_age_val = 0.0
    total_val = 0
    step_val = 0
    with torch.no_grad():
        for batch in tqdm(test_dl, desc=f"Epoch {epoch+1}/{num_epochs} - Validation"):
            images = batch[0].to(device).float()

            images_shape = images.shape
            images_reshaped = images.view(len(images), 2, images_shape[2]//2, images_shape[3], images_shape[4])
            images = images_reshaped.mean(dim=1).unsqueeze(1) 
            
            # Prepare labels based on target type
            if target == "trial_type":
                labels = batch[1]['trial_type']  # List of labels
                labels = label_enc.transform(labels)
                labels = torch.tensor(labels, dtype=torch.long).to(device)  # Shape: [batch_size]
            if target == 'subject_id':
                labels = label_enc.transform(batch[1]['sub'])
                labels = torch.tensor(labels, dtype=torch.long).to(device)  # Shape: [batch_size]
            elif target in numerical_columns:
                labels = get_label_restricted(batch[1]['sub'], target)
                labels = torch.tensor(labels, dtype=torch.float).to(device)  # Shape: [batch_size]
            elif target in categorical_columns:
                labels = get_label_restricted(batch[1]['sub'], target)
                labels = torch.tensor(labels, dtype=torch.long).to(device)  # Shape: [batch_size]
                
            # Forward pass
            outputs = model(images, gsr=gsr)
            
            # Compute loss
            loss = criterion(outputs.squeeze(), labels.squeeze())

            # Accumulate loss
            running_val_loss += loss.item()

            # Calculate and accumulate metrics
            if target_type in ["categorical", "special"]:
                _, predicted = torch.max(outputs, 1)
                correct_val += (predicted == labels).sum().item()
            elif target_type == "numerical":
                mse_age_val += (torch.sum((outputs.squeeze() - labels) ** 2).item())

            total_val += labels.size(0)
            step_val += 1

    # Calculate epoch-level validation metrics
    epoch_val_loss = running_val_loss / step_val if step_val > 0 else 0.0

    if target_type in ['categorical', 'special']:
        val_accuracy = 100 * correct_val / total_val if total_val > 0 else 0.0
    elif target_type == 'numerical':
        val_mse = mse_age_val / total_val if total_val > 0 else 0.0

    # Print epoch-level metrics
    if target_type in ['categorical', 'special']:
        print(f"Epoch [{epoch+1}/{num_epochs}] "
              f"- Training Loss: {epoch_train_loss:.4f}, Training Accuracy: {train_accuracy:.2f}% "
              f"- Validation Loss: {epoch_val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")
    elif target_type == 'numerical':
        print(f"Epoch [{epoch+1}/{num_epochs}] "
              f"- Training Loss: {epoch_train_loss:.4f}, Training MSE: {train_mse:.4f} "
              f"- Validation Loss: {epoch_val_loss:.4f}, Validation MSE: {val_mse:.4f}")

    # Log metrics with wandb
    if wandb_log:
        log_dict = {
            "epoch_train_loss": epoch_train_loss,
            "epoch_val_loss": epoch_val_loss,
        }
        if target in ["trial_type", "sex"]:
            log_dict.update({
                f"train_accuracy_{target}": train_accuracy,
                f"val_accuracy_{target}": val_accuracy,
            })
        elif target == "age":
            log_dict.update({
                f"train_mse_{target}": train_mse,
                f"val_mse_{target}": val_mse,
            })
        wandb.log(log_dict)
        
    if save_ckpt:
        outdir = os.path.abspath(f'checkpoints/{f"{found_model_name}_{model_suffix}_{target}_HCPFT"}')
        os.makedirs(outdir, exist_ok=True)
        print("outdir", outdir)
        # Save model and config
        torch.save(model.state_dict(), f"{outdir}/model.pth")
        with open(f"{outdir}/config.yaml", 'w') as f:
            yaml.dump(wandb_config, f)
        print(f"Saved model and config to {outdir}")
    

Epoch 1/20 - Training:   0%|          | 0/177 [00:01<?, ?it/s]

Step [100/177] - Training Loss: 6.4023 - Training Accuracy: 0.25%


Epoch 1/20 - Validation:   0%|          | 0/31 [00:01<?, ?it/s]

Epoch [1/20] - Training Loss: 6.2037, Training Accuracy: 0.39% - Validation Loss: 6.0060, Validation Accuracy: 0.62%


Epoch 2/20 - Training:   0%|          | 0/177 [00:01<?, ?it/s]

Step [100/177] - Training Loss: 5.9440 - Training Accuracy: 0.88%


Epoch 2/20 - Validation:   0%|          | 0/31 [00:01<?, ?it/s]

Epoch [2/20] - Training Loss: 5.8740, Training Accuracy: 0.96% - Validation Loss: 5.7398, Validation Accuracy: 1.04%


Epoch 3/20 - Training:   0%|          | 0/177 [00:01<?, ?it/s]

Step [100/177] - Training Loss: 5.7783 - Training Accuracy: 1.75%


Epoch 3/20 - Validation:   0%|          | 0/31 [00:01<?, ?it/s]

Epoch [3/20] - Training Loss: 5.4559, Training Accuracy: 2.52% - Validation Loss: 5.2250, Validation Accuracy: 4.36%


Epoch 4/20 - Training:   0%|          | 0/177 [00:01<?, ?it/s]

Step [100/177] - Training Loss: 4.5978 - Training Accuracy: 7.19%


Epoch 4/20 - Validation:   0%|          | 0/31 [00:01<?, ?it/s]

Epoch [4/20] - Training Loss: 4.5628, Training Accuracy: 8.16% - Validation Loss: 4.3656, Validation Accuracy: 11.00%


Epoch 5/20 - Training:   0%|          | 0/177 [00:01<?, ?it/s]

Step [100/177] - Training Loss: 3.9176 - Training Accuracy: 19.50%


Epoch 5/20 - Validation:   0%|          | 0/31 [00:01<?, ?it/s]

Epoch [5/20] - Training Loss: 3.3801, Training Accuracy: 21.15% - Validation Loss: 3.5673, Validation Accuracy: 19.92%


Epoch 6/20 - Training:   0%|          | 0/177 [00:01<?, ?it/s]

Step [100/177] - Training Loss: 2.4073 - Training Accuracy: 39.88%


Epoch 6/20 - Validation:   0%|          | 0/31 [00:01<?, ?it/s]

Epoch [6/20] - Training Loss: 2.2736, Training Accuracy: 42.48% - Validation Loss: 2.6854, Validation Accuracy: 34.44%


Epoch 7/20 - Training:   0%|          | 0/177 [00:01<?, ?it/s]

Step [100/177] - Training Loss: 0.7608 - Training Accuracy: 62.94%


Epoch 7/20 - Validation:   0%|          | 0/31 [00:01<?, ?it/s]

Epoch [7/20] - Training Loss: 1.3764, Training Accuracy: 64.34% - Validation Loss: 2.4143, Validation Accuracy: 39.83%


Epoch 8/20 - Training:   0%|          | 0/177 [00:01<?, ?it/s]

Step [100/177] - Training Loss: 0.7133 - Training Accuracy: 81.31%


Epoch 8/20 - Validation:   0%|          | 0/31 [00:01<?, ?it/s]

Epoch [8/20] - Training Loss: 0.7527, Training Accuracy: 82.33% - Validation Loss: 2.1406, Validation Accuracy: 46.27%


Epoch 9/20 - Training:   0%|          | 0/177 [00:01<?, ?it/s]

Step [100/177] - Training Loss: 0.3515 - Training Accuracy: 92.94%


Epoch 9/20 - Validation:   0%|          | 0/31 [00:01<?, ?it/s]

Epoch [9/20] - Training Loss: 0.3773, Training Accuracy: 92.51% - Validation Loss: 1.9166, Validation Accuracy: 52.28%


Epoch 10/20 - Training:   0%|          | 0/177 [00:01<?, ?it/s]

Exception ignored in: Exception ignored in: Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7ff62c7c4e00><function _MultiProcessingDataLoaderIter.__del__ at 0x7ff62c7c4e00><function _MultiProcessingDataLoaderIter.__del__ at 0x7ff62c7c4e00><function _MultiProcessingDataLoaderIter.__del__ at 0x7ff62c7c4e00>



Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/admin/home-ckadirt/foundation_env/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
Traceback (most recent call last):
  File "/admin/home-ckadirt/foundation_env/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
  File "/admin/home-ckadirt/foundation_env/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
  File "/admin/home-ckadirt/foundation_env/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1477, in 

Step [100/177] - Training Loss: 0.1097 - Training Accuracy: 96.69%


Epoch 10/20 - Validation:   0%|          | 0/31 [00:01<?, ?it/s]

Epoch [10/20] - Training Loss: 0.2007, Training Accuracy: 96.70% - Validation Loss: 1.8844, Validation Accuracy: 54.77%


Epoch 11/20 - Training:   0%|          | 0/177 [00:01<?, ?it/s]

Step [100/177] - Training Loss: 0.0242 - Training Accuracy: 99.50%


Epoch 11/20 - Validation:   0%|          | 0/31 [00:01<?, ?it/s]

Epoch [11/20] - Training Loss: 0.0830, Training Accuracy: 99.50% - Validation Loss: 1.7742, Validation Accuracy: 57.68%


Epoch 12/20 - Training:   0%|          | 0/177 [00:01<?, ?it/s]

Step [100/177] - Training Loss: 0.0443 - Training Accuracy: 99.94%


Epoch 12/20 - Validation:   0%|          | 0/31 [00:01<?, ?it/s]

Epoch [12/20] - Training Loss: 0.0375, Training Accuracy: 99.96% - Validation Loss: 1.7499, Validation Accuracy: 57.68%


Epoch 13/20 - Training:   0%|          | 0/177 [00:01<?, ?it/s]

Step [100/177] - Training Loss: 0.0225 - Training Accuracy: 100.00%


Epoch 13/20 - Validation:   0%|          | 0/31 [00:01<?, ?it/s]

Epoch [13/20] - Training Loss: 0.0242, Training Accuracy: 100.00% - Validation Loss: 1.7274, Validation Accuracy: 57.68%


Epoch 14/20 - Training:   0%|          | 0/177 [00:01<?, ?it/s]

Step [100/177] - Training Loss: 0.0216 - Training Accuracy: 100.00%


Epoch 14/20 - Validation:   0%|          | 0/31 [00:01<?, ?it/s]

Epoch [14/20] - Training Loss: 0.0194, Training Accuracy: 100.00% - Validation Loss: 1.7368, Validation Accuracy: 58.51%


Epoch 15/20 - Training:   0%|          | 0/177 [00:01<?, ?it/s]

Step [100/177] - Training Loss: 0.0183 - Training Accuracy: 100.00%


Epoch 15/20 - Validation:   0%|          | 0/31 [00:01<?, ?it/s]

Epoch [15/20] - Training Loss: 0.0167, Training Accuracy: 100.00% - Validation Loss: 1.7349, Validation Accuracy: 58.09%


Epoch 16/20 - Training:   0%|          | 0/177 [00:01<?, ?it/s]

KeyboardInterrupt: 