In [13]:
import pandas as pd
import wandb
import os
import json
from datetime import datetime
from rich.progress import Progress
from helpers.utils import ENVIRONMENTS_MAP

api = wandb.Api(timeout=60)
entity = 'jayden-teoh'
project = 'MORL-Baselines'
PLOT_TO_EXTRACT = 'eval/single_objective_return'
ENV_NAME = "MOHopperDR-v5"
ENVS = ENVIRONMENTS_MAP[ENV_NAME]

In [14]:
filters = {"group": "domain_randomization"}
try:
    runs_sample = api.runs(path=f"{entity}/{project}", per_page=1)
    total_runs = len(runs_sample)
except Exception as e:
    raise ValueError(f"Invalid entity '{entity}' or project '{project}': {str(e)}\n\nAlso, make sure you are properly authenticated. You can authenticate by using 'wandb.login() or setting the environment variable 'WANDB_API_KEY'")

In [39]:
# Default CSV file name format
date_str = datetime.now().strftime("%m%d%y")
output_file = f"{entity}-{project}-{date_str}.csv"

for env in ENVS:
    os.makedirs(f"data/{PLOT_TO_EXTRACT}/{env}", exist_ok=True)

all_runs_data = []
counter = 0
with Progress() as progress:
    task = progress.add_task("[cyan]Fetching runs...", total=total_runs)

    last_created_at = None
    while not progress.finished:
        filters = {"group": "domain_randomization", "tags": {"$in": [ENV_NAME]}}
        if last_created_at:
            filters["created_at"] = {"$gt": last_created_at}

        runs = api.runs(path=f"{entity}/{project}", per_page=100, order="created_at", filters=filters)
        for run in runs:
            if run.state != "finished":
                continue
            run_data = {
                "global_step": [],
            }
            
            hist = [run.history(keys=["global_step", f"{PLOT_TO_EXTRACT}/{env}"]) for env in ENVS]
            env_id, algo_name, seed, time = run.name.split('__')
            for i, env in enumerate(ENVS):
                data = hist[i]
                data.drop("_step", axis=1, inplace=True) # Drop the _step column, we only need the global_step
                # rename the column from "<PLOT_TO_EXTRACT>/<env_name>" to "<PLOT_TO_EXTRACT>"
                data.rename(columns={f"{PLOT_TO_EXTRACT}/{env}": PLOT_TO_EXTRACT}, inplace=True)
                os.makedirs(f"data/{PLOT_TO_EXTRACT}/{env}/{algo_name}", exist_ok=True)
                data.to_csv(f"data/{PLOT_TO_EXTRACT}/{env}/{algo_name}/seed_{seed}.csv", index=False)
            progress.update(task, advance=1)
        if len(runs) > 0:
            last_created_at = runs[-1].created_at

Output()