In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch_geometric.nn as pyg_nn
import torch_geometric.data as pyg_data
from pettingzoo.mpe import simple_tag_v2
import os

# Initialize the environment
env = simple_tag_v2.parallel_env(render_mode=None, num_adversaries=4, num_good=1, num_obstacles=2)
env.reset()

# Parameters
num_class_a = 3
num_class_b = 1
num_adversaries = 4
num_good_agents = 1  # Only one agent being chased by adversaries
num_obstacles = 2

# Initial lists of agents
adversary_agents = [agent for agent in env.agents if 'adversary' in agent]
good_agents = [agent for agent in env.agents if 'agent' in agent]

print("Adversary Agents:", adversary_agents)
print("Good Agents:", good_agents)


Adversary Agents: ['adversary_0', 'adversary_1', 'adversary_2', 'adversary_3']
Good Agents: ['agent_0']


In [2]:
# Define Temporal GCN model
class TemporalGCN(nn.Module):
    def __init__(self, node_input_dim, hidden_dim, output_dim):
        super(TemporalGCN, self).__init__()
        # Spatial GCN layers
        self.conv1 = pyg_nn.GCNConv(node_input_dim, hidden_dim)
        self.relu = nn.ReLU()
        # Temporal layer (e.g., GRU)
        self.gru = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
        # Output layer
        self.fc = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x_list, edge_index_list):
        """
        x_list: List of node feature tensors for each timestep [num_nodes_t, node_input_dim]
        edge_index_list: List of edge_index tensors for each timestep
        """
        h_list = []
        for x, edge_index in zip(x_list, edge_index_list):
            # Spatial GCN
            h = self.conv1(x, edge_index)
            h = self.relu(h)
            h_list.append(h)
        
        # Stack the node embeddings to form a sequence [batch_size, seq_len, hidden_dim]
        # Here, batch_size is the number of nodes, seq_len is the number of timesteps
        h_seq = torch.stack(h_list, dim=1)  # [num_nodes, seq_len, hidden_dim]
        
        # Temporal modeling
        # Initialize hidden state
        h0 = torch.zeros(1, h_seq.size(0), self.gru.hidden_size).to(h_seq.device)
        # Pass through GRU
        out, hn = self.gru(h_seq, h0)  # out: [num_nodes, seq_len, hidden_dim]
        # Take the last timestep's output
        out = out[:, -1, :]  # [num_nodes, hidden_dim]
        
        # Output layer
        out = self.fc(out)  # [num_nodes, output_dim]
        return out


In [3]:
def adversary_observation_wrapper(observations, adversary_agents, good_agents):
    updated_observations = {}
    adversary_positions = {}  # Store adversary positions for graph construction
    node_features = {}
    
    # For each adversary, extract self features
    for agent in adversary_agents:
        obs = observations[agent]
        # Extract self_vel and self_pos (indices 0-3)
        self_features = obs[0:4]
        # Add class identifier (0 for Class A, 1 for Class B)
        agent_class = 0 if agent in adversary_agents[:num_class_a] else 1
        self_features = np.concatenate([self_features, [agent_class]])
        # Convert to tensor
        self_features = torch.tensor(self_features, dtype=torch.float32)
        node_features[agent] = self_features
        # Store position for graph construction (indices 2-3)
        adversary_positions[agent] = obs[2:4]
        # Update observations
        updated_observations[agent] = obs  # Keep original observation if needed
    
    # Construct edge_index based on communication range
    edge_index = []
    agent_to_idx = {agent: idx for idx, agent in enumerate(adversary_agents)}
    for agent in adversary_agents:
        own_position = adversary_positions[agent]
        for other_agent in adversary_agents:
            if agent != other_agent:
                other_position = adversary_positions[other_agent]
                distance = np.linalg.norm(own_position - other_position)
                if distance <= communication_range:
                    edge_index.append([agent_to_idx[agent], agent_to_idx[other_agent]])
    
    # Convert edge_index to tensor
    if edge_index:
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    else:
        edge_index = torch.empty((2, 0), dtype=torch.long)
    
    # Convert node_features to a tensor
    node_features_list = [node_features[agent] for agent in adversary_agents]
    x = torch.stack(node_features_list)  # [num_nodes, node_input_dim]
    
    return updated_observations, x, edge_index, agent_to_idx


