# Convolutional DQN (C-DQN) for CartPole

This notebook demonstrates a simple Convolutional DQN that uses a stack of 4 grayscale frames (84Ã—84) as input. It's intended as an educational example; not all hyperparameters are tuned for performance.

Instructions:
- Run the cells in order.
- Ensure dependencies are installed (see the cell below).
- The notebook includes small checks to validate tensor shapes and a minimal evaluation routine.

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import gymnasium as gym
import cv2
import random
import copy
from collections import deque

In [3]:
# --- Frame Preprocessing Functions ---

def resizer(frame):
    """
    Converts RGB frame to grayscale and resizes it to 84x84.
    The formula is a standard luminance calculation.
    """
    # Convert to grayscale using the standard weighted sum
    # frame[:, :, 0] is R, [:, :, 1] is G, [:, :, 2] is B
    gray = 0.299 * frame[:, :, 0] + 0.587 * frame[:, :, 1] + 0.114 * frame[:, :, 2]
    
    # Resize and normalize
    # INTER_AREA is preferred for downsampling
    resized_frame = cv2.resize(gray, (84, 84), interpolation=cv2.INTER_AREA).astype(np.float32) / 255.0
    return resized_frame

def frame_tensor(resized_frame):
    """
    Convert a single preprocessed frame (84x84 numpy) to a tensor of shape (1, 1, 84, 84).
    (Used only for initial single frame processing, not in the main loop).
    """
    arr = np.ascontiguousarray(resized_frame, dtype=np.float32)
    return torch.from_numpy(arr).unsqueeze(0).unsqueeze(0) # (1, 1, 84, 84)

def stack_to_tensor(frame_stack):
    """
    Convert a sequence (deque) of 4 frames (numpy arrays) to a tensor
    of shape (1, 4, 84, 84) suitable for the ConvNet.
    This function expects frames in the deque to be 84x84 np.float32 arrays.
    """
    # np.stack creates an array of shape (4, 84, 84)
    state = np.stack(list(frame_stack), axis=0).astype(np.float32)
    # torch.from_numpy and unsqueeze(0) converts to (1, 4, 84, 84)
    return torch.from_numpy(state).unsqueeze(0)


In [4]:
# --- DQN Network ---

