# Implementation Detials of Deep Q-Learning Agent

I plan to use the "CarRacing-v3" environment from OpenAI Gym to implement and test our Deep Q-Learning algorithm. 

First of all, import a bunch of libraries:

In [None]:
import gymnasium as gym # OpenAI Gymnasium is a game environment library.
import math
import random
import argparse
import os
import matplotlib
import matplotlib.pyplot as plt
# provide named tuples for replay buffer in DQN
from collections import namedtuple, deque   
"""
for t in count():
    ... # t becomes 0, 1, 2, ... until break
"""
from itertools import count
from tqdm import tqdm

import numpy as np
# PyTorch library
import torch
# for building neural networks
import torch.nn as nn
# for optimization algorithms
import torch.optim as optim
# expose torch functions
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if device.type == "cuda":
    # cuDNN is NVIDIA’s GPU library for deep neural network ops.
    # speed up CNN training/inference when tensor sizes are consistent (like fixed-size images).
    torch.backends.cudnn.benchmark = True  # auto-tune conv kernels for fixed input size

Using device: cuda


Then we make the environment.

Its documentation is https://gymnasium.farama.org/environments/box2d/car_racing/.

Here, we set continuous=False to make the action space discrete, because for now we only consider a discrete action space, with limited actions like "turn left", "turn right", "accelerate", "brake", and "do nothing". 

In [3]:
def make_env():
    env = gym.make("CarRacing-v3", continuous=False)
    return env

## Frame Preprocessing

We preprocess the raw frames from the environment to make them more suitable for our neural network. The preprocessing steps include:

1. **Grayscale Conversion**: Convert the RGB image to grayscale to reduce the number of input channels and focus on the essential features of the environment.
2. **Resizing**: Resize the image to a smaller resolution (e.g., 84x84) to reduce the computational load while preserving important information.
3. **Normalization**: Normalize the pixel values to the range [0, 1] to improve the training stability of the neural network.

In [4]:
def preprocess_frame(frame: np.ndarray) -> np.ndarray:
    """Convert RGB 96x96x3 frame to grayscale 84x84, normalized to [0,1]."""
    # Grayscale via luminance weights
    gray = np.dot(frame[:, :, :3], [0.2989, 0.5870, 0.1140])
    # Crop bottom 12 rows (status bar) -> 84x96, then resize to 84x84
    gray = gray[:84, 6:90]  # crop to 84x84 directly
    return gray.astype(np.float32) / 255.0

We also create a FrameStack wrapper to stack the last 4 frames together, which allows the agent to capture temporal information and better understand the dynamics of the environment.

In [5]:
class FrameStack:
    """Maintains a stack of the last N preprocessed frames."""
    def __init__(self, n_frames: int = 4):
        self.n_frames = n_frames
        self.frames = deque(maxlen=n_frames)

    def reset(self, frame: np.ndarray):
        processed = preprocess_frame(frame)
        for _ in range(self.n_frames):
            self.frames.append(processed)
        return self._get_state()

    def step(self, frame: np.ndarray):
        self.frames.append(preprocess_frame(frame))
        return self._get_state()

    def _get_state(self) -> np.ndarray:
        # Returns shape (n_frames, 84, 84)
        return np.array(self.frames, dtype=np.float32)

## Replay Buffer

We implement a replay buffer to store the agent's experiences during training. The replay buffer allows us to sample random mini-batches of experiences for training the neural network, which helps to break the correlation between consecutive samples and improve the stability of the learning process.

In [6]:
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))


class ReplayMemory:
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

## DQN Architecture

For this implementation, we use a combination of CNN and MLP. CNN layers are used to extract features from frames, while MLP layers are used to output Q-values for each action.