In [6]:
# Node input dimension: self_vel (2) + self_pos (2) + class_id (1) = 5
node_input_dim = 5
hidden_dim = 16
communication_range = 1.5
output_dim = env.action_space(adversary_agents[0]).n  # Assuming all adversaries have the same action space

# Initialize the Temporal GCN model
temporal_gcn = TemporalGCN(node_input_dim, hidden_dim, output_dim)

# Optimizer
learning_rate = 0.001
optimizer = torch.optim.Adam(temporal_gcn.parameters(), lr=learning_rate)

# Loss function (placeholder)
loss_fn = nn.CrossEntropyLoss()


In [7]:
# Training parameters
num_episodes = 50  # Total number of episodes to run
print_interval = 10  # Print rewards every 10 episodes

# Initialize reward tracking
episode_rewards = []

# For temporal modeling, we'll collect sequences over timesteps
sequence_length = 5  # Number of timesteps to consider in the temporal model

for episode in range(1, num_episodes + 1):
    observations = env.reset()
    done = False
    cumulative_reward = 0  # Reset cumulative reward for the episode
    
    # Initialize lists to collect data over time
    x_sequence = []  # List of node features for each timestep
    edge_index_sequence = []  # List of edge indices for each timestep
    
    while not done:
        # Update adversary_agents and good_agents based on current observations
        adversary_agents = [agent for agent in observations.keys() if 'adversary' in agent]
        good_agents = [agent for agent in observations.keys() if 'agent' in agent]
        
        # Apply the observation wrapper for adversaries
        observations, x, edge_index, agent_to_idx = adversary_observation_wrapper(
            observations, adversary_agents, good_agents)
        
        # Collect data for temporal modeling
        x_sequence.append(x)
        edge_index_sequence.append(edge_index)
        
        # Ensure we have enough timesteps for temporal modeling
        if len(x_sequence) >= sequence_length:
            # Use the last 'sequence_length' timesteps
            x_input = x_sequence[-sequence_length:]
            edge_index_input = edge_index_sequence[-sequence_length:]
            
            # Forward pass through Temporal GCN
            output = temporal_gcn(x_input, edge_index_input)  # [num_nodes, output_dim]
            
            # Get actions
            actions = {}
            for agent in adversary_agents:
                idx = agent_to_idx[agent]
                action_probs = output[idx]
                action = torch.argmax(action_probs).item()
                actions[agent] = action
            
            # For good agents, sample random actions
            for agent in good_agents:
                actions[agent] = env.action_space(agent).sample()
            
            # Step the environment
            next_observations, rewards, terminations, truncations, infos = env.step(actions)
            
            # Compute loss (placeholder)
            target = torch.tensor([0] * len(adversary_agents), dtype=torch.long)  # Dummy target
            loss = loss_fn(output, target)
            
            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Update cumulative reward
            cumulative_reward += sum(rewards.values())
            
            # Update observations
            observations = next_observations
            
            # Check if all agents are done
            done = all(terminations.values()) or all(truncations.values())
        else:
            # Not enough timesteps yet, take random actions
            actions = {agent: env.action_space(agent).sample() for agent in env.agents}
            # Step the environment
            observations, rewards, terminations, truncations, infos = env.step(actions)
            # Update cumulative reward
            cumulative_reward += sum(rewards.values())
    
    # Append cumulative reward for the episode
    episode_rewards.append(cumulative_reward)
    
    # Print rewards every 'print_interval' episodes
    if episode % print_interval == 0:
        avg_reward = sum(episode_rewards[-print_interval:]) / print_interval
        print(f"Episode {episode}: Average Reward: {avg_reward}")


