In [None]:
"""Example for plotting gradient data"""
import os.path as op
from glob import glob
import pickle

import matplotlib.image as mpimg
from matplotlib.gridspec import GridSpec
import matplotlib.pyplot as plt

from utils import plot_gradient

In [None]:
data_dir = "../data"
figures_dir = op.abspath("../figures")

# Visualize segmentation

## Percentile

In [None]:
percent_grad_seg_path = "../results/segmentation/pct"
percent_grad_out_path = "../figures/Fig/segmentation/pct"
percent_grad_seg_lh_fnames = sorted(glob(op.join(percent_grad_seg_path, "*Percentile*_desc-Bin*-L_feature.func.gii")))
percent_grad_seg_rh_fnames = sorted(glob(op.join(percent_grad_seg_path, "*Percentile*_desc-Bin*-R_feature.func.gii")))
percent_grad_seg_fnames = zip(percent_grad_seg_lh_fnames, percent_grad_seg_rh_fnames)

plot_gradient(data_dir, percent_grad_seg_fnames, title=False, out_dir=percent_grad_out_path)

## KMeans

In [None]:
kmeans_grad_seg_path = "../results/segmentation/kmeans"
kmeans_grad_out_path = "../figures/Fig/segmentation/kmeans"
kmeans_grad_seg_lh_fnames = sorted(glob(op.join(kmeans_grad_seg_path, "*KMeans*_desc-Bin*-L_feature.func.gii")))
kmeans_grad_seg_rh_fnames = sorted(glob(op.join(kmeans_grad_seg_path, "*KMeans*_desc-Bin*-R_feature.func.gii")))
kmeans_grad_seg_fnames = zip(kmeans_grad_seg_lh_fnames, kmeans_grad_seg_rh_fnames)

plot_gradient(data_dir, kmeans_grad_seg_fnames, title=False, out_dir=kmeans_grad_out_path)

## KDE

In [None]:
kde_grad_seg_path = "../results/segmentation/kde"
kde_grad_out_path = "../figures/Fig/segmentation/kde"
kde_grad_seg_lh_fnames = sorted(glob(op.join(kde_grad_seg_path, "*KDE*_desc-Bin*-L_feature.func.gii")))
kde_grad_seg_rh_fnames = sorted(glob(op.join(kde_grad_seg_path, "*KDE*_desc-Bin*-L_feature.func.gii")))
kde_grad_seg_fnames = zip(kde_grad_seg_lh_fnames, kde_grad_seg_rh_fnames)

plot_gradient(data_dir, kde_grad_seg_fnames, title=False, out_dir=kde_grad_out_path)

In [None]:
n_cols = 3
n_rows = 5
w = 7.5
h = 9.5

img_lbs = ["PCT", "KMeans", "KDE"]
step = 0
row = 0
for segment_size in range(3, 33):
    pct_files = sorted(glob(op.join(figures_dir, "Fig", "segmentation", "pct", f"*{segment_size:02d}-*.tiff")))
    kms_files = sorted(glob(op.join(figures_dir, "Fig", "segmentation", "kmeans", f"*{segment_size:02d}-*.tiff")))
    kde_files = sorted(glob(op.join(figures_dir, "Fig", "segmentation", "kde", f"*{segment_size:02d}-*.tiff")))
    for segment_id, (pct_file, kms_file, kde_file) in enumerate(zip(pct_files, kms_files, kde_files), start=1):
        add_title = False
        if step % 5 == 0:
            add_title = True
        
        step += 1
        
        if row == 0:
            fig = plt.figure(figsize=(w, h))
            fig.subplots_adjust(
                left=None, bottom=None, right=None, top=None, wspace=0.9, hspace=None
            )
            gs = GridSpec(n_rows, n_cols, figure=fig)

        for img_i, img_file in enumerate([pct_file, kms_file, kde_file]):
            img = mpimg.imread(img_file)
            ax = fig.add_subplot(gs[row, img_i], aspect="equal")
            ax.imshow(img)

            if img_i == 0:
                ax.set_xticks([])
                ax.set_yticks([])
                if add_title:
                    ax.set_ylabel(f"Segment\nSolution\n\n\n\n{segment_size:02d}-{segment_id:02d}", rotation=0, labelpad=35, fontsize=12)
                else:
                    ax.set_ylabel(f"\n{segment_size:02d}-{segment_id:02d}", rotation=0, labelpad=35, fontsize=12)
                plt.setp(ax.spines.values(), color=None)
            else:
                ax.set_axis_off()
            
            if add_title:
                ax.set_title(img_lbs[img_i], fontsize=14)
        
        if row == 4:
            row = 0
            fig.tight_layout(pad=0.1, w_pad=0.6, h_pad=0.1)
            # plt.subplots_adjust(top=0.95)
            fig.savefig(op.join(figures_dir, "Fig", "segmentation", f"{segment_size:02d}-{segment_id:02d}_gradient-maps.eps"), bbox_inches="tight", dpi=500)
            print(f"\includegraphics[scale=1]{{{segment_size:02d}-{segment_id:02d}_gradient-maps.eps}}\n")
            plt.close()
            plt.clf()
        else:
            row += 1
        # suplabel = fig.supylabel(f"Segmt. Sol. - ID: {segment_size:02d} - {segment_id:02d}", fontsize=10)

In [None]:
with open(op.join(percent_grad_seg_path, "pct_results.pkl"), "rb") as results_file:
    pct_dict = pickle.load(results_file)
with open(op.join(kmeans_grad_seg_path, "kmeans_results.pkl"), "rb") as results_file:
    kmeans_dict = pickle.load(results_file)
with open(op.join(kde_grad_seg_path, "kde_results.pkl"), "rb") as results_file:
    kde_dict = pickle.load(results_file)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

n_segments = 30
min_n_segments = 3
colors = ["#05BFDB", "#088395", "#0A4D68"]

boundaries = []
for seg_i, n_segment in enumerate(range(min_n_segments, n_segments + min_n_segments)):
    
    fig, ax = plt.subplots()
    fig.set_size_inches(10, 2)
    
    yticks = [0.5]
    for dict_i, results_dict in enumerate([kde_dict, kmeans_dict, pct_dict]):
        bound_arr = results_dict["boundaries"][seg_i]
        peaks_arr = results_dict["peaks"][seg_i]

        x = []
        y = []
        x_err_i = []
        x_err_j = []
        for i in range(n_segment):
            bound_i, bound_j = bound_arr[i], bound_arr[i+1]
            peak = peaks_arr[i]
            x.append(peak)
            y.append(dict_i + 1)
            x_err_i.append(abs(peak - bound_i))
            x_err_j.append(abs(peak - bound_j))
    
        x_err = [x_err_i, x_err_j]
        (_, caps, _) = ax.errorbar(
            x, 
            y, 
            xerr=x_err, 
            fmt='o', 
            capsize=15, 
            elinewidth=3, 
            ecolor=colors[dict_i],
            markerfacecolor="r",
            markeredgecolor="r",
        )

        for cap in caps:
            cap.set_markeredgewidth(3)

        yticks.append(dict_i + 1)
    yticks.append(dict_i + 1.5)

    plt.yticks(yticks)
    ax.axis('off')
    fig.tight_layout()
    plt.savefig(op.join("../figures/Fig/segmentation", f"bound_{n_segment:02d}.eps"), bbox_inches="tight")
    
    plt.show()