In [None]:
"""Example for plotting gradient data"""
import os.path as op
from glob import glob
import pickle

import matplotlib.image as mpimg
from matplotlib.gridspec import GridSpec
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import seaborn as sns
import ptitprince as pt

from utils import plot_gradient

In [None]:
# Global Variable
data_dir = op.abspath("../data")
results_dir = op.abspath("../results")
figures_dir = op.abspath("../figures")

# Visualize Silhouette Scores

Mean silhouette scores were visualized to determine relative performance across percentile-based, 
k-means-based, and KDE-based segmentations, with the highest mean silhouette score representing 
the “best” relative performance.

In [None]:
hd_scores_df = pd.read_csv(op.join(results_dir, "segmentation", "scores_high-dimensional.csv"))
ld_scores_df = pd.read_csv(op.join(results_dir, "segmentation", "scores_uni-dimensional.csv"))

In [None]:
metric_dict = {
    "silhouette": "Mean Silhouette Coefficient",
    "variance_ratio": "Variance Ratio",
    "cluster_separation": "Cluster Separation",
}

In [None]:
sns.set_style("ticks",{'axes.grid' : True})

metrics = ["silhouette", "variance_ratio", "cluster_separation"]

fig_1, axes_1 = plt.subplots(3, 1)
fig_1.set_size_inches(6.5, 15)
fig_2, axes_2 = plt.subplots(3, 1)
fig_2.set_size_inches(6.5, 15)

colors = ["#0A4D68", "#088395", "#05BFDB"]
hue_order = ["PCT", "KMeans", "KDE"]

for met_i, metric in enumerate(metrics):
    sns.lineplot(
        data=ld_scores_df,
        x="segment", 
        y=metric, 
        hue="method", 
        style="method",
        markers=True, 
        dashes=False, 
        hue_order=hue_order,
        palette=colors,
        ax=axes_1[met_i],
    )

    sns.lineplot(
        data=hd_scores_df,
        x="segment", 
        y=metric, 
        hue="component", 
        style="component",
        palette="rocket_r",
        markers=True, 
        dashes=False, 
        ax=axes_2[met_i],
    )
    if met_i == 2:
        # axes[met_i, 1].set_ylim([1, 6])
        axes_1[met_i].set_xlabel('Segment Solution', fontsize=18)
        axes_2[met_i].set_xlabel('Segment Solution', fontsize=18)
        axes_1[met_i].tick_params(axis='x', labelsize=16)
        axes_2[met_i].tick_params(axis='x', labelsize=16)
        legend_1 = axes_2[met_i].legend(title='Components', fontsize=14)
        legend_1.get_title().set_fontsize('16')
        legend_2 = axes_1[met_i].legend(title='Segmentation', fontsize=14)
        legend_2.get_title().set_fontsize('16')
    else:
        axes_1[met_i].set_xticklabels([])
        axes_1[met_i].set_xlabel('')
        axes_2[met_i].set_xticklabels([])
        axes_2[met_i].set_xlabel('')
        axes_2[met_i].legend_.remove()
        axes_1[met_i].legend_.remove()

    if met_i == 0:
        axes_2[met_i].set_title("High-Dimensional K-means Clustering", fontsize=20)
        axes_1[met_i].set_title("One-Dimensional Segmentation", fontsize=20)
    
    axes_1[met_i].set_ylabel(metric_dict[metric], fontsize=18)
    axes_2[met_i].set_ylabel(metric_dict[metric], fontsize=18)
    axes_2[met_i].tick_params(axis='y', labelsize=16)
    axes_1[met_i].tick_params(axis='y', labelsize=16)

    if met_i != 2:
        index_1 = hd_scores_df[metric].idxmax()
        seg_1 = hd_scores_df.loc[index_1, 'segment']
        value_1 = hd_scores_df[metric].max()
        index_0 = ld_scores_df[metric].idxmax()
        seg_0 = ld_scores_df.loc[index_0, 'segment']
        value_0 = ld_scores_df[metric].max()
    else:
        index_1 = hd_scores_df[metric].idxmin()
        seg_1 = hd_scores_df.loc[index_1, 'segment']
        value_1 = hd_scores_df[metric].min()
        index_0 = ld_scores_df[metric].idxmin()
        seg_0 = ld_scores_df.loc[index_0, 'segment']
        value_0 = ld_scores_df[metric].min()
    
    axes_2[met_i].scatter(seg_1, value_1, facecolors='none', edgecolors='red', s=150)
    axes_1[met_i].scatter(seg_0, value_0, facecolors='none', edgecolors='red', s=150)
    
