In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque, namedtuple
import random
import matplotlib.pyplot as plt
import numpy as np

from torch.optim import Adam
from collections import defaultdict

from NaturalEnv import natural_env_v0

is_ipython = 'inline' in plt.get_backend()
if is_ipython:
    from IPython import display

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

args = {
    'render_mode': None,
    'max_cycles': 256,
    'continuous_actions': True,
    'num_predators': 0,
    'num_prey': 1,
    'num_obstacles': 0,
    'num_food': 1,
    'num_water': 1,
    'num_forests': 0
}

env = natural_env_v0.parallel_env(**args)
obs, _ = env.reset()

# Get observation and action spaces
obs_spaces = {agent: env.observation_space(agent).shape[0] for agent in env.agents}
action_spaces = {agent: env.action_space(agent).shape[0] for agent in env.agents}
agents = env.agents  # List of agents

print(f"Device used: {device}")
print(f"Observation spaces: {obs_spaces}")
print(f"Action spaces: {action_spaces}")

Device used: cuda
Observation spaces: {'prey_0': 16}
Action spaces: {'prey_0': 5}


In [2]:
import torch.nn as nn

class Actor(nn.Module):
    def __init__(self, obs_dim, action_dim, hidden_dim=128):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(obs_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.out = nn.Linear(hidden_dim, action_dim)

    def forward(self, obs):
        x = torch.relu(self.fc1(obs))
        x = torch.relu(self.fc2(x))
        return torch.sigmoid(self.out(x))

class Critic(nn.Module):
    def __init__(self, obs_dim, action_dim, hidden_dim=128):
        super(Critic, self).__init__()
        input_dim = obs_dim + action_dim  # Critic gets all obs & actions
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.out = nn.Linear(hidden_dim, 1)

    def forward(self, obs, actions):
        x = torch.cat([obs, actions], dim=-1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.out(x)
    
class ParallelReplayBuffer:
    def __init__(self, capacity=100000):
        self.buffer = deque(maxlen=capacity)

    def add(self, observations, actions, rewards, next_observations, dones):
        # Store a single transition for all agents
        self.buffer.append((observations, actions, rewards, next_observations, dones))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        obs, actions, rewards, next_obs, dones = zip(*batch)
        
        # Convert dictionary data into tensors
        obs = {agent: torch.tensor(np.array([b[agent] for b in obs]), dtype=torch.float32, device=device) for agent in obs[0]}
        actions = {agent: torch.tensor(np.array([b[agent] for b in actions]), dtype=torch.float32, device=device) for agent in actions[0]}
        rewards = {agent: torch.tensor([b[agent] for b in rewards], dtype=torch.float32, device=device) for agent in rewards[0]}
        next_obs = {agent: torch.tensor(np.array([b[agent] for b in next_obs]), dtype=torch.float32, device=device) for agent in next_obs[0]}
        dones = {agent: torch.tensor([b[agent] for b in dones], dtype=torch.float32, device=device) for agent in dones[0]}
        
        return obs, actions, rewards, next_obs, dones

    def size(self):
        return len(self.buffer)

In [3]:
def train_step(actors, critics, target_actors, target_critics, buffer, batch_size, gamma=0.95, tau=0.01):
    obs, actions, rewards, next_obs, dones = buffer.sample(batch_size)
    
    # Centralized Q-value update for each agent
    for i, agent in enumerate(agents):
        # Get target actions for all agents
        target_actions = torch.cat([target_actors[j](next_obs[other]) for j, other in enumerate(agents)], dim=-1)
        obs_concat = torch.cat([obs[other] for other in agents], dim=-1)
        next_obs_concat = torch.cat([next_obs[other] for other in agents], dim=-1)
        
        # Compute target Q-value
        target_q = target_critics[i](next_obs_concat, target_actions).detach()
        y = rewards[agent] + gamma * (1 - dones[agent]) * target_q.squeeze()
        
        # Predicted Q-value
        actions_concat = torch.cat([actions[other] for other in agents], dim=-1)
        current_q = critics[i](obs_concat, actions_concat).squeeze()
        
        # Critic Loss
        critic_loss = torch.nn.functional.mse_loss(current_q, y)
        critics[i].optimizer.zero_grad()
        critic_loss.backward()
        critics[i].optimizer.step()

    # Policy (actor) update
    for i, agent in enumerate(agents):
        current_actions = torch.cat(
            [actors[j](obs[other]) if other == agent else actions[other].detach() for j, other in enumerate(agents)], dim=-1
        )
        actor_loss = -critics[i](obs_concat, current_actions).mean()
        actors[i].optimizer.zero_grad()
        actor_loss.backward()
        actors[i].optimizer.step()

    # Soft update for target networks
    for i, agent in enumerate(agents):
        for target_param, param in zip(target_critics[i].parameters(), critics[i].parameters()):
            target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
        for target_param, param in zip(target_actors[i].parameters(), actors[i].parameters()):
            target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

def update_plot_multi_agent(episode, max_episodes, reward_history, show_result=False):
    plt.figure(1)

    if show_result:
        plt.title(f'Final Result:')
    else:
        plt.clf()
        plt.title(f"Episode {episode} of {max_episodes}")
    plt.xlabel('Episode')
    plt.ylabel('Total reward')

    for agent_name in agents:
        agent_name = agents[0]
        rewards_t = torch.tensor(reward_history[agent_name], dtype=torch.float)
        plt.plot(rewards_t.numpy(), label=agent_name)

        # Plot moving average of last 10 rewards
        if len(rewards_t) >= 10:
            means = rewards_t.unfold(0, 10, 1).mean(1).view(-1)
            means = torch.cat((torch.zeros(9), means))
            plt.plot(means.numpy())            

    plt.legend()
    plt.pause(0.001)
    if is_ipython:
        if not show_result:
            display.display(plt.gcf())
            display.clear_output(wait=True)
        else:
            display.display(plt.gcf())

In [4]:
# Initialize actors, critics, target networks, and optimizers
actors = [Actor(obs_spaces[agent], action_spaces[agent]).to(device) for agent in env.agents]
critics = [Critic(sum(obs_spaces.values()), sum(action_spaces.values())).to(device) for _ in env.agents]
target_actors = [Actor(obs_spaces[agent], action_spaces[agent]).to(device) for agent in env.agents]
target_critics = [Critic(sum(obs_spaces.values()), sum(action_spaces.values())).to(device) for _ in env.agents]

# Optimizers
for actor, critic in zip(actors, critics):
    actor.optimizer = Adam(actor.parameters(), lr=1e-3)
    critic.optimizer = Adam(critic.parameters(), lr=1e-3)

# Replay buffer
buffer = ParallelReplayBuffer()

# Main training loop
episodes = 1000
batch_size = 64

plt.ion()
reward_history = {agent_name: [] for agent_name in agents}

for episode in range(episodes):
    obs, _ = env.reset()

    # Get initial observations
    done = defaultdict(bool, {agent: False for agent in agents})
    episode_reward = {agent: 0 for agent in agents}
    
    while not all(done.values()):
        # Choose actions for each agent
        actions = {agent: actors[i](torch.tensor(obs[agent], dtype=torch.float32, device=device)).cpu().detach().numpy() for i, agent in enumerate(agents)}

        # Step the environment
        next_obs, rewards, terminated, truncated,  _ = env.step(actions)

        # Compute reward for each agent
        for agent in agents:
            episode_reward[agent] += rewards[agent]

        # Update done flag
        done = defaultdict(bool, {agent: terminated[agent] or truncated[agent] for agent in agents})
                
        # Store transition in replay buffer
        buffer.add(obs, actions, rewards, next_obs, done)
        
        obs = next_obs
        
        # Training step if enough data in buffer
        if buffer.size() > batch_size:
            train_step(actors, critics, target_actors, target_critics, buffer, batch_size)

    reward_history = {agent_name: reward_history[agent_name] + [episode_reward[agent_name]] for agent_name in agents}
    update_plot_multi_agent(episode, episodes, reward_history)

update_plot_multi_agent(episode + 1, episodes, reward_history, show_result=True)
plt.ioff()
env.close()

RuntimeError: CUDA error: unknown error
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
# Test the trained model
human_env = natural_env_v0.parallel_env(render_mode="human", continuous_actions=True, **args)

obs, _ = human_env.reset()
done = defaultdict(bool, {agent: False for agent in agents})

while not all(done.values()):
    actions = {agent: actors[i](torch.tensor(obs[agent], dtype=torch.float32, device=device)).cpu().detach().numpy() for i, agent in enumerate(agents)}
    obs, _, terminated, truncated, _ = human_env.step(actions)
    done = defaultdict(bool, {agent: terminated[agent] or truncated[agent] for agent in agents})

human_env.close()