In [1]:
import ray
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.models import ModelCatalog
from ray import tune
from gymnasium import spaces
import numpy as np
import torch
import torch.nn as nn

# Custom RLModule for PyTorch
class CustomRLModule(RLModule):
    def __init__(self, config):
        super().__init__(config)
        self.obs_dim = config.observation_space.shape[0]
        self.num_actions = config.action_space.n

        # Define a simple neural network
        self.network = nn.Sequential(
            nn.Linear(self.obs_dim, 8),
            nn.ReLU(),
            nn.Linear(8, 4),
            nn.ReLU(),
            nn.Linear(4, self.num_actions)
        )

        # Value head for PPO
        self.value_branch = nn.Linear(8, 1)
        self._last_value = None

    def forward_train(self, batch, **kwargs):
        obs = batch["obs"].float()
        features = self.network[:-1](obs)  # Get features before final layer
        action_logits = self.network[-1](features)
        self._last_value = self.value_branch(features).squeeze(-1)
        return {"action_dist_inputs": action_logits}

    def forward_inference(self, batch, **kwargs):
        obs = batch["obs"].float()
        features = self.network[:-1](obs)
        action_logits = self.network[-1](features)
        return {"action_dist_inputs": action_logits}

    def forward_exploration(self, batch, **kwargs):
        return self.forward_inference(batch, **kwargs)

    def get_state(self):
        return {}  # No recurrent state in this model

    def set_state(self, state):
        pass  # No state to set

    def get_train_action_dist_cls(self):
        from ray.rllib.models.torch.torch_distributions import TorchCategorical
        return TorchCategorical

    def get_inference_action_dist_cls(self):
        from ray.rllib.models.torch.torch_distributions import TorchCategorical
        return TorchCategorical

# Custom Random Policy
class RandomPolicy:
    def __init__(self, observation_space, action_space, config):
        self.action_space = action_space

    def compute_actions(self, obs_batch, state_batches=None, **kwargs):
        actions = [self.action_space.sample() for _ in range(len(obs_batch))]
        return actions, [], {}

    def learn_on_batch(self, samples):
        return {}  # Random policy doesn't learn

    def get_weights(self):
        return {}  # No weights for random policy

    def set_weights(self, weights):
        pass  # No weights to set

class CustomMARLEnv(MultiAgentEnv):
    def __init__(self, config=None):
        super().__init__()
        config = config or {}
        self._num_agents = config.get("num_agents", 2)
        self.obs_dim = config.get("obs_dim", 4)
        self.num_actions = config.get("num_actions", 3)
        self.max_steps = config.get("max_steps", 100)

        self.agents = [f"agent_{i}" for i in range(self._num_agents)]
        self._agent_ids = set(self.agents)
        self.current_step = 0

        obs_space = spaces.Box(low=-1, high=1, shape=(self.obs_dim,), dtype=np.float32)
        act_space = spaces.Discrete(self.num_actions)

        self.observation_space = {agent: obs_space for agent in self.agents}
        self.action_space = {agent: act_space for agent in self.agents}
        # self.observation_space = spaces.Dict({agent: obs_space for agent in self.agents})
        # self.action_space = spaces.Dict({agent: act_space for agent in self.agents})

        print(f"Initialized environment with {len(self.agents)} agents: {self.agents}")

    def get_action_space(self, agent_id):
        # Return the actual action space, not a dictionary

        print(f'get_action_space, agent_id: {agent_id}')
        print(f'get_action_space, self.action_space[agent_id]: {self.action_space[agent_id]}')

        return self.action_space[agent_id]

    def get_observation_space(self, agent_id):

        # print(f'get_observation_space, agent_id: {agent_id}')
        # print(f'get_observation_space, self.observation_spaces[agent_id]: {self.observation_spaces[agent_id]}')

        return self.observation_space[agent_id]

    def get_agent_ids(self):
        return self._agent_ids

    def reset(self, *, seed=None, options=None):
        self.current_step = 0
        observations = {agent: self.observation_space[agent].sample() for agent in self.agents}
        actions = {agent: self.action_space[agent].sample() for agent in self.agents}

        # print(f'reset observations: {observations}')
        # print(f'reset observations observations.keys: {observations.keys()}')
        # print(f'reset observations: {actions}')
        # print(f'reset observations actions.keys: {actions.keys()}')



        sampled_action = {}
        for aid in observations.keys():
            try:
                action_space = self.get_action_space(aid)
                print(f'action_space: {action_space}')
                if action_space is not None:
                    sampled_action[aid] = action_space.sample()
                else:
                    print(f"Warning: Action space for agent {aid} is None")
                    # You might want to skip this agent or use a default action space
            except (KeyError, TypeError) as e:
                print(f"Error getting action space for agent {aid}: {e}")
                # Handle the case where action_spaces is None or agent_id doesn't exist
            except Exception as e:
                print(f"Unexpected error sampling action for agent {aid}: {e}")



        infos = {agent: {} for agent in self.agents}
        return observations, infos

    def step(self, action_dict):
        self.current_step += 1

        if not action_dict:
            raise ValueError("No actions received")

        print(f"Step {self.current_step}: Received actions for {list(action_dict.keys())}")
        print(f"Expected agents: {self.agents}")
        print(f"step action_dict: {action_dict}")

        active_agents = [agent for agent in self.agents if agent in action_dict]

        observations = {}
        rewards = {}
        terminateds = {}
        truncateds = {}
        infos = {}

        for agent in active_agents:
            observations[agent] = self.observation_space[agent].sample()
            rewards[agent] = np.random.random()
            terminateds[agent] = False
            truncateds[agent] = False
            infos[agent] = {}

        terminateds["__all__"] = False
        truncateds["__all__"] = self.current_step >= self.max_steps

        return observations, rewards, terminateds, truncateds, infos

