In [21]:
# Import necessary packages
from typing import Any, Dict, List, Tuple, Union
from functools import partial
import numpy as np
import torch
from torch.optim import Adam, RMSprop

from maro.rl.model import DiscreteACBasedNet, FullyConnected, VNet, DiscreteQNet
from maro.rl.policy import DiscretePolicyGradient
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
from maro.rl.rollout import AbsEnvSampler, CacheElement, ExpElement
from maro.rl.training import TrainingManager
from maro.simulator import Env
from maro.simulator.scenarios.cim.common import Action, ActionType, DecisionEvent

from maro.rl.exploration import epsilon_greedy
from maro.rl.policy import ValueBasedPolicy
from maro.rl.training.algorithms import DQNParams, DQNTrainer


from maro.rl.training.algorithms import PPOParams, PPOTrainer

from maro.rl.model import DiscreteACBasedNet, FullyConnected, MultiQNet
from maro.rl.policy import DiscretePolicyGradient
from maro.rl.training.algorithms import DiscreteMADDPGParams, DiscreteMADDPGTrainer

In [2]:
# env and shaping config
reward_shaping_conf = {
    "time_window": 99,
    "fulfillment_factor": 1.0,
    "shortage_factor": 1.0,
    "time_decay": 0.97,
}
state_shaping_conf = {
    "look_back": 7,
    "max_ports_downstream": 2,
}
port_attributes = ["empty", "full", "on_shipper", "on_consignee", "booking", "shortage", "fulfillment"]
vessel_attributes = ["empty", "full", "remaining_space"]
action_shaping_conf = {
    "action_space": [(i - 10) / 10 for i in range(21)],
    "finite_vessel_space": True,
    "has_early_discharge": True,
}

