In [1]:
import nibabel as nib
import pandas as pd
import numpy as np
import os
import glob
import scipy.io as sio
from scipy.spatial.distance import cdist
from scipy.stats import pearsonr
import torch
from torch.utils.data import TensorDataset, DataLoader
from huggingface_hub import hf_hub_download
from fMRIVAE_Model import BetaVAE
import matplotlib.pyplot as plt
from tqdm import tqdm
# import subprocess
# from sklearn.metrics import silhouette_score, davies_bouldin_score, pairwise_distances, silhouette_samples
# from sklearn.svm import SVC
# from sklearn.model_selection import train_test_split
# from sklearn.preprocessing import StandardScaler
# from sklearn.model_selection import StratifiedKFold
# from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, accuracy_score
# import seaborn as sns


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def load_dtseries_and_tmask(subject_id, dtseries_folder, tmask_folder):
    """
    Loads dtseries files and corresponding tmask files for a subject in a fixed order.

    Parameters
    ----------
    subject_id : str
        e.g., "996782"
    dtseries_folder : str
        Path to folder containing dtseries files.
    tmask_folder : str
        Path to folder containing .txt tmask files.

    Returns
    -------
    concatenated_data : ndarray (n_vertices, total_timepoints)
        Concatenated vertex-level time series (no mask applied).
    concatenated_tmask : ndarray (total_timepoints,)
        Boolean array indicating which timepoints are valid.
    vertex_indices : ndarray
        Indices of cortical vertices used.
    """
    run_order = ['REST1_LR', 'REST1_RL', 'REST2_LR', 'REST2_RL']
    all_data = []
    all_tmask = []
    vertex_indices = None

    for run in run_order:
        # Build filenames
        dtseries_path = os.path.join(dtseries_folder, f"{subject_id}_rfMRI_{run}_surf_subcort_normalwall.dtseries.nii")
        tmask_path = os.path.join(tmask_folder, f"{subject_id}_rfMRI_{run}_NEW_TMASK.txt")

        # Load dtseries
        img = nib.load(dtseries_path)
        data = img.get_fdata()

        if vertex_indices is None:
            bm_index_map = img.header.get_index_map(1)
            vertex_indices = []
            for bm in bm_index_map.brain_models:
                if bm.brain_structure in ['CIFTI_STRUCTURE_CORTEX_LEFT', 'CIFTI_STRUCTURE_CORTEX_RIGHT']:
                    vertex_indices.extend(range(bm.index_offset, bm.index_offset + bm.index_count))
            vertex_indices = np.array(vertex_indices)

        cortex_data = data[:, vertex_indices].T  # (n_vertices, time)
        all_data.append(cortex_data)

        # Load tmask
        with open(tmask_path, 'r') as f:
            tmask = np.array([int(line.strip()) for line in f], dtype=bool)
        all_tmask.append(tmask)

    concatenated_data = np.concatenate(all_data, axis=1)
    concatenated_tmask = np.concatenate(all_tmask)

    return concatenated_data, concatenated_tmask, vertex_indices


In [3]:
def load_masked_ptseries(subject_id, ptseries_folder, parcellation, tmask):
    """
    Loads the ptseries file for a subject and applies reordering and tmask.
    Assumes ptseries is already concatenated for REST1 and REST2.

    Parameters
    ----------
    subject_id : str
    ptseries_folder : str
    parcellation : dict
    tmask : np.ndarray

    Returns
    -------
    ptseries_masked : np.ndarray (n_parcels, time)
    """
    import os
    import nibabel as nib

    pt_file = os.path.join(ptseries_folder, f"{subject_id}.Rest12.ptseries.nii")
    pt_img = nib.load(pt_file)
    pt_data = pt_img.get_fdata()

    assert pt_data.shape[0] == tmask.shape[0], \
        f"Mismatch in ptseries ({pt_data.shape[0]}) and tmask ({tmask.shape[0]}) length"

    parcel_order = parcellation["order"][0, 0].flatten() - 1
    pt_ordered = pt_data[:, parcel_order]
    return pt_ordered[tmask].T  # shape: (n_parcels, time)


In [4]:
def compute_parcel_vertex_correlation(ptseries_masked, cortex_masked):
    """
    Computes Pearson correlation between parcel and vertex time series.

    Parameters
    ----------
    ptseries_masked : ndarray (n_parcels, n_timepoints)
    cortex_masked : ndarray (n_vertices, n_timepoints)

    Returns
    -------
    corr_matrix : ndarray (n_parcels, n_vertices)
    """
    # Stack along rows then use np.corrcoef for fast computation
    combined = np.vstack([ptseries_masked, cortex_masked])
    full_corr = np.corrcoef(combined)
    num_parcels = ptseries_masked.shape[0]
    corr_matrix = full_corr[:num_parcels, num_parcels:]
    return corr_matrix