def env_creator(env_config):
    print(f"Creating environment with config: {env_config}")
    env = CustomMARLEnv(env_config)
    print(f"Environment created successfully with agents: {env.agents}")
    return env

# Initialize Ray
ray.shutdown()
ray.init(ignore_reinit_error=True)

# Register custom RLModule using register_custom_model
ModelCatalog.register_custom_model("custom_model", CustomRLModule)

# Register environment
tune.register_env("custom_marl_env", env_creator)

# Environment configuration
env_config = {
    "num_agents": 3,
    "obs_dim": 4,
    "num_actions": 3,
    "max_steps": 50,
}

print(f"Using env_config: {env_config}")

# Define policies
# def policy_mapping_fn(agent_id, episode, worker, **kwargs):
def policy_mapping_fn(agent_id, episode, **kwargs):
    if agent_id == "agent_0" or agent_id == "agent_1":
        return "ppo_policy"
    return "random_policy"

config = (
    PPOConfig()
    .environment(
        env="custom_marl_env",
        env_config=env_config
    )
    .multi_agent(
        policies={
            "ppo_policy": PolicySpec(
                policy_class=None,  # Use default PPO policy
                observation_space=spaces.Box(low=-1, high=1, shape=(4,), dtype=np.float32),
                action_space=spaces.Discrete(3),
                config={"model": {"custom_model": "custom_model"}}
            ),
            "random_policy": PolicySpec(
                policy_class=RandomPolicy,
                observation_space=spaces.Box(low=-1, high=1, shape=(4,), dtype=np.float32),
                action_space=spaces.Discrete(3)
            ),
        },
        policy_mapping_fn=policy_mapping_fn,
        policies_to_train=["ppo_policy"]  # Only train PPO policy
    )
    .env_runners(
        num_env_runners=0,
        # env_to_module_connector=(
        #     lambda env, spaces, device: FlattenObservations(multi_agent=True)
        # ),
    )
    .training(
        train_batch_size=300,
        num_epochs=5,
    )
    .framework("torch")  # Explicitly use PyTorch
    .debugging(log_level="DEBUG")
)

# Build and train
try:
    algo = config.build()

    for i in range(10):
        result = algo.train()
        reward_key = None
        for key in ['env_runners/episode_reward_mean', 'episode_reward_mean', 'hist_stats/episode_reward']:
            if key in result:
                reward_key = key
                break

        if reward_key:
            print(f"Iteration {i}: reward = {result[reward_key]}")
        else:
            print(f"Iteration {i}: training completed")

    algo.stop()

except Exception as e:
    print(f"Error: {e}")
    if 'algo' in locals():
        algo.stop()

