In [1]:
# Import packages and setup gpu configuration.
# This code block shouldnt need to be adjusted!
import os
import sys
import json
import yaml
import numpy as np
import copy
import math
import time
import random
from tqdm.auto import tqdm
import webdataset as wds
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torchvision import transforms
import utils
from mae_utils.flat_models import *
import h5py
from mae_utils import flat_models

# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True
# following fixes a Conv3D CUDNN_NOT_SUPPORTED error
torch.backends.cudnn.benchmark = True

# ## MODEL TO LOAD ##
if utils.is_interactive():
    model_name = "HCPflat_large_gsrFalse_"
else:
    model_name = sys.argv[1]
    

# outdir = os.path.abspath(f'checkpoints/{model_name}')
outdir = os.path.abspath(f'checkpoints/{model_name}')

print("outdir", outdir)
# Load previous config.yaml if available
if os.path.exists(f"{outdir}/config.yaml"):
    config = yaml.load(open(f"{outdir}/config.yaml", 'r'), Loader=yaml.FullLoader)
    print(f"Loaded config.yaml from ckpt folder {outdir}")
    # create global variables from the config
    print("\n__CONFIG__")
    for attribute_name in config.keys():
        print(f"{attribute_name} = {config[attribute_name]}")
        globals()[attribute_name] = config[f'{attribute_name}']
    print("\n")

world_size = os.getenv('WORLD_SIZE')
if world_size is None: 
    world_size = 1
else:
    world_size = int(world_size)
print(f"WORLD_SIZE={world_size}")

if utils.is_interactive():
    # Following allows you to change functions in models.py or utils.py and 
    # have this notebook automatically update with your revisions
    %load_ext autoreload
    %autoreload 2

batch_size = probe_batch_size
num_epochs = probe_num_epochs

data_type = torch.float32 # change depending on your mixed_precision
global_batch_size = batch_size * world_size

device = torch.device('cuda')

hcp_flat_path = "/weka/proj-medarc/shared/HCP-Flat"
# seed = 42
# num_frames = 16
# gsr = False
# num_workers = 10
# batch_size = 128
save_ckpt = True
wandb_log = True
print("PID of this process =",os.getpid())
utils.seed_everything(seed)

outdir /weka/proj-fmri/ckadirt/fMRI-foundation-model/src/checkpoints/HCPflat_large_gsrFalse_
Loaded config.yaml from ckpt folder /weka/proj-fmri/ckadirt/fMRI-foundation-model/src/checkpoints/HCPflat_large_gsrFalse_

__CONFIG__
base_lr = 0.001
batch_size = 32
ckpt_interval = 5
ckpt_saving = True
cls_embed = True
contrastive_loss_weight = 1.0
datasets_to_include = HCP
decoder_embed_dim = 512
grad_accumulation_steps = 1
grad_clip = 1.0
gsr = False
hcp_flat_path = /weka/proj-medarc/shared/HCP-Flat
mask_ratio = 0.75
model_name = HCPflat_large_gsrFalse_
no_qkv_bias = False
norm_pix_loss = False
nsd_flat_path = /weka/proj-medarc/shared/NSD-Flat
num_epochs = 100
num_frames = 16
num_samples_per_epoch = 200000
num_workers = 10
patch_size = 16
pct_masks_to_decode = 1
plotting = True
pred_t_dim = 8
print_interval = 20
probe_base_lr = 0.0003
probe_batch_size = 8
probe_num_epochs = 30
probe_num_samples_per_epoch = 100000
resume_from_ckpt = True
seed = 42
sep_pos_embed = True
t_patch_size = 2
test_nu

In [2]:
if os.getenv('global_pool') == "False":
    global_pool = False
else:
    global_pool = True
print(f"global_pool = {global_pool}")

try:
    gsr
except:
    gsr = True
    print("set gsr to True")
print(f"gsr = {gsr}")

global_pool = True
gsr = False


In [3]:
#### UNCOMMENT THIS TO SAVE THE HCP-FLAT IN HDF5 FORMAT


