<a href="https://colab.research.google.com/github/OneFineStarstuff/TheOneEverAfter/blob/main/_A2C_Architecture.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque, namedtuple
import os

# Define a SumTree to store priorities for sampling
class SumTree:
    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = np.zeros(2 * capacity - 1)
        self.data = np.zeros(capacity, dtype=object)
        self.data_pointer = 0

    def add(self, priority, data):
        tree_index = self.data_pointer + self.capacity - 1
        self.data[self.data_pointer] = data
        self.update(tree_index, priority)
        self.data_pointer += 1
        if self.data_pointer >= self.capacity:
            self.data_pointer = 0

    def update(self, tree_index, priority):
        change = priority - self.tree[tree_index]
        self.tree[tree_index] = priority
        self._propagate(tree_index, change)

    def _propagate(self, tree_index, change):
        parent = (tree_index - 1) // 2
        self.tree[parent] += change
        if parent != 0:
            self._propagate(parent, change)

    def get_leaf(self, value):
        parent = 0
        while True:
            left = 2 * parent + 1
            right = left + 1
            if left >= len(self.tree):
                leaf = parent
                break
            else:
                if value <= self.tree[left]:
                    parent = left
                else:
                    value -= self.tree[left]
                    parent = right
        data_index = leaf - self.capacity + 1
        return leaf, self.tree[leaf], self.data[data_index]

    @property
    def total_priority(self):
        return self.tree[0]

# Define Prioritized Replay Buffer
class PrioritizedReplayBuffer:
    def __init__(self, capacity, alpha):
        self.capacity = capacity
        self.alpha = alpha
        self.tree = SumTree(capacity)
        self.max_priority = 1.0

    def add(self, experience):
        priority = self.max_priority ** self.alpha
        self.tree.add(priority, experience)

    def sample(self, batch_size, beta):
        experiences = []
        indices = []
        priorities = []
        segment = self.tree.total_priority / batch_size

        for i in range(batch_size):
            a = segment * i
            b = segment * (i + 1)
            value = random.uniform(a, b)
            index, priority, experience = self.tree.get_leaf(value)
            experiences.append(experience)
            indices.append(index)
            priorities.append(priority)

        sampling_probabilities = priorities / self.tree.total_priority
        is_weight = np.power(self.tree.capacity * sampling_probabilities, -beta)
        is_weight /= is_weight.max()

        return experiences, indices, is_weight

    def update_priorities(self, indices, priorities):
        for index, priority in zip(indices, priorities):
            self.tree.update(index, priority)
            self.max_priority = max(self.max_priority, priority)