ray.shutdown()

2025-07-27 00:19:43,065	INFO worker.py:1917 -- Started a local Ray instance.
`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
2025-07-27 00:19:43,705	INFO connector_pipeline_v2.py:272 -- Added AddObservationsFromEpisodesToBatch to the end of EnvToModulePipeline.
2025-07-27 00:19:43,711	INFO connector_pipeline_v2.py:272 -- Added AddTimeDimToBatchAndZe

Using env_config: {'num_agents': 3, 'obs_dim': 4, 'num_actions': 3, 'max_steps': 50}
Creating environment with config: {'num_agents': 3, 'obs_dim': 4, 'num_actions': 3, 'max_steps': 50, worker=0/0, vector_idx=0, remote=False}
Initialized environment with 3 agents: ['agent_0', 'agent_1', 'agent_2']
Environment created successfully with agents: ['agent_0', 'agent_1', 'agent_2']
get_action_space, agent_id: agent_0
get_action_space, self.action_space[agent_id]: Discrete(3)
action_space: Discrete(3)
get_action_space, agent_id: agent_1
get_action_space, self.action_space[agent_id]: Discrete(3)
action_space: Discrete(3)
get_action_space, agent_id: agent_2
get_action_space, self.action_space[agent_id]: Discrete(3)
action_space: Discrete(3)
get_action_space, agent_id: agent_0
get_action_space, self.action_space[agent_id]: Discrete(3)
get_action_space, agent_id: agent_1
get_action_space, self.action_space[agent_id]: Discrete(3)
get_action_space, agent_id: agent_2
get_action_space, self.action_sp

2025-07-27 00:19:43,793	INFO connector_pipeline_v2.py:272 -- Added AddObservationsFromEpisodesToBatch to the end of LearnerConnectorPipeline.
2025-07-27 00:19:43,794	INFO connector_pipeline_v2.py:272 -- Added AddColumnsFromEpisodesToTrainBatch to the end of LearnerConnectorPipeline.
2025-07-27 00:19:43,800	INFO connector_pipeline_v2.py:272 -- Added AddTimeDimToBatchAndZeroPad to the end of LearnerConnectorPipeline.
2025-07-27 00:19:43,806	INFO connector_pipeline_v2.py:272 -- Added AddStatesFromEpisodesToBatch to the end of LearnerConnectorPipeline.
2025-07-27 00:19:43,812	INFO connector_pipeline_v2.py:272 -- Added AgentToModuleMapping to the end of LearnerConnectorPipeline.
2025-07-27 00:19:43,818	INFO connector_pipeline_v2.py:272 -- Added BatchIndividualItems to the end of LearnerConnectorPipeline.
2025-07-27 00:19:43,823	INFO connector_pipeline_v2.py:272 -- Added NumpyToTensor to the end of LearnerConnectorPipeline.
2025-07-27 00:19:44,785	INFO connector_pipeline_v2.py:258 -- Added A

get_action_space, agent_id: agent_0
get_action_space, self.action_space[agent_id]: Discrete(3)
action_space: Discrete(3)
get_action_space, agent_id: agent_1
get_action_space, self.action_space[agent_id]: Discrete(3)
action_space: Discrete(3)
get_action_space, agent_id: agent_2
get_action_space, self.action_space[agent_id]: Discrete(3)
action_space: Discrete(3)
Step 1: Received actions for ['agent_1', 'agent_0', 'agent_2']
Expected agents: ['agent_0', 'agent_1', 'agent_2']
step action_dict: {'agent_1': np.int32(1), 'agent_0': np.int32(2), 'agent_2': np.int32(2)}
Step 2: Received actions for ['agent_1', 'agent_0', 'agent_2']
Expected agents: ['agent_0', 'agent_1', 'agent_2']
step action_dict: {'agent_1': np.int32(0), 'agent_0': np.int32(2), 'agent_2': np.int32(1)}
Step 3: Received actions for ['agent_1', 'agent_0', 'agent_2']
Expected agents: ['agent_0', 'agent_1', 'agent_2']
step action_dict: {'agent_1': np.int32(2), 'agent_0': np.int32(0), 'agent_2': np.int32(0)}
Step 4: Received actio