In [1]:
import numpy as np
import gymnasium as gym
from gymnasium import spaces
import time
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
import torch.nn.functional as F
from torch import nn

In [2]:
class MultiAgentFeatureExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: spaces.Box, n_agents: int, features_dim: int):
        super(MultiAgentFeatureExtractor, self).__init__(observation_space, features_dim)
        self.n_agents = n_agents

        input_dim_per_agent = observation_space.shape[0]  
        self.fc1 = nn.Linear(input_dim_per_agent, 256)  
        self.fc2 = nn.Linear(256, features_dim)

    def forward(self, observations):
        x = F.relu(self.fc1(observations))
        #print(f"Shape after fc1: {x.shape}")  
        x = self.fc2(x)
        #print(f"Shape after fc2: {x.shape}")  
        return x

In [3]:
class CustomPPO(PPO):
    def __init__(self, *args, **kwargs):
        super(CustomPPO, self).__init__(*args, **kwargs)

    def train(self):
        # Call the original train function but modify loss computation
        super(CustomPPO, self).train()

    def custom_loss(self, policy_loss, value_loss, entropy_loss):
       
        custom_loss = policy_loss + 0.5 * value_loss - 0.01 * entropy_loss

        return custom_loss

    def compute_loss(self, observations, actions, rewards, old_log_probs, advantages, returns):
        # Standard PPO policy and value loss
        policy_loss, value_loss, entropy_loss = super().compute_loss(observations, actions, rewards, old_log_probs, advantages, returns)

        # Use the custom loss function
        loss = self.custom_loss(policy_loss, value_loss, entropy_loss)

        return loss

In [4]:
class JointActionSpaceWrapper(gym.Env):
    def __init__(self, env):
        super(JointActionSpaceWrapper, self).__init__()
        self.env = env
        self.n_agents = env.n_agents

        # Define the joint action space
        action_spaces = []
        for space in env.action_space:
            if isinstance(space, gym.spaces.Discrete):
                action_spaces.append([space.n])
            elif isinstance(space, gym.spaces.MultiDiscrete):
                action_spaces.append(space.nvec)
        self.action_space = gym.spaces.MultiDiscrete(np.concatenate(action_spaces, axis=0))

        # Define the joint observation space
        obs_shape = (self.n_agents * env.observation_space[0].shape[0],)
        self.observation_space = gym.spaces.Box(
            low=-np.inf, 
            high=np.inf, 
            shape=obs_shape, 
            dtype=np.float32
        )

    def reset(self, seed=None, **kwargs):
        obss, info = self.env.reset(seed=seed, **kwargs)
        obss = [np.array([o]) if np.isscalar(o) else o for o in obss]
        return np.concatenate(obss), info

    def step(self, actions):
        # Split actions for each agent
        split_actions = np.split(actions, self.n_agents)

        # Pass actions to the environment
        obss, rewards, done, truncated, info = self.env.step(split_actions)

        # Concatenate observations from all agents
        obss = [np.array([o]) if np.isscalar(o) else o for o in obss]
        joint_obs = np.concatenate(obss)

        # Aggregate rewards for cooperation
        joint_reward = sum(rewards) / self.n_agents

        # Check if all agents are done
        if isinstance(done, bool):
            joint_done = done or truncated
        else:
            joint_done = all(done) or all(truncated)

        return joint_obs, float(joint_reward), joint_done, truncated, info


In [5]:

# Environment layout
layout = """
g......
...x...
..x.x..
.x...x.
..x.x..
...x...
......g
"""

# Create and wrap the warehouse environment
env = gym.make("rware:rware-tiny-2ag-v2", layout=layout)
wrapped_env = JointActionSpaceWrapper(env)
vec_env = DummyVecEnv([lambda: wrapped_env])

policy_kwargs = dict(
    features_extractor_class=MultiAgentFeatureExtractor,
    features_extractor_kwargs=dict(n_agents=2, features_dim=142),
    net_arch=dict(pi=[128, 64], vf=[128, 64])
)

# Initialize custom PPO model with joint observations and actions
model = CustomPPO('MlpPolicy', vec_env, policy_kwargs=policy_kwargs, verbose=1)

#print(vec_env.observation_space)

# Train the model
model.learn(total_timesteps=50000)

# Save the trained model
model.save("ppo_multi_agent_coordinated")


Using cpu device


  logger.warn(


-----------------------------
| time/              |      |
|    fps             | 788  |
|    iterations      | 1    |
|    time_elapsed    | 2    |
|    total_timesteps | 2048 |
-----------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 42          |
|    iterations           | 2           |
|    time_elapsed         | 96          |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.011298375 |
|    clip_fraction        | 0.11        |
|    clip_range           | 0.2         |
|    entropy_loss         | -3.21       |
|    explained_variance   | -0.246      |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0398     |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.0193     |
|    value_loss           | 0.0135      |
-----------------------------------------
----------------------------------

In [8]:
model = PPO.load("ppo_multi_agent_coordinated")
obs = vec_env.reset()
done = False

env.render()

for step in range(500):
    # Predict actions using the trained model
    actions, _states = model.predict(obs)
    actions = np.squeeze(actions)  # Ensure actions are in the correct shape

    # Step through the environmentâ€‹
    obs, rewards, done, truncated, info = wrapped_env.step(actions)

    # Render the environment
    env.render()
    time.sleep(0.1)

    # Reset the environment if done
    if done or truncated:
        obs = wrapped_env.reset()

env.close

<bound method Wrapper.close of <OrderEnforcing<PassiveEnvChecker<Warehouse<rware-tiny-2ag-v2>>>>>