In [1]:
import os
import sys
import json
import argparse
import numpy as np
import math
from einops import rearrange
import time
import random
import string
import h5py
from tqdm import tqdm
import gc

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import transforms
from accelerate import Accelerator

from statistics import *

from torchmetrics.regression import PearsonCorrCoef

os.chdir("/home/naxos2-raid25/ojeda040/local/ojeda040/MindEye_Imagery/src")

# SDXL unCLIP requires code from https://github.com/Stability-AI/generative-models/tree/main
sys.path.append('generative_models/')
import sgm
from models import Clipper
from generative_models.sgm.modules.encoders.modules import FrozenOpenCLIPImageEmbedder # bigG embedder

# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True

# custom functions #
from utils import *

In [20]:
import numpy

def collection_brain_region_masks(brain_region_masks, file):
    
    
    for subject in file.keys():
        for region in file[subject].keys():
            brain_region_masks[subject][region] = file[subject][region][:]
        
    file.close()
    
    return brain_region_masks


file_kastner = h5py.File('/home/naxos2-raid25/ojeda040/local/ojeda040/MindEye_Imagery/dataset/kastner_rois.hdf5', "r")
file_mtl = h5py.File('/home/naxos2-raid25/ojeda040/local/ojeda040/MindEye_Imagery/dataset/mtl_rois.hdf5', "r")
file_thalamus = h5py.File('/home/naxos2-raid25/ojeda040/local/ojeda040/MindEye_Imagery/dataset/thalamus_rois.hdf5', "r")
file_streams = h5py.File('/home/naxos2-raid25/ojeda040/local/ojeda040/MindEye_Imagery/dataset/streams_rois.hdf5', "r")

brain_region_masks = {}
for subject in range(1,9):
    brain_region_masks[f'subj{subject:02d}'] = {}
    
streams_brain_region_masks = {}
for subject in range(1,9):
    streams_brain_region_masks[f'subj{subject:02d}'] = {}

brain_region_masks = collection_brain_region_masks(brain_region_masks, file_kastner)
brain_region_masks = collection_brain_region_masks(brain_region_masks, file_mtl)
brain_region_masks = collection_brain_region_masks(brain_region_masks, file_thalamus)
streams_brain_region_masks = collection_brain_region_masks(streams_brain_region_masks, file_streams)

for subject in brain_region_masks.keys():

    all_rois_besides_nsd_general = numpy.logical_or(brain_region_masks[subject]['35'], brain_region_masks[subject]['36'])
    
    for roi, mask in brain_region_masks[subject].items():
        if roi != 'nsd_general':
            all_rois_besides_nsd_general = numpy.logical_or(all_rois_besides_nsd_general, mask)
    
    both_nsd_general_and_other_mask = numpy.logical_and(all_rois_besides_nsd_general, brain_region_masks[subject]['nsd_general'])
    non_nsd_general = numpy.logical_xor(brain_region_masks[subject]['nsd_general'], both_nsd_general_and_other_mask)
    brain_region_masks[subject]['unlabled_nsd_general'] = non_nsd_general

for subject in streams_brain_region_masks.keys():
    
    all_rois_besides_nsd_general = numpy.logical_or(brain_region_masks[subject]['35'], brain_region_masks[subject]['36'])
    
    for roi, mask in brain_region_masks[subject].items():
        if roi != 'nsd_general':
            all_rois_besides_nsd_general = numpy.logical_or(all_rois_besides_nsd_general, mask)
            
    all_rois_streams = numpy.logical_or(streams_brain_region_masks[subject]['early'], streams_brain_region_masks[subject]['midventral'])
    
    for roi, mask in streams_brain_region_masks[subject].items():
        if roi == 'lateral' or roi == 'parietal' or roi == 'ventral':
            both_all_rois_and_stream_roi = numpy.logical_and(all_rois_besides_nsd_general, streams_brain_region_masks[subject][roi])
            non_all_rois_and_stream_roi = numpy.logical_xor(streams_brain_region_masks[subject][roi], both_all_rois_and_stream_roi)
            brain_region_masks[subject][roi] = non_all_rois_and_stream_roi
    
    
        