In [5]:
def forward_reformatting(corrs, transmat_path, img_size):
    """
    Projects vertex-level data onto 2D cortical grids for left and right hemispheres.

    This function:
    - Splits vertex-level fMRI data into left and right hemispheres
    - Applies hemisphere-specific transformation matrices (e.g., fMRI-to-grid)
    - Reshapes the resulting 2D projection into image format for model input

    Parameters
    ----------
    corrs : ndarray of shape (num_samples, num_features)
        Vertex-level input data (e.g., correlation values or BOLD signals).
        Assumes num_features = 59412 (29696 left + 29716 right cortical vertices).
    
    transmat_path : str
        Directory path containing 'Left_fMRI2Grid_192_by_192_NN.mat' and 
        'Right_fMRI2Grid_192_by_192_NN.mat', each with a 'grid_mapping' matrix.
    
    img_size : int
        The side length of the 2D grid (e.g., 192 for a 192×192 projection).
    
    Returns
    -------
    left_surf_data : ndarray of shape (num_samples, 1, img_size, img_size)
        2D grid representation of left hemisphere data.
    
    right_surf_data : ndarray of shape (num_samples, 1, img_size, img_size)
        2D grid representation of right hemisphere data.

    Notes
    -----
    - The transformation matrices are assumed to be of shape (img_size*img_size, num_vertices).
    - Data is reshaped using column-major (Fortran-style) ordering to match MATLAB-style layouts.
    - Output is formatted in NCHW format for input to deep learning models (e.g., CNNs).
    """
    num_samples, num_vertices = corrs.shape
    assert num_vertices == 59412, "Expected 59412 cortical vertices (29696 left + 29716 right)."

    # Split input features into left and right hemispheres
    left_data = corrs[:, :29696]   # shape: (num_samples, 29696)
    right_data = corrs[:, 29696:]  # shape: (num_samples, 29716)

    # Load transformation matrices
    left_transmat = sio.loadmat(os.path.join(transmat_path, "Left_fMRI2Grid_192_by_192_NN.mat"))['grid_mapping']
    right_transmat = sio.loadmat(os.path.join(transmat_path, "Right_fMRI2Grid_192_by_192_NN.mat"))['grid_mapping']

    # Project data onto 2D grid
    left_proj = left_data @ left_transmat.T    # shape: (num_samples, img_size * img_size)
    right_proj = right_data @ right_transmat.T # shape: (num_samples, img_size * img_size)

    # Reshape to (num_samples, 1, img_size, img_size) using column-major (Fortran-style) order
    # left_surf_data = np.reshape(left_proj, (num_samples, 1, img_size, img_size), order='F')
    # right_surf_data = np.reshape(right_proj, (num_samples, 1, img_size, img_size), order='F')
    left_surf_data = np.reshape(left_proj, (num_samples, 1, img_size, img_size))
    right_surf_data = np.reshape(right_proj, (num_samples, 1, img_size, img_size))

    return left_surf_data, right_surf_data

def backword_reformatting(left_surf_recon, right_surf_recon, transmat_path):
    assert left_surf_recon.shape == right_surf_recon.shape
    batch_size = left_surf_recon.shape[0]
    left_mask = sio.loadmat(os.path.join(transmat_path, "Left_fMRI2Grid_192_by_192_NN.mat"))
    right_mask = sio.loadmat(os.path.join(transmat_path, "Right_fMRI2Grid_192_by_192_NN.mat"))
    left_transmat_backward = left_mask['inverse_transformation']
    right_transmat_backward = right_mask['inverse_transformation']

    left_corrs = left_transmat_backward @ left_surf_recon.reshape(batch_size, -1).T
    right_corrs = right_transmat_backward @ right_surf_recon.reshape(batch_size, -1).T
    dtseries_recon = np.vstack((left_corrs, right_corrs))
    dtseries_recon[dtseries_recon == 0] = 1
    return dtseries_recon

In [6]:
def load_model(zdim, nc, device):
    repo_id = "cindyhfls/fcMRI-VAE"
    if zdim == 2:
        filename = "Checkpoint/checkpoint49_2024-03-28_Zdim_2_Vae-beta_20.0_Lr_0.0001_Batch-size_128_washu120_subsample10_train100_val10.pth.tar"
    elif zdim == 3:
        filename = "Checkpoint/checkpoint49_2024-11-28_Zdim_3_Vae-beta_20.0_Lr_0.0001_Batch-size_128_washu120_subsample10_train100_val10.pth.tar"
    elif zdim == 4:
        filename = "Checkpoint/checkpoint49_2024-06-21_Zdim_4_Vae-beta_20.0_Lr_0.0001_Batch-size_128_washu120_subsample10_train100_val10.pth.tar"
    else:
        raise ValueError("Invalid latent dimension. Please choose among 2, 3 and 4. ")
    
    checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
    # Load checkpoint into memory
    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    model = BetaVAE(z_dim=zdim, nc=nc)
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()
    return model.to(device)

