# DQN Improvements: Double DQN & Prioritized Replay

This notebook demonstrates two important improvements to the original DQN:

1. **Double DQN (DDQN)** : reduces overestimation bias by using the online network to select the next action and the target network to evaluate it.
2. **Prioritized Experience Replay (PER)** : samples important transitions (with high TD error) more frequently to speed up learning.

We'll provide clear, minimal implementations (educational, not production-optimized) using TensorFlow/Keras and `gymnasium` CartPole-v1.

## ⚙️ Imports & Setup

Uncomment pip installs if you need to run in a fresh environment.

In [None]:
# !pip install gymnasium --quiet
# !pip install tensorflow --quiet

import numpy as np
import random
import collections
import gymnasium as gym
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
import matplotlib.pyplot as plt

print('TF', tf.__version__)

## ✅ Simple Prioritized Replay Buffer (Proportional)

This is a straightforward implementation for learning/demo purposes: we store a priority for each transition and sample with probability proportional to that priority. We also include importance-sampling (IS) weights to partially correct sampling bias.

In [None]:
class PrioritizedReplayBuffer:
    def __init__(self, capacity, alpha=0.6):
        self.capacity = capacity
        self.alpha = alpha  # how much prioritization is used (0 = uniform)
        self.buffer = []
        self.priorities = np.zeros((capacity,), dtype=np.float32)
        self.pos = 0

    def push(self, transition, priority=None):
        if len(self.buffer) < self.capacity:
            self.buffer.append(transition)
        else:
            self.buffer[self.pos] = transition
        if priority is None:
            max_prio = self.priorities.max() if len(self.buffer) > 0 else 1.0
            self.priorities[self.pos] = max_prio
        else:
            self.priorities[self.pos] = priority
        self.pos = (self.pos + 1) % self.capacity

    def sample(self, batch_size, beta=0.4):
        if len(self.buffer) == self.capacity:
            prios = self.priorities
        else:
            prios = self.priorities[:self.pos]
        probs = prios ** self.alpha
        probs /= probs.sum()

        indices = np.random.choice(len(self.buffer), batch_size, p=probs)
        samples = [self.buffer[idx] for idx in indices]

        # importance-sampling weights
        total = len(self.buffer)
        weights = (total * probs[indices]) ** (-beta)
        weights /= weights.max()
        return samples, indices, weights

    def update_priorities(self, indices, priorities):
        for idx, prio in zip(indices, priorities):
            self.priorities[idx] = prio

    def __len__(self):
        return len(self.buffer)

# quick sanity
buf = PrioritizedReplayBuffer(1000)
buf.push((np.zeros(4), 0, 1.0, np.ones(4), False))
print('PER buffer OK, len:', len(buf))

## 🏗️ Q-network (same as DQN notebook)
Small MLP from state → Q-values.

In [None]:
def build_q_network(state_shape, n_actions, hidden_units=(64,64), lr=1e-3):
    inputs = layers.Input(shape=state_shape)
    x = inputs
    for h in hidden_units:
        x = layers.Dense(h, activation='relu')(x)
    outputs = layers.Dense(n_actions, activation='linear')(x)
    model = models.Model(inputs=inputs, outputs=outputs)
    model.compile(optimizer=optimizers.Adam(learning_rate=lr), loss='mse')
    return model


## 🔧 Hyperparameters

In [None]:
SEED = 42
np.random.seed(SEED)
random.seed(SEED)
tf.random.set_seed(SEED)

ENV_NAME = 'CartPole-v1'
NUM_EPISODES = 300
MAX_STEPS = 500
BATCH_SIZE = 64
GAMMA = 0.99
BUFFER_CAPACITY = 20000
LEARNING_RATE = 1e-3
TARGET_UPDATE_FREQ = 1000

EPS_START = 1.0
EPS_END = 0.01
EPS_DECAY = 0.995

MIN_REPLAY_SIZE = 1000

PER_ALPHA = 0.6
PER_BETA_START = 0.4
PER_BETA_FRAMES = NUM_EPISODES * 1.0

print('Hyperparameters defined')

## 🧪 Training Loop with Double DQN + Prioritized Replay

Key DDQN change: when computing targets use **argmax** from online network and **value** from target network:

target = r + γ * target_q(next_state)[ argmax_a online_q(next_state)[a] ]

We update priorities using absolute TD error + small epsilon.

In [None]:
env = gym.make(ENV_NAME)
state_shape = env.observation_space.shape
n_actions = env.action_space.n

online_q = build_q_network(state_shape, n_actions, lr=LEARNING_RATE)
target_q = build_q_network(state_shape, n_actions, lr=LEARNING_RATE)
target_q.set_weights(online_q.get_weights())

replay = PrioritizedReplayBuffer(BUFFER_CAPACITY, alpha=PER_ALPHA)

eps = EPS_START
total_steps = 0
episode_rewards = []

