# Simulating fleets of automated vehicles (AVs) making routing decisions: Medium traffic network, AV behaviors, IPPO/MAPPO algorithm implementation

> This tutorial is based on [Multi-Agent Reinforcement Learning (PPO) with TorchRL Tutorial](https://pytorch.org/rl/stable/tutorials/multiagent_ppo.html).

#### Imported libraries

In [None]:
import torch
from tqdm import tqdm

from tensordict.nn import TensorDictModule
from torchrl.collectors import SyncDataCollector
from torch.distributions import Categorical
from torchrl.envs.libs.pettingzoo import PettingZooWrapper
from torchrl.envs.transforms import TransformedEnv, RewardSum
from torchrl.envs.utils import check_env_specs
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.modules import MultiAgentMLP, ProbabilisticActor
from torchrl.objectives.value import GAE
from torchrl.objectives import ClipPPOLoss, ValueEstimators

import numpy as np
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '../../')))


# Now you can import the module
from routerl import TrafficEnvironment

os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"


#### Hyperparameters setting

In [None]:
# Devices
device = (
    torch.device(0)
    if torch.cuda.is_available()
    else torch.device("cpu")
)
print("device is: ", device)

# Sampling
frames_per_batch = 200  # Number of team frames collected per training iteration
n_iters = 10 # Number of sampling and training iterations - the episodes the plotter plots
total_frames = frames_per_batch * n_iters

# Training
num_epochs = 1  # Number of optimization steps per training iteration
minibatch_size = 2  # Size of the mini-batches in each optimization step
lr = 3e-4 # Learning rate
max_grad_norm = 3.0  # Maximum norm for the gradients

# PPO
clip_epsilon = 0.2  # clip value for PPO loss
gamma = 0.99  # discount factor
lmbda = 0.9  # lambda for generalised advantage estimation
entropy_eps = 1e-4  # coefficient of the entropy term in the PPO loss

policy_network_depth=3
policy_network_num_cells = 64

critic_network_depth=3
critic_network_num_cells = 64

# Human learning phase
human_learning_episodes = 2
new_machines_after_mutation = 100

test_eps = 5

# number of episodes the AV training will take
training_episodes = (frames_per_batch / new_machines_after_mutation) * n_iters

