# VDN algorithm implementation

> In this notebook, we implement a state-of-the-art Multi Agent Reinforcement Leaning (MARL) algorithm **[VDN](https://arxiv.org/abs/1706.05296)** in our environment. **VDN** is a deep algorithm for cooperative MARL, particularly suited for situations where agents receive a single, shared reward. Value-decomposition networks are a step towards automatically decomposing complex learning problems into local, more readile learnable sub-problems.


> Tutorial based on [VDN TorchRL Tutorial](https://github.com/pytorch/rl/blob/main/sota-implementations/multiagent/qmix_vdn.py).

<img src="../../docs/img/vdn.png" alt="VDN" width="700"/>


> Picture taken from VDN [paper](https://arxiv.org/pdf/1706.05296).

#### High-level overview of VDN algorithm

The joint action-value function for the system can be additively decomposed into value functions accross agents:

$$
Q((h^{1}, h^{2}, \ldots, h^{d}), (a^{1}, a^{2}, \ldots, a^{d})) \approx \sum_{i=1}^{d} \tilde{Q}_i(h^{i}, a^{i}),
$$


where the $\tilde{Q}_i$ depends only on each agent's local observations.

**Value-Decomposition** outperforms both centralized and fully independent learning approaches. When combined with additional techniques, it consistently yields agents that significantly surpass their centralized and independent counterparts.


### Simulation overview

> We simulate our environment with an initial population of **20 human agents**. These agents navigate the environment and eventually converge on the fastest path. After this convergence, we will transition **10 of these human agents** into **machine agents**, specifically autonomous vehicles (AVs), which will then employ the QMIX reinforcement learning algorithms to further refine their learning.

#### Imported libraries

In [1]:
import os
import sys
import time
import torch
from torch import nn


from tensordict.nn import TensorDictModule, TensorDictSequential
from torchrl.collectors import SyncDataCollector
from torchrl.envs.libs.pettingzoo import PettingZooWrapper
from torchrl.envs.utils import check_env_specs
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl._utils import logger as torchrl_logger
from torchrl.data import TensorDictReplayBuffer
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.envs import RewardSum, TransformedEnv
from torchrl.modules import EGreedyModule, QValueModule, SafeSequential
from torchrl.modules.models.multiagent import MultiAgentMLP, VDNMixer
from torchrl.objectives import SoftUpdate, ValueEstimators
from torchrl.objectives.multiagent.qmixer import QMixerLoss


sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '../../')))

from RouteRL.keychain import Keychain as kc
from RouteRL.environment.environment import TrafficEnvironment
from RouteRL.services.plotter import Plotter
from RouteRL.utilities import get_params

os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"


#### Hyperparameters setting

In [2]:
# Devices
device = (
    torch.device(0)
    if torch.cuda.is_available()
    else torch.device("cpu")
)

print("device is: ", device)
vmas_device = device  # The device where the simulator is run

# Sampling
frames_per_batch = 150  # Number of team frames collected per training iteration
n_iters = 100  # Number of sampling and training iterations - the episodes the plotter plots
total_frames = frames_per_batch * n_iters

# Training
num_epochs = 1  # Number of optimization steps per training iteration
minibatch_size = 16  # Size of the mini-batches in each optimization step
lr = 3e-3  # Learning rate
max_grad_norm = 5.0  # Maximum norm for the gradients
memory_size = 1000  # Size of the replay buffer
tau =  0.005

# PPO
clip_epsilon = 0.2  # clip value for PPO loss
gamma = 0.99  # discount factor
lmbda = 0.9  # lambda for generalised advantage estimation
entropy_eps = 1e-4  # coefficient of the entropy term in the PPO loss

device is:  cpu


#### Environment initialization

> In this example, the environment initially contains only human agents.

In [3]:
params= {
    "agent_generation_parameters" : {
        "num_agents" : 100,
        "ratio_mutating" : 0.3
    }
}
env = TrafficEnvironment(params, generate_agent_data=True, generate_paths=True)

