In [21]:
gym_id_to_task = {
    'ALE/Pong-v5': 'atari_pong',
    'ALE/MsPacman-v5': 'atari_ms_pacman',
}

RUN_IDS = ["ez5lxzjq", "t57gvwnw"]
ROLLING_WINDOW = 10

In [22]:
import wandb
from tqdm import tqdm

api = wandb.Api()
data = []

def get_run_data(api, run_id):
    run = api.run(f"armandpl/minidream_dev/{run_id}")
    history = run.scan_history(keys=["episode_return", "global_step"])

    xs = []
    ys = []

    for row in history:
        xs.append(row["global_step"])
        ys.append(row["episode_return"])

    # rolling average over ten episodes
    ys_rolling = [sum(ys[i-ROLLING_WINDOW:i])/ROLLING_WINDOW for i in range(ROLLING_WINDOW, len(ys))]
    xs_rolling = xs[ROLLING_WINDOW:]

    task = gym_id_to_task[run.config["env"]["env_id"]]
    return {
        "task": task,
        "method": "this repo, 10 ep rolling average",
        "seed": 0,
        "xs": xs_rolling,
        "ys": ys_rolling
    }

for run_id in tqdm(RUN_IDS):
    data.append(get_run_data(api, run_id))

100%|██████████| 2/2 [01:07<00:00, 33.60s/it]


In [4]:
# download official scores
!wget https://github.com/danijar/dreamerv3/raw/main/scores/data/atari100k_dreamerv3.json.gz

--2024-03-17 12:48:00--  https://github.com/danijar/dreamerv3/raw/main/scores/data/atari100k_dreamerv3.json.gz
Resolving github.com (github.com)... 140.82.121.4
Connecting to github.com (github.com)|140.82.121.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/danijar/dreamerv3/main/scores/data/atari100k_dreamerv3.json.gz [following]
--2024-03-17 12:48:01--  https://raw.githubusercontent.com/danijar/dreamerv3/main/scores/data/atari100k_dreamerv3.json.gz
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 2606:50c0:8000::154, 2606:50c0:8003::154, 2606:50c0:8001::154, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|2606:50c0:8000::154|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 17529 (17K) [application/octet-stream]
Saving to: ‘atari100k_dreamerv3.json.gz’


2024-03-17 12:48:01 (8,35 MB/s) - ‘atari100k_dreamerv3.json.gz’ saved [17529/17529]



In [5]:
import gzip
import json

# load the official implementation data
with gzip.open("./atari100k_dreamerv3.json.gz", 'rb') as f:
    official_scores = json.load(f)

In [37]:
import matplotlib.pyplot as plt
import numpy as np

unique_tasks = set([x["task"] for x in data])

for task in unique_tasks:
    official = [x for x in official_scores if x["task"] == task]

    # interpolate their data so its easier to do the mean + std
    # TODO maybe that's a mistake though?
    for i in range(len(official)):
        xs = official[i]["xs"]
        ys = official[i]["ys"]

        xs = [int(x//4) for x in xs] # divide by 4 bc of frameskip=4
        new_xs = range(max(xs))
        new_ys = interpolated_values = np.interp(x=new_xs, xp=xs, fp=ys)
        official[i]["xs"] = new_xs
        official[i]["ys"] = new_ys
    
    # get the mean and std at each point
    mean = np.nanmean([x["ys"] for x in official], axis=0)
    std = np.nanstd([x["ys"] for x in official], axis=0)
    low = mean - std
    high = mean + std

    # plot curve
    plt.fill_between(official[0]["xs"], low, high, alpha=0.3)
    # plot mean
    plt.plot(official[0]["xs"], mean, label="danijar/dreamerv3")

    ours = [x for x in data if x["task"] == task][0] # assume only one curve in our data
    plt.plot(ours["xs"], ours["ys"], label=f"this repo, {ROLLING_WINDOW} ep rolling avg")

    # setup plot
    plt.title(task)
    plt.xlabel("Steps")
    plt.ylabel("Episode Return")
    plt.legend()
    plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
    plt.savefig(f"../scores/{task}.jpg")
    plt.close()