In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# From Q-Tables to Neural Networks: The Path to Deep Q-Learning

*Part 4 of the Vizuara series on Value Functions and Q-Learning*
*Estimated time: 50 minutes*

## 1. Why Does This Matter?

In the previous notebooks, we built Q-Learning from scratch and trained an agent to solve FrozenLake -- a tiny 4x4 grid with just 16 states.

But real-world problems are far larger. Chess has roughly $10^{46}$ states. Atari games have pixel-based observations with billions of possible frames. A self-driving car sees a continuous stream of camera images -- effectively an infinite state space.

Can we store a Q-value for every single state-action pair? Absolutely not. We need to **generalize** -- learn a function that approximates Q-values for states we have never seen before.

This is exactly what Deep Q-Networks (DQN) do: replace the Q-table with a neural network that takes a state as input and outputs Q-values for all actions. This is the breakthrough that allowed DeepMind to play Atari games at superhuman levels in 2013.

By the end of this notebook, you will:
- Understand why tables fail for large state spaces
- Implement a neural network Q-function from scratch in PyTorch
- Train a DQN agent to solve CartPole
- See how experience replay and target networks stabilize training

## 2. Building Intuition

Think of it this way. A Q-table is like a phone book -- one entry for every person. If there are 16 people, the phone book is tiny. If there are 10 billion people, you need a very big book.

A neural network is like a person who has seen enough phone numbers to guess the pattern: "Oh, people in this area code tend to have numbers starting with 555." They do not memorize every number -- they learn the structure.

Similarly, a DQN does not memorize Q-values for every state. It learns the structure: "States that look like THIS tend to have high Q-values for action RIGHT." When it sees a new state it has never encountered, it can still estimate the Q-values by recognizing the pattern.

### Think About This

If two game states look almost identical (say, the paddle in Pong is shifted by one pixel), should their Q-values be very different? If not, is it wasteful to store them separately? This is the core motivation for function approximation.

## 3. The Mathematics

### The DQN Loss Function

Instead of a table update, we train a neural network $Q_\theta(s, a)$ to minimize:

