<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Deep_Q_Networks_(DQN)_with_Prioritized_Experience_Replay.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import namedtuple, deque
import random

# Define the Q-network
class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(QNetwork, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, action_dim)
        )

    def forward(self, x):
        return self.fc(x)

# Prioritized Experience Replay buffer
class PrioritizedReplayBuffer:
    def __init__(self, capacity, alpha):
        self.capacity = capacity
        self.alpha = alpha
        self.buffer = []
        self.priorities = np.zeros((capacity,), dtype=np.float32)
        self.pos = 0

    def add(self, experience, priority):
        max_priority = self.priorities.max() if self.buffer else 1.0
        priority = max(priority, max_priority)
        if len(self.buffer) < self.capacity:
            self.buffer.append(experience)
        else:
            self.buffer[self.pos] = experience
        self.priorities[self.pos] = priority ** self.alpha
        self.pos = (self.pos + 1) % self.capacity

    def sample(self, batch_size, beta):
        if len(self.buffer) == self.capacity:
            priorities = self.priorities
        else:
            priorities = self.priorities[:self.pos]

        probs = priorities / priorities.sum()
        indices = np.random.choice(len(self.buffer), batch_size, p=probs)
        experiences = [self.buffer[idx] for idx in indices]
        total = len(self.buffer)
        weights = (total * probs[indices]) ** (-beta)
        weights /= weights.max()
        return experiences, indices, weights

    def update_priorities(self, indices, priorities):
        for idx, priority in zip(indices, priorities):
            self.priorities[idx] = priority.item() ** self.alpha  # Extract single element

# Define hyperparameters and environment
env = gym.make("CartPole-v1", new_step_api=True)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
lr = 1e-3
gamma = 0.99
epsilon_start = 1.0
epsilon_end = 0.1
epsilon_decay = 500
batch_size = 64
beta_start = 0.4
beta_increment = 1e-4
memory_capacity = 10000
alpha = 0.6

# Initialize networks, buffer, and optimizer
q_net = QNetwork(state_dim, action_dim)
target_q_net = QNetwork(state_dim, action_dim)
target_q_net.load_state_dict(q_net.state_dict())
optimizer = optim.Adam(q_net.parameters(), lr=lr)
replay_buffer = PrioritizedReplayBuffer(memory_capacity, alpha)
criterion = nn.MSELoss()

# Function to update epsilon and beta
def update_hyperparameters(step):
    epsilon = epsilon_end + (epsilon_start - epsilon_end) * np.exp(-1. * step / epsilon_decay)
    beta = min(1.0, beta_start + step * beta_increment)
    return epsilon, beta

# Training loop
for episode in range(500):
    state = env.reset()
    total_reward = 0
    done = False
    step = 0

    while not done:
        epsilon, beta = update_hyperparameters(step)
        if random.random() > epsilon:
            state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
            action = q_net(state_tensor).argmax().item()
        else:
            action = env.action_space.sample()

        next_state, reward, done, truncated, _ = env.step(action)
        total_reward += reward

        state_tensor = torch.tensor(state, dtype=torch.float32)
        next_state_tensor = torch.tensor(next_state, dtype=torch.float32)
        action_tensor = torch.tensor(action, dtype=torch.int64).unsqueeze(0)
        reward_tensor = torch.tensor(reward, dtype=torch.float32).unsqueeze(0)
        done_tensor = torch.tensor(done, dtype=torch.float32).unsqueeze(0)

        with torch.no_grad():
            target_q = reward_tensor + gamma * target_q_net(next_state_tensor).max() * (1 - done_tensor)
        current_q = q_net(state_tensor)[action_tensor]

        td_error = torch.abs(target_q - current_q).item()
        replay_buffer.add((state, action, reward, next_state, done), td_error)

        if len(replay_buffer.buffer) >= batch_size:
            experiences, indices, weights = replay_buffer.sample(batch_size, beta)
            batch_states, batch_actions, batch_rewards, batch_next_states, batch_dones = zip(*experiences)

            batch_states = torch.tensor(np.array(batch_states), dtype=torch.float32)
            batch_actions = torch.tensor(np.array(batch_actions), dtype=torch.int64).unsqueeze(1)
            batch_rewards = torch.tensor(np.array(batch_rewards), dtype=torch.float32).unsqueeze(1)
            batch_next_states = torch.tensor(np.array(batch_next_states), dtype=torch.float32)
            batch_dones = torch.tensor(np.array(batch_dones), dtype=torch.float32).unsqueeze(1)

            with torch.no_grad():
                next_q_values = target_q_net(batch_next_states).max(1, keepdim=True)[0]
                expected_q_values = batch_rewards + gamma * next_q_values * (1 - batch_dones)

            current_q_values = q_net(batch_states).gather(1, batch_actions)
            loss = (torch.tensor(weights, dtype=torch.float32) * criterion(current_q_values, expected_q_values)).mean()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            priorities = torch.abs(expected_q_values - current_q_values).detach().numpy()
            replay_buffer.update_priorities(indices, priorities)

        state = next_state
        step += 1

    if episode % 10 == 0:
        target_q_net.load_state_dict(q_net.state_dict())

    print(f"Episode {episode}, Total Reward: {total_reward}")

print("Training complete.")