In [1]:
# Import packages and setup gpu configuration.
# This code block shouldnt need to be adjusted other than the model_name (if interactive)!
import os
import sys
import json
import yaml
import numpy as np
import math
import time
import datetime
import random
from tqdm import tqdm
import webdataset as wds
import matplotlib.pyplot as plt
import pandas as pd
import torch
import torch.nn as nn
from torchvision import transforms
import utils
from mae_utils import flat_models

## MODEL TO LOAD ##
if utils.is_interactive():
    model_name = "NSDflat_large_gsrFalse_"
else:
    model_name = sys.argv[1]
outdir = os.path.abspath(f'checkpoints/{model_name}')
print("outdir", outdir)

# Load previously saved config.yaml made during main training script
assert os.path.exists(f"{outdir}/config.yaml")
config = yaml.load(open(f"{outdir}/config.yaml", 'r'), Loader=yaml.FullLoader)
print(f"Loaded config.yaml from ckpt folder {outdir}")
# create global variables from the config
print("\n__CONFIG__")
for attribute_name in config.keys():
    print(f"{attribute_name} = {config[attribute_name]}")
    globals()[attribute_name] = config[f'{attribute_name}']
print("\n")

if utils.is_interactive():
    # Following allows you to change functions in other files and 
    # have this notebook automatically update with your revisions
    %load_ext autoreload
    %autoreload 2

device = torch.device('cuda')

print("PID of this process =",os.getpid())

# seed all random functions
utils.seed_everything(seed)

outdir /weka/proj-fmri/paulscotti/fMRI-foundation-model/src/checkpoints/NSDflat_large_gsrFalse_
Loaded config.yaml from ckpt folder /weka/proj-fmri/paulscotti/fMRI-foundation-model/src/checkpoints/NSDflat_large_gsrFalse_

__CONFIG__
base_lr = 0.001
batch_size = 32
ckpt_interval = 5
ckpt_saving = True
cls_embed = True
contrastive_loss_weight = 1.0
datasets_to_include = NSD
decoder_embed_dim = 512
grad_accumulation_steps = 1
grad_clip = 1.0
gsr = False
hcp_flat_path = /weka/proj-medarc/shared/HCP-Flat
mask_ratio = 0.75
model_name = NSDflat_large_gsrFalse_
no_qkv_bias = False
norm_pix_loss = False
nsd_flat_path = /weka/proj-medarc/shared/NSD-Flat
num_epochs = 100
num_frames = 16
num_samples_per_epoch = 200000
num_workers = 10
patch_size = 16
pct_masks_to_decode = 1
plotting = True
pred_t_dim = 8
print_interval = 20
probe_base_lr = 0.0003
probe_batch_size = 8
probe_num_epochs = 30
probe_num_samples_per_epoch = 100000
resume_from_ckpt = True
seed = 42
sep_pos_embed = True
t_patch_size = 2
t

# nsd_flat

## Set nsd_flat downstream config

In [2]:
if os.getenv('global_pool') == "False":
    global_pool = False
else:
    global_pool = True
print(f"global_pool = {global_pool}")

if os.getenv('use_visual_roi') == "False":
    use_visual_roi = False
else:
    use_visual_roi = True
print(f"use_visual_roi = {use_visual_roi}")

try:
    gsr
except:
    gsr = True
    print("set gsr to True")
print(f"gsr = {gsr}")

global_pool = True
use_visual_roi = True
gsr = False


## Load model

In [3]:
from mae_utils.flat import load_nsd_flat_mask, load_nsd_flat_mask_visual
from mae_utils.flat import create_nsd_flat
from mae_utils.flat import batch_unmask
import mae_utils.visualize as vis

flat_mask = load_nsd_flat_mask()
flat_mask_visual = load_nsd_flat_mask_visual()

model = flat_models.mae_vit_large_fmri(
    patch_size=patch_size,
    decoder_embed_dim=decoder_embed_dim,
    t_patch_size=t_patch_size,
    pred_t_dim=pred_t_dim,
    decoder_depth=4,
    cls_embed=cls_embed,
    norm_pix_loss=norm_pix_loss,
    no_qkv_bias=no_qkv_bias,
    sep_pos_embed=sep_pos_embed,
    trunc_init=trunc_init,
    pct_masks_to_decode=pct_masks_to_decode,
    img_mask=flat_mask,
)

img_size (144, 320) patch_size (16, 16) frames 16 t_patch_size 2
model initialized


## Load checkpoint

In [4]:
checkpoint_files = [f for f in os.listdir(outdir) if f.endswith('.pth')]

if utils.is_interactive():
    latest_checkpoint = "epoch99.pth"
else:
    latest_checkpoint = sys.argv[2] 