# plt.savefig(op.join(figures_dir, "Fig", "silhouette", "silhouette_scores.eps"), bbox_inches="tight")
# plt.tight_layout()
fig_1.tight_layout()
fig_2.tight_layout()
fig_1.savefig(op.join(figures_dir, "Fig", "Fig-04.eps"))
fig_2.savefig(op.join(figures_dir, "Fig", "Fig-10.eps"))

# Visualize Silhouette Samples

In [None]:
# Get samples and labels data
samples_arrays = []
labels_lst = []
for method in ["PCT", "KMeans", "KDE"]:
    samples_arr_fn = op.join(results_dir, "segmentation", f"{method}_silhouette.npy")
    samples_arrays.append(np.load(samples_arr_fn))

    results_dict_fn = op.join(results_dir, "segmentation", f"{method}_results.pkl")
    with open(results_dict_fn, "rb") as results_dict_file:
        results_dict = pickle.load(results_dict_file)
    labels_lst.append(results_dict["labels"])

In [None]:
# Create a pandas dataframe (violin_df) for violin plots.
min_n_segments = 2
n_segments = 31
segment_sizes = np.arange(min_n_segments, n_segments + min_n_segments)

violin_df = pd.DataFrame()
segment_sizes_lst = []
sample_scores_lst = []
segmentation_names_lst = [] 
for samples_arr in samples_arrays:
    for segm_i in range(n_segments):
        segment_sizes_lst.append([str(segment_sizes[segm_i])] * samples_arr.shape[1])
        sample_scores_lst.append(samples_arr[segm_i, :])

n_sample_scores = n_segments * samples_arr.shape[1]
segmentation_names_lst = ["PCT"] * n_sample_scores + ["KMeans"] * n_sample_scores + ["KDE"] * n_sample_scores

violin_df["segment_sizes"] = np.hstack(segment_sizes_lst)
violin_df["segmentation"] = segmentation_names_lst
violin_df["sample_scores"] = np.hstack(sample_scores_lst)

In [None]:
sns.set_style("ticks")
colors = ["#0A4D68", "#088395", "#05BFDB"]
hue_order = ["PCT", "KMeans", "KDE"]

yticks = np.array([-0.6, -0.4, -0.2, 0, 0.2, 0.4, 0.6, 0.8])

fontsize = 12
color = plt.colormaps["viridis"]
ort = "v"
dy = "sample_scores"
dx = "segmentation"