In [None]:
class CNNDQN(nn.Module):
    """Classic Atari-style CNN DQN: conv layers + fully connected layers."""
    def __init__(self, n_frames: int, n_actions: int):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(n_frames, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
        )
        # Compute flattened size: input 4x84x84
        # After conv1: 32 x 20 x 20
        # After conv2: 64 x 9 x 9
        # After conv3: 64 x 7 x 7 = 3136
        self.fc = nn.Sequential(
            nn.Linear(64 * 7 * 7, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions),
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        # final output is (batch_size, n_actions)
        return self.fc(x)

Here, the input is:

$$
x \in \mathbb{R}^{B \times 4 \times 84 \times 84}
$$

where:
- $B$ is the batch size
- $4$ is the number of stacked frames from the FrameStack wrapper
- $84 \times 84$ is the resolution of each preprocessed frame

### `nn.Sequential`

Then, `nn.Sequential` is used to build the layers of the DQN. Typically it just applies the layers in sequence. For example:

```python
self.conv = nn.Sequential(
    Layer1, 
    Layer2,
    ...
)
```

means:

$$
x_1 = \text{Layer1}(x) \\
x_2 = \text{Layer2}(x_1) \\
\ldots
$$

### `nn.Conv2d`

`nn.Conv2d(in_channels, out_channels, kernel_size, stride)` performs a 2D convolution.

Mathematically:

$Output = Conv2D(Input, Kernel)+Bias$

Where:
- `in_channels` is the number of input feature maps (e.g., 4 for the 4 stacked frames)
- `out_channels` is the number of output feature maps (i.e., the number of filters)
- `kernel_size` is the size of the convolutional kernel (e.g., 8 means an 8x8 kernel)
- `stride` is the step size for moving the kernel

### How output size is computed

For a 2D convolution:

If input size is $(H_{in}, W_{in})$.

Then output height is:

$$
H_{out} = \left\lfloor \frac{H_{in} - \text{kernel\_size}}{\text{stride}} \right\rfloor + 1
$$

And output width is:

$$
W_{out} = \left\lfloor \frac{W_{in} - \text{kernel\_size}}{\text{stride}} \right\rfloor + 1
$$

### The Actual Computation

Input $B \times 4 \times 84 \times 84$.

In Conv1, `nn.Conv2d(4, 32, kernel_size=8, stride=4)`:

$$
H_{out} = \left\lfloor \frac{84 - 8}{4} \right\rfloor + 1 = 20 \\
W_{out} = \left\lfloor \frac{84 - 8}{4} \right\rfloor + 1 = 20
$$

So the output of Conv1 is $B \times 32 \times 20 \times 20$, where 32 is the number of filters.

The following calculation is similar.


### What does `x.view(x.size(0), -1)` do?

Now in the end we have `x` with shape $B \times 64 \times 7 \times 7$ after the convolutional layers.

Then:
- `x.size(0)` means the batch size $B$.
- `-1` means we want to flatten the remaining dimensions (64, 7, 7) into a single dimension.

So `x.view(x.size(0), -1)` reshapes `x` to have shape $B \times (64 \cdot 7 \cdot 7)$, which is $B \times 3136$.

But how is this reshaped? It is reshaped in a row-major order (also known as C-style order). This means that the last dimension changes the fastest, and the first dimension changes the slowest. Nothing fancy.

### `in_channels` and `out_channels` in `nn.Conv2d`

`in_channels` is the number of input feature maps to the convolutional layer. In this concrete example, the first convolutional layer takes the 4 stacked frames as input, so `in_channels` is 4. Why do we have 4 stacked frames? Because we want the agent to capture temporal information (the movement of the car and the track), but a single frame doesn't contain that information. By stacking 4 frames together, the agent can see how the environment changes over time.

`out_channels` is the number of output feature maps produced by the convolutional layer. Each filter in the convolutional layer learns to detect a specific feature in the input frames (e.g., edges, corners, etc.). By setting `out_channels` to 32, we are allowing the convolutional layer to learn 32 different features from the input frames. The more filters we have, the more complex features the layer can learn, but it also increases the computational cost.

## Hyperparameters

Here we predefine some hyperparameters.

- `batch_size`: The number of samples used in one training iteration. A larger batch size can provide a more stable gradient estimate but requires more memory.
- `gamma`: The discount factor for future rewards. It determines how much the agent values future rewards compared to immediate rewards. A value close to 1 means the agent will consider future rewards more strongly, while a value close to 0 means the agent will focus more on immediate rewards.
- `epsilon_start`, `epsilon_end`, `epsilon_decay`: These parameters control the epsilon-greedy exploration strategy with decaying epsilon.
- `tau`: The soft update parameter for updating the target network. A smaller `tau` means the target network updates more slowly, which can help stabilize training.
- `learning_rate`: The learning rate for the optimizer, which controls how much the model's weights are updated during training. A smaller learning rate can lead to more stable training but may require more iterations to converge, while a larger learning rate can speed up training but may cause instability.
- `n_frames`: The number of frames to stack together in the FrameStack wrapper. This allows the agent to capture temporal information from the environment, which is crucial for understanding the dynamics of the game. In this case, we set it to 4, meaning we will stack the last 4 frames together as input to the neural network.
- `memory_size`: The maximum number of experiences that can be stored in the replay buffer.

In [9]:
BATCH_SIZE = 128
GAMMA = 0.99
EPS_START = 1.0
EPS_END = 0.05
EPS_DECAY = 50000       # decay over steps (not episodes)
TAU = 0.005             # soft target update rate
LR = 1e-4
N_FRAMES = 4
MEMORY_SIZE = 50000
OPTIMIZE_EVERY = 4      # optimize every N env steps
PLOT_EVERY = 10         # plot every N episodes

# BASE_DIR = os.path.dirname(__file__)
RESULTS_DIR = os.path.join("./results", "car_racing")
os.makedirs(RESULTS_DIR, exist_ok=True)
CHECKPOINT_PATH = os.path.join(RESULTS_DIR, "cnn_dqn.pt")

## Setup for Training

In [None]:
env = make_env()
n_actions = env.action_space.n  # 5 discrete actions
frame_stack = FrameStack(N_FRAMES)

policy_net = CNNDQN(N_FRAMES, n_actions).to(device)
target_net = CNNDQN(N_FRAMES, n_actions).to(device)

# `amsgrad=True` can help with convergence and stability.
# we will skip the details here.
optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
memory = ReplayMemory(MEMORY_SIZE)
steps_done = 0
episode_durations = []
episode_rewards = []

  from pkg_resources import resource_stream, resource_exists


In [None]:
# Load the target network with the same weights as the policy network
target_net.load_state_dict(policy_net.state_dict())

<All keys matched successfully>

In [12]:
# Select an action using epsilon-greedy strategy

def select_action(state):
    global steps_done
    # decay epsilon over time
    # here, the threshold follows a `exponential decay` curve.
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1.0 * steps_done / EPS_DECAY)
    steps_done += 1
    if random.random() > eps_threshold:
        with torch.no_grad():
            # policy_net(state) returns a shape of (batch_size, n_actions),
            # so argmax(dim=1) gives the index of the max action for each batch element,
            # because dim=1 corresponds to the action dimension.
            # so it returns a tensor of shape (batch_size,) containing the indices of the best action for each batch element.
            # in this particular case, batch_size=1, since we have only one state.
            # But still, it is viewed as an array (1,), which doesn't suit the `env.step()`, so we will reshape it:
            # `view(1, 1)` reshape it into a (1, 1) tensor. 
            return policy_net(state).argmax(dim=1).view(1, 1)
    else:
        return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)

