In [None]:
import os
import json
# data operation
import numpy as np
from scipy import stats
# plot
# import seaborn as sns
import matplotlib.pylab as plt
%config InlineBackend.figure_formats = ["svg"]
# custom functions
from utils import *

## Define paths and constants

In [None]:
# random seed
seed = 9873
# permutation number
# permutation_n = 1000

# internal and external dataset path
internal_path = "./data/internal"
external_path = "./data/external"
# subjects to exclude
internal_exclude = []
external_exclude = []

# neurosynth network path
network2path = {
    "control": "./data/ROIs/cognitive control_association-test_z_FDR_0.01.nii",
    "moral": "./data/ROIs/moral_association-test_z_FDR_0.01.nii",
    "reward": "./data/ROIs/reward_association-test_z_FDR_0.01.nii",
    "self": "./data/ROIs/self referential_association-test_z_FDR_0.01.nii"
}
network2rois_path = "./data/generate/network2rois.json"

# dataset X, y
in_data_path = "./data/generate/in_data.npy"
ex_data_path = "./data/generate/ex_data.npy"

# community names/affiliation
comm_names_path = "./data/ROIs/power264CommunityNamesAbb.txt"
comm_affi_path = "./data/ROIs/power264CommunityAffiliation.1D"

## Get input data

### Load pre-defined network based on 264ROIs

In [None]:
# read community and affiliation
comm_names = np.loadtxt(comm_names_path, dtype=str)
comm_affi = np.loadtxt(comm_affi_path, dtype=int)
# comm_names[-1] = "UN"
# get affiliation between community names and roi indexs
comm2rois = {comm: np.argwhere(comm_affi == (i + 1)).reshape(-1,).tolist() for i, comm in enumerate(comm_names)}

### Calculate overlap between neurosynth ROIs and 264ROIs

In [None]:
# get all overlapping ROIs' index in 264ROIs template
if os.path.exists(network2rois_path):
    with open(network2rois_path, "r") as f:
        network2rois = json.load(f)
else:
    network2rois = extract_roi(network2path, network2rois_path)
indexs = list(network2rois.values())
select_rois = sorted(list(set(indexs[0]).union(*indexs[1:])))

### Extract functional connectivity matrix and behavioral data

In [None]:
# extract dataset from internal/external dataset
# (n_in_subjects, n_features + 1)
in_data = np.load(in_data_path) if os.path.exists(in_data_path) else get_input(internal_path, internal_exclude, in_data_path, select_rois)
# (n_ex_subjects, n_features + 1)
ex_data = np.load(ex_data_path) if os.path.exists(ex_data_path) else get_input(external_path, external_exclude, ex_data_path, select_rois)

### Prepare input data

In [None]:
# split features and predictions
in_X, in_y = in_data[:, :-1], in_data[:, -1]
ex_X, ex_y = ex_data[:, :-1], ex_data[:, -1]

## Select features based on fc-behavior correlation

In [None]:
cv = "loocv"
# cv = 10

In [None]:
y_actual, y_predict, _, feature_per, feature_freq = prediction(in_X, in_y, in_X, in_y, cv, feature_index=None, validation="feature")
# get correlation coefficient
coef, pvalue = stats.spearmanr(y_actual, y_predict)
# print(f"interal CV result: coef is {coef}, p value is {pvalue:.5f}")

In [None]:
# select most frequent feature according to the feature numbers in every cv
# feature_index = sorted(feature_freq.argsort()[::-1][:int(np.percentile(feature_per, 80))])
feature_index = sorted(np.argwhere(feature_freq/(in_X.shape[0]) > 0.9).reshape(-1, ).tolist())
# feature_index = sorted(np.array(np.where(feature_freq>0)).reshape(-1, ).tolist())
# feature_index = sorted(feature_freq.argsort()[::-1][:int(len(feature_freq)*0.1)])
len(feature_index)

## Permutation test

### Fix features on internal data

In [None]:
# using selected features to test on internal data
in_y_actual, in_y_predict, in_model_coefs, _, _ = prediction(in_X, in_y, in_X, in_y, cv, feature_index, "feature")
in_coef, in_pvalue = stats.spearmanr(in_y_actual, in_y_predict)
# permutation
in_coefs = permutation(in_X, in_y, in_X, in_y, in_coef, feature_index, "feature")

