In [1]:
import copy
import matplotlib.pyplot as plt
import os
import pandas as pd
import pickle
import seaborn as sns
from tqdm import tqdm

import gym
import gymnasium
from gymnasium.utils.save_video import save_video

from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.dqn import DQN

import wandb

import warnings
warnings.filterwarnings('ignore')

In [2]:
PROJECT_PATH = 'boctrl-c/synaptogen'

SWEEP_IDS = { # bio
    'CartPole-v1': 'jmeb03mj',
    'MountainCar-v0': 'xtlyudvu',
    'LunarLander-v2': 'n3fosy57',
    'Acrobot-v1': '9ixwee6v'
}
DATA_DIR = 'data'
DEVICE = 'cuda:1'
CKPTS_DIR = 'ckpts'

BIO = True
SUFFIX = ''
if BIO: SUFFIX += '-bio'

In [3]:
api = wandb.Api()

# generate a dataframe with data from sweeps
runs = []
for env_name, sweep_id in SWEEP_IDS.items():
    sweep = api.sweep(os.path.join(PROJECT_PATH, 'sweeps', sweep_id))

    for run in sweep.runs:
        row = {}
        row.update(run.config), row.update(run.summary)
        row['name'] = run.name
        row['run_id'] = run.id
        row['env_name'] = env_name
        runs += [row]

df = pd.DataFrame(runs)

# evaluate models
for i, run in tqdm(df.iterrows()):
    # load the model
    model = DQN.load(os.path.join(CKPTS_DIR + SUFFIX, run['run_id'], 'best_model.zip'), device=DEVICE)
    
    # evaluate
    mean_reward, reward_std = evaluate_policy(model, gym.make(run['env_name']), n_eval_episodes=10)
    
    # update the dataframe with the model's mean reward and std
    df.loc[i, 'mean_reward'] = mean_reward
    df.loc[i, 'reward_std'] = reward_std

# save dataframe to file
# df.to_csv(os.path.join(DATA_DIR, f'runs{SUFFIX}.csv'), index=False)

36it [02:10,  3.63s/it]


In [4]:
df = pd.read_csv(os.path.join(DATA_DIR, f'runs{SUFFIX}.csv'))

shown_cols = ['env_name', 'run_id', 'num_genes', 'learning_rate', 'seed', 'mean_reward', 'reward_std']

best_runs = df[shown_cols].sort_values(
    by=['mean_reward', 'reward_std', 'num_genes'],
    ascending=[False, True, True]
).groupby('env_name').first().reset_index()
display(best_runs)

Unnamed: 0,env_name,run_id,num_genes,learning_rate,seed,mean_reward,reward_std
0,Acrobot-v1,w841xq3h,39,0.0003,3,-74.9,8.665449
1,CartPole-v1,5muxdd1h,39,0.0003,2,500.0,0.0
2,LunarLander-v2,60mpfq71,39,0.003,2,204.62058,55.020423
3,MountainCar-v0,jxdknkbh,39,0.03,1,-119.9,24.881519


In [5]:
# save episode videos
for _, run in best_runs.iterrows():
    model = DQN.load(os.path.join(CKPTS_DIR, run['run_id'], 'best_model.zip'), device=DEVICE)

    env_name = run['env_name']
    env = gymnasium.make(env_name, render_mode='rgb_array_list')

    obs, _ = env.reset()
    while True:
        action, _ = model.predict(obs)
        obs, _, terminated, truncated, _ = env.step(action)

        if terminated or truncated:
            save_video(
                env.render(),
                'videos',
                name_prefix=env_name,
                fps=env.metadata['render_fps'],
            )
            
            break