In [None]:
import sys
sys.path.append('/causal-discovery')

from cdrl.agent.mcts.mcts_agent import *
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.transforms as transforms
import seaborn as sns

from cdrl.io.storage import EvaluationStorage
from cdrl.io.file_paths import FilePaths

base_suffix = "timings"
instances = ["synth10lr", "synth15lr", "synth20lr", "synth25lr", "synth30lr", "synth35lr", "synth40lr", "synth45lr", "synth50lr"]

exp_ids = [f"{instance}_{base_suffix}" for instance in instances]
fp_out = FilePaths('/experiment_data', 'aggregate_cdrl')


In [None]:
all_timings_data = []

for exp_id in exp_ids:
    fp_in = FilePaths('/experiment_data', exp_id)
    storage = EvaluationStorage(fp_in)
    emd = storage.get_metrics_data("eval")

    for e in emd:
        entry = {}
        entry['total_duration_s'] = e['duration_construct_s']
        entry['per_action_s'] = e['duration_construct_s'] / (len(e['results']['construct']['edges']) * 2)
        entry['N'] = int(exp_id.split("_")[0][-4:-2])
        entry['agent'] = e['agent']
        all_timings_data.append(entry)

timings_df = pd.DataFrame(all_timings_data)
print(timings_df)

In [None]:
measurements = ["per_action_s"]

sns.set(font_scale=3.5)
plt.rc('font', family='serif')
# mpl.rcParams['text.usetex'] = True
mpl.rcParams["lines.linewidth"] = 8
mpl.rcParams["lines.markersize"] = 72


palette ={"uctfull": "C0", "uctfullnaive": "C5"}
legend_i = 0

dims = (8.26 * len(measurements) * 1.5, 8.26 * 1.15)

fig, axes = plt.subplots(1, len(measurements), figsize=dims, squeeze=False, sharey=False, sharex=False)

for i, measurement in enumerate(measurements):
    ax = axes[0][i]

    sns.lineplot(data=timings_df, x="N", y=measurement, ax=ax, hue="agent", palette=palette)
    ax.get_yaxis().get_major_formatter().labelOnlyBase = False
    handles, labels = ax.get_legend_handles_labels()
    ax.legend_.remove()
    ax.set_ylabel("Seconds per action")

agent_display_names = {"uctfull": "CD-UCT",
                       "uctfullnaive": "UCT (Naive)"}
display_labels = [agent_display_names[label] for label in labels[1:]]
fig.legend(handles[1:], display_labels, loc='upper left', borderaxespad=3.5, fontsize="medium")

# fig.suptitle(f"", y=0.92, fontsize=64)
plt.savefig(fp_out.figures_dir / f"{base_suffix}_timings.pdf", bbox_inches="tight")

In [None]:
# timings_df
tp = timings_df.pivot_table(columns=["agent"], values="per_action_s", index="N")
# tp.columns = tp.columns.droplevel(0)
tp["speedup"] = tp["uctfullnaive"] / tp["uctfull"]
tp

In [None]:
experiment_ids = []

# instances = ["sachs"] + [f"syntren{d}" for d in range(1, 11)]
instances = ["sachs"] + [f"syntren1"]

fp_out = FilePaths('/experiment_data', 'aggregate_cdrl')

for inst in instances:
    experiment_ids.append(f"{inst}_primary")

In [None]:
def get_timings_df(experiment_ids, collapse_syntren=True):
    all_timings_data = []

    for exp_id in experiment_ids:
        fp_in = FilePaths('/experiment_data', exp_id)
        storage = EvaluationStorage(fp_in)
        emd = storage.get_metrics_data("eval")

        for entry in emd:
            row_dict = {}

            row_dict["total_seconds"] = entry["duration_construct_s"]
            row_dict["agent"] = entry["agent"]
            if entry["agent"].startswith("uct"):
                row_dict["agent"] = "uct"

            row_dict["instance"] = exp_id.split("_")[0]

            if collapse_syntren:
                if row_dict["instance"].startswith("syntren"):
                    row_dict["instance"] = "syntren"

            all_timings_data.append(row_dict)

    timings_df = pd.DataFrame(all_timings_data)
    return timings_df

In [None]:
tdf = get_timings_df(experiment_ids)
tdf

In [None]:
timings_pivot = tdf.pivot_table(columns=["agent"], index=["instance"])
timings_pivot.columns = timings_pivot.columns.droplevel(0)
timings_pivot = pd.DataFrame(timings_pivot.to_records())
timings_pivot

In [None]:
import re

def format_timing(val_s):
    if type(val_s) == str:
        return val_s

    val_ms = val_s * 1000
    if np.isnan(float(val_ms)):
        return "---"

    val_seconds = int(val_ms / 1000)
    if val_seconds == 0:
        timing_string = "<00:01"
    else:
        timing_string = '{:01}:{:02}:{:02}'.format(val_seconds//3600, val_seconds%3600//60, val_seconds%60)
    return timing_string

timings_pivot = timings_pivot.applymap(format_timing)
colorder = ["instance", "uct", "rlbic", "greedy", "randomshooting", "random", "cam", "lingam", "notears", "pc", "ges"]


agent_display_names = {"uct": "CD-UCT",
                       "rlbic": "RL-BIC",
                       "greedy": "Greedy Search",
                       "random": "Uniform Sampling",
                       "randomshooting": "Random Search",
                       "cam": "CAM",
                       "lingam": "LiNGAM",
                       "notears": "NOTEARS",
                       "ges": "GES",
                       "pc": "PC"}

timings_pivot = timings_pivot[colorder]
timings_pivot.rename(columns=agent_display_names, inplace=True)

texfile =  str(fp_out.figures_dir / f"timings_final.tex")
fh = open(texfile, 'w')
n_startcols = 1
colformat = f"{'c' * n_startcols}|" + ("r"
                                       "" * (len(colorder) - n_startcols))
timings_pivot.to_latex(buf=fh, index=False, column_format=colformat)
fh.close()

replace_dict = {
    r"instance" : r"",
    r"agg" : r"",
    r"metric" : r"",
    r"reward": r"Reward $\uparrow$",
    r"tpr": r"TPR $\uparrow$",
    r"fdr": r"FDR $\downarrow$",
    r"shd": r"SHD $\downarrow$",
    r"-100.000": r"$\\times$",

    r"sachs": r"\\textit{Sachs}",
    r"syntren": r"\\textit{SynTReN}",
    r"mrr": r"MRR",

    r"nan±nan": r"$\infty$",
    r"NaN": r"$\infty$",
    r"nan": r"$\infty$",

    r"±(\d+\.\d+)": r"\\tiny{$\\pm\g<1>$}",
    r"±---": r"\\tiny{$\\pm0.000$}"
}

with open(texfile, 'r') as f:
    raw_content = f.read()

processed_content = raw_content
for orig, targ in replace_dict.items():
    processed_content = re.sub(orig, targ, processed_content, flags = re.M)

with open(texfile, 'w') as g:
    g.write(processed_content)