# Define the NoisyLinear layer
class NoisyLinear(nn.Module):
    def __init__(self, in_features, out_features, sigma_init=0.017):
        super(NoisyLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight_mu = nn.Parameter(torch.FloatTensor(out_features, in_features))
        self.weight_sigma = nn.Parameter(torch.FloatTensor(out_features, in_features).fill_(sigma_init))
        self.register_buffer('weight_epsilon', torch.FloatTensor(out_features, in_features))

        self.bias_mu = nn.Parameter(torch.FloatTensor(out_features))
        self.bias_sigma = nn.Parameter(torch.FloatTensor(out_features).fill_(sigma_init))
        self.register_buffer('bias_epsilon', torch.FloatTensor(out_features))

        self.reset_parameters()

    def reset_parameters(self):
        mu_range = 1 / np.sqrt(self.in_features)
        self.weight_mu.data.uniform_(-mu_range, mu_range)
        self.bias_mu.data.uniform_(-mu_range, mu_range)

    def forward(self, x):
        weight_epsilon = torch.normal(mean=0.0, std=1.0, size=self.weight_epsilon.shape, device=self.weight_epsilon.device)
        bias_epsilon = torch.normal(mean=0.0, std=1.0, size=self.bias_epsilon.shape, device=self.bias_epsilon.device)
        weight = self.weight_mu + self.weight_sigma * weight_epsilon
        bias = self.bias_mu + self.bias_sigma * bias_epsilon
        return torch.nn.functional.linear(x, weight, bias)

# Define the Dueling Neural Network with Noisy Linear layers
class DuelingDQN(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(DuelingDQN, self).__init__()
        self.fc1 = NoisyLinear(state_dim, 128)

        # Value stream
        self.value_fc = NoisyLinear(128, 64)
        self.value = NoisyLinear(64, 1)

        # Advantage stream
        self.advantage_fc = NoisyLinear(128, 64)
        self.advantage = NoisyLinear(64, action_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))

        # Calculate value and advantage
        value = torch.relu(self.value_fc(x))
        value = self.value(value)

        advantage = torch.relu(self.advantage_fc(x))
        advantage = self.advantage(advantage)

        # Combine value and advantage
        q_values = value + (advantage - advantage.mean(dim=1, keepdim=True))
        return q_values

# Define the Policy Network for Policy Gradient
class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128)
        self.fc2 = nn.Linear(128, action_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return torch.softmax(self.fc2(x), dim=-1)  # Action probabilities

# Define the Actor-Critic Network
class ActorCriticNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(ActorCriticNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128)

        # Actor (Policy)
        self.actor = nn.Linear(128, action_dim)

        # Critic (Value)
        self.critic = nn.Linear(128, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        policy = torch.softmax(self.actor(x), dim=-1)  # Action probabilities
        value = self.critic(x)  # State value
        return policy, value

# Hyperparameters
GAMMA = 0.99            # Discount factor for future rewards
LR = 1e-3               # Learning rate
BATCH_SIZE = 64         # Batch size for experience replay
EPSILON_START = 1.0     # Initial epsilon for exploration
EPSILON_END = 0.01      # Minimum epsilon
EPSILON_DECAY = 0.995   # Decay rate for epsilon
TARGET_UPDATE = 10      # Update target network every 10 episodes
ALPHA = 0.6             # Prioritization exponent
BETA_START = 0.4        # Initial beta value for importance sampling
BETA_INCREMENT = 1e-3   # Beta increment per episode
CHECKPOINT_DIR = './checkpoints' # Directory to save checkpoints

# Create checkpoint directory if it does not exist
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# Environment setup
env = gym.make("CartPole-v1", new_step_api=True)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

# Initialize networks
policy_net = PolicyNetwork(state_dim, action_dim)
value_net = DuelingDQN(state_dim, action_dim)
target_net = DuelingDQN(state_dim, action_dim)
actor_critic_net = ActorCriticNetwork(state_dim, action_dim)
target_net.load_state_dict(value_net.state_dict())
target_net.eval()

# Initialize optimizers
optimizer_policy = optim.Adam(policy_net.parameters(), lr=LR)
optimizer_value = optim.Adam(value_net.parameters(), lr=LR)
optimizer_actor_critic = optim.Adam(actor_critic_net.parameters(), lr=LR)

# Initialize prioritized replay buffer
memory = PrioritizedReplayBuffer(capacity=10000, alpha=ALPHA)

# Function to compute returns
def compute_returns(rewards, gamma=0.99):
    returns = []
    G = 0
    for reward in reversed(rewards):
        G = reward + gamma * G
        returns.insert(0, G)
    return torch.tensor(returns)

# Function to train policy gradient
def train_policy_gradient(episode_states, episode_actions, episode_rewards):
    returns = compute_returns(episode_rewards)

    # Normalize returns for stable training
    returns = (returns - returns.mean()) / (returns.std() + 1e-8)

    policy_loss = []
    for state, action, G in zip(episode_states, episode_actions, returns):
        state = torch.FloatTensor(state)
        action = torch.tensor(action)

        # Compute log-probability and loss
        action_prob = policy_net(state)
        log_prob = torch.log(action_prob[action])
        policy_loss.append(-log_prob * G)

    # Backpropagation
    optimizer_policy.zero_grad()
    policy_loss = torch.stack(policy_loss).sum()
    policy_loss.backward()
    optimizer_policy.step()

# Function to train actor-critic network
def train_actor_critic(states, actions, rewards, next_states, dones, gamma=0.99):
    states = torch.FloatTensor(states)
    actions = torch.LongTensor(actions)
    rewards = torch.FloatTensor(rewards)
    next_states = torch.FloatTensor(next_states)
    dones = torch.FloatTensor(dones)

    # Forward pass
    policy, value = actor_critic_net(states)
    _, next_value = actor_critic_net(next_states)

    # Compute advantage
    target_value = rewards + gamma * next_value * (1 - dones)
    advantage = target_value - value

    # Compute policy loss
    log_probs = torch.log(policy.gather(1, actions.unsqueeze(1)).squeeze())
    policy_loss = -(log_probs * advantage.detach()).mean()

    # Compute value loss
    value_loss = advantage.pow(2).mean()

    # Backpropagation
    loss = policy_loss + value_loss
    optimizer_actor_critic.zero_grad()
    loss.backward()
    optimizer_actor_critic.step()

# Function to select an action
def select_action(state, epsilon):
    if random.random() < epsilon:
        return env.action_space.sample()
    with torch.no_grad():
        state = torch.FloatTensor(state).unsqueeze(0)
        q_values = value_net(state)
        return q_values.argmax().item()

# Function to store experiences in memory
def store_experience(state, action, reward, next_state, done):
    memory.add((state, action, reward, next_state, done))

# Function to sample and train the model with Double DQN
def optimize_model_double_dqn(beta):
    if len(memory.tree.data) < BATCH_SIZE:
        return

    # Sample a batch of experiences from memory
    experiences, indices, is_weight = memory.sample(BATCH_SIZE, beta)
    states, actions, rewards, next_states, dones = zip(*experiences)

    states = torch.FloatTensor(np.array(states))
    actions = torch.LongTensor(np.array(actions)).unsqueeze(1)
    rewards = torch.FloatTensor(np.array(rewards))
    next_states = torch.FloatTensor(np.array(next_states))
    dones = torch.FloatTensor(np.array(dones))
    is_weight = torch.FloatTensor(is_weight)

    # Get Q values for current states
    current_q_values = value_net(states).gather(1, actions).squeeze()

    # Double DQN update: use policy network for action selection and target network for Q-value calculation
    next_actions = value_net(next_states).argmax(1).unsqueeze(1)
    next_q_values = target_net(next_states).gather(1, next_actions).squeeze()
    target_q_values = rewards + (GAMMA * next_q_values * (1 - dones))

    # Compute loss and optimize
    loss = (current_q_values - target_q_values.detach()).pow(2) * is_weight
    loss = loss.mean()
    optimizer_value.zero_grad()
    loss.backward()
    optimizer_value.step()

    # Update priorities in memory
    new_priorities = (current_q_values - target_q_values.detach()).abs().detach().cpu().numpy()
    memory.update_priorities(indices, new_priorities + 1e-5)

# Load checkpoint if available
def load_checkpoint(path, value_net, optimizer):
    if os.path.isfile(path):
        checkpoint = torch.load(path, weights_only=True)
        value_net.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        return checkpoint['episode']
    return 0

# Save model checkpoints
def save_checkpoint(episode, value_net, optimizer, path):
    torch.save({
        'episode': episode,
        'model_state_dict': value_net.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, path)

# Enable anomaly detection
torch.autograd.set_detect_anomaly(True)

# Training loop
num_episodes = 500
epsilon = EPSILON_START
beta = BETA_START

start_episode = load_checkpoint(os.path.join(CHECKPOINT_DIR, 'dqn_checkpoint.pth'), value_net, optimizer_value)

for episode in range(start_episode, num_episodes):
    state = env.reset()
    episode_states = []
    episode_actions = []
    episode_rewards = []
    episode_next_states = []
    episode_dones = []
    total_reward = 0

    for t in range(200):
        action = select_action(state, epsilon)
        next_state, reward, done, truncated, _ = env.step(action)
        total_reward += reward

        episode_states.append(state)
        episode_actions.append(action)
        episode_rewards.append(reward)
        episode_next_states.append(next_state)
        episode_dones.append(done or truncated)

        store_experience(state, action, reward, next_state, done or truncated)
        state = next_state

        optimize_model_double_dqn(beta)

        if done or truncated:
            break

    train_policy_gradient(episode_states, episode_actions, episode_rewards)

    train_actor_critic(episode_states, episode_actions, episode_rewards, episode_next_states, episode_dones)

    # Decay epsilon for exploration-exploitation trade-off
    epsilon = max(EPSILON_END, epsilon * EPSILON_DECAY)
    beta = min(1.0, beta + BETA_INCREMENT)

    # Update target network every TARGET_UPDATE episodes
    if episode % TARGET_UPDATE == 0:
        target_net.load_state_dict(value_net.state_dict())

    # Save checkpoint
    if episode % TARGET_UPDATE == 0:
        save_checkpoint(episode, value_net, optimizer_value, os.path.join(CHECKPOINT_DIR, 'dqn_checkpoint.pth'))

    print(f"Episode {episode}, Total Reward: {total_reward}")

env.close()