# from torch.utils.data import default_collate
# from mae_utils.flat import load_hcp_flat_mask
# from mae_utils.flat import create_hcp_flat
# from mae_utils.flat import batch_unmask
# import mae_utils.visualize as vis


# batch_size = 26
# print(f"changed batch_size to {batch_size}")

# ## Test ##
# datasets_to_include = "HCP"
# assert "HCP" in datasets_to_include
# test_dataset = create_hcp_flat(root=hcp_flat_path, 
#                 clip_mode="event", frames=num_frames, shuffle=False, gsr=gsr, sub_list = 'test')
# test_dl = wds.WebLoader(
#     test_dataset.batched(batch_size, partial=False, collation_fn=default_collate),
#     batch_size=None,
#     shuffle=False,
#     num_workers=num_workers,
#     pin_memory=True,
# )

# ## Train ##
# assert "HCP" in datasets_to_include
# train_dataset = create_hcp_flat(root=hcp_flat_path, 
#                 clip_mode="event", frames=num_frames, shuffle=False, gsr=gsr, sub_list = 'train')
# train_dl = wds.WebLoader(
#     train_dataset.batched(batch_size, partial=False, collation_fn=default_collate),
#     batch_size=None,
#     shuffle=False,
#     num_workers=num_workers,
#     pin_memory=True,
# )

# def flatten_meta(meta_dict):
#     """
#     Flatten the meta dictionary by:
#     - Replacing single-item lists with the item itself.
#     - Converting tensors to scalar numbers.
#     """
#     flattened = {}
#     for key, value in meta_dict.items():
#         if isinstance(value, list):
#             if len(value) == 1:
#                 flattened[key] = value[0]  # Replace list with its single item
#             else:
#                 flattened[key] = value  # Keep as is if multiple items
#         elif isinstance(value, torch.Tensor):
#             # Convert tensor to scalar
#             if value.numel() == 1:
#                 flattened[key] = value.item()
#             else:
#                 flattened[key] = value.tolist()  # Convert multi-element tensor to list
#         else:
#             flattened[key] = value  # Keep the value as is
#     return flattened

# import h5py
# meta_array = np.array([], dtype=object)
# # Open an HDF5 file in write mode
# with h5py.File('train_hcp.hdf5', 'w') as h5f:
#     flatmaps_dset = None
    
#     total_samples = 0

#     for i, batch in tqdm(enumerate(train_dl), total = 120000):
#         images = batch['image'][0]
#         meta = batch['meta']
#         batch_size = images.shape[0]
#         meta_serializable = meta.copy()
        
        
#         # Step 2: Serialize the dictionary to a JSON string
#         meta_str = json.dumps(flatten_meta(meta_serializable), indent=4)
#         meta_array = np.append(meta_array, meta_str)
#         if flatmaps_dset is None:
#             # Initialize datasets with unlimited (None) maxshape along the first axis
#             flatmaps_shape = (0,) + images.shape[1:]
#             flatmaps_maxshape = (None,) + images.shape[1:]

#             flatmaps_dset = h5f.create_dataset(
#                 'flatmaps',
#                 shape=flatmaps_shape,
#                 maxshape=flatmaps_maxshape,
#                 dtype=np.float16,
#                 chunks=True  # Enable chunking for efficient resizing
#             )

#         # Resize datasets to accommodate new data
#         flatmaps_dset.resize(total_samples + batch_size, axis=0)

#         # Write data to the datasets
#         flatmaps_dset[total_samples:total_samples + batch_size] = images.numpy().astype(np.float16)

#         total_samples += batch_size
        
#     print(f"Processed {total_samples} samples")
# np.save('metadata_test_HCP.npy', meta_array)


# import h5py
# meta_array = np.array([], dtype=object)
# # Open an HDF5 file in write mode
# with h5py.File('test_hcp.hdf5', 'w') as h5f:
#     flatmaps_dset = None
    
#     total_samples = 0

