<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Dueling_DQN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque
import os

# Define the Dueling Neural Network for the Q-Learning model
class DuelingDQN(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(DuelingDQN, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128)

        # Value stream
        self.value_fc = nn.Linear(128, 64)
        self.value = nn.Linear(64, 1)

        # Advantage stream
        self.advantage_fc = nn.Linear(128, 64)
        self.advantage = nn.Linear(64, action_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))

        # Calculate value and advantage
        value = torch.relu(self.value_fc(x))
        value = self.value(value)

        advantage = torch.relu(self.advantage_fc(x))
        advantage = self.advantage(advantage)

        # Combine value and advantage
        q_values = value + (advantage - advantage.mean())
        return q_values

# Hyperparameters
GAMMA = 0.99          # Discount factor for future rewards
LR = 1e-3             # Learning rate
BATCH_SIZE = 64       # Batch size for experience replay
EPSILON_START = 1.0   # Initial epsilon for exploration
EPSILON_END = 0.01    # Minimum epsilon
EPSILON_DECAY = 0.995 # Decay rate for epsilon
TARGET_UPDATE = 10    # Update target network every 10 episodes
CHECKPOINT_DIR = './checkpoints' # Directory to save checkpoints

# Create checkpoint directory if it does not exist
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# Environment and network setup
env = gym.make("CartPole-v1")
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

policy_net = DuelingDQN(state_dim, action_dim)
target_net = DuelingDQN(state_dim, action_dim)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.Adam(policy_net.parameters(), lr=LR)
memory = deque(maxlen=10000)

# Function to select an action
def select_action(state, epsilon):
    if random.random() < epsilon:
        return env.action_space.sample()
    with torch.no_grad():
        state = torch.FloatTensor(state).unsqueeze(0)
        q_values = policy_net(state)
        return q_values.argmax().item()

# Function to store experiences in memory
def store_experience(state, action, reward, next_state, done):
    memory.append((state, action, reward, next_state, done))

# Function to sample and train the model with Double DQN
def optimize_model_double_dqn():
    if len(memory) < BATCH_SIZE:
        return

    # Sample a batch of experiences
    batch = random.sample(memory, BATCH_SIZE)
    states, actions, rewards, next_states, dones = zip(*batch)

    states = torch.FloatTensor(states)
    actions = torch.LongTensor(actions).unsqueeze(1)
    rewards = torch.FloatTensor(rewards)
    next_states = torch.FloatTensor(next_states)
    dones = torch.FloatTensor(dones)

    # Get Q values for current states
    current_q_values = policy_net(states).gather(1, actions).squeeze()

    # Double DQN update: use policy network for action selection and target network for Q-value calculation
    next_actions = policy_net(next_states).argmax(1).unsqueeze(1)
    next_q_values = target_net(next_states).gather(1, next_actions).squeeze()
    target_q_values = rewards + (GAMMA * next_q_values * (1 - dones))

    # Compute loss and optimize
    loss = nn.MSELoss()(current_q_values, target_q_values.detach())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# Function to save model checkpoints
def save_checkpoint(episode, policy_net, optimizer, path):
    torch.save({
        'episode': episode,
        'model_state_dict': policy_net.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, path)

# Function to load model checkpoints
def load_checkpoint(path, policy_net, optimizer):
    if os.path.isfile(path):
        checkpoint = torch.load(path)
        policy_net.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        return checkpoint['episode']
    return 0

# Training loop
num_episodes = 500
epsilon = EPSILON_START

# Load checkpoint if available
start_episode = load_checkpoint(os.path.join(CHECKPOINT_DIR, 'dqn_checkpoint.pth'), policy_net, optimizer)

for episode in range(start_episode, num_episodes):
    state = env.reset()
    total_reward = 0

    for t in range(200):
        action = select_action(state, epsilon)
        next_state, reward, done, _ = env.step(action)
        total_reward += reward

        store_experience(state, action, reward, next_state, done)
        state = next_state

        optimize_model_double_dqn()

        if done:
            break

    # Decay epsilon for exploration-exploitation trade-off
    epsilon = max(EPSILON_END, epsilon * EPSILON_DECAY)

    # Update target network every TARGET_UPDATE episodes
    if episode % TARGET_UPDATE == 0:
        target_net.load_state_dict(policy_net.state_dict())

    # Save checkpoint
    if episode % TARGET_UPDATE == 0:
        save_checkpoint(episode, policy_net, optimizer, os.path.join(CHECKPOINT_DIR, 'dqn_checkpoint.pth'))

    print(f"Episode {episode}, Total Reward: {total_reward}")

env.close()