beta = PER_BETA_START

for ep in range(1, NUM_EPISODES + 1):
    obs, _ = env.reset(seed=SEED + ep)
    state = np.array(obs, dtype=np.float32)
    ep_reward = 0
    done = False
    steps = 0

    while not done and steps < MAX_STEPS:
        # Epsilon-greedy
        if random.random() < eps:
            action = env.action_space.sample()
        else:
            q_vals = online_q.predict(state.reshape(1, -1), verbose=0)[0]
            action = int(np.argmax(q_vals))

        next_obs, reward, terminated, truncated, info = env.step(action)
        done = terminated or truncated
        next_state = np.array(next_obs, dtype=np.float32)

        # initial priority = max priority so new transitions are likely sampled
        transition = (state, action, reward, next_state, done)
        replay.push(transition)

        state = next_state
        ep_reward += reward
        steps += 1
        total_steps += 1

        # Learn
        if len(replay) >= MIN_REPLAY_SIZE:
            beta = min(1.0, PER_BETA_START + (ep / PER_BETA_FRAMES) * (1.0 - PER_BETA_START))
            samples, indices, is_weights = replay.sample(BATCH_SIZE, beta=beta)
            states_b = np.array([s for (s,a,r,ns,d) in samples], dtype=np.float32)
            actions_b = np.array([a for (s,a,r,ns,d) in samples], dtype=np.int32)
            rewards_b = np.array([r for (s,a,r,ns,d) in samples], dtype=np.float32)
            next_states_b = np.array([ns for (s,a,r,ns,d) in samples], dtype=np.float32)
            dones_b = np.array([d for (s,a,r,ns,d) in samples], dtype=np.float32)
            is_weights = np.array(is_weights, dtype=np.float32)

            # Double DQN target calculation
            q_next_online = online_q.predict(next_states_b, verbose=0)
            next_actions = np.argmax(q_next_online, axis=1)
            q_next_target = target_q.predict(next_states_b, verbose=0)
            q_next_target_vals = q_next_target[np.arange(BATCH_SIZE), next_actions]

            targets = online_q.predict(states_b, verbose=0)
            td_errors = np.zeros(BATCH_SIZE, dtype=np.float32)
            for i in range(BATCH_SIZE):
                if dones_b[i]:
                    target_val = rewards_b[i]
                else:
                    target_val = rewards_b[i] + GAMMA * q_next_target_vals[i]
                td_errors[i] = abs(targets[i, actions_b[i]] - target_val)
                targets[i, actions_b[i]] = target_val

            # apply importance-sampling weights by scaling the MSE loss during training
            # We'll perform a manual gradient step using train_on_batch with scaled targets approach
            # Simple approach: multiply (targets - pred) by IS weights when computing loss isn't directly supported
            # Instead, use train_on_batch with full targets and then update priorities using td_errors + eps
            online_q.train_on_batch(states_b, targets)

            # update priorities (+ small epsilon to avoid zero)
            new_prios = td_errors + 1e-6
            replay.update_priorities(indices, new_prios)

        # update target network periodically
        if total_steps % TARGET_UPDATE_FREQ == 0 and total_steps > 0:
            target_q.set_weights(online_q.get_weights())

    episode_rewards.append(ep_reward)
    eps = max(EPS_END, eps * EPS_DECAY)

    if ep % 10 == 0:
        avg_reward = np.mean(episode_rewards[-10:])
        print(f'Episode {ep:03d} | AvgReward(10): {avg_reward:.2f} | Eps: {eps:.3f} | ReplaySize: {len(replay)}')

# Save model
online_q.save('ddqn_per_cartpole.h5')
print('Training completed')

## 📊 Results & Visualization
Plot the episode rewards to inspect learning progress.

In [None]:
plt.figure(figsize=(10,5))
plt.plot(episode_rewards, label='Episode Reward')
plt.xlabel('Episode')
plt.ylabel('Reward')
plt.title('DDQN + PER on CartPole: Episode Rewards')
plt.grid(True)
plt.legend()
plt.show()

## 🔍 Notes and Caveats

- This PER implementation is simplified (proportional sampling via numpy). Production implementations use efficient data structures (e.g., sum trees) for O(log N) updates and sampling.
- We didn't fully incorporate IS weights in the loss function (we used a simpler routine). For better correctness, multiply TD errors by the IS weights in the loss during gradient updates (or implement custom training step).
- DDQN mitigates Q-value overestimation and usually improves stability.
- Consider advanced improvements: Dueling networks, Noisy networks, Prioritized Replay with sum-tree, Multi-step returns, and PER with correct IS weighting.

## ✅ Summary

- Implemented Double DQN target computation (select action with online network, evaluate with target network).
- Implemented a simple prioritized replay buffer with proportional prioritization and importance sampling scheduling.
- These techniques often improve sample efficiency and stability compared to vanilla DQN.