In [None]:
import os
import json
# data operation
import numpy as np
from scipy import stats
# plot
# import seaborn as sns
import matplotlib.pylab as plt
%config InlineBackend.figure_formats = ["svg"]
# custom functions
from utils import *

## Define paths and constants

In [None]:
# random seed
seed = 9873
# permutation number
# permutation_n = 1000

# internal and external dataset path
internal_path = "./data/internal"
external_path = "./data/external"
# subjects to exclude
internal_exclude = []
external_exclude = []

# neurosynth network path
network2path = {
    "control": "./data/ROIs/cognitive control_association-test_z_FDR_0.01.nii",
    "moral": "./data/ROIs/moral_association-test_z_FDR_0.01.nii",
    "reward": "./data/ROIs/reward_association-test_z_FDR_0.01.nii",
    "self": "./data/ROIs/self referential_association-test_z_FDR_0.01.nii"
}
network2rois_path = "./data/generate/network2rois.json"

# dataset X, y
in_data_path = "./data/generate/in_data.npy"
ex_data_path = "./data/generate/ex_data.npy"

# community names/affiliation
comm_names_path = "./data/ROIs/power264CommunityNamesAbb.txt"
comm_affi_path = "./data/ROIs/power264CommunityAffiliation.1D"

## Get input data

### Load pre-defined network based on 264ROIs

In [None]:
# read community and affiliation
comm_names = np.loadtxt(comm_names_path, dtype=str)
comm_affi = np.loadtxt(comm_affi_path, dtype=int)
# get affiliation between community names and roi indexs
comm2rois = {comm: np.argwhere(comm_affi == (i + 1)).reshape(-1,).tolist() for i, comm in enumerate(comm_names)}

### Calculate overlap between neurosynth ROIs and 264ROIs

In [None]:
# get all overlapping ROIs' index in 264ROIs template
if os.path.exists(network2rois_path):
    with open(network2rois_path, "r") as f:
        network2rois = json.load(f)
else:
    network2rois = extract_roi(network2path, network2rois_path)
indexs = list(network2rois.values())
select_rois = sorted(list(set(indexs[0]).union(*indexs[1:])))

### Extract functional connectivity matrix and behavioral data

In [None]:
# extract dataset from internal/external dataset
# (n_in_subjects, n_features + 1)
in_data = np.load(in_data_path) if os.path.exists(in_data_path) else get_input(internal_path, internal_exclude, in_data_path, select_rois)
# (n_ex_subjects, n_features + 1)
ex_data = np.load(ex_data_path) if os.path.exists(ex_data_path) else get_input(external_path, external_exclude, ex_data_path, select_rois)

### Prepare input data

In [None]:
# split features and predictions
in_X, in_y = in_data[:, :-1], in_data[:, -1]
ex_X, ex_y = ex_data[:, :-1], ex_data[:, -1]

## CPM

### Internal validation

In [None]:
cv = "loocv"
# cv = 10

In [None]:
y_actual, y_predict, _, feature_per, feature_freq = prediction(in_X, in_y, in_X, in_y, cv, feature_index=None, validation="kfold")
coef, pvalue = stats.spearmanr(y_actual, y_predict)
print(f"internal validation: coef is {coef}, p value is {pvalue:.5f}")

### External validation

In [None]:
# Select features based on fc-behavior correlation
feature_index = sorted(np.argwhere(feature_freq/in_X.shape[0] >= 1.0).reshape(-1, ).tolist()) # only for LOOCV
len(feature_index)

In [None]:
# using selected features to train on internal data and test on external data
inex_y_actual, inex_y_predict, inex_model_coefs, _, _ = prediction(in_X, in_y, ex_X, ex_y, cv, feature_index, "single")
inex_coef, inex_pvalue = stats.spearmanr(inex_y_actual, inex_y_predict)
print(f"external validation: coef is {inex_coef}, p value is {inex_pvalue:.5f}")
# permutation
inex_coefs = permutation(in_X, in_y, ex_X, ex_y, inex_coef, feature_index, "single")

