In [None]:
import json
import h5py
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
from neuromaps.datasets import fetch_atlas
from heteromodes.utils import load_hmap, unmask
from heteromodes.plotting import plot_brain, plot_heatmap

sns.set_theme(style="white")

In [None]:
with open(f"{PROJ_DIR}/scripts/model_rest/results_config.json", "r") as f:
    config = json.load(f)

# id = 39
# hmap_labels = ["None", "myelinmap", "thickness"]
# hmap_labels_plotting = ["None", "T1w/T2w", "Cortical thickness"]
id = config["id"]
hmap_labels = config["hmap_labels"]
hmap_labels_plotting = config["hmap_labels_plotting"]

print(f"Loading id:{id} results")

In [None]:
edge_fc_xval, node_fc_xval, phase_xval, fc_matrices_xval, phase_maps_xval = [], [], [], [], []
alpha_best_xval = {}
for hmap_label in hmap_labels:
    file = f"{PROJ_DIR}/results/model_rest/group/id-{id}/{hmap_label}_results_crossval-True.h5"

    with h5py.File(file, 'r+') as f:
        # if "fc_matrices" in f.keys():
        #     f.move("fc_matrices", "best_fcs")
        metrics = f.attrs['metrics']

        edge_fc_xval.append(np.array(f['edge_fc_test'][:]).flatten())
        node_fc_xval.append(np.array(f['node_fc_test'][:]).flatten())
        phase_xval.append(np.abs(np.array(f['phase_test'][:]).flatten()))

        best_comb_xval = np.mean(np.array(f['best_combs'][:]), axis=0)
        alpha_best_xval[hmap_label] = best_comb_xval[0]

        fc_matrices_xval.append(np.mean(np.array(f['best_fcs'][:]), axis=2))
        if "phase" in metrics:
            phase_maps_xval.append(np.mean(np.array(f['best_phase_maps'][:]), axis=1))

print(f"alpha_best: {alpha_best_xval}")
print(np.shape(fc_matrices_xval), np.shape(phase_maps_xval))

## Evaluation metrics for each model

In [None]:
# Set plotting defaults
fs_ax = 15
fs_title = 20
plt.rcParams['xtick.major.size'] = 5
plt.rcParams['xtick.major.width'] = 1.5
plt.rcParams['xtick.bottom'] = True
plt.rcParams['ytick.left'] = True

fig, axs = plt.subplots(1, len(metrics)+1, figsize=(len(metrics)*9, 5))
if len(metrics) == 1:
    axs = [axs]

pnts = np.linspace(0, np.pi * 2, 24)
circ = np.c_[np.sin(pnts) / 2, -np.cos(pnts) / 2]
vert = np.r_[circ, circ[::-1] * .7]
open_circle = mpl.path.Path(vert)

# Plot edge-level fc
i = 0
if "edge_fc" in metrics:
    sns.stripplot(data=edge_fc_xval, ax=axs[i], marker=open_circle, size=4, alpha=0.5, zorder=1, linewidth=1)
    sns.violinplot(data=edge_fc_xval, ax=axs[i], density_norm="count", fill=False, linewidth=3, inner="box", inner_kws={"box_width": 5, "whis_width": 1, "color": "black"})
    # sns.violinplot(data=edge_fc, ax=axs[i], density_norm="count", inner="box")
    axs[i].set_xticks(ticks=range(len(hmap_labels_plotting)))
    axs[i].set_xticklabels(labels=hmap_labels_plotting, ha='right', fontsize=15)
    axs[i].tick_params(axis='x', labelrotation=45)
    axs[i].tick_params(axis='y', labelsize=fs_ax)
    axs[i].set_title("Edge-level FC fit", fontsize=fs_title)
    axs[i].set_xlabel("Heterogeneity map", fontsize=fs_ax)
    axs[i].set_ylabel("Pearson's r", fontsize=fs_ax)
    axs[i].spines['top'].set_visible(False)
    axs[i].spines['right'].set_visible(False)
    axs[i].set_ylim(0, 1)
    # for i, violin in enumerate(axs[0].collections[:len(hmap_labels)]):  # ::2 to skip the body parts, focusing on the borders
    #     violin.set_edgecolor(sns.color_palette()[i])
    #     violin.set_linewidth(2)

    i += 1

