# ATARI Pong DQN — Train or Load a Checkpoint in Google Colab
This notebook lets you:
- Train a DQN agent on Pong from scratch or resume from a checkpoint.
- Alternatively, load an existing checkpoint and run an evaluation with inline rendering.

It assumes the Python files from this repo are available in the same directory (nn.py, helper.py, rpbuf.py, checkpoint.py, config.py, etc.).


## 1) Optional: Mount Google Drive (Colab)
If your checkpoints or project folder are on Drive, mount it here.


In [1]:
import sys, os

IN_COLAB = 'COLAB_RELEASE_TAG' in os.environ or 'COLAB_GPU' in os.environ
if IN_COLAB:
  !git clone https://github.com/fypeex/ATARI.git

  import sys
  sys.path.append('/content/ATARI')

## 2) Clone helper code from repository and Install dependencies (Colab)
If running locally with dependencies already installed, you can skip this.


In [2]:
import subprocess
if IN_COLAB:
    print('Installing packages for Colab...')
    subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q',
                           'gymnasium[atari,accept-rom-license]>=0.29.1',
                           'ale-py', 'autorom', 'opencv-python', 'pillow'])
    # Download Atari ROMs (required by ALE); accepts license automatically
    try:
        subprocess.check_call([sys.executable, '-m', 'AutoROM', '--accept-license'])
    except Exception as e:
        print('AutoROM download step failed or already satisfied:', e)


## 3) Imports and setup


In [3]:
import os
import time
from typing import Deque

import gymnasium as gym
import numpy as np
import torch
import ale_py  # ensure ALE is registered
from IPython.display import clear_output, display
from PIL import Image

from ipywidgets import widgets

if IN_COLAB:
    from ATARI.nn import DQN
    from ATARI.helper import preprocess_obs, init_state_stack, get_state_from_stack
    from ATARI.rpbuf import ReplayBuffer
    from ATARI.checkpoint import save_checkpoint, load_checkpoint
    from ATARI.config import DEVICE, VALID_ACTIONS, CHECKPOINT_DIR
else:
    from nn import DQN
    from helper import preprocess_obs, init_state_stack, get_state_from_stack
    from rpbuf import ReplayBuffer
    from checkpoint import save_checkpoint, load_checkpoint
    from config import DEVICE, VALID_ACTIONS, CHECKPOINT_DIR

if IN_COLAB:
    CHECKPOINT_DIR = "./ATARI/" + CHECKPOINT_DIR
    print(CHECKPOINT_DIR)

device = DEVICE
print('Using device:', device)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)


Using device: cuda


## 4) Configuration
Adjust training hyperparameters, mode, and checkpoint behavior here.


In [4]:
# Mode: 'train' to train/resume, 'eval' to just run episodes with a loaded policy
MODE = 'eval'  # 'train' or 'eval'

# Checkpoint loading option at startup:
#   'none'   -> start fresh
#   'latest' -> load newest checkpoint in models/
#   path     -> a specific checkpoint file path (e.g., 'models/dqn_pong_model_100000.pth')
CHECKPOINT_OPTION = 'latest'

# Evaluation episodes (used if MODE == 'eval')
EVAL_EPISODES = 3

# Training hyperparameters (defaults are modest for Colab)
GAMMA = 0.99
BATCH_SIZE = 32
LR = 1e-4
EPS_START = 1.0
EPS_END = 0.1
EPS_DECAY_FRAMES = 1_000_000
TARGET_UPDATE = 10_000
REPLAY_INIT = 5_000   # a smaller initial buffer fill for Colab
REPLAY_CAP = 100_000  # smaller buffer to fit memory constraints
MAX_FRAMES = 200_000  # adjust for Colab session length
MOVE_PENALTY = 0.0    # optional movement penalty shaping (0 is fine)


## 5) Utilities for training and evaluation and rendering in colab


In [5]:
def update_video(frame_array, scale=2):
    img = Image.fromarray(frame_array)
    if scale != 1:
        img = img.resize((img.width * scale, img.height * scale))

    with video_out:
        video_out.clear_output(wait=True)
        display(img)

def make_env(render_mode = None):
    env = gym.make(
        'ALE/Pong-v5',
        render_mode=render_mode
    )
    return env

def select_action(policy_net, state, frame_idx):
    eps = max(EPS_END, EPS_START - (EPS_START - EPS_END) * frame_idx / EPS_DECAY_FRAMES)
    if np.random.rand() < eps:
        a_idx = np.random.randint(len(VALID_ACTIONS))
    else:
        with torch.no_grad():
            s = torch.from_numpy(state).unsqueeze(0).to(device)
            s = s.float() / 255.0
            q_values = policy_net(s)
            a_idx = int(torch.argmax(q_values, dim=1).item())
    return a_idx, eps


def optimize_model(policy_net, target_net, optimizer, replay):
    if len(replay) < BATCH_SIZE:
        return

    states, actions, rewards, next_states, dones = replay.sample(BATCH_SIZE)

    states = torch.from_numpy(states).float().to(device) / 255.0
    next_states = torch.from_numpy(next_states).float().to(device) / 255.0
    actions     = torch.from_numpy(actions).long().to(device)
    rewards     = torch.from_numpy(rewards).to(device)
    dones       = torch.from_numpy(dones).to(device)

    q_values = policy_net(states)
    q_value = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)

    with torch.no_grad():
        next_q_values = target_net(next_states).max(1)[0]
        expected_q = rewards + GAMMA * next_q_values * (1.0 - dones)

    loss = torch.nn.functional.mse_loss(q_value, expected_q)

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(policy_net.parameters(), 10.0)
    optimizer.step()