#     for i, batch in tqdm(enumerate(test_dl), total = 12000):
#         images = batch['image'][0]
#         meta = batch['meta']
#         batch_size = images.shape[0]
#         meta_serializable = meta.copy()
        
        
#         # Step 2: Serialize the dictionary to a JSON string
#         meta_str = json.dumps(flatten_meta(meta_serializable), indent=4)
#         meta_array = np.append(meta_array, meta_str)
#         if flatmaps_dset is None:
#             # Initialize datasets with unlimited (None) maxshape along the first axis
#             flatmaps_shape = (0,) + images.shape[1:]
#             flatmaps_maxshape = (None,) + images.shape[1:]

#             flatmaps_dset = h5f.create_dataset(
#                 'flatmaps',
#                 shape=flatmaps_shape,
#                 maxshape=flatmaps_maxshape,
#                 dtype=np.float16,
#                 chunks=True  # Enable chunking for efficient resizing
#             )

#         # Resize datasets to accommodate new data
#         flatmaps_dset.resize(total_samples + batch_size, axis=0)

#         # Write data to the datasets
#         flatmaps_dset[total_samples:total_samples + batch_size] = images.numpy().astype(np.float16)

#         total_samples += batch_size
        
#     print(f"Processed {total_samples} samples")
# np.save('metadata_train_HCP.npy', meta_array)

### Preparing data

In [4]:
from sklearn.preprocessing import LabelEncoder

INCLUDE_CONDS = {
    "fear",
    "neut",
    "math",
    "story",
    "lf",
    "lh",
    "rf",
    "rh",
    "t",
    "match",
    "relation",
    "mental",
    "rnd",
    "0bk_body",
    "2bk_body",
    "0bk_faces",
    "2bk_faces",
    "0bk_places",
    "2bk_places",
    "0bk_tools",
    "2bk_tools",
}

# test_data = []

# # Iterate over the DataLoader with a progress bar
# for sample in tqdm(train_dl, desc="Processing samples"):
#     x = sample['image']
#     y = sample['meta']['trial_type']
#     key = sample['meta']['key']
#     print(x.shape, y, key)
#     break
# Initialize the label encoder
label_encoder = LabelEncoder()
label_encoder.fit(sorted(INCLUDE_CONDS))  # Ensure consistent ordering

num_classes = len(label_encoder.classes_)
print(f"Number of classes: {num_classes}")

Number of classes: 21


In [5]:
f_train = h5py.File('/weka/proj-fmri/ckadirt/fMRI-foundation-model/src/train_hcp.hdf5', 'r')
flatmaps_train = f_train['flatmaps']

f_test = h5py.File('/weka/proj-fmri/ckadirt/fMRI-foundation-model/src/test_hcp.hdf5', 'r')
flatmaps_test = f_test['flatmaps']

metadata_train = np.load('/weka/proj-fmri/ckadirt/fMRI-foundation-model/src/metadata_train_HCP.npy', allow_pickle=True)
metadata_test = np.load('/weka/proj-fmri/ckadirt/fMRI-foundation-model/src/metadata_test_HCP.npy', allow_pickle=True)

In [6]:
from torch.utils.data import Dataset, DataLoader

class HCPFlatDataset(Dataset):
    def __init__(self, flatmaps, metadata):
        self.flatmaps = flatmaps
        self.metadata = metadata

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        return self.flatmaps[idx], json.loads(self.metadata[idx])
print("Moving datasets to ram")
# Loading to cpu for faster training, this can take several minutes.  Remove this [:] if you want to move one at the time.
train_dataset = HCPFlatDataset(flatmaps_train, metadata_train)
train_dl = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

test_dataset = HCPFlatDataset(flatmaps_test, metadata_test)
test_dl = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
print("Datasets ready")

Moving datasets to ram
Datasets ready


### Creating and loading Model

In [7]:
from mae_utils.flat import load_hcp_flat_mask
from mae_utils.flat import create_hcp_flat
from mae_utils.flat import batch_unmask
import mae_utils.visualize as vis

flat_mask = load_hcp_flat_mask(hcp_flat_path)

mae_model = flat_models.mae_vit_large_fmri(
    patch_size=patch_size,
    decoder_embed_dim=decoder_embed_dim,
    t_patch_size=t_patch_size,
    pred_t_dim=pred_t_dim,
    decoder_depth=4,
    cls_embed=cls_embed,
    norm_pix_loss=norm_pix_loss,
    no_qkv_bias=no_qkv_bias,
    sep_pos_embed=sep_pos_embed,
    trunc_init=trunc_init,
    pct_masks_to_decode=pct_masks_to_decode,
    img_mask=flat_mask,
)

