In [19]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch_geometric.nn as pyg_nn
import torch_geometric.data as pyg_data
from pettingzoo.mpe import simple_tag_v2
import os

# Initialize the environment
env = simple_tag_v2.parallel_env(render_mode=None, num_adversaries=4, num_good=1, num_obstacles=2)
env.reset()

# Parameters
num_class_a = 3
num_class_b = 1
num_adversaries = 4
num_good_agents = 1
num_obstacles = 2
communication_range = 10

# Agent lists
adversary_agents = [agent for agent in env.agents if 'adversary' in agent]
good_agents = [agent for agent in env.agents if 'agent' in agent]
landmarks = ['landmark_{}'.format(i) for i in range(num_obstacles)]

print("Adversary Agents:", adversary_agents)
print("Good Agents:", good_agents)


Adversary Agents: ['adversary_0', 'adversary_1', 'adversary_2', 'adversary_3']
Good Agents: ['agent_0']


In [20]:
# Define Temporal GCN model
class TemporalGCN(nn.Module):
    def __init__(self, node_input_dim, hidden_dim, output_dim):
        super(TemporalGCN, self).__init__()
        # Spatial GCN layers
        self.conv1 = pyg_nn.GCNConv(node_input_dim, hidden_dim)
        self.relu = nn.ReLU()
        # Temporal layer (e.g., GRU)
        self.gru = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
        # Output layer
        self.fc = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x_list, edge_index_list):
        """
        x_list: List of node feature tensors for each timestep [num_nodes_t, node_input_dim]
        edge_index_list: List of edge_index tensors for each timestep
        """
        h_list = []
        for x, edge_index in zip(x_list, edge_index_list):
            # Spatial GCN
            h = self.conv1(x, edge_index)
            h = self.relu(h)
            h_list.append(h)
        
        # Stack the node embeddings to form a sequence [batch_size, seq_len, hidden_dim]
        # Here, batch_size is the number of nodes, seq_len is the number of timesteps
        h_seq = torch.stack(h_list, dim=1)  # [num_nodes, seq_len, hidden_dim]
        
        # Temporal modeling
        # Initialize hidden state
        h0 = torch.zeros(1, h_seq.size(0), self.gru.hidden_size).to(h_seq.device)
        # Pass through GRU
        out, hn = self.gru(h_seq, h0)  # out: [num_nodes, seq_len, hidden_dim]
        # Take the last timestep's output
        out = out[:, -1, :]  # [num_nodes, hidden_dim]
        
        # Output layer
        out = self.fc(out)  # [num_nodes, output_dim]
        return out
    
# Define the Graph Attention Network layer
class GATLayer(nn.Module):
    def __init__(self, input_dim, output_dim, heads=4):
        super(GATLayer, self).__init__()
        self.gat_conv = pyg_nn.GATConv(input_dim, output_dim, heads=heads, concat=False)

    def forward(self, x, edge_index):
        return self.gat_conv(x, edge_index)
    
# Define Temporal Transformer
class TemporalTransformer(nn.Module):
    def __init__(self, hidden_dim, num_heads, num_layers):
        super(TemporalTransformer, self).__init__()
        encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_heads)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

    def forward(self, h_seq):
        # h_seq: [seq_len, num_nodes, hidden_dim]
        h_seq = h_seq.permute(1, 0, 2)  # [num_nodes, seq_len, hidden_dim]
        h_seq = h_seq.transpose(0, 1)    # Transformer expects [seq_len, batch_size, hidden_dim]
        out = self.transformer(h_seq)
        out = out[-1, :, :]  # Take the last timestep's output
        return out  # [num_nodes, hidden_dim]
    
# Define the Policy Network
class PolicyNetwork(nn.Module):
    def __init__(self, input_dim, action_dim):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, 64)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(64, action_dim)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        action_probs = self.fc2(x)
        return nn.functional.softmax(action_probs, dim=-1)

# Define the complete model
class AdvancedAgentModel(nn.Module):
    def __init__(self, node_input_dim, hidden_dim, action_dim, num_heads=4, num_layers=2):
        super(AdvancedAgentModel, self).__init__()
        ## Change order of GATLayer and TemporalTransformer ##Maybe add Moro-san code here
        # Spatial layer
        self.gat = GATLayer(node_input_dim, hidden_dim, heads=num_heads)
        # Temporal layer
        self.temporal = TemporalTransformer(hidden_dim, num_heads, num_layers)
        # Policy network
        self.policy = PolicyNetwork(hidden_dim, action_dim)

    def forward(self, x_list, edge_index_list):
        h_list = []
        for x, edge_index in zip(x_list, edge_index_list):
            # Spatial GAT
            h = self.gat(x, edge_index)
            h_list.append(h)
        # Stack the node embeddings to form a sequence
        h_seq = torch.stack(h_list)  # [seq_len, num_nodes, hidden_dim]
        # Temporal Transformer
        out = self.temporal(h_seq)  # [num_nodes, hidden_dim]
        # Policy network
        action_probs = self.policy(out)  # [num_nodes, action_dim]
        return action_probs
    
