In [1]:
import nibabel as nib
import pandas as pd
import numpy as np
import os
import scipy.io as sio
import torch
from torch.utils.data import TensorDataset, DataLoader
from huggingface_hub import hf_hub_download
from fMRIVAE_Model import BetaVAE

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def cifti2_filename(cifti1_path, output_dir):
    basename = os.path.basename(cifti1_path)

    # List of known CIFTI suffixes
    known_suffixes = [".dtseries.nii", ".dconn.nii", ".dscalar.nii", ".dpconn.nii"]

    # Match and strip full suffix
    for suffix in known_suffixes:
        if basename.endswith(suffix):
            stem = basename[: -len(suffix)]  # strip the full suffix
            cifti2_name = f"{stem}_cifti2{suffix}"
            return os.path.join(output_dir, cifti2_name)

    raise ValueError("Unknown CIFTI suffix in filename")

In [3]:
def calc_correlations(input_file, tmask):
    if input_file.endswith(".dtseries.nii"):
        cifti_img = nib.load(input_file)
        cifti_data = cifti_img.get_fdata()
        # cifti_header = cifti_img.header
        bm_index_map = cifti_img.header.get_index_map(1)

        # Initialize storage
        vertex_indices = []
        # Loop through the brain models in the index map
        for bm in bm_index_map.brain_models:
            structure = bm.brain_structure
            if structure in ['CIFTI_STRUCTURE_CORTEX_LEFT', 'CIFTI_STRUCTURE_CORTEX_RIGHT']:
                # Convert Cifti2VertexIndices to numpy array
                offset = bm.index_offset
                count = bm.index_count
                vertex_indices.extend(range(offset, offset + count))

        vertex_indices = np.array(vertex_indices)
        cortex_ts = np.transpose(cifti_data[:, vertex_indices])
        masked_cortex_ts = cortex_ts[:, tmask]
        corr_matrix = np.corrcoef(masked_cortex_ts)
        return corr_matrix
    else:
        raise ValueError("Input file must be a CIFTI-2 time-series file. ")

In [4]:
def forward_reformatting(corrs, transmat_path, img_size):
    # assert corrs.shape == (59412, )
    left_data, right_data = corrs[:29696, :], corrs[29696:, :]
    left_mask = sio.loadmat(os.path.join(transmat_path, "Left_fMRI2Grid_192_by_192_NN.mat"))
    right_mask = sio.loadmat(os.path.join(transmat_path, "Right_fMRI2Grid_192_by_192_NN.mat"))
    left_transmat = left_mask['grid_mapping']
    right_transmat = right_mask['grid_mapping']
    
    left_surf_data = np.reshape(left_transmat @ left_data, (img_size, img_size, 1, -1), order='F')
    left_surf_data = np.transpose(left_surf_data, axes=(3, 2, 0, 1)) # (batch, 1, height_width)
    right_surf_data = np.reshape(right_transmat @ right_data, (img_size, img_size, 1, -1), order='F')
    right_surf_data = np.transpose(right_surf_data, axes=(3, 2, 0, 1))

    return left_surf_data, right_surf_data

In [5]:
def backword_reformatting(left_surf_recon, right_surf_recon, transmat_path):
    assert left_surf_recon.shape == right_surf_recon.shape
    batch_size = left_surf_recon.shape[0]
    left_mask = sio.loadmat(os.path.join(transmat_path, "Left_fMRI2Grid_192_by_192_NN.mat"))
    right_mask = sio.loadmat(os.path.join(transmat_path, "Right_fMRI2Grid_192_by_192_NN.mat"))
    left_transmat_backward = left_mask['inverse_transformation']
    right_transmat_backward = right_mask['inverse_transformation']

    left_corrs = left_transmat_backward @ left_surf_recon.reshape(batch_size, -1).T
    right_corrs = right_transmat_backward @ right_surf_recon.reshape(batch_size, -1).T
    dtseries_recon = np.vstack((left_corrs, right_corrs))
    dtseries_recon[dtseries_recon == 0] = 1
    return dtseries_recon