for segm_i in range(31):
    segment_size = segm_i + 2

    fig, axes = plt.subplots(1, 4)
    fig.set_size_inches(7.5, 3)

    for method_i, method in enumerate(["PCT", "KMeans", "KDE"]):
        with open(op.join(results_dir, "segmentation", f"{method}_results.pkl"), "rb") as results_file:
            results_dict = pickle.load(results_file)
        bound_arr = results_dict["boundaries"][segm_i]
        
        imb_axis = axes[method_i + 1]

        norm = plt.Normalize(0, segment_size-1)
        x_min, x_max = np.round(bound_arr[0], 2), np.round(bound_arr[-1], 2)
        x_med = np.round((x_min + x_max) / 2, 2)
        for cluster_i in range(segment_size):
            boun_i, boun_j = (bound_arr[cluster_i], bound_arr[cluster_i + 1])
            # Aggregate the silhouette scores for samples belonging to
            # cluster i, and sort them
            ith_cluster_silhouette_values = samples_arrays[method_i][segm_i, labels_lst[method_i][segm_i] == cluster_i]

            ith_cluster_silhouette_values.sort()

            size_cluster_i = ith_cluster_silhouette_values.shape[0]
            
            imb_axis.fill_between(
                np.linspace(boun_i, boun_j, size_cluster_i),
                0,
                ith_cluster_silhouette_values,
                facecolor=color(norm(cluster_i)),
                edgecolor=color(norm(cluster_i)),
                alpha=1,
            )

        imb_axis.set_yticks(yticks)
        imb_axis.axes.yaxis.set_ticklabels([])
        imb_axis.set_xticks([x_min, x_med, x_max])
        imb_axis.set_xticklabels([x_min, x_med, x_max], fontsize=fontsize-2)

        if method_i == 1:
            imb_axis.set_xlabel("Cluster Imbalance", fontsize=fontsize)
        imb_axis.set_title(method)
        # The vertical line for average silhouette score of all the values
        y_line = ld_scores_df[ld_scores_df.segment == segment_size]["silhouette"].values[0]
        imb_axis.axhline(y=y_line, color="black", linestyle="--")
        imb_axis.grid(axis='y', which='major', color='gray', alpha=0.5)

    pt.half_violinplot(
        x=dx,
        y=dy,
        data=violin_df[violin_df.segment_sizes == str(segment_size)],
        palette=colors,
        bw=0.05,
        cut=0.0,
        scale="area",
        width=0.8,
        dodge=False,
        inner=None,
        orient=ort,
        ax=axes[0],
    )
    sns.stripplot(
        x=dx,
        y=dy,
        data=violin_df[violin_df.segment_sizes == str(segment_size)],
        hue_order=hue_order,
        palette=colors,
        edgecolor="white",
        dodge=False,
        size=1,
        jitter=1,
        zorder=0,
        orient=ort,
        ax=axes[0],
    )
    box_axe = sns.boxplot(
        x=dx,
        y=dy,
        data=violin_df[violin_df.segment_sizes == str(segment_size)],
        palette=colors,
        width=0.2,
        zorder=10,
        dodge=True,
        showcaps=True,
        showfliers=False,
        boxprops={"zorder": 9, "alpha": 0.8},
        whiskerprops={"color": "black", "zorder": 10},
        capprops={"color": "black", "zorder": 10},
        medianprops={"color": "black", "zorder": 10},
        saturation=1,
        orient=ort,
        ax=axes[0],
    )
    plt.setp(box_axe.collections + box_axe.artists, alpha=0.8)

    axes[0].set_title("Silhouette Distribution")
    # axes[0].set_xlabel("Silhouette Distribution", fontsize=fontsize)
    axes[0].set_ylabel("Silhouette Coefficient", fontsize=fontsize)
    axes[0].set_xlabel("")
    axes[0].set_yticks(yticks)
    axes[0].set_yticklabels(yticks, fontsize=fontsize-2)
    # axes[0].set_xticklabels([])
    axes[0].grid(axis='y', which='major', color='gray', alpha=0.5)
    
    fig.suptitle(f'Segment Solution: {segment_size:02d}', fontsize=fontsize)
    #fig.supylabel("Silhouette Coefficient", fontsize=20)
    fig.tight_layout()
    fig.savefig(op.join(figures_dir, "Fig", "silhouette", f"{segment_size:02d}_silhouette_samples.png"), bbox_inches="tight", dpi=500)
    print(f"\includegraphics[scale=1]{{{segment_size:02d}_silhouette_samples.png}}\n")
    plt.close()
    plt.clf()

## Combine Plot: Silhouette Coefficient Distribution + Visualize Cluster Balance

This plot is useful for determining cluster imbalance. We displays the silhouette 
coefficient for each sample per cluster, to visualize which clusters are dense and which are not.

In [None]:
# sns.set(style="whitegrid")
sns.set_style("ticks")

colors = ["#0A4D68", "#088395", "#05BFDB"]
hue_order = ["PCT", "KMeans", "KDE"]
yticks = np.array([-0.6, -0.4, -0.2, 0, 0.2, 0.4, 0.6, 0.8])

ort = "v"
dy = "sample_scores"
dx = "segmentation"

