<a href="https://colab.research.google.com/github/Strojove-uceni/2024-final-letadylka-prochazka-belohlavek/blob/main/demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**TABLE OF CONTENTS**

>[Introduction](#scrollTo=Mey2sVO3w06m)

>>[Abstract](#scrollTo=Mey2sVO3w06m)

>>[How to Run the Project](#scrollTo=Mey2sVO3w06m)

>>>[Disclamer](#scrollTo=ZDfbtUw36j_1)

>[Reinforcement Learning](#scrollTo=ZDfbtUw36j_1)

>>[Problems of RL](#scrollTo=ZDfbtUw36j_1)

>>>[Idea behind the approach we use](#scrollTo=ZDfbtUw36j_1)

>>[Replay Buffer](#scrollTo=ZDfbtUw36j_1)

>>>>[Importance for RNNS](#scrollTo=ZDfbtUw36j_1)

>>>>[Prioritized Experience Replay](#scrollTo=ZDfbtUw36j_1)

>>>>>[Sampling Mechanism](#scrollTo=ZDfbtUw36j_1)

>[NN Components](#scrollTo=xdxNv9pODPz0)

>>[Classical MLP](#scrollTo=Z4BoIqSPdCtb)

>>[Attention Model](#scrollTo=HU4Z78Kv9hiW)

>>[Q-values Prediction](#scrollTo=FN9jHalCI6UK)

>>>[Q-Net](#scrollTo=65H4yFBgdJeh)

>>>[DGN](#scrollTo=_s_7VNFm91w8)

>>>[DQN](#scrollTo=G5-kiKLj98yv)

>>>[CommNet](#scrollTo=VK8i3VvG-Bg-)

>>>[Reward System: Detailed Description and Discussion](#scrollTo=f_ZqrL8K8zqq)

>>>>[Reward Components](#scrollTo=f_ZqrL8K8zqq)

>>>[Summary](#scrollTo=f_ZqrL8K8zqq)

>>[State Aggregation](#scrollTo=AGSMkNTI-KIH)

>>>[SUM](#scrollTo=AGSMkNTI-KIH)

>>>[GCN](#scrollTo=7S-b66SX-TyK)

>>[NetMon](#scrollTo=ojoJfBcX-Wrm)

>>>[State management](#scrollTo=ojoJfBcX-Wrm)

>[Multi GPU setups](#scrollTo=Pw68VqsjYacn)

>>[Bi-GPU setup](#scrollTo=Pw68VqsjYacn)

>>[Multi-GPU setup](#scrollTo=Pw68VqsjYacn)

>[Selecting Parameters](#scrollTo=1Ei12632gg3Z)

>>[Common Parameters in the Sweep](#scrollTo=1Ei12632gg3Z)

>>[CommNet specific](#scrollTo=1Ei12632gg3Z)

>>[DQN, DGN specific](#scrollTo=1Ei12632gg3Z)

>[Advice for Parameter Selection](#scrollTo=OhkNKLaYlGRj)

>>[CommNet settings](#scrollTo=OhkNKLaYlGRj)

>>[DQN, DGN settings](#scrollTo=OhkNKLaYlGRj)

>>[Aggregation Type](#scrollTo=OhkNKLaYlGRj)

>>>[WANDB sweep for CommNet](#scrollTo=e1gjgCZZzhzm)

>>>[Hyperparameters importance for CommNet](#scrollTo=ghDCmQ55xYYl)

>>>[WANDB sweep for DQN vs. DGN](#scrollTo=TO3iriopz9P7)

>>>[HYPERPARAMETERS IMPORTANCE FOR DGN and DQN](#scrollTo=ca4DwESCxrjT)

>[Our Results](#scrollTo=S8nU4WQLwJI7)

>>[Our Selected Best Performing Models](#scrollTo=S8nU4WQLwJI7)

>>>[Key observations](#scrollTo=S8nU4WQLwJI7)

>>>[Hyperparameter search](#scrollTo=S8nU4WQLwJI7)

>>[Heatmaps](#scrollTo=K0VfxzH_jn4H)

>>>[COMMNET](#scrollTo=8GQ6b3ibeK6K)

>>>[DQN](#scrollTo=eNq4qi-Pjbwl)

>>>[DGN](#scrollTo=xkXvwNuhkwSj)

>>[Shortest Path Ratio](#scrollTo=MMiw4WCyk_W9)

>>>[COMMNET](#scrollTo=-qaQ7fj_mNZJ)

>>>[DQN](#scrollTo=E3XestfLmXa0)

>[Conclusion](#scrollTo=buWgEb9M5rJd)

>[References](#scrollTo=XI6rlOjqEeNN)



**Authors: Michal Bělohlávek, Tomáš Procházka**

# <u>**Introduction**</u>
Welcome to the demo file, where you can run the project effortlessly and view the results and visualizations firsthand. While this provides an easy and pleasant way to experience our neural network, we strongly encourage anyone visiting this demo to run the project as intended, expand upon it, and enhance its capabilities.

## **Abstract**
This project was created and submitted as the final semester project for the Machine Learning 2 class at FNSPE CTU. It focuses on reinforcement learning for multiple agents controlled by a single neural network within a graph environment. The primary objective is to develop a neural network solution capable of efficiently navigating multiple planes across a fully connected graph, estimating the shortest paths while avoiding plane collisions.
To facilitate the learning process, we implemented an enhanced version of the classical replay buffer, which samples experiences based on predicted future rewards. Additionally, we created a dense reward system to encourage traversal along longer paths and implemented node masking to enable generalization across graphs with diverse neighbourhoods. Lastly, we provide additional code for multi-GPU setups to enable faster inference times.

## **How to Run the Project**
For those who decide to download the project and run the training on their PC, please beware of the configurations. A basic setup is present in /data as demo_config.yaml and runs on CPU.

Setting up capacity, minibatch_size or sequence_length too high may result in freezing the computer.

Most hyperparameters may be changed in the config.yaml file. If you intend to do your own sweeps on weights and biases, we have also uploaded a version of the main file ```wandb_main.py``` that supports sweep configuration.

If you however decide to only run the project in this demo file, note that any pre-trained models are too large to upload to the GitHub repo directly, so the training will be done from scratch here. The training will use the ```demo_confing.yaml``` with small number of steps and generally "low" settings, so taht the training can be completed in reasonable amount of time. Therefore, one should expect very poor results compared to the results we present at the end of this notebook. Note that we need to install specific versions of many libraries that are compatible, this may take a while.

The Runtime may encounter an error with tensorflow, simply restart Runtime and run the following cells and training will start. The training takes about 5 minutes. Sometimes Google Colab fails to connect to GitHub. The error you may encounter is that it fails to find the directory. To solve this, simply keep deleting the runtime until it works. If you encounter any errors about pip dependency, ignore it and proceed.

**To see the training and results, simply run the following code boxes.**


In [None]:
!git clone https://github.com/Strojove-uceni/2024-final-letadylka-prochazka-belohlavek

In [None]:
%cd /content/2024-final-letadylka-prochazka-belohlavek/
!pip install -r /content/2024-final-letadylka-prochazka-belohlavek/requirements.txt

/content/2024-final-letadylka-prochazka-belohlavek
Collecting gymnasium==1.0.0 (from -r /content/2024-final-letadylka-prochazka-belohlavek/requirements.txt (line 1))
  Downloading gymnasium-1.0.0-py3-none-any.whl.metadata (9.5 kB)
Collecting matplotlib==3.9.3 (from -r /content/2024-final-letadylka-prochazka-belohlavek/requirements.txt (line 2))
  Downloading matplotlib-3.9.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting numpy==2.2.0 (from -r /content/2024-final-letadylka-prochazka-belohlavek/requirements.txt (line 4))
  Downloading numpy-2.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (62 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.0/62.0 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pandas==2.2.3 (from -r /content/2024-final-letadylka-prochazka-belohlavek/requirements.txt (line 5))
  Downloading pandas-2.2.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata 

In [None]:
!pip uninstall tensorflow, pandas
!pip install numpy==1.26.4, tensorboard==2.18.0, pandas==2.2.2
!conda install -c conda-forge cudatoolkit cudnn

[31mERROR: Invalid requirement: 'tensorflow,': Expected end or semicolon (after name and no valid version specifier)
    tensorflow,
              ^[0m[31m
[0mCollecting numpy==1.26.4
  Downloading numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting tensorboard==2.18.0
  Downloading tensorboard-2.18.0-py3-none-any.whl.metadata (1.6 kB)
Collecting pandas==2.2.2
  Downloading pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (19 kB)
Downloading numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.2/18.2 MB[0m [31m51.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading tensorboard-2.18.0-py3-none-any.whl (5.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.5/5.5 MB[0m [3

In [None]:
%cd  /content/2024-final-letadylka-prochazka-belohlavek
!python /content/2024-final-letadylka-prochazka-belohlavek/src/main.py --demo_config data/demo_config.yaml

/content/2024-final-letadylka-prochazka-belohlavek
2024-12-13 14:24:19.342974: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-12-13 14:24:19.371717: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-12-13 14:24:19.379803: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-13 14:24:19.404203: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
se

<!-- # Overview of the Used Machine Learning Techniques -->
<span><font color="green;">
###Disclamer

**The following code is used as an illustration only, to see the functionalities of the code directly, visit the /src file on our GitHub repo. This approach was taken because it is not feasible to copy the whole code into Colab.**
</font></span>

(Or maybe it would be feasible but that would be an extreme violation of our hard work.)

# <u>**Reinforcement Learning**</u>
**Reinforcement Learning (RL)** is a powerful machine learning paradigm where agents learn to make decisions by interacting with some environment. Unlike Supervise Learning, where the model is trained on labeled data, RL agengts learn optimal behaviors through trial and error, guided by feedback in the form of rewards.

**Who/What is an agent/environment?**

- Agent is someone who makes decisions while interacting with the environment to achieve predefined goals.
- Environment is some external system with which the agent interacts. It provides something that we call **observations** and rewards based on the agent's actions.

**What is an observation?**

Observation represents the current state/situation of the environment which is perceived by the agent. It contains necessary information required for the agent to make a decision.

**What are actions?**

Actions are all the possible decisions the agent can take in a given state.


**What is a reward?**

Reward is a single number representing a feedback signal received by the agent after performing an action. It indicates the immediate benefit of that action.

**What is an experience?**

Experience compromises tuples - (state, action, reward, next state) that the agent accumulates over time while interacting with the environment. These experiences are crucial for learning and are stored in a replay buffer for training purposes.

**What is a policy?**

Policy is a strategy that the agent employs to decide actions based on the current state.

**What is a Q-function?**

Q-function is representing the expected cumulative reward of taking a particular action in a given state.

## <u>**Problems of RL**</u>
In systems where **several agents** are present (**Multi-Agent RL**) few problems arise with how everything is pipelined.

RL is often divided into two categories. **Centralized** approach and **Decentralized** approach. **Centralized** approach often involves a **single 'controller'/coordinator** that has access to all agents' information and makes decisions for all of the agents. **Decentralized** version enables each agent to make decesions **independently**, while relying on local information and limited communication. Both of these approaches run into several problems. The former allows for the best decision making, but does not scale to large graphs. On the other hand, the latter does not have access to enough information, making it less reactive.

### **Idea behind the approach we use**
The observation space is expanded with learned graph (**environment**) observations that leverage recurrent message passing. This approach keeps the agents reactive and do not have to grather infromatin about the whole graph before taking an action.

**This idea is not ours. We do not claim it, it was taken from [here](https://github.com/jw3il/graph-marl).**


## <u>**Replay Buffer**</u>
Our work prouds itself amongst other things on the advancement of a random **Replay Buffer** (within ```replay_buffer.py```) that significantly improved the prediction of paths that lead to future reward. A replay buffer is a key component in Reinforcement Learning that stores past experiences. This storage mechanism is essential for training models, especially those utilizing **RNN**s (**Recurrent Neural Networks**), which rely on sequential dependencies to function effectively.

#### **Importance for RNNS**
For RNNs, replay buffers play a vital role by allowing agents to learn from diverse trajectories while maintaining temporal coherence. Unlike traditional models that might sample individual transitions randomly, our approach involves sampling batches of sequences. This method ensures that the RNN captures meaningful patterns over time, thereby improving its ability to model long-term dependencies and make more informed predictions.

#### **Prioritized Experience Replay**
To further enhance the effectiveness of the replay buffer, we implemented a **Prioritized Experience Replay** mechanism. This version of replay buffer samples the batch sequences based on maximizing the **temporal differnce**(**TD**) **error**, which is the mean square error between predicted future and immediate rewards. By prioritizing experiences with higher TD errors, the agent focuses more on learning from actions that have a significant impact on future rewards

##### **Sampling Mechanism**
The sequences are still sampled on random but the sampling is weighted by the probabilities based on **priority**. The sampling probability is calculated as follows:

$$\text{probability} = \text{priority}^{\alpha}, \quad \alpha \in [0, 1]$$
and then normalized to form a valid probability distribution
$$\text{probability} = \frac{\text{priority}}{\sum \text{probability}}.$$

In addition, there is a system of weighing the TD errores to prevent over-prioritizing some sampled indices. This is done through scaling parameter $\beta$ that gradually increases from 0.4 to 1.0 during training, allowing the model to transition smoothly into using prioritized sequences.

You can read more on this prioritized sampling in the referenced [repo](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/rl/dqn/replay_buffer.py).


# <u>**NN Components**</u>
In the text below, we first describe models that were used for Q-values predictions:
- **DGN**
- **DQN**
- **Comm_net**,
    
then we go over the methods that were used to aggregate hidden node states:

- **SUM**
- **GCN**

and lastly we describe the **NetMon** class, that was originally provided by the authors.

## **Classical MLP**
MLP is a feed forward network that passes the input thourgh many linear layers with activation functions. In our case, we used leaky-ReLU as the activation function. It is also possible to modify this setting to for example GeLU in the config.yaml file but we should points out that this may lead to the agent learning to take forbidden edges, that are subsequently masked leading to insufficient gradient flow. This approach has not been explored in this project. Dropout is included for regularization.

In [None]:
import torch
from torch import nn

class MLP(nn.Module):
    """
    This is the underlying module for all used models within this work.
    """

    def __init__(self, in_features, mlp_units, activation_fn, activation_on_output = True):
        super(MLP, self).__init__()

        self.activation = activation_fn
        self.dropout = nn.Dropout(0.3)


        self.linear_layers = nn.ModuleList() # Storage for L layers
        previous_units = in_features

        # Transform units into a list
        if isinstance(mlp_units, int):
            mlp_units = [mlp_units]

        # Create a chain of layers
        for units in mlp_units:
            self.linear_layers.append(nn.Linear(previous_units, units))
            previous_units = units

        self.out_features = previous_units
        self.activation_on_ouput = activation_on_output

    # Forward pass
    def forward(self, x):

        # Inter layers
        for module in self.linear_layers[:-1]:
            x = module(x)
            if self.activation is not None:
                x = self.activation(x)
            x = self.dropout(x)

        # Pass through the last layer
        x = self.linear_layers[-1](x)
        if self.activation_on_ouput:
            x = self.activation(x)
            x = self.dropout(x)

        return x

## **Attention Model**
The AttModel class implements a multi-head attention mechanism that is inspired by "Attention is All You Need" paradigm. It utilizes the multiple attention head to capture different aspects of the input simultaneously. Scaling factors are applied to stabilize gradients during training.

Moreover, masking is used to further specialize the input and the values calculated based on the input data and the state of the agents with the environment.



In [None]:
class AttModel(nn.Module):
    """
        Basic attention model with with masking and scaling.
    """

    def __init__(self, in_features, k_features, v_features, out_features, num_heads, activation_fn, vkq_activation_fn):
        super(AttModel, self).__init__()


        self.k_features = k_features
        self.v_features = v_features
        self.num_heads = num_heads      # Number of attention heads

        self.fc_v = nn.Linear(in_features, v_features * num_heads)  # Transforming input features into Values for attention
        self.fc_k = nn.Linear(in_features, k_features * num_heads)  # Transforming input features into Keys for attention
        self.fc_q = nn.Linear(in_features, k_features * num_heads)  # Transforming input values into Queries for attention

        self.fc_out = nn.Linear(v_features * num_heads, out_features)   # Transforms the outputs from all attention heads into output dimension

        self.activation = activation_fn
        self.vkq_activation = vkq_activation_fn     # Activation function that can be applied into Values, Keys, Queries


        """
        Defining the scaling factor for attention as 1/ sqrt(d_k), this is the same as the publishing paper "Attention is All You Need".
        This is done for the purpose of reducing the gradient so it does not become too large. Later you will see that without it, the dot product
        would grow too large without the scaling.
        """
        self.attention_scale = 1 / (k_features **0.5)

        self.dropout = nn.Dropout(0.1)

    # Forward pass
    def forward(self, x, mask):
        batch_size, num_agents = x.shape[0], x.shape[1]

        """
        The code below does the following:
            - a linear mapping is applied on the inputs to obtain Values, Keys, Queries
            - the Values, Keys, Queries are then reshaped to separate the different attention heads of the model
            :reshape: will result in (batch_size, num_agents, num_heads, features_per_head)

        Pipeline:
                  Input x
                    |
              [Linear Layers] -> V, Q, K
                    |
            [Optional Activation] (vkq_activation_fn)
                    |
            [Reshape for Multi-Head]
                    |
            [Transpose for Heads]
                    |
            [Compute Attention Weights (Dot Product, Scale, Mask, Softmax)]
                    |
            [Apply Attention to Values]
                    |
            [Skip Connection]
                    |
            [Transpose and Concatenate Heads]
                    |
            [Final Linear Layer and Activation]
                    |
                  Output
        """

        v = self.fc_v(x).view(batch_size, num_agents, self.num_heads, self.v_features)
        q = self.fc_q(x).view(batch_size, num_agents, self.num_heads, self.k_features)
        k = self.fc_k(x).view(batch_size, num_agents, self.num_heads, self.k_features)

        if self.vkq_activation is not None:
            v = self.vkq_activation(v)
            q = self.vkq_activation(q)
            k = self.vkq_activation(k)

        # We rearrange the tensors to shape (batch_size, num_heads, num_agents, features_per_head)
        # This is done so we can perform batch multiplication over the batch size and heads
        q, k, v = q.transpose(1,2), k.transpose(1,2), v.transpose(1,2)

        # Add head axis (we are keeping the same mask for all attention heads)
        mask = mask.unsqueeze(1)    # (batch_size, 1, num_agents, num_agents) (1,1,20,20)

        """
        The attention is calculated as a dot product of all queries with all keys,
            while scaling it with the attention scale so it does not explode.
            - q is of shape             (batch_size, num_heads, num_agents, features_per_head)
            - k transposed is of shape  (batch_size, num_heads, features_per_head, num_agents)
            - the multiplication result is of shape (batch_size, num_heads, num_agents, num_agents)
        :masked_fill sets positions where mask == 0 to a large negative value - removes them from the attention computation practically
        """

        att_weights = torch.matmul(q, k.transpose(2, 3)) * self.attention_scale
        att = att_weights.masked_fill(mask==0, -1e9)
        att = F.softmax(att, dim=-1)    # Softmax is applied along the last dimension to obtain normalized attention probabilities
        att = self.dropout(att)

        # Now we combine the Values with respect to the attention we just computed
        """
            - att is of shape (batch_size, num_heads, num_agents, num_agents)
            - v is of shape (batch_size, num_heads, num_agents, v_features)
            - the multiplication result is of shape (batch_size, num_heads, num_agents, v_features)
        """
        out = torch.matmul(att, v)

        # We add a skip connection
        out  = torch.add(out, v)    # This additionally promotes gradient flow and mitigates vanishing gradient

        # Now "remove" the transpose and concatenate all heads together
        """
            - out is of shape (batch_size, num_heads, num_agents, v_features)
            - out after transpose is of shape (batch_size, num_agents, num_heads, v_features)
            - contiguous() ensures that the tensor is stored in a contiguous chunk of memory so that the reshape for view can happen
            - view is used to reshape the tensor to (batch_size, num_agents, v_features), therefore, we flatten the last two dimensions
                into a single one (num_heads * v_features)
            - final out is of shape  (batch_size, num_agents, num_heads * v_features)
        """

        out = out.transpose(1,2).contiguous().view(batch_size, num_agents, -1)
        out = self.activation(self.fc_out(out)) # Linear map into a desired feature dimension
        out = self.dropout(out)

        return out, att_weights

## <u>**Q-values Prediction**</u>
We use **Q-Net** for Q-value predictions, a **reinforcement learning** technique that assigns values to each potential future action based on agent's observations (state). In this particular setting, the Q-Net predicts the expected reward for **each edge** the agent could take at any given step.

To accomodate graphs with a varied number of edges per node, we implemented a **node mask** and generalized the setting to fit graphs with **variable** edge count for each node. The algorithm in Q-Net leverages dynamic programming weighted by the learning rate hyperparameter.

$$Q_{target} = (1-l)*Q_{now}+l * E[R_{t+1}(a_{t+1}, s_{t+1}) + \gamma * max_{a_{t+1}}Q_{next}(a_{t+1}, s_{t+1})| s_t],$$

where $R_{t+1}(a_{t+1}, s_{t+1})$ is the reward received at timestep *t* after taking action $a_t$. These experiences are sampled from a batch of experiences and $l$ denotes learning rate. From this formulation, we can see that sampling the batch indices that maximize the temporal difference error, we essentially grow the $Q_{now}$ values for future steps.

The goal of each agent is to maximize the expected future reward weighted by the gamma (discount) factor

$$max_{(a)_{t_0}^T} E[\sum_{t=t_0}^T \gamma^{t-t_0}R_{t}(a_t, s_t)],$$
where $\gamma \in (0,1)$.

### **Q-Net**
In our **Q-learning** framework, the **Q-net** serves as the Q-function approximator, providing Q-values representing expected rewards.  While we closely follow the original implementation, we have introduced a slight but impactful adjustment: instead of utilizing a single linear layer to transform inputs into action values, we employ an **MLP**.

This modification was necessary due to the **higher dimensionality of agent observations** resulting from the larger graph structure we are working with. By using an MLP, our architecture allows for a much smoother reduction that would occur if we had retained the original single linear layer structure. As a result, our Q-net can handle **more complex** and **high-dimensional data**.


In [None]:
class Q_Net(nn.Module):
    """
    This servers as the Q-function  approximator in RL. It estimates Q-values for each possible action given a particular state.
    What are Q-values? Rewards.
    So, given a particular state, this estimates the expected future rewards(Q-values) for each possible action our plane(agent) can take.
    """
    def __init__(self, in_features, actions):
        super(Q_Net, self).__init__()
        self.fc = MLP(in_features, (2048,1024,512,actions), None, False)

    def forward(self, x):
        return self.fc(x)

### **DGN**
**DGN (Deep Graph Neural Network)** integrates graph neural network with attention mechanisms to efficiently model interactions between agents. It starts with an encoder, that is represented by MLP, which processes the input features. The processed input is passed through a desired number of multi-head attention layers with masking to focus on relevant interactions. Finally, the output is fed into the Q-net, which estimates the Q-values.


In [None]:
class DGN(nn.Module):
    """
        The Deep Graph Neural network. Incorporates attention mechanism.
    """

    def __init__(self, in_features, mlp_units, num_actions, num_heads, num_attention_layers, activation_fn, kv_values):
        super(DGN, self).__init__()

        self.encoder = MLP(in_features, mlp_units, activation_fn)
        self.att_layers = nn.ModuleList()
        hidden_features = self.encoder.out_features

        print("In features of DGN: ", in_features)
        print("MLP units are: ", mlp_units)

        for _ in range(num_attention_layers):
            self.att_layers.append(
                AttModel(hidden_features, kv_values, kv_values, hidden_features, num_heads, activation_fn, activation_fn)
                                   )

        self.q_net = Q_Net(hidden_features * (num_attention_layers + 1), num_actions)

        self.att_weights = []

    def forward(self, x, mask):
        """
        Additional comment to the function:
            - each attention layer refines the representation h by focusing on relevant parts of the input
            - by concatenating the representations the feature set for the Q-network is enhanced, consequently making more informed decisions

        """

        h = self.encoder(x)     # Encodes the input featuers, has a shape of (batch_size, num_agents, hidden_features)
        q_input = h     # Initialize the q_input with encoded features
        self.att_weights.clear()    # Ensuring that attention weights from previous forward passes do not accumulate

        for attention_layer in self.att_layers:
            h, att_weight = attention_layer(h, mask)
            self.att_weights.append(att_weight)

            # Concatenation of outputs
            q_input = torch.cat((q_input, h), dim=-1)

        # Final q_input is of shape (batch_size, num_agents, hidden_features * (num_attention_layers +1))
        q = self.q_net(q_input)

        return q    # is of shape (batch_size, num_agents, num_actions)


### **DQN**
**DQN (Deep Q-Learning Network)** is a strandard architecture in Deep Q-Learning that leverages an MLP encoder to estimate **Q-values (expected rewards)**. The encoder transforms the incoming features into meaningful representations that are passed to the Q-network that estimates the Q-values.

One might wonder, **why not use a single MLP for the entire process?** The reasoning is that separating the architectures allows for greater  versatility and scalability. This approach makes it easier to adapt and extend the model painlessly.

In [None]:
class DQN(nn.Module):
    """
    Introduces simple Deep Feed Forward Neural Network( = MLP) as the encoder.
    """

    def __init__(self, in_features, mlp_units, num_actions, activation_fn):
        super(DQN, self).__init__()

        self.encoder = MLP(in_features, mlp_units, activation_fn)   # Encodes incoming features
        self.q_net = Q_Net(self.encoder.out_features, num_actions)  # Outputs Q-values
        self.activation = activation_fn

    def forward(self, x, mask):
        batch, agent, features = x.shape
        h = self.encoder(x)
        q = self.q_net(h)
        return q


###**CommNet**
**CommNet (Communication Network)** is a specialized network designed to facilitate information exchange between multiple agents. CommNet builds upon the standard **DQNR (Deep Q-Network Architecture)** architecture, which utilized LSTMs forward passes, allowing for inter-agent communication. One can select a desirable amount of communication rounds between agents. Communication is restricted be the adjacency matrix that allows communication only between neighbours. Additionally, an LSTM forward pass is implemented to enhance cooperation. Overall, **CommNet is designed to foster collaboration and coordination strategies** in multi-agent environments for cooperative tasks.


In [None]:
class DQNR(nn.Module):
    """
    Recurrent DQN with an lstm cell.
    """

    def __init__(self, in_features, mlp_units, num_actions, activation_fn):
        super(DQNR, self).__init__()
        self.encoder = MLP(in_features, mlp_units, activation_fn)
        self.lstm = nn.LSTMCell(
            input_size=self.encoder.out_features, hidden_size=self.encoder.out_features
        )
        self.state = None
        self.q_net = Q_Net(self.encoder.out_features, num_actions)

    def get_state_len(self):
        return 2 * self.lstm.hidden_size

    def _state_reshape_in(self, batch_size, n_agents):
        """
        Reshapes the state of shape
            (batch_size, n_agents, self.get_state_len())
        to shape
            (2, batch_size * n_agents, hidden_size).

        :param batch_size: the batch size
        :param n_agents: the number of agents
        """
        self.state = (
            self.state.reshape(
                batch_size * n_agents,
                2,
                self.lstm.hidden_size,
            )
            .transpose(0, 1)
            .contiguous()
        )

    def _state_reshape_out(self, batch_size, n_agents):
        """
        Reshapes the state of shape
            (2, batch_size * n_agents, hidden_size)
        to shape
            (batch_size, n_agents, self.get_state_len()).

        :param batch_size: the batch size
        :param n_agents: the number of agents
        """
        self.state = self.state.transpose(0, 1).reshape(batch_size, n_agents, -1)

    def _lstm_forward(self, x, reshape_state=True):
        """
        A single lstm forward pass

        :param x: Cell input
        :param reshape_state: reshape the state to and from (batch_size, n_agents, -1)
        """
        batch_size, n_agents, feature_dim = x.shape
        # combine agent and batch dimension
        x = x.view(batch_size * n_agents, -1)

        if self.state is None:
            lstm_hidden_state, lstm_cell_state = self.lstm(x)
        else:
            if reshape_state:
                self._state_reshape_in(batch_size, n_agents)
            lstm_hidden_state, lstm_cell_state = self.lstm(
                x, (self.state[0], self.state[1])
            )

        self.state = torch.stack((lstm_hidden_state, lstm_cell_state))
        x = lstm_hidden_state

        # undo combine
        x = x.view(batch_size, n_agents, -1)
        if reshape_state:
            self._state_reshape_out(batch_size, n_agents)

        return x

    def forward(self, x, mask):
        h = self.encoder(x)
        h = self._lstm_forward(h)
        return self.q_net(h)


class CommNet(DQNR):
    """
        Communication Network employing inter-agent communication.
    """

    def __init__(
        self,
        in_features,
        mlp_units,
        num_actions,
        comm_rounds,
        activation_fn,
    ):
        super().__init__(in_features, mlp_units, num_actions, activation_fn)
        assert comm_rounds >= 0
        self.comm_rounds = comm_rounds

    def forward(self, x, mask):
        batch_size, n_agents, feature_dim = x.shape
        h = self.encoder(x)

        # manually reshape state
        if self.state is not None:
            self._state_reshape_in(batch_size, n_agents)

        h = self._lstm_forward(h, reshape_state=False)

        # explicitly exclude self-communication from mask
        mask = mask * ~torch.eye(n_agents, dtype=bool, device=x.device).unsqueeze(0)

        for _ in range(self.comm_rounds):
            # combine hidden state h according to mask
            # first add up hidden states according to mask
            #    h has dimensions (batch, agents, features)
            #    and mask has dimensions (batch, agents, neighbors)
            #    => we have to transpose the mask to aggregate over all neighbors
            c = torch.bmm(h.transpose(1, 2), mask.transpose(1, 2)).transpose(1, 2)
            # then normalize according to number of neighbors per agent
            c = c / torch.clamp(mask.sum(dim=-1).unsqueeze(-1), min=1)

            # skip connection for hidden state and communication
            h = h + c
            # use new hidden state
            self.state[0] = h.view(batch_size * n_agents, -1)

            # pass through forward module
            h = self._lstm_forward(h, reshape_state=False)

        # manually reshape state in the end
        self._state_reshape_out(batch_size, n_agents)
        return self.q_net(h)




### <u>**Reward System: Detailed Description and Discussion**</u>
Below we describe the reward system that we have created to promote efficient routing.

#### **Reward Components**
1. **Blocked Path Penalty:**
   - *Condition:* If a plane attempts to traverse an edge whose load exceeds a threshold (1.5 = 3 planes), making it no-passable.
   - *Penalty:* −10.0
   - *Purpose:* The planes quickly learn that overloading an edge is harming. It is desirable that blocking happens during training as the planes are able to learn to avoid this scenario.
   - Before the implementation of the prioritized based replay-buffer and with higher $\gamma$ settings, planes were not able to experience being blocked as they explored too much and they got blocked very ofter during evaluation. We were able to tackle this issue effectively.

2. **Shortest Path Incentive:**
   - *Condition:* If the next node lies along the shortest path from the current node to the target node.
   - *Reward:* +3.0
   - *Purpose:* Encourages planes to follow the shortest paths for efficiency.

3. **Progress-Based Incentive:**
   - *Condition:* If the distance to the target decreases after taking an action.
   - *Reward:* +1.5 * exp(0.002 * progress)
   - *Purpose:* Rewards meaningful progress toward the target node and helps the plane to implicitly locate the target node.

4. **Time Penalty:**
   - *Condition:* Natural penalty incurred every action step.
   - *Penalty:* −0.01
   - *Purpose:* Encourages faster completion of the task by minimizing delays.

5. **Near-Target Penalty:**
   - *Condition:* If the plane is near the target (distance < 250) but chooses an action that doesn’t lead directly to the target in the next step.
   - *Penalty:* −1.0
   - *Purpose:* Reduces unnecessary detours near the destination and balances the reward for progress.

6. **Looping Penalty:**
   - *Condition:* If the plane revisits a node it has already visited.
   - *Penalty:* −2.0
   - *Purpose:* Discourages repeated visits to the same node.

7. **Exploration Incentive:**
   - *Condition:* If the plane visits a new node.
   - *Reward:* +2.5
   - *Purpose:* Encourages exploration of unvisited nodes.

8. **Edge Traversal Adjustment:**
   - *Condition:* While traversing an edge.
     - *Shortest Path edge:* +0.2 per step.
     - *Not Shortest Path edge:* −0.1 per step.
   - *Purpose:* Maintains consistent evaluation during traversal.

9. **Target Reached Reward:**
   - *Condition:* Upon successfully reaching the target node.
   - *Reward:* +10.0 * (1 + 1 / shortest path ratio)
   - *Purpose:* Maximizes efficiency and incentivizes the shortest path.

10. **Emergency Landing Penalty:**
    - *Condition:* When no valid actions are possible (plane is "dropped").
    - *Penalty:* −20.0
    - *Purpose:* Penalizes failure to complete the task.

---

### **Summary**
This reward system strikes a balance between exploration, efficiency, and resource management. Positive rewards incentivize optimal pathfinding and exploration, while penalties discourage inefficiency, unnecessary detours, and overloading of edges. The design ensures that agents learn to navigate the network effectively while avoiding congested or suboptimal routes. Furthermore, after over 100 separate runs across a wide variety of settings, no loopholes have been identified in our reward system. Implementation of this reward can be found in the ```routing.py``` file within the **step** function.

## <u>**State Aggregation**</u>

### **SUM**
**SUM** is a straighforward method of aggregating hidden node states during the message passing phase within the **NetMon** class.

Node masking is implemented to ensure that only hidden states from neighbouring nodes is aggregated. This selective aggregation allows for creation of meaninigful representations for the underlying graph on which everything operates.

Its primary advantage is **efficiency and speed** as was shown in the publishing paper [] and also within our work.


In [None]:
class SimpleAggregation(nn.Module):
    def __init__(self, agg: str, mask_eye: bool) -> None:
        super().__init__()
        self.agg = agg
        assert self.agg == "mean" or self.agg == "sum"
        self.mask_eye = mask_eye

    def forward(self, node_features, node_adjacency):
        if self.mask_eye:
            node_adjacency = node_adjacency * ~(
                torch.eye(
                    node_adjacency.shape[1],
                    node_adjacency.shape[1],
                    device=node_adjacency.device,
                )
                .repeat(node_adjacency.shape[0], 1, 1)
                .bool()
            )
        feature_sum = torch.bmm(node_adjacency, node_features)
        if self.agg == "sum":
            return feature_sum
        if self.agg == "mean":
            num_neighbors = torch.clamp(node_adjacency.sum(dim=-1), min=1).unsqueeze(-1)
            return feature_sum / num_neighbors


### **GCN**
**GCN (Graph Convolutional Network)** is a graph

GCN is a graph convolutional operator that during Message Passing phase within the GNN handles hidden state aggregation, much like the **SUM** method described above. Implementation is available at [GCN](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GCNConv.html#torch_geometric.nn.conv.GCNConv) within the Pytorch Geometric library that specializes on GNNs.

We do not provide below any code for GCNConv class from Pytorch Geometric as we did not change it at all and do not want to implement it by ourselfs. At the very least, we provide a short description.
It is a fundamental building block for GCNs that operate on graph-like/structured data. It performs convolution operations as a form of aggregation of hidden node states from neighbourhoods.



## <u>**NetMon**</u>
**NetMon**, which is within ```model.py``` is the cornerstone of this system and it was meticulously designed by the original authors. This class is responsible for creating the underlying graph observations, which are then integrated into agent observations through the **NetMon environment wrapper** (within ```wrapper.py```). By utilizing a MP phase similar to those found in Graph Neural Networks (GNNs), **NetMon** recurrently updates and produces hidden states, which carry meaningful information about the underlying graph.

### **State management**
**LSTM**s **(Long Short-Term Memory)**/**GRU**s **(Gated Recurrent Unit)** are used to maintain and update the internal states of nodes over time. This approach effectively caputres **temporal dependencies** within the network, making this mechanism crucial for the routing task. Because of this system, agents can **adjust their paths dynamically** as the recurrent networks remember past patterns and are able to predict future ones.

In [None]:
class NetMon(nn.Module):
    """
    Why does this even exist?
        - processing observations from nodes in the graph
        - performs message passing
        - aggregation of information from neighboring nodes
        - updating node states with RNN
        - produces features for nodes in the graph
    """

    def __init__(self, in_features, hidden_features: int, encoder_units, iterations, activation_fn,
                rnn_type="lstm",
                rnn_carryover=True, agg_type="sum",
                output_neighbor_hidden = False, output_global_hidden = False
    ) -> None:
        super().__init__()

        assert isinstance(hidden_features, int)

        # print("In-features to Netmon encoder:", in_features)
        self.encode = MLP(in_features, (*encoder_units, hidden_features), activation_fn)    # Define simple MLP as the endocer function
        self.state = None
        self.output_neighbor_hidden = output_neighbor_hidden
        self.output_global_hidden = output_global_hidden
        self.rnn_carryover = rnn_carryover
        self.iterations = iterations

        # 0 = dense input - dense matricies
        # 1 = sparse input - sparse matricies
        self.aggregation_def_type = None

        # Agreggation
        self.agg_type_str = agg_type    # GCN

        # Now we will resolve the actual aggregation with the individual networks
        if agg_type == "sum" or agg_type == "mean":
            self.aggregate = SimpleAggregation(agg=agg_type, mask_eye=False)
            self.aggregation_def_type = 0
        elif agg_type == "gcn":
            self.aggregate = GCNConv(hidden_features, hidden_features, improved=True)
            self.aggregation_def_type = 1
        else:
            raise ValueError(f"Unknown aggregation type {agg_type}")


        # Update and observation encoding
        self.rnn_type = rnn_type    # lstm in our
        if self.rnn_type == "lstm":
            self.rnn_obs = nn.LSTMCell(hidden_features, hidden_features)
            self.rnn_update = nn.LSTMCell(hidden_features, hidden_features)
            self.num_states = 2 if rnn_carryover else 4 # 2
        elif self.rnn_type == "gru":
            self.rnn_obs = nn.GRUCell(hidden_features, hidden_features)
            self.rnn_update = nn.GRUCell(hidden_features, hidden_features)
            self.num_states = 1 if rnn_carryover else 2
        elif self.rnn_type == "none":
            # empty state / stateless => simply store h for debugging
            self.num_states = 1
        else:
            raise ValueError(f"Unknown rnn type {self.rnn_type}")

        self.hidden_features = hidden_features
        self.state_size = hidden_features * self.num_states



    def forward(self, x, mask, node_agent_matrix, max_degree=None, no_agent_mapping = False):
        # This function contains steps (1), (2) and (3)
        h, last_neighbor_h = self.update_node_states(x, mask)

        # Step (4), Check what type of node states to aggregate. Either global or neighbor
        if self.output_neighbor_hidden or self.output_global_hidden:
            extended_h = [h]

            # Aggregate neighbors
            if self.output_neighbor_hidden:
                extended_h.append(
                        self.get_neighbor_h(last_neighbor_h, mask, max_degree)
                    )

            # Aggregate global
            if self.output_global_hidden:
                extended_h.append(self.get_global_h(h))

            h = torch.cat(extended_h, dim=-1)   # Concatenates all features along the last dimension


        if no_agent_mapping:
            return h

        return NetMon.output_to_network_obs(h, node_agent_matrix)


    def get_state_size(self):
        return self.state_size

    def get_global_h(self, h):
        """
            Computes a global summary of the nodes states and appends it to each Node's representation.
        """
        _, n_nodes, _ = h.shape     # (batch_size, n_nodes, hidden_size)

        """
            - mean(dim=1) computes the mean along all nodes for each batch
                -> (batch_size, hidden_size)
            - repeat repeats the global hidden state n_node times along a new dimension -> (n_nodes, batch_size, hidden_size)
            - transpose resutls in (batch_size, n_nodes, hidden_size)
        """
        global_h = h.mean(dim=1).repeat((n_nodes,1,1)).transpose(0,1)
        return global_h

    def get_neighbor_h(self, neighbor_h, mask, max_degree):
        """
            Computes a summary based on Nodes neighbors and appends it to each Node's representation.
        """
        batch_size, n_nodes , _ = neighbor_h.shape

        # Get max node id for dense observation tensor (excludes self)
        if max_degree is None:  # The maximum number of neigbors for each node -> if it is none -> compute from adjacency matrix
            max_degree = torch.sum(mask, dim=1).max().long().item() - 1

        # Pre-allocate a placeholder for each neighbor
        h_neighbors = torch.zeros((batch_size, n_nodes, max_degree, neighbor_h.shape[-1]), device = neighbor_h.device)

        # Get mask without self (pure neighbors)
        neighbor_mask = mask * ~(                               # ~ is negation -> creates a matrix of ones where diagonal is 0
            torch.eye(n_nodes, n_nodes, device=mask.device)
            .unsqueeze(0)   # Add dimension to the 0th positions
            .repeat(mask.shape[0], 1, 1) # Repeat the tensor mask.shape[0] times along the first dimension and once along the second and the third
            .bool()
        )

        # Now we want to collect features from neighbors
        h_index = neighbor_mask.nonzero()

        # Get the relative neighbor ID for the insertion into h_neighbors
        cumulative_neighbor_index = neighbor_mask.cumsum(dim=-1).long() - 1     # Cumulatively sums the neighbor mask along the last dimension to assign a unique index to each neighbor per node
        h_neighbors_index = cumulative_neighbor_index[h_index[:,0], h_index[:, 1], h_index[:, 2]]

        # Copy the last hidden state of all neighbors into the hidden state tensor
        # For each neighbor connection, copies neighbor's hidden state into h_neighbors to corresponding position
        h_neighbors[h_index[:,0], h_index[:, 1], h_neighbors_index] = neighbor_h[h_index[:,0], h_index[:, 2]]

        # Concatenate info for each node
        return h_neighbors.reshape(batch_size, n_nodes, -1)     # Reshape from (batch_size, n_nodes, max_degree, hidden_size) to (batch_size, n_nodes, max_degree*hidden_size)
                                                                # |
                                                                #  -> each node has a concatenated vector of its neighbors' hidden states


    def update_node_states(self, x, mask):
        """
            This function performs message passing and state updates over a specified number of iterations.
            It integrates both node features and graph structure.

            :mask: it is the adjacency matrix of the graph -> (n_waypoints,n_waypoints)
        """
        batch_size, n_nodes, feature_dim = x.shape # (1, 131, 1463)

        x = x.reshape(batch_size * n_nodes, -1)  # New shape is (batch_size * n_nodes, feature_dim)

        if self.state == None: # For storing hidden states
            # Init
            self.state = torch.zeros((batch_size, n_nodes, self.state_size), device = x.device)

        # Reshape the state before further processing
        self.state_reshape_in(batch_size, n_nodes)

        # step (1): encode observation to get h^0_v and combine with state
        h = self.encode(x)  # Producing initial hidden representations

        # Choose what we are using. Either LSTM or GRU
        if self.rnn_type in ["lstm", "lnlstm"]:
            h0, cx0 = self.rnn_obs(h, (self.state[0], self.state[1])) # rnn.obs processes the encoded features h along with the previous states
            h, cx = h0, cx0

        # Message passing iterations
        if self.iterations <= 0 and self.output_neighbor_hidden:
            last_neighbor_h = torch.zeros_like(h, device=h.device)  # Returns a tensor filled with 0s in the shape of h
        else:
            last_neighbor_h = None

        if self.aggregation_def_type != 0:
            mask_sparse, mask_weights = dense_to_sparse(mask)   # Conversion to a sparse representation for the aggreagtion function

        if self.aggregation_def_type == 2:
            H, C = self.state[0], self.state[1] # Init of additional aggregation types

        # Iteration
        for it in range(self.iterations):
            if self.output_neighbor_hidden and it == self.iterations-1:
                if self.aggregation_def_type == 2:
                    """
                        we know that the aggregation step will exchange the hidden states
                        (and much more..) so we can just use them for the skip connection
                        instead of the other nodes' input.
                        This is only relevant for a single iteration per step.
                    """
                    last_neighbor_h = H
                else:
                    # use the last received hidden state
                    last_neighbor_h = h

            # step (2): aggregate - computes the aggregated messages M for each node.
            if self.aggregation_def_type == 0:  # Simple aggregation
                M = self.aggregate(h.view(batch_size, n_nodes, -1), mask).view(
                    batch_size * n_nodes, -1
                )
            elif self.aggregation_def_type == 1:    # Aggregation through conv. layers with sparse mask
                M = self.aggregate(h, mask_sparse)

            elif self.aggregation_def_type == 2:    # Specialized aggregation with additional states
                H, C = self.aggregate(h, mask_sparse, H=H, C=C)
                M = H
            elif self.aggregation_def_type == 3:    # Uses models like GraphSAGE etc. proly won't be useful to us
                # overwrite last_neighbor_h with jumping knowledge output
                M, last_neighbor_h = self.aggregate(h, mask_sparse)

            # step (3): update - it is performed using RNN cell with the aggregated messages M
            """
                What is the carryover mechanism?
                    Carryover mechanism controls whether to carry over states between iterations or reset them.
            """
            if self.rnn_type in ["lstm", "lnlstm"]:
                if not self.rnn_carryover and it == 0:
                    rnn_input = (self.state[2], self.state[3])
                else:
                    rnn_input = (h, cx)

                h1, cx1 = self.rnn_update(M, rnn_input)
                h, cx = h1, cx1
            elif self.rnn_type == "gru":
                if not self.rnn_carryover and it == 0:
                    rnn_input = self.state[1]
                else:
                    rnn_input = h
                h1 = self.rnn_update(M, rnn_input)
                h = h1
            else:
                h = M

        # Reshaping
        if last_neighbor_h is not None:
            last_neighbor_h = last_neighbor_h.reshape(batch_size, n_nodes, -1)  # Reshaping to original dimensions for output
        h = h.reshape(batch_size, n_nodes, -1)  # Reshaping to original dimensions for output

        # Updating of the internal state
        if self.rnn_type in ["lstm", "lnlstm"]:
            if self.rnn_carryover:
                self.state = torch.stack((h1, cx1))     # Concatenating tensors along a new dimension
            else:
                self.state = torch.stack((h0, cx0, h1, cx1))
        elif self.rnn_type == "gru":
            if self.rnn_carryover:
                self.state = h1.unsqueeze(0)
            else:
                self.state = torch.stack((h0.unsqueeze(0), h1.unsqueeze(0)))
        elif self.rnn_type == "none":
            # store last node state for debugging and aux loss
            self.state = h.unsqueeze(0)

        self.state_reshape_out(batch_size, n_nodes)

        return h, last_neighbor_h

    def state_reshape_in(self, batch_size, n_agents):
        """
            Reshapes the state of shape (batch_size, n_agents, self.get_state_len())
                to shape
                    (2, batch_size * n_agents, hidden_size)
        """

        if self.state.numel() == 0:
            return

        self.state = self.state.reshape(batch_size * n_agents, self.num_states, -1).transpose(0,1)

    def state_reshape_out(self, batch_size, n_agents):
        """
            Reshapes the state of shape
                (2, batch_size * n_agents, hidden_size)
            to shape
                (batch_size, n_agents, self.get_state_len()).

            :param batch_size: the batch size
            :param n_agents: the number of agents
        """
        if self.state.numel() == 0:
            return

        self.state = self.state.transpose(0, 1).reshape(batch_size, n_agents, -1)


    @staticmethod
    def output_to_network_obs(netmon_out, node_agent_matrix):
        """
            Netmon_out is called within the forward function. Why? It performs the mapping of the node information to agents.
                It multiplies the node outputs with node_agent_matrix to aggregate/map node outputs to agent-specific outputs.
        """

        # bmm performs a batch matrix-matrix product of matricies stored in netmon_out.transpose(1,2) and node_agent_matrix
        return torch.bmm(netmon_out.transpose(1, 2), node_agent_matrix).transpose(1, 2)

# <u>**Multi GPU setups**</u>
##**Bi-GPU setup**
In **our bi-GPU configuration**, we have adapted the primary script (```main.py``` to ```bi_gpu_main.py```) and the environment wrapper to utilize two GPUs. A key modification involves assigning the Q-value predictor to a **separate** GPU. This separation allows us to handle larger mini-batch sizes by distributing the computational load more evenly across both GPUs. By adjusting the device allocation within ```bi_gpu_main.py```, we ensuree that different components of the model are optimally assigned to each GPU.

Additionally, the environment wrapper has been updated to manage data transfer between the two GPUs. This includes managing inter-GPU communication to maintain consistency during the training process. Specifically, dedicating the Q-value predictor to separate GPU reduces the memory and processing strain on the primary GPU, thereby enabling higher throughput and faster training times.

##**Multi-GPU setup**
To scale our setup to multiple GPUs, we incorporated PyTorch's **DataParallel** class, which allows the model to be distributed across several GPUs within a single compute node. This approach offers the flexibility to handle varying numbers of GPUs without requiring significant alterations. By leveraging **DataParallel**, we split the model across available GPUs, enabling parallel processing of input batches $→$ increasing computational efficiency.

A crucial modification was made to the NetMon class. Originally, **NetMon** maintained its internal state as an attribute, which posed challenges for parallel execution since DataParallel class replicates the model across available GPUs. To overcome this, we r**emoved** the internal state attribute from NetMon, ensuring that state information is managed externally rather than being tied to a specific instance of the model. Furthermore, the entire pipeline, including the **NetMon environment wrapper**, was redesigned to return the internal state during the forward pass. This adjustment ensures compatibility with DataParallel, which requires modules to be **stateless**.


# <u>**Selecting Parameters**</u>
To optimize othe performance of our models, we conducted multiple hyperparameter sweeps using Weights & Biases (W&B). These sweeps provided valuable insights into the importance and correlations of various parameters. Although project time constraints limited the extent of our data collection, the results obtained are sufficiently clear to guide our parameter selection. We selected the best-performing models based on the accumulated rewards. Additionally, we considered other significant metrics such as SPR (the mean ratio of the path length taken to the true shortest path) and throughput (the mean ratio of planes that reached their targets during an episode).

We ran two sweeps, one for only CommNet settings (first picture), the other sweep for comparing DQN and DGN (second picture). Both sweeps ran for 75k steps, with CommNet being significantly faster yet inferior architecture and Sum being the superior AND faster aggregation method.

Both sweeps were conducted over 75k steps. Our findings indicate that while **CommNet** (first figure) is significantly faster, it demonstrated inferior performance compared to other architectures (second picture). Conversely, the SUM aggregation method proved to be both **superior in performance and faster in execution**.

##**Common Parameters in the Sweep**
 - **mini_batch_size** : number of sampled experiences from the replay buffer
 - **epsilon_decay** : decay factor of epsilon in the EpsilonGreedy policy
 - **agg_type** : method used during NetMon message passing phase $→$ produces node representations
 - **gamma**: discount factor in Q-Learning

## **CommNet specific**
 - **comm_round**: number of information passing round

## **DQN, DGN specific**
 - **num_heads**: number of attention heads
 - **att_layers**: number of attention layers




# <u>**Advice for Parameter Selection**</u>
In this section, we provide guidance on fine-tuning hyperparameters, followed by results from 35 sweeps conducted on Weights & Biases (W&B) across two different settings.
## **CommNet settings**
Our experiments reveal that using a **lower number of communication rounds** (com_rounds) and **increasing the mini-batch size** significantly enhances the model's performance. Although the epsilon update frequency did not appear as critical, this may be due to the limited range we explored during the sweep. We recommend setting this parameter to around **70 for 75k steps** and gradually increasing it as the number of training steps grows. A good rule of thumb is to ensure that after training, the epsilon value stabilizes around **0.01**. If epsilon falls below 0.01, it is reset to maintain some exploration.
## **DQN, DGN settings**
In the **DQN** and **DGN** settings, the discount factor (gamma) emerged as **the most crucial** hyperparameter. This finding aligns with expectations, given that the architecture directly feeds into the **Q-Net**. We observed consistent improvements in accumulated rewards when gamma was set lower (approximately **0.92-0.95**) over **75k steps**.

Additionally, the epsilon update frequency proved to be highly significant, showing a strong negative correlation with performance. This indicates that epsilon should be updated more frequently, enabling the models to adapt and navigate the environment more efficiently. Furthermore, the **number of attention heads** and **attention layers** demonstrated a positive correlation with rewards, suggesting that increasing these parameters can further enhance performance.

## **Aggregation Type**
Across both sweeps, the **choice of aggregation type** did not show substantial importance. However, for **CommNet**, we recommend using **GCN**, as it exhibited a positive correlation with the generated rewards, whereas **SUM** showed a negative correlation.

On the other hand, in the case of **DQN** and **DGN**, **SUM** performed better. Based on our extensive training experience, we recommend using SUM for its speed and comparable results.

### **WANDB sweep for CommNet**
![](https://github.com/Strojove-uceni/2024-final-letadylka-prochazka-belohlavek/blob/main/pictures/commnet.png?raw=true)

### **Hyperparameters importance for CommNet**
![](https://github.com/Strojove-uceni/2024-final-letadylka-prochazka-belohlavek/blob/main/pictures/commnet_parameters.png?raw=true)

### **WANDB sweep for DQN vs. DGN**
![](https://github.com/Strojove-uceni/2024-final-letadylka-prochazka-belohlavek/blob/main/pictures/dgn_dqn.png?raw=true)

### **HYPERPARAMETERS IMPORTANCE FOR DGN and DQN**
![](https://github.com/Strojove-uceni/2024-final-letadylka-prochazka-belohlavek/blob/main/pictures/dgn_dqn_parameters.png?raw=true)

# <u>**Our Results**</u>

We trained multiple models with various settings but after performing the sweeps on Weights and Biases, our mean reward nearly doubled and so did the throughput of planes into targets. What's more interesting is that we observed a clean positive correlation between reward and throughput, meaning that with higher mean reward achieved the number of planes that landed was also higher. This effectively means that the agents did not find a loophole in the reward system, which is very pleasant and we are very intrigued by this result. These claims deserve more backing and rigorous testing as well as the before mentioned sweeps of the hyperparameter space. Unfortunatelly, we didn't have enough time during this project to give attention to all these things that arguably deserve it.


Through extensive training with various model settings and conducting hyperparameter sweeps on Weights & Biases (W&B), we achieved remarkable improvements in our system's performance. Specifically, our **mean reward nearly doubled**, and the **throughput of planes reaching their targets also doubled**.


## <u>**Our Selected Best Performing Models**</u>
We present the top two models we have selected based on mean reward. Note that higher mean reward very likely implies higher throuput and smaller shortest path ratio. We also added a third model for comparison to include every architecture we used.

### **Key observations**
- We observed a **clear positive** correlation between the mean reward and the throughput of planes. This means that as the agents achieved higher mean rewards, the number of planes successfully landing at their targets also increased.
- The positive correlation **reassures** us that the agents **did not** find unintended shortcuts or loopholes to artificially inflate rewards. Instead, they genuinely improved their performance by making more effective decisions. This outcome is of **extremely high importance** to us, as it validates the robustness of our reward design!! and the overall reliability of our training process. (We had an extremely hard time of deriving a working reward system that would promote effective routing in a large graph.)


### **Hyperparameter search**
Due to the limited time, we were not able to perform a fully rigorous hyperparameter search. Nevertheless, the current results provide a **strong foundation for future work**.



<style>
    table {
        width: 200%;
    }
</style>
| Parameters               | Model 1          | Model 2          | Model 3        |
|-------------------------|------------------------|------------------------|------------------------|
| model_type              | comm_net               | dqn                    | dgn                    |
| iterations              | 6                      | 6                      | 8                      |
| agg_type                | gcn                    | gcn                    | sum                    |
| att_layers              | -                     | -                      | 6                      |
| num_heads            | - |        -             | 12|
| kv_values               | -                     | -                     | 16                     |
| comm_rounds             | 4                      | -                      | -                      |
| epsilon_update_freq     | 70                     | 90                     | 70                     |
| total_steps             | 75000                  | 75000                  | 75000                  |
| step_before_train       | 15000                  | 15000                  | 15000                  |
| step_between_train      | 10                     | 10                     | 5                      |
| sequence_length         | 16                     | 16                     | 16                     |
| gamma                   | 0.93                   | 0.91                   | 0.91                   |
| mini_batch_size         | 32                     | 64                     | 32                     |
|-------------------------|------------------------|------------------------|------------------------|
|Results|
|-------------------------|------------------------|------------------------|------------------------|
| delays_mean             | 20.7              | **17.4**              | 28.6              |
| delays_arrived_mean     | 21.8             | **18**              | 29              |
| spr_mean                | 3.7               | **3**               | 5              |
| looped_mean             | 0.41               | **0.05**               | 1.3               |
| throughput_mean         | 0.42                  | **0.51**               | 0.29               |
| dropped_mean            | 0.0                    | 0.0                    | 0.0                    |
| blocked_mean            | 0.001               | 0.002               | 0.001               |
| total_edge_load_mean    | 5               | 5               | 5               |
| occupied_edges_mean     | 9.4              | 9.4               | 9.3               |
| planes_on_edges_mean    | 9.99               | 9.99               | 9.99               |
| total_plane_size_mean   | 5.0                    | 5.0                    | 5.0                    |
| plane_sizes_mean        | 0.5                    | 0.5                    | 0.5                    |
| plane_distances_mean    | 9.3                | 9.6               | 9.1               |
| reward_mean             | 3.5              | **3.9**               | 2.7               |


## <u>**Heatmaps**</u>
Heatmaps show the edge and node usage by the agents. Additional pictures can be found in the ```/pictures/evaluation_pictures``` directory or can be generated by the provided code. After each evaluation phase, the code automatically saves many interesting and informative pictures. These were taken from the original implementation and slightly <u>**beautified**</u>.

### **COMMNET**
Utilization is concentraded aournd the central nodes with a balanced distribution across edges and the edge utilization appears to be more evenly distributed compared to the other models.

\

![](https://github.com/Strojove-uceni/2024-final-letadylka-prochazka-belohlavek/blob/main/pictures/evaluation_pictures/heatmapcommnetmask.png?raw=true)

### **DQN**
Node and edge utilization still centers around the same critical nodes but shows a more pronounced preference for specific paths.

\
![](https://github.com/Strojove-uceni/2024-final-letadylka-prochazka-belohlavek/blob/main/pictures/evaluation_pictures/heatmapdqnmask.png?raw=true)

### **DGN**
Higher utilization is concentrated around fewer nodes, indicating potentially fewer but more optimized routes. Additionally, some edges show much higher utilization than others, hinting at more focused path selection compared to DQN or COMMNET.

\
![](https://github.com/Strojove-uceni/2024-final-letadylka-prochazka-belohlavek/blob/main/pictures/evaluation_pictures/heatmapdgnmask.png?raw=true)

##<u>**Shortest Path Ratio**</u>
These plots show the relationship between the number of steps in the shortest path to the target and the number of steps the agent took to reach it. The linear line in the middle represents the "lower bound," but the number of steps must be scaled by the planes' speed for this interpretation to be strictly valid. Here, it simply serves as a reference for comparing the stability of the models.

We can make an interesting observation that with the growing number of steps needed to reach the target, the number of steps the agents take to reach it grows almost linearly with some deviations. The **biggest instability** is in the DGN model, while DQN appears to be **exceptionally stable**. In the case of CommNet, the model behaves **even better** for longer paths, whis is simply lovely.

### **COMMNET**
![](https://github.com/Strojove-uceni/2024-final-letadylka-prochazka-belohlavek/blob/main/pictures/evaluation_pictures/sprcommnetmask.png?raw=true)

### **DQN**
![](https://github.com/Strojove-uceni/2024-final-letadylka-prochazka-belohlavek/blob/main/pictures/evaluation_pictures/sprdqnmask.png?raw=true)

### **DGN**
![](https://github.com/Strojove-uceni/2024-final-letadylka-prochazka-belohlavek/blob/main/pictures/evaluation_pictures/sprdgnmask.png?raw=true)

# <u>**Conclusion**</u>
In this project, we explored the integration of Deep Graph Neural Networks (DGN) with reinforcement learning techniques to enhance multi-agent coordination and decision-making. We have successfully adapted previous work for our task of efficent air traffic control - efficient routing for planes in the airspace.

Our investigation into different architectures, including CommNet, DQN, and DGN, revealed critical insights into the importance of communication rounds, mini-batch sizes, discount factors, and aggregation methods. The implementation of a prioritized replay buffer further augmented our model's ability to learn from impactful experiences, enhancing its capability to predict and navigate optimal paths effectively. Additionally, node masking allowed us to generalize this solution of graphs with diverse neighbourhoods.

Future work could involve more exhaustive hyperparameter sweeps, longer training durations. Moreover, the system could be replicated across multiple regions and creating controller-subcontroller system, where each sub-controller specializes in managing a specific area.


**Special Thanks**

We would like to extend our gratitude to the FNSPE faculty for providing the computational resources necessary for this project, including access to the HELIOS cluster. Additionally, we would like to thank Tomáš Kerepecký for his assistance in setting up WANDB on the UTIA computing cluster.

**Other considered approaches**


These methods were considered, but not selected either for their complexity or applicability.

- [MALR on coordination graphs](https://medium.com/@jamgochian95/multi-agent-reinforcement-learning-with-coordination-graphs-428dddb99907)
- [Former based approach for shortest path](https://medium.com/octavian-ai/finding-shortest-paths-with-graph-networks-807c5bbfc9c8)
- [GNN for shortest path](https://medium.com/@bnn_upc/computing-the-shortest-path-with-graph-neural-networks-gnn-a-hands-on-introduction-to-ignnition-bea531b3b5b2)
- [Interesting non-linear approach](https://www.sciencedirect.com/science/article/pii/S0096300306016304)
- [Approximation of the shortest path](https://arxiv.org/pdf/2002.05257)
- [Meta Learning intro](https://medium.com/huggingface/from-zero-to-research-an-introduction-to-meta-learning-8e16e677f78a)
- [Multi agent meta learning](https://signalprocessingsociety.org/publications-resources/ieee-open-journal-signal-processing/dif-maml-decentralized-multi-agent-meta)
- [Deep path](https://sites.cs.ucsb.edu/~william/papers/DeepPath.pdf)
- [Shortest path with attention network](https://www.ijcai.org/proceedings/2019/569)

**Honorable mentions**
- [NN with Particle Swarm Optimization (PSO)](https://ojs.aaai.org/index.php/SOCS/article/view/18244)
- [Growing Neural Gas](https://en.wikipedia.org/wiki/Neural_gas)


# References

- Graph MARL GitHub Repository: [Original Implementation][1]
- Replay Buffer Implementation: [Replay Buffer GitHub][2]
- [PyTorch Documentation][3]
- [PyTorch Geometric Documentation][4]

[1]: https://github.com/jw3il/graph-marl/tree/main?tab=readme-ov-file
[2]: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/rl/dqn/replay_buffer.py
[3]: https://pytorch.org/
[4]: https://pytorch-geometric.readthedocs.io/en/latest/
