In [1]:
!uv pip install torch>=2.9.1 wandb>=0.23.0 stable-baselines3>=2.7.0 opencv-python>=4.11.0.86 numpy>=2.3.5 gymnasium>=1.2.2 ale-py>=0.11.2

[2mUsing Python 3.12.12 environment at: /usr[0m
[2K[2mResolved [1m64 packages[0m [2min 462ms[0m[0m
[2K[2mPrepared [1m1 package[0m [2min 21ms[0m[0m
[2K[2mInstalled [1m1 package[0m [2min 7ms[0m[0m
 [32m+[39m [1mstable-baselines3[0m[2m==2.7.1[0m


In [2]:
import typing as tt, gymnasium as gym, numpy as np, collections
from gymnasium import spaces
from stable_baselines3.common import atari_wrappers

class ImageToPyTorch(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        obs = self.observation_space
        assert isinstance(obs, gym.spaces.Box)
        assert len(obs.shape) == 3
        new_shape = (obs.shape[-1], obs.shape[0], obs.shape[1])
        self.observation_space = gym.spaces.Box(
            low = obs.low.min(), high = obs.high.max(),
            shape=new_shape, dtype=obs.dtype
        )

    def observation(self, observation):
        return np.moveaxis(observation, 2, 0)


class BufferWrapper(gym.ObservationWrapper):
    def __init__(self, env, n_steps):
        super().__init__(env)
        obs = env.observation_space
        assert isinstance(obs, gym.spaces.Box)
        new_obs = gym.spaces.Box(
            obs.low.repeat(n_steps, axis=0), obs.high.repeat(n_steps, axis=0),
            dtype=obs.dtype
        )
        self.observation_space = new_obs
        self.buffer = collections.deque(maxlen=n_steps)

    def reset(self, *, seed: tt.Optional[int] = None, options: tt.Optional[dict[str, tt.Any]] = None):
        for _ in range(self.buffer.maxlen-1):
            self.buffer.append(self.env.observation_space.low)
        obs, extra = self.env.reset()
        return self.observation(obs), extra

    def observation(self, observation: np.ndarray) -> np.ndarray:
        self.buffer.append(observation)
        return np.concatenate(self.buffer)


def make_env(env_name: str, **kwargs):
    env = gym.make(env_name, **kwargs)
    env = atari_wrappers.AtariWrapper(
        env,
        noop_max=30,
        frame_skip=4,        # ← THIS WAS MISSING
        clip_reward=True,    # ← THIS WAS MISSING (or False before)
        terminal_on_life_loss=True
    )
    env = ImageToPyTorch(env)
    env = BufferWrapper(env, n_steps=4)
    return env

Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.
  return datetime.utcnow().replace(tzinfo=utc)


In [9]:
import gymnasium as gym
import ale_py # Import ale_py for Atari environments
gym.register_envs(ale_py) # Register Atari environments

from stable_baselines3 import DQN

env = make_env('PongNoFrameskip-v4')

# --- DQN Hyperparameters (matching your settings) ---
model = DQN(
    policy="CnnPolicy",
    env=env,
    learning_rate=1e-4,
    buffer_size=10000,           # REPLAY_SIZE
    learning_starts=10000,       # REPLAY_START_SIZE
    batch_size=32,
    gamma=0.99,
    train_freq=1,
    target_update_interval=1000, # SYNC_TARGET_FRAMES
    exploration_fraction=150000 / 1000000,  # ratio of frames for epsilon decay
    exploration_initial_eps=1.0,
    exploration_final_eps=0.01,
    verbose=1,
)


# --- Training ---
model.learn(total_timesteps=1_000_000)


# --- Save the trained model ---
model.save("dqn_pong_sb3")

print("Training complete!")


Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 840      |
|    ep_rew_mean      | -20.8    |
|    exploration_rate | 0.978    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 593      |
|    time_elapsed     | 5        |
|    total_timesteps  | 3360     |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 906      |
|    ep_rew_mean      | -20.2    |
|    exploration_rate | 0.952    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 641      |
|    time_elapsed     | 11       |
|    total_timesteps  | 7250     |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 899      |
|    ep_rew_mean      | -20.2 

KeyboardInterrupt: 

In [10]:
model.save("dqn_pong_sb3")

In [12]:
import gymnasium as gym
from stable_baselines3 import DQN
from stable_baselines3.common.atari_wrappers import (
    NoopResetEnv, MaxAndSkipEnv, EpisodicLifeEnv,
    ClipRewardEnv, WarpFrame
)
from gymnasium.wrappers import RecordVideo
import os

video_dir = "/content/videos"
os.makedirs(video_dir, exist_ok=True)

env = make_env('PongNoFrameskip-v4', render_mode="rgb_array")
env = RecordVideo(env, video_folder=video_dir, name_prefix="dqn_pong")

# Load the trained model
model = DQN.load("dqn_pong_sb3.zip", device="cpu")

# Where to save videos (Colab)

# Create wrapped env WITH recording

# Run one full episode
obs, _ = env.reset()
done = False
total_reward = 0

while not done:
    action, _ = model.predict(obs, deterministic=True)
    obs, reward, terminated, truncated, _ = env.step(action)
    total_reward += reward
    done = terminated or truncated

env.close()
print("Episode reward:", total_reward)
print("Video saved to:", video_dir)

  logger.warn("Unable to save last video! Did you call close()?")
  logger.warn(


Episode reward: 21.0
Video saved to: /content/videos
