In [1]:
# Import packages and setup gpu configuration.
# This code block shouldnt need to be adjusted other than the model_name (if interactive)!
import os
import sys
import json
import yaml
import numpy as np
import math
import time
import datetime
import random
from tqdm import tqdm
import webdataset as wds
import matplotlib.pyplot as plt
import pandas as pd
import torch
import torch.nn as nn
from torchvision import transforms
import utils
from mae_utils import flat_models
from elbow.sinks import BufferedParquetWriter

## MODEL TO LOAD ##
if utils.is_interactive():
    model_name = "HCPflat_large_gsrFalse_"
else:
    model_name = sys.argv[1]
outdir = os.path.abspath(f'checkpoints/{model_name}')
print("outdir", outdir)

# Load previously saved config.yaml made during main training script
assert os.path.exists(f"{outdir}/config.yaml")
config = yaml.load(open(f"{outdir}/config.yaml", 'r'), Loader=yaml.FullLoader)
print(f"Loaded config.yaml from ckpt folder {outdir}")
# create global variables from the config
print("\n__CONFIG__")
for attribute_name in config.keys():
    print(f"{attribute_name} = {config[attribute_name]}")
    globals()[attribute_name] = config[f'{attribute_name}']
print("\n")

if utils.is_interactive():
    # Following allows you to change functions in other files and 
    # have this notebook automatically update with your revisions
    %load_ext autoreload
    %autoreload 2

device = torch.device('cuda')

print("PID of this process =",os.getpid())

# seed all random functions
utils.seed_everything(seed)

outdir /weka/proj-fmri/ckadirt/fMRI-foundation-model/src/checkpoints/HCPflat_large_gsrFalse_
Loaded config.yaml from ckpt folder /weka/proj-fmri/ckadirt/fMRI-foundation-model/src/checkpoints/HCPflat_large_gsrFalse_

__CONFIG__
base_lr = 0.001
batch_size = 32
ckpt_interval = 5
ckpt_saving = True
cls_embed = True
contrastive_loss_weight = 1.0
datasets_to_include = HCP
decoder_embed_dim = 512
grad_accumulation_steps = 1
grad_clip = 1.0
gsr = False
hcp_flat_path = /weka/proj-medarc/shared/HCP-Flat
mask_ratio = 0.75
model_name = HCPflat_large_gsrFalse_
no_qkv_bias = False
norm_pix_loss = False
nsd_flat_path = /weka/proj-medarc/shared/NSD-Flat
num_epochs = 100
num_frames = 16
num_samples_per_epoch = 200000
num_workers = 10
patch_size = 16
pct_masks_to_decode = 1
plotting = True
pred_t_dim = 8
print_interval = 20
probe_base_lr = 0.0003
probe_batch_size = 8
probe_num_epochs = 30
probe_num_samples_per_epoch = 100000
resume_from_ckpt = True
seed = 42
sep_pos_embed = True
t_patch_size = 2
test_nu

In [2]:
os.environ['HCP_FLAT_ROOT'] = hcp_flat_path

In [3]:
if os.getenv('global_pool') == "False":
    global_pool = False
else:
    global_pool = True
print(f"global_pool = {global_pool}")

try:
    gsr
except:
    gsr = True
    print("set gsr to True")
print(f"gsr = {gsr}")

global_pool = True
gsr = False


# hcp_flat

In [4]:
from mae_utils.flat import load_hcp_flat_mask
from mae_utils.flat import create_hcp_flat
from mae_utils.flat import batch_unmask
import mae_utils.visualize as vis

flat_mask = load_hcp_flat_mask(hcp_flat_path)

model = flat_models.mae_vit_large_fmri(
    patch_size=patch_size,
    decoder_embed_dim=decoder_embed_dim,
    t_patch_size=t_patch_size,
    pred_t_dim=pred_t_dim,
    decoder_depth=4,
    cls_embed=cls_embed,
    norm_pix_loss=norm_pix_loss,
    no_qkv_bias=no_qkv_bias,
    sep_pos_embed=sep_pos_embed,
    trunc_init=trunc_init,
    pct_masks_to_decode=pct_masks_to_decode,
    img_mask=flat_mask,
)

