You can use this notebook to download runs that have been logged on [Weights and Biases](wandb.ai) and save them to a csv.

In [1]:
import numpy as np
import matplotlib.pyplot as plt

import gymnasium as gym
import miniworld
import wandb
import pandas as pd
from tqdm import tqdm

tqdm.pandas()

In [2]:
api = wandb.Api(timeout=30)

In [68]:
entity, project = "xxxx", "sfa_project"  # set to your entity and project name
runs = api.runs(entity + "/" + project)

run_dfs = []
for run in tqdm(runs):
    # Load the config dict for this run
    config_dict = {k: v for k, v in run.config.items() if not k.startswith('_')}
    
    # Next, extract relevant config items:
    # print([k for k, v in run.config.items()]) <- Use this to see config items
    relevant_items = ["env_name", "repl_mode", "total_timesteps"]
    config_dict = {k: v for k, v in config_dict.items() if k in relevant_items}
    config_dict["name"] = run.name
    
    # Now, load the logs:
    relevant_cols = ["global_step", "rollout/ep_len_mean", "rollout/ep_rew_mean"]
    history = run.scan_history(keys = relevant_cols)
    log_data = pd.DataFrame(history)

    # Duplicate config to match size of logs and concatenate
    conf_data = pd.DataFrame([config_dict]*len(log_data))
    run_df = pd.concat([conf_data, log_data], axis=1)
    
    run_dfs.append(run_df)

all_runs = pd.concat(run_dfs, ignore_index=True, axis=0)

100%|███████████████████████████████████████████████████████████████████| 91/91 [09:13<00:00,  6.09s/it]


In [69]:
all_runs

Unnamed: 0,env_name,repl_mode,total_timesteps,name,global_step,rollout/ep_len_mean,rollout/ep_rew_mean
0,MiniWorld-WallGap-v0,sfa,2000000.0,decent-night-122,384.0,300.000000,0.000000
1,MiniWorld-WallGap-v0,sfa,2000000.0,decent-night-122,512.0,300.000000,0.000000
2,MiniWorld-WallGap-v0,sfa,2000000.0,decent-night-122,640.0,300.000000,0.000000
3,MiniWorld-WallGap-v0,sfa,2000000.0,decent-night-122,768.0,300.000000,0.000000
4,MiniWorld-WallGap-v0,sfa,2000000.0,decent-night-122,896.0,300.000000,0.000000
...,...,...,...,...,...,...,...
1171842,MiniWorld-StarMazeArm-v0,,1000000.0,peach-wildflower-15,993280.0,357.339996,0.784355
1171843,MiniWorld-StarMazeArm-v0,,1000000.0,peach-wildflower-15,995328.0,371.839996,0.774421
1171844,MiniWorld-StarMazeArm-v0,,1000000.0,peach-wildflower-15,997376.0,371.869995,0.774417
1171845,MiniWorld-StarMazeArm-v0,,1000000.0,peach-wildflower-15,999424.0,386.399994,0.764480


In [75]:
filtered_runs = all_runs[all_runs["repl_mode"].notna()]
len(filtered_runs["name"].unique())

80

In [76]:
filtered_runs.to_csv("downloaded_runs.csv", index=False)