# Plot node-level fc
if "node_fc" in metrics:
    sns.stripplot(data=node_fc_xval, ax=axs[i], marker=open_circle, size=4, alpha=0.5, zorder=1, linewidth=1)
    sns.violinplot(data=node_fc_xval, ax=axs[i], density_norm="count", fill=False, linewidth=3, inner="box", inner_kws={"box_width": 5, "whis_width": 1, "color": "black"})
    # sns.violinplot(data=node_fc, ax=axs[i], density_norm="count", inner="box")
    axs[i].set_xticks(ticks=range(len(hmap_labels_plotting)), labels=hmap_labels_plotting, ha='right', fontsize=15)
    axs[i].tick_params(axis='x', labelrotation=45)
    axs[i].tick_params(axis='y', labelsize=fs_ax)
    axs[i].set_title("Node-level FC fit", fontsize=fs_title)
    axs[i].set_xlabel("Heterogeneity map", fontsize=fs_ax)
    axs[i].set_ylabel("Pearson's r", fontsize=fs_ax)
    axs[i].spines['top'].set_visible(False)
    axs[i].spines['right'].set_visible(False)
    axs[i].set_ylim(0, 1)

    i += 1

# Plot phase
if "phase" in metrics:
    sns.stripplot(data=phase_xval, ax=axs[i], marker=open_circle, size=4, alpha=0.5, zorder=1, linewidth=1)
    sns.violinplot(data=phase_xval, ax=axs[i], density_norm="count", fill=False, linewidth=3, inner="box", inner_kws={"box_width": 5, "whis_width": 1, "color": "black"})
    # sns.violinplot(data=phase, ax=axs[i], density_norm="count", inner="box")
    axs[i].set_xticks(ticks=range(len(hmap_labels_plotting)), labels=hmap_labels_plotting, ha='right', fontsize=15)
    axs[i].tick_params(axis='x', labelrotation=45)
    axs[i].tick_params(axis='y', labelsize=fs_ax)
    axs[i].set_title("DPI fit", fontsize=fs_title)
    axs[i].set_xlabel("Heterogeneity map", fontsize=fs_ax)
    axs[i].set_ylabel("Pearson's r", fontsize=fs_ax)
    axs[i].spines['top'].set_visible(False)
    axs[i].spines['right'].set_visible(False)
    axs[i].set_ylim(0, 1)

    i += 1

# Plot a vertical line separating the last metric from the combined metric
axs[i].axvline(x=len(hmap_labels_plotting)-0.5, color='black', linewidth=1)

# Plot combined metric
combined = np.array(edge_fc_xval) + np.array(node_fc_xval) + np.array(phase_xval)
sns.stripplot(data=combined.T, ax=axs[i], marker=open_circle, size=4, alpha=0.5, zorder=1, linewidth=1)
sns.violinplot(data=combined.T, ax=axs[i], density_norm="count", fill=False, linewidth=3, inner="box", inner_kws={"box_width": 5, "whis_width": 1, "color": "black"})
# sns.violinplot(data=combined.T, ax=axs[i], density_norm="count", inner="box")
axs[i].set_xticks(ticks=range(len(hmap_labels_plotting)), labels=hmap_labels_plotting, ha='right', fontsize=15)
axs[i].tick_params(axis='x', labelrotation=45)
axs[i].tick_params(axis='y', labelsize=fs_ax)
axs[i].set_title("Overall fit", fontsize=fs_title)
axs[i].set_xlabel("Heterogeneity map", fontsize=fs_ax)
axs[i].set_ylabel("Combined metric", fontsize=fs_ax)
axs[i].spines['top'].set_visible(False)
axs[i].spines['right'].set_visible(False)
axs[i].set_ylim(0, 2.25)

plt.show()

## Plot model fit scores

In [None]:
edge_fc_fit, node_fc_fit, phase_fit, fc_matrix_fit, phase_map_fit = [], [], [], [], []
for hmap_label in hmap_labels:
    file = f"{PROJ_DIR}/results/model_rest/group/id-{id}/{hmap_label}_results_crossval-False.h5"
    try:
        with h5py.File(file, 'r+') as f:
            metrics = f.attrs['metrics']

            best_ind = np.argmax(f['edge_fc_corr'][:] + f['node_fc_corr'][:] + f['phase_corr'][:])
            edge_fc_fit.append(f['edge_fc_corr'][best_ind])
            node_fc_fit.append(f['node_fc_corr'][best_ind])
            phase_fit.append(f['phase_corr'][best_ind])
    except:
        print(f"Could not load {file}")

