In [None]:
import gymnasium as gym
from stable_baselines3 import A2C
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold
from stable_baselines3.common.vec_env import VecFrameStack, VecTransposeImage
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.evaluation import evaluate_policy
import os

In [None]:
env_name = "ALE/Breakout-v5"
def create_env(env_name, n_envs=1):
    env = make_atari_env(env_name, n_envs=n_envs, seed=0)
    env = VecTransposeImage(env) 
    env = VecFrameStack(env, n_stack=4)
    return env

train_env = create_env(env_name, n_envs=4)
eval_env = create_env(env_name, n_envs=1)

In [None]:
env = gym.make(env_name, render_mode="human")

In [None]:
episodes = 5
env.metadata['render_fps'] = 30
for episode in range(1, episodes+1):
    obs, _ = env.reset()
    done = False
    score = 0 
    
    while not done:
        env.render()
        action = env.action_space.sample()
        obs, reward, done, truncared, info = env.step(action)
        score+=reward
    print('Episode:{} Score:{}'.format(episode, score))
env.close()

In [None]:
log_path = os.path.join('Training', 'Logs')
model = A2C("CnnPolicy", train_env, verbose=1, tensorboard_log=log_path)

In [None]:
model.learn(total_timesteps=100000)

In [None]:
a2c_path = os.path.join('Training', 'Saved Models', 'a2c_breakout_v5')
model.save(a2c_path)

In [None]:
evaluate_policy(model, eval_env, n_eval_episodes=50, render = True)

In [None]:
save_path = os.path.join('Training', 'Saved Models')
stop_callback = StopTrainingOnRewardThreshold(reward_threshold=20, verbose=1)
eval_callback = EvalCallback(eval_env, 
                             callback_on_new_best=stop_callback, 
                             eval_freq=10000, 
                             best_model_save_path=save_path, 
                             verbose=1)

In [None]:
train_env = create_env(env_name, n_envs=4)
model = A2C('CnnPolicy', train_env, verbose=1, tensorboard_log=log_path, policy_kwargs={
    'net_arch': {
        'pi': [128, 128, 128, 128],
        'vf': [128, 128, 128, 128]
    }
})

In [None]:
model.learn(total_timesteps=20000, callback=eval_callback)