Imports

In [1]:
import numpy as np
import os
from plotting import plot_sample_efficiency_curve, COLOR_MAPPING, LABEL_MAPPING, load_and_process_experiment_data, arange_frames

Load data

In [None]:
envs = [
    "mm_act_grid",
    "mm_grid",
    "mp_grid_on",
    "mp_grid_off",
    "ss"
]

run_ids = [
    "gru",
    "trxl",
    "gru_rec",
    "trxl_rec",
    "gru_25",
    "trxl_25",
    "gru_rec_25",
    "trxl_rec_25",
    "4stack",
    "16stack",
    "pos_enc",
    "ppo"
]

skip = 1

# Loop over all possible paths and load data
raw_data_dict = {}
for env in envs:
    raw_data_dict[env] = {}
    for run_id in run_ids:
        path = f"./results/{env}/{run_id}/"
        if os.path.exists(path) and os.path.isdir(path):
            data = load_and_process_experiment_data(path, "reward")

            if "ss" in env:
                data[data == 0.25] = 0.5
                data[data == 1.25] = 1.0

            # Average over the episodes dimension
            data = data.mean(axis=2)
            multiplier = 1

            raw_data_dict[env][run_id] = data[::skip] * multiplier
        else:
            continue


Process and aggregate data

(101, 5, 150)

101 Checkpoints
5 Runs
150 Episodes

In [3]:
# Aggregate data
mean_dict = {}
std_dict = {}
min_dict = {}
max_dict = {}
for env in raw_data_dict:
    mean_dict[env] = {}
    std_dict[env] = {}
    min_dict[env] = {}
    max_dict[env] = {}
    for run_id in raw_data_dict[env]:
        mean_dict[env][run_id] = raw_data_dict[env][run_id].mean(axis=1)
        std_dict[env][run_id] = raw_data_dict[env][run_id].std(axis=1)
        min_dict[env][run_id] = raw_data_dict[env][run_id].min(axis=1)
        max_dict[env][run_id] = raw_data_dict[env][run_id].max(axis=1) 

Plot mean and std across runs

mm_act_grid

In [None]:
frames = arange_frames(mean_dict["mm_act_grid"]["gru"].shape[0])
algos = ["ppo", "pos_enc", "4stack", "16stack", "gru", "trxl"]
plot_sample_efficiency_curve(frames,
                             mean_dict["mm_act_grid"],
                             std_dict["mm_act_grid"],
                             algorithms=algos,
                             colors=COLOR_MAPPING,
                             label_mapping=LABEL_MAPPING,
                             figsize=(8,7.5),
                             xlabel="Steps (in millions)",
                             ylabel="Task Progress",
                             marker="",
                             labelsize=32,
                             ticklabelsize=32,
                             spinewidth=2,
                             out="fig10_mm_act_grid.pdf")

Plot mean and std across runs

mm_grid

In [None]:
frames = arange_frames(mean_dict["mm_grid"]["gru"].shape[0])
algos = ["ppo", "pos_enc", "4stack", "16stack", "gru", "trxl"]
plot_sample_efficiency_curve(frames,
                             mean_dict["mm_grid"],
                             std_dict["mm_grid"],
                             algorithms=algos,
                             colors=COLOR_MAPPING,
                             label_mapping=LABEL_MAPPING,
                             figsize=(8,7.5),
                             xlabel="Steps (in millions)",
                             ylabel="Task Progress",
                             marker="",
                             labelsize=32,
                             ticklabelsize=32,
                             spinewidth=2,
                             out="fig10_mm_grid.pdf")

Plot mean and std across runs

mp_grid_on

In [None]:
frames = arange_frames(mean_dict["mp_grid_on"]["gru"].shape[0])
algos = ["ppo", "pos_enc", "4stack", "16stack", "gru", "trxl"]
plot_sample_efficiency_curve(frames,
                             mean_dict["mp_grid_on"],
                             std_dict["mp_grid_on"],
                             algorithms=algos,
                             colors=COLOR_MAPPING,
                             label_mapping=LABEL_MAPPING,
                             figsize=(8,7.5),
                            #  xticks=list(range(0, 820, 100)),
                             xlabel="Steps (in millions)",
                             ylabel="Success Rate",
                             marker="",
                             labelsize=32,
                             ticklabelsize=32,
                             spinewidth=2,
                             out="fig10_mp_grid_on.pdf")

Plot mean and std across runs

mp_grid_off

In [None]:
frames = arange_frames(mean_dict["mp_grid_off"]["gru"].shape[0])
algos = ["ppo", "pos_enc", "4stack", "16stack", "gru", "trxl"]
plot_sample_efficiency_curve(frames,
                             mean_dict["mp_grid_off"],
                             std_dict["mp_grid_off"],
                             algorithms=algos,
                             colors=COLOR_MAPPING,
                             label_mapping=LABEL_MAPPING,
                             figsize=(8,7.5),
                             xlabel="Steps (in millions)",
                             ylabel="Success Rate",
                             marker="",
                             labelsize=32,
                             ticklabelsize=32,
                             spinewidth=2,
                             out="fig10_mp_grid_off.pdf")

Plot mean and std across runs

ss

In [None]:
frames = arange_frames(mean_dict["ss"]["gru_rec_25"].shape[0])
algos = ["ppo", "pos_enc", "4stack", "16stack", "gru_rec_25", "trxl_rec_25", "gru_25", "trxl_25"]
plot_sample_efficiency_curve(frames,
                             mean_dict["ss"],
                             std_dict["ss"],
                             algorithms=algos,
                             colors=COLOR_MAPPING,
                             label_mapping=LABEL_MAPPING,
                             figsize=(8,7.5),
                             xlabel="Steps (in millions)",
                             ylabel="Success Rate",
                             marker="",
                             labelsize=32,
                             ticklabelsize=32,
                             spinewidth=2,
                             out="fig10_ss.pdf")