In [6]:
def rand_sample(left_surf_data, right_surf_data, sample_ratio=1.0):
    assert left_surf_data.shape == right_surf_data.shape
    if sample_ratio <= 0 or sample_ratio > 1.0:
        raise ValueError("Please pick a sample ratio between 0 and 1. ")
    # print(left_surf_data.shape)
    sample_size, _, _, img_size = left_surf_data.shape
    indices = np.random.choice(sample_size, int(sample_size*sample_ratio), replace=False)
    sampled_left_surf_data = left_surf_data[indices, :, :, :]
    sampled_right_surf_data  = right_surf_data[indices, :, :, :]
    return sampled_left_surf_data, sampled_right_surf_data

In [7]:
def load_model(zdim, nc, device):
    repo_id = "cindyhfls/fcMRI-VAE"
    if zdim == 2:
        filename = "Checkpoint/checkpoint49_2024-03-28_Zdim_2_Vae-beta_20.0_Lr_0.0001_Batch-size_128_washu120_subsample10_train100_val10.pth.tar"
    elif zdim == 3:
        filename = "Checkpoint/checkpoint49_2024-11-28_Zdim_3_Vae-beta_20.0_Lr_0.0001_Batch-size_128_washu120_subsample10_train100_val10.pth.tar"
    elif zdim == 4:
        filename = "Checkpoint/checkpoint49_2024-06-21_Zdim_4_Vae-beta_20.0_Lr_0.0001_Batch-size_128_washu120_subsample10_train100_val10.pth.tar"
    else:
        raise ValueError("Invalid latent dimension. Please choose among 2, 3 and 4. ")
    
    checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
    # Load checkpoint into memory
    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    model = BetaVAE(z_dim=zdim, nc=nc)
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()
    return model.to(device)

In [8]:
def model_inference(left_surf_data, right_surf_data, zdim, nc, mode, batch_size, device):
    """
    Run inference on left and right surface data using a VAE model.

    Parameters:
        left_surf_data (np.ndarray or torch.Tensor): Input tensor of shape (batch, C, H, W)
        right_surf_data (np.ndarray or torch.Tensor): Same shape as left_surf_data
        zdim (int): Dimensionality of latent space
        nc (int): Number of input channels
        mode (str): "encode" for latent output, "both" for latent and reconstruction
        batch_size (int): Inference batch size
        device (torch.device): Target device for model and data

    Returns:
        If mode == "encode":
            z_distributions: np.ndarray of shape (N, 2*zdim)
        If mode == "both":
            Tuple of:
                z_distributions: np.ndarray of shape (N, 2*zdim)
                xL_recon: np.ndarray of shape (N, C, H, W)
                xR_recon: np.ndarray of shape (N, C, H, W)
    """

    def generate_loader(left_surf_data, right_surf_data, batch_size):
        if isinstance(left_surf_data, np.ndarray):
            left_surf_data = torch.tensor(left_surf_data, dtype=torch.float32)
        if isinstance(right_surf_data, np.ndarray):
            right_surf_data = torch.tensor(right_surf_data, dtype=torch.float32)

        # Create a TensorDataset
        dataset = TensorDataset(left_surf_data, right_surf_data)
        # Return a DataLoader
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
        return loader

    model = load_model(zdim, nc, device=device)
    model.eval()
    inference_loader = generate_loader(left_surf_data, right_surf_data, batch_size)

    all_z_distributions = []
    all_xL_recon = []
    all_xR_recon = []

    with torch.no_grad():
        for batch_idx, (xL, xR) in enumerate(inference_loader):
            xL = xL.to(device)
            xR = xR.to(device)
            # print(xL.shape)
            z_distribution = model._encode(xL, xR)
            all_z_distributions.append(z_distribution.cpu().numpy())

            if mode == "both":
                mu = z_distribution[:, :zdim]
                # z = torch.tensor(mu).to(device)
                z = mu.clone().detach().to(device)
                xL_recon, xR_recon = model._decode(z)
                all_xL_recon.append(xL_recon.cpu().numpy())
                all_xR_recon.append(xR_recon.cpu().numpy())
    
    all_z_distributions = np.concatenate(all_z_distributions, axis=0)

    if mode == "encode":
        return all_z_distributions
    elif mode == "both":
        all_xL_recon = np.concatenate(all_xL_recon, axis=0)
        all_xR_recon = np.concatenate(all_xR_recon, axis=0)
        return all_z_distributions, all_xL_recon, all_xR_recon
    else:
        raise ValueError("Invalid mode. Choose 'encode' or 'both'.")

