Install Dependencies

In [None]:
%pip install gymnasium[accept-rom-license]

In [None]:
%pip install gymnasium[atari]

In [None]:
%pip install stable-baselines3[extra]

Import Dependencies

In [None]:
import os
import gymnasium
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3 import DQN


Constants

In [None]:
# ENV_NAME = 'ALE/Adventure-v5'
ENV_NAME = 'ALE/SpaceInvaders-v5'
ENV_RENDER_MODE = 'human'

CALLBACK_CHECKPOINT_DIR = '\models'
CALLBACK_LOG_DIR = '\logs'
CALLBACK_CHECK_FREQ = 10000

VERBOSE = 1

ALGORITHM_POLICY = 'MlpPolicy'
ALGORITHM_BUFFER_SIZE = 50000
ALGORITHM_LEARNING_STARTS = 1000
ALGORITHM_TOTAL_TIMESTEPS = 1000000
ALGORITHM_BEST_MODEL_PATH = 'best_model.zip'

ALGORITHM_RENDER_MODEL = False
ALGORITHM_RENDER_EPISODES = 10

ALGORITHM_LOAD_MODEL = False
ALGORITHM_NEW_MODEL = True

Creating the Env

In [None]:
env = gymnasium.make(ENV_NAME, render_mode=ENV_RENDER_MODE) 

Creating Callback

In [None]:
class TrainAndLoggingCallback(BaseCallback):
    def __init__(self, check_freq, save_path, verbose=VERBOSE):
        super(TrainAndLoggingCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.save_path = save_path

    def _init_callback(self):
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)
    
    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            model_path = os.path.join(self.save_path, '{}'.format(self.n_calls))
            self.model.save(model_path)

        return True

callback = TrainAndLoggingCallback(check_freq=CALLBACK_CHECK_FREQ, save_path=CALLBACK_CHECKPOINT_DIR)

Loading, Training and Rendering

In [None]:
model = None

if ALGORITHM_RENDER_MODEL is True:
    print('RENDERING MODEL')
    if ALGORITHM_LOAD_MODEL is True:
        model = DQN.load(ALGORITHM_BEST_MODEL_PATH, env=env)
    else:
        model = DQN(ALGORITHM_POLICY, env, tensorboard_log = CALLBACK_LOG_DIR, verbose=VERBOSE, buffer_size=ALGORITHM_BUFFER_SIZE, learning_starts=ALGORITHM_LEARNING_STARTS)

    for ep in range(ALGORITHM_RENDER_EPISODES):
        obs = env.reset()
        terminated  = False
        truncated  = False
        while not terminated or not truncated:
            env.render()
            obs, reward, terminated, truncated, info = env.step(env.action_space.sample())
    env.close()

elif ALGORITHM_LOAD_MODEL is True:
    print('LOADING MODEL')
    model = DQN.load(ALGORITHM_BEST_MODEL_PATH, env=env)
    model.learn(total_timesteps=ALGORITHM_TOTAL_TIMESTEPS, callback=callback) 

elif ALGORITHM_NEW_MODEL is True:
    print('NEW MODEL')
    model = DQN(ALGORITHM_POLICY, env, tensorboard_log = CALLBACK_LOG_DIR, verbose=VERBOSE, buffer_size=ALGORITHM_BUFFER_SIZE, learning_starts=ALGORITHM_LEARNING_STARTS)
    model.learn(total_timesteps=ALGORITHM_TOTAL_TIMESTEPS, callback=callback) 