In [None]:
# Then we train the model. Introduce a optimize_model() function to perform a single optimization step on the policy network.

def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    # Did a tanspose here: transpose the batch
    # from a sequence into a batch, separated by each field of the named tuple.
    batch = Transition(*zip(*transitions))


    non_final_mask = torch.tensor(
        tuple(map(lambda s: s is not None, batch.next_state)),
        device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])

    state_batch = torch.cat(batch.state)    # (Batch, Frames, H, W)
    action_batch = torch.cat(batch.action)  # (Batch, 1)
    reward_batch = torch.cat(batch.reward)  # (Batch, 1)

    # `policy_net(state_batch)` returns `(Batch, Action)`, but the `Action` here
    # is not the action taken, but the predicted Q-values for all actions.
    # `action_batch` stores the action actually taken for a given state.
    # so `gather(1, action_batch)` look along dimension 1 (the action dimension),
    # and picks the Q-value corresponding to the action taken in each transition.
    """
    For example, suppose the policy network returns:

    [[0.1, 0.5, 0.2],
     [0.3, 0.4, 0.6]]

    and the action_batch is:
    [[1],
     [2]]

    Then `gather(1, action_batch)` will pick:
    [[0.5],  # from the first row, pick index 1
     [0.6]]  # from the second row, pick index 2
    """
    state_action_values = policy_net(state_batch).gather(1, action_batch)

    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    with torch.no_grad():
        # max(1) returns the max Q-value for each next state across all actions, and `.values` gives the actual max values (not the indices).
        next_state_values[non_final_mask] = target_net(non_final_next_states).max(1).values

    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    optimizer.zero_grad()
    loss.backward()
    # clip gradients to prevent explosion. here the limit is set to 100.
    torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    optimizer.step()