img_size (144, 320) patch_size (16, 16) frames 16 t_patch_size 2
model initialized


In [8]:
checkpoint_files = [f for f in os.listdir(outdir) if f.endswith('.pth')]

if utils.is_interactive():
    latest_checkpoint = "epoch99.pth"
else:
    latest_checkpoint = sys.argv[2] 
print(f"latest_checkpoint: {latest_checkpoint}")

# Load the checkpoint
checkpoint_path = os.path.join(outdir, latest_checkpoint)

state = torch.load(checkpoint_path)
mae_model.load_state_dict(state["model_state_dict"], strict=False)
mae_model.to(device)

print(f"\nLoaded checkpoint {latest_checkpoint} from {outdir}\n")

latest_checkpoint: epoch99.pth


  state = torch.load(checkpoint_path)



Loaded checkpoint epoch99.pth from /weka/proj-fmri/ckadirt/fMRI-foundation-model/src/checkpoints/HCPflat_large_gsrFalse_



In [9]:
class LinearClassifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(LinearClassifier, self).__init__()
        self.linear = nn.Linear(input_dim, num_classes)
        
    def forward(self, x):
        # Flatten the input except for the batch dimension
        x = x.view(x.size(0), -1)
        out = self.linear(x)
        return out  # Raw logits

# Determine the input dimension from a single sample
# Assuming images are of shape [1, 16, 144, 320]
input_dim = np.prod(mae_model(torch.randn(1,1,16,144,320).to(device),global_pool=global_pool, forward_features = True).shape[1:])
print(f"Input dimension: {input_dim}")


Input dimension: 1024


In [10]:
class FullModel(nn.Module):
    def __init__(self, lc_model, mae_model):
        super(FullModel, self).__init__()
        self.lc_model = lc_model
        self.mae_model = mae_model
        
        
    def forward(self, x, gsr):
        x = self.mae_model(x, global_pool=global_pool, forward_features = True)
        x = self.lc_model(x)
        return x


In [11]:
# Initialize the model
lc_model = LinearClassifier(input_dim=input_dim, num_classes=num_classes)

model = FullModel(lc_model, mae_model)

# Move the model to the GPU
model.to(device)

# Define loss function
criterion = nn.CrossEntropyLoss()

# Define optimizer with L2 regularization (weight_decay)
learning_rate = 1e-4
weight_decay = 1e-5  # Adjust based on your needs
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
num_epochs = 20  # Adjust as needed


### Data

In [16]:
import uuid

myuuid = uuid.uuid4()
str(myuuid)

'9f9d4444-58f9-4378-9a9a-ea56bf707eb6'

In [17]:
import wandb

if utils.is_interactive():
    print("Running in interactive notebook. Disabling W&B and ckpt saving.")
    wandb_log = True
    save_ckpt = True

if wandb_log:
    wandb_project = 'fMRI-foundation-model'
    wandb_config = {
        "model_name": model_name+'_HCP_FT',
        "batch_size": batch_size,
        "learning_rate": learning_rate,
        "weight_decay": weight_decay,
        "num_epochs": num_epochs,
        "seed": seed,
    }
    print("wandb_config:\n", wandb_config)
    random_id = str(uuid.uuid4())
    print("wandb_id:", "HCPflat_raw" + f"_{random_id}")
    wandb.init(
        id=model_name+'_HCP_FT' + f"_{random_id}",
        project=wandb_project,
        name=model_name+'_HCP_FT',
        config=wandb_config,
        resume="allow",
    )

Running in interactive notebook. Disabling W&B and ckpt saving.
wandb_config:
 {'model_name': 'HCPflat_large_gsrFalse__HCP_FT', 'batch_size': 8, 'learning_rate': 0.0001, 'weight_decay': 1e-05, 'num_epochs': 20, 'seed': 42}
wandb_id: HCPflat_raw_7ee35929-85c0-47da-91ab-dac2776c444f


