In [2]:
# Import packages and setup gpu configuration.
# This code block shouldnt need to be adjusted!
import os
import sys
import json
import yaml
import numpy as np
import copy
import math
import time
import datetime
import random
from tqdm import tqdm
import webdataset as wds
import matplotlib.pyplot as plt
import pandas as pd

import torch
import torch.nn as nn
from torchvision import transforms
import utils
from flat_models import *

from elbow.sinks import BufferedParquetWriter

# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True
# following fixes a Conv3D CUDNN_NOT_SUPPORTED error
torch.backends.cudnn.benchmark = True

## MODEL TO LOAD ##
model_name = "nsdflat_large"
outdir = os.path.abspath(f'checkpoints/{model_name}')
print("outdir", outdir)
# Load previous config.yaml if available
if os.path.exists(f"{outdir}/config.yaml"):
    config = yaml.load(open(f"{outdir}/config.yaml", 'r'), Loader=yaml.FullLoader)
    print(f"Loaded config.yaml from ckpt folder {outdir}")
    # create global variables from the config
    print("\n__CONFIG__")
    for attribute_name in config.keys():
        print(f"{attribute_name} = {config[attribute_name]}")
        globals()[attribute_name] = config[f'{attribute_name}']
    print("\n")

### Multi-GPU config ###
device_count = torch.cuda.device_count()
print(f"Number of available CUDA devices: {device_count}")

local_rank = os.getenv('LOCAL_RANK')
if local_rank is None: 
    local_rank = 0
else:
    local_rank = int(local_rank)
print(f"LOCAL RANK={local_rank}")

num_devices = os.getenv('NUM_GPUS')
if num_devices is None: 
    num_devices = 1
else:
    num_devices = int(num_devices)
print(f"NUM GPUS={num_devices}")
distributed = True if num_devices>1 else False
if distributed: assert device_count==num_devices

node = os.getenv('SLURM_NODEID')
if node is None:
    node = 0
else:
    node = int(node)
print(f"NODE={node}")

global_rank = os.getenv('RANK')
if global_rank is None:
    global_rank = 0
else:
    global_rank = int(global_rank)
print(f"GLOBAL RANK={global_rank}")

world_size = os.getenv('WORLD_SIZE')
if world_size is None: 
    world_size = 1
else:
    world_size = int(world_size)
print(f"WORLD_SIZE={world_size}")

if utils.is_interactive():
    # Following allows you to change functions in models.py or utils.py and 
    # have this notebook automatically update with your revisions
    %load_ext autoreload
    %autoreload 2

batch_size = probe_batch_size
num_epochs = probe_num_epochs

data_type = torch.float32 # change depending on your mixed_precision
global_batch_size = batch_size * world_size

device = torch.device('cuda')

print("PID of this process =",os.getpid())
print("device =", device, "distributed =",distributed, "num_devices =", num_devices, "local rank =", local_rank, "world size =", world_size, "data_type =", data_type)

# seed all random functions
utils.seed_everything(seed + global_rank)

outdir /weka/proj-fmri/paulscotti/fMRI-foundation-model/flat/checkpoints/nsdflat_large
Loaded config.yaml from ckpt folder /weka/proj-fmri/paulscotti/fMRI-foundation-model/flat/checkpoints/nsdflat_large

__CONFIG__
base_lr = 0.001
batch_size = 32
ckpt_interval = 5
ckpt_saving = True
cls_embed = True
decoder_embed_dim = 512
grad_accumulation_steps = 1
grad_clip = 1.0
hcp_flat_path = /weka/proj-medarc/shared/NSD-Flat
mask_ratio = 0.75
model_name = nsdflat_large
no_qkv_bias = False
norm_pix_loss = False
num_epochs = 100
num_frames = 16
num_samples_per_epoch = 200000
num_workers = 10
patch_size = 16
pct_masks_to_decode = 1
pred_t_dim = 8
print_interval = 20
probe_base_lr = 0.0003
probe_batch_size = 8
probe_num_epochs = 30
probe_num_samples_per_epoch = 100000
resume_from_ckpt = True
seed = 42
sep_pos_embed = True
t_patch_size = 2
test_num_samples_per_epoch = 50000
test_set = False
trunc_init = False
use_contrastive_loss = False
wandb_log = True


