In [None]:
import sys
sys.path.append('/causal-discovery-dv')

from cdrl.agent.mcts.mcts_agent import *
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.transforms as transforms
import seaborn as sns

from cdrl.io.storage import EvaluationStorage
from cdrl.io.file_paths import FilePaths

discrete_instances = ["asia", "child", "insurance"]
fp_out = FilePaths('/experiment_data', 'aggregate_cdrl')

In [None]:
metrics_to_display = {
    "construct": ["shd", "fdr", "tpr", "reward"],
}

def get_eval_df(experiment_ids, which_results="construct", collapse_syntren=True, skip_nonmdp=True):
    all_eval_data = []

    for exp_id in experiment_ids:
        fp_in = FilePaths('/experiment_data', exp_id)
        storage = EvaluationStorage(fp_in)

        emd = storage.get_metrics_data("eval")

        metrics = metrics_to_display[which_results]

        for entry in emd:
            if which_results == "construct":
                if skip_nonmdp and entry["agent"] in ["cam", "lingam", "notears"]:
                    continue

            for metric in metrics:
                row_dict = {}

                row_dict["metric"] = metric
                row_dict["value"] = entry["results"][which_results][metric]

                row_dict["agent"] = entry["agent"]
                if entry["agent"].startswith("uct"):
                    row_dict["agent"] = "uct"

                # if entry["agent"] not in ["cam", "lingam", "notears"]:
                #     row_dict["agent"] = row_dict["agent"] + "_" + exp_id[-7:]

                row_dict["instance"] = exp_id.split("_")[0]

                if collapse_syntren:
                    if row_dict["instance"].startswith("syntren"):
                        row_dict["instance"] = "syntren"

                all_eval_data.append(row_dict)

    eval_df = pd.DataFrame(all_eval_data)
    return eval_df

In [None]:
import scipy as sp

def compute_ci(data, confidence=0.95):
    if len(data) == 1:
        return 0.

    a = np.array(data)
    n = len(a)
    se = sp.stats.sem(a)
    h = se * sp.stats.t.ppf((1 + confidence) / 2., n-1)
    return h

def augment_with_cis(results_pivot, orig_df):
    pivot_cp = deepcopy(results_pivot)

    all_algos = results_pivot.columns.tolist()[2:]
    all_algos = [a for a in all_algos if a not in ["greedy", "cam", "notears", "lingam", "gobnilp"]]

    for algo in all_algos:
        algo_cis = []

        for row in results_pivot.itertuples():
            metric = getattr(row, 'metric')
            instance = getattr(row, "instance")

            # print(algo, metric, instance)
            relevant_entries = orig_df[(orig_df["metric"] == metric) &
                                       (orig_df['instance'] == instance) &
                                       (orig_df['agent'] == algo)]
            metric_values = relevant_entries["value"].tolist()
            ci = compute_ci(metric_values)
            # print(metric_values, ci)
            algo_cis.append(ci)

        pivot_cp[f"{algo}_ci"] = algo_cis

    for algo in all_algos:
        colname_ci = f"{algo}_ci"
        pivot_cp[algo] = pivot_cp.agg(lambda x: f"{x[algo]:.3f}±{x[colname_ci]:.3f}", axis=1)
        pivot_cp.drop(columns=[colname_ci], inplace=True)

    return pivot_cp



In [None]:
import re

def prepare_and_write_latex(df, which_results="construct", file_suffix=""):
    colorder = ["instance", "metric", "uct", "greedy", "randomshooting", "random", "gobnilp"]

    print(colorder)
    agent_display_names = {"uct": "CD-UCT",
                           "rlbic": "RL-BIC",
                           "greedy": "Greedy Search",
                           "random": "Random Sampling",
                           "randomshooting": "Random Search",
                           "cam": "CAM",
                           "lingam": "LiNGAM",
                           "notears": "NOTEARS",
                           "gobnilp": "GOBNILP"}

    df = df[colorder]
    df['metric'] = pd.Categorical(df['metric'],categories=['reward', 'tpr','fdr','shd'], ordered=True)
    
    if file_suffix == "discrete":
        df["instance"] = pd.Categorical(df['instance'], categories=discrete_instances, ordered=True)

    if which_results == "joint":
        df = df.sort_values(by=["phase", "instance", "metric"])
    elif file_suffix == "discrete":
        df = df.sort_values(by=["instance", "metric"])
    elif file_suffix == "":
        df = df.sort_values(by=["instance", "metric"])
    else:
        df = df.sort_values(by=["metric"])

    if which_results == "prune_cam":
        df.loc[ df["instance"] == "syntren", ["notears"]] = -100
    df.rename(columns=agent_display_names, inplace=True)

    texfile =  str(fp_out.figures_dir / f"{which_results}_final{'_' + file_suffix if file_suffix != '' else ''}.tex")
    fh = open(texfile, 'w')

    n_startcols = 3 if which_results == "joint" else (2 if file_suffix == "" else 1)

    colformat = f"{'c' * n_startcols}|" + ("r" * (len(colorder) - n_startcols))
    df.to_latex(buf=fh, float_format="{:0.3f}".format, index=False, column_format=colformat)
    fh.close()

    replace_dict = {}
    if file_suffix == "discrete":
        replace_dict[r"nan±nan"] =  r"?"
    else:
        replace_dict[r"nan±nan"] =  r"$\\infty$"
        
    replace_dict.update({
        r"instance" : r"",
        r"agg" : r"",
        r"metric" : r"",
        r"phase": r"Phase",
        r"construct": r"\\textbf{Construct}",
        r"prune": r"\\textbf{Prune}",

        r"reward": r"Reward $\\uparrow$",
        r"tpr": r"TPR $\\uparrow$",
        r"fdr": r"FDR $\\downarrow$",
        r"shd": r"SHD $\\downarrow$",
        r"-100.000": r"$\\times$",
        r"-999.000": r"---",

        r"sachs": r"\\textit{Sachs}",
        r"syntren": r"\\textit{SynTReN}",
        
        r"asia": r"\\textit{Asia}",
        r"child": r"\\textit{Child}",
        r"insurance": r"\\textit{Insurance}",
        
        r"mrr": r"MRR",

        r"NaN": r"$\\infty$",
        r"nan": r"$\\infty$",

        r"±(\d+\.\d+)": r"\\tiny{$\\pm\g<1>$}",
        r"±---": r"\\tiny{$\\pm0.000$}"
    })
    

    with open(texfile, 'r') as f:
        raw_content = f.read()

    processed_content = raw_content
    for orig, targ in replace_dict.items():
        processed_content = re.sub(orig, targ, processed_content, flags = re.M)

    with open(texfile, 'w') as g:
        g.write(processed_content)

    return df

In [None]:
experiment_ids = [f"{instance_name}_discretevars" for instance_name in discrete_instances]

discrete_df = get_eval_df(experiment_ids, which_results="construct")
discrete_pivot = discrete_df.pivot_table(columns=["agent"], index=["instance", "metric"])


dvfp = deepcopy(discrete_pivot)
dvfp.columns = dvfp.columns.droplevel(0)
dvfp = pd.DataFrame(dvfp.to_records())

dvfp_final = augment_with_cis(dvfp, discrete_df)
dvfp_final
prepare_and_write_latex(dvfp_final, which_results="construct", file_suffix="discrete")