print(f"latest_checkpoint: {latest_checkpoint}")

# Load the checkpoint
checkpoint_path = os.path.join(outdir, latest_checkpoint)

state = torch.load(checkpoint_path)
model.load_state_dict(state["model_state_dict"], strict=False)
model.to(device)

print(f"\nLoaded checkpoint {latest_checkpoint} from {outdir}\n")

latest_checkpoint: epoch99.pth


  state = torch.load(checkpoint_path)



Loaded checkpoint epoch99.pth from /weka/proj-fmri/paulscotti/fMRI-foundation-model/src/checkpoints/NSDflat_large_gsrFalse_



In [5]:
if use_visual_roi:
    print("change mask to visual cortex only")
    model.initialize_mask(flat_mask_visual)

change mask to visual cortex only


## Create dataset and data loaders

In [6]:
batch_size = 1
print(f"changed batch_size to {batch_size}")

## Test ##
datasets_to_include = "NSD"
assert "NSD" in datasets_to_include
test_dataset = create_nsd_flat(root=nsd_flat_path, 
                frames=num_frames, shuffle=False, 
                gsr=gsr, sub="sub-01", run="task-only")
test_dl = wds.WebLoader(
    test_dataset.batched(batch_size, partial=False),
    batch_size=None,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
)

changed batch_size to 1


# Start extraction

In [7]:
cnt = 141930 # 2213 if bs=64, 5909 if bs=24, 141930 if bs=1

In [8]:
# random_subselection = np.random.choice(np.arange(294912), 10240)
@torch.no_grad()
def extract_features(dl, global_pool=True):
    for samples, samples_meta, samples_events, sample_means, sample_sds in tqdm(dl,total=cnt):     
        #### sanity testing by using just the raw flat maps ####
        # x = samples.to(device)
        # x = model.patch_embed(x)
        # N, T, L, C = x.shape
        # x = x[:, :, model.patch_mask_indices]
        # x = x.flatten(1)
        # features = x[:,random_subselection]
        ########
        
        features = model.forward_features(samples.to(device),global_pool=global_pool)
        features = features.flatten(1)
        features = features.cpu().numpy()

        for feat, meta, events in zip(features, samples_meta, samples_events):
            if meta['start']==0:
                meta["events"] = events
            else:
                meta["events"] = None
            yield feat, meta

In [9]:
out_folder = f'{outdir}_gp{global_pool}_visual{use_visual_roi}/{latest_checkpoint[:-4]}'
print(out_folder)
os.makedirs(out_folder,exist_ok=True)

/weka/proj-fmri/paulscotti/fMRI-foundation-model/src/checkpoints/NSDflat_large_gsrFalse__gpFalse_visualTrue/epoch99


In [None]:
import h5py
utils.seed_everything(seed)
output_path = f"{out_folder}/test.h5"
print("Output location:", output_path)
start_time = time.time()
chunk_size = 10000