In [None]:
in_y_actual, in_y_predict, in_model_coefs, _, _ = prediction(in_X, in_y, in_X, in_y, cv, feature_index, "feature")

### Fix features on external data

In [None]:
# using selected features to test on external data
ex_y_actual, ex_y_predict, ex_model_coefs, _, _ = prediction(ex_X, ex_y, ex_X, ex_y, cv, feature_index, "feature")
ex_coef, ex_pvalue = stats.spearmanr(ex_y_actual, ex_y_predict)
# permutation
ex_coefs = permutation(ex_X, ex_y, ex_X, ex_y, ex_coef, feature_index, "feature")

## Functional connectivity plot based on the model

In [None]:
# fc matrix, 1 means item in selected features, 0 means out of selected features
fc_sum = recovery_fc(select_rois, feature_index, 1)

In [None]:
# internal model coefficients matrix
in_fc_coef = recovery_fc(select_rois, feature_index, np.array(in_model_coefs).mean(axis=0))
# external model coefficients matrix
ex_fc_coef = recovery_fc(select_rois, feature_index, np.array(ex_model_coefs).mean(axis=0))

## Plot

### Plot internal/external prediction/permutation

#### Model based on internal data

In [None]:
# prediction vs. target
plot_corr(in_y_actual, in_y_predict, "./plot/internal_corr.svg", "orange")
print(f"internal CV result: coef is {in_coef}, p value is {in_pvalue:.5f}")

In [None]:
# permutation
plot_permutation(in_coef, in_coefs, "./plot/internal_permutation.svg", "orange")

#### Model based on external data

In [None]:
# prediction vs. target
plot_corr(ex_y_actual, ex_y_predict, "./plot/external_corr.svg", "#F17D80")
print(f"external CV result: coef is {ex_coef}, p value is {ex_pvalue:.5f}")

In [None]:
# permutation
plot_permutation(ex_coef, ex_coefs, "./plot/external_permutation.svg", "#F17D80")

### Plot summarized edges

In [None]:
# plot 14 communities according to 264ROI
plot_conn(comm2rois, fc_sum, save_path="./plot/comm_fc_sum.svg", fontsize=16, colormap=plt.cm.Greens, annot=True)

In [None]:
# plot 4 networks according to 264ROI
plot_conn(network2rois, fc_sum, save_path="./plot/network_fc_sum.svg", fontsize=20, colormap=plt.cm.Greens, annot=True)

### Plot model coefficients

#### Internal model coefficients

In [None]:
def plot_conn(name2index, full_matrix, save_path=None, fontsize=12, filter=None, colormap=plt.cm.Blues, annot=False):
    # filter networks with valid value
    fil_name2index = {name: index for name, index in name2index.items() if abs(np.sum(full_matrix[index])) > 0}
    # matrix value filter
    tmp_matrix = full_matrix.copy()
    if filter == "abs":
        tmp_matrix = abs(tmp_matrix)
    elif filter == "pos":
        tmp_matrix[tmp_matrix < 0] = 0
    elif filter == "neg":
        tmp_matrix[tmp_matrix > 0] = 0
    print(f"Coefficients sum: {np.sum(tmp_matrix)}")
    # construct plotting fc matrix
    fc = np.zeros((len(fil_name2index), len(fil_name2index)))
    # get summarized fc matrix
    for i, name_1 in enumerate(fil_name2index.keys()):
        for j, name_2 in enumerate(fil_name2index.keys()):
            fc[i, j] = np.sum(tmp_matrix[np.ix_(fil_name2index[name_1], fil_name2index[name_2])])
    # plot
    mask = np.zeros_like(fc, dtype=bool)
    mask[np.triu_indices_from(mask, k=1)] = True
    plt.figure(figsize=(10, 8))
    with plt.style.context({"axes.labelsize": fontsize, "xtick.labelsize": fontsize, "ytick.labelsize": fontsize}):
        g = sns.heatmap(fc, mask=mask, annot=annot, annot_kws=dict(size=fontsize), fmt=".0f", cmap=colormap, linewidths=0.5, square=True, xticklabels=fil_name2index.keys(), yticklabels=fil_name2index.keys())
        g.set_xticklabels(g.get_xticklabels(), rotation=45, horizontalalignment="right")
        g.set_yticklabels(g.get_yticklabels(), rotation=0, horizontalalignment="right")
    if save_path:
        plt.savefig(save_path, bbox_inches="tight")