## What does `batch = Transition(*zip(*transitions))` do?

batch = Transition(*zip(*transitions)) is a transpose/re-batching step.
- transitions is a list of sampled items from replay memory, each like:
  Transition(state, action, next_state, reward)
- So shape-wise it is like:  
  [ (s1,a1,ns1,r1), (s2,a2,ns2,r2), ..., (sB,aB,nsB,rB) ]
zip(*transitions) flips that into grouped fields:
- (s1, s2, ..., sB)
- (a1, a2, ..., aB)
- (ns1, ns2, ..., nsB)
- (r1, r2, ..., rB)
Then Transition(*...) wraps those groups back into a named tuple, so you can access:
- batch.state
- batch.action
- batch.next_state
- batch.reward
Why this is needed here:
- The later code is field-wise and vectorized: torch.cat(batch.state), torch.cat(batch.action), mask over batch.next_state, etc.
- Without this transform, you’d have to loop over each sampled transition and manually collect each field. This line does that cleanly in one step

## What does `non_final_mask` do?

`non_final_mask` exists for one reason: **some transitions end an episode**, so they have **no valid next state** to bootstrap from.

In many replay-buffer implementations (including the PyTorch DQN tutorial this resembles), they store terminal transitions like this:

* if the episode ended after taking action $a_t$, then `next_state = None`

That `None` is the “done” signal.

---

### 1) What does DQN want to compute here?

For each sampled transition $(s_t, a_t, r_{t+1}, s_{t+1})$, DQN builds a TD target:

If the next state is non terminal:

$$
y_t = r_{t+1} + \gamma \max_{a'} Q_{\text{target}}(s_{t+1}, a')
$$

If the next state is terminal (episode ended), there is no future return to add, so:

$$
y_t = r_{t+1}
$$

Meaning of each term:

