# Day 28 - DQN

Following this [implementation guide](https://chatgpt.com/share/67ac5d36-f610-800e-b057-b16698d8714f).

In [1]:
from tqdm.auto import tqdm
import gymnasium as gym
from gymnasium.wrappers import AtariPreprocessing, FrameStackObservation
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
import ale_py

In [2]:
gym.register_envs(ale_py)

In [3]:
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x71bcf05c5d30>

### Create Atari Breakout environment with preprocessing wrappers

We disable frameskip here, but we could also disable it in `AtariPreprocessing`.
Otherwise, the frameskips would stack.

In [4]:
env = gym.make("ALE/Breakout-v5", render_mode=None, frameskip=1)  # no human rendering

A.L.E: Arcade Learning Environment (version 0.10.1+unknown)
[Powered by Stella]


### Apply Atari-specific preprocessing: grayscale, resize, frame skip, etc.

In [5]:
env = AtariPreprocessing(
    env,
    screen_size=84,
    grayscale_obs=True,
    frame_skip=4,
    noop_max=30,
    scale_obs=True,
)  

- `grayscale_obs`: outputs a single-channel 84x84 image
- `frame_skip`: repeat each action for 4 frames (=> 15fps decisions)
- `noop_max`: do up to 30 no-op actions at reset (random delay before game starts, common in Atari)

We also stack last 4 frames to give temporal context.

In [6]:
env = FrameStackObservation(env, stack_size=4) 

### Verify environment spaces

In [7]:
obs_shape = env.observation_space.shape  # should be (4, 84, 84) for 4 grayscale frames
n_actions = env.action_space.n
print("Observation shape:", obs_shape)
print("Number of actions:", n_actions)

Observation shape: (4, 84, 84)
Number of actions: 4


### Hyperparameters

In [8]:
num_episodes = 500        # number of episodes to train (adjust as needed; Atari usually needs much more)
learning_rate = 1e-4      # Adam optimizer learning rate
gamma = 0.99              # discount factor for future rewards
batch_size = 32
buffer_size = 100_000     # replay buffer capacity
min_buffer_size = 10_000  # minimum transitions in buffer before training begins
epsilon_start = 1.0
epsilon_end = 0.1
epsilon_decay = 1e6       # decay over 1e6 timesteps to epsilon_end
target_update_freq = 1000 # how often (steps) to update target network

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


## Define the Q-Network Model

In [10]:
class DQN(nn.Module):
    def __init__(self, input_shape, n_actions):
        super().__init__()
        
        # input_shape is (C, H, W), e.g., (4, 84, 84)
        # The channels, here, are our stacked frames; not colors
        c, h, w = input_shape

        # Conv layers
        self.conv1 = nn.Conv2d(c, 32, kernel_size=8, stride=4) # Output: 32 x 20 x 20
        self.conv2 = nn.Conv2d(32, 64, 4, 2)                   # Output: 64 x 9 x 9
        self.conv3 = nn.Conv2d(64, 64, 3, 1)                   # Output: 64 x 7 x 7

        # Fully connected layers
        self.fc1 = nn.Linear(64 * 7 * 7, 512)
        self.fc2 = nn.Linear(512, n_actions)

    def forward(self, x):
        # Pass through the conv block
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))

        # Pass through the fully connected block
        x = x.flatten(1) # Flattens all dimensions, starting from 1
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

We now initialize both the policy network, as well as the target network.

In [11]:
policy_net = DQN(obs_shape, n_actions).to(device)
target_net = DQN(obs_shape, n_actions).to(device)

We then copy the weights over from the policy network, and put the target network into eval mode.

In [12]:
target_net.load_state_dict(policy_net.state_dict())
target_net.eval() # Avoids computing gradients

DQN(
  (conv1): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))
  (conv2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
  (conv3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=3136, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=4, bias=True)
)

In [13]:
print(
    f"The network has a total of",
    f"{sum(p.numel() for p in policy_net.parameters()):,} parameters."
)

The network has a total of 1,686,180 parameters.


## Implementing the Replay Buffer

In [14]:
from collections import deque