In [18]:
class CIMEnvSampler(AbsEnvSampler):
    def _get_global_and_agent_state_impl(
        self,
        event: DecisionEvent,
        tick: int = None,
    ) -> Tuple[Union[None, np.ndarray, List[object]], Dict[Any, Union[np.ndarray, List[object]]]]:
        tick = self._env.tick
        vessel_snapshots, port_snapshots = self._env.snapshot_list["vessels"], self._env.snapshot_list["ports"]
        port_idx, vessel_idx = event.port_idx, event.vessel_idx
        ticks = [max(0, tick - rt) for rt in range(state_shaping_conf["look_back"] - 1)]
        future_port_list = vessel_snapshots[tick:vessel_idx:"future_stop_list"].astype("int")
        state = np.concatenate(
            [
                port_snapshots[ticks : [port_idx] + list(future_port_list) : port_attributes],
                vessel_snapshots[tick:vessel_idx:vessel_attributes],
            ],
        )
        return state, {port_idx: state}

    def _translate_to_env_action(
        self,
        action_dict: Dict[Any, Union[np.ndarray, List[object]]],
        event: DecisionEvent,
    ) -> Dict[Any, object]:
        action_space = action_shaping_conf["action_space"]
        finite_vsl_space = action_shaping_conf["finite_vessel_space"]
        has_early_discharge = action_shaping_conf["has_early_discharge"]

        port_idx, model_action = list(action_dict.items()).pop()

        vsl_idx, action_scope = event.vessel_idx, event.action_scope
        vsl_snapshots = self._env.snapshot_list["vessels"]
        vsl_space = vsl_snapshots[self._env.tick : vsl_idx : vessel_attributes][2] if finite_vsl_space else float("inf")

        percent = abs(action_space[model_action[0]])
        zero_action_idx = len(action_space) / 2  # index corresponding to value zero.
        if model_action < zero_action_idx:
            action_type = ActionType.LOAD
            actual_action = min(round(percent * action_scope.load), vsl_space)
        elif model_action > zero_action_idx:
            action_type = ActionType.DISCHARGE
            early_discharge = (
                vsl_snapshots[self._env.tick : vsl_idx : "early_discharge"][0] if has_early_discharge else 0
            )
            plan_action = percent * (action_scope.discharge + early_discharge) - early_discharge
            actual_action = round(plan_action) if plan_action > 0 else round(percent * action_scope.discharge)
        else:
            actual_action, action_type = 0, None

        return {port_idx: Action(vsl_idx, int(port_idx), actual_action, action_type)}

    def _get_reward(self, env_action_dict: Dict[Any, object], event: DecisionEvent, tick: int) -> Dict[Any, float]:
        start_tick = tick + 1
        ticks = list(range(start_tick, start_tick + reward_shaping_conf["time_window"]))

        # Get the ports that took actions at the given tick
        ports = [int(port) for port in list(env_action_dict.keys())]
        port_snapshots = self._env.snapshot_list["ports"]
        future_fulfillment = port_snapshots[ticks:ports:"fulfillment"].reshape(len(ticks), -1)
        future_shortage = port_snapshots[ticks:ports:"shortage"].reshape(len(ticks), -1)

        decay_list = [reward_shaping_conf["time_decay"] ** i for i in range(reward_shaping_conf["time_window"])]
        rewards = np.float32(
            reward_shaping_conf["fulfillment_factor"] * np.dot(future_fulfillment.T, decay_list)
            - reward_shaping_conf["shortage_factor"] * np.dot(future_shortage.T, decay_list),
        )
        return {agent_id: reward for agent_id, reward in zip(ports, rewards)}

    def _post_step(self, cache_element: CacheElement) -> None:
        self._info["env_metric"] = self._env.metrics

    def _post_eval_step(self, cache_element: CacheElement) -> None:
        self._post_step(cache_element)

    def post_collect(self, info_list: list, ep: int) -> None:
        # print the env metric from each rollout worker
        for info in info_list:
            print(f"env summary (episode {ep}): {info['env_metric']}")

        # average env metric
        metric_keys, num_envs = info_list[0]["env_metric"].keys(), len(info_list)
        avg_metric = {key: sum(info["env_metric"][key] for info in info_list) / num_envs for key in metric_keys}
        print(f"average env summary (episode {ep}): {avg_metric}")

        self.metrics.update(avg_metric)
        self.metrics = {k: v for k, v in self.metrics.items() if not k.startswith("val/")}

    def post_evaluate(self, info_list: list, ep: int) -> None:
        # print the env metric from each rollout worker
        for info in info_list:
            print(f"env summary (episode {ep}): {info['env_metric']}")

        # average env metric
        metric_keys, num_envs = info_list[0]["env_metric"].keys(), len(info_list)
        avg_metric = {key: sum(info["env_metric"][key] for info in info_list) / num_envs for key in metric_keys}
        print(f"average env summary (episode {ep}): {avg_metric}")

        self.metrics.update({"val/" + k: v for k, v in avg_metric.items()})

    def monitor_metrics(self) -> float:
        return -self.metrics["val/container_shortage"]

# Actor - Critic

In [5]:
import torch
from torch.optim import Adam, RMSprop

from maro.rl.model import DiscreteACBasedNet, FullyConnected, VNet
from maro.rl.policy import DiscretePolicyGradient
from maro.rl.training.algorithms import ActorCriticParams, ActorCriticTrainer

state_dim = (
    (state_shaping_conf["look_back"] + 1) * (state_shaping_conf["max_ports_downstream"] + 1) * len(port_attributes)
    + len(vessel_attributes)
)
action_num = len(action_shaping_conf["action_space"])

actor_net_conf = {
    "hidden_dims": [256, 128, 64],
    "activation": torch.nn.Tanh,
    "output_activation": torch.nn.Tanh,
    "softmax": True,
    "batch_norm": False,
    "head": True,
}
critic_net_conf = {
    "hidden_dims": [256, 128, 64],
    "output_dim": 1,
    "activation": torch.nn.LeakyReLU,
    "output_activation": torch.nn.LeakyReLU,
    "softmax": False,
    "batch_norm": True,
    "head": True,
}
actor_learning_rate = 0.001
critic_learning_rate = 0.001