img_size (144, 320) patch_size (16, 16) frames 16 t_patch_size 2
model initialized


# Load checkpoint

In [5]:
checkpoint_files = [f for f in os.listdir(outdir) if f.endswith('.pth')]

if utils.is_interactive():
    latest_checkpoint = "epoch99.pth"
else:
    latest_checkpoint = sys.argv[2] 
print(f"latest_checkpoint: {latest_checkpoint}")

# Load the checkpoint
checkpoint_path = os.path.join(outdir, latest_checkpoint)

state = torch.load(checkpoint_path)
model.load_state_dict(state["model_state_dict"], strict=False)
model.to(device)

print(f"\nLoaded checkpoint {latest_checkpoint} from {outdir}\n")

latest_checkpoint: epoch99.pth


  state = torch.load(checkpoint_path)



Loaded checkpoint epoch99.pth from /weka/proj-fmri/ckadirt/fMRI-foundation-model/src/checkpoints/HCPflat_large_gsrFalse_



## Create dataset and data loaders

In [6]:
from torch.utils.data import default_collate
batch_size = 1
print(f"changed batch_size to {batch_size}")

## Test ##
datasets_to_include = "HCP"
assert "HCP" in datasets_to_include
test_dataset = create_hcp_flat(root=hcp_flat_path, 
                clip_mode="event", frames=num_frames, shuffle=False, gsr=gsr, sub_list = 'test')
test_dl = wds.WebLoader(
    test_dataset.batched(batch_size, partial=False, collation_fn=default_collate),
    batch_size=None,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
)

## Train ##
assert "HCP" in datasets_to_include
train_dataset = create_hcp_flat(root=hcp_flat_path, 
                clip_mode="event", frames=num_frames, shuffle=False, gsr=gsr, sub_list = 'train')
train_dl = wds.WebLoader(
    train_dataset.batched(batch_size, partial=False, collation_fn=default_collate),
    batch_size=None,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
)

changed batch_size to 1


# Start extraction

In [7]:
cnt = 9999 # need to change this

In [8]:
@torch.no_grad()
def extract_features(dl, global_pool=True):
    for samples in tqdm(dl,total=cnt): 
        samples_meta = samples['meta']
        features = model(samples['image'].to(device),global_pool=global_pool, forward_features = True)
        features = features.flatten(1)
        features = features.cpu().numpy()
        meta_dict = {}
        for key, value in samples_meta.items():
            if type(value) == torch.Tensor:
                value = value.cpu().numpy()
            meta_dict[key] = value
        for feat, meta in zip(features, samples_meta):
            yield {"feature": feat, **meta_dict}

In [9]:
out_folder = f'{outdir}_gp{global_pool}/{latest_checkpoint[:-4]}/HCP'
print(out_folder)
os.makedirs(out_folder,exist_ok=True)

/weka/proj-fmri/ckadirt/fMRI-foundation-model/src/checkpoints/HCPflat_large_gsrFalse__gpTrue/epoch99/HCP


In [10]:

# Ensure the output Parquet directory exists
outdir_parquet = os.path.join(f'{outdir}_gp{global_pool}/{latest_checkpoint[:-4]}', 'HCP')
os.makedirs(outdir_parquet, exist_ok=True)  # <-- Add this line

utils.seed_everything(seed)

print("Start extract")
start_time = time.time()

with BufferedParquetWriter(f"{outdir_parquet}/test.parquet", blocking=True) as writer:
    for sample in extract_features(test_dl):
        writer.write(sample)

with BufferedParquetWriter(f"{outdir_parquet}/train.parquet", blocking=True) as writer:
    for sample in extract_features(train_dl):
        writer.write(sample)

total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("Extract time {}".format(total_time_str))
print(torch.cuda.memory_allocated())


/weka/proj-fmri/ckadirt/fMRI-foundation-model/src/checkpoints/HCPflat_large_gsrFalse__gpTrue/epoch99
Start extract


12082it [12:19, 16.35it/s]                                                                                                                                                                             
111302it [1:25:46, 21.63it/s]                                                                                                                                                                          

Extract time 1:38:05
3206953984



