In [5]:
# Assignment 3: DQN for Pong - Expert Notebook

# 1. Imports and Seed Setting
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque
import matplotlib.pyplot as plt
from assignment3_utils import img_crop, downsample, to_grayscale, normalize_grayscale, process_frame, transform_reward

np.random.seed(42)
torch.manual_seed(42)
random.seed(42)




In [6]:
import gymnasium as gym
import ale_py


In [8]:
# 2. Environment Setup
image_shape = (84, 80)

def make_env():
    try:
        env = gym.make('ALE/Pong-v5')
    except Exception:
        env = gym.make('Pong-v4')
    return env

env = make_env()

# 3. Frame Preprocessing Integration
def preprocess_frame(img: np.ndarray, image_shape: tuple) -> np.ndarray:
    img = img_crop(img)
    img = downsample(img)
    img = to_grayscale(img)
    img = normalize_grayscale(img)
    return np.expand_dims(img.reshape(image_shape[0], image_shape[1], 1), axis=0)



def get_initial_state(env, image_shape):
    obs, info = env.reset()
    processed = preprocess_frame(obs, image_shape).squeeze(0)
    return np.stack([processed] * 4, axis=0)  # shape (4, 84, 80)

print(f"Initial state shape: {get_initial_state(env, image_shape).shape}")

Initial state shape: (4, 84, 80, 1)


In [9]:
# 4. DQN Network
class DQNCNN(nn.Module):
    def __init__(self, input_channels, n_actions):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_channels, 32, kernel_size=8, stride=4), nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU()
        )
        conv_w = self._conv2d_size_out(self._conv2d_size_out(self._conv2d_size_out(image_shape[1],8,4),4,2),3,1)
        conv_h = self._conv2d_size_out(self._conv2d_size_out(self._conv2d_size_out(image_shape[0],8,4),4,2),3,1)
        linear_input_size = conv_w * conv_h * 64
        self.fc = nn.Sequential(
            nn.Linear(linear_input_size, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )
    def _conv2d_size_out(self, size, kernel_size, stride, padding=0):
        return (size - kernel_size + 2*padding) // stride + 1
    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)


In [10]:
# Example test
input_channels = 4  # 4 stacked frames
n_actions = env.action_space.n  # number of actions in Pong
model = DQNCNN(input_channels, n_actions)

# Dummy input: batch of 1, shape (1, 4, 84, 80)
dummy_input = torch.randn(1, 4, 84, 80)
output = model(dummy_input)
print("Output shape:", output.shape)


Output shape: torch.Size([1, 6])


In [11]:
# 5. Replay Buffer
class ReplayBuffer:
    def __init__(self, capacity=100_000):
        self.buffer = deque(maxlen=capacity)
    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return (np.array(states), np.array(actions), np.array(rewards),
                np.array(next_states), np.array(dones).astype(np.float32))
    def __len__(self):
        return len(self.buffer)

In [12]:
buffer = ReplayBuffer(capacity=10)  # create a buffer
print(len(buffer))                   # check how many items are in it (should be 0)

# Add a dummy experience
state = np.zeros((4, 84, 80))
action = 0
reward = 1.0
next_state = np.ones((4, 84, 80))
done = False
buffer.push(state, action, reward, next_state, done)

print(len(buffer))  # now it should be 1


0
1


In [13]:
# 6. Epsilon-greedy Action Selection
def epsilon_greedy(state, model, epsilon, n_actions):
    if random.random() < epsilon:
        return random.randrange(n_actions)
    state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
    with torch.no_grad():
        return model(state_tensor).max(1)[1].item()

In [15]:
# 7. Training Loop
def train_dqn(env, policy_net, target_net, optimizer, buffer, image_shape,
              num_episodes=500, batch_size=8, target_update=10,
              gamma=0.95, eps_start=1.0, eps_end=0.05, eps_decay=0.995):
    epsilon = eps_start
    n_actions = env.action_space.n
    rewards_per_ep = []
    avg_rewards = []

    for ep in range(num_episodes):
        state = get_initial_state(env, image_shape)
        total_reward = 0
        done = False

        while not done:
            action = epsilon_greedy(state, policy_net, epsilon, n_actions)
            obs, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated
            processed_next = preprocess_frame(obs, image_shape).squeeze(0)
            next_state = np.append(state[1:], [processed_next], axis=0)
            buffer.push(state, action, reward, next_state, done)
            state = next_state
            total_reward += reward

            if len(buffer) >= batch_size:
                states, actions, rewards, next_states, dones = buffer.sample(batch_size)
                states_v = torch.tensor(states, dtype=torch.float32)
                next_states_v = torch.tensor(next_states, dtype=torch.float32)
                actions_v = torch.tensor(actions).unsqueeze(1)
                rewards_v = torch.tensor(rewards).unsqueeze(1)
                dones_v = torch.tensor(dones).unsqueeze(1)

                q_vals = policy_net(states_v).gather(1, actions_v)
                next_q_vals = target_net(next_states_v).max(1)[0].unsqueeze(1).detach()
                q_targets = rewards_v + gamma * next_q_vals * (1 - dones_v)
                loss = nn.MSELoss()(q_vals, q_targets)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            epsilon = max(epsilon * eps_decay, eps_end)

        if ep % target_update == 0:
            target_net.load_state_dict(policy_net.state_dict())

        rewards_per_ep.append(total_reward)
        avg_5 = np.mean(rewards_per_ep[-5:])
        avg_rewards.append(avg_5)

        if ep % 10 == 0:
            print(f"Episode {ep} | Reward: {total_reward} | Avg5: {avg_5:.2f} | Epsilon: {epsilon:.2f}")

    return rewards_per_ep, avg_rewards



In [19]:
# 8. Run Experiments and Plot Results
def run_experiments():
    configs = [
        {"batch_size": 8, "target_update": 10},
        {"batch_size": 16, "target_update": 10},
        {"batch_size": 8, "target_update": 3},
        {"batch_size": 16, "target_update": 3},
    ]
    episodes = 500
    all_results = []

    for config in configs:
        print(f"Running config: batch_size={config['batch_size']}, target_update={config['target_update']}")
        policy_net = DQNCNN(4, env.action_space.n)
        target_net = DQNCNN(4, env.action_space.n)
        target_net.load_state_dict(policy_net.state_dict())
        target_net.eval()
        optimizer = optim.Adam(policy_net.parameters(), lr=1e-4)
        buffer = ReplayBuffer()

        rewards, avg_rewards = train_dqn(
            env, policy_net, target_net, optimizer, buffer, image_shape,
            num_episodes=episodes, batch_size=config['batch_size'], target_update=config['target_update']
        )

        all_results.append((config, rewards, avg_rewards))

        plt.figure(figsize=(12,6))
        plt.plot(rewards, label='Episode Reward')
        plt.plot(avg_rewards, label='5-episode Average')
        plt.title(f"Batch Size: {config['batch_size']}, Target Update: {config['target_update']}")
        plt.xlabel('Episode')
        plt.ylabel('Reward')
        plt.legend()
        plt.show()

    # Combined Plot
    plt.figure(figsize=(14,8))
    for config, _, avg_rewards in all_results:
        label = f"B{config['batch_size']}_U{config['target_update']}"
        plt.plot(avg_rewards, label=label)
    plt.title("5-Episode Moving Average Reward for all Experiments")
    plt.xlabel("Episode")
    plt.ylabel("Average Reward")
    plt.legend()
    plt.show()

# Uncomment to run all experiments   
run_experiments()

Running config: batch_size=8, target_update=10


RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [8, 4, 84, 80, 1]