class MyActorNet(DiscreteACBasedNet):
    def __init__(self, state_dim: int, action_num: int) -> None:
        super(MyActorNet, self).__init__(state_dim=state_dim, action_num=action_num)
        self._actor = FullyConnected(input_dim=state_dim, output_dim=action_num, **actor_net_conf)
        self._optim = Adam(self._actor.parameters(), lr=actor_learning_rate)

    def _get_action_probs_impl(self, states: torch.Tensor) -> torch.Tensor:
        return self._actor(states)


class MyCriticNet(VNet):
    def __init__(self, state_dim: int) -> None:
        super(MyCriticNet, self).__init__(state_dim=state_dim)
        self._critic = FullyConnected(input_dim=state_dim, **critic_net_conf)
        self._optim = RMSprop(self._critic.parameters(), lr=critic_learning_rate)

    def _get_v_values(self, states: torch.Tensor) -> torch.Tensor:
        return self._critic(states).squeeze(-1)


def get_ac_policy(state_dim: int, action_num: int, name: str) -> DiscretePolicyGradient:
    return DiscretePolicyGradient(name=name, policy_net=MyActorNet(state_dim, action_num))


def get_ac(state_dim: int, name: str) -> ActorCriticTrainer:
    return ActorCriticTrainer(
        name=name,
        reward_discount=0.0,
        params=ActorCriticParams(
            get_v_critic_net_func=lambda: MyCriticNet(state_dim),
            grad_iters=10,
            critic_loss_cls=torch.nn.SmoothL1Loss,
            min_logp=None,
            lam=0.0,
        ),
    )

In [25]:
learn_env = Env(scenario="cim", topology="toy.4p_ssdd_l0.0", durations=500,options={"enable-dump-snapshot": "./ac_learn"})
test_env = Env(scenario="cim", topology="toy.4p_ssdd_l0.1", durations=500,options={"enable-dump-snapshot": "./ac_test"})
num_agents = len(learn_env.agent_idx_list)
agent2policy = {agent: f"ac_{agent}.policy" for agent in learn_env.agent_idx_list}
policies = [get_ac_policy(state_dim, action_num, f"ac_{i}.policy") for i in range(num_agents)]
trainers = [get_ac(state_dim, f"ac_{i}") for i in range(num_agents)]

rl_component_bundle = RLComponentBundle(
    env_sampler=CIMEnvSampler(
        learn_env=learn_env,
        test_env=test_env,
        policies=policies,
        agent2policy=agent2policy,
        reward_eval_delay=reward_shaping_conf["time_window"],
    ),
    agent2policy=agent2policy,
    policies=policies,
    trainers=trainers,
)

In [26]:
env_sampler = rl_component_bundle.env_sampler

num_episodes = 30
eval_schedule = [5,10,15,20,25,30]
eval_point_index = 0

training_manager = TrainingManager(rl_component_bundle=rl_component_bundle)

# main loop
for ep in range(1, num_episodes + 1):
    result = env_sampler.sample()
    experiences: List[List[ExpElement]] = result["experiences"]
    info_list: List[dict] = result["info"]
        
    print("Collecting result:")
    env_sampler.post_collect(info_list, ep)
    print()

    training_manager.record_experiences(experiences)
    training_manager.train_step()

    if ep == eval_schedule[eval_point_index]:
        eval_point_index += 1
        result = env_sampler.eval()
        
        print("Evaluation result:")
        env_sampler.post_evaluate(result["info"], ep)
        print()

training_manager.exit()

Collecting result:
env summary (episode 1): {'order_requirements': 1000000, 'container_shortage': 669920, 'operation_number': 1747069}
average env summary (episode 1): {'order_requirements': 1000000.0, 'container_shortage': 669920.0, 'operation_number': 1747069.0}

Collecting result:
env summary (episode 2): {'order_requirements': 1000000, 'container_shortage': 692457, 'operation_number': 1578829}
average env summary (episode 2): {'order_requirements': 1000000.0, 'container_shortage': 692457.0, 'operation_number': 1578829.0}

Collecting result:
env summary (episode 3): {'order_requirements': 1000000, 'container_shortage': 518699, 'operation_number': 1969736}
average env summary (episode 3): {'order_requirements': 1000000.0, 'container_shortage': 518699.0, 'operation_number': 1969736.0}