##### abs

In [None]:
# 14 communities
plot_conn(comm2rois, in_fc_coef, save_path="./plot/in_comm_abs_fc_coeff.svg", fontsize=16, filter="abs")

In [None]:
# 4 networks
plot_conn(network2rois, in_fc_coef, save_path="./plot/in_network_abs_fc_coeff.svg", fontsize=20, filter="abs")

##### pos

In [None]:
# 14 communities
plot_conn(comm2rois, in_fc_coef, save_path="./plot/in_comm_pos_fc_coeff.svg", fontsize=16, filter="pos")

In [None]:
# 4 networks
plot_conn(network2rois, in_fc_coef, save_path="./plot/in_network_pos_fc_coeff.svg", fontsize=20, filter="pos")

##### neg

In [None]:
# 14 communities
plot_conn(comm2rois, in_fc_coef, save_path="./plot/in_comm_neg_fc_coeff.svg", fontsize=16, filter="neg", colormap=plt.cm.Blues_r)

In [None]:
# 4 networks
plot_conn(network2rois, in_fc_coef, save_path="./plot/in_network_neg_fc_coeff.svg", fontsize=20, filter="neg", colormap=plt.cm.Blues_r)

#### External model coefficients

##### abs

In [None]:
# 14 communities
plot_conn(comm2rois, ex_fc_coef, save_path="./plot/ex_comm_abs_fc_coeff.svg", fontsize=16, filter="abs")

In [None]:
# 4 networks
plot_conn(network2rois, ex_fc_coef, save_path="./plot/ex_network_abs_fc_coeff.svg", fontsize=20, filter="abs")

##### pos

In [None]:
# 14 communities
plot_conn(comm2rois, ex_fc_coef, save_path="./plot/ex_comm_pos_fc_coeff.svg", fontsize=16, filter="pos")

In [None]:
# 4 networks
plot_conn(network2rois, ex_fc_coef, save_path="./plot/ex_network_pos_fc_coeff.svg", fontsize=20, filter="pos")

##### neg

In [None]:
# 14 communities
plot_conn(comm2rois, ex_fc_coef, save_path="./plot/ex_comm_neg_fc_coeff.svg", fontsize=16, filter="neg", colormap=plt.cm.Blues_r)

In [None]:
# 4 networks
plot_conn(network2rois, ex_fc_coef, save_path="./plot/ex_network_neg_fc_coeff.svg", fontsize=20, filter="neg", colormap=plt.cm.Blues_r)

In [None]:
plt.cm.Blues_r

### Plot markers

In [None]:
from nilearn import plotting, datasets
power = datasets.fetch_coords_power_2011()
coords = np.vstack((power.rois["x"], power.rois["y"], power.rois["z"])).T
spheres_masker = input_data.NiftiSpheresMasker(seeds=coords, radius=5)

In [None]:
color_mapping = {
    "control": "#129490",
    "moral": "#70B77E",
    "reward": "#E0A890",
    "self": "#CE1483"
}
sub_coord = np.array([coords[i] for i in select_rois])
roi2network = {i: network for network, rois in network2rois.items() for i in rois}
node_color = [color_mapping[roi2network[i]] for i in select_rois]

In [None]:
marker_view = plotting.view_markers(sub_coord, marker_color=node_color, marker_size=8)
marker_view.open_in_browser()

### Plot internal/external model coefficients of features(FC)

In [None]:
in_view = plotting.view_connectome(in_fc_coef[np.ix_(select_rois, select_rois)], coords[select_rois], node_color=node_color, node_size=8, edge_threshold="99%", colorbar_fontsize=20)
in_view.open_in_browser()

In [None]:
ex_view = plotting.view_connectome(ex_fc_coef[np.ix_(select_rois, select_rois)], coords[select_rois], node_color=node_color, node_size=8, edge_threshold="99%", colorbar_fontsize=20)
ex_view.open_in_browser()