with h5py.File("/home/naxos2-raid25/ojeda040/local/ojeda040/MindEye_Imagery/dataset/roi_collection.hdf5", 'w') as hdf:
    for subject in range(1, 9):
        subject_str = f"subj{subject:02d}"
        mask_dictionary = brain_region_masks[subject_str]

        # Create a group for each subject
        subject_group = hdf.create_group(subject_str)

        # Iterate through the mask dictionary and create datasets for each mask
        for region, mask in mask_dictionary.items():
            if region != 'nsd_general':
                subject_group.create_dataset(region, data=mask)

In [21]:
import numpy 

brain_region_masks = {}
with h5py.File("/home/naxos2-raid25/ojeda040/local/ojeda040/MindEye_Imagery/dataset/roi_collection.hdf5", "r") as file:
    # Iterate over each subject
    for subject in file.keys():
        subject_group = file[subject]
        subject_masks = {}
        # Load the masks data for each subject
        for region in subject_group.keys():
            subject_masks[region] = subject_group[region][:]
        brain_region_masks[subject] = subject_masks
        
subject_masks = brain_region_masks[f"subj01"]

logic_or_mask = numpy.logical_or(subject_masks['35'], subject_masks['36'])
for roi, mask in subject_masks.items():
    print(f"ROI: {roi}, mask shape: {mask.shape}, True count: {np.sum(mask)}")
    if roi != 'nsd_general':
        logic_or_mask = numpy.logical_or(logic_or_mask, mask)
    
# logic_or_mask = numpy.logical_not(logic_or_mask, subject_masks['nsd_general'])

ROI: 35, mask shape: (238508,), True count: 241
ROI: 36, mask shape: (238508,), True count: 652
ROI: CA1, mask shape: (238508,), True count: 242
ROI: CA2, mask shape: (238508,), True count: 17
ROI: CA3, mask shape: (238508,), True count: 32
ROI: DG, mask shape: (238508,), True count: 254
ROI: ERC, mask shape: (238508,), True count: 343
ROI: FEF, mask shape: (238508,), True count: 72
ROI: HT, mask shape: (238508,), True count: 249
ROI: IPS0, mask shape: (238508,), True count: 606
ROI: IPS1, mask shape: (238508,), True count: 462
ROI: IPS2, mask shape: (238508,), True count: 500
ROI: IPS3, mask shape: (238508,), True count: 508
ROI: IPS4, mask shape: (238508,), True count: 65
ROI: IPS5, mask shape: (238508,), True count: 14
ROI: LGN, mask shape: (238508,), True count: 55
ROI: LO1, mask shape: (238508,), True count: 343
ROI: LO2, mask shape: (238508,), True count: 193
ROI: PHC, mask shape: (238508,), True count: 221
ROI: PHC1, mask shape: (238508,), True count: 199
ROI: PHC2, mask shape: 

In [15]:
print(np.sum(numpy.logical_and(logic_or_mask, subject_masks['nsd_general'])))

9136


In [None]:
def test(holdout_subject=1, top_n_rois=-1, data_path="../dataset/"):
    
    # Initialize an empty list to store the dataset names
    dataset_names = []

    with h5py.File(f'{data_path}/kastner_rois.hdf5', 'r') as file:
        # Function to recursively collect dataset names
        def collect_names(name, obj):
            if isinstance(obj, h5py.Dataset):
                dataset_names.append(name)

        # Iterate through the file structure and collect dataset names
        file.visititems(collect_names)
        
    
    
    
    
    with h5py.File(f'{data_path}/kastner_rois.hdf5', 'r') as file:
        roi = f['betas'][:]
        betas = torch.from_numpy(betas).to("cpu")
        
        beta_file = f"{data_path}/preprocessed_data/subject{subject}/whole_brain_include_heldout.pt"
        x = torch.load(beta_file).requires_grad_(False).to("cpu")
        
    # Function to recursively print the structure of the file
    def print_structure(name, obj):
        if isinstance(obj, h5py.Group):
            print(f"Group: {name}")
        elif isinstance(obj, h5py.Dataset):
            print(f"Dataset: {name}, shape: {obj.shape}, dtype: {obj.dtype}")

    # Iterate through the file structure
    file.visititems(print_structure)
    
    # betas = file['betas'][:]
    # betas = torch.from_numpy(betas).to("cpu")

