Imports

In [2]:
import numpy as np
import os
from plotting import plot_sample_efficiency_curve, COLOR_MAPPING, LABEL_MAPPING, load_and_process_experiment_data, arange_frames

Load data

In [None]:
envs = [
    "emm",
]

run_ids = [
    "gru_384",
    "trxl",
    "trxl_rec",
    "trxl_lr",
    "trxl_qpos",
    "trxl_gt",
    "trxl_gt_qpos_lr",
    "trxl_lr_learned",
    "trxl_lr_relative"
]

skip = 1

# Loop over all possible paths and load data
raw_data_dict = {}
for env in envs:
    raw_data_dict[env] = {}
    for run_id in run_ids:
        path = f"./results/{env}/{run_id}/"
        if os.path.exists(path) and os.path.isdir(path):
            data = load_and_process_experiment_data(path, "reward")
            # Average over the episodes dimension
            data = data.mean(axis=2)

            multiplier = 1
            if env == "emm":
                multiplier = 10
            elif env == "emp":
                multiplier = 1
            elif env == "ess":
                multiplier = 4

            raw_data_dict[env][run_id] = data[::skip] * multiplier
        else:
            continue


Process and aggregate data

(101, 5, 150)

101 Checkpoints
5 Runs
150 Episodes

In [9]:
# Aggregate data
mean_dict = {}
std_dict = {}
min_dict = {}
max_dict = {}
for env in raw_data_dict:
    mean_dict[env] = {}
    std_dict[env] = {}
    min_dict[env] = {}
    max_dict[env] = {}
    for run_id in raw_data_dict[env]:
        mean_dict[env][run_id] = raw_data_dict[env][run_id].mean(axis=1)
        std_dict[env][run_id] = raw_data_dict[env][run_id].std(axis=1)
        min_dict[env][run_id] = raw_data_dict[env][run_id].min(axis=1)
        max_dict[env][run_id] = raw_data_dict[env][run_id].max(axis=1) 

# Setup frames
frames = arange_frames(mean_dict["emm"]["gru_384"].shape[0], skip)

Plot mean and std across runs

EMM

In [None]:
algos = [
    "gru_384",
    "trxl",
    "trxl_rec",
    "trxl_lr",
    "trxl_qpos",
    "trxl_gt",
    "trxl_gt_qpos_lr"
]

plot_sample_efficiency_curve(frames,
                             mean_dict["emm"],
                             std_dict["emm"],
                             algorithms=algos,
                             colors=COLOR_MAPPING,
                             label_mapping=LABEL_MAPPING,
                             figsize=(16.5,4.5),
                             xticks=list(range(0, 820, 100)),
                             xlabel="Steps (in millions)",
                             ylabel="Commands Executed",
                             marker="",
                             out="fig8.pdf")

In [None]:
algos = [
    "trxl",
    "trxl_lr",
    "trxl_lr_learned",
    "trxl_lr_relative"
]

plot_sample_efficiency_curve(frames,
                             mean_dict["emm"],
                             std_dict["emm"],
                             algorithms=algos,
                             colors=COLOR_MAPPING,
                             label_mapping=LABEL_MAPPING,
                             figsize=(16.5,4.5),
                             xticks=list(range(0, 820, 100)),
                             xlabel="Steps (in millions)",
                             ylabel="Commands Executed",
                             marker="",
                             out="fig9.pdf")

In [None]:
algos = [
    "trxl",
    "trxl_lr",
    "trxl_lr_learned",
    "trxl_lr_relative"
]

plot_sample_efficiency_curve(frames,
                             raw_data_dict["emm"],
                             algorithms=algos,
                             colors=COLOR_MAPPING,
                             label_mapping=LABEL_MAPPING,
                             figsize=(16.5,4.5),
                             xticks=list(range(0, 820, 100)),
                             xlabel="Steps (in millions)",
                             ylabel="Commands Executed",
                             marker="")