VBox(children=(Label(value='0.022 MB of 0.022 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch_train_loss,█▄▁
epoch_val_loss,▂▁█
train_accuracy,▁▅█
val_accuracy,█▁▁

0,1
epoch_train_loss,1.35031
epoch_val_loss,3.83135
train_accuracy,59.09091
val_accuracy,25.0


In [None]:
for epoch in range(num_epochs):
    running_train_loss = 0.0
    correct_train = 0
    total_train = 0
    step = 0

    # with torch.amp.autocast(device_type='cuda'):
    # Training Phase
    model.train()
    for  batch in tqdm(train_dl, desc=f"Epoch {epoch+1}/{num_epochs} - Training"):
        optimizer.zero_grad()
        images = batch[0].to(device).float().unsqueeze(1) #fix this  # Shape: [batch_size, 1, 16, 144, 320]
        labels = batch[1]['trial_type']  # List of labels
        
        encoded_labels = label_encoder.transform(labels)
        encoded_labels = torch.tensor(encoded_labels, dtype=torch.long).to(device)  # Shape: [batch_size]
        
        # Forward pass
        outputs = model(images, gsr=gsr)  # Shape: [num_train_samples, num_classes]
        
        # Compute loss
        loss = criterion(outputs, encoded_labels)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
        # Accumulate loss
        running_train_loss += loss.item() * images.size(0)

        
        # Calculate accuracy
        _, predicted = torch.max(outputs, 1)
        
        correct_train += (predicted == encoded_labels).sum().item()
        total_train += encoded_labels.size(0)
        
        step = step + 1
        if step % 100 == 0:
            print(f"Step [{step}/{len(train_dl)}] - Training Loss: {loss.item():.4f} - Training Accuracy: {100 * correct_train / total_train:.2f}%")
        # thth

    epoch_train_loss = running_train_loss / total_train if total_train > 0 else 0.0
    train_accuracy = 100 * correct_train / total_train if total_train > 0 else 0.0
    
    # Validation Phase
    model.eval()
    running_val_loss = 0.0
    correct_val = 0
    total_val = 0
    
    with torch.no_grad():
        for batch in tqdm(test_dl, desc=f"Epoch {epoch+1}/{num_epochs} - Validation"):
            
            images = batch[0].to(device).float().unsqueeze(1) #fix this
            labels = batch[1]['trial_type']
            
            # Encode labels to integer indices
            encoded_labels = label_encoder.transform(labels)
            encoded_labels = torch.tensor(encoded_labels, dtype=torch.long).to(device)
            
        
            # Forward pass
            outputs = model(images, gsr=gsr)
            
            # Compute loss
            loss = criterion(outputs, encoded_labels)
            
            # Accumulate loss
            running_val_loss += loss.item() * images.size(0)
            
            # Calculate accuracy
            _, predicted = torch.max(outputs, 1)
            correct_val += (predicted == encoded_labels).sum().item()
            total_val += encoded_labels.size(0)

            
    
    epoch_val_loss = running_val_loss / total_val if total_val > 0 else 0.0
    val_accuracy = 100 * correct_val / total_val if total_val > 0 else 0.0
    
    print(f"Epoch [{epoch+1}/{num_epochs}] "
        f"- Training Loss: {epoch_train_loss:.4f}, Training Accuracy: {train_accuracy:.2f}% "
        f"- Validation Loss: {epoch_val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")
    
    if wandb_log:
        wandb.log({
            "epoch_train_loss": epoch_train_loss,
            "epoch_val_loss": epoch_val_loss,
            "train_accuracy": train_accuracy,
            "val_accuracy": val_accuracy,
        })
    if save_ckpt:
        outdir = os.path.abspath(f'checkpoints/{model_name+"HCP_FT"}')
        os.makedirs(outdir, exist_ok=True)
        print("outdir", outdir)
        # Save model and config
        torch.save(model.state_dict(), f"{outdir}/model.pth")
        with open(f"{outdir}/config.yaml", 'w') as f:
            yaml.dump(wandb_config, f)
        print(f"Saved model and config to {outdir}")
    