In [1]:
# Import packages and setup gpu configuration.
# This code block shouldnt need to be adjusted!
import os
import sys
import json
import yaml
import numpy as np
import copy
import math
import time
import datetime
import random
from tqdm import tqdm
import webdataset as wds
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torchvision import transforms
import utils
from flat_models import *

from elbow.sinks import BufferedParquetWriter

# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True
# following fixes a Conv3D CUDNN_NOT_SUPPORTED error
torch.backends.cudnn.benchmark = True

## MODEL TO LOAD ##
model_name = "ps16_mask9_3losses_bs32"
outdir = os.path.abspath(f'checkpoints/{model_name}')
print("outdir", outdir)
# Load previous config.yaml if available
if os.path.exists(f"{outdir}/config.yaml"):
    config = yaml.load(open(f"{outdir}/config.yaml", 'r'), Loader=yaml.FullLoader)
    print(f"Loaded config.yaml from ckpt folder {outdir}")
    # create global variables from the config
    print("\n__CONFIG__")
    for attribute_name in config.keys():
        print(f"{attribute_name} = {config[attribute_name]}")
        globals()[attribute_name] = config[f'{attribute_name}']
    print("\n")

### Multi-GPU config ###
device_count = torch.cuda.device_count()
print(f"Number of available CUDA devices: {device_count}")

local_rank = os.getenv('LOCAL_RANK')
if local_rank is None: 
    local_rank = 0
else:
    local_rank = int(local_rank)
print(f"LOCAL RANK={local_rank}")

num_devices = os.getenv('NUM_GPUS')
if num_devices is None: 
    num_devices = 1
else:
    num_devices = int(num_devices)
print(f"NUM GPUS={num_devices}")
distributed = True if num_devices>1 else False
if distributed: assert device_count==num_devices

node = os.getenv('SLURM_NODEID')
if node is None:
    node = 0
else:
    node = int(node)
print(f"NODE={node}")

global_rank = os.getenv('RANK')
if global_rank is None:
    global_rank = 0
else:
    global_rank = int(global_rank)
print(f"GLOBAL RANK={global_rank}")

world_size = os.getenv('WORLD_SIZE')
if world_size is None: 
    world_size = 1
else:
    world_size = int(world_size)
print(f"WORLD_SIZE={world_size}")

if utils.is_interactive():
    # Following allows you to change functions in models.py or utils.py and 
    # have this notebook automatically update with your revisions
    %load_ext autoreload
    %autoreload 2

batch_size = probe_batch_size
num_epochs = probe_num_epochs
hcp_flat_path = "/weka/proj-medarc/shared/hcp_flat"

data_type = torch.float32 # change depending on your mixed_precision
global_batch_size = batch_size * world_size

device = torch.device('cuda')

print("PID of this process =",os.getpid())
print("device =", device, "distributed =",distributed, "num_devices =", num_devices, "local rank =", local_rank, "world size =", world_size, "data_type =", data_type)

# seed all random functions
utils.seed_everything(seed + global_rank)

outdir /weka/proj-fmri/paulscotti/fMRI-foundation-model/flat/checkpoints/ps16_mask9_3losses_bs32
Loaded config.yaml from ckpt folder /weka/proj-fmri/paulscotti/fMRI-foundation-model/flat/checkpoints/ps16_mask9_3losses_bs32

__CONFIG__
base_lr = 0.001
batch_size = 32
ckpt_interval = 5
ckpt_saving = True
cls_embed = True
decoder_embed_dim = 512
grad_clip = 1.0
hcp_flat_path = /weka/proj-medarc/shared/hcp_flat
mask_ratio = 0.9
model_name = ps16_mask9_3losses_bs32
no_qkv_bias = False
norm_pix_loss = False
num_epochs = 100
num_frames = 16
num_samples_per_epoch = 200000
num_workers = 10
patch_size = 16
pred_t_dim = 8
print_interval = 20
probe_base_lr = 0.0003
probe_batch_size = 8
probe_num_epochs = 30
probe_num_samples_per_epoch = 100000
resume_from_ckpt = True
seed = 42
sep_pos_embed = True
t_patch_size = 2
test_num_samples_per_epoch = 50000
trunc_init = False
use_contrastive_loss = True
wandb_log = True


Number of available CUDA devices: 1
LOCAL RANK=0
NUM GPUS=1
NODE=0
GLOBAL RANK=0
WORL

# hcp_flat

In [2]:
from util.hcp_flat import load_hcp_flat_mask
from util.hcp_flat import create_hcp_flat
import util.visualize as vis

