In [None]:
import nibabel as nib
import pandas as pd
import numpy as np
import os
import scipy.io as sio
import torch

In [2]:
def cifti2_filename(cifti1_path, output_dir):
    basename = os.path.basename(cifti1_path)

    # List of known CIFTI suffixes
    known_suffixes = [".dtseries.nii", ".dconn.nii", ".dscalar.nii", ".dpconn.nii"]

    # Match and strip full suffix
    for suffix in known_suffixes:
        if basename.endswith(suffix):
            stem = basename[: -len(suffix)]  # strip the full suffix
            cifti2_name = f"{stem}_cifti2{suffix}"
            return os.path.join(output_dir, cifti2_name)

    raise ValueError("Unknown CIFTI suffix in filename")

In [3]:
def calc_correlations(input_file, tmask):
    if input_file.endswith(".dtseries.nii"):
        cifti_img = nib.load(input_file)
        cifti_data = cifti_img.get_fdata()
        # cifti_header = cifti_img.header
        bm_index_map = cifti_img.header.get_index_map(1)

        # Initialize storage
        vertex_indices = []
        # Loop through the brain models in the index map
        for bm in bm_index_map.brain_models:
            structure = bm.brain_structure
            if structure in ['CIFTI_STRUCTURE_CORTEX_LEFT', 'CIFTI_STRUCTURE_CORTEX_RIGHT']:
                # Convert Cifti2VertexIndices to numpy array
                offset = bm.index_offset
                count = bm.index_count
                vertex_indices.extend(range(offset, offset + count))

        vertex_indices = np.array(vertex_indices)
        cortex_ts = np.transpose(cifti_data[:, vertex_indices])
        masked_cortex_ts = cortex_ts[:, tmask]
        corr_matrix = np.corrcoef(masked_cortex_ts)
        return corr_matrix
    else:
        raise ValueError("Input file must be a CIFTI-2 time-series file. ")

In [4]:
def format_data(corrs, transmat_path, img_size):
    # assert corrs.shape == (59412, )
    left_data, right_data = corrs[:29697, :], corrs[29697:, :]
    left_transmat = sio.loadmat(os.path.join(transmat_path, "Left_fMRI2Grid_192_by_192_NN.mat"))
    right_transmat = sio.loadmat(os.path.join(transmat_path, "Right_fMRI2Grid_192_by_192_NN.mat"))
    
    left_surf_data = np.reshape(left_transmat['grid_mapping'] @ left_data, (img_size, img_size, 1, -1), order='F')
    right_surf_data = np.reshape(right_transmat['grid_mapping'] @ right_data, (img_size, img_size, 1, -1), order='F')

    return left_surf_data, right_surf_data

In [None]:
def rand_sample(left_surf_data, right_surf_data, sample_ratio=1.0):
    assert left_surf_data.shape == right_surf_data.shape
    if sample_ratio <= 0 or sample_ratio > 1.0:
        raise ValueError("Please pick a sample ratio between 0 and 1. ")
    img_size, _, _, sample_size = left_surf_data.shape
    indices = np.random.choice(sample_size, sample_size*sample_ratio, replace=False)
    sampled_left_surf_data = left_surf_data[:, :, :, indices]
    sampled_right_surf_data  = right_surf_data[:, :, :, indices]
    return sampled_left_surf_data, sampled_right_surf_data

In [None]:
def load_model(zdim):
    pass

def generate_loader(left_surf_data, right_surf_data, batch_size):
    pass

In [None]:
def model_inference(zdim, left_surf_data, right_surf_data, mode, batch_size, device):
    model = load_model(zdim).to(device)
    model.eval()
    inference_loader = generate_loader(left_surf_data, right_surf_data, batch_size)

    all_z_distributions = []
    all_xL_recon = []
    all_xR_recon = []
    
    for batch_idx, (xL, xR) in enumerate(inference_loader):
        xL = xL.to(device)
        xR = xR.to(device)
        z_distribution = model._encode(xL, xR)
        all_z_distributions.append(z_distribution.cpu().numpy())

        if mode == "both":
            mu = z_distribution[:, :zdim]
            z = torch.tensor(mu).to(device)
            xL_recon, xR_recon = model._decode(z)
            all_xL_recon.append(xL_recon.cpu().numpy())
            all_xR_recon.append(xR_recon.cpu().numpy())
    
    all_z_distributions = np.concatenate(all_z_distributions, axis=0)

    if mode == "encode":
        return all_z_distributions
    elif mode == "both":
        all_xL_recon = np.concatenate(all_xL_recon, axis=0)
        all_xR_recon = np.concatenate(all_xR_recon, axis=0)
        return all_z_distributions, all_xL_recon, all_xR_recon
    else:
        raise ValueError("Invalid mode. Choose 'encode' or 'both'.")

In [None]:
def visualize_representations(zs, labels):
    pass

In [18]:
example_cifti_cohort_filepath = "./data/cohort_files/cohortfiles_washu120.txt"
example_tmask_cohort_filepath = "./data/tmask_files/tmasklist_washu120.txt"

cifti_cohort_df = pd.read_csv(example_cifti_cohort_filepath, delim_whitespace=True, header=None)
tmask_cohort_df = pd.read_csv(example_tmask_cohort_filepath, delim_whitespace=True, header=None)

example_idx = 111
subj, cifti1_path = cifti_cohort_df.iloc[example_idx, :2].tolist()
tmask_subj, tmask_path = tmask_cohort_df.iloc[example_idx, :].tolist()
assert subj == tmask_subj
tmask = np.loadtxt(tmask_path, dtype=int).astype(bool)

cifti2_path = cifti2_filename(cifti1_path, "./data/washu120/")

cifti_img = nib.load(cifti2_path)
cifti_data = cifti_img.get_fdata()
cifti_header = cifti_img.header
bm_index_map = cifti_img.header.get_index_map(1)

# Initialize storage
vertex_indices = []
# Loop through the brain models in the index map
for bm in bm_index_map.brain_models:
    structure = bm.brain_structure
    if structure in ['CIFTI_STRUCTURE_CORTEX_LEFT', 'CIFTI_STRUCTURE_CORTEX_RIGHT']:
        # Convert Cifti2VertexIndices to numpy array
        offset = bm.index_offset
        count = bm.index_count
        vertex_indices.extend(range(offset, offset + count))

vertex_indices = np.array(vertex_indices)
cortex_ts = np.transpose(cifti_data[:, vertex_indices])
masked_cortex_ts = cortex_ts[:, tmask]
corr_matrix = np.corrcoef(masked_cortex_ts)
# print(corr_matrix.shape)

In [19]:
transmat_path = "./mask"
left_mask_struct = sio.loadmat(os.path.join(transmat_path, "Left_fMRI2Grid_192_by_192_NN.mat"))
left_mask = left_mask_struct['grid_mapping']
right_mask_struct = sio.loadmat(os.path.join(transmat_path, "Right_fMRI2Grid_192_by_192_NN.mat"))
right_mask = right_mask_struct['grid_mapping']
print(left_mask.shape)
print(right_mask.shape)
# mask = sio.loadmat(os.path.join(transmat_path, "MSE_Mask.mat"))

(36864, 29696)
(36864, 29716)
