In [None]:
import os.path as op
import itertools
import gc

import matplotlib.pyplot as plt
import pandas as pd
from gradec.fetcher import _fetch_features, _fetch_frequencies, _fetch_classification
from gradec.plot import plot_radar, plot_cloud
from gradec.utils import _decoding_filter

In [None]:
data_dir = op.join("..", "data")

In [None]:
methods = ["PCT", "KMeans", "KDE"]
dset_names = ["neurosynth", "neuroquery"]
models = ["term", "lda", "gclda"]

label_dict = {
    "PCT": "PCT", 
    "KMeans": "KMeans", 
    "KDE": "KDE", 
    "neurosynth": "NS", 
    "neuroquery": "NQ", 
    "term": "Term", 
    "lda": "LDA", 
    "gclda": "GCLDA"
}

for seg_sol in range(2, 33):
    big_cloud_fn = op.join("./Fig", "survey", f"cloud_{seg_sol:02d}.png")
    print(f"\includegraphics[scale=1]{{cloud_{seg_sol:02d}.png}}")

    if not op.exists(big_cloud_fn):
        cloud_fig, cloud_axes_tpl = plt.subplots(18, seg_sol)
        cloud_fig.set_size_inches(1.6 * seg_sol, 15)

        # radar_fig, radar_axes_tpl = plt.subplots(18, seg_sol, subplot_kw={"projection": "polar"})
        # radar_fig.set_size_inches(1.6 * seg_sol, 15)
        for row_i, (model, dset_name, method) in enumerate(itertools.product(models, dset_names, methods)):
            corr_dir = op.join("../results/decoding", f"{dset_name}_{model}_corr_{method}")
            corr_file = op.join(corr_dir, f"corrs_{seg_sol:02d}.csv")
            pval_file = op.join(corr_dir, f"pvals-FDR_{seg_sol:02d}.csv")
            corr_df = pd.read_csv(corr_file, index_col="feature")
            pval_df = pd.read_csv(pval_file, index_col="feature")
            
            # Load features for visualization
            features = _fetch_features(dset_name, model, data_dir=data_dir)
            classification, class_lst = _fetch_classification(dset_name, model, data_dir=data_dir)

            for seg in range(seg_sol):
                seg_id = seg + 1
                cloud_ax = cloud_axes_tpl[row_i, seg]

                data_df = corr_df[[f"{seg}"]]

                if model in ["lda", "gclda"]:
                    frequencies = _fetch_frequencies(dset_name, model, data_dir=data_dir)
                    filtered_df, filtered_features, filtered_frequencies = _decoding_filter(
                        data_df,
                        features,
                        classification,
                        freq_by_topic=frequencies,
                        class_by_topic=class_lst,
                    )
                else:
                    filtered_df, filtered_features = _decoding_filter(
                        data_df,
                        features,
                        classification,
                    )
                filtered_df.columns = ["r"]

                # Visualize results
                corrs = filtered_df["r"].to_numpy()

                # Word cloud plot
                cloud_fn = op.join("./Fig", "survey", f"cloud_{model}-{dset_name}-{method}_{seg_sol:02d}-{seg_id:02d}.png")
                if model in ["lda", "gclda"]:
                    plot_cloud(
                        corrs, 
                        filtered_features,
                        model,
                        frequencies=filtered_frequencies,
                        cmap="YlOrRd",
                        ax=cloud_ax,
                        out_fig=cloud_fn,
                    )
                else:
                    plot_cloud(
                        corrs, 
                        filtered_features,
                        model,
                        cmap="YlOrRd",
                        ax=cloud_ax,
                        out_fig=cloud_fn,
                    )
                
                if row_i == 0:
                    cloud_ax.set_title(f"Segment {seg_id:02d}", fontsize=8)
                    # radar_ax.set_title(f"Segment {seg_id}", fontsize=8)
                if seg_id == 1:
                    # print(label_dict[model], label_dict[dset_name], label_dict[method])
                    # cloud_axes_tpl[0,0].spines["left"].set_visible(True)
                    cloud_ax.set_ylabel(
                        f"{label_dict[model]}-{label_dict[dset_name]}\n{label_dict[method]}", 
                        fontsize=6,
                    )
                    #radar_ax.set_ylabel(
                    #    f"{label_dict[model]}-{label_dict[dset_name]}\n{label_dict[method]}", 
                    #    fontsize=6,
                    #)
                
        plt.tight_layout(w_pad=0.8, h_pad=0.8)
        # plt.subplots_adjust(wspace=0.3, hspace=0.3)
        cloud_fig.savefig(big_cloud_fn, bbox_inches="tight", dpi=300)
        #big_radar_fn = op.join("./Fig", "survey", f"radar_{seg_sol}.eps")
        #radar_fig.savefig(big_radar_fn, bbox_inches="tight")
        plt.close()
        gc.collect()