Collecting result:
env summary (episode 4): {'order_requirements': 1000000, 'container_shortage': 584074, 'operation_number': 2005531}
average env summary (episode 4): {'order_requirements': 1000000.0, 

# DQN

In [27]:
q_net_conf = {
    "hidden_dims": [256, 128, 64, 32],
    "activation": torch.nn.LeakyReLU,
    "output_activation": torch.nn.LeakyReLU,
    "softmax": False,
    "batch_norm": True,
    "skip_connection": False,
    "head": True,
    "dropout_p": 0.0,
}
learning_rate = 0.05


class MyQNet(DiscreteQNet):
    def __init__(
        self,
        state_dim: int,
        action_num: int,
        dueling_param: Optional[Tuple[dict, dict]] = None,
    ) -> None:
        super(MyQNet, self).__init__(state_dim=state_dim, action_num=action_num)

        self._use_dueling = dueling_param is not None
        self._fc = FullyConnected(input_dim=state_dim, output_dim=1 if self._use_dueling else action_num, **q_net_conf)
        if self._use_dueling:
            q_kwargs, v_kwargs = dueling_param
            self._q = FullyConnected(input_dim=self._fc.output_dim, output_dim=action_num, **q_kwargs)
            self._v = FullyConnected(input_dim=self._fc.output_dim, output_dim=1, **v_kwargs)

        self._optim = RMSprop(self.parameters(), lr=learning_rate)

    def _get_q_values_for_all_actions(self, states: torch.Tensor) -> torch.Tensor:
        logits = self._fc(states)
        if self._use_dueling:
            q = self._q(logits)
            v = self._v(logits)
            logits = q - q.mean(dim=1, keepdim=True) + v
        return logits


def get_dqn_policy(state_dim: int, action_num: int, name: str) -> ValueBasedPolicy:
    q_kwargs = {
        "hidden_dims": [128],
        "activation": torch.nn.LeakyReLU,
        "output_activation": torch.nn.LeakyReLU,
        "softmax": False,
        "batch_norm": True,
        "skip_connection": False,
        "head": True,
        "dropout_p": 0.0,
    }
    v_kwargs = {
        "hidden_dims": [128],
        "activation": torch.nn.LeakyReLU,
        "output_activation": None,
        "softmax": False,
        "batch_norm": True,
        "skip_connection": False,
        "head": True,
        "dropout_p": 0.0,
    }

    return ValueBasedPolicy(
        name=name,
        q_net=MyQNet(
            state_dim,
            action_num,
            dueling_param=(q_kwargs, v_kwargs),
        ),
        exploration_strategy=(epsilon_greedy, {"epsilon": 0.4}),
        warmup=100,
    )


def get_dqn(name: str) -> DQNTrainer:
    return DQNTrainer(
        name=name,
        reward_discount=0.0,
        replay_memory_capacity=10000,
        batch_size=32,
        params=DQNParams(
            update_target_every=5,
            num_epochs=10,
            soft_update_coef=0.1,
            double=False
        ),
    )

In [28]:
learn_env = Env(scenario="cim", topology="toy.4p_ssdd_l0.0", durations=500,options={"enable-dump-snapshot": "./dqn_learn"})
test_env = Env(scenario="cim", topology="toy.4p_ssdd_l0.1", durations=500,options={"enable-dump-snapshot": "./dqn_test"})
num_agents = len(learn_env.agent_idx_list)
agent2policy = {agent: f"dqn_{agent}.policy" for agent in learn_env.agent_idx_list}
policies = [get_dqn_policy(state_dim, action_num, f"dqn_{i}.policy") for i in range(num_agents)]
trainers = [get_dqn(f"dqn_{i}") for i in range(num_agents)]

rl_component_bundle = RLComponentBundle(
    env_sampler=CIMEnvSampler(
        learn_env=learn_env,
        test_env=test_env,
        policies=policies,
        agent2policy=agent2policy,
        reward_eval_delay=reward_shaping_conf["time_window"],
    ),
    agent2policy=agent2policy,
    policies=policies,
    trainers=trainers,
)

In [29]:
env_sampler = rl_component_bundle.env_sampler

num_episodes = 30
eval_schedule = [5,10,15,20,25,30]
eval_point_index = 0

training_manager = TrainingManager(rl_component_bundle=rl_component_bundle)
for ep in range(1, num_episodes + 1):
    result = env_sampler.sample()
    experiences: List[List[ExpElement]] = result["experiences"]
    info_list: List[dict] = result["info"]
        
    print("Collecting result:")
    env_sampler.post_collect(info_list, ep)
    print()

    training_manager.record_experiences(experiences)
    training_manager.train_step()

    if ep == eval_schedule[eval_point_index]:
        eval_point_index += 1
        result = env_sampler.eval()
        
        print("Evaluation result:")
        env_sampler.post_evaluate(result["info"], ep)
        print()

training_manager.exit()

Collecting result:
env summary (episode 1): {'order_requirements': 1000000, 'container_shortage': 677164, 'operation_number': 1597412}
average env summary (episode 1): {'order_requirements': 1000000.0, 'container_shortage': 677164.0, 'operation_number': 1597412.0}

Collecting result:
env summary (episode 2): {'order_requirements': 1000000, 'container_shortage': 945986, 'operation_number': 136577}
average env summary (episode 2): {'order_requirements': 1000000.0, 'container_shortage': 945986.0, 'operation_number': 136577.0}

Collecting result:
env summary (episode 3): {'order_requirements': 1000000, 'container_shortage': 976645, 'operation_number': 106695}
average env summary (episode 3): {'order_requirements': 1000000.0, 'container_shortage': 976645.0, 'operation_number': 106695.0}

Collecting result:
env summary (episode 4): {'order_requirements': 1000000, 'container_shortage': 967329, 'operation_number': 98635}
average env summary (episode 4): {'order_requirements': 1000000.0, 'conta

# MADDPG

In [32]:
actor_net_conf = {
    "hidden_dims": [256, 128, 64],
    "activation": torch.nn.Tanh,
    "output_activation": torch.nn.Tanh,
    "softmax": True,
    "batch_norm": False,
    "head": True,
}
critic_net_conf = {
    "hidden_dims": [256, 128, 64],
    "output_dim": 1,
    "activation": torch.nn.LeakyReLU,
    "output_activation": torch.nn.LeakyReLU,
    "softmax": False,
    "batch_norm": True,
    "head": True,
}
actor_learning_rate = 0.001
critic_learning_rate = 0.001

class MyActorNet(DiscreteACBasedNet):
    def __init__(self, state_dim: int, action_num: int) -> None:
        super(MyActorNet, self).__init__(state_dim=state_dim, action_num=action_num)
        self._actor = FullyConnected(input_dim=state_dim, output_dim=action_num, **actor_net_conf)
        self._optim = Adam(self._actor.parameters(), lr=actor_learning_rate)

    def _get_action_probs_impl(self, states: torch.Tensor) -> torch.Tensor:
        return self._actor(states)


class MyMultiCriticNet(MultiQNet):
    def __init__(self, state_dim: int, action_dims: List[int]) -> None:
        super(MyMultiCriticNet, self).__init__(state_dim=state_dim, action_dims=action_dims)
        self._critic = FullyConnected(input_dim=state_dim + sum(action_dims), **critic_net_conf)
        self._optim = RMSprop(self._critic.parameters(), critic_learning_rate)

    def _get_q_values(self, states: torch.Tensor, actions: List[torch.Tensor]) -> torch.Tensor:
        return self._critic(torch.cat([states] + actions, dim=1)).squeeze(-1)


def get_multi_critic_net(state_dim: int, action_dims: List[int]) -> MyMultiCriticNet:
    return MyMultiCriticNet(state_dim, action_dims)


def get_maddpg_policy(state_dim: int, action_num: int, name: str) -> DiscretePolicyGradient:
    return DiscretePolicyGradient(name=name, policy_net=MyActorNet(state_dim, action_num))


def get_maddpg(state_dim: int, action_dims: List[int], name: str) -> DiscreteMADDPGTrainer:
    return DiscreteMADDPGTrainer(
        name=name,
        reward_discount=0.0,
        params=DiscreteMADDPGParams(
            num_epoch=10,
            get_q_critic_net_func=partial(get_multi_critic_net, state_dim, action_dims),
            shared_critic=False,
        ),
    )

In [33]:
learn_env = Env(scenario="cim", topology="toy.4p_ssdd_l0.0", durations=500,options={"enable-dump-snapshot": "./maddpg_learn"})
test_env = Env(scenario="cim", topology="toy.4p_ssdd_l0.1", durations=500,options={"enable-dump-snapshot": "./maddpg_test"})

num_agents = len(learn_env.agent_idx_list)
agent2policy = {agent: f"maddpg_{agent}.policy" for agent in learn_env.agent_idx_list}
policies = [get_maddpg_policy(state_dim, action_num, f"maddpg_{i}.policy") for i in range(num_agents)]
trainers = [get_maddpg(state_dim, [1], f"maddpg_{i}") for i in range(num_agents)]

# Build RLComponentBundle
rl_component_bundle = RLComponentBundle(
    env_sampler=CIMEnvSampler(
        learn_env=learn_env,
        test_env=test_env,
        policies=policies,
        agent2policy=agent2policy,
        reward_eval_delay=reward_shaping_conf["time_window"],
    ),
    agent2policy=agent2policy,
    policies=policies,
    trainers=trainers,
)

In [34]:
env_sampler = rl_component_bundle.env_sampler

num_episodes = 30
eval_schedule = [5,10,15,20,25,30]
eval_point_index = 0

training_manager = TrainingManager(rl_component_bundle=rl_component_bundle)

# main loop
for ep in range(1, num_episodes + 1):
    result = env_sampler.sample()
    experiences: List[List[ExpElement]] = result["experiences"]
    info_list: List[dict] = result["info"]
        
    print("Collecting result:")
    env_sampler.post_collect(info_list, ep)
    print()

    training_manager.record_experiences(experiences)
    training_manager.train_step()

    if ep == eval_schedule[eval_point_index]:
        eval_point_index += 1
        result = env_sampler.eval()
        
        print("Evaluation result:")
        env_sampler.post_evaluate(result["info"], ep)
        print()

training_manager.exit()

Collecting result:
env summary (episode 1): {'order_requirements': 1000000, 'container_shortage': 678834, 'operation_number': 1804825}
average env summary (episode 1): {'order_requirements': 1000000.0, 'container_shortage': 678834.0, 'operation_number': 1804825.0}

Collecting result:
env summary (episode 2): {'order_requirements': 1000000, 'container_shortage': 538220, 'operation_number': 1745516}
average env summary (episode 2): {'order_requirements': 1000000.0, 'container_shortage': 538220.0, 'operation_number': 1745516.0}

Collecting result:
env summary (episode 3): {'order_requirements': 1000000, 'container_shortage': 471884, 'operation_number': 1915979}
average env summary (episode 3): {'order_requirements': 1000000.0, 'container_shortage': 471884.0, 'operation_number': 1915979.0}

Collecting result:
env summary (episode 4): {'order_requirements': 1000000, 'container_shortage': 385978, 'operation_number': 1995230}
average env summary (episode 4): {'order_requirements': 1000000.0, 

# PPO

In [35]:
actor_net_conf = {
    "hidden_dims": [256, 128, 64],
    "activation": torch.nn.Tanh,
    "softmax": True,
    "batch_norm": False,
    "head": True,
}
critic_net_conf = {
    "hidden_dims": [256, 128, 64],
    "output_dim": 1,
    "activation": torch.nn.LeakyReLU,
    "softmax": False,
    "batch_norm": True,
    "head": True,
}

actor_learning_rate = 0.001
critic_learning_rate = 0.001

class MyActorNet(DiscreteACBasedNet):
    def __init__(self, state_dim: int, action_num: int) -> None:
        super(MyActorNet, self).__init__(state_dim=state_dim, action_num=action_num)
        self._actor = FullyConnected(input_dim=state_dim, output_dim=action_num, **actor_net_conf)
        self._optim = Adam(self._actor.parameters(), lr=actor_learning_rate)

    def _get_action_probs_impl(self, states: torch.Tensor) -> torch.Tensor:
        return self._actor(states)


class MyCriticNet(VNet):
    def __init__(self, state_dim: int) -> None:
        super(MyCriticNet, self).__init__(state_dim=state_dim)
        self._critic = FullyConnected(input_dim=state_dim, **critic_net_conf)
        self._optim = RMSprop(self._critic.parameters(), lr=critic_learning_rate)

    def _get_v_values(self, states: torch.Tensor) -> torch.Tensor:
        return self._critic(states).squeeze(-1)
    
def get_ppo_policy(state_dim: int, action_num: int, name: str) -> DiscretePolicyGradient:
    return DiscretePolicyGradient(name=name, policy_net=MyActorNet(state_dim, action_num))


def get_ppo(state_dim: int, name: str) -> PPOTrainer:
    return PPOTrainer(
        name=name,
        reward_discount=0.0,
        params=PPOParams(
            get_v_critic_net_func=lambda: MyCriticNet(state_dim),
            grad_iters=10,
            critic_loss_cls=torch.nn.SmoothL1Loss,
            lam=0.0,
            clip_ratio=0.1,
        ),
    )

In [36]:
learn_env = Env(scenario="cim", topology="toy.4p_ssdd_l0.0", durations=500,options={"enable-dump-snapshot": "./ppo_learn"})
test_env = Env(scenario="cim", topology="toy.4p_ssdd_l0.1", durations=500,options={"enable-dump-snapshot": "./ppo_test"})

num_agents = len(learn_env.agent_idx_list)
agent2policy = {agent: f"ppo_{agent}.policy" for agent in learn_env.agent_idx_list}
policies = [get_ppo_policy(state_dim, action_num, f"ppo_{i}.policy") for i in range(num_agents)]
trainers = [get_ppo(state_dim, f"ppo_{i}") for i in range(num_agents)]

rl_component_bundle = RLComponentBundle(
    env_sampler=CIMEnvSampler(
        learn_env=learn_env,
        test_env=test_env,
        policies=policies,
        agent2policy=agent2policy,
        reward_eval_delay=reward_shaping_conf["time_window"],
    ),
    agent2policy=agent2policy,
    policies=policies,
    trainers=trainers,
)

In [37]:
env_sampler = rl_component_bundle.env_sampler

num_episodes = 30
eval_schedule = [5,10,15,20,25,30]
eval_point_index = 0

training_manager = TrainingManager(rl_component_bundle=rl_component_bundle)

# main loop
for ep in range(1, num_episodes + 1):
    result = env_sampler.sample()
    experiences: List[List[ExpElement]] = result["experiences"]
    info_list: List[dict] = result["info"]
        
    print("Collecting result:")
    env_sampler.post_collect(info_list, ep)
    print()

    training_manager.record_experiences(experiences)
    training_manager.train_step()

    if ep == eval_schedule[eval_point_index]:
        eval_point_index += 1
        result = env_sampler.eval()
        
        print("Evaluation result:")
        env_sampler.post_evaluate(result["info"], ep)
        print()

training_manager.exit()

Collecting result:
env summary (episode 1): {'order_requirements': 1000000, 'container_shortage': 673956, 'operation_number': 2061585}
average env summary (episode 1): {'order_requirements': 1000000.0, 'container_shortage': 673956.0, 'operation_number': 2061585.0}

Collecting result:
env summary (episode 2): {'order_requirements': 1000000, 'container_shortage': 754426, 'operation_number': 1576141}
average env summary (episode 2): {'order_requirements': 1000000.0, 'container_shortage': 754426.0, 'operation_number': 1576141.0}

Collecting result:
env summary (episode 3): {'order_requirements': 1000000, 'container_shortage': 600971, 'operation_number': 1919497}
average env summary (episode 3): {'order_requirements': 1000000.0, 'container_shortage': 600971.0, 'operation_number': 1919497.0}

Collecting result:
env summary (episode 4): {'order_requirements': 1000000, 'container_shortage': 580572, 'operation_number': 2018525}
average env summary (episode 4): {'order_requirements': 1000000.0, 