In [None]:
# Import packages and setup gpu configuration.
# This code block shouldnt need to be adjusted!
import os
import sys
import json
import yaml
import numpy as np
import copy
import math
import time
import datetime
import random
from tqdm import tqdm
import webdataset as wds
import matplotlib.pyplot as plt
import pandas as pd

import torch
import torch.nn as nn
from torchvision import transforms
import utils
from flat_models import *

from elbow.sinks import BufferedParquetWriter

# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True
# following fixes a Conv3D CUDNN_NOT_SUPPORTED error
torch.backends.cudnn.benchmark = True

## MODEL TO LOAD ##
model_name = "nsdflat_large_gsr_"
outdir = os.path.abspath(f'checkpoints/{model_name}')
print("outdir", outdir)
# Load previous config.yaml if available
if os.path.exists(f"{outdir}/config.yaml"):
    config = yaml.load(open(f"{outdir}/config.yaml", 'r'), Loader=yaml.FullLoader)
    print(f"Loaded config.yaml from ckpt folder {outdir}")
    # create global variables from the config
    print("\n__CONFIG__")
    for attribute_name in config.keys():
        print(f"{attribute_name} = {config[attribute_name]}")
        globals()[attribute_name] = config[f'{attribute_name}']
    print("\n")

### Multi-GPU config ###
device_count = torch.cuda.device_count()
print(f"Number of available CUDA devices: {device_count}")

local_rank = os.getenv('LOCAL_RANK')
if local_rank is None: 
    local_rank = 0
else:
    local_rank = int(local_rank)
print(f"LOCAL RANK={local_rank}")

num_devices = os.getenv('NUM_GPUS')
if num_devices is None: 
    num_devices = 1
else:
    num_devices = int(num_devices)
print(f"NUM GPUS={num_devices}")
distributed = True if num_devices>1 else False
if distributed: assert device_count==num_devices

node = os.getenv('SLURM_NODEID')
if node is None:
    node = 0
else:
    node = int(node)
print(f"NODE={node}")

global_rank = os.getenv('RANK')
if global_rank is None:
    global_rank = 0
else:
    global_rank = int(global_rank)
print(f"GLOBAL RANK={global_rank}")

world_size = os.getenv('WORLD_SIZE')
if world_size is None: 
    world_size = 1
else:
    world_size = int(world_size)
print(f"WORLD_SIZE={world_size}")

if utils.is_interactive():
    # Following allows you to change functions in models.py or utils.py and 
    # have this notebook automatically update with your revisions
    %load_ext autoreload
    %autoreload 2

batch_size = probe_batch_size
num_epochs = probe_num_epochs

data_type = torch.float32 # change depending on your mixed_precision
global_batch_size = batch_size * world_size

device = torch.device('cuda')

print("PID of this process =",os.getpid())
print("device =", device, "distributed =",distributed, "num_devices =", num_devices, "local rank =", local_rank, "world size =", world_size, "data_type =", data_type)

# seed all random functions
utils.seed_everything(seed + global_rank)

# hcp_flat

In [None]:
from util.flat import load_hcp_flat_mask, load_nsd_flat_mask
from util.flat import create_hcp_flat, create_nsd_flat
from util.flat import batch_unmask
import util.visualize as vis

if "HCP" in datasets_to_include:
    flat_mask = load_hcp_flat_mask()
elif "NSD" in datasets_to_include:
    flat_mask = load_nsd_flat_mask()

model = mae_vit_large_fmri(
    patch_size=patch_size,
    decoder_embed_dim=decoder_embed_dim,
    t_patch_size=t_patch_size,
    pred_t_dim=pred_t_dim,
    decoder_depth=4,
    cls_embed=cls_embed,
    norm_pix_loss=norm_pix_loss,
    no_qkv_bias=no_qkv_bias,
    sep_pos_embed=sep_pos_embed,
    trunc_init=trunc_init,
    pct_masks_to_decode=pct_masks_to_decode,
    img_mask=flat_mask,
)

# Load checkpoint

In [None]:
checkpoint_files = [f for f in os.listdir(outdir) if f.endswith('.pth')]

# Extract epoch numbers and find the largest
epoch_numbers = []
for file in checkpoint_files:
    try:
        epoch_number = int(file.split('epoch')[-1].split('.')[0])
        epoch_numbers.append(epoch_number)
    except ValueError:
        continue