$$\mathcal{L}(\theta) = \mathbb{E}\left[\left(r + \gamma \max_{a'} Q_{\theta^-}(s', a') - Q_\theta(s, a)\right)^2\right]$$

This is the mean squared TD error. The network tries to make its Q-value predictions match the TD targets.

Key innovations of DQN:

1. **Experience Replay**: Store transitions $(s, a, r, s', \text{done})$ in a buffer. Sample random mini-batches for training. This breaks correlations between consecutive samples.

2. **Target Network**: Use a separate, slowly-updated copy of the network ($Q_{\theta^-}$) to compute TD targets. This prevents the moving target problem.

## 4. Let's Build It -- Component by Component

### 4.1 First, Let's See the Table Limitation

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import random

# The scaling problem
state_spaces = {
    "FrozenLake (4x4)": 16,
    "FrozenLake (8x8)": 64,
    "Taxi-v3": 500,
    "Blackjack": 704,
    "Chess (approx)": 1e46,
    "Go (approx)": 1e170,
    "Atari (raw pixels)": 256 ** (210 * 160 * 3),
}

print("State Space Sizes:")
print("-" * 50)
for name, size in state_spaces.items():
    if size < 1e6:
        print(f"  {name:30s} {int(size):>15,} states")
    else:
        print(f"  {name:30s}  ~10^{np.log10(float(size)):.0f} states")

print()
print("A Q-table with 10^46 entries would require more memory than")
print("all the atoms in the observable universe.")
print()
print("Solution: Replace the table with a FUNCTION APPROXIMATOR.")

### 4.2 The Q-Network

In [None]:
class QNetwork(nn.Module):
    """A simple feed-forward Q-Network."""

    def __init__(self, state_dim, n_actions, hidden_dim=128):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_actions),
        )

    def forward(self, state):
        """
        Input: state tensor of shape (batch_size, state_dim)
        Output: Q-values for each action, shape (batch_size, n_actions)
        """
        return self.network(state)


# Test the network
state_dim = 4  # CartPole has 4 observations
n_actions = 2  # CartPole has 2 actions (left, right)

q_net = QNetwork(state_dim, n_actions)
print(f"Q-Network architecture:")
print(q_net)
print(f"\nTotal parameters: {sum(p.numel() for p in q_net.parameters()):,}")

# Test forward pass
dummy_state = torch.randn(1, state_dim)
q_values = q_net(dummy_state)
print(f"\nInput state shape: {dummy_state.shape}")
print(f"Output Q-values: {q_values.detach().numpy()}")
print(f"Best action: {q_values.argmax().item()}")

### 4.3 Experience Replay Buffer

In [None]:
class ReplayBuffer:
    """Fixed-size buffer to store experience tuples."""

    def __init__(self, capacity=10000):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        """Store a transition."""
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        """Sample a random batch of transitions."""
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)

        return (
            torch.FloatTensor(np.array(states)),
            torch.LongTensor(actions),
            torch.FloatTensor(rewards),
            torch.FloatTensor(np.array(next_states)),
            torch.FloatTensor(dones),
        )

    def __len__(self):
        return len(self.buffer)


# Test
buffer = ReplayBuffer(capacity=1000)

# Add some dummy experiences
for i in range(100):
    s = np.random.randn(4)
    a = np.random.randint(2)
    r = np.random.randn()
    s_next = np.random.randn(4)
    done = np.random.random() < 0.1
    buffer.push(s, a, r, s_next, done)

print(f"Buffer size: {len(buffer)}")
states, actions, rewards, next_states, dones = buffer.sample(32)
print(f"Batch shapes: states={states.shape}, actions={actions.shape}, rewards={rewards.shape}")

### 4.4 The DQN Agent

In [None]:
class DQNAgent:
    """Deep Q-Network agent with experience replay and target network."""

    def __init__(self, state_dim, n_actions, hidden_dim=128,
                 lr=1e-3, gamma=0.99, epsilon=1.0, epsilon_decay=0.995,
                 epsilon_min=0.01, buffer_size=10000, batch_size=64,
                 target_update_freq=100):
        self.state_dim = state_dim
        self.n_actions = n_actions
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min
        self.batch_size = batch_size
        self.target_update_freq = target_update_freq

        # Q-Network and Target Network
        self.q_net = QNetwork(state_dim, n_actions, hidden_dim)
        self.target_net = QNetwork(state_dim, n_actions, hidden_dim)
        self.target_net.load_state_dict(self.q_net.state_dict())

        self.optimizer = optim.Adam(self.q_net.parameters(), lr=lr)
        self.buffer = ReplayBuffer(buffer_size)
        self.steps = 0
        self.losses = []

    def choose_action(self, state):
        """Epsilon-greedy action selection."""
        if np.random.random() < self.epsilon:
            return np.random.randint(self.n_actions)

        with torch.no_grad():
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            q_values = self.q_net(state_tensor)
            return q_values.argmax().item()

    def store(self, state, action, reward, next_state, done):
        """Store transition in replay buffer."""
        self.buffer.push(state, action, reward, next_state, done)

    def train_step(self):
        """Perform one training step (sample batch, compute loss, update)."""
        if len(self.buffer) < self.batch_size:
            return

        states, actions, rewards, next_states, dones = self.buffer.sample(self.batch_size)

        # Current Q-values: Q(s, a) for the actions we actually took
        current_q = self.q_net(states).gather(1, actions.unsqueeze(1)).squeeze(1)

        # Target Q-values using the target network
        with torch.no_grad():
            next_q = self.target_net(next_states).max(1)[0]
            target_q = rewards + self.gamma * next_q * (1 - dones)

        # Loss: MSE between current and target Q-values
        loss = nn.MSELoss()(current_q, target_q)

        # Optimize
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.losses.append(loss.item())
        self.steps += 1

        # Update target network periodically
        if self.steps % self.target_update_freq == 0:
            self.target_net.load_state_dict(self.q_net.state_dict())

    def decay_epsilon(self):
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)


print("DQN Agent components:")
print("  1. Q-Network: predicts Q(s,a) for all actions")
print("  2. Target Network: stable target for TD computation")
print("  3. Replay Buffer: breaks correlation in training data")
print("  4. Epsilon-greedy: balances exploration and exploitation")