Episode 10: Average Reward: 10.7807294692779
Episode 20: Average Reward: 8.530477074640222
Episode 30: Average Reward: -6.904504160444003
Episode 40: Average Reward: -13.201293840875223
Episode 50: Average Reward: 3.500866183514612


In [9]:
# Save the Temporal GCN model
model_dir = 'saved_models'
if not os.path.exists(model_dir):
    os.makedirs(model_dir)
torch.save(temporal_gcn.state_dict(), os.path.join(model_dir, "temporal_gcn_model.pth"))

print("Model saved successfully.")


Model saved successfully.


In [10]:
# Testing with a different number of adversaries
test_num_adversaries = 6  # Increased number of adversaries
test_num_class_a = 4
test_num_class_b = 2

# Initialize the environment with the new number of adversaries
env = simple_tag_v2.parallel_env(render_mode=None, num_adversaries=test_num_adversaries, num_good=1, num_obstacles=2)
env.reset()

# Run the environment for testing
num_epochs = 50
print_interval = 10  # Calculate average reward every 10 epochs
episode_rewards = []

for epoch in range(1, num_epochs + 1):
    observations = env.reset()
    done = False
    cumulative_reward = 0  # Reset cumulative reward for the epoch
    
    x_sequence = []
    edge_index_sequence = []
    
    while not done:
        # Update adversary_agents and good_agents based on current observations
        adversary_agents = [agent for agent in observations.keys() if 'adversary' in agent]
        good_agents = [agent for agent in observations.keys() if 'agent' in agent]
        
        # Apply the observation wrapper for adversaries
        observations, x, edge_index, agent_to_idx = adversary_observation_wrapper(
            observations, adversary_agents, good_agents)
        
        # Collect data for temporal modeling
        x_sequence.append(x)
        edge_index_sequence.append(edge_index)
        
        if len(x_sequence) >= sequence_length:
            x_input = x_sequence[-sequence_length:]
            edge_index_input = edge_index_sequence[-sequence_length:]
            
            # Forward pass through Temporal GCN
            output = temporal_gcn(x_input, edge_index_input)  # [num_nodes, output_dim]
            
            # Get actions
            actions = {}
            for agent in adversary_agents:
                idx = agent_to_idx[agent]
                action_probs = output[idx]
                action = torch.argmax(action_probs).item()
                actions[agent] = action
            
            # For good agents, sample random actions
            for agent in good_agents:
                actions[agent] = env.action_space(agent).sample()
            
            # Step the environment
            observations, rewards, terminations, truncations, infos = env.step(actions)
            
            # Update cumulative reward
            cumulative_reward += sum(rewards.values())
            
            # Check if all agents are done
            if all(terminations.values()) or all(truncations.values()):
                done = True
        else:
            # Not enough timesteps yet, take random actions
            actions = {agent: env.action_space(agent).sample() for agent in env.agents}
            # Step the environment
            observations, rewards, terminations, truncations, infos = env.step(actions)
            # Update cumulative reward
            cumulative_reward += sum(rewards.values())
    
    # Track the cumulative reward for the epoch
    episode_rewards.append(cumulative_reward)
    
    # Print average reward every 'print_interval' epochs
    if epoch % print_interval == 0:
        avg_reward = sum(episode_rewards[-print_interval:]) / print_interval
        print(f"Epoch {epoch}: Average Reward over last {print_interval} epochs: {avg_reward}")
    
print("Test completed.")


Epoch 10: Average Reward over last 10 epochs: -14.06920927202702
Epoch 20: Average Reward over last 10 epochs: 30.364071343836123
Epoch 30: Average Reward over last 10 epochs: -3.5507463682510947
Epoch 40: Average Reward over last 10 epochs: 5.822840216382039
Epoch 50: Average Reward over last 10 epochs: 16.18777799676129
Test completed.