if len(edge_fc_fit) > 0:
    # Plot evaluation metrics
    fig, axs = plt.subplots(1, len(metrics), figsize=(len(metrics)*9, 5))
    if len(metrics) == 1:
        axs = [axs]

    # Plot edge-level fc
    i = 0
    if "edge_fc" in metrics:
        axs[i].plot(edge_fc_fit, '.')
        axs[i].set_xticks(ticks=range(len(hmap_labels_plotting)), labels=hmap_labels_plotting, ha='right', fontsize=15)
        axs[i].tick_params(axis='x', labelrotation=45)
        axs[i].tick_params(axis='y', labelsize=fs_ax)
        axs[i].set_title("Edge-level FC fit", fontsize=fs_title)
        axs[i].set_xlabel("Heterogeneity map", fontsize=fs_ax)
        axs[i].set_ylabel("Pearson's r", fontsize=fs_ax)
        axs[i].spines['top'].set_visible(False)
        axs[i].spines['right'].set_visible(False)
        axs[i].set_ylim(0, 1)
        i += 1

    # Plot node-level fc
    if "node_fc" in metrics:
        axs[i].plot(node_fc_fit, '.')
        axs[i].set_xticks(ticks=range(len(hmap_labels_plotting)), labels=hmap_labels_plotting, ha='right', fontsize=15)
        axs[i].tick_params(axis='x', labelrotation=45)
        axs[i].tick_params(axis='y', labelsize=fs_ax)
        axs[i].set_title("Node-level FC fit", fontsize=fs_title)
        axs[i].set_xlabel("Heterogeneity map", fontsize=fs_ax)
        axs[i].set_ylabel("Pearson's r", fontsize=fs_ax)
        axs[i].spines['top'].set_visible(False)
        axs[i].spines['right'].set_visible(False)
        axs[i].set_ylim(0, 1)
        i += 1

    # Plot phase
    if "phase" in metrics:
        axs[i].plot(phase_fit, '.')
        axs[i].set_xticks(ticks=range(len(hmap_labels_plotting)), labels=hmap_labels_plotting, ha='right', fontsize=15)
        axs[i].tick_params(axis='x', labelrotation=45)
        axs[i].tick_params(axis='y', labelsize=fs_ax)
        axs[i].set_title("Phase fit", fontsize=fs_title)
        axs[i].set_xlabel("Heterogeneity map", fontsize=fs_ax)
        axs[i].set_ylabel("Pearson's r", fontsize=fs_ax)
        axs[i].spines['top'].set_visible(False)
        axs[i].spines['right'].set_visible(False)
        axs[i].set_ylim(0, 1)
        i += 1

    plt.show()

## Alpha landscape

In [None]:
# Load homogeneous results
file = f"{PROJ_DIR}/results/model_rest/group/id-{id}/None_results_crossval-True.h5"
with h5py.File(file, 'r') as f:
    edge_fc_hom = np.mean(np.array(f['edge_fc_test'][:]).flatten())
    node_fc_hom= np.mean(np.array(f['node_fc_test'][:]).flatten())
    phase_hom = np.mean(np.array(f['phase_test'][:]).flatten())
    combined_hom = edge_fc_hom + node_fc_hom + phase_hom

In [None]:
edge_fc_land, node_fc_land, phase_land, combined_land = [], [], [], []
best_alpha, alpha_vals = [], []
for i, hmap_label in enumerate(hmap_labels[1:]):
    file = f"{PROJ_DIR}/results/model_rest/group/id-{id}/{hmap_label}_results_crossval-True.h5"

    with h5py.File(file, 'r') as f:
        edge_fc_land.append(np.mean(f['edge_fc_train'][:], axis=0))
        node_fc_land.append(np.mean(f['node_fc_train'][:], axis=0))
        phase_land.append(np.mean(np.abs(f['phase_train'][:]), axis=0))
        combined_land.append(np.mean(f['combined_train'][:], axis=0))

        alpha_vals.append(f['combs'][:, 0])
        best_alpha.append(np.mean(f['best_combs'], axis=0)[0])