env_params = {
    "agent_parameters" : {
        "new_machines_after_mutation": new_machines_after_mutation,

        "human_parameters" :
        {
            "model" : "gawron",

            "noise_weight_agent" : 0,
            "noise_weight_path" : 0.8,
            "noise_weight_day" : 0.2,

            "beta" : -1,
            "beta_k_i_variability" : 0.1,
            "epsilon_i_variability" : 0.1,
            "epsilon_k_i_variability" : 0.1,
            "epsilon_k_i_t_variability" : 0.1,

            "greedy" : 0.1,
            "gamma_c" : 0.0,
            "gamma_u" : 0.0,
            "remember" : 1,

            "alpha_zero" : 0.8,
            "alphas" : [0.2]  
        },
        "machine_parameters" :
        {
            "behavior" : "selfish",
        }
    },
    "simulator_parameters" : {
        "network_name" : "saint_arnoult",
        "custom_network_folder" : "custom_networks/saint_arnoult",
        "sumo_type" : "sumo",
    },  
    "plotter_parameters" : {
        "phases" : [0, human_learning_episodes, int(training_episodes) + human_learning_episodes],
        "smooth_by" : 50,
        "plot_choices" : "basic",
        "phase_names" : [
            "Human learning", 
            "Mutation - Machine learning",
            "Testing phase"
        ],
        "records_folder" : "records/saint_arnoult_records",
        "plots_folder" : "plots/saint_arnoult_plots"
    },
    "path_generation_parameters":
    {
        "origins" : ['-42762428#0', '-1323419247', '71324347#2', '-100525445', '-282689981#0', '-101607967#0', '-101594809#0', '-659278038#3', '-101609498#5', '-120511302', '101416508#0', '-47374680#7', '336863934', '100475365#1', '-101611601', '-101604496#1', '101594836#1', '659281081#1', '-297823021', '-282689985#1', '526438862#2', '71324347#0', '297823019#4', '-47374665#1', '-297823024', '-100488715#4', '868087112', '75421017', '416409192#0', '-282689975#2', '120511932#0', '-101604513#1', '-100468675#0', '-416409190', '297823019#3', '-297823017', '-100468681#0', '-101600741#2', '-101608828', '-101594843', '-101787489', '-1323419243', '-101601559', '120511930', '659283000#1', '526439431#0', '-71324343#4', '-101607098#1', '101611600', '526438862#0', '101417074', '71324343#3', '-101411571#0', '-101749462', '-42762428#3', '71324347#5', '101611060', '-101415865#0', '-100525438#2', '-101418584#1', '71223791#0', '-71223800#7', '120509991', '-1006827381', '-416375813', '-71324343#5', '-101601551', '-416409191#0', '101606547#1', '-101746498#2', '-101601553#1', '71324347#1', '-101746516#2', '1164653913', '71324347#3', '-101415864#1', '71223791#1', '-679068326#3', '-101746498#0', '463830226#1', '-101765339#0', '-101787498', '101784499#1', '101418583', '-47374664', '-101749451', '47374680#2', '619185406', '-101746498#1', '-1200809291', '-282689986', '-619184994', '-101608960', '-100468681#2', '-100488715#6', '-659278038#0', '-100488715#3', '100525438#0', '101607969#1', '-101604497', '-336860316#0', '120829754', '1200809292', '659278036#2', '71223800#6', '526438862#1', '-101752975#1', '336863927#3', '101607964#0', '-101605951#3', '101594822', '-1323419242#1', '101607969#0', '-101606549', '-101605951#5', '416409192#5', '76867740', '-101765343', '-416375814', '-352797377', '-71223800#4', '-479315694', '-336860317#0', '-101606546', '101416508#1', '-101749465', '-282689981#1', '-282689975#3', '-336860317#1', '336863927#0', '1323419244', '120509990', '120511301', '-101746516#0', '101605951#0', '416405090', '-101767911', '-101752970#1', '-101754276#1', '101765339#2', '-416409191#1', '-101784793', '659281081#0', '-101749449#0', '-173532317#1'],
        "destinations" : ['-100468681#2', '464265154', '416373452', '101606547#1', '-352797377', '-100468681#0', '-101787493', '-101752975#1', '282689983#1', '-372606529#4', '101418583', '526438862#2', '100475365#1', '336863934', '-101418584#1', '-101746498#2', '-101754276#1', '-101608959', '-1323419243', '-101787498', '101767915#1', '-101611601', '-100468675#0', '-416409191#0', '-101594836#0', '-101604513#1', '-71324343#0', '-659283000#3', '-336863927#1', '-619184994', '-101787489', '-101746516#0', '-101601832#0', '100525438#0', '-1200809291', '-101599076#0', '71324347#3', '-101749458#0', '463830226#1', '-47374664', '-1047412913', '-100525438#2', '-101415865#0', '101607963', '416428068', '-101746508#0', '-101601551', '120509990', '-100468681#1', '526438862#1', '-101418584#0', '868087112', '-1323419247', '71324343#3', '-416412034', '-71223800#1', '1047412912#0', '71223791#1', '120509991', '-120511144#2', '-101752970#2', '-282690560#0', '282479125#1', '-101600741#2', '-101752975#2', '71324347#5', '659281081#1', '-101609498#0', '1164659926', '-101608828', '-101604518', '-1323419242#1', '101608824', '101754278#1', '-23887076#2', '526438862#0', '120829754', '-101605951#5', '-416409192#7', '101417074', '-120511144#0', '-100488715#7', '-101609494#0', '-101606549', '-101767922', '-837309289#1', '-416375814', '-101606546', '-336860317#0', '-101411571#1', '1323419244', '-71223800#4', '-101746514', '-101418589'],
        "number_of_paths" : 4,
        "beta" : -5,
        "num_samples" : 10,
        "visualize_paths" : False
    } 
}

#### Environment initialization

> In this example, the environment initially contains only human agents.

> If the paths are already created then create_paths=False, we don't have to create again.

