In [1]:
import ray
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models import ModelCatalog
from ray import tune
from gymnasium import spaces
import numpy as np
import torch
import torch.nn as nn

In [2]:
# Custom PyTorch Model
class CustomTorchModel(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)
        
        self.obs_dim = obs_space.shape[0]
        self.num_actions = num_outputs
        
        # Define a simple neural network
        self.network = nn.Sequential(
            nn.Linear(self.obs_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, self.num_actions)
        )
        
        # Value head for PPO
        self.value_branch = nn.Linear(32, 1)
        self._last_value = None
        
    def forward(self, input_dict, state, seq_lens):
        obs = input_dict["obs"].float()
        features = self.network[:-1](obs)  # Get features before final layer
        action_logits = self.network[-1](features)
        self._last_value = self.value_branch(features).squeeze(-1)
        return action_logits, state
    
    def value_function(self):
        return self._last_value

In [3]:
# Custom Random Policy
class RandomPolicy:
    def __init__(self, observation_space, action_space, config):
        self.action_space = action_space
        
    def compute_actions(self, obs_batch, state_batches=None, **kwargs):
        actions = [self.action_space.sample() for _ in range(len(obs_batch))]
        return actions, [], {}
    
    def learn_on_batch(self, samples):
        return {}  # Random policy doesn't learn
    
    def get_weights(self):
        return {}  # No weights for random policy
    
    def set_weights(self, weights):
        pass  # No weights to set

In [4]:
class CustomMARLEnv(MultiAgentEnv):
    def __init__(self, config=None):
        super().__init__()
        config = config or {}
        self._num_agents = config.get("num_agents", 3)
        self.obs_dim = config.get("obs_dim", 4)
        self.num_actions = config.get("num_actions", 3)
        self.max_steps = config.get("max_steps", 100)
        
        self.agents = [f"agent_{i}" for i in range(self._num_agents)]
        self._agent_ids = set(self.agents)
        self.current_step = 0
        
        obs_space = spaces.Box(low=-1, high=1, shape=(self.obs_dim,), dtype=np.float32)
        act_space = spaces.Discrete(self.num_actions)
        
        self.observation_space = {agent: obs_space for agent in self.agents}
        self.action_space = {agent: act_space for agent in self.agents}
        
        print(f"Initialized environment with {len(self.agents)} agents: {self.agents}")
    
    def get_agent_ids(self):
        return self._agent_ids
        
    def reset(self, *, seed=None, options=None):
        self.current_step = 0
        observations = {agent: self.observation_space[agent].sample() for agent in self.agents}
        infos = {agent: {} for agent in self.agents}
        return observations, infos
    
    def step(self, action_dict):
        self.current_step += 1
        
        if not action_dict:
            raise ValueError("No actions received")
        
        print(f"Step {self.current_step}: Received actions for {list(action_dict.keys())}")
        print(f"Expected agents: {self.agents}")
        
        active_agents = [agent for agent in self.agents if agent in action_dict]
        
        observations = {}
        rewards = {}
        terminateds = {}
        truncateds = {}
        infos = {}
        
        for agent in active_agents:
            observations[agent] = self.observation_space[agent].sample()
            rewards[agent] = np.random.random()
            terminateds[agent] = False
            truncateds[agent] = False
            infos[agent] = {}
        
        terminateds["__all__"] = False
        truncateds["__all__"] = self.current_step >= self.max_steps
        
        return observations, rewards, terminateds, truncateds, infos

In [5]:
def env_creator(env_config):
    print(f"Creating environment with config: {env_config}")
    env = CustomMARLEnv(env_config)
    print(f"Environment created successfully with agents: {env.agents}")
    return env

In [6]:
# Define policies
def policy_mapping_fn(agent_id, episode, **kwargs):
    if agent_id == "agent_0" or agent_id == "agent_1":
        return "ppo_policy"
    return "random_policy"

In [7]:
# Initialize Ray
ray.shutdown()
ray.init(ignore_reinit_error=True)

2025-07-23 12:13:06,857	INFO worker.py:1917 -- Started a local Ray instance.


0,1
Python version:,3.10.12
Ray version:,2.47.1


