In [None]:
import os.path as op
import itertools

import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt

In [None]:
result_dir = op.abspath("../results")
data_dir = op.abspath("../data")
figure_dir = op.abspath("./Fig")

In [None]:
sns.set(style="white")

cotegories = np.array(["Functional", "Clinical", "Anatomical", "Non-Specific"])

colors = plt.get_cmap("Set1")

dset_names = ["neurosynth", "neuroquery"]
models = ["term", "lda", "gclda"]

method_lst = []
class_lst = []
for model, dset_name in itertools.product(models, dset_names):
    print(model, dset_name)
    data_df = pd.read_csv(op.join(data_dir, "classification", f"{model}_{dset_name}_classification.csv"))
    
    data = []
    for cotegory in cotegories:
        n_elements = data_df[data_df["Classification"] == cotegory].shape[0]
        data.append(n_elements)
        class_lst.append([cotegory] * n_elements)
        method_lst.append([f"{model}_{dset_name}"] * n_elements )
        
    explode = [0] * len(cotegories)
    explode[0] = 0.1

    palette_color = sns.color_palette('muted')

    # colors.colors[:len(keys)]
    plt.pie(data, labels=cotegories, colors=palette_color,
        explode=explode, autopct='%.0f%%')
    plt.show()
    
new_data_df = pd.DataFrame()
new_data_df["method"] = np.hstack(method_lst)
new_data_df["classification"] = np.hstack(class_lst)
new_data_df


In [None]:
cross_data_prop_df = pd.crosstab(index=new_data_df["method"],
                             columns=new_data_df["classification"],
                             normalize="index")
cross_data_prop_df = cross_data_prop_df[cotegories]
cross_data_prop_df = cross_data_prop_df.sort_index(ascending=False)
cross_data_prop_df

In [None]:
cross_data_df = pd.crosstab(index=new_data_df["method"],
                             columns=new_data_df["classification"])
cross_data_df = cross_data_df[cotegories]
cross_data_df = cross_data_df.sort_index(ascending=False)
cross_data_df

In [None]:
fontsize = 11
my_cmap = plt.get_cmap("tab20c")
idxes = [1, 5, 9, 17]
# colors = [my_cmap.colors[idx] for idx in idxes]

colors = ["#393E46", '#6D9886', '#F2E7D5', '#F7F7F7']

fig, ax = plt.subplots(1, 1)
fig.set_size_inches(3.5, 3.5)

cross_data_prop_df.plot(
    kind='bar', 
    stacked=True, 
    color=colors,
    edgecolor='white', 
    linewidth=2,
    width=0.9,
    ax=ax,
)

ax.set_xticklabels(["Term-NS", "Term-NQ", "LDA-NS", "LDA-NQ", "GCLDA-NS", "GCLDA-NQ"], fontsize=fontsize)

ax.legend(
    loc="upper center",
    bbox_to_anchor=(0.5, 1.18),
    ncol=4,
    fontsize=fontsize,
)
plt.xticks(fontsize=fontsize)
plt.yticks(fontsize=fontsize)
ax.set_xlabel("Decoding Strategy", fontsize=fontsize+2)
ax.set_ylabel("Proportion", fontsize=fontsize+2)

"""
for n, x in enumerate([*cross_data_df.index.values]):
    for (proportion, count, y_loc) in zip(cross_data_prop_df.loc[x],
                                          cross_data_df.loc[x],
                                          cross_data_prop_df.loc[x].cumsum()):
                
        plt.text(
            x=n - 0.17,
            y=(y_loc - proportion) + (proportion / 2),
            s=f'{count}', 
            color="black",
            fontsize=12,
            fontweight="bold"
        )
        plt.text(
            x=n - 0.2,
            y=(y_loc - proportion) + (proportion / 2) - 0.04,
            s=f'({int(np.round(proportion * 100))}%)', 
            color="black",
            fontsize=12,
            fontweight="bold"
        )
"""
plt.savefig(op.join("./Fig", "classification", "class_prop_barh.eps"), bbox_inches="tight")
plt.savefig(op.join("./Fig", "Fig-09.eps"), bbox_inches="tight")
plt.show()