In [None]:
%load_ext autoreload
%autoreload 2

import sys
from pathlib import Path
import pickle
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from itertools import product
from tqdm.notebook import tqdm
from joblib import Parallel, delayed

parent_path = str(Path("../../../../../").resolve())
print(parent_path)
sys.path.append(parent_path)

from seads import EXPERIMENT_DIR
from seads.utils.evaluation import plot_interpolated_performance
from seads.jobs.evaluation.load_evaluation_rollouts import load_evaluation_rollouts

import matplotlib as mpl
mpl.pyplot.rcdefaults()

from matplotlib import rc
import matplotlib.pyplot as plt

%matplotlib inline

FIG_SCALE = 3
TITLE_FONTSIZE = 18
rc('font',**{'family':'sans-serif','sans-serif':['Computer Modern Sans Serif']})
rc('text', usetex=True)
mpl.rcParams.update({'font.size': 18})

In [None]:
ENVS = [
    "lightsout_cursor",
    "tileswap_cursor",
    "lightsout_reacher",
    "tileswap_reacher",
    "lightsout_jaco",
    "tileswap_jaco",
]

def load_seads_data():
    n_seeds = 10
    run_stem_list = [
        f"corl22_{env}_default" for env in ENVS
    ]
    n_rollouts_per_depth = 20
    return load_evaluation_rollouts(run_stem_list, n_seeds, n_rollouts_per_depth, keep_episodes=False)

In [None]:
seads_df_agg, seads_df_all = load_seads_data()
seads_df_all_replanning = seads_df_all[seads_df_all["replanning"] == True]
seads_df_all_noreplanning = seads_df_all[seads_df_all["replanning"] == False]

In [None]:
sanity_df = seads_df_agg.groupby(["run_stem", "run_seed", "replanning", "ckpt_step"]).size().reset_index()
sanity_df[sanity_df.loc[:, 0] != 100]

In [None]:
pd.set_option('display.max_rows', 500)

## Success rate evaluation

In [None]:
SEEDS_TO_RETAIN = list(range(1, 11))

plot_kwargs = {
    "value_column_name":"mean_val_success_rate",
    "plot_percent": True
}

figsize=(1 * FIG_SCALE * 5.4, 0.6 * FIG_SCALE)
fig, ax_arr = plt.subplots(nrows=1, ncols=len(ENVS), figsize=figsize, sharex=False, sharey=True, squeeze=False)

for ax in ax_arr.flatten():
    ax.set_ylim(-5, 105)

labels = []
artists = []

for idx, env in enumerate(ENVS):
    game, manip = env.split("_")
    env_name = {"lightsout": "LightsOut", "tileswap": "TileSwap"}[game] + manip.capitalize()
    
    if manip == "cursor":
        max_env_steps = int(5e5)
        xticks = [0, 5e5]
        xticklabels = ["$0$", "0.5M"]
    elif manip in ["reacher", "jaco"]:
        max_env_steps = int(1e7)
        xticks = [0, 0.5e7, 1e7]
        xticklabels = ["0", "5M", "10M"]
    else:
        raise ValueError
    
    if idx == 0:
        legend_lists = [artists,labels]
    else:
        legend_lists = None
    
    ax = ax_arr[0, idx]
    ax.set_title(env_name, fontsize=TITLE_FONTSIZE)
    d = plot_interpolated_performance(
        ax, seads_df_all_replanning, f"corl22_{env}_default",
        label="SEADS (replan.)", legend_lists=legend_lists,
        max_env_steps=max_env_steps, seeds_to_retain=SEEDS_TO_RETAIN, **plot_kwargs)
    plot_interpolated_performance(
        ax, seads_df_all_noreplanning, f"corl22_{env}_default",
        label="SEADS (no replan.)", legend_lists=legend_lists,
        max_env_steps=max_env_steps, seeds_to_retain=SEEDS_TO_RETAIN, **plot_kwargs)
    ax.set_xlabel("Env. steps")
    if idx == 0:
        ax.set_ylabel(r"Success rate (\%)")  
    else:
        ax.set_ylabel(r"")  
    ax.set_xticks(xticks) 
    ax.set_xticklabels(xticklabels)
    xticks = ax.xaxis.get_majorticklabels()

plt.subplots_adjust(hspace=0.5)

ax_arr[0,0].legend(artists, labels, fontsize=TITLE_FONTSIZE, ncol=4, loc="lower center", bbox_to_anchor=(3,-1))
plt.savefig("generated/successrate_seads_default.png", bbox_inches='tight')