In [None]:
import sys
from pathlib import Path
import pickle
import pandas as pd
import numpy as np
from itertools import product
from tqdm.notebook import tqdm
from joblib import Parallel, delayed

parent_path = str(Path("../../../../../").resolve())
print(parent_path)
sys.path.append(parent_path)

from seads import EXPERIMENT_DIR
from seads.jobs.evaluation.load_skill_coverage_data import load_skill_coverage
from seads.utils.evaluation import plot_interpolated_performance

import matplotlib as mpl
from matplotlib import rc
rc('font',**{'family':'sans-serif','sans-serif':['Computer Modern Sans Serif']})
rc('text', usetex=True)
mpl.rcParams.update({'font.size': 12})
import matplotlib.pyplot as plt

%matplotlib inline

FIG_SCALE = 3
TITLE_FONTSIZE = 18
rc('font',**{'family':'sans-serif','sans-serif':['Computer Modern Sans Serif']})
rc('text', usetex=True)
mpl.rcParams.update({'font.size': 18})

base_dir = EXPERIMENT_DIR.joinpath("train_lads_corl22")

In [None]:
ENVS = [
    "lightsout_cursor",
    "tileswap_cursor",
    "lightsout_reacher",
    "tileswap_reacher",
    "lightsout_jaco",
    "tileswap_jaco",
]

VARIANTS = [
    "default",
]

RUN_SEEDS = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

In [None]:
run_prefix = "corl22"
df = load_skill_coverage(run_prefix, ENVS, VARIANTS, RUN_SEEDS)

In [None]:
df["env_steps"] = df["env_interactions_lpfm"] + df["env_interactions_sac"]
evaluations = df[["covered_mean", "run_stem", "run_seed", "env_steps", "checkpoint"]]

In [None]:
ABLATION_NAMES = {
    "default": "SEADS",
}

def plot_ablations(ablations):
    figsize=(1 * FIG_SCALE * 5.4, 0.6 * FIG_SCALE)
    fig, ax_arr = plt.subplots(nrows=1, ncols=len(ENVS), figsize=figsize, sharex=False, sharey=False, squeeze=False)
    
    run_prefix = "corl22"
    
    performance_data = []
    
    for idx, env in enumerate(ENVS):
        game, manip = env.split("_")
        env_name = {"lightsout": "LightsOut", "tileswap": "TileSwap"}[game] + manip.capitalize()
        if manip == "cursor":
            max_s = int(5e5)
        else:
            max_s = int(1e7)
    
        ax = ax_arr[0, idx]
        artists, labels = [], []
        legend_lists = [artists, labels]
        for ablation_abbr in ablations:
            run_stem = f"{run_prefix}_{game}_{manip}_{ablation_abbr}"
            perf_data = plot_interpolated_performance(
                ax,
                evaluations,
                run_stem,
                "covered_mean",
                max_env_steps=max_s,
                label=ABLATION_NAMES[ablation_abbr],
                legend_lists=legend_lists
            )
            perf_data["run_stem"] = run_stem
            perf_data["env"] = env_name
            perf_data["ablation"] = f"{ABLATION_NAMES[ablation_abbr]}"

            performance_data.append(perf_data)
        
        if game == "tileswap":
            # TileSwap has 12 unique skills
            ax.set_ylim(0, 13)
            ax.axhline(y=12, color="gray", linestyle="--")
            ax.set_yticks([0, 12])
        elif game == "lightsout":
            # LightsOut has 25 unique skills
            ax.set_ylim(0, 27)
            ax.axhline(y=25, color="gray", linestyle="--")
            ax.set_yticks([0, 25])
        else:
            raise ValueError
        
        ax.set_xlabel("Env. steps")
        if manip == "cursor":
            ax.set_xticks([0, int(5e5)])
            ax.set_xticklabels(["0", "0.5M"])
        else:
            ax.set_xticks([0, int(5e6), int(1e7)])
            ax.set_xticklabels(["0", "5M", "10M"])

        ax.set_title(env_name, fontsize=TITLE_FONTSIZE)

    ax_arr[0, 0].set_ylabel("Average unique\n game moves")
    fig.legend(
        artists, 
        labels, 
        fontsize=TITLE_FONTSIZE, 
        ncol=len(ablations), 
        loc="lower center",
        bbox_to_anchor=(0.5, -0.55)
    )
    
    performance_data = pd.DataFrame(performance_data)
    return performance_data
    
plot_ablations(["default",])
plt.savefig("generated/skillcoverage_default.png", bbox_inches='tight')