In [None]:
import sys
sys.path.append('/causal-discovery')

from cdrl.agent.mcts.mcts_agent import *
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.transforms as transforms
import seaborn as sns

from cdrl.io.storage import EvaluationStorage
from cdrl.io.file_paths import FilePaths

budgets = ["0.1", "0.25" "0.5", "1", "2.5", "5", "10", "25", "50", "100", "250", "500", "1178"]

exp_ids_main = [f"sachs_varbudgetb{budget}" for budget in budgets]
exp_ids_baselines = [f"sachs_primary" for budget in budgets]


In [None]:
all_data = []

for exp_ids in [exp_ids_main, exp_ids_baselines]:
    for i, exp_id in enumerate(exp_ids):
        fp_in = FilePaths('/experiment_data', exp_id)
        storage = EvaluationStorage(fp_in)
        emd = storage.get_metrics_data("eval")

        for entry in emd:
            if "primary" in exp_id and entry["agent"] in ["uctfull", "randomshooting"]:
                # overlap between "primary" and 1178 budget experiments, skip so we do not include this data twice.
                continue
            
            row_dict = {}
            row_dict["agent"] = entry["agent"]
            row_dict["construct_reward"] = entry["results"]["construct"]["reward"]
            row_dict["construct_shd"] = entry["results"]["construct"]["shd"]
            row_dict["prune_shd"] = entry["results"]["prune_cam"]["shd"]
            row_dict["budget"] = float(budgets[i])
            row_dict["total_seconds"] = entry["duration_construct_s"]

            all_data.append(row_dict)

budget_df = pd.DataFrame(all_data)
budget_df['agent'] = pd.Categorical(budget_df['agent'], categories=["uctfull", "rlbic", "greedy", "randomshooting", "random"], ordered=True)
budget_df = budget_df.sort_values(by=["agent"])


In [None]:
from matplotlib.legend_handler import HandlerLine2D


def legend_handle_update(handle, orig):
    handle.update_from(orig)
    handle.set_linewidth(8)

sns.set(font_scale=3.5)
plt.rc('font', family='serif')
mpl.rcParams['text.usetex'] = False
mpl.rcParams["lines.linewidth"] = 8
mpl.rcParams["lines.markersize"] = 72

dims = (5 * 8.26, 1.2 * 8.26)
all_budgets = [float(b) for b in budgets]
plot_ys = ["construct_reward", "construct_shd", "prune_shd", "total_seconds"]
plot_ys_display = ["Construction reward", r"Construction SHD", r"Pruning SHD", r"Total seconds"]
fig, axes = plt.subplots(1, 4, figsize=dims, squeeze=False, sharey=False, sharex=False)
fig.tight_layout()

for i, plot_y in enumerate(plot_ys):
    ax = axes[0][i]

    sns.lineplot(data=budget_df, x="budget", y=plot_y, ax=ax, hue="agent")
    ax.set_xticks(all_budgets)
    handles, labels = ax.get_legend_handles_labels()

    ax.set_xscale("log")
    if plot_y == "total_seconds":
        ax.set_yscale("log")

    ax.get_xaxis().get_major_formatter().labelOnlyBase = False


    ax.tick_params(axis='both', which='major', labelsize=42)
    ax.tick_params(axis='both', which='minor', labelsize=24)
    ax.set_xticks([10 ** i for i in [-1, 0, 1, 2, 3]])
    if plot_y == "total_seconds":
        ax.set_yticks([10 ** i for i in [0, 1, 2, 3]])


    ax.legend_.remove()

    ax.set_ylabel(plot_ys_display[i])
    ax.set_xlabel("Budget")

agent_display_names = {"uctfull": "CD-UCT",
                       "rlbic": "RL-BIC",
                       "greedy": "Greedy Search",
                       "random": "Uniform Sampling",
                       "randomshooting": "Random Search",
                       }

display_labels = [agent_display_names[label] for label in labels[1:]]
fig.legend(handles[1:], display_labels, loc='upper center', borderaxespad=-0.2, fontsize="medium", ncol=5,
               handler_map={plt.Line2D: HandlerLine2D(update_func=legend_handle_update)})

fp_out = FilePaths('/experiment_data', 'aggregate_cdrl')
plt.savefig(fp_out.figures_dir / f"budget_analysis_final.pdf", bbox_inches="tight")