In [74]:
from pathlib import Path
import pandas as pd
from functools import partial

def get_result_str(x, metric: str):
    mean = x[metric].mean()
    std = x[metric].std()

    result = rf"${mean:.4f} \pm {std:.4f}$" 
    
    res = pd.Series({metric: result})
    return res


model = "ilm"  # ["ilm", "igpr"]

agents = ["sac", "dqn", "ppo"]
metrics = ["mean_squared_error"]  # , "mean_absolute_error"]
table = []
for agent in agents:
    fn = Path(f"../figures/{agent}/{model}_maps/iqm_fit_cv.csv")
    if fn.is_file():
        # print("-"*80, agent)
        df = pd.read_csv(fn)
        # print(df)

        result = []
        for metric in metrics:
            ret = df.groupby(by="phase_index").apply(lambda x: pd.Series({"metric": metric, "mean": x[metric].mean(), "std": x[metric].std()}))
            ret = df.groupby(by="phase_index").apply(partial(get_result_str, metric=metric))
            result.append(ret)


        
        result = pd.concat(result, axis=1)
        # result["agent"] = agent
        table.append(result)
    else:
        print("Did not find", fn)

table = pd.concat(table, axis=1)
table.columns = pd.MultiIndex.from_tuples([(a.upper(), m.replace("_", " ").capitalize()) for a in agents for m in metrics])
table.index.name = "Phase"
print(table.to_latex(escape=False))

\begin{tabular}{llll}
\toprule
{} &                  SAC &                  DQN &                  PPO \\
{} &   Mean squared error &   Mean squared error &   Mean squared error \\
Phase &                      &                      &                      \\
\midrule
1     &  $0.4764 \pm 0.3680$ &  $0.0044 \pm 0.0008$ &  $0.0028 \pm 0.0006$ \\
2     &  $2.3295 \pm 0.6123$ &  $0.0466 \pm 0.0187$ &  $0.0076 \pm 0.0034$ \\
3     &  $2.0241 \pm 0.4517$ &  $0.0587 \pm 0.0281$ &  $0.0211 \pm 0.0059$ \\
4     &  $2.1405 \pm 0.3728$ &                  NaN &                  NaN \\
\bottomrule
\end{tabular}



  print(table.to_latex(escape=False))
