Install Dependencies

In [1]:
%pip install gymnasium[accept-rom-license]

Note: you may need to restart the kernel to use updated packages.


In [2]:
%pip install gymnasium[atari]

Collecting shimmy[atari]<1.0,>=0.1.0 (from gymnasium[atari])
  Using cached Shimmy-0.2.1-py3-none-any.whl (25 kB)
Installing collected packages: shimmy
  Attempting uninstall: shimmy
    Found existing installation: Shimmy 1.1.0
    Uninstalling Shimmy-1.1.0:
      Successfully uninstalled Shimmy-1.1.0
Successfully installed shimmy-0.2.1
Note: you may need to restart the kernel to use updated packages.


In [3]:
%pip install stable-baselines3[extra]

Collecting shimmy[atari]~=1.1.0 (from stable-baselines3[extra])
  Obtaining dependency information for shimmy[atari]~=1.1.0 from https://files.pythonhosted.org/packages/d5/fb/083e36bbcf325f6304bbeb2278b102c4ac8e87eb1ca771780f64decbb2f1/Shimmy-1.1.0-py3-none-any.whl.metadata
  Using cached Shimmy-1.1.0-py3-none-any.whl.metadata (3.3 kB)
Collecting autorom[accept-rom-license]~=0.6.1 (from stable-baselines3[extra])
  Using cached AutoROM-0.6.1-py3-none-any.whl (9.4 kB)
Using cached Shimmy-1.1.0-py3-none-any.whl (37 kB)
Installing collected packages: shimmy, autorom
  Attempting uninstall: shimmy
    Found existing installation: Shimmy 0.2.1
    Uninstalling Shimmy-0.2.1:
      Successfully uninstalled Shimmy-0.2.1
  Attempting uninstall: autorom
    Found existing installation: AutoROM 0.4.2
    Uninstalling AutoROM-0.4.2:
      Successfully uninstalled AutoROM-0.4.2
Successfully installed autorom-0.6.1 shimmy-1.1.0
Note: you may need to restart the kernel to use updated packages.


Import Dependencies

In [4]:
import os
import gymnasium
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3 import DQN


Constants

In [5]:
ENV_ADVENTURE = "ALE/Adventure-v5"
ENV_RENDER_MODE = "human"

CALLBACK_CHECKPOINT_DIR = './models/'
CALLBACK_LOG_DIR = './logs/'
CALLBACK_CHECK_FREQ = 10000

VERBOSE = 1

DQN_POLICY = 'MlpPolicy'
DQN_BUFFER_SIZE = 20000
DQN_LEARNING_STARTS = 200
DQN_TOTAL_TIMESTEPS = 350000
DQN_BEST_MODEL_PATH = 'best_model.zip'

DQN_LOAD_MODEL = False
DQN_NEW_MODEL = False
DQN_RENDER_MODEL = False
DQN_RENDER_EPISODES = 10

Creating the Env

In [6]:
env = gymnasium.make(ENV_ADVENTURE) 

Creating Callback

In [7]:
class TrainAndLoggingCallback(BaseCallback):
    def __init__(self, check_freq, save_path, verbose=VERBOSE):
        super(TrainAndLoggingCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.save_path = save_path

    def _init_callback(self):
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)
    
    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            model_path = os.path.join(self.save_path, '{}'.format(self.n_calls))
            self.model.save(model_path)

        return True

callback = TrainAndLoggingCallback(check_freq=CALLBACK_CHECK_FREQ, save_path=CALLBACK_CHECKPOINT_DIR)

Loading, Training and Rendering

In [8]:
model = None

if DQN_NEW_MODEL is True:
    model = DQN(DQN_POLICY, env, tensorboard_log = CALLBACK_LOG_DIR, verbose=VERBOSE, buffer_size=DQN_BUFFER_SIZE, learning_starts=DQN_LEARNING_STARTS)
    model.learn(total_timesteps=DQN_TOTAL_TIMESTEPS, callback=callback) 

elif DQN_LOAD_MODEL is True:
    model = DQN.load(DQN_BEST_MODEL_PATH, env=env)
    model.learn(total_timesteps=DQN_TOTAL_TIMESTEPS, callback=callback) 

elif DQN_RENDER_MODEL is True:
    model = DQN.load(DQN_BEST_MODEL_PATH, env=env)
    for ep in range(DQN_RENDER_EPISODES):
        obs = env.reset()
        terminated  = False
        truncated  = False
        while not terminated or not truncated:
            env.render()
            action, _ = model.predict()
            obs, reward, terminated, truncated, info = env.step(action)

  