In [8]:
"""
Multi-Agent Proximal Policy Optimization (MAPPO) with the Centralized Value Function
"""

import os

import numpy as np
import ray
import torch
import torch.nn as nn
from ray import tune
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.policy.policy import PolicySpec
from ray.train import report
from ray.tune import Tuner, TuneConfig, PlacementGroupFactory
from ray.tune.logger import pretty_print
from ray.tune.registry import register_env

from config import env_configs
from env import env_creator




class CentralizedCriticModel(TorchModelV2, nn.Module):
    """
    Centralized Critic Model for MAPPO
    """

    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)

        # Shared network for policy
        self.shared_network = FullyConnectedNetwork(
            obs_space, action_space, num_outputs, model_config, name + "_shared"
        )

        # Centralized value function network
        self.central_value_function = FullyConnectedNetwork(
            obs_space, action_space, 1, model_config, name + "_vf"
        )

    def forward(self, input_dict, state, seq_lens):
        """Forward pass through the model."""
        features, _ = self.shared_network(input_dict)
        self._value_out, _ = self.central_value_function(input_dict)
        return features, []

    def value_function(self):
        """Compute the value function."""
        return torch.reshape(self._value_out, [-1])

# Register the custom model and environment
ModelCatalog.register_custom_model("centralized_critic", CentralizedCriticModel)
register_env("InventoryManagementEnv", env_creator)


# ----------------------------------------------
# PPO Training Function
# ----------------------------------------------

def tune_ppo(config):
    """
    Run PPO training with given configuration and environment.
    """
    env_config_name = "stochastic_demand"
    env_config = env_configs[env_config_name]
    num_episodes = 100

    
    # Define PPO algorithm with custom multi-agent setup
    algo = (
        PPOConfig()
        .api_stack(enable_rl_module_and_learner=False, enable_env_runner_and_connector_v2=False)
        .environment(env="InventoryManagementEnv", env_config=env_config)
        .framework("torch")
        .resources(num_gpus=1)
        .env_runners(num_envs_per_env_runner=1)
        .multi_agent(
            policies={
                f"policy_{m}": PolicySpec(
                    observation_space=env_creator(env_config).agent_observation_space,
                    action_space=env_creator(env_config).agent_action_space,
                    config={"model": {"custom_model": "centralized_critic"}},
                ) for m in range(env_config['num_stages'])
            },
            policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: f"policy_{agent_id.split('_')[-1]}",
        )
        .training(gamma=1.0)
        .evaluation(
            evaluation_interval=config["training_iteration"],
            evaluation_duration=num_episodes,
            evaluation_duration_unit="episodes",
            evaluation_parallel_to_training=False,
        )
        .update_from_dict(config)
        .build()
    )
    # Training loop
    for i in range(config["training_iteration"]):
        result = algo.train()
        report({"mean_episode_reward": result["env_runners"]["episode_reward_mean"]})
        
    # Evaluation phase
    results = algo.evaluate()
    episode_rewards = results['env_runners']['hist_stats']['episode_reward']
    episode_reward_mean = np.mean(episode_rewards)
    episode_reward_std = np.std(episode_rewards)
    print(f"env_config_name = {env_config_name}, num_episodes = {num_episodes}, "
          f"episode_reward_mean = {episode_reward_mean:.2f}, episode_reward_std = {episode_reward_std:.2f}")

    algo.stop()



# ----------------------------------------------
# Main Entry Point for Hyperparameter Tuning
# ----------------------------------------------

if __name__ == '__main__':
    ray.init()

    resources_per_trial =   PlacementGroupFactory([{"CPU": 1, 'GPU': 0.25}] + [{"CPU": 1}] * 2) # setting up the rescources

    tuner = Tuner(
        tune.with_resources(tune_ppo, resources_per_trial),
        tune_config=TuneConfig(
            metric="mean_episode_reward",
            mode="max",
            num_samples=20,
        ),
        param_space={
            "model": {
                "fcnet_hiddens": tune.choice([[128, 128], [256, 256]]),
                "fcnet_activation": tune.choice(["relu"]),
            },
            "lr": tune.choice([1e-4, 5e-4, 1e-3]),
            "train_batch_size": tune.choice([500, 1000, 2000]),
            "sgd_minibatch_size": tune.choice([32, 64, 128]),
            "num_sgd_iter": tune.choice([5, 10]),
            "training_iteration": tune.choice([500, 800, 1000, 1500]),
        },
        run_config=ray.air. RunConfig(
    storage_path=os.path.join(os.getcwd(), "..", "results"),
    name="MAPPO",
),
    )

    result = tuner.fit()

    
   # Display best result
    best_result = result.get_best_result(metric="mean_episode_reward", mode="max")
    best_config = best_result.config
    print(f"Best trial config: {pretty_print(best_config)}")

    ray.shutdown()


