In [None]:
import sys
sys.path.append('/causal-discovery')

import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns

from cdrl.io.storage import EvaluationStorage
from cdrl.io.file_paths import FilePaths

gtype = "er"

gsizes = [15, 20, 25, 30, 35, 40, 45]
datasizes = [10, 25, 50, 75, 100, 175, 250, 375, 500, 750, 1000, 1750, 2500, 3750, 5000]

exp_ids_vardensity = [f"synth10gpr_vardensity_{gtype}_e{gsize}" for gsize in gsizes]
exp_ids_vardata = [f"synth10gpr_vardata_{gtype}_n{datasize}" for datasize in datasizes]

fp_out = FilePaths('/experiment_data', 'aggregate_cdrl')


In [None]:
def get_data_for_exp_ids(exp_ids):
    all_data = []

    for exp_id in exp_ids:
        fp_in = FilePaths('/experiment_data', exp_id)
        storage = EvaluationStorage(fp_in)
        emd = storage.get_metrics_data("eval")

        for e in emd:
            entry = {}
            if "density" in exp_id:
                entry['e'] = int(exp_id.split("_")[-1][1:])
            if "data" in exp_id:
                entry['m'] = int(exp_id.split("_")[-1][1:])

            for metric in ['tpr', 'fdr', 'shd']:
                entry[metric] = e['results']['construct'][metric]

            entry['agent'] = e['agent']
            all_data.append(entry)

    df = pd.DataFrame(all_data)
    df['agent'] = pd.Categorical(df['agent'], categories=["uctfull", "rlbic", "greedy", "randomshooting", "random"], ordered=True)
    df = df.sort_values(by=["agent"])
    return df

In [None]:
density_df = get_data_for_exp_ids(exp_ids_vardensity)
data_df = get_data_for_exp_ids(exp_ids_vardata)

In [None]:
from matplotlib.legend_handler import HandlerLine2D
def legend_handle_update(handle, orig):
    handle.update_from(orig)
    handle.set_linewidth(8)

sns.set(font_scale=3.5)
plt.rc('font', family='serif')
# mpl.rcParams['text.usetex'] = True
mpl.rcParams["lines.linewidth"] = 8
mpl.rcParams["lines.markersize"] = 72

legend_i = 1

dims = (2.5 * 8.26, 1.2 * 8.26)

fig, axes = plt.subplots(1, 2, figsize=dims, squeeze=False, sharey=False, sharex=False)

xs = ["e", "m"]
display_metric = "shd"


for i, df in enumerate([density_df, data_df]):
    ax = axes[0][i]

    sns.lineplot(data=df, x=xs[i], y=display_metric, ax=ax, hue="agent")
    if i == 1:
        ax.set_xscale("log")
        ax.get_xaxis().get_major_formatter().labelOnlyBase = False

    handles, labels = ax.get_legend_handles_labels()
    ax.legend_.remove()

    ax.set_ylabel("SHD")
    if xs[i] == "e":
        ax.set_xlabel("Number of edges")
    else:
        ax.set_xlabel("Number of datapoints")

agent_display_names = {"uctfull": "CD-UCT",
                       "rlbic": "RL-BIC",
                       "greedy": "Greedy Search",
                       "random": "Uniform Sampling",
                       "randomshooting": "Random Search"}

relevant_labels = [labels[1]]
relevant_labels.extend(labels[3:])
relevant_handles = [handles[1]]
relevant_handles.extend(handles[3:])

display_labels = [agent_display_names[label] for label in relevant_labels]
fig.legend(relevant_handles, display_labels, loc='upper center', borderaxespad=-0.25, fontsize="medium", ncol=2,
               handler_map={plt.Line2D: HandlerLine2D(update_func=legend_handle_update)})
fig.tight_layout(rect=[0,0,1,0.89])
plt.savefig(fp_out.figures_dir / f"finalvartopology_{display_metric}.pdf", bbox_inches="tight")