### Rank features

In [None]:
rank_coefs = []
rank_pvalues = []
for i in range(0, len(feature_index)):
    feature_filter = np.delete(feature_index, i)
    # using selected features to train on internal data and test on external data
    tmp_inex_y_actual, tmp_inex_y_predict, tmp_inex_model_coefs, _, _ = prediction(in_X, in_y, ex_X, ex_y, cv, feature_filter, "single")
    tmp_inex_coef, tmp_inex_pvalue = stats.spearmanr(tmp_inex_y_actual, tmp_inex_y_predict)
    rank_coefs.append(tmp_inex_coef)
    rank_pvalues.append(tmp_inex_pvalue)
# compute feature importance
rank_features = inex_coef - np.array(rank_coefs)
# standardize rank
rank_features = (rank_features - rank_features.mean()) / rank_features.std()

## Plot

### FC matrix recovery

In [None]:
# FC counting matrix, 1 means connection between features, 0 means no connection
fc_sum = recovery_fc(select_rois, feature_index, 1)
# FC importance matrix
rank_fc = recovery_fc(select_rois, feature_index, rank_features)

### Plot results of intenal validation

In [None]:
# prediction vs. target
plot_corr(y_actual, y_predict, "./plot/internal_corr.svg", "orange")
print(f"internal validation: coef is {coef}, p value is {pvalue:.5f}")

### Plot results of extenal validation

In [None]:
# prediction vs. target
plot_corr(inex_y_actual, inex_y_predict, "./plot/external_corr.svg", "#F17D80")
print(f"external validation: coef is {inex_coef}, p value is {inex_pvalue:.5f}")

In [None]:
# permutation
plot_permutation(inex_coef, inex_coefs, "./plot/external_permutation.svg", "#F17D80")

### Plot summed number of features

In [None]:
# according to 14 communities
plot_conn(comm2rois, fc_sum, save_path="./plot/comm_fc_sum.svg", fontsize=16, colormap=plt.cm.Greens, annot=True)

In [None]:
# according to 4 networks
plot_conn(network2rois, fc_sum, save_path="./plot/network_fc_sum.svg", fontsize=20, colormap=plt.cm.Greens, annot=True)

### Plot feature importance

In [None]:
# according to 14 communities
plot_conn(comm2rois, rank_fc, save_path = f'./plot/comm_rank_model_fn-{len(feature_index)}.svg', fontsize=16)

In [None]:
# according to 4 networks
plot_conn(network2rois, rank_fc, save_path = f'./plot/network_rank_model_fn-{len(feature_index)}.svg', fontsize=20)

### Plot markers and selected features

In [None]:
# selected FC recovery 
feature_beh_corr = np.array([stats.pearsonr(in_X[:, i], in_y)[0] for i in feature_index])
ex_fc_coef = recovery_fc(select_rois, feature_index, feature_beh_corr)

In [None]:
from nilearn import plotting, datasets
power = datasets.fetch_coords_power_2011()
coords = np.vstack((power.rois["x"], power.rois["y"], power.rois["z"])).T
spheres_masker = input_data.NiftiSpheresMasker(seeds=coords, radius=5)

In [None]:
color_mapping = {
    "control": "#129490",
    "moral": "#70B77E",
    "reward": "#E0A890",
    "self": "#CE1483"
}
sub_coord = np.array([coords[i] for i in select_rois])
roi2network = {i: network for network, rois in network2rois.items() for i in rois}
node_color = [color_mapping[roi2network[i]] for i in select_rois]

In [None]:
marker_view = plotting.view_markers(sub_coord, marker_color=node_color, marker_size=8)
marker_view.open_in_browser()

In [None]:
ex_view = plotting.view_connectome(ex_fc_coef[np.ix_(select_rois, select_rois)], coords[select_rois], node_color=node_color, node_size=8, edge_threshold="99%", colorbar_fontsize=20, symmetric_cmap=False, edge_cmap="Reds")
ex_view.open_in_browser()