In [None]:
"""Example for plotting gradient data"""
import os.path as op
from glob import glob
import itertools
import gc

import numpy as np
from neuromaps.datasets import fetch_fslr
from surfplot import Plot
from gradec.fetcher import _fetch_metamaps
from surfplot.utils import add_fslr_medial_wall, threshold
import matplotlib.pyplot as plt
from nilearn.plotting.cm import _cmap_d as nilearn_cmaps
import pandas as pd

In [None]:
data_dir = "../data"

In [None]:
def plot_surf_maps(lh_grad, rh_grad, threshold_, color_range, cmap, dpi, data_dir, out_filename):
    neuromaps_dir = op.join(data_dir, "neuromaps")

    surfaces = fetch_fslr(density="32k", data_dir=neuromaps_dir)
    lh, rh = surfaces["inflated"]
    sulc_lh, sulc_rh = surfaces["sulc"]

    lh_grad = threshold(lh_grad, threshold_)
    rh_grad = threshold(rh_grad, threshold_)

    p = Plot(lh, views="lateral")
    p.add_layer({"left": sulc_lh}, cmap="binary_r", cbar=False)
    p.add_layer({"left": lh_grad}, cmap=cmap, cbar=False, color_range=color_range,)
    fig = p.build()

    fig.savefig(out_filename, bbox_inches="tight", dpi=dpi, transparent=True)
    fig = None
    plt.close()
    gc.collect()
    plt.clf()

In [None]:
methods = ["Percentile", "KMeans", "KDE"]
dset_names = ["neurosynth", "neuroquery"]
models = ["term", "lda", "gclda"]

full_vertices = 64984
hemi_vertices = full_vertices // 2

total_methods = 18
color_map = plt.get_cmap("tab20")
segmentations = [3, 17, 32]
seg_sols = [[1, 2, 3], [1, 9, 17], [1, 16, 32]]
data_df = pd.read_csv("../results/performance/performance.tsv", delimiter="\t")
map_out_path = "../figures/Fig/decoding"
features_lst = []
prefix_lst = []
for seg_i, segmentation in enumerate(segmentations):
    print(f"Segmentation: {segmentation}")
    for seg_sol in seg_sols[seg_i]:
        print(f"\tSegmentation solution: {seg_sol}")
        sub_features_lst = []
        for dset_i, dset_name in enumerate(dset_names):
            sub_class_lst = []
            sub_features_dset_lst = []
            for model_i, model in enumerate(models):
                for iter_i, method in enumerate(methods):
                
                    method_name = f"{model}_{dset_name}_{method}"
                    temp_df = data_df[
                        (data_df["method"] == method_name) & 
                        (data_df["segment_solution"] == segmentation) &
                        (data_df["segment"] == seg_sol)
                    ]
                    corr_idx = temp_df["corr_idx"].values[0]
                    corr_idx_str = f"{corr_idx:04d}" if model == "term" else f"{corr_idx:03d}"
                    """
                    # Get gradient maps
                    maps_fslr = _fetch_metamaps(dset_name, model, data_dir=data_dir)
                    data = maps_fslr[corr_idx, :]
                    threshold_ = np.percentile(data, 80) if model == "gclda" else 2
                    cmap = "afmhot" if model == "gclda" else nilearn_cmaps["cold_hot"]
                    max_val = round(np.max(np.abs(data)), 2)
                    range_ = (0, max_val) if model == "gclda" else (-max_val, max_val)
                    data = add_fslr_medial_wall(data)
                    data_lh, data_rh = data[:hemi_vertices], data[hemi_vertices:full_vertices]
                    
                    prefix = f"{segmentation:02d}-{seg_sol:02d}-{model_i}-{dset_i}-{iter_i}"
                    out_filename = op.join(map_out_path, f"maps_{prefix}_{dset_name}-{model}-{method}.tiff")
                    plot_surf_maps(data_lh, data_rh, threshold_, range_, cmap, 100, data_dir, out_filename)
                    """
                    feature = temp_df["features"].values[0]

                    if model != "term":
                        feature = "; ".join(feature.split("_")[1:])

                    sub_features_dset_lst.append(feature)
                    sub_class_lst.append(temp_df["classification"].values[0])
            
            sub_features_lst.append(sub_features_dset_lst)
            
            print(sub_class_lst)

            # Plot classification
            fig, ax = plt.subplots(1, 1)
            fig.set_size_inches(4, 4)

            n_methods = len(sub_class_lst)
            y_coords = np.linspace(1, 0, n_methods)
            colors = color_map.colors[:total_methods][::2] if dset_name == "neurosynth" else color_map.colors[:total_methods][1::2]

            for y_coord, color, text in zip(y_coords, colors, sub_class_lst):
                text_kwargs = dict(ha="center", va="center", weight='bold', fontsize=22, color=color)
                text = ax.text(0.5, y_coord, text, **text_kwargs)

            ax.axis('off')
            fig.tight_layout()
            plt.savefig(op.join("./Fig", "performance", f"classification-{dset_name}-{segmentation:02d}-{seg_sol:02d}.eps"), bbox_inches="tight", transparent=True)
            plt.close()
            gc.collect()
            plt.clf()

        sub_features_lst = [item for pair in zip(sub_features_lst[0], sub_features_lst[1]) for item in pair]
        print(sub_features_lst)
        
        fig, ax = plt.subplots(1, 1)
        fig.set_size_inches(7, 8)

        n_methods = len(sub_features_lst)
        y_coords = np.linspace(1, 0, n_methods)

        colors = color_map.colors[:total_methods]

        for y_coord, color, text in zip(y_coords, colors, sub_features_lst):
            text_kwargs = dict(ha="center", va="center", weight='bold', fontsize=22, color=color)
            text = ax.text(0.5, y_coord, text, **text_kwargs)

        ax.axis('off')
        fig.tight_layout()
        plt.savefig(op.join("./Fig", "performance", f"features-{segmentation:02d}-{seg_sol:02d}.eps"), bbox_inches="tight", transparent=True)
        plt.close()
        gc.collect()
        plt.clf()