* $y_t$: training target for the Q-value of the chosen action at $s_t$
* $r_{t+1}$: immediate reward after taking $a_t$ at $s_t$
* $\gamma$: discount factor (how much you care about the future)
* $Q_{\text{target}}(\cdot)$: Q-network used for the target (here `target_net`)
* $\max_{a'}$: greedy value over all possible next actions

So terminal transitions must **not** include the $\gamma \max Q(\cdot)$ part.

That is exactly what the mask implements.

---

### 2) What `non_final_mask` is doing

```python
non_final_mask = torch.tensor(
    tuple(map(lambda s: s is not None, batch.next_state)),
    device=device, dtype=torch.bool)
```

This creates a boolean tensor of length `BATCH_SIZE`.

* `True` means: this transition has a real next state, so it is **non terminal**
* `False` means: `next_state is None`, so it is **terminal**

Example: if the batch has 5 transitions and two are terminal, you might get:

* `non_final_mask = [True, False, True, True, False]`

---

### 3) Why `non_final_next_states` exists

```python
non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
```

You cannot do `torch.cat` on a list that contains `None`.

So they extract only the actual tensors and concatenate them into one tensor, so they can run them through `target_net` in one shot.

---

### 4) How the mask is used to fill `next_state_values`

```python
next_state_values = torch.zeros(BATCH_SIZE, device=device)
with torch.no_grad():
    next_state_values[non_final_mask] = target_net(non_final_next_states).max(1).values
```

Step by step:

1. Start with all zeros:

* `next_state_values[i] = 0` for every transition $i$

This already matches the terminal rule:

* if transition $i$ is terminal, we want the bootstrapped term to be 0

2. For non terminal transitions only, compute:

$$
\max_{a'} Q_{\text{target}}(s_{t+1}, a')
$$

and write those values back into the right positions using the mask:

* `next_state_values[non_final_mask] = ...`

So after this:

* terminal transitions keep value 0
* non terminal transitions get their max-Q bootstrap value

This is a clean way to implement the piecewise definition of $y_t$.

---

### 5) Where this shows up in the target in your code

```python
expected_state_action_values = (next_state_values * GAMMA) + reward_batch
```

This is exactly:

$$
y_t = r_{t+1} + \gamma \cdot \text{next_state_value}
$$

and because terminal transitions have `next_state_value = 0`, you automatically get:

$$
y_t = r_{t+1}
$$

for terminal transitions.

---

### 6) Why you did not “see it in the original paper”

The original DQN math usually writes the target with an implicit “if terminal then no bootstrap” condition, often described in words rather than showing an explicit boolean mask in pseudocode.

In code, you must handle it concretely, and a mask is the standard vectorized way.

---



In [None]:
# Finally the whole training loop:

skip_training = os.path.exists(CHECKPOINT_PATH) and not args.load

if args.episodes:
    num_episodes = args.episodes
elif torch.cuda.is_available():
    print("Using GPU for training.")
    num_episodes = 500
else:
    print("Using CPU for training.")
    num_episodes = 30

if not skip_training:
    for i_episode in tqdm(range(num_episodes), desc="Training"):
        obs, info = env.reset()
        state = frame_stack.reset(obs)
        state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)

        ep_reward = 0.0
        for t in count():
            action = select_action(state)
            next_obs, reward, terminated, truncated, info = env.step(action.item())
            ep_reward += reward
            reward_tensor = torch.tensor([reward], device=device, dtype=torch.float32)

            done = terminated or truncated
            if done:
                next_state = None
            else:
                next_state = frame_stack.step(next_obs)
                next_state = torch.tensor(next_state, dtype=torch.float32, device=device).unsqueeze(0)

            memory.push(state, action, next_state, reward_tensor)
            state = next_state

            # Optimize and update target net every N steps
            if t % OPTIMIZE_EVERY == 0:
                optimize_model()

                target_net_state_dict = target_net.state_dict()
                policy_net_state_dict = policy_net.state_dict()
                for key in policy_net_state_dict:
                    target_net_state_dict[key] = policy_net_state_dict[key] * TAU + \
                                                 target_net_state_dict[key] * (1 - TAU)
                target_net.load_state_dict(target_net_state_dict)

            if done:
                episode_durations.append(t + 1)
                episode_rewards.append(ep_reward)
                if i_episode % PLOT_EVERY == 0:
                    plot_training_live()
                break

    print('Training complete')

    # Save checkpoint with training metrics
    torch.save({
        "policy_net": policy_net.state_dict(),
        "target_net": target_net.state_dict(),
        "optimizer": optimizer.state_dict(),
        "steps_done": steps_done,
        "episode_durations": episode_durations,
        "episode_rewards": episode_rewards,
    }, CHECKPOINT_PATH)
    print(f"Checkpoint saved to {CHECKPOINT_PATH}")

# Save final chart (always regenerated)
if episode_rewards:
    print("Evaluating random baseline (50 episodes)...")
    rand_dur, rand_rew = evaluate_random(num_episodes=50)
    print(f"  Random baseline => duration: {rand_dur:.0f}, reward: {rand_rew:.1f}")
    chart_path = os.path.join(RESULTS_DIR, "cnn_dqn.png")
    plot_final_chart(episode_rewards, rand_rew, chart_path)
else:
    print("No training metrics available — cannot generate charts.")

env.close()
print(f"Done! All results saved to {RESULTS_DIR}")
plt.ioff()
plt.show()