In [2]:
from datetime import datetime

from stable_baselines3 import DQN
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack

In [5]:
env = make_atari_env("ALE/Pong-v5", n_envs=1, seed=42, env_kwargs={"full_action_space": False, "frameskip": 1})
# Frame-stacking with 4 frames
env = VecFrameStack(env, n_stack=4)

current_datetime_str = datetime.now().strftime("%d-%m-%Y_%H:%M:%S")

checkpoint_callback = CheckpointCallback(
    save_freq=500_000,
    save_path=f"runs/checkpoints/CNN_DQN_Pong-v5_{current_datetime_str}",
    name_prefix="CNN_DQN_Pong-v5",
    save_replay_buffer=True,
)

model = DQN("CnnPolicy",
            env,
            verbose=1,
            tensorboard_log="runs/logs/",
            batch_size=32,
            buffer_size=10_000,
            exploration_final_eps=0.01,
            exploration_fraction=0.1,
            gradient_steps=1,
            learning_rate=0.0001,
            learning_starts=100_000,
            optimize_memory_usage=True,
            replay_buffer_kwargs={"handle_timeout_termination": False},
            target_update_interval=1000,
            train_freq=4,)

Using cuda device
Wrapping the env in a VecTransposeImage.


In [6]:
model.learn(10_000_000, tb_log_name=f"CNN_DQN_Pong-v5_{current_datetime_str}", callback=checkpoint_callback)

Logging to runs/logs/CNN_DQN_Pong-v5_11-04-2023_14:37:12_1
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.63e+03 |
|    ep_rew_mean      | -20.5    |
|    exploration_rate | 0.996    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 1080     |
|    time_elapsed     | 3        |
|    total_timesteps  | 3601     |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.62e+03 |
|    ep_rew_mean      | -20.5    |
|    exploration_rate | 0.993    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 1070     |
|    time_elapsed     | 6        |
|    total_timesteps  | 7199     |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.63e+03 |
|    ep_rew_mean      | -20.6    |
|    exploration_rate | 0.989  

KeyboardInterrupt: 