fig, axes_tpl = plt.subplots(3, 4)
fig.set_size_inches(15, 11)
for segm_i, segment_size in enumerate([3, 17, 32]):
    vio_axis = axes_tpl[segm_i, 0]

    for method_i, method in enumerate(["PCT", "KMeans", "KDE"]):
        with open(op.join(results_dir, "segmentation", f"new_{method}_results.pkl"), "rb") as results_file:
            results_dict = pickle.load(results_file)
        bound_arr = results_dict["boundaries"][segment_size-2]
        
        imb_axis = axes_tpl[segm_i, method_i + 1]

        norm = plt.Normalize(0, segment_size-1)
        color = cm.get_cmap("viridis")
        x_min, x_max = np.round(bound_arr[0], 2), np.round(bound_arr[-1], 2)
        x_med = np.round((x_min + x_max) / 2, 2)
        for cluster_i in range(segment_size):
            boun_i, boun_j = (bound_arr[cluster_i], bound_arr[cluster_i + 1])
            # Aggregate the silhouette scores for samples belonging to
            # cluster i, and sort them
            ith_cluster_silhouette_values = samples_arrays[method_i][segment_size-2, labels_lst[method_i][segment_size-2] == cluster_i]

            ith_cluster_silhouette_values.sort()

            size_cluster_i = ith_cluster_silhouette_values.shape[0]
            
            imb_axis.fill_between(
                np.linspace(boun_i, boun_j, size_cluster_i),
                0,
                ith_cluster_silhouette_values,
                facecolor=color(norm(cluster_i)),
                edgecolor=color(norm(cluster_i)),
                alpha=1,
            )

        # yticks = np.arange(-1, 1.5, 0.5)
        imb_axis.set_yticks(yticks)
        imb_axis.set_xticks([])  # Clear the yaxis labels / ticks
        imb_axis.axes.yaxis.set_ticklabels([])

        if segment_size == 32:
            imb_axis.set_xticks([x_min, x_med, x_max])
            imb_axis.set_xticklabels([x_min, x_med, x_max], fontsize=18)

        # The vertical line for average silhouette score of all the values
        y_line = ld_scores_df[(ld_scores_df.segment == segment_size) & (ld_scores_df.method == method)]["silhouette"].values[0]
        imb_axis.axhline(y=y_line, color="black", linestyle="--")
        imb_axis.grid(axis='y', which='major', color='gray', alpha=0.5)

    pt.half_violinplot(
        x=dx,
        y=dy,
        data=violin_df[violin_df.segment_sizes == str(segment_size)],
        palette=colors,
        bw=0.05,
        cut=0.0,
        scale="area",
        width=0.8,
        dodge=False,
        inner=None,
        orient=ort,
        ax=vio_axis,
    )
    sns.stripplot(
        x=dx,
        y=dy,
        data=violin_df[violin_df.segment_sizes == str(segment_size)],
        hue_order=hue_order,
        palette=colors,
        edgecolor="white",
        dodge=False,
        size=1,
        jitter=1,
        zorder=0,
        orient=ort,
        ax=vio_axis,
    )
    box_axe = sns.boxplot(
        x=dx,
        y=dy,
        data=violin_df[violin_df.segment_sizes == str(segment_size)],
        palette=colors,
        width=0.2,
        zorder=10,
        dodge=True,
        showcaps=True,
        showfliers=False,
        boxprops={"zorder": 9, "alpha": 0.8},
        whiskerprops={"color": "black", "zorder": 10},
        capprops={"color": "black", "zorder": 10},
        medianprops={"color": "black", "zorder": 10},
        saturation=1,
        orient=ort,
        ax=vio_axis,
    )
    plt.setp(box_axe.collections + box_axe.artists, alpha=0.8)

    vio_axis.set_ylabel("")
    vio_axis.set_xlabel("")
    vio_axis.set_yticks(yticks)
    vio_axis.set_yticklabels(yticks, fontsize=18)
    vio_axis.set_xticklabels([])
    vio_axis.grid(axis='y', which='major', color='gray', alpha=0.5)

fig.supylabel("Silhouette Coefficient", fontsize=20)
fig.tight_layout()
fig.savefig(op.join(figures_dir, "Fig", "silhouette", "silhouette_samples.png"), bbox_inches="tight", dpi=1000)
plt.show()

## Visualize Confidence Maps

Confidence (silhouette) values for each vertex with respect to its assigned network is
visualized. Regions close to the boundaries between networks were less confident of their 
assignment. The best value is 1 and the worst value is -1. Values near 0 indicate overlapping 
clusters.


### Percentile

In [None]:
percentile_samples_seg_path = op.join(results_dir, "segmentation", "PCT_confidence-maps")
percentile_samples_seg_lh_fnames = sorted(glob(op.join(percentile_samples_seg_path, "*desc-SilhouetteSamples_space-fsLR_den-32k_hemi-L_feature.func.gii")))
percentile_samples_seg_rh_fnames = sorted(glob(op.join(percentile_samples_seg_path, "*desc-SilhouetteSamples_space-fsLR_den-32k_hemi-R_feature.func.gii")))
percentile_samples_seg_fnames = zip(percentile_samples_seg_lh_fnames, percentile_samples_seg_rh_fnames)

plot_gradient(data_dir, percentile_samples_seg_fnames, cmap="afmhot", color_range=(-0.5,.5), out_dir=op.join(figures_dir, "Fig", "silhouette"))

### KMeans