In [9]:
def calc_etasquared(a, b):
    """
    Calculate eta squared based on Cohen 2008 Neuroimage.
    
    Parameters:
    a : np.ndarray
        First input array.
    b : np.ndarray
        Second input array.
    
    Returns:
    etasquared : np.ndarray
        Array of eta squared values.
    """

    # Ensure inputs are at least 2D
    if a.ndim == 1:
        a = a[:, np.newaxis]
    if b.ndim == 1:
        b = b[:, np.newaxis]

    assert a.shape[0] == b.shape[0], 'input size mismatch'

    cols_a = a.shape[1]
    cols_b = b.shape[1]
    etasquared = np.full((cols_b, cols_a), np.nan)

    for ia in range(cols_a):
        for ib in range(cols_b):
            aa = a[:, ia]
            bb = b[:, ib]

            m = (aa + bb) / 2
            Mbar = np.nanmean(m)
            
            SSwithin = np.nansum((aa - m) ** 2 + (bb - m) ** 2)
            SStotal = np.nansum((aa - Mbar) ** 2 + (bb - Mbar) ** 2)
            etasquared[ib, ia] = 1 - SSwithin / SStotal

    return etasquared

In [10]:
def visualize_representations(zs, labels):
    pass

In [11]:
example_cifti_cohort_filepath = "./data/cohort_files/cohortfiles_washu120.txt"
example_tmask_cohort_filepath = "./data/tmask_files/tmasklist_washu120.txt"

cifti_cohort_df = pd.read_csv(example_cifti_cohort_filepath, delim_whitespace=True, header=None)
tmask_cohort_df = pd.read_csv(example_tmask_cohort_filepath, delim_whitespace=True, header=None)

example_idx = 111
subj, cifti1_path = cifti_cohort_df.iloc[example_idx, :2].tolist()
tmask_subj, tmask_path = tmask_cohort_df.iloc[example_idx, :].tolist()
assert subj == tmask_subj
tmask = np.loadtxt(tmask_path, dtype=int).astype(bool)

cifti2_path = cifti2_filename(cifti1_path, "./data/washu120/")

cifti_img = nib.load(cifti2_path)
cifti_data = cifti_img.get_fdata()
cifti_header = cifti_img.header
bm_index_map = cifti_img.header.get_index_map(1)

# Initialize storage
vertex_indices = []
# Loop through the brain models in the index map
for bm in bm_index_map.brain_models:
    structure = bm.brain_structure
    if structure in ['CIFTI_STRUCTURE_CORTEX_LEFT', 'CIFTI_STRUCTURE_CORTEX_RIGHT']:
        # Convert Cifti2VertexIndices to numpy array
        offset = bm.index_offset
        count = bm.index_count
        vertex_indices.extend(range(offset, offset + count))

vertex_indices = np.array(vertex_indices)
cortex_ts = np.transpose(cifti_data[:, vertex_indices])
masked_cortex_ts = cortex_ts[:, tmask]
corr_matrix = np.corrcoef(masked_cortex_ts)
# print(corr_matrix.shape)

In [12]:
transmat_path = "./mask"
left_mask_struct = sio.loadmat(os.path.join(transmat_path, "Left_fMRI2Grid_192_by_192_NN.mat"))
left_mask = left_mask_struct['grid_mapping']
right_mask_struct = sio.loadmat(os.path.join(transmat_path, "Right_fMRI2Grid_192_by_192_NN.mat"))
right_mask = right_mask_struct['grid_mapping']
print(left_mask.shape)
print(right_mask.shape)
# mask = sio.loadmat(os.path.join(transmat_path, "MSE_Mask.mat"))

(36864, 29696)
(36864, 29716)


In [13]:
# for test
left_surf_data, right_surf_data = forward_reformatting(corr_matrix, transmat_path, img_size=192)
print(left_surf_data.shape)
print(right_surf_data.shape)


(59412, 1, 192, 192)
(59412, 1, 192, 192)


In [14]:
model = load_model(zdim=2, nc=1, device="cpu")
left_surf_data, right_surf_data = rand_sample(left_surf_data, right_surf_data, sample_ratio=0.1)
print(left_surf_data.shape)
zs, left_surf_recon, right_surf_recon = model_inference(left_surf_data, right_surf_data, zdim=2, nc=1, mode="both", batch_size=16, device="cpu")

print(zs.shape)
print(left_surf_recon.shape)
print(right_surf_recon.shape)

(5941, 1, 192, 192)
(5941, 4)
(5941, 1, 192, 192)
(5941, 1, 192, 192)