# Open the HDF5 file in write mode
with h5py.File(output_path, "w") as h5file:
    # Initialize variables to determine feature shape and dtype
    first_sample = next(extract_features(test_dl, global_pool=global_pool), None)
    if first_sample is None:
        raise ValueError("No samples found in the dataset.")
    
    first_feature, first_metadata = first_sample
    feature_shape = first_feature.shape  # e.g., (feature_length,)
    feature_dtype = first_feature.dtype  # e.g., float32
    
    # Create the extendable features dataset
    features_ds = h5file.create_dataset(
        "features",
        shape=(0,) + feature_shape,
        maxshape=(None,) + feature_shape,
        dtype=feature_dtype,
        compression="gzip",
        chunks=True
    )
    
    # Create separate datasets for each metadata field
    string_dt = h5py.string_dtype(encoding='utf-8')
    key_ds = h5file.create_dataset(
        "key",
        shape=(0,),
        maxshape=(None,),
        dtype=string_dt,
        compression="gzip",
        chunks=True
    )
    sub_ds = h5file.create_dataset(
        "sub",
        shape=(0,),
        maxshape=(None,),
        dtype='i8',
        compression="gzip",
        chunks=True
    )
    ses_ds = h5file.create_dataset(
        "ses",
        shape=(0,),
        maxshape=(None,),
        dtype='i8',
        compression="gzip",
        chunks=True
    )
    run_ds = h5file.create_dataset(
        "run",
        shape=(0,),
        maxshape=(None,),
        dtype='i8',
        compression="gzip",
        chunks=True
    )
    start_ds = h5file.create_dataset(
        "start",
        shape=(0,),
        maxshape=(None,),
        dtype='i8',
        compression="gzip",
        chunks=True
    )
    events_ds = h5file.create_dataset(
        "events",
        shape=(0,),
        maxshape=(None,),
        dtype=string_dt,
        compression="gzip",
        chunks=True
    )
    
    # Initialize buffers with the first sample
    buffer_features = [first_feature]
    buffer_keys = [first_metadata.get('key', '')]
    buffer_subs = [first_metadata.get('sub', 0)]
    buffer_sess = [first_metadata.get('ses', 0)]
    buffer_runs = [first_metadata.get('run', 0)]
    buffer_starts = [first_metadata.get('start', 0)]
    buffer_events = [json.dumps(first_metadata.get('events', []))]
    
    samples_processed = 1
    
    # Iterate over the remaining samples
    for feature, metadata in extract_features(test_dl, global_pool=global_pool):
        buffer_features.append(feature)
        buffer_keys.append(metadata.get('key', ''))
        buffer_subs.append(metadata.get('sub', 0))
        buffer_sess.append(metadata.get('ses', 0))
        buffer_runs.append(metadata.get('run', 0))
        buffer_starts.append(metadata.get('start', 0))
        buffer_events.append(json.dumps(metadata.get('events', [])))
        
        samples_processed += 1
        
        # When buffer is full, write to HDF5
        if samples_processed % chunk_size == 0:
            # Append features
            features_array = np.stack(buffer_features)  # Shape: (chunk_size, ...)
            current_size = features_ds.shape[0]
            new_size = current_size + features_array.shape[0]
            features_ds.resize((new_size,) + feature_shape)
            features_ds[current_size:new_size, ...] = features_array
            buffer_features = []
            
            # Append 'key'
            keys_encoded = np.array(buffer_keys, dtype=string_dt)
            key_ds.resize((new_size,))
            key_ds[current_size:new_size] = keys_encoded
            buffer_keys = []
            
            # Append 'sub'
            subs_array = np.array(buffer_subs, dtype='i8')
            sub_ds.resize((new_size,))
            sub_ds[current_size:new_size] = subs_array
            buffer_subs = []
            
            # Append 'ses'
            sess_array = np.array(buffer_sess, dtype='i8')
            ses_ds.resize((new_size,))
            ses_ds[current_size:new_size] = sess_array
            buffer_sess = []
            
            # Append 'run'
            runs_array = np.array(buffer_runs, dtype='i8')
            run_ds.resize((new_size,))
            run_ds[current_size:new_size] = runs_array
            buffer_runs = []
            
            # Append 'start'
            starts_array = np.array(buffer_starts, dtype='i8')
            start_ds.resize((new_size,))
            start_ds[current_size:new_size] = starts_array
            buffer_starts = []
            
            # Append 'events'
            events_encoded = np.array(buffer_events, dtype=string_dt)
            events_ds.resize((new_size,))
            events_ds[current_size:new_size] = events_encoded
            buffer_events = []
    
    # After loop, write any remaining data in buffers
    if buffer_features:
        features_array = np.stack(buffer_features)
        current_size = features_ds.shape[0]
        new_size = current_size + features_array.shape[0]
        features_ds.resize((new_size,) + feature_shape)
        features_ds[current_size:new_size, ...] = features_array
        
        # Append 'key'
        keys_encoded = np.array(buffer_keys, dtype=string_dt)
        key_ds.resize((new_size,))
        key_ds[current_size:new_size] = keys_encoded
        
        # Append 'sub'
        subs_array = np.array(buffer_subs, dtype='i8')
        sub_ds.resize((new_size,))
        sub_ds[current_size:new_size] = subs_array
        
        # Append 'ses'
        sess_array = np.array(buffer_sess, dtype='i8')
        ses_ds.resize((new_size,))
        ses_ds[current_size:new_size] = sess_array
        
        # Append 'run'
        runs_array = np.array(buffer_runs, dtype='i8')
        run_ds.resize((new_size,))
        run_ds[current_size:new_size] = runs_array
        
        # Append 'start'
        starts_array = np.array(buffer_starts, dtype='i8')
        start_ds.resize((new_size,))
        start_ds[current_size:new_size] = starts_array
        
        # Append 'events'
        events_encoded = np.array(buffer_events, dtype=string_dt)
        events_ds.resize((new_size,))
        events_ds[current_size:new_size] = events_encoded

# Calculate total time
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("Extraction time:", total_time_str)
print(f"Processed {samples_processed} samples in total.")
h5file.close()

Output location: /weka/proj-fmri/paulscotti/fMRI-foundation-model/src/checkpoints/NSDflat_large_gsrFalse__gpFalse_visualTrue/epoch99/test.h5


  0%|                                                   | 0/141930 [00:03<?, ?it/s]
 28%|██████████▉                            | 39997/141930 [30:36<28:44, 59.10it/s]

In [None]:
f = h5py.File(f'{output_path}', 'r')
print(f.keys())