In [None]:
kmeans_samples_seg_path = op.join(results_dir, "segmentation", "KMeans_confidence-maps")
kmeans_samples_seg_lh_fnames = sorted(glob(op.join(kmeans_samples_seg_path, "*desc-SilhouetteSamples_space-fsLR_den-32k_hemi-L_feature.func.gii")))
kmeans_samples_seg_rh_fnames = sorted(glob(op.join(kmeans_samples_seg_path, "*desc-SilhouetteSamples_space-fsLR_den-32k_hemi-R_feature.func.gii")))
kmeans_samples_seg_fnames = zip(kmeans_samples_seg_lh_fnames, kmeans_samples_seg_rh_fnames)

plot_gradient(data_dir, kmeans_samples_seg_fnames, cmap="afmhot", color_range=(-0.5,.5), out_dir=op.join(figures_dir, "Fig", "silhouette"))

### KDE

In [None]:
data_dir = "../data"

kde_samples_seg_path = op.join(results_dir, "segmentation", "KDE_confidence-maps")
kde_samples_seg_lh_fnames = sorted(glob(op.join(kde_samples_seg_path, "*desc-SilhouetteSamples_space-fsLR_den-32k_hemi-L_feature.func.gii")))
kde_samples_seg_rh_fnames = sorted(glob(op.join(kde_samples_seg_path, "*desc-SilhouetteSamples_space-fsLR_den-32k_hemi-R_feature.func.gii")))
kde_samples_seg_fnames = zip(kde_samples_seg_lh_fnames, kde_samples_seg_rh_fnames)

plot_gradient(data_dir, kde_samples_seg_fnames, cmap="afmhot", color_range=(-0.5,.5), out_dir=op.join(figures_dir, "Fig", "silhouette"))

In [None]:
full_vertices = 64984
hemi_vertices = full_vertices // 2
space = 4
n_cols = 3*space+1
n_rows = 5
w = 7.5
h = 9.5

img_cbar = op.join(figures_dir, "Fig", "silhouette", "silhouette_samples_cbar.png")

img_lbs = ["PCT", "KMeans", "KDE"]
pct_files = sorted(glob(op.join(figures_dir, "Fig", "silhouette", "PCT*-SilhouetteSamples.tiff")))
kms_files = sorted(glob(op.join(figures_dir, "Fig", "silhouette", "KMeans*-SilhouetteSamples.tiff")))
kde_files = sorted(glob(op.join(figures_dir, "Fig", "silhouette", "KDE*-SilhouetteSamples.tiff")))
step = 0
row = 0
for segment_size, (pct_file, kms_file, kde_file) in enumerate(zip(pct_files, kms_files, kde_files), start=2):
    add_title = False
    if step % 5 == 0:
        add_title = True
        
    step += 1
        
    if row == 0:
        fig = plt.figure(figsize=(w, h))
        fig.subplots_adjust(
            left=None, bottom=None, right=None, top=None, wspace=0.9, hspace=0
        )
        gs = GridSpec(n_rows, n_cols, figure=fig)

    col = 0
    for img_i, img_file in enumerate([pct_file, kms_file, kde_file]):
        img = mpimg.imread(img_file)
        ax = fig.add_subplot(gs[row, col : col + space], aspect="equal")
        ax.imshow(img)
        
        if img_i == 0:
            ax.set_xticks([])
            ax.set_yticks([])
            if add_title:
                ax.set_ylabel(f"Segment\nSolution\n\n\n\n{segment_size:02d}", rotation=0, labelpad=35, fontsize=12)
            else:
                ax.set_ylabel(f"\n{segment_size:02d}", rotation=0, labelpad=35, fontsize=12)
            plt.setp(ax.spines.values(), color=None)
        else:
            ax.set_axis_off()

        col += space

        if add_title:
            ax.set_title(img_lbs[img_i], fontsize=14)
    
    ax = fig.add_subplot(gs[row, n_cols-1], aspect="equal")
    img = mpimg.imread(img_cbar)
    ax.imshow(img)
    ax.set_axis_off()

    if row == 4 or segment_size == 32:
        row = 0
        fig.tight_layout(pad=0.1, w_pad=0.1)
        # plt.subplots_adjust(top=0.95)
        fig.savefig(op.join(figures_dir, "Fig", "silhouette", f"{segment_size:02d}_confidence-maps.eps"), bbox_inches="tight", dpi=500)
        print(f"\includegraphics[scale=1]{{{segment_size:02d}_confidence-maps.eps}}\n")
        plt.close()
        plt.clf()
    else:
        row += 1