## 6) Train function (supports resume via checkpoint)


In [6]:
def train():
    env = make_env(render_mode=None)
    obs, info = env.reset()

    # Networks & optimizer
    policy_net = DQN(in_channels=4, num_actions=len(VALID_ACTIONS)).to(device)
    target_net = DQN(in_channels=4, num_actions=len(VALID_ACTIONS)).to(device)
    optimizer = torch.optim.Adam(policy_net.parameters(), lr=LR)

    target_net.load_state_dict(policy_net.state_dict())
    target_net.eval()

    # Replay buffer
    replay = ReplayBuffer(cap=REPLAY_CAP)

    # Try to resume from checkpoint
    start_frame = load_checkpoint(
        policy_net,
        device,
        CHECKPOINT_OPTION,
        target_net=target_net,
        optimizer=optimizer,
        directory=CHECKPOINT_DIR,
    )

    # Initialize state stack
    stack = init_state_stack(obs)
    state = get_state_from_stack(stack)

    episode_reward = 0.0
    episode_move_count = 0

    for frame_idx in range(start_frame + 1, MAX_FRAMES + 1):
        a_idx, eps = select_action(policy_net, state, frame_idx)
        if frame_idx % 1000 == 0:
            print(f'Frame {frame_idx}, epsilon: {eps:.3f}')

        # Save checkpoint periodically
        if frame_idx % 10000 == 0:
            save_checkpoint(policy_net, target_net, optimizer, frame_idx, directory=CHECKPOINT_DIR)

        env_action = VALID_ACTIONS[a_idx]
        next_obs, reward, terminated, truncated, info = env.step(env_action)
        done = terminated or truncated
        episode_reward += reward

        is_movement = (env_action != 0)
        movement_reward = MOVE_PENALTY if is_movement else 0.0
        episode_move_count += is_movement
        shaped_reward = reward + movement_reward

        frame = preprocess_obs(next_obs)
        stack.append(frame)
        next_state = get_state_from_stack(stack)

        replay.push(state, a_idx, shaped_reward, next_state, float(done))
        state = next_state

        if frame_idx > REPLAY_INIT:
            optimize_model(policy_net, target_net, optimizer, replay)

        if frame_idx % TARGET_UPDATE == 0:
            target_net.load_state_dict(policy_net.state_dict())

        if done:
            print(
                f'Frame {frame_idx}, episode reward: {episode_reward:.1f}, '
                f'episode move count: {episode_move_count}, epsilon: {eps:.3f}'
            )
            obs, info = env.reset()
            stack = init_state_stack(obs)
            state = get_state_from_stack(stack)
            episode_reward = 0.0
            episode_move_count = 0

    # Final save at the end of training loop
    save_checkpoint(policy_net, target_net, optimizer, MAX_FRAMES, directory=CHECKPOINT_DIR)
    env.close()

    print('Training complete. Checkpoints in:', CHECKPOINT_DIR)


## 7) Evaluate function (inline rendering)


In [7]:
def evaluate(checkpoint_option: str, episodes: int = 3, fps: int = 30, render_scale: int = 2):
    env = make_env(render_mode='rgb_array' if IN_COLAB else "human")

    policy_net = DQN(in_channels=4, num_actions=len(VALID_ACTIONS)).to(device)
    policy_net.eval()
    _ = load_checkpoint(policy_net, device, checkpoint_option, directory=CHECKPOINT_DIR)

    all_rewards = []
    frame_delay = 1.0 / max(1, fps)

    time.sleep(1)
    with log_out:
        log_out.clear_output(wait=True)
    clear_output()

    display(widgets.VBox([video_out, log_out]))

    for ep in range(episodes):
        obs, info = env.reset()
        stack = init_state_stack(obs)
        state = get_state_from_stack(stack)
        done = False
        ep_reward = 0.0
        while not done:
            # render current env frame
            frame_rgb = env.render()  # (H,W,3) RGB
            if frame_rgb is not None and IN_COLAB:
                update_video(frame_rgb, scale=render_scale)

            with torch.no_grad():
                s = torch.from_numpy(state).unsqueeze(0).to(device)
                if s.dtype != torch.float32:
                    s = s.float() / 255.0
                q_values = policy_net(s)
                a_idx = int(torch.argmax(q_values, dim=1).item())

            env_action = VALID_ACTIONS[a_idx]
            next_obs, reward, terminated, truncated, info = env.step(env_action)
            done = terminated or truncated
            ep_reward += reward

            frame = preprocess_obs(next_obs)
            stack.append(frame)
            state = get_state_from_stack(stack)

            if fps > 0:
                time.sleep(frame_delay)

        all_rewards.append(ep_reward)
        print(f'[Eval] Episode {ep + 1}/{episodes} reward: {ep_reward:.1f}')

    env.close()

    if all_rewards:
        mean_r = np.mean(all_rewards)
        std_r = np.std(all_rewards)
        print(f'[Eval] Mean reward over {episodes} episodes: {mean_r:.2f} ± {std_r:.2f}')


## 8) Run (choose mode in the config cell)


In [None]:
log_out = widgets.Output()
video_out = widgets.Output()


if MODE == 'train':
    print('Starting training...')
    train()
elif MODE == 'eval':
    print('Starting evaluation...')
    evaluate(CHECKPOINT_OPTION, episodes=EVAL_EPISODES)
else:
    raise ValueError("MODE must be 'train' or 'eval'")

VBox(children=(Output(), Output()))