In [15]:
class ReplayBuffer:
    def __init__(self, capacity, state_shape):
        self.capacity = capacity
        self.buffer = deque(maxlen=capacity)
        self.state_shape = state_shape

    def add(self, state, action, reward, next_state, done):
        """Store a transition in the buffer"""
        # Pixels range from 0 to 255, so they are u8
        # Gym may return lazy frame objects, which we .copy() to
        # ensure that we have actual pixel data
        state = state.copy()
        next_state = next_state.copy()
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        # Neat trick for unpacking the list
        states, actions, rewards, next_states, dones = zip(*batch)

        states_arr = np.array(states, copy=False) # (4, 84, 84)
        next_states_arr = np.array(next_states, copy=False)
        actions_arr = np.array(actions, dtype=np.int64)
        rewards_arr = np.array(rewards, dtype=np.float32)
        # dones are stored as floats, so that we can use the value
        # to zero out next_state values when an episode is over
        dones_arr = np.array(dones, dtype=np.float32)

        states_t = torch.tensor(states_arr, device=device)
        next_states_t = torch.tensor(next_states_arr, device=device)
        actions_t = torch.tensor(actions_arr, device=device)
        rewards_t = torch.tensor(rewards_arr, device=device)
        dones_t = torch.tensor(dones_arr, device=device)

        return states_t, actions_t, rewards_t, next_states_t, dones_t

    def __len__(self):
        return len(self.buffer)

We then immediately initialize the buffer we will use during training

In [16]:
replay_buffer = ReplayBuffer(buffer_size, obs_shape)

## Training Loop

## Tracking and Debugging with Weights & Biases (W&B)

In [17]:
from gymnasium.wrappers import RecordVideo

video_folder = "./videos/dqn_breakout_tutorial"
env = gym.make("ALE/Breakout-v5", render_mode="rgb_array", frameskip=1)
env = RecordVideo(
    env=env,
    video_folder=video_folder,
    episode_trigger=lambda x: x % 10 == 0,
)

env = AtariPreprocessing(
    env,
    screen_size=84,
    grayscale_obs=True,
    frame_skip=4,
    noop_max=30,
    scale_obs=True,
)  

