In [None]:
import os
import glob
import numpy as np
import pandas as pd
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import warnings

## Load Results

In [None]:
# Replace `YYYYMMDD-HHMMSS` with the timestamp of the run.
PATH = "./results/YYYYMMDD-HHMMSS/"

In [None]:
results =  list(glob.glob(PATH + "aggregating/csv/*.csv"))
results += list(glob.glob(PATH + "bootstrapping/csv/*.csv"))
print(len(results))

In [None]:
results = pd.concat([pd.read_csv(result, dtype = { "sample_ratio": str }) for result in results])
results.set_index("id", inplace=True)
results.sort_values(["bn", "sample_ratio"], inplace=True)
results.head()

In [None]:
results.drop(columns = ["sensitivity", "specificity", "accuracy", "balanced_accuracy"], inplace=True)

In [None]:
stats = pd.read_csv("./stats.csv")
stats = stats.set_index("name")
stats.head()

In [None]:
order = sorted(dict.fromkeys(results["bn"]).keys())
order

## Plot Results

In [None]:
eps = np.finfo(float).eps

In [None]:
sns.set_style("white")
sns.set_style(
    "ticks",
    {
        "axes.edgecolor": "0",
        "xtick.color": "0",
        "ytick.color": "0"
    }
)
sns.set_context("paper", font_scale = 2.30)
mpl.rcParams["axes.formatter.limits"] = (-5, 3)

### Aggregation

In [None]:
groupby = dict((bn, r) for (bn, r) in results.groupby("bn", sort=True))

In [None]:
os.makedirs(PATH + "plots", exist_ok=True)

In [None]:
for bn in order:
    r = groupby[bn]
    # Select aggredated models.
    colors = ["red", "blue", "orange", "green"] # ["red", "red", "red", "red", "red", "green", "purple"]
    methods = ["tma_0.50", "pma", "cma", "ima"] # ["tma_0.50", "tma_0.60", "tma_0.70", "tma_0.80", "tma_0.90", "sma", "uma"]
    linestyles = ["-", "-", "-", "-"] # ["-", (0, (1, 1)), (0, (1, 3)), (0, (1, 5)), (0, (1, 7)), "-", "-"]
    # Normalize BIC.
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        # Subselect dataset.
        r = r[r["method"].apply(lambda x: x in (["none"] + methods))]
        # Normalize BIC by sample_ratio.
        if "scaled_in_bic" not in r.columns:
            r.insert(5, "scaled_in_bic", r["in_bic"].copy())
            r.insert(6, "scaled_out_bic", r["out_bic"].copy())
            for c in ["scaled_in_bic", "scaled_out_bic"]:
                for s in r["sample_ratio"].unique():
                    k = r[c].loc[r["sample_ratio"] == s]
                    r[c].loc[r["sample_ratio"] == s] = (k - k.min()) / (k.max() - k.min() + eps)
        # Drop non-normalized BIC.
        if "in_bic" in r.columns:
            r.drop(["in_bic", "out_bic"], axis=1, inplace=True)
    # Groupby method.
    r = dict((method, r) for (method, r) in r.groupby("method", sort=True))
    # Box plot.
    g = pd.melt(
        r["none"],
        id_vars = ["bn", "sample_ratio", "method"],
        var_name = "metric",
    )
    g = sns.FacetGrid(
        g,
        col = "metric",
        height = 5.5,
        aspect = .75,
        sharex = False,
        sharey = False,
        margin_titles = True
    )
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        g.map(sns.boxplot, "sample_ratio", "value", color = "white")
    # Add average models.
    for ((color, method), linestyle) in zip(zip(colors, methods), linestyles):
        columns = list(r[method].columns)[3:]
        for (i, column) in enumerate(columns):
            ax = g.facet_axis(0, i)
            if column in ["scaled_in_bic", "scaled_out_bic", "f1"]:
                ax.set(ylim=(-0.05, 1.05))
            if column in ["shd"]:
                ax.set(ylim=(-0.05, None))
            sns.lineplot(
                data=r[method],
                x="sample_ratio",
                y=column,
                color=color,
                linewidth=3,
                linestyle=linestyle,
                marker="o",
                markersize=9,
                ax=ax
            )
            # Add vertical line.
            ax.axvline(3.00, 0.05, 0.95, color = "gray", linestyle = "--")
            # Add textes.
            if column == columns[0]:
                ax.text(1.95, 0.15, "Low", color = "gray", fontsize = 12, rotation = 90)
                ax.text(2.40, 0.05, "sample size", color = "gray", fontsize = 12, rotation = 90)
                ax.text(3.30, 0.15, "High", color = "gray", fontsize = 12, rotation = 90)
                ax.text(3.75, 0.05, "sample size", color = "gray", fontsize = 12, rotation = 90)
            # Fix y-label override.
            if i == 0:
                ax.set_ylabel(bn.split("-")[1].upper(), labelpad = 15)
            else:
                ax.set_ylabel(None)
    xticks = [ax.get_xticklabels() for ax in g.axes.flat]
    sns.despine(offset = 2.5, trim = True)
    for (i, ax) in enumerate(g.axes.flat):
        _ = ax.set_xlabel(ax.get_xlabel(), labelpad=12)
        _ = ax.set_xticklabels(xticks[i], rotation=90)
    handles=[
        mpatches.Patch(color="gray", label="Bootstrap"),
        mpatches.Patch(color="red", label="TMA"),
        mpatches.Patch(color="blue", label="PMA"),
        mpatches.Patch(color="orange", label="CMA"),
        mpatches.Patch(color="green", label="IMA"),
    ]
    plt.figlegend(handles=handles, bbox_to_anchor=(0.775, 0.01), ncol=len(handles))
    plt.savefig(PATH + "plots/" + bn + ".pdf", bbox_inches="tight")
    plt.show()
    plt.close()