model = mae_vit_small_fmri(
    patch_size=patch_size,
    decoder_embed_dim=decoder_embed_dim,
    t_patch_size=t_patch_size,
    pred_t_dim=pred_t_dim,
    decoder_depth=4,
    cls_embed=cls_embed,
    norm_pix_loss=norm_pix_loss,
    no_qkv_bias=no_qkv_bias,
    sep_pos_embed=sep_pos_embed,
    trunc_init=trunc_init,
)

model.eval()
model.requires_grad_(False)
model.to(device)

img_size (144, 320) patch_size (16, 16) frames 16 t_patch_size 2
model initialized


MaskedAutoencoderViT(
  (patch_embed): PatchEmbed(
    (proj): Conv3d(1, 384, kernel_size=(2, 16, 16), stride=(2, 16, 16))
  )
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (q): Linear(in_features=384, out_features=384, bias=True)
        (k): Linear(in_features=384, out_features=384, bias=True)
        (v): Linear(in_features=384, out_features=384, bias=True)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity()
        (fc2): Linear(in_features=1536, out_features=384, bias=True)
        (drop2): Dropout(p=0.0, inpl

## Create dataset and data loaders

In [3]:
batch_size = 64 #12
print(f"changed batch_size to {batch_size}")

## Train ##
train_dataset = create_hcp_flat(root=hcp_flat_path, 
                    split="train", frames=num_frames, 
                    clip_mode="event", shuffle=False)
train_dl = wds.WebLoader(
    train_dataset.batched(batch_size, partial=False),
    batch_size=None,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
)

## Test ##
test_dataset = create_hcp_flat(root=hcp_flat_path, 
                    split="test", frames=num_frames, 
                    clip_mode="event", shuffle=False)
test_dl = wds.WebLoader(
    test_dataset.batched(batch_size, partial=False),
    batch_size=None,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
)

changed batch_size to 64


# Load checkpoint

In [4]:
if not os.path.exists(outdir) or not os.path.isdir(outdir):
    print(f"\nCheckpoint folder {outdir} does not exist.\n")
    err
else:
    checkpoint_files = [f for f in os.listdir(outdir) if f.endswith('.pth')]

    # Extract epoch numbers and find the largest
    epoch_numbers = []
    for file in checkpoint_files:
        try:
            epoch_number = int(file.split('epoch')[-1].split('.')[0])
            epoch_numbers.append(epoch_number)
        except ValueError:
            continue
    latest_epoch = max(epoch_numbers)
    latest_checkpoint = f"epoch{latest_epoch}.pth"
    
    # # Or specify epoch number 
    # latest_checkpoint = "epoch15.pth"

    # Load the checkpoint
    checkpoint_path = os.path.join(outdir, latest_checkpoint)

    state = torch.load(checkpoint_path)
    model.load_state_dict(state["model_state_dict"], strict=True)

    print(f"\nLoaded checkpoint {latest_checkpoint} from {outdir}\n")


Loaded checkpoint epoch99.pth from /weka/proj-fmri/paulscotti/fMRI-foundation-model/flat/checkpoints/ps16_mask9_3losses_bs32



# Start extraction

In [5]:
outdir_parquet = f'{outdir}/{latest_checkpoint[:-4]}'
print(outdir_parquet)

os.makedirs(outdir_parquet,exist_ok=True)

/weka/proj-fmri/paulscotti/fMRI-foundation-model/flat/checkpoints/ps16_mask9_3losses_bs32/epoch99


In [6]:
# for i, (samples, samples_meta) in enumerate(tqdm(train_dl)):
#     samples = samples.to(device, non_blocking=True)
## using batch_size=12: 49284it [14:03, 58.44it/s]

In [7]:
@torch.no_grad()
def extract_features(dl):
    for samples, samples_meta in tqdm(dl,total=49284):
        samples = samples.to(device, non_blocking=True)

        features = model.forward_features(samples)

        features = features.cpu().numpy()

        for feat, meta in zip(features, samples_meta):
            yield {"feature": feat, **meta}

In [8]:
utils.seed_everything(seed)

print("Start extract")
start_time = time.time()

with BufferedParquetWriter(f"{outdir_parquet}/test.parquet", blocking=True) as writer:
    for sample in extract_features(test_dl):
        writer.write(sample)

total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("Extract time {}".format(total_time_str))
print(torch.cuda.memory_allocated())

Start extract


  0%|▏                                    | 197/49284 [00:29<2:03:32,  6.62it/s]


Extract time 0:00:29
389945344


In [9]:
utils.seed_everything(seed)

print("Start extract")
start_time = time.time()

with BufferedParquetWriter(f"{outdir_parquet}/train.parquet", blocking=True) as writer:
    for sample in extract_features(train_dl):
        writer.write(sample)

total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("Extract time {}".format(total_time_str))
print(torch.cuda.memory_allocated())

Start extract


  4%|█▎                                  | 1849/49284 [03:59<1:42:32,  7.71it/s]


Extract time 0:03:59
389945344