In [None]:
test(holdout_subject=1, top_n_rois=-1, data_path="/home/naxos2-raid25/ojeda040/local/ojeda040/MindEye_Imagery/dataset")

In [3]:
with h5py.File('/home/naxos2-raid25/ojeda040/local/ojeda040/MindEye_Imagery/dataset/kastner_rois.hdf5', 'r') as file:
    # Function to recursively print the structure of the file
    def print_structure(name, obj):
        if isinstance(obj, h5py.Group):
            print(f"Group: {name}")
        elif isinstance(obj, h5py.Dataset):
            print(f"Dataset: {name}, shape: {obj.shape}, dtype: {obj.dtype}")

    # Iterate through the file structure
    file.visititems(print_structure)
    
    # betas = file['betas'][:]
    # betas = torch.from_numpy(betas).to("cpu")

Group: subj01
Dataset: subj01/FEF, shape: (238508,), dtype: bool
Dataset: subj01/IPS0, shape: (238508,), dtype: bool
Dataset: subj01/IPS1, shape: (238508,), dtype: bool
Dataset: subj01/IPS2, shape: (238508,), dtype: bool
Dataset: subj01/IPS3, shape: (238508,), dtype: bool
Dataset: subj01/IPS4, shape: (238508,), dtype: bool
Dataset: subj01/IPS5, shape: (238508,), dtype: bool
Dataset: subj01/LO1, shape: (238508,), dtype: bool
Dataset: subj01/LO2, shape: (238508,), dtype: bool
Dataset: subj01/PHC1, shape: (238508,), dtype: bool
Dataset: subj01/PHC2, shape: (238508,), dtype: bool
Dataset: subj01/SPL1, shape: (238508,), dtype: bool
Dataset: subj01/TO1, shape: (238508,), dtype: bool
Dataset: subj01/TO2, shape: (238508,), dtype: bool
Dataset: subj01/V1d, shape: (238508,), dtype: bool
Dataset: subj01/V1v, shape: (238508,), dtype: bool
Dataset: subj01/V2d, shape: (238508,), dtype: bool
Dataset: subj01/V2v, shape: (238508,), dtype: bool
Dataset: subj01/V3A, shape: (238508,), dtype: bool
Dataset:

In [11]:

with h5py.File('/home/naxos2-raid25/ojeda040/local/ojeda040/MindEye_Imagery/dataset/kastner_rois.hdf5', 'r') as file:
    for roi in file['subj01'].keys():
        print(roi)
        for subject in file.keys():
            print(subject)
            mask = file[subject][roi]

FEF
subj01
subj02
subj03
subj04
subj05
subj06
subj07
subj08
IPS0
subj01
subj02
subj03
subj04
subj05
subj06
subj07
subj08
IPS1
subj01
subj02
subj03
subj04
subj05
subj06
subj07
subj08
IPS2
subj01
subj02
subj03
subj04
subj05
subj06
subj07
subj08
IPS3
subj01
subj02
subj03
subj04
subj05
subj06
subj07
subj08
IPS4
subj01
subj02
subj03
subj04
subj05
subj06
subj07
subj08
IPS5
subj01
subj02
subj03
subj04
subj05
subj06
subj07
subj08
LO1
subj01
subj02
subj03
subj04
subj05
subj06
subj07
subj08
LO2
subj01
subj02
subj03
subj04
subj05
subj06
subj07
subj08
PHC1
subj01
subj02
subj03
subj04
subj05
subj06
subj07
subj08
PHC2
subj01
subj02
subj03
subj04
subj05
subj06
subj07
subj08
SPL1
subj01
subj02
subj03
subj04
subj05
subj06
subj07
subj08
TO1
subj01
subj02
subj03
subj04
subj05
subj06
subj07
subj08
TO2
subj01
subj02
subj03
subj04
subj05
subj06
subj07
subj08
V1d
subj01
subj02
subj03
subj04
subj05
subj06
subj07
subj08
V1v
subj01
subj02
subj03
subj04
subj05
subj06
subj07
subj08
V2d
subj01
subj02
subj03
subj04