0,1
Current time:,2025-03-29 11:25:53
Running for:,00:00:44.21
Memory:,22.7/31.8 GiB

Trial name,status,loc,lr,model/fcnet_activati on,model/fcnet_hiddens,num_sgd_iter,sgd_minibatch_size,train_batch_size,training_iteration,iter,total time (s),mean_episode_reward
tune_ppo_16f32_00000,RUNNING,127.0.0.1:16220,0.0005,relu,"[128, 128]",10,64,500,800,3.0,36.6407,-1281.54
tune_ppo_16f32_00001,RUNNING,127.0.0.1:45284,0.0001,relu,"[128, 128]",5,32,2000,500,,,
tune_ppo_16f32_00002,RUNNING,127.0.0.1:18808,0.0005,relu,"[128, 128]",5,128,2000,800,,,
tune_ppo_16f32_00003,RUNNING,127.0.0.1:38416,0.0005,relu,"[256, 256]",10,32,1000,500,1.0,29.2072,-1286.13
tune_ppo_16f32_00004,PENDING,,0.0001,relu,"[256, 256]",10,128,1000,1500,,,
tune_ppo_16f32_00005,PENDING,,0.001,relu,"[128, 128]",10,64,500,800,,,
tune_ppo_16f32_00006,PENDING,,0.001,relu,"[256, 256]",10,64,500,800,,,
tune_ppo_16f32_00007,PENDING,,0.0001,relu,"[256, 256]",10,32,2000,1500,,,
tune_ppo_16f32_00008,PENDING,,0.001,relu,"[128, 128]",5,128,500,500,,,
tune_ppo_16f32_00009,PENDING,,0.0005,relu,"[128, 128]",5,128,1000,500,,,




[36m(pid=16220)[0m Variable demand for t=0: 4


[36m(tune_ppo pid=16220)[0m `UnifiedLogger` will be removed in Ray 2.7.
[36m(tune_ppo pid=16220)[0m   return UnifiedLogger(config, logdir, loggers=None)
[36m(tune_ppo pid=16220)[0m The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
[36m(tune_ppo pid=16220)[0m   self._loggers.append(cls(self.config, self.logdir, self.trial))
[36m(tune_ppo pid=16220)[0m The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
[36m(tune_ppo pid=16220)[0m   self._loggers.append(cls(self.config, self.logdir, self.trial))
[36m(tune_ppo pid=16220)[0m The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
[36m(tune_ppo pid=16220)[0m   self._loggers.append(cls(self.config, self.logdir, self.trial))


[36m(RolloutWorker pid=24008)[0m Variable demand for t=0: 4[32m [repeated 4x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)[0m


[36m(tune_ppo pid=16220)[0m Trainable.setup took 13.589 seconds. If your trainable is slow to initialize, consider setting reuse_actors=True to reduce actor creation overheads.
[36m(tune_ppo pid=38416)[0m `UnifiedLogger` will be removed in Ray 2.7.[32m [repeated 3x across cluster][0m
[36m(tune_ppo pid=38416)[0m   return UnifiedLogger(config, logdir, loggers=None)[32m [repeated 3x across cluster][0m
[36m(tune_ppo pid=38416)[0m The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.[32m [repeated 3x across cluster][0m
[36m(tune_ppo pid=38416)[0m   self._loggers.append(cls(self.config, self.logdir, self.trial))[32m [repeated 9x across cluster][0m
[36m(tune_ppo pid=38416)[0m The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.[32m [repeated 3x across cluster][0m
[36m(tune_ppo pid=38416)[0m The `TBXLogger interface 

Best trial config: lr: 0.0005
model:
  fcnet_activation: relu
  fcnet_hiddens:
  - 128
  - 128
num_sgd_iter: 10
sgd_minibatch_size: 64
train_batch_size: 500
training_iteration: 800

[36m(RolloutWorker pid=2784)[0m Variable demand for t=0: 4[32m [repeated 7x across cluster][0m


In [9]:
!pip freeze > current_packages.txt