In [None]:
env = TrafficEnvironment(seed=42, create_agents=False, create_paths=False, save_detectors_info=False, **env_params)

In [None]:
print("Number of total agents is: ", len(env.all_agents), "\n")
print("Number of human agents is: ", len(env.human_agents), "\n")
print("Number of machine agents (autonomous vehicles) is: ", len(env.machine_agents), "\n")

> Reset the environment and the connection with SUMO

In [None]:
env.start()
env.reset()

#### Human learning

In [None]:
for human in env.human_agents:

    inverse = 1 / np.array(human.initial_knowledge)
    invserse_normalized = inverse / inverse.sum()

    indices = np.arange(len(human.initial_knowledge))
    human.default_action = np.random.choice(indices, size=1, p=invserse_normalized)[0]

#### Mutation

> **Mutation**: a portion of human agents are converted into machine agents (autonomous vehicles). 

In [None]:
env.mutation()

In [None]:
print("Number of total agents is: ", len(env.all_agents), "\n")
print("Number of human agents is: ", len(env.human_agents), "\n")
print("Number of machine agents (autonomous vehicles) is: ", len(env.machine_agents), "\n")

> `TorchRL` enables us to make different groups with different agents. Here, all the AV agents are included in one group.

In [None]:
group = {'agents': [str(machine.id) for machine in env.machine_agents]}

#### PettingZoo environment wrapper

In [None]:
env = PettingZooWrapper(
    env=env,
    use_mask=True, # Whether to use the mask in the outputs. It is important for AEC environments to mask out non-acting agents.
    categorical_actions=True,
    done_on_any = False, # Whether the environment’s done keys are set by aggregating the agent keys using any() (when True) or all() (when False).
    group_map=group,
    device=device
)

#### Transforms

In [None]:
env = TransformedEnv(
    env,
    RewardSum(in_keys=[env.reward_key], out_keys=[("agents", "episode_reward")]),
)

The <code style="color:white">check_env_specs()</code> function runs a small rollout and compared it output against the environment specs. It will raise an error if the specs aren't properly defined.

In [None]:
check_env_specs(env)


In [None]:
reset_td = env.reset()

#### Policy/Actor network

In [None]:
share_parameters_policy = False 

policy_net = torch.nn.Sequential(
    MultiAgentMLP(
        n_agent_inputs = env.observation_spec["agents", "observation"].shape[-1],
        n_agent_outputs = env.action_spec.space.n,
        n_agents = env.n_agents,
        centralised=False,
        share_params=share_parameters_policy,
        device=device,
        depth=policy_network_depth,
        num_cells=policy_network_num_cells,
        activation_class=torch.nn.Tanh,
    ),
)

In [None]:
policy_module = TensorDictModule(
    policy_net,
    in_keys=[("agents", "observation")],
    out_keys=[("agents", "logits")],
) 

In [None]:
policy = ProbabilisticActor(
    module=policy_module,
    spec=env.action_spec,
    in_keys=[("agents", "logits")],
    out_keys=[env.action_key],
    distribution_class=Categorical,
    return_log_prob=True,
    log_prob_key=("agents", "sample_log_prob"),
)

#### Critic network

> The critic reads the observations and returns the corresponding value estimates.

In [None]:
share_parameters_critic = True
mappo = True  # IPPO if False

critic_net = MultiAgentMLP(
    n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1],
    n_agent_outputs=1, 
    n_agents=env.n_agents,
    centralised=mappo,
    share_params=share_parameters_critic,
    device=device,
    depth=critic_network_depth,
    num_cells=critic_network_num_cells,
    activation_class=torch.nn.ReLU,
)

critic = TensorDictModule(
    module=critic_net,
    in_keys=[("agents", "observation")],
    out_keys=[("agents", "state_value")],
)

#### Collector

In [None]:
collector = SyncDataCollector(
    env,
    policy,
    device=device,
    storing_device=device,
    frames_per_batch=frames_per_batch,
    total_frames=total_frames,
    reset_at_each_iter=True
) 

#### Replay buffer

In [None]:
replay_buffer = ReplayBuffer(
    storage=LazyTensorStorage(
        frames_per_batch, device=device
    ),  
    sampler=SamplerWithoutReplacement(),
    batch_size=minibatch_size,
)