In [7]:
def model_inference(left_surf_data, right_surf_data, zdim, nc, mode, batch_size, device):
    """
    Run inference on left and right surface data using a VAE model.

    Parameters:
        left_surf_data (np.ndarray or torch.Tensor): Input tensor of shape (batch, C, H, W)
        right_surf_data (np.ndarray or torch.Tensor): Same shape as left_surf_data
        zdim (int): Dimensionality of latent space
        nc (int): Number of input channels
        mode (str): "encode" for latent output, "both" for latent and reconstruction
        batch_size (int): Inference batch size
        device (torch.device): Target device for model and data

    Returns:
        If mode == "encode":
            z_distributions: np.ndarray of shape (N, 2*zdim)
        If mode == "both":
            Tuple of:
                z_distributions: np.ndarray of shape (N, 2*zdim)
                xL_recon: np.ndarray of shape (N, C, H, W)
                xR_recon: np.ndarray of shape (N, C, H, W)
    """

    def generate_loader(left_surf_data, right_surf_data, batch_size):
        if isinstance(left_surf_data, np.ndarray):
            left_surf_data = torch.tensor(left_surf_data, dtype=torch.float32)
        if isinstance(right_surf_data, np.ndarray):
            right_surf_data = torch.tensor(right_surf_data, dtype=torch.float32)

        # Create a TensorDataset
        dataset = TensorDataset(left_surf_data, right_surf_data)
        # Return a DataLoader
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
        return loader

    model = load_model(zdim, nc, device=device)
    model.eval()
    inference_loader = generate_loader(left_surf_data, right_surf_data, batch_size)

    all_z_distributions = []
    all_xL_recon = []
    all_xR_recon = []

    with torch.no_grad():
        for batch_idx, (xL, xR) in enumerate(inference_loader):
            xL = xL.to(device)
            xR = xR.to(device)
            # print(xL.shape)
            z_distribution = model._encode(xL, xR)
            all_z_distributions.append(z_distribution.cpu().numpy())

            if mode == "both":
                mu = z_distribution[:, :zdim]
                # z = torch.tensor(mu).to(device)
                z = mu.clone().detach().to(device)
                xL_recon, xR_recon = model._decode(z)
                all_xL_recon.append(xL_recon.cpu().numpy())
                all_xR_recon.append(xR_recon.cpu().numpy())
    
    all_z_distributions = np.concatenate(all_z_distributions, axis=0)

    if mode == "encode":
        return all_z_distributions
    elif mode == "both":
        all_xL_recon = np.concatenate(all_xL_recon, axis=0)
        all_xR_recon = np.concatenate(all_xR_recon, axis=0)
        return all_z_distributions, all_xL_recon, all_xR_recon
    else:
        raise ValueError("Invalid mode. Choose 'encode' or 'both'.")

In [8]:
dtseries_folder = "/data/laumann/data1/tunde/bold_data_final/"
ptseries_folder = "/data/wheelock/data1/datasets/HCP/HCP_965_all_Gordon333_20221226/parcel_matrices/"
tmask_folder = "/data/wheelock/data1/datasets/HCP/HCP_all_masks/"

df = pd.read_csv("/data/wheelock/data1/datasets/HCP/HCP_965_10min_Gordon333_20221123/retained_FD_Rest1.txt", delim_whitespace=True)
subj_hcpAll = df['subject'].to_list()

transmat_path = "./mask"
parcellation_filename = "IM_Gordon_13nets_333Parcels.mat"
parcellation = sio.loadmat(os.path.join(transmat_path, parcellation_filename))["IM"]


In [9]:
parcel_to_vertex_savepath = "./data/hcp/parcel_to_vertex_corrs/"
latents_savepath = "./data/hcp/vae_latents/"
zdim = 2

In [10]:
missing_subjects = []

for subj in tqdm(subj_hcpAll, desc="Processing subjects"):
    try:
        # Load dtseries and tmask
        dtseries_data, full_tmask, vertex_indices = load_dtseries_and_tmask(subj, dtseries_folder, tmask_folder)
        
        # Load ptseries and apply tmask
        ptseries_masked = load_masked_ptseries(subj, ptseries_folder, parcellation, full_tmask)

        # Compute correlation
        corr_matrix = compute_parcel_vertex_correlation(ptseries_masked, dtseries_data[:, full_tmask])

        # (Optional) Save or collect corr_matrix here
        # np.save(os.path.join(parcel_to_vertex_savepath, f"{subj}_corrs.npy"), corr_matrix)
        # sio.savemat(os.path.join(parcel_to_vertex_savepath, f"{subj}_corrs.mat"), {'corr_matrix': corr_matrix})

        l_surf, r_surf = forward_reformatting(corr_matrix, transmat_path="./mask", img_size=192)
        zs = model_inference(l_surf, r_surf, zdim=zdim, nc=1, mode="encode", batch_size=16, device="cpu")
        mus = zs[:, :zdim]

        # np.save(os.path.join(latents_savepath, f"{subj}_latents.npy"), mus)
        sio.savemat(os.path.join(latents_savepath, f"{subj}_latents.mat"), {'mu_distributions': mus})

        # print(corr_matrix.shape)
        # print(mus.shape)

    except (FileNotFoundError, AssertionError, nib.filebasedimages.ImageFileError) as e:
        print(f"Skipping subject {subj} due to error: {e}")
        missing_subjects.append(subj)

print(f"\nSkipped {len(missing_subjects)} subjects: {missing_subjects}")



Processing subjects: 100%|██████████| 965/965 [15:04:01<00:00, 56.21s/it]  


Skipped 0 subjects: []