class ConvNet(nn.Module):
    # 
    def __init__(self, action_space):
        super().__init__()

        # Input: (N, 4, 84, 84) where 4 is the number of stacked frames
        self.conv1 = nn.Conv2d(4, 16, kernel_size=8, stride=4) # Output: ((84-8)/4 + 1) = 20x20
        self.conv2 = nn.Conv2d(16, 32, kernel_size=4, stride=2) # Output: ((20-4)/2 + 1) = 9x9

        conv_out_size = self._get_conv_out() # Calculated as 32 * 9 * 9 = 2592

        self.fc1 = nn.Linear(conv_out_size, 256)
        self.fc2 = nn.Linear(256, action_space)

    def _get_conv_out(self):
        # Calculates the size of the output feature map after convolutions and flattening
        with torch.no_grad():
            dummy = torch.zeros(1, 4, 84, 84)
            x = F.relu(self.conv1(dummy))
            x = F.relu(self.conv2(x))
            # x.view(1, -1) flattens the dimensions after the batch dimension (1)
            # .size(1) returns the number of features
            return x.view(1, -1).size(1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        # Flatten the output for the fully connected layers
        x = x.view(x.size(0), -1) 
        x = F.relu(self.fc1(x))
        q = self.fc2(x)
        return q

# --- Action Selection ---

def epsilon_greedy_action_selection(q_values, epsilon):
    # q_values: shape (1, num_actions) or (batch_size, num_actions)
    
    if random.random() < epsilon:
        # Explore: pick a random action
        return random.randrange(q_values.shape[1])
    else:
        # Exploit: pick the action with the highest Q-value
        return q_values.argmax(dim=1).item()

# --- Replay Buffer ---

class ReplayBuffer:
    # NOTE: To save memory and simplify batching, the buffer stores 
    # the (4, 84, 84) NumPy array state stack, not the (1, 4, 84, 84) tensor.
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        """
        state, next_state: NumPy arrays (4, 84, 84)
        """
        # Store as NumPy arrays/standard types
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        
        # Convert states back to a batched tensor of shape (batch_size, 4, 84, 84)
        states = torch.from_numpy(np.array(states)).float()
        next_states = torch.from_numpy(np.array(next_states)).float()
        
        return states, actions, rewards, next_states, dones
    
    def __len__(self):
        return len(self.buffer)

# --- Training Function ---

def train_step(buffer, net, target_net, optimizer, batch_size, gamma):
    # 
    if len(buffer) < batch_size:
        return None

    states, actions, rewards, next_states, dones = buffer.sample(batch_size)

    # Convert sampled lists/tuples to tensors
    actions = torch.tensor(actions, dtype=torch.long) # (batch_size,)
    rewards = torch.tensor(rewards, dtype=torch.float32) # (batch_size,)
    dones = torch.tensor(dones, dtype=torch.float32) # (batch_size,)

    # Compute Q-values for current state (Q(s, a))
    # net(states) has shape (batch_size, num_actions)
    # gather(1, actions.unsqueeze(1)) selects the Q-value for the action that was taken
    q_values = net(states).gather(1, actions.unsqueeze(1)).squeeze(1) # (batch_size,)

    # Compute target Q-values (r + gamma * max_a' Q_target(s', a'))
    with torch.no_grad():
        # Get Q-values for next state from the TARGET network
        next_q_values = target_net(next_states) # (batch_size, num_actions)
        # Find the maximum Q-value for the next state
        max_next_q_values = next_q_values.max(1)[0] # (batch_size,)
        # Target Q-value: r + gamma * max_Q_target(s')
        # (1 - dones) handles terminal states: if done is True (1), the max_next_q_values term is zeroed out.
        target_q_values = rewards + gamma * max_next_q_values * (1 - dones)

    # Compute loss (MSE between Q(s,a) and Target Q-value)
    loss = loss_fn(q_values, target_q_values)

    # Optimization step
    optimizer.zero_grad()
    loss.backward()
    # Gradient clipping is often used in DQN, but omitted here for simplicity
    optimizer.step()

    return loss.item()


In [None]:
# --- Main Training Loop ---

# Environment setup
# Use render_mode="rgb_array" to get the image frame
env = gym.make("CartPole-v1", render_mode="rgb_array")

# Hyperparameters and setup
BUFFER_CAPACITY = 20000
buffer = ReplayBuffer(capacity=BUFFER_CAPACITY)
NET_ACTION_SPACE = env.action_space.n # 2 for CartPole: Left or Right
net = ConvNet(action_space=NET_ACTION_SPACE)
gamma = 0.99
batch_size = 64
learning_rate = 1e-3

# Target Network setup (for stability)
target_net = copy.deepcopy(net)
target_net.eval()
num_episodes = 500
epsilon = 0.9
epsilon_decay = 0.995
min_epsilon = 0.05
target_update_freq = 250 # steps

# Optimization setup
optimizer = optim.Adam(net.parameters(), lr=learning_rate)
loss_fn = nn.MSELoss()

frame_stack = deque(maxlen=4)
step_count = 0

print("Starting DQN Training (CNN on Frames)")
for episode in range(num_episodes):
    # 1. Environment Reset and Initial State setup
    obs, info = env.reset()
    frame = env.render()
    processed = resizer(frame) # 84x84 NumPy array

    frame_stack.clear()
    for _ in range(4):
        frame_stack.append(processed) # Populate the deque with the initial frame

    state_tensor = stack_to_tensor(frame_stack) # (1, 4, 84, 84)
    state_for_buffer = np.stack(list(frame_stack), axis=0) # (4, 84, 84) NumPy array

    done = False
    total_reward = 0
    episode_loss = []
    q_values_at_end = None # To store the final Q-values for logging

    while not done:
        step_count += 1

        # 2. Target Network Update
        if step_count % target_update_freq == 0:
            target_net.load_state_dict(net.state_dict())
            print(f"\n--- Episode {episode+1}: Updated target network. ---")

        # 3. Action Selection (Epsilon-Greedy)
        with torch.no_grad():
            q_values = net(state_tensor)
        action = epsilon_greedy_action_selection(q_values, epsilon)

        # 4. Environment Step
        obs, reward, terminated, truncated, info = env.step(action)
        done = terminated or truncated

        # Reward shaping (as in your original code)
        # if done:
        #     reward = -10
        total_reward += reward

        # 5. Next State Processing
        frame_next = env.render()
        processed_next = resizer(frame_next) # 84x84 NumPy array
        frame_stack.append(processed_next)

        next_state_tensor = stack_to_tensor(frame_stack) # (1, 4, 84, 84)
        next_state_for_buffer = np.stack(list(frame_stack), axis=0) # (4, 84, 84) NumPy array
        
        # 6. Store Transition
        buffer.push(state_for_buffer, action, reward, next_state_for_buffer, done)

        # 7. Train Network
        loss = train_step(buffer, net, target_net, optimizer, batch_size, gamma)
        if loss is not None:
            episode_loss.append(loss)

        # 8. State Update
        state_tensor = next_state_tensor
        state_for_buffer = next_state_for_buffer # Must update for the next push
        q_values_at_end = q_values # Store for final logging

    # End of Episode Logging and Decay

    # Checkpoint (using a modulus that makes sense for the number of episodes)
    if (episode + 1) % 100 == 0:
        torch.save(net.state_dict(), f"dqn_checkpoint_{episode+1}.pt")
        print(f"Saved checkpoint at episode {episode+1}")

    # Epsilon decay
    epsilon = max(min_epsilon, epsilon * epsilon_decay)

    avg_loss = sum(episode_loss) / len(episode_loss) if episode_loss else 0
    max_q_val = q_values_at_end.max().item() if q_values_at_end is not None else 0.0

    print(
        f"Episode {episode+1:4d} | Reward: {total_reward:6.2f} | "
        f"Steps: {step_count} | Avg Loss: {avg_loss:.4f} | "
        f"Epsilon: {epsilon:.3f} | Buffer: {len(buffer):5d} | "
        f"Max Q: {max_q_val:.3f}"
    )

env.close()

print("\nTraining complete.")

Starting DQN Training (CNN on Frames)
Episode    1 | Reward:  15.00 | Steps: 15 | Avg Loss: 0.0000 | Epsilon: 0.895 | Buffer:    15 | Max Q: 0.022
Episode    2 | Reward:  19.00 | Steps: 34 | Avg Loss: 0.0000 | Epsilon: 0.891 | Buffer:    34 | Max Q: 0.023
Episode    3 | Reward:  18.00 | Steps: 52 | Avg Loss: 0.0000 | Epsilon: 0.887 | Buffer:    52 | Max Q: 0.023
Episode    4 | Reward:  24.00 | Steps: 76 | Avg Loss: 0.4059 | Epsilon: 0.882 | Buffer:    76 | Max Q: 1.463
Episode    5 | Reward:  30.00 | Steps: 106 | Avg Loss: 0.0146 | Epsilon: 0.878 | Buffer:   106 | Max Q: 0.998
Episode    6 | Reward:  14.00 | Steps: 120 | Avg Loss: 0.0007 | Epsilon: 0.873 | Buffer:   120 | Max Q: 1.048
Episode    7 | Reward:  16.00 | Steps: 136 | Avg Loss: 0.0002 | Epsilon: 0.869 | Buffer:   136 | Max Q: 1.019
Episode    8 | Reward:  23.00 | Steps: 159 | Avg Loss: 0.0001 | Epsilon: 0.865 | Buffer:   159 | Max Q: 1.027
Episode    9 | Reward:  10.00 | Steps: 169 | Avg Loss: 0.0000 | Epsilon: 0.860 | Buffe

In [28]:
eval_net = net
eval_net.eval()  # disables dropout / batchnorm if any
import imageio

frames = []

obs, info = env.reset()
frame = env.render()
frames.append(frame)

# initialize frame stack with 4 copies of the first processed frame
processed = resizer(frame)
frame_stack = deque(maxlen=4)
for _ in range(4):
    frame_stack.append(processed)

state_tensor = stack_to_tensor(frame_stack)
total_reward = 0
done = False

while not done:
    with torch.no_grad():
        q_values = eval_net(state_tensor)
    action = q_values.argmax(dim=1).item()  # greedy

    obs, reward, terminated, truncated, info = env.step(action)
    done = terminated or truncated
    total_reward += reward

    frame_next = env.render()
    frames.append(frame_next)

    processed_next = resizer(frame_next)
    frame_stack.append(processed_next)
    state_tensor = stack_to_tensor(frame_stack)


In [29]:
imageio.mimsave('eval_cartpole.gif', frames, fps=30)
print("Saved evaluation GIF. Total Reward:", total_reward)


Saved evaluation GIF. Total Reward: 103.0