env = FrameStackObservation(env, stack_size=4) 

  logger.warn(


In [18]:
obs_shape = env.observation_space.shape  # should be (4, 84, 84) for 4 grayscale frames
n_actions = env.action_space.n

In [19]:
num_episodes = 10_000     # number of episodes to train (adjust as needed; Atari usually needs much more)
learning_rate = 1e-4      # Adam optimizer learning rate
gamma = 0.99              # discount factor for future rewards
batch_size = 64
buffer_size = 100_000     # replay buffer capacity
min_buffer_size = 10_000  # minimum transitions in buffer before training begins
epsilon_start = 1.0
epsilon_end = 0.1
epsilon_decay = 1e6       # decay over 1e6 timesteps to epsilon_end
target_update_freq = 1000 # how often (steps) to update target network

### Start a new W&B run

In [20]:
import wandb

In [21]:
wandb.init(
    project="dqn-breakout-tutorial",
    config={ # Hyperparameters and config
        "env": "ALE/Breakout-v5",
        "episodes": num_episodes,
        "learning_rate": learning_rate,
        "batch_size": batch_size,
        "buffer_size": buffer_size,
        "min_buffer_size": min_buffer_size,
        "gamma": gamma,
        "epsilon_start": epsilon_start,
        "epsilon_end": epsilon_end,
        "epsilon_decay_steps": epsilon_decay,
        "target_update_freq": target_update_freq,
    }
)

[34m[1mwandb[0m: Currently logged in as: [33mfitti[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


### Run the updated training loop

In [22]:
policy_net = DQN(obs_shape, n_actions).to(device)
target_net = DQN(obs_shape, n_actions).to(device)

target_net.load_state_dict(policy_net.state_dict())
target_net.eval() # Avoids computing gradients

replay_buffer = ReplayBuffer(buffer_size, obs_shape)

In [None]:
import os
from pathlib import Path

optimizer = optim.Adam(policy_net.parameters(), lr=learning_rate)
loss_fn = nn.SmoothL1Loss()

epsilon = epsilon_start
epsilon_decay_rate = (epsilon_start - epsilon_end) / epsilon_decay

global_step = 0
episode_rewards = []
running_loss = 0.0
loss_count = 0

for episode in tqdm(range(num_episodes), desc="Episodes"):
    state, _ = env.reset(seed=seed)
    state = np.array(state, copy=False)
    total_reward = 0
    done = False

    while not done:
        # Action selection
        if random.random() < epsilon:
            action = env.action_space.sample()
        else:
            state_t = torch.tensor(state, device=device).unsqueeze(0)
            with torch.no_grad():
                q_values = policy_net(state_t)
                action = int(torch.argmax(q_values, dim=1).item())

        # Epsilon decay
        if epsilon > epsilon_end:
            epsilon -= epsilon_decay_rate

        # Environment step
        next_state, reward, done, *_ = env.step(action)
        next_state = np.array(next_state, copy=False)

        # Store transition
        replay_buffer.add(state, action, reward, next_state, done)

        # Update values
        state = next_state
        total_reward += reward
        global_step += 1

        # Learn from replay
        if len(replay_buffer) >= min_buffer_size:
            # Sample a batch
            states_b, actions_b, rewards_b, next_states_b, dones_b = replay_buffer.sample(batch_size)

            # Compute q values for all states and actions
            q_values = policy_net(states_b) # (batch_size, n_actions)

            # Gather the q values for the actions taken in the batch
            # This will result in a tensor like:
            # [q_0, q_2, q_1, ...],
            # if the actions taken were:
            # [a_0, a_2, a_1, ...]
            state_action_values = q_values.gather(1, actions_b.view(-1, 1)).squeeze(1)

            # Compute the targets
            with torch.no_grad():
                # rewards_b + gamma * max(target_net(next_states_b), dim=1) - state_action_values
                next_q_values = target_net(next_states_b)
                max_next_q_values, _ = next_q_values.max(dim=1) # torch.max also returns indices

                # Only include the next state if we are not done
                targets = rewards_b + gamma * max_next_q_values * (1.0 - dones_b)

            # Compute the loss
            loss = loss_fn(state_action_values, targets)

            # Optimize the policy_net parameters
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update the target_net periodically
            if global_step % target_update_freq == 0:
                target_net.load_state_dict(policy_net.state_dict())

            # Accumulate loss for logging
            running_loss += loss.item()
            loss_count += 1

        # Track the return, now that the episode is over
        episode_rewards.append(total_reward)
        avg_reward_100 = np.mean(episode_rewards[-100:])

        # Log metrics to W&B
        if loss_count > 0:
            avg_loss = running_loss / loss_count
        else:
            avg_loss = None
            
        wandb.log({
            "episode": episode,
            "episode_reward": total_reward,
            "epsilon": epsilon,
            "avg_reward_100": avg_reward_100,
            "avg_loss": avg_loss,
        })

        # Reset running loss counters
        running_loss = 0.0
        loss_count = 0

        # Periodically show results and log videos
        if (episode+1) % 10 == 0:
            avg_reward = np.mean(episode_rewards[-10:])
            print(
                f"Episode {episode+1}: Reward: {total_reward}, Avg (last 10): {avg_reward:.2f}",
                f"Epsilon: {epsilon:.3f}",
                end="\t\t\r",
            )

            # Log the most recent video that was recorded
            latest_video = max(Path(video_folder).glob('*'), key=os.path.getctime)
            wandb.log({
                "video": wandb.Video(str(latest_video))
            })

Episodes:   0%|          | 0/10000 [00:00<?, ?it/s]

Episode 180: Reward: 1.0, Avg (last 10): 1.00 Epsilon: 0.971		

## Evaluation: Testing the Learned Agent and Observing Behavior

In [None]:
policy_net.eval()

num_test_episodes = 5
test_rewards = []

for i in range(num_test_episodes):
    state, _ = env.reset()
    state = np.array(state, copy=False)
    done = False
    episode_reward = 0
    while not done:
        state_t = torch.tensor(state, device=device).unsqueeze(0)
        q_values = policy_net(state_t)
        action = int(torch.argmax(q_values, dim=1).item())
        next_state