print(env)

[CONFIRMED] Environment variable exists: SUMO_HOME
[SUCCESS] Added module directory: /opt/homebrew/opt/sumo/share/sumo/tools
TrafficEnvironment with 100 agents.            
0 machines and 100 humans.            
Machines: []            
Humans: [Human 0, Human 1, Human 2, Human 3, Human 4, Human 5, Human 6, Human 7, Human 8, Human 9, Human 10, Human 11, Human 12, Human 13, Human 14, Human 15, Human 16, Human 17, Human 18, Human 19, Human 20, Human 21, Human 22, Human 23, Human 24, Human 25, Human 26, Human 27, Human 28, Human 29, Human 30, Human 31, Human 32, Human 33, Human 34, Human 35, Human 36, Human 37, Human 38, Human 39, Human 40, Human 41, Human 42, Human 43, Human 44, Human 45, Human 46, Human 47, Human 48, Human 49, Human 50, Human 51, Human 52, Human 53, Human 54, Human 55, Human 56, Human 57, Human 58, Human 59, Human 60, Human 61, Human 62, Human 63, Human 64, Human 65, Human 66, Human 67, Human 68, Human 69, Human 70, Human 71, Human 72, Human 73, Human 74, Human 75, Huma

In [4]:
print("Number of total agents is: ", len(env.all_agents), "\n")
print("Agents are: ", env.all_agents, "\n")
print("Number of human agents is: ", len(env.human_agents), "\n")
print("Number of machine agents (autonomous vehicles) is: ", len(env.machine_agents), "\n")

Number of total agents is:  100 

Agents are:  [Human 0, Human 1, Human 2, Human 3, Human 4, Human 5, Human 6, Human 7, Human 8, Human 9, Human 10, Human 11, Human 12, Human 13, Human 14, Human 15, Human 16, Human 17, Human 18, Human 19, Human 20, Human 21, Human 22, Human 23, Human 24, Human 25, Human 26, Human 27, Human 28, Human 29, Human 30, Human 31, Human 32, Human 33, Human 34, Human 35, Human 36, Human 37, Human 38, Human 39, Human 40, Human 41, Human 42, Human 43, Human 44, Human 45, Human 46, Human 47, Human 48, Human 49, Human 50, Human 51, Human 52, Human 53, Human 54, Human 55, Human 56, Human 57, Human 58, Human 59, Human 60, Human 61, Human 62, Human 63, Human 64, Human 65, Human 66, Human 67, Human 68, Human 69, Human 70, Human 71, Human 72, Human 73, Human 74, Human 75, Human 76, Human 77, Human 78, Human 79, Human 80, Human 81, Human 82, Human 83, Human 84, Human 85, Human 86, Human 87, Human 88, Human 89, Human 90, Human 91, Human 92, Human 93, Human 94, Human 95, Hu

> Reset the environment and the connection with SUMO

In [5]:
env.start()
env.reset()

 Retrying in 1 seconds


({}, {})

#### Human learning

In [6]:
num_episodes = 100

for episode in range(num_episodes):
    env.step()

#### Mutation

> **Mutation**: a portion of human agents are converted into machine agents (autonomous vehicles). You can adjust the number of agents to be mutated in the <code style="color:white">/params.json</code> file.

In [7]:
env.mutation()

In [8]:
print("Number of total agents is: ", len(env.all_agents), "\n")
print("Agents are: ", env.all_agents, "\n")
print("Number of human agents is: ", len(env.human_agents), "\n")
print("Number of machine agents (autonomous vehicles) is: ", len(env.machine_agents), "\n")

Number of total agents is:  100 

Agents are:  [Machine 63, Machine 78, Machine 84, Machine 91, Machine 19, Machine 38, Machine 77, Machine 95, Machine 10, Machine 59, Machine 69, Machine 23, Machine 92, Machine 65, Machine 45, Machine 55, Machine 56, Machine 72, Machine 31, Machine 86, Machine 20, Machine 75, Machine 13, Machine 43, Machine 32, Human 0, Human 1, Human 2, Human 3, Human 4, Human 5, Human 6, Human 7, Human 8, Human 9, Human 11, Human 12, Human 14, Human 15, Human 16, Human 17, Human 18, Human 21, Human 22, Human 24, Human 25, Human 26, Human 27, Human 28, Human 29, Human 30, Human 33, Human 34, Human 35, Human 36, Human 37, Human 39, Human 40, Human 41, Human 42, Human 44, Human 46, Human 47, Human 48, Human 49, Human 50, Human 51, Human 52, Human 53, Human 54, Human 57, Human 58, Human 60, Human 61, Human 62, Human 64, Human 66, Human 67, Human 68, Human 70, Human 71, Human 73, Human 74, Human 76, Human 79, Human 80, Human 81, Human 82, Human 83, Human 85, Human 87, Hu

> Create a group that contains all the machine (RL) agents.

>  **Hint:** the agents aren't competely independent in this example.

In [9]:
machine_list = []
for machines in env.machine_agents:
    machine_list.append(str(machines.id))
      
group = {'agents': machine_list}

#### PettingZoo environment wrapper

In [10]:
env = PettingZooWrapper(
    env=env,
    use_mask=True,
    categorical_actions=True,
    done_on_any = False,
    group_map=group,
    device=device
)

> The environment is defined by a series of metadata that describe what can be expected during its execution. 

There are four specs to look at:

- <code style="color:white">action_spec</code> defines the action space;

- <code style="color:white">reward_spec</code> defines the reward domain;

- <code style="color:white">done_spec</code> defines the done domain;

- <code style="color:white">observation_spec</code> which defines the domain of all other outputs from environment steps;

In [11]:
print("action_spec:", env.full_action_spec, "\n\n")
print("reward_spec:", env.full_reward_spec, "\n\n")
print("done_spec:", env.full_done_spec, "\n\n")
print("observation_spec:", env.observation_spec, "\n\n")

action_spec: Composite(
    agents: Composite(
        action: Categorical(
            shape=torch.Size([25]),
            space=CategoricalBox(n=3),
            device=cpu,
            dtype=torch.int64,
            domain=discrete),
        device=cpu,
        shape=torch.Size([25])),
    device=cpu,
    shape=torch.Size([])) 


reward_spec: Composite(
    agents: Composite(
        reward: UnboundedContinuous(
            shape=torch.Size([25, 1]),
            space=ContinuousBox(
                low=Tensor(shape=torch.Size([25, 1]), device=cpu, dtype=torch.float32, contiguous=True),
                high=Tensor(shape=torch.Size([25, 1]), device=cpu, dtype=torch.float32, contiguous=True)),
            device=cpu,
            dtype=torch.float32,
            domain=continuous),
        device=cpu,
        shape=torch.Size([25])),
    device=cpu,
    shape=torch.Size([])) 


done_spec: Composite(
    done: Categorical(
        shape=torch.Size([1]),
        space=CategoricalBox(n=2),


> Agent group mapping

In [12]:
print("env.group is: ", env.group_map, "\n\n")

env.group is:  {'agents': ['63', '78', '84', '91', '19', '38', '77', '95', '10', '59', '69', '23', '92', '65', '45', '55', '56', '72', '31', '86', '20', '75', '13', '43', '32']} 




#### Transforms

> We can append any TorchRL transform we need to our environment. These will modify its input/output in some desired way. In multi-agent contexts, it is paramount to provide explicitly the keys to modify.



Here we instatiate a <code style="color:white">RewardSum</code> transformer that will sum rewards over episode.

In [13]:
env = TransformedEnv(
    env,
    RewardSum(in_keys=[env.reward_key], out_keys=[("agents", "episode_reward")]),
)

The <code style="color:white">check_env_specs()</code> function runs a small rollout and compared it output against the environment specs. It will raise an error if the specs aren't properly defined.

In [14]:
check_env_specs(env)


2025-01-30 23:31:01,971 [torchrl][INFO] check_env_specs succeeded!


In [15]:
reset_td = env.reset()


#### Policy network

> Instantiate an `MPL` that can be used in multi-agent contexts.

In [16]:
net = MultiAgentMLP(
        n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1],
        n_agent_outputs=env.action_spec.space.n,
        n_agents=env.n_agents,
        centralised=False,
        share_params=True,
        device=device,
        depth=2,
        num_cells=256,
        activation_class=nn.Tanh,
    )

> The neural network is wrapped in a `TensorDictModule`, which is responsible for managing the input and output interactions with the tensordict. Specifically, the module reads from the specified `in_keys`, processes the inputs through the neural network, and writes the resulting outputs to the defined `out_keys`. 

In [17]:
module = TensorDictModule(
        net, in_keys=[("agents", "observation")], out_keys=[("agents", "action_value")]
)

> **`QValueModule`** takes a tensor as input, which contains the `Q-values` (these values indicate how good it is to take each action in the given state). It identifies the action with the highest `Q-values` using the `argmax` operation.

In [18]:
value_module = QValueModule(
    action_value_key=("agents", "action_value"),
    out_keys=[
        env.action_key,
        ("agents", "action_value"),
        ("agents", "chosen_action_value"),
    ],
    spec=env.action_spec,
    action_space=None,
)

> **`SafeSequential`** is a `TensordictModule` that will concatenate the parameter lists in a single list.

In [19]:
qnet = SafeSequential(module, value_module)

> In the already made `Q network` the **`Epsilon-Greedy exploration module`** is added. This module randomly updates the actions in a tensordict given an epsilon greedy exploration strategy.

In [20]:
qnet_explore = TensorDictSequential(
    qnet,
    EGreedyModule(
        eps_init=0.3,
        eps_end=0,
        annealing_num_steps=int(total_frames * (1 / 2)),
        action_key=env.action_key,
        spec=env.action_spec,
    ),
)

#### Mixer

> `VDNMixer` mixes **the local Q values** of the agents into **a global Q value** by summing them together, accorbing to [VDN paper](https://arxiv.org/pdf/1706.05296).

In [21]:
mixer = TensorDictModule(
    module=VDNMixer(
        n_agents=env.n_agents,
        device=device,
    ),
    in_keys=[("agents", "chosen_action_value")],
    out_keys=["chosen_action_value"],
)

#### Collector

Collectors perform the following operations:

1. **Reset Environment**: Initialize the environment.
2. **Compute Action**: Determine the next action using the policy and the latest observation.
3. **Execute Step**: Step through the environment with the computed action.

These operations repeat until the environment signals to stop.

In [22]:
collector = SyncDataCollector(
        env,
        qnet_explore,
        device=device,
        storing_device=device,
        frames_per_batch=frames_per_batch,
        total_frames=total_frames,
    )

#### Replay buffer

> In an off-policy setting, the replay buffer exceeds the number of frames utilized for policy updates, allowing agents to learn from previous rollouts as well.



In [23]:
replay_buffer = TensorDictReplayBuffer(
        storage=LazyTensorStorage(memory_size, device=device),
        sampler=SamplerWithoutReplacement(),
        batch_size=minibatch_size,
    )

#### Qmix loss function

> `QMixerLoss` mixes *local agent q values* into *a global q value* according to a mixing network and then uses DQN updated on the global value.

In [24]:
loss_module = QMixerLoss(qnet, mixer, delay_value=True)

loss_module.set_keys(
    action_value=("agents", "action_value"),
    local_value=("agents", "chosen_action_value"),
    global_value="chosen_action_value",
    action=env.action_key,
)

loss_module.make_value_estimator(ValueEstimators.TD0, gamma=gamma)
target_net_updater = SoftUpdate(loss_module, eps=1 - tau)

optim = torch.optim.Adam(loss_module.parameters(), lr)


#### Training loop

In [25]:
total_time = 0
total_frames = 0
sampling_start = time.time()

for i, tensordict_data in enumerate(collector):
    torchrl_logger.info(f"\nIteration {i}")

    sampling_time = time.time() - sampling_start

    ## Generate the rollouts
    tensordict_data.set(
        ("next", "reward"), tensordict_data.get(("next", env.reward_key)).mean(-2)
    )
    del tensordict_data["next", env.reward_key]
    tensordict_data.set(
        ("next", "episode_reward"),
        tensordict_data.get(("next", "agents", "episode_reward")).mean(-2),
    )
    del tensordict_data["next", "agents", "episode_reward"]


    current_frames = tensordict_data.numel()
    total_frames += current_frames
    data_view = tensordict_data.reshape(-1)
    replay_buffer.extend(data_view)


    training_tds = []
    training_start = time.time()
    
    ## Update the policies of the learning agents
    for _ in range(num_epochs):
        for _ in range(frames_per_batch // minibatch_size):
            subdata = replay_buffer.sample()
            loss_vals = loss_module(subdata)
            training_tds.append(loss_vals.detach())

            loss_value = loss_vals["loss"]

            loss_value.backward()

            total_norm = torch.nn.utils.clip_grad_norm_(
                loss_module.parameters(), max_grad_norm
            )
            training_tds[-1].set("grad_norm", total_norm.mean())

            optim.step()
            optim.zero_grad()
            target_net_updater.step()

    qnet_explore[1].step(frames=current_frames)  # Update exploration annealing
    collector.update_policy_weights_()

    training_time = time.time() - training_start

    iteration_time = sampling_time + training_time
    total_time += iteration_time
    training_tds = torch.stack(training_tds) 

2025-01-30 23:31:07,356 [torchrl][INFO] 
Iteration 0
2025-01-30 23:31:12,939 [torchrl][INFO] 
Iteration 1
2025-01-30 23:31:17,466 [torchrl][INFO] 
Iteration 2
2025-01-30 23:31:22,649 [torchrl][INFO] 
Iteration 3
2025-01-30 23:31:27,182 [torchrl][INFO] 
Iteration 4
2025-01-30 23:31:31,595 [torchrl][INFO] 
Iteration 5
2025-01-30 23:31:35,889 [torchrl][INFO] 
Iteration 6
2025-01-30 23:31:40,243 [torchrl][INFO] 
Iteration 7
2025-01-30 23:31:44,576 [torchrl][INFO] 
Iteration 8
2025-01-30 23:31:48,866 [torchrl][INFO] 
Iteration 9
2025-01-30 23:31:53,242 [torchrl][INFO] 
Iteration 10
2025-01-30 23:31:57,596 [torchrl][INFO] 
Iteration 11
2025-01-30 23:32:01,913 [torchrl][INFO] 
Iteration 12
2025-01-30 23:32:06,223 [torchrl][INFO] 
Iteration 13
2025-01-30 23:32:10,648 [torchrl][INFO] 
Iteration 14
2025-01-30 23:32:14,992 [torchrl][INFO] 
Iteration 15
2025-01-30 23:32:19,373 [torchrl][INFO] 
Iteration 16
2025-01-30 23:32:23,682 [torchrl][INFO] 
Iteration 17
2025-01-30 23:32:27,949 [torchrl][INFO

>  Check `\plots` directory to find the plots created from this experiment.

In [26]:
from RouteRL.services import plotter
plotter(params[kc.PLOTTER])

<RouteRL.services.plotter.Plotter at 0x16e270c20>

In [27]:
env.stop()