In [12]:
brain_region_masks = {}
with h5py.File('/home/naxos2-raid25/ojeda040/local/ojeda040/MindEye_Imagery/dataset/kastner_rois.hdf5', "r") as file:
    # Iterate over each subject
    for subject in file.keys():
        subject_group = file[subject]
        subject_masks = {}
        # Load the masks data for each subject
        for region in subject_group.keys():
            subject_masks[region] = subject_group[region][:]
        brain_region_masks[subject] = subject_masks
        
subject_masks = brain_region_masks[f"subj01"]

In [2]:
# Compute Pearson correlation along the 18 trials (axis 0) for each of the 72 elements
def pearson_corr(x, y):
    # Mean of each row (across trials)
    mean_x = torch.mean(x, dim=0)
    mean_y = torch.mean(y, dim=0)
    
    # Covariance numerator
    cov = torch.sum((x - mean_x) * (y - mean_y), dim=0)
    
    # Standard deviations
    std_x = torch.sqrt(torch.sum((x - mean_x) ** 2, dim=0))
    std_y = torch.sqrt(torch.sum((y - mean_y) ** 2, dim=0))
    
    # Pearson correlation
    corr = cov / (std_x * std_y)
    
    return corr

In [22]:
brain_region_masks = {}
with h5py.File('/home/naxos2-raid25/ojeda040/local/ojeda040/MindEye_Imagery/dataset/kastner_rois.hdf5', "r") as file:
    
    # Iterate over each subject
    for subject in file.keys():
        subject_group = file[subject]
        subject_masks = {}
        # Load the masks data for each subject
        for region in subject_group.keys():
            subject_masks[region] = subject_group[region][:]
        brain_region_masks[subject] = subject_masks
        
    subjects = [f'subj0{i}' for i in range(2, 9)]

    # Initialize an empty dictionary to store ROI correlations
    roi_correlations = {}

    # Loop through each ROI for correlation analysis, with tqdm progress bar
    for roi in tqdm(file['subj01'].keys(), desc="Processing ROIs"):
        
        roi_data = []  # To store ROI data across all subjects
        print(f"\nProcessing ROI: {roi}")
            
        # Load the ROI data for each subject
        for subject in subjects:
            print(f"  Loading subject: {subject}")
                
            # Load the beta file for the subject
            beta_vision, _ = load_nsd_mental_imagery(subject=subject[-1:], mode='vision', stimtype="all", average=False, nest=False, whole_brain=True)
            beta_imagery, _ = load_nsd_mental_imagery(subject=subject[-1:], mode='imagery', stimtype="all", average=False, nest=False, whole_brain=True)
            
            print(f"    Shape of beta_vision: {beta_vision.shape}")
            print(f"    Shape of beta_imagery: {beta_imagery.shape}")
            
            # Load the boolean mask for the current ROI
            roi_mask = file[subject][roi]
            # nsd_general = np.array(file[subject]['nsd_general'])
            # print(type(nsd_general))
            # print(nsd_general)
            # print(f"  Loaded ROI mask for {subject}{roi}, mask shape: {roi_mask.shape}, True count: {np.sum(roi_mask)}")
            # print(f"  Loaded ROI mask for {subject}{roi}, mask shape: {nsd_general.shape}, True count: {np.sum(nsd_general)}")
            # roi_mask = roi_mask[nsd_general]
            print(f"  Loaded ROI mask for {subject}{roi}, mask shape: {roi_mask.shape}, True count: {np.sum(roi_mask)}")
            
            # Apply the ROI mask to the subject data (along the second dimension)
            masked_data_vision = beta_vision[..., roi_mask]
            masked_data_imagery = beta_imagery[..., roi_mask]
            print(f"    Shape of masked data vision: {masked_data_vision.shape}")
            print(f"    Shape of masked data imagery: {masked_data_imagery.shape}")
            
            # Remove the singleton dimension (1) to reduce the shape.
            masked_data_vision = masked_data_vision.squeeze(1)
            masked_data_imagery = masked_data_imagery.squeeze(1)
            
            # Compute the Pearson correlation for the two tensors
            #correlation = pearson_corr(masked_data_vision, masked_data_imagery)

            # Append the single pearson correlation value to 
            # the subject correlation tensor to the list
            pearson = PearsonCorrCoef(num_outputs=masked_data_vision.shape[-1])
            subject_specific_correlation_value = torch.mean(pearson(masked_data_vision, masked_data_imagery))
            print(f"    Subject specific correlation value: {subject_specific_correlation_value}")
            
            roi_data.append(subject_specific_correlation_value)
            
        # Stack the tensors along a new dimension and calculate the mean
        mean_roi_pearson_correlation_across_subjects = torch.mean(torch.stack(roi_data), dim=0)
        print(f"    ROI: {roi} Mean Pearson Correlation Across Subjects: {mean_roi_pearson_correlation_across_subjects}")
        
        # Store the mean correlation for the ROI
        # .item() converts the torch object to a float value.
        roi_correlations[roi] = mean_roi_pearson_correlation_across_subjects.item()
                
    # Sort ROIs by mean Pearson correlation
    sorted_rois = sorted(roi_correlations.items(), key=lambda x: x[1], reverse=True)

    # Print the ROIs ranked by correlation
    print("\nROIs ranked by Pearson correlation across subjects:")
    for roi, corr in sorted_rois:
        print(f"ROI: {roi}, Mean Pearson Correlation: {corr:.4f}")

