# This file is for creating the benchmark, with stacked frames.

# import relevant packages

In [12]:
import gymnasium as gym
import torch
import wandb

from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback, EveryNTimesteps
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack, VecTransposeImage

from feature_extraction.callbacks.wandb_reward_logging_callback import WandbRewardLoggingCallback
from utils import evaluate_policy


# Settings

In [13]:
progress_bar = True
train_model = True
eval_model = False
save_name = "a2c_breakout_benchmark_framestack"
verbose = 0
logdir = "logs/"

# Login to wanb and create a project with config

In [16]:
wandb.login()
config = dict(
    env_id="ALE/Breakout-v5",
    algorithm='PPO',
    #Hyperparams
    policy="CnnPolicy",
    n_timesteps=1000,
    n_envs=2,
    learning_rate=2.5e-4,
    n_steps=128,
    batch_size=256,
    n_epochs=4,
    gamma=0.99,
    gae_lambda=0.95,
    clip_range=0.1,
    clip_range_vf=None,
    normalize_advantage=True,
    normalize=False,
    ent_coef=0.01,
    vf_coef=0.5,
    max_grad_norm=0.5,
    use_sde=False,
    sde_sample_freq=-1,
    rollout_buffer_class=None,
    rollout_buffer_kwargs=None,
    target_kl=None,
    stats_window_size=100,
    tensorboard_log=None,
    policy_kwargs=None,
    verbose=0,
    seed=None,
    device='auto',
    _init_setup_model=True,
    env_wrapper='stable_baselines3.common.atari_wrappers.AtariWrapper',
    frame_stack=4,
)


wandb.init(project=save_name, config=config)
config = wandb.config

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011113030955554375, max=1.0…

Problem at: /tmp/ipykernel_70872/3987479232.py 39 <module>


KeyboardInterrupt: 

# Create callbacks

In [None]:
vec_eval_env = make_atari_env(config.env_id, n_envs=config.n_envs)
vec_eval_env = VecFrameStack(vec_eval_env, n_stack=config.frame_stack)
vec_eval_env = VecTransposeImage(vec_eval_env)

# WandbCallback
wandb_callback_after_eval = WandbRewardLoggingCallback()

# Save best model
eval_callback = EvalCallback(vec_eval_env, best_model_save_path="./logs/",
                             log_path="./logs/", eval_freq=max(500 // config.n_envs, 1), callback_after_eval=wandb_callback_after_eval,
                             deterministic=True, render=False)


# Create vectorized env and stack frames

In [None]:
vec_train_env = make_atari_env(config.env_id, n_envs=config.n_envs)
# Frame-stacking with 4 frames
vec_train_env = VecFrameStack(vec_train_env, n_stack=config.frame_stack)
vec_train_env = VecTransposeImage(vec_train_env)

# Create model, learn and save with wandb

In [15]:
if train_model:
    ppo_params_keys = [
        'policy', 'learning_rate', 'n_steps', 'batch_size', 'n_epochs',
        'gamma', 'gae_lambda', 'clip_range', 'clip_range_vf', 'normalize_advantage',
        'ent_coef', 'vf_coef', 'max_grad_norm', 'use_sde', 'sde_sample_freq',
        'rollout_buffer_class', 'rollout_buffer_kwargs', 'target_kl',
        'stats_window_size', 'tensorboard_log', 'policy_kwargs', 'verbose',
        'seed', 'device', '_init_setup_model'
    ]   
    
    # Step 2: Filter the config dictionary to extract only the hyperparameters for PPO
    ppo_hyperparams = {key: config[key] for key in ppo_params_keys if key in config}
    
    # Step 3: Unpack the filtered hyperparameters dictionary into the PPO constructor
    model = PPO(**ppo_hyperparams, env=vec_train_env)
    
    model = PPO(config.policy, vec_train_env, verbose=verbose)
    wandb.watch(model.policy, log="all", log_freq=1000)
    print(config.n_timesteps)
    #model.learn(total_timesteps=config.n_timesteps, callback=eval_callback, progress_bar=progress_bar)
    #model.save(save_name)

AttributeError: 'dict' object has no attribute 'policy'

# Export model to ONNX

In [None]:
#Example for creating an ONNX model (should be saved to wandb)
dummy_input = torch.randn(1, 4, 84, 84)  # Batch size of 1

torch.onnx.export(model.policy,             # Model's policy to export
                  dummy_input,              # Example input for the model
                  f"{logdir}{save_name}.onnx") # Path to save the ONNX model


# Save files to wandb

In [None]:
# Assuming `logdir` and `save_name` are defined variables

wandb.save(f"{logdir}best_model.zip")
wandb.save(f"{logdir}evaluations.npz")
wandb.save(f"{logdir}{save_name}.onnx")

# Load and evaluate Model

In [None]:
if eval_model:
    model = PPO.load("logs/best_model.zip", env=vec_eval_env)
    mean_reward, std_reward = evaluate_policy(model, model.get_env(), n_eval_episodes=2, render=False, fps=30)
    print(f"Mean reward: {mean_reward:.2f} +/- {std_reward:.2f}")
    

In [None]:
print(model.policy)

# Wrap up

In [None]:
wandb.finish()