In [4]:
import os
import warnings
import subprocess
import numpy as np
import pandas as pd
import nibabel as nib
import matplotlib.pyplot as plt
from scipy.stats import pearsonr, zscore, mode
from scipy.spatial import cKDTree

from precision_mapping import mapping
from feature_extraction import utils


warnings.filterwarnings('ignore')
os.chdir('/host/corin/tank/jonah/pmat/pmfe/')

func = '/host/corin/tank/jonah/pmat/results/HCD0001305/surface_dtseries_smoothed.L.func.gii'
surf = '/host/cicero/local_raid/data/HCPD/HCPD_2.0_Release/Structural_Preprocessed_Recommended/fmriresults01/HCD0001305_V1_MR/MNINonLinear/fsaverage_LR32k/HCD0001305_V1_MR.L.midthickness_MSMAll.32k_fs_LR.surf.gii'
output = '/host/corin/tank/jonah/pmat/pm_test'
tmp = f'{output}/tmp'
networks = '/host/corin/tank/jonah/pmat/pm_test/networks.L.label.gii'
hemi = 'L'



network_indices, network_labels = utils.get_template_info()


In [None]:


def get_clusters(networks, surf, hemi):
    '''Return a .'''

    network_indices, network_labels = utils.get_template_info()

    # Create list of boolean arrays for each network.
    networks_bool = [nib.load(networks).darrays[0].data == network_idx for network_idx in network_indices]

    # Separate networks into separate darrays.
    network_gii = mapping.create_func_gii(networks_bool, hemi=hemi, map_names=network_labels)
    nib.save(network_gii, f'networks_separated.{hemi}.func.gii')

    # Identify clusters for each network.
    subprocess.run([
        f'wb_command', '-metric-find-clusters',
        f'{surf}',                                # <surface> - the surface to compute on
        f'networks_separated.{hemi}.func.gii',    # <metric-in> - the input metric
        '0',                                      # <value-threshold> - threshold for data values
        '0',                                      # <minimum-area> - threshold for cluster area, in mm^2
        f'{tmp}/clusters.{hemi}.func.gii'         # <metric-out> - output - the output metric
    ])

    # Create single darray for each cluster.
    cluster_darrays = nib.load(f'{tmp}/clusters.{hemi}.func.gii').darrays

    clusters_bool = []
    for darray in cluster_darrays:
        cluster_indices = set(darray.data)
        cluster_indices.remove(0)
        clusters_bool.append([darray.data == idx for idx in cluster_indices])

    cluster_bool = np.vstack(clusters_bool)
    map_names=[f'{idx}' for idx in range(cluster_bool.shape[1])]

    # Write clusters to temporaray directory.
    clusters_gii = mapping.create_func_gii(cluster_bool, hemi, map_names=map_names)
    nib.save(clusters_gii, f'{tmp}/clusters.{hemi}.func.gii')

    # Find borders of clusters.
    subprocess.run([
        f'wb_command', '-metric-rois-to-border',
        f'{surf}',                                # <surface> - the surface to use for neighbor information
        f'{tmp}/clusters.{hemi}.func.gii',        # <metric> - the input metric containing ROIs
        f'network_clusters',                      # <class-name> - the name to use for the class of the output borders
        f'{tmp}/borders.{hemi}.border',           # <border-out> - output - the output border file
    ])

    # Find border vertices.
    subprocess.run([
        f'wb_command', '-border-to-vertices',
        f'{surf}',                               # <surface> - the surface to compute on
        f'{tmp}/borders.{hemi}.border',          # <border-file> - the border file
        f'{tmp}/borders.{hemi}.func.gii'         # <metric-out> - output - the output metric file
    ])


def get_cluster_sharpness(network_data, cluster_data, border_data, time_series, network_labels, coords, distance=5):

    # Get indices of vertices on the border.
    border_vertices = np.argwhere(border_data).flatten()

    # Get network-index of current cluster.
    cluster_network, _ = mode(network_data[np.bool(cluster_data)])

    # Get cluster time-series (mean BOLD signal across vertices within the cluster).
    cluster_xs = time_series[:,np.bool(cluster_data)].mean(axis=1)

    # Get list of nearby vertices (+/- distance) for each border vertex.
    all_nearby_vertices = [np.array(tree.query_ball_point(coords[border_vertex], r=distance)) for border_vertex in border_vertices]
    relevant_vertices = np.unique(np.hstack(all_nearby_vertices))

    # Correlate the time-series of each vertex on the surface with the cluster time-series.
    r_vals = np.zeros(network_data.shape[0])
    r_vals[relevant_vertices] = np.array([np.corrcoef(cluster_xs, vertex_xs)[0,1] for vertex_xs in time_series.T[relevant_vertices]])

    # Get network-wise border-sharpness.
    cluster_sharpness = {label:[] for label in network_labels}
    for vertex_idx, _ in enumerate(border_vertices):

        # Find all vertices within specified distance.
        nearby_vertices = all_nearby_vertices[vertex_idx]

        # Get network-indices of each vertex.
        vertex_networks = network_data[nearby_vertices]

        # Get r-vals inside cluster.
        inside_corrs = r_vals[nearby_vertices][vertex_networks == cluster_network]

        for net_idx, net_label in zip(network_indices, network_labels):
            outside_corrs = r_vals[nearby_vertices][vertex_networks == net_idx]
            if len(outside_corrs) < 5: continue
            cluster_sharpness[net_label].append(utils.get_cohens_d(inside_corrs, outside_corrs))

    cluster_sharpness_mean = {label: np.nanmean(cluster_sharpness[label]) for label in network_labels}

    return cluster_sharpness_mean, cluster_network