Processing ROIs:   0%|          | 0/26 [00:00<?, ?it/s]


Processing ROI: FEF
  Loading subject: subj02
torch.Size([18, 1, 242606]) torch.Size([18, 3, 425, 425])
torch.Size([18, 1, 242606]) torch.Size([18, 3, 425, 425])
    Shape of beta_vision: torch.Size([18, 1, 242606])
    Shape of beta_imagery: torch.Size([18, 1, 242606])
  Loaded ROI mask for subj02FEF, mask shape: (242606,), True count: 97


  masked_data_vision = beta_vision[..., roi_mask]
  masked_data_imagery = beta_imagery[..., roi_mask]


    Shape of masked data vision: torch.Size([18, 1, 97])
    Shape of masked data imagery: torch.Size([18, 1, 97])
    Subject specific correlation value: -0.04663512110710144
  Loading subject: subj03
torch.Size([18, 1, 246730]) torch.Size([18, 3, 425, 425])
torch.Size([18, 1, 246730]) torch.Size([18, 3, 425, 425])
    Shape of beta_vision: torch.Size([18, 1, 246730])
    Shape of beta_imagery: torch.Size([18, 1, 246730])
  Loaded ROI mask for subj03FEF, mask shape: (246730,), True count: 74
    Shape of masked data vision: torch.Size([18, 1, 74])
    Shape of masked data imagery: torch.Size([18, 1, 74])
    Subject specific correlation value: 0.09989674389362335
  Loading subject: subj04
torch.Size([18, 1, 229642]) torch.Size([18, 3, 425, 425])
torch.Size([18, 1, 229642]) torch.Size([18, 3, 425, 425])
    Shape of beta_vision: torch.Size([18, 1, 229642])
    Shape of beta_imagery: torch.Size([18, 1, 229642])
  Loaded ROI mask for subj04FEF, mask shape: (229642,), True count: 107


In [13]:
print(subject_masks)