## 5. Your Turn

### TODO: Implement the Training Loop for CartPole

In [None]:
import gymnasium as gym

def train_dqn(env_name="CartPole-v1", n_episodes=500, render_every=None):
    """
    Train a DQN agent on CartPole.

    CartPole: Balance a pole on a cart by moving left or right.
    State: [cart_position, cart_velocity, pole_angle, pole_angular_velocity]
    Actions: 0 = push left, 1 = push right
    Reward: +1 for each step the pole stays upright
    """
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    n_actions = env.action_space.n

    agent = DQNAgent(
        state_dim=state_dim,
        n_actions=n_actions,
        hidden_dim=128,
        lr=1e-3,
        gamma=0.99,
        epsilon=1.0,
        epsilon_decay=0.995,
        epsilon_min=0.01,
        buffer_size=10000,
        batch_size=64,
        target_update_freq=100,
    )

    rewards_per_episode = []

    for episode in range(n_episodes):
        state, _ = env.reset()
        total_reward = 0
        done = False

        # ============ TODO ============
        # Implement the episode loop:
        #
        # While not done:
        #   1. Choose action using agent.choose_action(state)
        #   2. Take action: next_state, reward, terminated, truncated, _ = env.step(action)
        #   3. Compute done = terminated or truncated
        #   4. Store transition: agent.store(state, action, reward, next_state, float(terminated))
        #   5. Perform training step: agent.train_step()
        #   6. Update state and total_reward
        #
        # After episode:
        #   7. Decay epsilon: agent.decay_epsilon()
        #   8. Append total_reward to rewards_per_episode
        # ==============================

        pass  # YOUR CODE HERE

        if (episode + 1) % 50 == 0:
            recent = np.mean(rewards_per_episode[-50:]) if rewards_per_episode else 0
            print(f"Episode {episode+1:4d} | Reward: {recent:6.1f} | Epsilon: {agent.epsilon:.3f}")

    env.close()
    return agent, rewards_per_episode

In [None]:
# Verification
agent_dqn, dqn_rewards = train_dqn(n_episodes=500)

# CartPole is "solved" when average reward > 195 over 100 episodes
final_avg = np.mean(dqn_rewards[-100:])
print(f"\nFinal average reward (last 100): {final_avg:.1f}")
if final_avg > 195:
    print("CartPole SOLVED!")
else:
    print(f"Not quite solved yet. Try more episodes or tuning hyperparameters.")

## 6. Putting It All Together

In [None]:
# Compare tabular Q-learning (impossible for CartPole) vs DQN

print("Why DQN works where tables fail:")
print("-" * 50)
print(f"CartPole state space: CONTINUOUS (4 real numbers)")
print(f"CartPole state examples:")
env_demo = gym.make("CartPole-v1")
for i in range(5):
    s, _ = env_demo.reset()
    print(f"  State {i+1}: cart_pos={s[0]:.3f}, cart_vel={s[1]:.3f}, "
          f"pole_angle={s[2]:.3f}, pole_angvel={s[3]:.3f}")
env_demo.close()

print()
print("Each state is unique -- a table would need infinite entries!")
print("The neural network generalizes: similar states get similar Q-values.")

## 7. Training and Results