In [None]:
cotegories = np.array(["Functional", "Clinical", "Anatomical", "Non-Specific"])

methods = ["Percentile", "KMeans", "KDE"]
dset_names = ["neurosynth", "neuroquery"]
models = ["term", "lda", "gclda"]

total_methods = 18
color_map = plt.get_cmap("tab20")
data_df = pd.read_csv("../results/performance/performance.tsv", delimiter="\t")
features_lst = []
prefix_lst = []
for seg_sol in range(3,33):
    sub_class_lst = []
    seg_sol_lst = []
    data2plot_df = pd.DataFrame()
    for seg_id in range(1, seg_sol+1):
        for dset_name, model, method in itertools.product(dset_names, models, methods):
            method_name = f"{model}_{dset_name}_{method}"
            temp_df = data_df[
                (data_df["method"] == method_name) & 
                (data_df["segment_solution"] == seg_sol) &
                (data_df["segment"] == seg_id)
            ]
            corr_idx = temp_df["corr_idx"].values[0]
            corr_idx_str = f"{corr_idx:04d}" if model == "term" else f"{corr_idx:03d}"


            sub_class_lst.append(temp_df["classification"].values[0])
            seg_sol_lst.append(seg_id)
    
    data2plot_df["segment"] = seg_sol_lst
    data2plot_df["classification"] = sub_class_lst

    cross_data_prop_df = pd.crosstab(index=data2plot_df["segment"],
                             columns=data2plot_df["classification"],
                             normalize="index")
    for category in cotegories:
        if category not in cross_data_prop_df.columns:
            cross_data_prop_df[category] = [0]*len(cross_data_prop_df)

    cross_data_prop_df = cross_data_prop_df[cotegories]
    cross_data_prop_df = cross_data_prop_df.sort_index(ascending=True)

    fontsize = 11

    colors = ["#393E46", '#6D9886', '#F2E7D5', '#F7F7F7']

    fig, ax = plt.subplots(1, 1)
    fig.set_size_inches(9 + seg_id*0.2, 4)

    cross_data_prop_df.plot(
        kind='bar', 
        stacked=True, 
        color=colors,
        edgecolor='white', 
        linewidth=2,
        width=0.9,
        ax=ax,
    )

    ax.tick_params(axis='x', rotation=0)
    ax.legend(
        loc="upper left",
        bbox_to_anchor=(1, 1),
        ncol=1,
        fontsize=fontsize,
    )
    plt.xticks(fontsize=fontsize)
    plt.yticks(fontsize=fontsize)
    ax.set_xlabel("Segment ID", fontsize=fontsize+2)
    ax.set_ylabel("Proportion", fontsize=fontsize+2)
    ax.set_title(f"Segment Solution: {seg_sol:02d}", fontsize=fontsize+2)

    plt.savefig(op.join("./Fig", "classification", f"class_segsol-{seg_sol}.eps"), bbox_inches="tight")
    plt.close()
    gc.collect()
    

In [None]:
for i in range(3, 33):
    print(f"\includegraphics[scale=0.47]{{class_segsol-{i}.eps}}\n")