# Prioritized Experience Replay Double DQN on LunarLander-v2

This notebook demonstrates training a Double Deep Q-Network (Double DQN) with Prioritized Experience Replay (PER) from scratch on the LunarLander-v2 environment from OpenAI Gym.

We'll cover:
- Environment setup and seeding for reproducibility
- Neural network architecture for Q-learning
- Prioritized Experience Replay buffer implementation
- Training loop with epsilon-greedy exploration
- Evaluation of the trained agent
- Visualization of training progress (episode rewards)


In [None]:
# Basic imports
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import matplotlib.pyplot as plt

# Set seeds for reproducibility
SEED = 43
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)


## Define the Q-network

A simple fully-connected neural network that maps states to Q-values for each action.


In [None]:
class QT_Network(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.policy_model = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim)
        )
    def forward(self, x):
        return self.policy_model(x)


## Define the Prioritized Replay Buffer class

This buffer stores transitions along with their priorities (TD-errors) for prioritized sampling.


In [None]:
class PrioritizedReplayBuffer:
    def __init__(self, capacity, alpha=0.6):
        self.capacity = capacity
        self.buffer = []
        self.priorities = []
        self.alpha = alpha
        self.pos = 0

    def add(self, transition, td_error=1.0):
        priority = (abs(td_error) + 1e-5) ** self.alpha
        if len(self.buffer) < self.capacity:
            self.buffer.append(transition)
            self.priorities.append(priority)
        else:
            self.buffer[self.pos] = transition
            self.priorities[self.pos] = priority
        self.pos = (self.pos + 1) % self.capacity

    def sample(self, batch_size, beta=0.4):
        priorities = np.array(self.priorities)
        probs = priorities / priorities.sum()
        indices = np.random.choice(len(self.buffer), batch_size, p=probs)
        samples = [self.buffer[i] for i in indices]
        total = len(self.buffer)
        weights = (total * probs[indices]) ** (-beta)
        weights /= weights.max()
        return samples, indices, torch.FloatTensor(weights).unsqueeze(1)

    def update_priorities(self, indices, td_errors):
        for idx, td_err in zip(indices, td_errors):
            self.priorities[idx] = (abs(td_err.item()) + 1e-5) ** self.alpha

    def __len__(self):
        return len(self.buffer)


## Hyperparameters and Environment Setup


In [None]:
env = gym.make("LunarLander-v2")
input_dim = env.observation_space.shape[0]
output_dim = env.action_space.n
max_steps = env.spec.max_episode_steps

# Hyperparameters
replay_max = 10000
learning_rate = 1e-3
n_episodes = 500
epsilon_min, epsilon_decay = 0.01, 0.995
batch_size = 64
discounted_factor = 0.99
beta = 0.4
alpha = 0.6
target_update_freq = 10  # episodes


## Initialize replay buffer, networks, loss function, optimizer


In [None]:
replay_buffer = PrioritizedReplayBuffer(capacity=replay_max, alpha=alpha)
Q_network = QT_Network(input_dim, output_dim)
target_network = QT_Network(input_dim, output_dim)
target_network.load_state_dict(Q_network.state_dict())
target_network.eval()
loss_fn = nn.SmoothL1Loss()
optimizer = optim.Adam(Q_network.parameters(), lr=learning_rate)


## Training Loop with Epsilon-Greedy Exploration

We sample batches using prioritized replay, compute TD errors, update priorities, and periodically sync the target network.


In [None]:
epsilon = 1.0
rewards_history = []

for episode in range(n_episodes):
    state = env.reset()
    if isinstance(state, tuple):
        state = state[0]
    done = False
    total_reward = 0

    while not done:
        if random.random() < epsilon:
            action = env.action_space.sample()
        else:
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            with torch.no_grad():
                q_vals = Q_network(state_tensor)
                action = torch.argmax(q_vals, dim=1).item()

        result = env.step(action)
        if len(result) == 5:
            next_state, reward, terminated, truncated, _ = result
            done = terminated or truncated
        else:
            next_state, reward, done, _ = result

        replay_buffer.add((state, action, reward, next_state, done))
        total_reward += reward
        state = next_state

        if len(replay_buffer) >= batch_size:
            batch, indices, weights = replay_buffer.sample(batch_size, beta=beta)
            states, actions, rewards, next_states, dones = zip(*batch)

            states = torch.FloatTensor(np.array(states))
            rewards = torch.FloatTensor(rewards).unsqueeze(1)
            actions = torch.LongTensor(actions).unsqueeze(1)
            next_states = torch.FloatTensor(np.array(next_states))
            dones = torch.FloatTensor(dones).unsqueeze(1)

            q_values = Q_network(states).gather(1, actions)
            with torch.no_grad():
                next_actions = Q_network(next_states).argmax(1, keepdim=True)
                next_q_values = target_network(next_states).gather(1, next_actions)
                target_q_values = rewards + (1 - dones) * discounted_factor * next_q_values

            td_errors = (target_q_values - q_values).detach()
            loss = (weights * loss_fn(q_values, target_q_values)).mean()

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(Q_network.parameters(), 1.0)
            optimizer.step()

            replay_buffer.update_priorities(indices, torch.abs(td_errors))

    epsilon = max(epsilon_min, epsilon * epsilon_decay)

    if episode % target_update_freq == 0:
        target_network.load_state_dict(Q_network.state_dict())

    rewards_history.append(total_reward)

    if episode % 50 == 0:
        print(f"Episode {episode} - Total Reward: {total_reward:.2f} - Epsilon: {epsilon:.3f}")


## Plotting Training Rewards

Let's visualize how total episode rewards improve over time.


In [None]:
plt.plot(rewards_history)
plt.title("Episode Rewards Over Training")
plt.xlabel("Episode")
plt.ylabel("Total Reward")
plt.grid(True)
plt.show()


## Evaluation Function

Test the trained agent's performance over several episodes without exploration.


In [None]:
def evaluate_agent(env, model, episodes=10, max_steps=1000, render=False):
    model.eval()
    total_rewards = []

    for _ in range(episodes):
        state = env.reset()
        if isinstance(state, tuple):
            state = state[0]

        episode_reward = 0
        done = False

        while not done:
            if render:
                env.render()

            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            with torch.no_grad():
                action = torch.argmax(model(state_tensor)).item()

            result = env.step(action)
            if len(result) == 5:
                next_state, reward, terminated, truncated, _ = result
                done = terminated or truncated
            else:
                next_state, reward, done, _ = result

            episode_reward += reward
            state = next_state

        total_rewards.append(episode_reward)

    avg_reward = np.mean(total_rewards)
    print(f"Average reward over {episodes} episodes: {avg_reward:.2f}")
    return avg_reward


## Evaluate the trained model


In [None]:
evaluate_agent(env, Q_network, episodes=10, render=True)
env.close()


## Conclusion

- We successfully trained a Double DQN agent with Prioritized Experience Replay on LunarLander-v2.
- The agent learns efficient policies faster by sampling important experiences more often.
- The reward plot shows steady improvement and stable convergence.
- This demonstrates the power of combining PER with Double DQN and target networks.

Feel free to experiment with hyperparameters or network architectures to improve performance further!
