In [None]:
import math
import os.path as op
from ast import literal_eval
import itertools
import gc

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import matplotlib.cm as cm

from utils import _get_twfrequencies

In [None]:
def plot_radar(corrs, features, model, fig=None, ax=None, out_fig=None):
    n_rows = 10 if len(corrs) > 10 else len(corrs)
    angle_zero = 0
    fontsize = 36
    
    corrs = corrs[:n_rows]
    features = features[:n_rows]
    angles = [(angle_zero + (n / float(n_rows) * 2 * np.pi)) for n in range(n_rows)]
    if model == "lda" or model == "gclda":
        features = ["\n".join(feature.split("_")[1:]).replace(" ", "\n") for feature in features]
    else:
        features = [feature.replace(" ", "\n") for feature in features]

    roundup_corr = math.ceil(corrs.max() * 10) / 10

    # Define color scheme
    plt.rcParams["text.color"] = "#1f1f1f"
    cmap = cm.get_cmap("YlOrRd")
    norm = plt.Normalize(vmin=corrs.min(), vmax=corrs.max())
    colors = cmap(norm(corrs))

    # Plot radar
    if fig is None and ax is None:
        fig, ax = plt.subplots(figsize=(9, 9), subplot_kw={"projection": "polar"})
    
    ax.set_theta_offset(0)
    ax.set_ylim(-0.1, roundup_corr)

    ax.bar(angles, corrs, color=colors, alpha=0.9, width=0.52, zorder=10)  
    ax.vlines(angles, 0, roundup_corr, color="grey", ls=(0, (4, 4)), zorder=11)

    ax.set_xticks(angles)
    ax.set_xticklabels(features, size=fontsize, zorder=13)

    ax.xaxis.grid(False)

    step = 0.10000000000000009
    yticks = np.round(np.arange(0, roundup_corr + step, step), 1)
    ax.set_yticklabels([])
    ax.set_yticks(yticks)

    ax.spines["start"].set_color("none")
    ax.spines["polar"].set_color("none")

    xticks = ax.xaxis.get_major_ticks()
    [xtick.set_pad(90) for xtick in xticks]

    sep = 0.06
    [
        ax.text(np.pi / 2, ytick - sep, f"{ytick}", ha="center", size=fontsize-2, color="grey", zorder=12) 
        for ytick in yticks
    ]

    if out_fig is not None:
        fig.savefig(out_fig, bbox_inches="tight")
        plt.close()
        gc.collect()

In [None]:
from wordcloud import WordCloud

def plot_cloud(features_list, frequencies, corrs, model, fig=None, ax=None, out_fig=None):
    frequencies_dict = {}
    if model == "lda" or model == "gclda":
        for features, frequency, corr in zip(features_list, frequencies, corrs):
            #frequency = literal_eval(frequency_str)
            for word, freq in zip(features, frequency):
                if word not in frequencies_dict:
                    frequencies_dict[word] = freq * corr
    else:
        for word, corr in zip(features_list, corrs):
            if word not in frequencies_dict:
                frequencies_dict[word] = corr
    
    dpi = 100
    w = 9
    h = 5
    if fig is None and ax is None:
        fig, ax = plt.subplots(figsize=(w, h))
    
    wc = WordCloud(
        width=w * dpi,
        height=h * dpi,
        background_color="white", 
        random_state=0, 
        colormap="YlOrRd"
    )
    wc.generate_from_frequencies(frequencies=frequencies_dict)
    ax.imshow(wc)
    # ax.axis("off")
    ax.get_xaxis().set_ticks([])
    ax.get_yaxis().set_ticks([])
    for spine in ax.spines.values():
        spine.set_visible(False)
    
    if out_fig is not None:
        fig.savefig(out_fig, bbox_inches="tight", dpi=dpi)
        # plt.close()
        gc.collect()

In [None]:
methods = ["Percentile", "KMeans", "KDE"]
dset_names = ["neurosynth", "neuroquery"]
models = ["term", "lda", "gclda"]

classifications_dir = op.join("../data/classification")
models_dir = op.join("../data/models")