get_clusters(networks, surf, hemi, network_indices)


# Inputs: 
clusters = f'{tmp}/clusters.{hemi}.func.gii'
borders = f'{tmp}/borders.{hemi}.func.gii'

network_data = nib.load(networks).darrays[0].data
cluster_darrays = nib.load(clusters).darrays
border_darrays = nib.load(clusters).darrays
n_clusters = len(cluster_darrays)

# Get time-series.
time_series = utils.get_time_series(func)

# Load surface coordinates and define KDTree.
surf_gii = nib.load(surf)
coords = surf_gii.darrays[0].data
tree = cKDTree(coords)


# Get boundary sharpness of each cluster.
cluster_sharpness = {idx:[] for idx in network_indices}
for cluster_idx in range(n_clusters):

    # Get vertex-wise network labels.
    cluster_data = cluster_darrays[cluster_idx].data
    border_data = border_darrays[cluster_idx].data

    sharpness, net_idx = get_cluster_sharpness(network_data, cluster_data, border_data, time_series, network_labels, coords)
    cluster_sharpness[net_idx].append(sharpness)


# Combine into single dataframe.
all_df = []
for idx in network_indices:
    df = pd.DataFrame(cluster_sharpness[idx])
    df.insert(0,'network_idx',idx)
    all_df.append(df)

df = pd.concat(all_df)
df.reset_index(drop=True)
idx_to_label = {idx: label for idx, label in zip(network_indices, network_labels)}
df.insert(0, 'network_label', df['network_idx'].map(idx_to_label))

df.groupby('network_label').mean()


Unnamed: 0_level_0,network_idx,visual_a,visual_b,somatomotor_a,somatomotor_b,dorsal_attention_a,dorsal_attention_b,ventral_attention,salience,limbic_a,limbic_b,control_c,control_a,control_b,temporal_parietal,default_c,default_a,default_b
network_label,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
control_a,12.0,3.206007,5.871359,9.332896,,4.219301,2.138841,3.738891,2.786528,2.931059,3.291236,2.295918,0.0,3.142575,3.937993,4.22865,3.335985,3.908567
control_b,13.0,7.216858,1.84627,6.023683,,9.049582,5.575248,6.958099,2.85933,3.515981,3.052014,1.832026,2.788712,0.0,3.927023,2.384702,1.713324,2.645802
control_c,11.0,9.864905,2.841365,8.674176,2.276344,4.287628,6.596193,3.925744,2.42487,4.608,,0.0,2.599923,2.879389,9.967596,3.374908,2.249681,4.78258
default_a,16.0,5.299863,2.405091,3.045145,3.511434,4.082169,7.24406,2.89745,3.705665,2.995526,2.77472,2.298594,4.189055,2.42535,4.067667,2.901653,0.0,2.613999
default_b,17.0,8.061296,,1.250059,6.016866,4.471443,5.071686,3.745767,2.868925,1.562343,3.500287,4.794703,6.115773,2.575064,2.995087,3.757147,2.444354,0.0
default_c,15.0,2.25027,2.135876,,9.608708,2.702073,3.163061,,,2.569471,5.907587,4.95063,,2.922754,3.402599,0.0,1.344298,
dorsal_attention_a,5.0,2.057348,2.026761,4.252942,,0.0,2.939969,3.884559,3.383444,3.149416,1.243513,2.704535,2.66843,6.885227,5.721536,2.105514,4.843378,4.339534
dorsal_attention_b,6.0,6.736912,2.231082,2.517406,5.798584,3.054835,0.0,2.067205,2.29537,3.354049,3.592916,2.021371,2.039454,4.862714,3.504517,3.635397,4.418421,15.249109
limbic_a,9.0,1.715022,2.311432,,3.899602,1.151254,,2.10313,3.934714,0.0,1.807778,3.485945,2.379518,2.310513,1.321678,2.92573,2.034708,1.462341
limbic_b,10.0,,5.110385,3.148295,4.566976,5.17124,3.053352,5.538825,2.879245,2.943986,0.0,,3.079896,2.698298,2.237132,3.062526,3.094514,3.61938
