<a href="https://colab.research.google.com/github/OneFineStarstuff/TheOneEverAfter/blob/main/_Multi_Agent_Deep_Deterministic_Policy_Gradient_(MADDPG).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random

# Define Actor Network
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, max_action):
        super(Actor, self).__init__()
        self.l1 = nn.Linear(state_dim, 400)
        self.l2 = nn.Linear(400, 300)
        self.l3 = nn.Linear(300, action_dim)
        self.max_action = max_action

    def forward(self, x):
        x = torch.relu(self.l1(x))
        x = torch.relu(self.l2(x))
        x = torch.tanh(self.l3(x)) * self.max_action
        return x

# Define Critic Network
class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()
        self.l1 = nn.Linear(state_dim + action_dim, 400)
        self.l2 = nn.Linear(400, 300)
        self.l3 = nn.Linear(300, 1)

    def forward(self, x, u):
        x = torch.cat([x, u], dim=1)
        x = torch.relu(self.l1(x))
        x = torch.relu(self.l2(x))
        x = self.l3(x)
        return x

# Replay Buffer
class ReplayBuffer:
    def __init__(self, max_size=int(1e6)):
        self.buffer = deque(maxlen=max_size)

    def add(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))
        return np.array(state), np.array(action), np.array(reward), np.array(next_state), np.array(done)

# MADDPG Algorithm
class MADDPG:
    def __init__(self, state_dim, action_dim, max_action, num_agents):
        self.actors = [Actor(state_dim, action_dim, max_action).to(device) for _ in range(num_agents)]
        self.actors_target = [Actor(state_dim, action_dim, max_action).to(device) for _ in range(num_agents)]
        self.critics = [Critic(state_dim*num_agents, action_dim*num_agents).to(device) for _ in range(num_agents)]
        self.critics_target = [Critic(state_dim*num_agents, action_dim*num_agents).to(device) for _ in range(num_agents)]

        self.actor_optimizers = [optim.Adam(actor.parameters(), lr=1e-4) for actor in self.actors]
        self.critic_optimizers = [optim.Adam(critic.parameters(), lr=1e-3) for critic in self.critics]

        for i in range(num_agents):
            self.actors_target[i].load_state_dict(self.actors[i].state_dict())
            self.critics_target[i].load_state_dict(self.critics[i].state_dict())

        self.replay_buffer = ReplayBuffer()
        self.max_action = max_action
        self.num_agents = num_agents
        self.gamma = 0.99
        self.tau = 0.005

    def select_action(self, state):
        actions = []
        for i in range(self.num_agents):
            state_tensor = torch.FloatTensor(state[i]).unsqueeze(0).to(device)
            action = self.actors[i](state_tensor).cpu().data.numpy().flatten()
            actions.append(action)
        return actions

    def train(self, batch_size=100):
        state, action, reward, next_state, done = self.replay_buffer.sample(batch_size)

        state = torch.FloatTensor(state).to(device)
        next_state = torch.FloatTensor(next_state).to(device)
        action = torch.FloatTensor(action).to(device)
        reward = torch.FloatTensor(reward).to(device)
        done = torch.FloatTensor(done).to(device)

        for i in range(self.num_agents):
            current_states = torch.cat([state[:, j, :] for j in range(self.num_agents)], dim=1)
            current_actions = torch.cat([action[:, j, :] for j in range(self.num_agents)], dim=1)
            next_states = torch.cat([next_state[:, j, :] for j in range(self.num_agents)], dim=1)

            # Target Critic
            with torch.no_grad():
                target_actions = [self.actors_target[j](next_state[:, j, :]) for j in range(self.num_agents)]
                target_actions = torch.cat(target_actions, dim=1)
                target_Q = reward[:, i].unsqueeze(1) + self.gamma * self.critics_target[i](next_states, target_actions) * (1 - done[:, i].unsqueeze(1))

            current_Q = self.critics[i](current_states, current_actions)
            critic_loss = nn.MSELoss()(current_Q, target_Q)

            self.critic_optimizers[i].zero_grad()
            critic_loss.backward()
            self.critic_optimizers[i].step()

            actor_actions = current_actions.clone()
            actor_actions[:, i*action.shape[2]:(i+1)*action.shape[2]] = self.actors[i](state[:, i, :])
            actor_loss = -self.critics[i](current_states, actor_actions).mean()

            self.actor_optimizers[i].zero_grad()
            actor_loss.backward()
            self.actor_optimizers[i].step()

            # Update target networks
            for param, target_param in zip(self.critics[i].parameters(), self.critics_target[i].parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
            for param, target_param in zip(self.actors[i].parameters(), self.actors_target[i].parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

# Example usage of the MADDPG algorithm in a simulated environment
if __name__ == "__main__":
    state_dim = 24  # Example state dimension
    action_dim = 4  # Example action dimension
    max_action = 1  # Maximum action value
    num_agents = 3  # Number of agents

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    maddpg = MADDPG(state_dim, action_dim, max_action, num_agents)

    # Example training loop (replace with real environment interaction)
    for episode in range(1000):
        states = [np.random.randn(state_dim) for _ in range(num_agents)]
        for t in range(100):
            actions = maddpg.select_action(states)
            next_states = [np.random.randn(state_dim) for _ in range(num_agents)]
            rewards = [np.random.randn() for _ in range(num_agents)]
            dones = [False for _ in range(num_agents)]

            maddpg.replay_buffer.add(states, actions, rewards, next_states, dones)

            if len(maddpg.replay_buffer.buffer) > 1000:
                maddpg.train()

            states = next_states

        if episode % 100 == 0:
            print(f'Episode {episode} completed.')

    print("MADDPG training completed.")