In [None]:
if dqn_rewards:
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))

    # Reward curve
    window = 20
    moving_avg = [np.mean(dqn_rewards[max(0,i-window):i+1]) for i in range(len(dqn_rewards))]

    axes[0].plot(dqn_rewards, alpha=0.3, color='#2171b5')
    axes[0].plot(moving_avg, color='#2171b5', linewidth=2)
    axes[0].axhline(y=195, color='green', linestyle='--', label='Solved threshold')
    axes[0].set_xlabel('Episode', fontsize=11)
    axes[0].set_ylabel('Reward', fontsize=11)
    axes[0].set_title('DQN on CartPole', fontsize=13, fontweight='bold')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)

    # Loss curve
    if agent_dqn.losses:
        loss_window = 100
        loss_avg = [np.mean(agent_dqn.losses[max(0,i-loss_window):i+1])
                    for i in range(len(agent_dqn.losses))]
        axes[1].plot(loss_avg[::10], color='#d94701', linewidth=1)
        axes[1].set_xlabel('Training Step (x10)', fontsize=11)
        axes[1].set_ylabel('Loss', fontsize=11)
        axes[1].set_title('Training Loss', fontsize=13, fontweight='bold')
        axes[1].grid(True, alpha=0.3)

    # Q-value landscape: test different pole angles
    angles = np.linspace(-0.2, 0.2, 100)
    q_left = []
    q_right = []
    for angle in angles:
        state = torch.FloatTensor([[0.0, 0.0, angle, 0.0]])
        with torch.no_grad():
            q = agent_dqn.q_net(state)
        q_left.append(q[0, 0].item())
        q_right.append(q[0, 1].item())

    axes[2].plot(np.degrees(angles), q_left, label='Push Left', color='#d94701')
    axes[2].plot(np.degrees(angles), q_right, label='Push Right', color='#2171b5')
    axes[2].axvline(x=0, color='gray', linestyle='--', alpha=0.5)
    axes[2].set_xlabel('Pole Angle (degrees)', fontsize=11)
    axes[2].set_ylabel('Q-value', fontsize=11)
    axes[2].set_title('Learned Q-values vs Pole Angle', fontsize=13, fontweight='bold')
    axes[2].legend()
    axes[2].grid(True, alpha=0.3)

    plt.suptitle('Deep Q-Network: Complete Analysis', fontsize=15, fontweight='bold')
    plt.tight_layout()
    plt.show()
else:
    print("No training data available. Make sure the training loop TODO is completed.")

## 8. Final Output

In [None]:
# Test the trained agent
if dqn_rewards:
    env_test = gym.make("CartPole-v1")
    test_rewards = []

    for _ in range(100):
        state, _ = env_test.reset()
        total = 0
        done = False
        while not done:
            with torch.no_grad():
                q = agent_dqn.q_net(torch.FloatTensor(state).unsqueeze(0))
                action = q.argmax().item()
            state, reward, terminated, truncated, _ = env_test.step(action)
            done = terminated or truncated
            total += reward
        test_rewards.append(total)
    env_test.close()

    print("=" * 60)
    print("  FINAL RESULTS: DQN on CartPole-v1")
    print("=" * 60)
    print(f"  Test episodes:     100")
    print(f"  Mean reward:       {np.mean(test_rewards):.1f}")
    print(f"  Std reward:        {np.std(test_rewards):.1f}")
    print(f"  Min reward:        {np.min(test_rewards):.0f}")
    print(f"  Max reward:        {np.max(test_rewards):.0f}")
    print(f"  Solved (>195):     {'YES' if np.mean(test_rewards) > 195 else 'NO'}")
    print()
    print("  The neural network learned to balance the pole!")
    print("  From tabular Q-learning on 16 states to DQN on continuous spaces.")
    print("  This is the same idea that DeepMind used to play Atari in 2013.")

    # Histogram of test rewards
    plt.figure(figsize=(8, 4))
    plt.hist(test_rewards, bins=20, color='#2171b5', edgecolor='white', alpha=0.8)
    plt.axvline(x=195, color='green', linestyle='--', linewidth=2, label='Solved threshold')
    plt.xlabel('Episode Reward', fontsize=12)
    plt.ylabel('Count', fontsize=12)
    plt.title('DQN Test Performance Distribution', fontsize=14, fontweight='bold')
    plt.legend(fontsize=11)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

    print()
    print("Congratulations! You have built Deep Q-Learning from scratch!")

## 9. Reflection and Next Steps

### Reflection Questions
1. Why does DQN need two networks (Q-network and target network)? What goes wrong with just one?
2. Experience replay breaks the correlation between consecutive samples. Why is this important for neural network training?
3. The DQN loss function looks like supervised learning (MSE loss). But the "labels" (TD targets) change during training. How is this different from standard supervised learning?

### Optional Challenges
1. Implement Double DQN: use the Q-network to select the best action but the target network to evaluate it. This reduces overestimation.
2. Implement Prioritized Experience Replay: sample transitions with high TD error more frequently.
3. Try DQN on LunarLander-v3. This is a harder environment -- how many episodes does it take to solve?