latest_epoch = max(epoch_numbers)
latest_checkpoint = f"epoch{latest_epoch}.pth"

# # Or specify epoch number 
# latest_checkpoint = "epoch15.pth"

# Load the checkpoint
checkpoint_path = os.path.join(outdir, latest_checkpoint)

state = torch.load(checkpoint_path)
model.load_state_dict(state["model_state_dict"], strict=True)
model.to(device)

print(f"\nLoaded checkpoint {latest_checkpoint} from {outdir}\n")

## Create dataset and data loaders

In [None]:
batch_size = 1
print(f"changed batch_size to {batch_size}")

## Test ##
if "HCP" in datasets_to_include:
    test_dataset = create_hcp_flat(root=hcp_flat_path, 
                    split="test", frames=num_frames, shuffle=False)
elif "NSD" in datasets_to_include:
    test_dataset = create_nsd_flat(root=nsd_flat_path, 
                    frames=num_frames, shuffle=False,
                    sub="sub-01")
test_dl = wds.WebLoader(
    test_dataset.batched(batch_size, partial=False),
    batch_size=None,
    shuffle=False,
    num_workers=1,
    pin_memory=True,
)

# Start extraction

In [None]:
sessions = []
runs = []
starts = []
cnt = 0
total = 0
if "HCP" in datasets_to_include:
    for sample, sample_meta in test_dl:
        cnt += 1
        total += len(sample)
elif "NSD" in datasets_to_include:
    for sample, sample_meta, sample_events, sample_means, sample_sds in test_dl:
        sessions.append([s['ses'] for s in sample_meta])
        runs.append([s['run'] for s in sample_meta])
        starts.append([s['start'] for s in sample_meta])
        cnt += 1
        total += len(sample)
print("cnt", cnt)
print("total", total)
sessions=np.array(sessions).flatten()
runs=np.array(runs).flatten()
starts=np.array(starts).flatten()

In [None]:
np.unique(runs[(sessions==sess)])

In [None]:
sess=5
run=3
print(len(np.sort(starts[(sessions==sess)&(runs==run)])))
print(len(np.unique(np.sort(starts[(sessions==sess)&(runs==run)]))))
print(np.sort(starts[(sessions==sess)&(runs==run)])[-1])

In [None]:
@torch.no_grad()
def extract_features(dl, dataset="nsd"):
    if dataset=="hcp":
        err
        for samples, samples_meta in tqdm(dl,total=cnt):
    
            samples = samples.to(device)
            
            features = model.forward_features(samples)
    
            features = features.cpu().numpy()
    
            for feat, meta in zip(features, samples_meta):
                yield {"feature": feat, **meta}
    elif dataset=="nsd": 
        for samples, samples_meta, samples_events, sample_means, sample_sds in tqdm(dl,total=cnt):
            # # Normalize and unmask the inputs
            # samples -= sample_means[:, None]
            # samples /= sample_sds[:, None]
            # samples = batch_unmask(samples, flat_mask).unsqueeze(1)
            
            features = model.forward_features(samples.to(device))
    
            features = features.cpu().numpy()
    
            for feat, meta, events in zip(features, samples_meta, samples_events):
                yield {"feature": feat, **meta,
                       "events": events}

In [None]:
outdir_parquet = f'{outdir}/{latest_checkpoint[:-4]}'
print(outdir_parquet)

os.makedirs(outdir_parquet,exist_ok=True)

In [None]:
utils.seed_everything(seed)

print("Start extract")
print("output location:", f"{outdir_parquet}/test.parquet")
start_time = time.time()

with BufferedParquetWriter(f"{outdir_parquet}/test.parquet", blocking=True) as writer:
    if "NSD" in datasets_to_include:
        for sample in extract_features(test_dl, dataset="nsd"):
            writer.write(sample)
    elif "HCP" in datasets_to_include:
        for sample in extract_features(test_dl, dataset="hcp"):
            writer.write(sample)
    else:
        raise ValueError("No valid dataset specified in datasets_to_include")

total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("Extract time {}".format(total_time_str))
print(torch.cuda.memory_allocated())

In [None]:
print("output location:", f"{outdir_parquet}/test.parquet")

In [None]:
features = pd.read_parquet(f"{outdir_parquet}/test.parquet")
features

In [None]:
sess=9#5
run=7#3
len(np.sort(features[(features['sub']==1)&(features['ses']==sess)&(features['run']==run)]['start'].values))