Install Dependencies

In [None]:
%pip install gymnasium[accept-rom-license]

In [None]:
%pip install gymnasium[atari]

In [None]:
%pip install stable-baselines3[extra]

In [None]:
%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

Import Dependencies

In [None]:
import os
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3 import A2C


Constants

In [None]:
ENV_NAME = 'ALE/Breakout-v5'
ENV_N_STACKS = 9
ENV_SEED = 0
VEC_ENV_RENDER_MODE = 'human'

CALLBACK_LOG_DIR = '\logs\\breakout'
CALLBACK_CHECKPOINT_DIR = '\models\\breakout'
CALLBACK_CHECK_FREQ = 10000
CALLBACK_ON_TRAINING_MODEL = f'{ENV_N_STACKS}_stacks_breakout_v5_training_model'
CALLBACK_ON_TRAINING_END = f'{ENV_N_STACKS}_stacks_breakout_v5_training_end_'

VERBOSE = 1

ALGORITHM_BEST_MODEL_NAME = f'{ENV_N_STACKS}_stacks_breakout_v5_best_model.zip'
ALGORITHM_POLICY = 'CnnPolicy'
ALGORITHM_DEVICE = 'cuda'
ALGORITHM_TOTAL_TIMESTEPS = 1000000

ALGORITHM_RENDER_MODEL = True
ALGORITHM_RENDER_EPISODES = 10
ALGORITHM_RENDER_STEPS = 1000
ALGORITHM_PREDICT_DETERMINISTIC = True

ALGORITHM_NEW_MODEL = False
ALGORITHM_LOAD_MODEL = False

MESSAGE_RENDERING_MODEL = 'RENDERING MODEL'
MESSAGE_LOADING_MODEL = 'LOADING MODEL'
MESSAGE_TRAINING_NEW_MODEL = 'NEW MODEL'


Creating the Env

In [None]:
vec_env = make_atari_env(ENV_NAME, n_envs=ENV_N_STACKS, seed=ENV_SEED)
env = VecFrameStack(vec_env, n_stack=ENV_N_STACKS)

Creating Callback

In [None]:
class TrainAndLoggingCallback(BaseCallback):
    def __init__(self, check_freq, save_path, verbose=VERBOSE):
        super(TrainAndLoggingCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.save_path = save_path

    def _init_callback(self):
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)
    
    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            model_path = os.path.join(self.save_path, CALLBACK_ON_TRAINING_MODEL)
            self.model.save(model_path)

        return True

    def _on_training_end(self):
        model_path = os.path.join(self.save_path, f'{CALLBACK_ON_TRAINING_END}_{ALGORITHM_TOTAL_TIMESTEPS}')
        self.model.save(model_path)

callback = TrainAndLoggingCallback(check_freq=CALLBACK_CHECK_FREQ, save_path=CALLBACK_CHECKPOINT_DIR)

Loading, Training and Rendering

In [None]:
model = None

if ALGORITHM_RENDER_MODEL is True:
    print(MESSAGE_RENDERING_MODEL)
    model = A2C.load(ALGORITHM_BEST_MODEL_NAME, env=env)
    vec_env = model.get_env()

    for ep in range(ALGORITHM_RENDER_EPISODES):
        obs = vec_env.reset()
        for step in range(ALGORITHM_RENDER_STEPS):
            action, _ = model.predict(obs, deterministic=ALGORITHM_PREDICT_DETERMINISTIC)
            obs, rewards, dones, info = vec_env.step(action)
            vec_env.render(VEC_ENV_RENDER_MODE)
    
    env.close()

elif ALGORITHM_LOAD_MODEL is True:
    print(MESSAGE_LOADING_MODEL)
    model = A2C.load(ALGORITHM_BEST_MODEL_NAME, env=env, device=ALGORITHM_DEVICE)
    model.learn(total_timesteps=ALGORITHM_TOTAL_TIMESTEPS, callback=callback) 

elif ALGORITHM_NEW_MODEL is True:
    print(MESSAGE_TRAINING_NEW_MODEL)
    model = A2C(ALGORITHM_POLICY, env, tensorboard_log = CALLBACK_LOG_DIR, verbose=VERBOSE, device=ALGORITHM_DEVICE)
    model.learn(total_timesteps=ALGORITHM_TOTAL_TIMESTEPS, callback=callback) 