Number of available CUDA devices: 1
LOCAL RA

# hcp_flat

In [24]:
from util.hcp_flat import load_hcp_flat_mask, load_nsd_flat_mask
from util.hcp_flat import create_hcp_flat, create_nsd_flat
import util.visualize as vis

#mae_vit_small_fmri(
model = mae_vit_large_fmri(
    patch_size=patch_size,
    decoder_embed_dim=decoder_embed_dim,
    t_patch_size=t_patch_size,
    pred_t_dim=pred_t_dim,
    decoder_depth=4,
    cls_embed=cls_embed,
    norm_pix_loss=norm_pix_loss,
    no_qkv_bias=no_qkv_bias,
    sep_pos_embed=sep_pos_embed,
    trunc_init=trunc_init,
    pct_masks_to_decode=pct_masks_to_decode,
)

img_size (144, 320) patch_size (16, 16) frames 16 t_patch_size 2
model initialized


# Load checkpoint

In [26]:
checkpoint_files = [f for f in os.listdir(outdir) if f.endswith('.pth')]

# Extract epoch numbers and find the largest
epoch_numbers = []
for file in checkpoint_files:
    try:
        epoch_number = int(file.split('epoch')[-1].split('.')[0])
        epoch_numbers.append(epoch_number)
    except ValueError:
        continue
latest_epoch = max(epoch_numbers)
latest_checkpoint = f"epoch{latest_epoch}.pth"

# # Or specify epoch number 
# latest_checkpoint = "epoch15.pth"

# Load the checkpoint
checkpoint_path = os.path.join(outdir, latest_checkpoint)

state = torch.load(checkpoint_path)
model.load_state_dict(state["model_state_dict"], strict=True)
model.to(device)

print(f"\nLoaded checkpoint {latest_checkpoint} from {outdir}\n")


Loaded checkpoint epoch99.pth from /weka/proj-fmri/paulscotti/fMRI-foundation-model/flat/checkpoints/nsdflat_large



## Create dataset and data loaders

In [129]:
batch_size = 64 #12
print(f"changed batch_size to {batch_size}")

## Test ##
if "HCP" in hcp_flat_path:
    dataset = "hcp"
    test_dataset = create_hcp_flat(root=hcp_flat_path, 
                    split="test", frames=num_frames, shuffle=False)
elif "NSD" in hcp_flat_path:
    dataset = "nsd"
    test_dataset = create_nsd_flat(root=hcp_flat_path, 
                    frames=num_frames, shuffle=False)
test_dl = wds.WebLoader(
    test_dataset.batched(batch_size, partial=False),
    batch_size=None,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
)

changed batch_size to 64


# Start extraction

In [130]:
cnt = 0
total = 0
if dataset == "hcp":
    for sample, sample_meta in test_dl:
        cnt += 1
        total += len(sample)
elif dataset == "nsd":
    for sample, sample_meta, sample_events in test_dl:
        cnt += 1
        total += len(sample)
print("cnt", cnt)
print("total", total)

cnt 2292
total 146688


In [136]:
model.to(device)
x = sample[[0]].to(device)
with torch.no_grad():
    x = model.patch_embed(x)
print(x.shape)

torch.Size([1, 8, 180, 1024])


In [137]:
@torch.no_grad()
def extract_features(dl, dataset="nsd"):
    if dataset=="hcp":
        for samples, samples_meta in tqdm(dl,total=cnt):
    
            samples = samples.to(device)
            
            features = model.forward_features(samples)
    
            features = features.cpu().numpy()
    
            for feat, meta in zip(features, samples_meta):
                yield {"feature": feat, **meta}
    elif dataset=="nsd": 
        for samples, samples_meta, samples_events in tqdm(dl,total=cnt):
    
            samples = samples.to(device)
            
            features = model.forward_features(samples)
    
            features = features.cpu().numpy()
    
            for feat, meta, events in zip(features, samples_meta, samples_events):
                yield {"feature": feat, **meta,
                       "events": events}