n_segmentations = 30
data_lst = []
for model, dset_name, method in itertools.product(models, dset_names, methods):
    corr_dir = op.join("../results/decoding", f"{dset_name}_{model}_corr_{method}")
    
    # Data for wordcloud
    frequencies = (
        _get_twfrequencies(dset_name, model, 3, models_dir)
        if model in ["lda", "gclda"]
        else None
    )

    tmp_data_lst = []
    for seg_sol in range(3, 33):
        
        corr_file = op.join(corr_dir, f"corrs_{seg_sol:02d}.csv")
        pval_file = op.join(corr_dir, f"pvals-FDR_{seg_sol:02d}.csv")
        corr_df = pd.read_csv(corr_file, index_col="feature")
        pval_df = pd.read_csv(pval_file, index_col="feature")
        features = corr_df.index.to_list()

        class_df = pd.read_csv(op.join(classifications_dir, f"{model}_{dset_name}_classification.csv"), index_col="FEATURE")
        if model == "term":
            classification = [class_df.loc[[feature], "Classification"].values[0] if feature in class_df.index else "Non-Specific" for feature in features]
            classification = classification * corr_df.shape[1]
        else:
            classification = class_df["Classification"].to_list() * corr_df.shape[1]

        tmp_data_df = corr_df.melt(ignore_index=False).rename(columns={'variable': 'seg_id', "value": "corr"})
        tmp_data_df['seg_id'] = tmp_data_df['seg_id'].astype(int) + 1
        tmp_data_df["pval"] = pval_df.melt(ignore_index=False)["value"]

        tmp_data_df.insert(0, 'seg_sol', [seg_sol] * len(tmp_data_df))
        tmp_data_df.insert(0, 'method', [f"{model}_{dset_name}_{method}"] * len(tmp_data_df))
        tmp_data_df["classification"] = classification

        if model in ["lda", "gclda"]:
            tmp_data_df["frequencies"] = frequencies * corr_df.shape[1]
        
        tmp_data_lst.append(tmp_data_df)
    data_lst.append(tmp_data_lst)

In [None]:
methods = ["Percentile", "KMeans", "KDE"]
dset_names = ["neurosynth", "neuroquery"]
models = ["term", "lda", "gclda"]

label_dict = {
    "Percentile": "PCT", 
    "KMeans": "KMeans", 
    "KDE": "KDE", 
    "neurosynth": "NS", 
    "neuroquery": "NQ", 
    "term": "Term", 
    "lda": "LDA", 
    "gclda": "GCLDA"
}

for seg_sol in range(5, 6):
    big_cloud_fn = op.join("./Fig", "survey", f"filtered-cloud_{seg_sol:02d}.png")
    if not op.exists(big_cloud_fn):
        cloud_fig, cloud_axes_tpl = plt.subplots(18, seg_sol)
        cloud_fig.set_size_inches(1.6 * seg_sol, 15)

        # radar_fig, radar_axes_tpl = plt.subplots(18, seg_sol, subplot_kw={"projection": "polar"})
        # radar_fig.set_size_inches(1.6 * seg_sol, 15)
        for row_i, (model, dset_name, method) in enumerate(itertools.product(models, dset_names, methods)):
            data_df = data_lst[row_i][seg_sol-3]

            for seg_id in range(1, seg_sol+1):
                cloud_ax = cloud_axes_tpl[row_i, seg_id-1]
                # radar_ax = radar_axes_tpl[row_i, seg_id-1]

                # data_df = data_df.rename(columns={ data_df.columns[0]: "index" })
                # & pval < 0.05 only for las plot
                # filtered_df = data_df.query(f'seg_sol == {seg_sol} & seg_id == {seg_id} & corr > 0 & classification == "Functional"')
                filtered_df = data_df.query(f'seg_sol == {seg_sol} & seg_id == {seg_id} & corr > 0 & classification == "Functional" & pval < 0.05')
                filtered_df = filtered_df.sort_values(by=['corr'], ascending=False)

                # Data for radar plot
                corrs = filtered_df["corr"].values
                features = filtered_df.index.values
                frequencies = filtered_df["frequencies"].values if model in ["lda", "gclda"] else None

                features_split = [feature.split("_")[1:] for feature in features] if model in ["lda", "gclda"] else features
                
                radar_fn = op.join("./Fig", "survey", f"filtered-radar_{model}-{dset_name}-{method}_{seg_sol:02d}-{seg_id:02d}.eps")
                cloud_fn = op.join("./Fig", "survey", f"filtered-cloud_{model}-{dset_name}-{method}_{seg_sol:02d}-{seg_id:02d}.png")

                plot_radar(corrs, features, model, out_fig=radar_fn)
                plot_cloud(features_split, frequencies, corrs, model, out_fig=cloud_fn)
                
                plot_cloud(features_split, frequencies, corrs, model, cloud_fig, cloud_ax)
                if row_i == 0:
                    cloud_ax.set_title(f"Segment {seg_id:02d}", fontsize=8)
                    # radar_ax.set_title(f"Segment {seg_id}", fontsize=8)
                if seg_id == 1:
                    print(label_dict[model], label_dict[dset_name], label_dict[method])
                    # cloud_axes_tpl[0,0].spines["left"].set_visible(True)
                    cloud_ax.set_ylabel(
                        f"{label_dict[model]}-{label_dict[dset_name]}\n{label_dict[method]}", 
                        fontsize=6,
                    )
                    #radar_ax.set_ylabel(
                    #    f"{label_dict[model]}-{label_dict[dset_name]}\n{label_dict[method]}", 
                    #    fontsize=6,
                    #)
                
        plt.tight_layout(w_pad=0.8, h_pad=0.8)
        # plt.subplots_adjust(wspace=0.3, hspace=0.3)
        cloud_fig.savefig(big_cloud_fn, bbox_inches="tight", dpi=1000)
        #big_radar_fn = op.join("./Fig", "survey", f"radar_{seg_sol}.eps")
        #radar_fig.savefig(big_radar_fn, bbox_inches="tight")
        plt.close()
        gc.collect()