In [1]:
import h5py as h5
import numpy as np
import re
import torch

from algonauts.data import FMRI_Dataset
from algonauts.utils.utils import collect_predictions
from algonauts.utils.viz import plot_glass_brain, load_and_label_atlas, voxelwise_pearsonr


In [2]:
feature_paths = {
    "aud_last": "../data/features/Omni/Qwen2.5_3B/features_tr1.49_len8_before6/aud_last", #torch.Size([102, 1280])
    "aud_ln_post": "../data/features/Omni/Qwen2.5_3B/features_tr1.49_len8_before6/audio_ln_post", #torch.Size([102, 1280])
    "conv3d_features": "../data/features/Omni/Qwen2.5_3B/features_tr1.49_len8_before6/conv3d_features", #torch.Size([3536, 1280])
}

input_dims = {
    "aud_last": 1280 * 2,
}

modality_keys = list(input_dims.keys())

In [3]:
ds = FMRI_Dataset(
    "../data/raw/fmri",
    input_dims=input_dims,
    modalities=modality_keys,
    feature_paths=feature_paths,
    normalize_bold=False
)
filter_fn = lambda sample: sample["name"] in ["life", 'figures']
ds.filter_samples(filter_fn)
len(ds)

136

In [4]:
ds.samples[0]["fmri_file"]
ds.subject_name_id_dict.keys()

dict_keys(['sub-01', 'sub-02', 'sub-03', 'sub-05'])

In [5]:


sort = lambda lst: sorted(lst, key=lambda s: int(re.search(r"(figures|life)(\d+)", s).group(2)))

data_life1 = []
order_life1 = []
data_life2 = []
order_life2 = []
for subject in ds.subject_name_id_dict.keys():
    with h5.File(f"../data/raw/fmri/{subject}/func/{subject}_task-movie10_space-MNI152NLin2009cAsym_atlas-Schaefer18_parcel-1000Par7Net_bold.h5") as f:
        keys = list(f.keys())
        keys_life = [key for key in keys if "figures" in key or "life" in key]
        keys_life1 = [key for key in keys_life if "run-1" in key]
        keys_life2 = [key for key in keys_life if "run-2" in key]

        keys_life1 = sort(keys_life1)
        keys_life2 = sort(keys_life2)

        for key in keys_life1:
            order_life1.append(key)
            data_life1.append(f[key][:])
        for key in keys_life2:
            order_life2.append(key)
            data_life2.append(f[key][:])

data_life1 = np.concatenate(data_life1, axis=0)
data_life2 = np.concatenate(data_life2, axis=0)
print(data_life1.shape, data_life2.shape)
print(order_life1)
print(order_life2)

(27484, 1000) (27484, 1000)
['ses-006_task-life01_run-1', 'ses-007_task-figures01_run-1', 'ses-006_task-life02_run-1', 'ses-007_task-figures02_run-1', 'ses-007_task-figures03_run-1', 'ses-007_task-life03_run-1', 'ses-007_task-life04_run-1', 'ses-009_task-figures04_run-1', 'ses-007_task-life05_run-1', 'ses-008_task-figures05_run-1', 'ses-008_task-figures06_run-1', 'ses-008_task-figures07_run-1', 'ses-008_task-figures08_run-1', 'ses-008_task-figures09_run-1', 'ses-009_task-figures10_run-1', 'ses-009_task-figures11_run-1', 'ses-009_task-figures12_run-1', 'ses-001_task-life01_run-1', 'ses-002_task-figures01_run-1', 'ses-001_task-life02_run-1', 'ses-002_task-figures02_run-1', 'ses-001_task-life03_run-1', 'ses-002_task-figures03_run-1', 'ses-001_task-life04_run-1', 'ses-002_task-figures04_run-1', 'ses-002_task-figures05_run-1', 'ses-002_task-life05_run-1', 'ses-002_task-figures06_run-1', 'ses-003_task-figures07_run-1', 'ses-003_task-figures08_run-1', 'ses-003_task-figures09_run-1', 'ses-003_

In [6]:
masker = load_and_label_atlas(ds.samples[0]["subject_atlas"])
r_retest = voxelwise_pearsonr(data_life1, data_life2)

In [7]:
def collect_predictions_per_sample(loader, model, device):
    model.eval()
    fmri_true, fmri_pred, meta = [], [], []
    sid_map = {v: k for k, v in loader.dataset.subject_name_id_dict.items()}

    with torch.no_grad():
        for batch in loader:            # batch-size == 1
            sid_t  = batch["subject_ids"][0]
            run_id = batch["run_ids"][0]
            dataset_name = batch["dataset_names"][0]

            fmri  = batch["fmri"].to(device)[0]
            attn  = batch["attention_masks"].to(device)[0].bool()
            feats = {k: batch[k].to(device) for k in loader.dataset.modalities}

            pred = model(feats,
                         batch["subject_ids"],
                         batch["run_ids"],
                         batch["attention_masks"].to(device))[0]

            fmri_true.append(fmri[attn].cpu().numpy())
            fmri_pred.append(pred[attn].cpu().numpy())

            sid = sid_map[sid_t]
            atlas_path = loader.dataset.samples[0]["subject_atlas"].format(subject=sid)

            meta.append({"subject": sid,
                         "run": run_id,
                         "atlas_path": atlas_path,
                         "dataset_name": dataset_name})

    return fmri_true, fmri_pred, meta

In [50]:
import wandb
from algonauts.data.loader import get_train_val_loaders
from algonauts.models import load_model_from_ckpt

model, config = load_model_from_ckpt(
                model_ckpt_path="../runs/model1/checkpoints/0ixqflbd/final_model.pt",
                params_path="../runs/model1/checkpoints/0ixqflbd/config.yaml",
                device="cuda",
            )
print("Loaded model")
wandb.init(project="try", config=vars(config), )

model.to("cuda")
config.val_name = "life"
config.filter_name = []
config.batch_size = 1
train_loader, valid_loader = get_train_val_loaders(config)
print(len(train_loader), len(valid_loader))
print("Loaded data")
out_dir="many/Figures"

fmri_true, fmri_pred, meta = collect_predictions_per_sample(
    valid_loader, model, device="cuda"
)

r = voxelwise_pearsonr(fmri_true, fmri_pred)

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

In [None]:
plot_glass_brain(r_retest, "test_retest", masker, filename="test_retest")
plot_glass_brain(r, "model", masker, filename="model")