In [138]:
outdir_parquet = f'{outdir}/{latest_checkpoint[:-4]}'
print(outdir_parquet)

os.makedirs(outdir_parquet,exist_ok=True)

/weka/proj-fmri/paulscotti/fMRI-foundation-model/flat/checkpoints/nsdflat_large/epoch99


In [139]:
utils.seed_everything(seed)

print("Start extract")
print("output location:", f"{outdir_parquet}/test.parquet")
start_time = time.time()

with BufferedParquetWriter(f"{outdir_parquet}/test.parquet", blocking=True) as writer:
    for sample in extract_features(test_dl, dataset=dataset):
        writer.write(sample)

total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("Extract time {}".format(total_time_str))
print(torch.cuda.memory_allocated())

Start extract
output location: /weka/proj-fmri/paulscotti/fMRI-foundation-model/flat/checkpoints/nsdflat_large/epoch99/test.parquet


100%|███████████████████████████████████████| 2292/2292 [25:57<00:00,  1.47it/s]

Extract time 0:25:57
3403167232





In [None]:
print("output location:", f"{outdir_parquet}/test.parquet")

In [None]:
# utils.seed_everything(seed)

# print("Start extract")
# start_time = time.time()

# with BufferedParquetWriter(f"{outdir_parquet}/train.parquet", blocking=True) as writer:
#     for sample in extract_features(train_dl):
#         writer.write(sample)

# total_time = time.time() - start_time
# total_time_str = str(datetime.timedelta(seconds=int(total_time)))
# print("Extract time {}".format(total_time_str))
# print(torch.cuda.memory_allocated())

In [140]:
features = pd.read_parquet(f"{outdir_parquet}/test.parquet")
features

Unnamed: 0,feature,key,sub,ses,run,start,events
0,"[1.2957616, 0.8341554, 0.977095, 0.7594984, 0....",sub-01_ses-25_run-09,1,25,9,0,"[{'index': 12, 'nsd_id': 45455}, {'index': 16,..."
1,"[1.6828736, 0.76685095, 0.801588, 1.1775866, 0...",sub-01_ses-25_run-09,1,25,9,16,"[{'index': 12, 'nsd_id': 45455}, {'index': 16,..."
2,"[1.3467954, 0.62295175, 0.53726166, 0.94804835...",sub-01_ses-25_run-09,1,25,9,32,"[{'index': 12, 'nsd_id': 45455}, {'index': 16,..."
3,"[1.2597265, 0.56107175, 0.88799864, 1.0987167,...",sub-01_ses-25_run-09,1,25,9,48,"[{'index': 12, 'nsd_id': 45455}, {'index': 16,..."
4,"[0.9291683, 0.7384197, 0.43327004, 0.96608216,...",sub-01_ses-25_run-09,1,25,9,64,"[{'index': 12, 'nsd_id': 45455}, {'index': 16,..."
...,...,...,...,...,...,...,...
146683,"[0.39492762, 0.9567513, 1.9104024, 0.6862465, ...",sub-01_ses-09_run-07,1,9,7,15,"[{'index': 12, 'nsd_id': 46391}, {'index': 16,..."
146684,"[0.6176336, 0.97453976, 1.6229396, 0.7125873, ...",sub-01_ses-09_run-07,1,9,7,31,"[{'index': 12, 'nsd_id': 46391}, {'index': 16,..."
146685,"[0.36875397, 0.9014565, 2.026524, 0.7028155, 0...",sub-01_ses-09_run-07,1,9,7,47,"[{'index': 12, 'nsd_id': 46391}, {'index': 16,..."
146686,"[0.11672564, 0.7731965, 2.1594627, 0.8890982, ...",sub-01_ses-09_run-07,1,9,7,63,"[{'index': 12, 'nsd_id': 46391}, {'index': 16,..."
