Imports

In [1]:
import numpy as np
import os
from plotting import plot_sample_efficiency_curve, COLOR_MAPPING, LABEL_MAPPING, load_and_process_experiment_data, arange_frames

Load data

In [None]:
# Determine tiles visited
path = f"./results/mp_off_dense/trxl_rec/"
original_shape = (51, 5, 50, 3)
rewards = load_and_process_experiment_data(path, "reward").reshape(original_shape)
successes = load_and_process_experiment_data(path, "success").reshape(original_shape)
seeds = load_and_process_experiment_data(path, "seed").reshape(original_shape)
successes_agg = successes[:, :, :, :].max(axis=(0, 1, 3))
tiles_visited = rewards[:, :, :, :].max(axis=(0, 1, 3)) - 0.9
print(f"Tiles visited: {tiles_visited * 10}")

In [None]:
envs = [
    "mm10",
    "mp_off_dense",
    "ss"
]

run_ids = [
    "gru",
    "trxl",
    "gru_rec",
    "trxl_rec",
    "gru_25",
    "trxl_25",
    "gru_rec_25",
    "trxl_rec_25",
    "lstm",
    "gtrxl_b0",
    "gtrxl_b2",
    "gtrxl_b0_rec",
    "gtrxl_b2_rec"
]

skip = 1

# Loop over all possible paths and load data
raw_data_dict = {}
for env in envs:
    raw_data_dict[env] = {}
    for run_id in run_ids:
        path = f"./results/{env}/{run_id}/"
        if os.path.exists(path) and os.path.isdir(path):
            data = load_and_process_experiment_data(path, "reward")

            if "dense" in env:
                success_data = load_and_process_experiment_data(path, "success")
                original_shape = (success_data.shape[0], 5, 50, 3)
                target_shape = (success_data.shape[0], 5, 150)
                success_data = success_data.reshape(original_shape)
                success_data = (success_data * 0.9)
                reward_data = data.reshape(original_shape)
                reward_data = reward_data - success_data
                data = reward_data / tiles_visited[np.newaxis, np.newaxis, :, np.newaxis]
                data = data.reshape(target_shape)

            if "ss" in env:
                data[data == 0.25] = 0.5
                data[data == 1.25] = 1.0

            # Average over the episodes dimension
            data = data.mean(axis=2)

            multiplier = 1
            # if "mm" in env:
            #     multiplier = 10

            raw_data_dict[env][run_id] = data[::skip] * multiplier
        else:
            continue



Process and aggregate data

(101, 5, 150)

101 Checkpoints
5 Runs
150 Episodes

In [4]:
# Aggregate data
mean_dict = {}
std_dict = {}
min_dict = {}
max_dict = {}
for env in raw_data_dict:
    mean_dict[env] = {}
    std_dict[env] = {}
    min_dict[env] = {}
    max_dict[env] = {}
    for run_id in raw_data_dict[env]:
        mean_dict[env][run_id] = raw_data_dict[env][run_id].mean(axis=1)
        std_dict[env][run_id] = raw_data_dict[env][run_id].std(axis=1)
        min_dict[env][run_id] = raw_data_dict[env][run_id].min(axis=1)
        max_dict[env][run_id] = raw_data_dict[env][run_id].max(axis=1) 

Plot mean and std across runs

mm10

In [None]:
frames = arange_frames(mean_dict["mm10"]["gru"].shape[0])
plot_sample_efficiency_curve(frames,
                             mean_dict["mm10"],
                             std_dict["mm10"],
                             colors=COLOR_MAPPING,
                             label_mapping=LABEL_MAPPING,
                             figsize=(12,7.5),
                             xlabel="Steps (in millions)",
                             ylabel="Task Progress",
                             marker="",
                             labelsize=32,
                             ticklabelsize=32,
                             spinewidth=2,
                             out="fig11_mm10.pdf",)

Plot mean and std across runs

mp_off_dense

In [None]:
frames = arange_frames(mean_dict["mp_off_dense"]["gru"].shape[0])
plot_sample_efficiency_curve(frames,
                             mean_dict["mp_off_dense"],
                             std_dict["mp_off_dense"],
                             colors=COLOR_MAPPING,
                             label_mapping=LABEL_MAPPING,
                             figsize=(12,7.5),
                             xlabel="Steps (in millions)",
                             ylabel="Task Progress",
                             marker="",
                             labelsize=32,
                             ticklabelsize=32,
                             spinewidth=2,
                             out="fig11_mp_off_dense.pdf",)

Plot mean and std across runs

ss

In [None]:
frames = arange_frames(mean_dict["ss"]["gru_rec_25"].shape[0])
algos = ["gru_rec_25", "trxl_rec_25", "gru_25", "trxl_25", "lstm", "gtrxl_b0_rec", "gtrxl_b2_rec"]
plot_sample_efficiency_curve(frames,
                             mean_dict["ss"],
                             std_dict["ss"],
                             algorithms=algos,
                             colors=COLOR_MAPPING,
                             label_mapping=LABEL_MAPPING,
                             figsize=(12,7.5),
                             xlabel="Steps (in millions)",
                             ylabel="Task Progress",
                             marker="",
                             labelsize=32,
                             ticklabelsize=32,
                             spinewidth=2,
                             out="fig11_ss.pdf")

Wall-time

mp_off_dense

In [8]:
# GRU wall-time
num_checkpoints = 51
checkpoint_interval = 500
frames = {}
frames["gru"] = []
for i in range(num_checkpoints):
    frames["gru"].append(5.19 * checkpoint_interval * i / 3600)
frames["lstm"] = []
for i in range(num_checkpoints):
    frames["lstm"].append(5.09 * checkpoint_interval * i / 3600)
frames["trxl"] = []
for i in range(num_checkpoints):
    frames["trxl"].append(7.39 * checkpoint_interval * i / 3600)
frames["gtrxl_b0"] = []
for i in range(num_checkpoints):
    frames["gtrxl_b0"].append(6.68 * checkpoint_interval * i / 3600)

In [None]:
plot_sample_efficiency_curve(frames,
                             mean_dict["mp_off_dense"],
                             std_dict["mp_off_dense"],
                             algorithms=frames.keys(),
                             colors=COLOR_MAPPING,
                             label_mapping=LABEL_MAPPING,
                             figsize=(18,7.5),
                             xlabel="Hours",
                             ylabel="Task Progress",
                             marker="",
                             labelsize=32,
                             ticklabelsize=32,
                             spinewidth=2,
                             out="fig12_mp_off_dense.pdf",)