# Define the centralized critic
class CentralizedCritic(nn.Module):
    def __init__(self, node_input_dim, hidden_dim):
        super(CentralizedCritic, self).__init__()
        # GAT layer to process the global graph
        self.gat = pyg_nn.GATConv(node_input_dim, hidden_dim, heads=4, concat=False)
        # Self-attention layer
        self.attention = nn.MultiheadAttention(hidden_dim, num_heads=4)
        # Output layer
        self.value_head = nn.Linear(hidden_dim, 1)
    
    def forward(self, x, edge_index, batch):
        # x: [num_nodes, node_input_dim]
        # edge_index: [2, num_edges]
        # batch: [num_nodes] (batch assignment for each node)
        
        # GAT layer
        h = self.gat(x, edge_index)
        h = nn.functional.elu(h)
        
        # Reshape for attention
        h = h.unsqueeze(1)  # [num_nodes, 1, hidden_dim]
        
        # Self-attention
        h, _ = self.attention(h, h, h)  # h: [num_nodes, 1, hidden_dim]
        h = h.squeeze(1)  # [num_nodes, hidden_dim]
        
        # Global mean pooling
        h = pyg_nn.global_mean_pool(h, batch)  # [batch_size, hidden_dim]
        
        # Value estimation
        value = self.value_head(h)  # [batch_size, 1]
        return value.squeeze(-1)  # [batch_size]



In [21]:
def adversary_observation_wrapper(observations, adversary_agents, good_agents):
    updated_observations = {}
    adversary_positions = {}
    node_features = {}
    
    for agent in adversary_agents:
        obs = observations[agent]
        # Self features: self_vel (2), self_pos (2), class_id (1)  need to change here (ToDo), should work with relative position/velocity
        self_features = obs[0:4]
        agent_class = 0 if agent in adversary_agents[:num_class_a] else 1
        self_features = np.concatenate([self_features, [agent_class]])
        self_features = torch.tensor(self_features, dtype=torch.float32)
        node_features[agent] = self_features
        adversary_positions[agent] = obs[2:4]
        updated_observations[agent] = obs
    
    # Construct edge_index based on communication range
    edge_index = []
    agent_to_idx = {agent: idx for idx, agent in enumerate(adversary_agents)}
    for agent in adversary_agents:
        own_pos = adversary_positions[agent]
        for other_agent in adversary_agents:
            if agent != other_agent:
                other_pos = adversary_positions[other_agent]
                distance = np.linalg.norm(own_pos - other_pos)
                if distance <= communication_range:
                    edge_index.append([agent_to_idx[agent], agent_to_idx[other_agent]])
    
    # Convert edge_index and node_features to tensors
    if edge_index:
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    else:
        edge_index = torch.empty((2, 0), dtype=torch.long)
    node_features_list = [node_features[agent] for agent in adversary_agents]
    x = torch.stack(node_features_list)  # [num_nodes, node_input_dim]
    
    return updated_observations, x, edge_index, agent_to_idx


In [22]:
def extract_global_state(observations, adversary_agents, good_agents):
    # Collect positions and velocities
    node_features = []
    node_types = []
    node_positions = []
    edge_index = []
    node_idx = 0
    idx_map = {}
    
    # Adversaries
    for agent in adversary_agents:
        obs = observations[agent]
        # Self_vel (0:2), Self_pos (2:4)
        vel = obs[0:2]
        pos = obs[2:4]
        features = np.concatenate([pos, vel])
        node_features.append(features)
        node_types.append(0)  # Adversary
        node_positions.append(pos)
        idx_map[agent] = node_idx
        node_idx += 1
    
    # Good agents
    for agent in good_agents:
        obs = observations[agent]
        vel = obs[0:2]
        pos = obs[2:4]
        features = np.concatenate([pos, vel])
        node_features.append(features)
        node_types.append(1)  # Good agent
        node_positions.append(pos)
        idx_map[agent] = node_idx
        node_idx += 1
    
    # Landmarks (if positions are accessible)
    # Assuming we can get landmark positions from the environment
    for i in range(num_obstacles):
        # Here, we need to access the landmark positions from the environment
        # Since the observations do not include landmark positions directly, we might need to modify the environment to expose this information
        landmark_pos = env.world.landmarks[i].state.p_pos
        features = np.concatenate([landmark_pos, np.zeros(2)])  # No velocity
        node_features.append(features)
        node_types.append(2)  # Landmark
        node_positions.append(landmark_pos)
        idx_map['landmark_{}'.format(i)] = node_idx
        node_idx += 1
    
    # Build edge_index (fully connected for the critic)
    num_nodes = len(node_features)
    for i in range(num_nodes):
        for j in range(num_nodes):
            if i != j:
                edge_index.append([i, j])
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    
    # Convert to tensors
    x = torch.tensor(node_features, dtype=torch.float32)
    
    # Create batch (since we have a single graph, batch is zeros)
    batch = torch.zeros(num_nodes, dtype=torch.long)
    
    return x, edge_index, batch