[36m(MultiAgentEnvRunner pid=827)[0m 2025-07-23 12:13:09,140	ERROR multi_agent_env_runner.py:834 -- 'dict' object has no attribute 'sample'
[36m(MultiAgentEnvRunner pid=827)[0m Traceback (most recent call last):
[36m(MultiAgentEnvRunner pid=827)[0m   File "/usr/local/lib/python3.10/dist-packages/ray/rllib/env/multi_agent_env_runner.py", line 832, in make_env
[36m(MultiAgentEnvRunner pid=827)[0m     check_multiagent_environments(env.unwrapped)
[36m(MultiAgentEnvRunner pid=827)[0m   File "/usr/local/lib/python3.10/dist-packages/ray/rllib/utils/pre_checks/env.py", line 57, in check_multiagent_environments
[36m(MultiAgentEnvRunner pid=827)[0m     sampled_action = {
[36m(MultiAgentEnvRunner pid=827)[0m   File "/usr/local/lib/python3.10/dist-packages/ray/rllib/utils/pre_checks/env.py", line 58, in <dictcomp>
[36m(MultiAgentEnvRunner pid=827)[0m     aid: env.get_action_space(aid).sample() for aid in reset_obs.keys()
[36m(MultiAgentEnvRunner pid=827)[0m AttributeError: 'dict'

[36m(MultiAgentEnvRunner pid=827)[0m Creating environment with config: {'num_agents': 3, 'obs_dim': 4, 'num_actions': 3, 'max_steps': 50, worker=1/1, vector_idx=0, remote=False}
[36m(MultiAgentEnvRunner pid=827)[0m Initialized environment with 3 agents: ['agent_0', 'agent_1', 'agent_2']
[36m(MultiAgentEnvRunner pid=827)[0m Environment created successfully with agents: ['agent_0', 'agent_1', 'agent_2']
[36m(MultiAgentEnvRunner pid=827)[0m Step 1: Received actions for ['agent_1', 'agent_0', 'agent_2']
[36m(MultiAgentEnvRunner pid=827)[0m Expected agents: ['agent_0', 'agent_1', 'agent_2']
[36m(MultiAgentEnvRunner pid=827)[0m Step 2: Received actions for ['agent_1', 'agent_0', 'agent_2']
[36m(MultiAgentEnvRunner pid=827)[0m Expected agents: ['agent_0', 'agent_1', 'agent_2']
[36m(MultiAgentEnvRunner pid=827)[0m Step 3: Received actions for ['agent_1', 'agent_0', 'agent_2']
[36m(MultiAgentEnvRunner pid=827)[0m Expected agents: ['agent_0', 'agent_1', 'agent_2']
[36m(MultiAg

In [8]:
# Register custom model
ModelCatalog.register_custom_model("custom_model", CustomTorchModel)

# Register environment
tune.register_env("custom_marl_env", env_creator)

# Environment configuration
env_config = {
    "num_agents": 3,
    "obs_dim": 4,
    "num_actions": 3,
    "max_steps": 50,
}

print(f"Using env_config: {env_config}")

Using env_config: {'num_agents': 3, 'obs_dim': 4, 'num_actions': 3, 'max_steps': 50}


In [9]:
config = (
    PPOConfig()
    .environment(
        env="custom_marl_env",
        env_config=env_config
    )
    .multi_agent(
        policies={
            "ppo_policy": PolicySpec(
                policy_class=None,  # Use default PPO policy
                observation_space=spaces.Box(low=-1, high=1, shape=(4,), dtype=np.float32),
                action_space=spaces.Discrete(3),
                config={"model": {"custom_model": "custom_model"}}
            ),
            "random_policy": PolicySpec(
                policy_class=RandomPolicy,
                observation_space=spaces.Box(low=-1, high=1, shape=(4,), dtype=np.float32),
                action_space=spaces.Discrete(3)
            ),
        },
        policy_mapping_fn=policy_mapping_fn,
        policies_to_train=["ppo_policy"]  # Only train PPO policy
    )
    .env_runners(
        num_env_runners=1
    )
    .training(
        train_batch_size=1000,
        num_sgd_iter=5,
    )
    .framework("torch")  # Explicitly use PyTorch
    .debugging(log_level="DEBUG")
)




In [10]:
# Build and train
try:
    algo = config.build()
    
    for i in range(10):
        result = algo.train()
        reward_key = None
        for key in ['env_runners/episode_reward_mean', 'episode_reward_mean', 'hist_stats/episode_reward']:
            if key in result:
                reward_key = key
                break
        
        if reward_key:
            print(f"Iteration {i}: reward = {result[reward_key]}")
        else:
            print(f"Iteration {i}: training completed")
    
    algo.stop()
    
except Exception as e:
    print(f"Error: {e}")
    if 'algo' in locals():
        algo.stop()

`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
2025-07-23 12:13:09,263	INFO env_runner_group.py:320 -- Inferred observation/action spaces from remote worker (local worker has no env): {'__env__': (None, None), '__env_single__': (Dict('agent_0': Box(-1.0, 1.0, (4,), float32), 'agent_1': Box(-1.0, 1.0, (4,), float32), 'agent_2': Box(-1.0, 1.0, (4,), 

Iteration 0: training completed
Iteration 1: training completed
Iteration 2: training completed
Iteration 3: training completed
Iteration 4: training completed
Iteration 5: training completed
Iteration 6: training completed
Iteration 7: training completed
Iteration 8: training completed
Iteration 9: training completed


In [11]:
ray.shutdown()