{'FEF': array([False, False, False, ..., False, False, False]), 'IPS0': array([False, False, False, ..., False, False, False]), 'IPS1': array([False, False, False, ..., False, False, False]), 'IPS2': array([False, False, False, ..., False, False, False]), 'IPS3': array([False, False, False, ..., False, False, False]), 'IPS4': array([False, False, False, ..., False, False, False]), 'IPS5': array([False, False, False, ..., False, False, False]), 'LO1': array([False, False, False, ..., False, False, False]), 'LO2': array([False, False, False, ..., False, False, False]), 'PHC1': array([False, False, False, ..., False, False, False]), 'PHC2': array([False, False, False, ..., False, False, False]), 'SPL1': array([False, False, False, ..., False, False, False]), 'TO1': array([False, False, False, ..., False, False, False]), 'TO2': array([False, False, False, ..., False, False, False]), 'V1d': array([False, False, False, ..., False, False, False]), 'V1v': array([False, False, False, ..., False

In [2]:
# Initialize an empty list to store the dataset names
dataset_names = []

with h5py.File('/home/naxos2-raid25/ojeda040/local/ojeda040/MindEye_Imagery/dataset/kastner_rois.hdf5', 'r') as file:
    # Function to recursively collect dataset names
    def collect_names(name, obj):
        if isinstance(obj, h5py.Dataset):
            dataset_names.append(name)

    # Iterate through the file structure and collect dataset names
    file.visititems(collect_names)

# Print the list of dataset names
print(dataset_names)

['subj01/FEF', 'subj01/IPS0', 'subj01/IPS1', 'subj01/IPS2', 'subj01/IPS3', 'subj01/IPS4', 'subj01/IPS5', 'subj01/LO1', 'subj01/LO2', 'subj01/PHC1', 'subj01/PHC2', 'subj01/SPL1', 'subj01/TO1', 'subj01/TO2', 'subj01/V1d', 'subj01/V1v', 'subj01/V2d', 'subj01/V2v', 'subj01/V3A', 'subj01/V3B', 'subj01/V3d', 'subj01/V3v', 'subj01/VO1', 'subj01/VO2', 'subj01/hV4', 'subj01/nsd_general', 'subj02/FEF', 'subj02/IPS0', 'subj02/IPS1', 'subj02/IPS2', 'subj02/IPS3', 'subj02/IPS4', 'subj02/IPS5', 'subj02/LO1', 'subj02/LO2', 'subj02/PHC1', 'subj02/PHC2', 'subj02/SPL1', 'subj02/TO1', 'subj02/TO2', 'subj02/V1d', 'subj02/V1v', 'subj02/V2d', 'subj02/V2v', 'subj02/V3A', 'subj02/V3B', 'subj02/V3d', 'subj02/V3v', 'subj02/VO1', 'subj02/VO2', 'subj02/hV4', 'subj02/nsd_general', 'subj03/FEF', 'subj03/IPS0', 'subj03/IPS1', 'subj03/IPS2', 'subj03/IPS3', 'subj03/IPS4', 'subj03/IPS5', 'subj03/LO1', 'subj03/LO2', 'subj03/PHC1', 'subj03/PHC2', 'subj03/SPL1', 'subj03/TO1', 'subj03/TO2', 'subj03/V1d', 'subj03/V1v', 'sub

In [9]:
import h5py
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm  # Import tqdm for progress bar

# Function to calculate Pearson correlation between subjects' ROIs
def pearson_correlation_across_subjects(roi_data):
    n_subjects = len(roi_data)
    
    # Collect the flattened ROI data for each subject
    roi_data_flattened = [data.flatten() for data in roi_data]
    
    # Calculate pairwise Pearson correlation between subjects
    corr_matrix = np.corrcoef(roi_data_flattened)
    
    # Return only the upper triangular part of the matrix (since it's symmetric)
    return corr_matrix

# Path to data and HDF5 file
hdf5_file = '/home/naxos2-raid25/ojeda040/local/ojeda040/MindEye_Imagery/dataset/kastner_rois.hdf5'
data_path = '/home/naxos2-raid25/ojeda040/local/ojeda040/MindEye_Imagery/dataset/preprocessed_data'

# Load the HDF5 file and process the data
with h5py.File(hdf5_file, 'r') as file:
    subjects = [f'subj0{i}' for i in range(2, 9)]  # Exclude subject 01, load subj02 to subj08
    rois = [f'/FEF', '/IPS0', '/IPS1', '/IPS2', '/IPS3', '/IPS4', '/IPS5', 
            '/LO1', '/LO2', '/PHC1', '/PHC2', '/SPL1', '/TO1', '/TO2', 
            '/V1d', '/V1v', '/V2d', '/V2v', '/V3A', '/V3B', '/V3d', '/V3v', 
            '/VO1', '/VO2', '/hV4', '/nsd_general']

    # Initialize an empty dictionary to store ROI correlations
    roi_correlations = {}

    # Loop through each ROI for correlation analysis, with tqdm progress bar
    for roi in tqdm(rois, desc="Processing ROIs"):
        roi_data = []  # To store ROI data across all subjects
        print(f"\nProcessing ROI: {roi}")
        
        # Load the ROI data for each subject
        for subject in subjects:
            print(f"  Loading subject: {subject}")
            
            # Load the beta file for the subject
            beta_file = f"{data_path}/subject{subject[-1:]}/whole_brain_include_heldout.pt"
            print(f"  Loading beta file: {beta_file}")
            x = torch.load(beta_file).requires_grad_(False).to("cpu")
            print(f"    Shape of x: {x.shape}")
            
            # Load the boolean mask for the current ROI
            roi_mask = file[f'{subject}{roi}'][:]
            print(f"  Loaded ROI mask for {subject}{roi}, mask shape: {roi_mask.shape}, True count: {np.sum(roi_mask)}")
            
            # Apply the ROI mask to the subject data (along the second dimension)
            masked_data = x[:, roi_mask]
            print(f"    Shape of masked data: {masked_data.shape}")
            
            # Append masked data for this subject
            roi_data.append(masked_data.numpy())  # Convert tensor to numpy
        
        print(f"  Collected masked data for {len(roi_data)} subjects")
        
        # Calculate the Pearson correlation across subjects for this ROI
        print(f"  Calculating Pearson correlation for ROI: {roi}")
        corr_matrix = pearson_correlation_across_subjects(roi_data)
        
        # Calculate the mean correlation for this ROI (for ranking)
        mean_corr = np.mean(corr_matrix[np.triu_indices_from(corr_matrix, k=1)])
        print(f"  Mean Pearson correlation for {roi}: {mean_corr:.4f}")
        
        # Store the mean correlation for the ROI
        roi_correlations[roi] = mean_corr

    # Sort ROIs by mean Pearson correlation
    sorted_rois = sorted(roi_correlations.items(), key=lambda x: x[1], reverse=True)

    # Print the ROIs ranked by correlation
    print("\nROIs ranked by Pearson correlation across subjects:")
    for roi, corr in sorted_rois:
        print(f"ROI: {roi}, Mean Pearson Correlation: {corr:.4f}")




Processing ROIs:   0%|          | 0/26 [00:00<?, ?it/s]


Processing ROI: /FEF
  Loading subject: subj02
  Loading beta file: /home/naxos2-raid25/ojeda040/local/ojeda040/MindEye_Imagery/dataset/preprocessed_data/subject2/whole_brain_include_heldout.pt
    Shape of x: torch.Size([30000, 242606])
  Loaded ROI mask for subj02/FEF, mask shape: (242606,), True count: 97
    Shape of masked data: torch.Size([30000, 97])
  Loading subject: subj03
  Loading beta file: /home/naxos2-raid25/ojeda040/local/ojeda040/MindEye_Imagery/dataset/preprocessed_data/subject3/whole_brain_include_heldout.pt
    Shape of x: torch.Size([24000, 246730])
  Loaded ROI mask for subj03/FEF, mask shape: (246730,), True count: 74
    Shape of masked data: torch.Size([24000, 74])
  Loading subject: subj04
  Loading beta file: /home/naxos2-raid25/ojeda040/local/ojeda040/MindEye_Imagery/dataset/preprocessed_data/subject4/whole_brain_include_heldout.pt
    Shape of x: torch.Size([22500, 229642])
  Loaded ROI mask for subj04/FEF, mask shape: (229642,), True count: 107
    Shape 

Processing ROIs:   0%|          | 0/26 [02:33<?, ?it/s]

    Shape of x: torch.Size([22500, 234961])
  Loaded ROI mask for subj08/FEF, mask shape: (234961,), True count: 93
    Shape of masked data: torch.Size([22500, 93])
  Collected masked data for 7 subjects
  Calculating Pearson correlation for ROI: /FEF





ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (7,) + inhomogeneous part.