In [None]:
if len(hmap_labels) > 2:
    # fig, axs  = plt.subplots(1, len(hmap_labels[1:]), figsize=(len(hmap_labels[1:])*9, 6))
    fig, axs = plt.subplots(2, int(np.ceil(len(hmap_labels[1:])/2)), figsize=(len(hmap_labels[1:])*5, 15))
    axs = axs.flatten()
else:
    fig, axs  = plt.subplots(1, 1, figsize=(9, 6))
    axs = [axs]

for i, hmap_label in enumerate(hmap_labels[1:]):
    # Plot homogeneous
    if "edge_fc" in metrics:
        axs[i].axhline(edge_fc_hom, label="Edge FC (Hom.)", linestyle="--", color="tab:blue")
        axs[i].plot(alpha_vals[i], np.array(edge_fc_land[i]), label="Edge FC (Het.)", color="tab:blue")
    if "node_fc" in metrics:
        axs[i].axhline(node_fc_hom, label="Node FC (Hom.)", linestyle="--", color="tab:orange")
        axs[i].plot(alpha_vals[i], np.array(node_fc_land[i]), label="Node FC (Het.)", color="tab:orange")
    if "phase" in metrics:
        axs[i].axhline(phase_hom, label="Phase (Hom.)", linestyle="--", color="tab:green")
        axs[i].plot(alpha_vals[i], np.array(phase_land[i]), label="Phase (Het.)", color="tab:green")

    # Plot combined
    if len(metrics) > 1:
        axs[i].axhline(combined_hom, label="Combined (Hom.)", linestyle="--", color="tab:red")
        axs[i].plot(alpha_vals[i], np.array(combined_land[i]), label="Combined (Het.)", color="tab:red")

    # Plot best alpha
    axs[i].axvline(best_alpha[i], linestyle="-.", color="black", label="Best alpha")

    # Set labels
    axs[i].set_xticks(alpha_vals[i])
    axs[i].set_xticklabels([f"{x:.1f}" for x in alpha_vals[i]], fontsize=9)
    axs[i].set_title(hmap_labels[i+1], fontsize=20)
    axs[i].set_xlabel("Alpha")
    axs[i].set_ylabel("Pearson's r")
    axs[i].set_ylim(-1, 3)
    axs[i].legend()

# If there are an odd number of plots, remove the last axis
if len(hmap_labels[1:]) % 2 == 1:
    axs[-1].set_axis_off()

plt.suptitle(f"Alpha landsacpes for id: {id}", fontsize=30)
plt.show()

## Labels

In [None]:
hmap_labels_plotting_alpha = [f"{label}\n($\\alpha$ = {alpha_best_xval[hmap_label]:.2f})" for label, hmap_label in zip(hmap_labels_plotting, hmap_labels)]
n_labels = len(hmap_labels_plotting_alpha)

fig, axs = plt.subplots(1, len(hmap_labels), figsize=(5*len(hmap_labels), 2))
axs = axs.flatten()
for i, label in enumerate(hmap_labels_plotting_alpha):
    axs[i].text(0.5, 0.5, label,
            fontsize=25,
            bbox=dict(facecolor='white', edgecolor=sns.color_palette()[i], boxstyle='round,pad=0.3', linewidth=6),
            ha='center', va='center')
    # Remove axes
    axs[i].set_xticks([])
    axs[i].set_yticks([])
    axs[i].set_frame_on(False)

plt.show()

## FC matrices

In [None]:
fig, axs = plt.subplots(1, len(hmap_labels), figsize=(5*len(hmap_labels), 5))
for i in range(len(hmap_labels)):
    plot_heatmap(fc_matrices_xval[i], ax=axs[i], cmap="seismic", cbar=True, center=0)

plt.show()

## Phase maps

In [None]:
den = "4k"
fslr = fetch_atlas("fslr", den)
surf = fslr["inflated"][0]
medmask = h5py.File(f"{PROJ_DIR}/results/model_rest/group/id-{id}/None_results_crossval-True.h5", 'r')['medmask'][:]

In [None]:
if "phase" in metrics:
    datap = unmask(np.array(phase_maps_xval).T, medmask)
    fig = plot_brain(surf, datap, labels=hmap_labels_plotting, cbar=False,
                    cbar_kws=dict(fontsize=25), cmap="turbo")
    plt.show()