In [1]:
!pip install stable-baselines3==2.1.0 gymnasium[atari] ale-py imageio matplotlib

Collecting stable-baselines3==2.1.0
  Downloading stable_baselines3-2.1.0-py3-none-any.whl.metadata (5.2 kB)
Collecting gymnasium<0.30,>=0.28.1 (from stable-baselines3==2.1.0)
  Downloading gymnasium-0.29.1-py3-none-any.whl.metadata (10 kB)
INFO: pip is looking at multiple versions of gymnasium[atari] to determine which version is compatible with other requirements. This could take a while.
Collecting gymnasium[atari]
  Downloading gymnasium-1.1.0-py3-none-any.whl.metadata (9.4 kB)
  Downloading gymnasium-1.0.0-py3-none-any.whl.metadata (9.5 kB)
Collecting shimmy<1.0,>=0.1.0 (from shimmy[atari]<1.0,>=0.1.0; extra == "atari"->gymnasium[atari])
  Downloading Shimmy-0.2.1-py3-none-any.whl.metadata (2.3 kB)
Collecting ale-py
  Downloading ale_py-0.8.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.1 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.13->stable-baselines3==2.1.0)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.

In [2]:
!pip install gymnasium[accept-rom-license]

Collecting autorom~=0.4.2 (from autorom[accept-rom-license]~=0.4.2; extra == "accept-rom-license"->gymnasium[accept-rom-license])
  Downloading AutoROM-0.4.2-py3-none-any.whl.metadata (2.8 kB)
Collecting AutoROM.accept-rom-license (from autorom[accept-rom-license]~=0.4.2; extra == "accept-rom-license"->gymnasium[accept-rom-license])
  Downloading AutoROM.accept-rom-license-0.6.1.tar.gz (434 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m434.7/434.7 kB[0m [31m25.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Downloading AutoROM-0.4.2-py3-none-any.whl (16 kB)
Building wheels for collected packages: AutoROM.accept-rom-license
  Building wheel for AutoROM.accept-rom-license (pyproject.toml) ... [?25l[?25hdone
  Created wheel for AutoROM.accept-rom-license: filename=autorom_accept_rom_license-0.6.1-py3-none

In [4]:
import gymnasium as gym
from stable_baselines3 import DQN
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.logger import configure

env_id = "BreakoutNoFrameskip-v4"

# Create Atari environment with 4 parallel instances
env = make_atari_env(env_id, n_envs=4, seed=42)
env = VecFrameStack(env, n_stack=4)

# Use MLP policy
policy = "MlpPolicy"

model = DQN(
    policy,
    env,
    learning_rate=1e-3,  # Faster learning
    gamma=0.99,
    batch_size=64,  # Process more samples per step
    buffer_size=10_000,  # Smaller buffer
    exploration_initial_eps=1.0,
    exploration_final_eps=0.05,  # Reduce exploration decay speed
    exploration_fraction=0.2,
    target_update_interval=500,  # Update more frequently
    train_freq=1,  # Train more often
    gradient_steps=2,  # More gradient updates per step
    verbose=1,
    tensorboard_log="./dqn_tensorboard/"
)

# Configure logging
logger = configure("./dqn_logs_mlp/", ["stdout", "csv", "tensorboard"])
model.set_logger(logger)

# Checkpoint callback
checkpoint_callback = CheckpointCallback(save_freq=5000, save_path="./checkpoints_mlp/")

# Train for 500,000 steps
model.learn(
    total_timesteps=500_000,
    callback=checkpoint_callback,
    tb_log_name="dqn_breakout_mlp_fast"
)

# Save the trained model
model.save("dqn_breakout_mlp_fast")

env.close()


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
| train/              |          |
|    learning_rate    | 0.001    |
|    loss             | 0.000351 |
|    n_updates        | 202462   |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 666      |
|    ep_rew_mean      | 1.2      |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 13568    |
|    fps              | 263      |
|    time_elapsed     | 1724     |
|    total_timesteps  | 455132   |
| train/              |          |
|    learning_rate    | 0.001    |
|    loss             | 0.00138  |
|    n_updates        | 202564   |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 670      |
|    ep_rew_mean      | 1.23     |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 1