#### PPO loss function

In [None]:
loss_module = ClipPPOLoss(
    actor_network=policy,
    critic_network=critic,
    clip_epsilon=clip_epsilon,
    entropy_coef=entropy_eps,
    normalize_advantage=False,
)
loss_module.set_keys( 
    reward=env.reward_key,  
    action=env.action_key, 
    sample_log_prob=("agents", "sample_log_prob"),
    value=("agents", "state_value"),
    done=("agents", "done"),
    terminated=("agents", "terminated"),
)

loss_module.make_value_estimator(
    ValueEstimators.GAE, gamma=gamma, lmbda=lmbda
) 

GAE = loss_module.value_estimator

optim = torch.optim.Adam(loss_module.parameters(), lr)

#### Training loop

In [None]:
pbar = tqdm(total=n_iters, desc="episode_reward_mean = 0")

episode_reward_mean_list = []
loss_values = []
loss_entropy = []
loss_objective = []
loss_critic = []

for tensordict_data in collector: ##loops over frame_per_batch

    ## Generate the rollouts
    tensordict_data.set(
        ("next", "agents", "done"),
        tensordict_data.get(("next", "done"))
        .unsqueeze(-1)
        .expand(tensordict_data.get_item_shape(("next", env.reward_key))),  # Adjust index to start from 0
    )
    tensordict_data.set(
        ("next", "agents", "terminated"),
        tensordict_data.get(("next", "terminated"))
        .unsqueeze(-1)
        .expand(tensordict_data.get_item_shape(("next", env.reward_key))),  # Adjust index to start from 0
    )

    # Compute GAE for all agents
    with torch.no_grad():
            GAE(
                tensordict_data,
                params=loss_module.critic_network_params,
                target_params=loss_module.target_critic_network_params,
            )

    data_view = tensordict_data.reshape(-1)  
    replay_buffer.extend(data_view)

    ## Update the policies of the learning agents
    for _ in range(num_epochs):
        for _ in range(frames_per_batch // minibatch_size):
            subdata = replay_buffer.sample()
            loss_vals = loss_module(subdata)

            loss_value = (
                loss_vals["loss_objective"]
                + loss_vals["loss_critic"]
                + loss_vals["loss_entropy"]
            )

            loss_value.backward()

            torch.nn.utils.clip_grad_norm_(
                loss_module.parameters(), max_grad_norm
            ) 

            optim.step()
            optim.zero_grad()

            loss_values.append(loss_value.item())

            loss_entropy.append(loss_vals["loss_entropy"].item())

            loss_objective.append(loss_vals["loss_objective"].item())

            loss_critic.append(loss_vals["loss_critic"].item())


   
    collector.update_policy_weights_()
   
    # Logging
    done = tensordict_data.get(("next", "agents", "done"))  # Get done status for the group

    episode_reward_mean = (
        tensordict_data.get(("next", "agents", "episode_reward"))[done].mean().item()
    )
    episode_reward_mean_list.append(episode_reward_mean)


    pbar.set_description(f"episode_reward_mean = {episode_reward_mean}", refresh=False)
    pbar.update()

> Testing phase

In [None]:
policy.eval()

for episode in range(test_eps):
    env.rollout(len(env.machine_agents), policy=policy)

>  Check `\plots` directory to find the plots created from this experiment.

In [None]:
env.plot_results()

> The plots reveal that the introduction of AVs into urban traffic influences human agents' decision-making. This insight highlights the need for research aimed at mitigating potential negative effects of AV introduction, such as increased human travel times, congestion, and subsequent rises in $CO_2$ emissions.

| |  |
|---------|---------|
| **Action shifts of human and AV agents** ![](plots_saved/mappo_actions_shifts.png) | **Action shifts of all vehicles in the network** ![](plots_saved/mappo_actions.png) |
| ![](plots_saved/mappo_rewards.png) | ![](plots_saved/mappo_travel_times.png) |


<p align="center">
  <img src="plots_saved/mappo_tt_dist.png" width="700" />
</p>


> Interrupt the connection with `SUMO`.

In [None]:
env.stop_simulation()