In [23]:
# Node input dimension: self_vel (2) + self_pos (2) + class_id (1) = 5
node_input_dim = 5
hidden_dim = 64
action_dim = env.action_space(adversary_agents[0]).n
num_heads = 4
num_layers = 2

# Initialize the model
advanced_agent_model = AdvancedAgentModel(node_input_dim, hidden_dim, action_dim, num_heads, num_layers)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
advanced_agent_model.to(device)

# Use DistributedDataParallel if multiple GPUs are available
if torch.cuda.device_count() > 1:
    advanced_agent_model = nn.DataParallel(advanced_agent_model)

# Optimizer and learning rate scheduler
learning_rate = 3e-4
optimizer = optim.Adam(advanced_agent_model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5000, gamma=0.5)


In [24]:
# Define the value network
class ValueNetwork(nn.Module):
    def __init__(self, node_input_dim, hidden_dim):
        super(ValueNetwork, self).__init__()
        self.fc1 = nn.Linear(node_input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        value = self.fc2(x)
        return value.squeeze(-1)  # [num_nodes]

# Initialize the value network
value_network = ValueNetwork(node_input_dim, hidden_dim).to(device)
value_optimizer = optim.Adam(value_network.parameters(), lr=learning_rate)

In [25]:
from collections import deque

# Training parameters
num_episodes = 1000  # Increase for better training
max_timesteps = 200  # Max timesteps per episode
gamma = 0.99         # Discount factor
gae_lambda = 0.95    # GAE lambda for advantage computation
epsilon = 0.2        # PPO clip parameter
entropy_coef = 0.01  # Coefficient for entropy regularization
value_coef = 0.5     # Coefficient for value loss
grad_norm_clip = 0.5 # Gradient clipping

# Storage for trajectories
class RolloutBuffer:
    def __init__(self):
        self.states = []
        self.edge_indices = []
        self.actions = []
        self.log_probs = []
        self.rewards = []
        self.values = []
        self.dones = []
        self.global_states = []  # Global states for critic
        self.global_edge_indices = []  # Edge indices for critic
        self.global_batches = []  # Batches for critic

    def clear(self):
        self.states = []
        self.edge_indices = []
        self.actions = []
        self.log_probs = []
        self.rewards = []
        self.values = []
        self.dones = []
        self.global_states = []
        self.global_edge_indices = []
        self.global_batches = []

buffer = RolloutBuffer()
episode_rewards = deque(maxlen=100)

for episode in range(1, num_episodes + 1):
    observations = env.reset()
    done = False
    cumulative_reward = 0
    timestep = 0
    buffer.clear()

    x_sequence = []
    edge_index_sequence = []

    while not done and timestep < max_timesteps:
        adversary_agents = [agent for agent in observations.keys() if 'adversary' in agent]
        good_agents = [agent for agent in observations.keys() if 'agent' in agent]

        # Observation wrapper
        observations, x, edge_index, agent_to_idx = adversary_observation_wrapper(
            observations, adversary_agents, good_agents)
        x = x.to(device)
        edge_index = edge_index.to(device)

        # Collect data for temporal modeling
        x_sequence.append(x)
        edge_index_sequence.append(edge_index)

        if len(x_sequence) > 5:
            x_sequence.pop(0)
            edge_index_sequence.pop(0)

        # Prepare input sequences
        x_input = x_sequence.copy()
        edge_index_input = edge_index_sequence.copy()
        for i in range(len(x_input)):
            x_input[i] = x_input[i].to(device)
            edge_index_input[i] = edge_index_input[i].to(device)

        # Get action probabilities from the actor
        action_probs = advanced_agent_model(x_input, edge_index_input)

        # Compute value estimates from the critic using the global state
        global_x, global_edge_index, global_batch = extract_global_state(observations, adversary_agents, good_agents)
        global_x = global_x.to(device)
        global_edge_index = global_edge_index.to(device)
        global_batch = global_batch.to(device)
        value_estimates = critic(global_x, global_edge_index, global_batch)

        # Sample actions
        actions = {}
        log_probs = []
        values = []
        for agent in adversary_agents:
            idx = agent_to_idx[agent]
            dist = torch.distributions.Categorical(action_probs[idx])
            action = dist.sample()
            actions[agent] = action.item()
            log_probs.append(dist.log_prob(action))
            values.append(value_estimates[idx])

        # For good agents, sample random actions
        for agent in good_agents:
            actions[agent] = env.action_space(agent).sample()

        # Step the environment
        next_observations, rewards, terminations, truncations, infos = env.step(actions)

        # Store transitions
        buffer.states.append(x)
        buffer.edge_indices.append(edge_index)
        buffer.actions.append(torch.tensor([actions[agent] for agent in adversary_agents], dtype=torch.long, device=device))
        buffer.log_probs.append(torch.stack(log_probs))
        buffer.rewards.append(torch.tensor([rewards[agent] for agent in adversary_agents], dtype=torch.float32, device=device))
        buffer.values.append(torch.stack(values))
        buffer.dones.append(torch.tensor([done]*len(adversary_agents), dtype=torch.float32, device=device))
        buffer.global_states.append(global_x)
        buffer.global_edge_indices.append(global_edge_index)
        buffer.global_batches.append(global_batch)

        cumulative_reward += sum(rewards.values())
        observations = next_observations
        done = all(terminations.values()) or all(truncations.values())
        timestep += 1

    # Compute advantages and returns using Generalized Advantage Estimation (GAE)
    returns = []
    advantages = []
    gae = 0
    next_value = 0
    for i in reversed(range(len(buffer.rewards))):
        delta = buffer.rewards[i].mean() + gamma * next_value * (1 - buffer.dones[i].mean()) - buffer.values[i].mean()
        gae = delta + gamma * gae_lambda * gae * (1 - buffer.dones[i].mean())
        advantages.insert(0, gae)
        returns.insert(0, gae + buffer.values[i].mean())
        next_value = buffer.values[i].mean()

    # Flatten the lists
    states = torch.cat(buffer.states)
    edge_indices = buffer.edge_indices
    actions = torch.cat(buffer.actions)
    log_probs_old = torch.cat(buffer.log_probs).detach()
    returns = torch.cat(returns).detach()
    advantages = torch.cat(advantages).detach()

    # PPO update
    num_mini_batch = 4
    mini_batch_size = len(buffer.actions) // num_mini_batch

    for _ in range(4):  # Optimize policy for K epochs
        # Shuffle indices
        indices = np.arange(len(buffer.actions))
        np.random.shuffle(indices)

        for start in range(0, len(buffer.actions), mini_batch_size):
            end = start + mini_batch_size
            mb_indices = indices[start:end]

            # Prepare minibatch data
            mb_states = [buffer.states[i] for i in mb_indices]
            mb_edge_indices = [buffer.edge_indices[i] for i in mb_indices]
            mb_actions = torch.cat([buffer.actions[i] for i in mb_indices])
            mb_log_probs_old = torch.cat([buffer.log_probs[i] for i in mb_indices]).detach()
            mb_returns = torch.stack([returns[i] for i in mb_indices]).detach()
            mb_advantages = torch.stack([advantages[i] for i in mb_indices]).detach()
            mb_global_states = [buffer.global_states[i] for i in mb_indices]
            mb_global_edge_indices = [buffer.global_edge_indices[i] for i in mb_indices]
            mb_global_batches = [buffer.global_batches[i] for i in mb_indices]

            # Compute action probabilities
            action_probs = []
            for x_input, edge_index_input in zip(mb_states, mb_edge_indices):
                x_sequence = [x_input]
                edge_index_sequence = [edge_index_input]
                ap = advanced_agent_model(x_sequence, edge_index_sequence)
                action_probs.append(ap)
            action_probs = torch.cat(action_probs)

            # Recompute value estimates using the critic
            values = []
            for global_x, global_edge_index, global_batch in zip(mb_global_states, mb_global_edge_indices, mb_global_batches):
                value = critic(global_x, global_edge_index, global_batch)
                values.append(value)
            values = torch.stack(values)

            # Compute new log probs and entropy
            log_probs = []
            entropy = 0
            for idx in range(len(mb_actions)):
                dist = torch.distributions.Categorical(action_probs[idx])
                log_prob = dist.log_prob(mb_actions[idx])
                log_probs.append(log_prob)
                entropy += dist.entropy()

            log_probs = torch.stack(log_probs)
            ratios = torch.exp(log_probs - mb_log_probs_old)

            # Surrogate loss
            surr1 = ratios * mb_advantages
            surr2 = torch.clamp(ratios, 1 - epsilon, 1 + epsilon) * mb_advantages
            policy_loss = -torch.min(surr1, surr2).mean()

            # Value loss
            value_loss = value_coef * (mb_returns - values).pow(2).mean()

            # Entropy regularization
            entropy_loss = -entropy_coef * entropy.mean()

            # Total loss
            loss = policy_loss + value_loss + entropy_loss

            # Backpropagation
            optimizer.zero_grad()
            value_optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(advanced_agent_model.parameters(), grad_norm_clip)
            optimizer.step()
            value_optimizer.step()

    scheduler.step()
    episode_rewards.append(cumulative_reward)

    if episode % 10 == 0:
        avg_reward = np.mean(episode_rewards)
        print(f"Episode {episode}: Average Reward: {avg_reward:.2f}")


KeyError: 'agent_0'

In [9]:
# Create a directory to save models
model_dir = 'saved_models'
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

# Save the models
torch.save(advanced_agent_model.state_dict(), os.path.join(model_dir, "advanced_agent_model.pth"))
torch.save(value_network.state_dict(), os.path.join(model_dir, "value_network.pth"))
print("Models saved successfully.")


Model saved successfully.


In [12]:
# Load the models
advanced_agent_model.load_state_dict(torch.load(os.path.join(model_dir, "advanced_agent_model.pth")))
critic.load_state_dict(torch.load(os.path.join(model_dir, "critic_model.pth")))

# Set models to evaluation mode
advanced_agent_model.eval()
critic.eval()

# Testing parameters
test_num_adversaries = 6  # Increase number of adversaries
env = simple_tag_v2.parallel_env(render_mode=None, num_adversaries=test_num_adversaries, num_good=1, num_obstacles=2)
env.reset()
num_test_episodes = 50
test_episode_rewards = []

for episode in range(num_test_episodes):
    observations = env.reset()
    done = False
    cumulative_reward = 0
    x_sequence = []
    edge_index_sequence = []

    while not done:
        adversary_agents = [agent for agent in observations.keys() if 'adversary' in agent]
        good_agents = [agent for agent in observations.keys() if 'agent' in agent]

        # Observation wrapper
        observations, x, edge_index, agent_to_idx = adversary_observation_wrapper(
            observations, adversary_agents, good_agents)
        x = x.to(device)
        edge_index = edge_index.to(device)

        x_sequence.append(x)
        edge_index_sequence.append(edge_index)

        if len(x_sequence) > 5:
            x_sequence.pop(0)
            edge_index_sequence.pop(0)

        x_input = x_sequence.copy()
        edge_index_input = edge_index_sequence.copy()
        for i in range(len(x_input)):
            x_input[i] = x_input[i].to(device)
            edge_index_input[i] = edge_index_input[i].to(device)

        # Get action probabilities
        with torch.no_grad():
            action_probs = advanced_agent_model(x_input, edge_index_input)

        # Select actions
        actions = {}
        for agent in adversary_agents:
            idx = agent_to_idx[agent]
            action = torch.argmax(action_probs[idx]).item()
            actions[agent] = action

        # Good agents take random actions
        for agent in good_agents:
            actions[agent] = env.action_space(agent).sample()

        # Step the environment
        observations, rewards, terminations, truncations, infos = env.step(actions)

        cumulative_reward += sum(rewards.values())
        done = all(terminations.values()) or all(truncations.values())

    test_episode_rewards.append(cumulative_reward)

avg_test_reward = np.mean(test_episode_rewards)
print(f"Average Test Reward over {num_test_episodes} episodes: {avg_test_reward:.2f}")
print("Testing completed.")


Epoch 10: Average Reward over last 10 epochs: -9.548204535002139
Epoch 20: Average Reward over last 10 epochs: 52.82452114051563
Epoch 30: Average Reward over last 10 epochs: 11.951303239275566
Epoch 40: Average Reward over last 10 epochs: 5.262679379053028
Epoch 50: Average Reward over last 10 epochs: 26.979216094949738
Test completed.
