In [6]:
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)

import gymnasium as gym
from pettingzoo.mpe import simple_spread_v3
from pettingzoo.utils.conversions import aec_to_parallel
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, VecMonitor
from stable_baselines3.common.policies import ActorCriticPolicy
import stable_baselines3.common.policies as policies_module
import supersuit as ss
import wandb
import tensorboard
import torch

In [4]:
# create the PettingZoo environment
env = simple_spread_v3.env(N=2, local_ratio=0.5, max_cycles=25, continuous_actions=False) # N=2 agents, 2 landmarks, partial global vs local reward mixing

# Apply wrappers
env = ss.black_death_v3(env)  # keeps a dummy agent alive when one is done
env = ss.flatten_v0(env)      # flatten observations for SB3 compatibility
env = aec_to_parallel(env)    # Convert AECEnv to ParallelEnv
env = ss.pettingzoo_env_to_vec_env_v1(env)  # convert to vectorized env
env = ss.concat_vec_envs_v1(env, 1, num_cpus=1, base_class="stable_baselines3")  # Use SB3 base class
env = VecMonitor(env)  # to record metrics easily

In [5]:
# first get your API key from wandb.ai (or we can all use mine)
wandb.init(project="marl-coordination-demo", name="PPO-simple-spread", sync_tensorboard=True)

[34m[1mwandb[0m: Currently logged in as: [33manastasiia-chernavskaia[0m ([33manastasiia-chernavskaia-barcelona-school-of-economics[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [7]:
policy_kwargs = dict(
    net_arch=[64, 64],  # Simplified MLP architecture
    activation_fn=torch.nn.ReLU
)

model = PPO(ActorCriticPolicy, env, verbose=1, tensorboard_log="./ppo_marl_tb/", policy_kwargs=policy_kwargs)
model.learn(total_timesteps=50000)  # train for 50k timesteps
model.save("ppo_marl_simple_spread")
wandb.finish()

Using cuda device
Logging to ./ppo_marl_tb/PPO_11
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 26       |
|    ep_rew_mean     | -21.3    |
| time/              |          |
|    fps             | 1037     |
|    iterations      | 1        |
|    time_elapsed    | 3        |
|    total_timesteps | 4096     |
---------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 26           |
|    ep_rew_mean          | -20.4        |
| time/                   |              |
|    fps                  | 712          |
|    iterations           | 2            |
|    time_elapsed         | 11           |
|    total_timesteps      | 8192         |
| train/                  |              |
|    approx_kl            | 0.0059424946 |
|    clip_fraction        | 0.0276       |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.61        |
|    exp

0,1
global_step,▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▆▆▆▇▇▇████
rollout/ep_len_mean,▁▁▁▁▁▁▁▁▁▁▁▁▁
rollout/ep_rew_mean,▁▂▂▄▅▄▄▆▇█▇▇█
time/fps,█▃▂▂▂▂▁▁▁▁▁▁▁
train/approx_kl,▁▁▃▁▁▄█▅▃▆▂▆
train/clip_fraction,▁▁▄▂▂▅█▇▄█▄▆
train/clip_range,▁▁▁▁▁▁▁▁▁▁▁▁
train/entropy_loss,▁▁▂▂▃▄▅▅▆▇▇█
train/explained_variance,▁▂▂▃▅▅▅▇▇███
train/learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁

0,1
global_step,53248.0
rollout/ep_len_mean,26.0
rollout/ep_rew_mean,-14.7733
time/fps,578.0
train/approx_kl,0.0082
train/clip_fraction,0.06472
train/clip_range,0.2
train/entropy_loss,-1.44669
train/explained_variance,0.49282
train/learning_rate,0.0003
