### Project Group 1 in Practical Planning Robust Behavior for autonomous driving
# Reinforcement Learning using Graph Neural Networks

### Tom Dörr, Marco Oliva, Quoc Trung Nguyen, Silvan Wimmer

__Objective__: Exploit the graph-like structure of traffic scenarios by applying graph neural networks to the Soft-Actor-Critic algorithm.
## Chapter 1: Basic Setup
### 1.1: Imports and Setup

In [1]:
#from docs.report.helper_functions import set_notebook_log_level

In [2]:
import tensorflow as tf
import numpy as np
import logging
import json
import pprint as pp
import shutil
import os
import time

# reduce the number of log messages for improved readability
import logging
import sys
logging.disable(sys.maxsize)

# Bark imports
from bark.runtime.commons.parameters import ParameterServer

# Bark-ml imports
from bark.runtime.commons.parameters import ParameterServer
from bark_ml.environments.blueprints import ContinuousMergingBlueprint,\
    ContinuousHighwayBlueprint
from bark_ml.environments.single_agent_runtime import SingleAgentRuntime
from bark_ml.library_wrappers.lib_tf_agents.agents import BehaviorGraphSACAgent
from bark_ml.observers.graph_observer import GraphObserver

# Supervised tests
from bark_ml.tests.capability_gnn_actor.data_handler import SupervisedData
from bark_ml.tests.capability_gnn_actor.actor_nets import ConstantActorNet,\
  RandomActorNet

# helper_functions
from helper_functions import configurable_setup, benchmark_actor, explain_node_attributes,\
    explain_edge_attributes, explain_observation

pygame 1.9.6
Hello from the pygame community. https://www.pygame.org/contribute.html


# ToDO:
- Fuse all imports
- Why is GraphObserver.graph() not working?

## Chapter 2: GraphObserver
In this chapter we want to briefly introduce the working mechanisms of the GraphObserver.

<img src="images/observer.png" width="700">


The GraphObserver has the following parameters which can be set in with the ParameterServer (i.e. with `params["ML"]["GraphObserver"]`):
- "AgentLimit": The maximum number of agents that can be observed. Default is 4.
- "VisibilityRadius": The radius in which agent can 'see', i.e. detect other agents. Default is 50.
- "SelfLoops": Whether each node has an edge pointing to itself. Influences performance i.e. of GNN SAC instances based on spektral.
- "EnabledNodeFeatures": The list of available node features, given by their string key that the observer should extract from the world and insert into the observation. For a list of available features, refer to the list returned by `GraphObserver.available_node_attributes`. Not available node features are ignored (but a log message is shown.)
- "EnabledEdgeFeatures": The list of available edge features, given by their string key that the observer should extract from the world and insert into the observation. For a list of available features, refer to the list returned by `GraphObserver.available_edge_attributes`. Not available edge features are ignored.

In [3]:
# Show available node_attributes
print("Available_node_attributes:", GraphObserver.available_node_attributes())
print("Meaning of node attributes:\n"+explain_node_attributes())

# Show available edge_attributes
print("\nAvailable_edge_attributes:", GraphObserver.available_edge_attributes())
print("Meaning of edge attributes:\n"+explain_edge_attributes())

Available_node_attributes: ['x', 'y', 'theta', 'vel', 'goal_x', 'goal_y', 'goal_dx', 'goal_dy', 'goal_theta', 'goal_d', 'goal_vel']
Meaning of node attributes:
     'x': x-coordinate in world
     'y': y-coordinate in world
     'theta': orientation of agent
     'vel': velocity of agent in direction of orientation
     'goal_x': x-coordinate of goal
     'goal_y': y-coordinate of goal
     'goal_dx': distance in x-coordinate from agent to goal
     'goal_dy': distance in y-coordinate from agent to goal
     'goal_theta': difference between goal orientation and agent orientation
     'goal_d': distance to goal (straight line)
     'goal_vel': velocity, the agent should have when reaching goal

Available_edge_attributes: ['dx', 'dy', 'dvel', 'dtheta']
Meaning of edge attributes:
     'dx': difference in x-coordinate between two agents
     'dy': difference in y-coordinate between two agents
     'dvel': difference in velocity between two agents
     'dtheta': difference in orientation b

The `Observe(world)` method returns an Observation based on the current snapshot of the world by extracting node attributes, adjacency matrix, and edge attributes per edge. An Observation is a Tensor containing all information of the graph. The Tensor shape is used to be compatible with tf_agents.

The `graph(observations, graph_dims, dense=False)` method takes an 'observation' as input (additionally it needs some information about the protcoll of the observation, i.e. where in the tensor what information is stored - the 'graph_dims'). 'dense' is a parameter which specifies the format of the returned graph representation (for further details see the documentation of the method).

OK, let's look at a small example:

In [4]:
%%capture
# create environment
params = ParameterServer(filename='data/tfa_gnn_params.json')
bp = ContinuousHighwayBlueprint(params, number_of_senarios=2, random_seed=0)
observer = GraphObserver(params=params)
env = SingleAgentRuntime(blueprint=bp, observer=observer, render=False)
    
# create agent
agent = BehaviorGraphSACAgent(environment=env, observer=observer, params=params)

The following example explains one exemplary observation and its parts. Further details for the parts can be looked up e.g. in the documentation):

In [5]:
# let's initialize the environment and call Observe() (this is happening internally in env.reset!)
observation = env.reset()

explain_observation(observation, observer.graph_dimensions)

Node_attributes(flattened matrix of original shape 4x11) (nodes x node attributes):
 [ 0.4999322  -0.8763594  -0.5        -0.45761493 -0.4999322   1.
 -0.4999322   0.9381797  -0.49511522  0.8758025  -0.4        -0.4999322
 -0.8404766  -0.5        -0.44750327 -0.4999322   1.          0.
  0.92023826 -0.5         0.8397137  -0.4         0.4999322  -1.
 -0.5        -0.40704772 -0.4999322   1.         -0.4999322   1.
 -0.49541715  0.99937826 -0.4         0.4999322  -0.7510234  -0.5
 -0.44514397 -0.4999322   1.         -0.4999322   0.8755117  -0.49476564
  0.75053436 -0.4       ]
Adjacency matrix(flattened matrix of original shape 4x4) (nodes x nodes):
 [0. 1. 1. 1. 1. 0. 1. 1. 1. 1. 0. 1. 1. 1. 1. 0.]
Edge_attributes(flattened matrix of original shape 16x4) (number of edges x edge attributes):
 [ 0.          0.          0.          0.          0.9998644  -0.03588281
 -0.01011166  0.          0.          0.12364063 -0.05056721  0.
  0.         -0.12533593 -0.01247097  0.         -0.9998644 

And now let us take a look at the graph() method:

In [6]:
observer.graph(observation, observer.graph_dimensions)

InvalidArgumentError: Index out of range using input dim 1; input has only 1 dims [Op:StridedSlice] name: strided_slice/

## Chapter 3: Graph Neural Networks
Before diving into how we apply graph neural networks to our problem, let's have a **very brief** overview about the idea behind them.  
Most importantly, they operate on graph structured data, i.e. data consisting of 
- **Nodes:** feature vectors (node embeddings) of some data entities (and optionally a label), in our case each vehicle is a node
- **Edges:** specified links between nodes
- **Edge features:** optionally, each link between nodes can have its own feature vector

In the section about the `GraphObserver` above, we've already seen how this graph can look like in our scenario. Let's take a step back and use a simplified visualization where the green node represents the ego vehicle and the remaining nodes are other vehicles in its vicinity on the road.

![Schematic view of a GNN](images/simple_gnn.png)

The ego node is connected to both other nodes (it "sees" the other nodes) which in turn do not see each other.

Now, the nodes send messages (their current embeddings) along all outgoing links (here, all links are bidrectional), propagated through a neural network. From now on, we refer to this neural network as the _message passing layer(s)_.

> **NOTE**  
All edges share the same neural network, instead of each edge having its own weights.

Each node aggregates all incoming messages using an aggregation function, like summing or averaging. The result is then processed by another neural network, e.g. a recurrent unit, which computes the new embedding of the node.


In our project, we have integrated two different libraries that offer GNN implementations:
1. [tf2_gnn](https://github.com/microsoft/tf2-gnn): the library that was initially planned to be used in the project
2. [Spektral](https://graphneural.network/#installation): a library that supports edge features, which `tf2_gnn` does not

## Chapter 4: The `GNNWrapper` class

As an abstraction over the specific implementation of the graph neural network, we implemented a wrapper class called `GNNWrapper`. Its primary function is to act just as a GNN and so the only interface is the `call` function that accepts a batch of observations (array representations of graphs) and returns a batch of updated node embeddings for each graph.

In order to support `tf2_gnn` and `Spektral`, we have two distinct call implementations, one for each library. The `GNNWrapper` class decides which one to call based on the arguments given in the initialization.

Both functions however work almost the same:
1. Convert the given observations into nodes, edges and, when using `Spektral`, edges features.
2. Call the respective library with the converted graph representation.

When specifying `Spektral` as the GNN library, the call function looks like this:

In [7]:
from spektral.layers import EdgeConditionedConv, GlobalAttnSumPool
from tensorflow.keras.layers import Dense
from bark_ml.observers.graph_observer import GraphObserver
from docs.report.helper_functions import get_sample_observations, graph_dims

def call_spektral_demo(observations):
    # define the layers of the GNN (normally, this happens upon initialization)
    
    # this defines an edge-conditioned convolution as the message passing layer
    # the `kernel_network` argument defines the layers of the edge neural network
    edge_convolution = EdgeConditionedConv(channels=16, kernel_network=[128], activation="relu")

    def call_spektral(observations, training=False):
        # convert the observations into
        # old_embeddings: tensor containing the node features (embeddings)
        # A: binary adjacency matrix specifying edges in the graph
        # E: tensor containg edge features
        old_embeddings, A, E = GraphObserver.graph(observations, graph_dims)

        # pass the inputs through an edge conditioned convolution
        # layer and receive new node embeddings
        new_embeddings = edge_convolution([old_embeddings, A, E])

        # output the final transformed node embeddings
        return old_embeddings, new_embeddings
    
    old_embeddings, new_embeddings = call_spektral(observations)
    
    print("Here's how the embeddings of the ego agent have changed:\n")
    print(f'old embeddings of shape {old_embeddings[0, 0].shape}: \n{old_embeddings[0,0,:].numpy()}\n')
    print(f'new embeddings of shape {new_embeddings[0, 0].shape}: \n{new_embeddings[0,0,:].numpy()}')

# call the function with sample observations
call_spektral_demo(get_sample_observations())

Here's how the embeddings of the ego agent have changed:

old embeddings of shape (5,): 
[0.75324947 0.21065447 0.5193448  0.56075853 0.6621831 ]

new embeddings of shape (16,): 
[0.         0.         0.         0.         0.         0.
 0.2384678  0.         0.19256851 1.040382   0.18795615 0.
 0.87221414 0.         0.         0.26274508]


In comparison, when `tf2_gnn` is specified, the implementation looks like this:

In [8]:
from tf2_gnn.layers import GNN, GNNInput
import tensorflow as tf

def call_tf2_gnn_demo(observations):
    # the number and types of layers in the GNN are all encoded
    # in the parameters dictionary, let's stick to the default for now
    gnn_params = GNN.get_default_hyperparameters()

    # uncomment the following two lines to have a look at them
    # print(f'GNN parameters:')
    # pp.pprint(gnn_params)

    # initialize a GNN instance which acts as a keras layer
    gnn = GNN(gnn_params)

    def call_tf2_gnn(observations, training=False):
        batch_size = tf.constant(observations.shape[0])

        # convert the observations into
        # old_embeddings: tensor containing the node features
        # A: dense adjacency list in the format [[0, 1], [2, 4]]
        #    specifying source and target node ids of an egde
        # node_to_graph_map: a tensor that assigns each node in X to a graph
        old_embeddings, A, node_to_graph_map = GraphObserver.graph(
          observations,
          graph_dims=graph_dims, 
          dense=True)
        
        # build the struct that tf2_gnn expects as input
        gnn_input = GNNInput(
          node_features=old_embeddings,
          adjacency_lists=(A,),
          node_to_graph_map=node_to_graph_map,
          num_graphs=batch_size,
        )

        new_embeddings = gnn(gnn_input, training=training)
        
        # only for demo purposes
        old_embeddings = tf.reshape(old_embeddings, [batch_size, graph_dims[0], -1])
        new_embeddings = tf.reshape(new_embeddings, [batch_size, 5, -1])
        
        return old_embeddings, new_embeddings
    
    old_embeddings, new_embeddings = call_tf2_gnn(observations)
    
    print("\nHere's how the embeddings of the ego agent have changed:\n")
    print(f'old embeddings of shape {old_embeddings[0, 0].shape}: \n{old_embeddings[0,0,:].numpy()}\n')
    print(f'new embeddings of shape {new_embeddings[0, 0].shape}: \n{new_embeddings[0,0,:].numpy()}')

# call the function with sample observations
call_tf2_gnn_demo(get_sample_observations())


Here's how the embeddings of the ego agent have changed:

old embeddings of shape (5,): 
[0.11097512 0.3341901  0.30166426 0.2869975  0.09270446]

new embeddings of shape (16,): 
[0.00809294 0.01876954 0.01764303 0.01134065 0.00685283 0.05259295
 0.         0.         0.         0.06448102 0.02972951 0.
 0.00982654 0.09037469 0.         0.01321555]


Having the GNN functionality nicely abstracted behind this wrapper, we can now easily integrate it into the Soft-Actor-Critic framework.

## Chapter 5: The Soft-Actor-Critic Algorithm with Graph Neural Networks

Next, let's examine the integrated system.

We want to exploit the graph-like structure of traffic scenarios and have already encoded the state of the world as a graph. Now, we want to apply graph neural networks to the SAC algorithm. 

The resulting actor and critic networks are quite similar in structure. Here's how they work and what they compute.

### The Actor Network

Implemented in the class `GNNActorNetwork`.


**Input**: a batch of observations of shape _(batch_size, observation_size)_  
**Output**: a batch of a normal distributions over the action space from which the policy will sample the actions performed by the agent

![Actor Network Architecture](images/actor_architecture.png)

**1. GNN**  
The observations are directly fed into the graph neural network (a `GNNWrapper` instance). It converts the observations into graphs and computes new node embeddings for each graph by means of message passing and aggregation. Optionally, the new note embeddings are propagated through a dense layer before being returned.

> **NOTE**  
From here on, we're only interested in the embeddings of the ego agent. Hence, instead of feeding the whole graph representation into the encoding network, we extract the embeddings of the first node of each graph, which represents the ego agent.

**2. Encoding Network**  
In the encoding network, the node embeddings of the ego agent are now passed through a series of dense layers. Depending on the parameters passed into the actor, we can also add convolutions, dropout and other types of layers here.

**3. Projection Network**  
Finally, the projection network receives the hidden representations after the encoding network and computes a normal distribution over the action space for each observation contained in the batch, modeled by a mean and a standard deviation.

In a very simplified manner for brevity, the implementation of the actor's `call` function looks as follows:
```python
def call(self, observations, training=False):
    # get the updated node embeddings
    output = self._gnn(observations, training=training)

    # extract the ego state (the first node embedding vector of each batch element)
    output = output[:, 0]
    
    # pass the ego agent's node embeddings through the encoder
    output = self._encoder(output, training=training)
    
    # compute a normal distribution
    output = self._projection_net(output, training=training)

    return output
```

### The Critic Network

Implemented in the class `GNNCriticNetwork`.

**Input**: a batch of observation-action pairs, i.e. `[obs, action]` with shapes _(batch_size, observation_size)_ and _(batch_size, 2)_  
**Output**: a scalar value assigned to each observation-action pair


The major difference compared to the actor network is that in the critic, we have two parallel pipelines for the observations and their corresponding actions.

![Critic Network Architecture](images/critic_architecture.png)

**1. Actions**  
The actions are simply passed into an action encoding network that works similar to the encoding network of the actor network, i.e. a series of dense layers with optional convolutions, dropout layers, etc.

**2. Observations**  
The observations are processed in the exact same way as in the actor network. We compute new graph representations in the GNN, extract the ego node embeddings and pass them through an encoding network.

**3. Joining Actions and Observations**  
After receiving the outputs from the action and observation encoding networks, we concatenate the observation-action pair of each element in the batch to one feature vector.  
Finally, we pass this concatenated state through a fully connected joint network which outputs a scalar value for each observation-action pair.

Again, a simplified version of the implemenation looks like this:
```python
def call(self, inputs, training=False):
    observations, actions = inputs
     
    # get the updated node embeddings
    node_embeddings = self._gnn(observations, training=training)
    
    # extract the ego state (the first node embedding vector of each batch element)
    output = output[:, 0]
    
    # pass the node embeddings through their observation encoder
    node_embeddings = self._observation_encoder(node_embeddings, trainig=training)
    
    # do the same for the actions with a different action encoder
    actions = self._action_encoder(actions, training=training)
    
    # concatenate observations and actions into one vector
    joint = tf.concat([node_embeddings, actions], 1)
    
    # compute a scalar output value
    output = self._joint_net(joint, training=training)

    return output, network_state
```

## Chapter 6: Putting it all together and setting up an example

Now, let's set up an SAC-agent using the graph neural networks described above to be used in BARK-ML.

We start out with the default parameter set as defined in `tfa_gnn_params.json` and make some optional changes afterwards.

In [9]:
params = ParameterServer(filename='data/tfa_gnn_params.json')

First, set up the GNN-related parameters. We use the same GNN configuration in the actor and critic.

In [10]:
# use a spektral GNN
params['ML']['BehaviorGraphSACAgent']['GNN']['Library'] = 'spektral'
    
# use two message passing layers with 80 channels of node embeddings each
params["ML"]["BehaviorGraphSACAgent"]["GNN"]["NumMpLayers"] = 2
params["ML"]["BehaviorGraphSACAgent"]["GNN"]["MpLayersHiddenDim"] = 80
    
# use two fully connected layers in the edge feature mlp of each message passing layer
params['ML']['BehaviorGraphSACAgent']['GNN']['EdgeFcLayerParams'] = [128, 64]

Next, configure the layers that make up the encoding networks in the actor and critic.

In [11]:
params["ML"]["BehaviorGraphSACAgent"]["CriticJointFcLayerParams"] = [128, 128]
params["ML"]["BehaviorGraphSACAgent"]["CriticObservationFcLayerParams"] = [128, 128]
params["ML"]["BehaviorGraphSACAgent"]["ActorFcLayerParams"] = [256, 128]

Finally, we configure the `GraphObserver`.
Here we specify that it should always observe at most 4 agents simultaneously, i.e. the ego agent and its three nearest agents.

In [12]:
params["ML"]["GraphObserver"]["AgentLimit"] = 4

We can also specify which features the graph observer should use in the node embeddings and the edges. This is useful since not all environments contain the same information and thus, some features might not be possible to compute. To get a list of all available features, execute the following cell.

In [13]:
print(f"Available node features:")
for key, value in GraphObserver.available_node_attributes(with_descriptions=True).items():
    print(f"'{key}': {value}")

print(f"\nAvailable edge features:")
for key, value in GraphObserver.available_edge_attributes(with_descriptions=True).items():
    print(f"'{key}': {value}")

Available node features:
'x': The x-components of the agent's position.
'y': The y-components of the agent's position.
'theta': The current heading angle of tha agent.
'vel': The current velocity of the agent.
'goal_x': The x-component of the goal's position.
'goal_y': The y-component of the goal's position.
'goal_dx': The difference in the x-component of the agent's and the goal's position.
'goal_dy': The difference in the y-component of the agent's and the goal's position.
'goal_theta': The goal heading angle.
'goal_d': The euclidian distance of the agent to the goal.
'goal_vel': The goal velocity.

Available edge features:
'dx': The difference in the x-position of the two agents.
'dy': The difference in the y-position of the two agents.
'dvel': The difference in the velocity of the two agents.
'dtheta': The difference in the heading angle of the two agents.


In our example, we train an agent on the `ContinuousMergingBlueprint` where the goal definition does not contain velocity information, so we can not use the `goal_vel` node feature. So let's configure the `GraphObserver` to use all available node features, except `goal_vel`.

In the edges, we want to use all available features, so we don't specify anything. The `GraphObserver` always defaults to using all features if nothing is explicitely configured.

In [14]:
enabled_node_feaures = GraphObserver.available_node_attributes()[:-1]
params["ML"]["GraphObserver"]["EnabledNodeFeatures"] = enabled_node_feaures

In case you feel like experimenting, expand the following dropdown to get an overview of the most important parameters you can tweak.
<details>
<summary><b>List of the most important paramaters</b></summary>
<br>
  <b>Description:</b> Specifies the maximum number of agents that are included in an observation. (int)<br>
  <b>Path:</b> ['ML']['GraphObserver']['AgentLimit'] <br>
  <br>
  <b>Description:</b> Specifies whether each node in the graph will have an edge pointing to itself. (Bool)<br>
  <b>Path:</b> ['ML']['GraphObserver']['SelfLoops'] <br>
  <br>
  <b>Description:</b> Specifies the features that the GraphObserver will include in the node embeddings. [str]<br>
  <b>Path:</b> ['ML']['GraphObserver']['EnabledNodeFeatures'] <br>
  <br>
  <b>Description:</b> Specifies the features that the GraphObserver will include in the edges. [str]<br>
  <b>Path:</b> ['ML']['GraphObserver']['EnabledEdgeFeatures'] <br>
  <br>
  <b>Description:</b> Specifies the fully connected layers (number and sizes) of the actor encoding network. ([int]) <br>
  <b>Path:</b> ['ML']['BehaviorGraphSACAgent']['ActorFcLayerParams'] <br>
  <br>
  <b>Description:</b> Specifies the fully connected layers (number and sizes) of the critic action encoding network. ([int]) <br>
  <b>Path:</b> ['ML']['BehaviorGraphSACAgent']['CriticActionFcLayerParams'] <br>
  <br>
  <b>Description:</b> Specifies the fully connected layers (number and sizes) of the critic observation encoding network. ([int]) <br>
  <b>Path:</b> ['ML']['BehaviorGraphSACAgent']['CriticObservationFcLayerParams'] <br>
  <br>
  <b>Description:</b> Specifies the fully connected layers (number and sizes) of the critic joint network. ([int]) <br>
  <b>Path:</b> ['ML']['BehaviorGraphSACAgent']['CriticJointFcLayerParams'] <br>
  <br>
  <b>Description:</b> Specifies the number of message passing layers in the GNN. (int) <br>
  <b>Path:</b> ['ML']['BehaviorGraphSACAgent']['GNN']['NumMpLayers'] <br>
  <br>
  <b>Description:</b> Specifies the number of units in the message passing layers in the GNN. (int) <br>
  <b>Path:</b> ['ML']['BehaviorGraphSACAgent']['GNN']['MpLayersHiddenDim'] <br>
  <br>
  <b>Description:</b> Specifies which library to use as the GNN implementation, either "tf2_gnn" or "spektral". <br>
  <b>Path:</b> ['ML']['BehaviorGraphSACAgent']['GNN']['Library'] <br>
  
  <h3>The following parameters only apply to TF2-GNN.</h3>

  <br>
  <b>Description:</b> The identifier of the message passing class to be used, here: a relational gated convolution network. (str)
      <br><i>NOTE: when using the 'ggnn' message passing layer, 'MpLayersHiddenDim' must match the number of node features!</i> <br>
  <b>Path:</b> ['ML']['BehaviorGraphSACAgent']['GNN.message_calculation_class'] <br>
  <br>
  <b>Description:</b> The identifier of the message passing class to be used, here: a gated recurrent unit. (str) <br>
  <b>Path:</b> ['ML']['BehaviorGraphSACAgent']['GNN']['global_exchange_mode'] <br>
  <br>
  <b>Description:</b> Specifies after how many message passing layers a dense layer is inserted. (int) <br>
  <b>Path:</b> ['ML']['BehaviorGraphSACAgent']['GNN']['dense_every_num_layers'] <br>
  <br>
  <b>Description:</b> Specifies after how many message passing layers a global exchange layer is inserted. (int) <br>
  <b>Path:</b> ['ML']['BehaviorGraphSACAgent']['GNN.global_exchange_every_num_layers'] <br>
  
  <h3>The following parameters only apply to Spektral.</h3>

  <b>Description:</b> Specifies the number of channels in the edge conditioned convolution layer. (int) <br>
  <b>Path:</b> ['ML']['BehaviorGraphSACAgent']['GNN']['MPChannels'] <br>
  
  <b>Description:</b> Specifies the fully connected layers (number and sizes) in the edge network. ([int]) <br>
  <b>Path:</b> ['ML']['BehaviorGraphSACAgent']['GNN']['EdgeFcLayerParams'] <br>
  
  <b>Description:</b> Specifies the activation function of the message passing layer. (str) <br>
  <b>Path:</b> ['ML']['BehaviorGraphSACAgent']['GNN']['MPLayerActivation'] <br>
</details>

Finally, we configure the BARK-ML environment and our agent.  
**Pro-tip: run the following cell twice to get a cleaner output with less log messages.**

In [15]:
%%capture
from docs.report.helper_functions import prepare_agent, summarize_agent
        
# create environment
bp = ContinuousMergingBlueprint(params, number_of_senarios=2500, random_seed=0)
observer = GraphObserver(params=params)
env = SingleAgentRuntime(blueprint=bp, observer=observer, render=False)
    
# create agent
agent = BehaviorGraphSACAgent(environment=env, observer=observer, params=params)

In [16]:
# only for demo purposes
prepare_agent(agent, params, env)
summarize_agent(agent)


[1mAGENT SUMMARY[0m

Network                        Parameters
ActorNetwork...................... 547.300
CriticNetwork..................... 553.441
CriticNetwork2.................... 553.441
TargetCriticNetwork1.............. 553.441
TargetCriticNetwork2.............. 553.441
------------------------------------------
Total parameters                 2.761.064


## Chapter 7: Evaluate capabilities of actor (with supervised learning)
In this part we briefly introduce a supervised learning setting which helps quickly debug different actor implementations. It is evaluated if the actor network is capable of overfitting a small dataset by comparing it to RandomActor and ConstantActor. The RandomActor outputs a random label within a prespecified bound. The ConstantActor outputs the mean label of the training dataset every time. This part basically shows the happenings of the `py_gnn_actor_tests`.

Additionally, the performance while learning is compared to the standard SAC actor for better comparison with the help of Tensorboard.

This section starts with the definition of some parameters:

In [17]:
# Filenames for default parameter files
filename_tf2_gnn = "../../examples/example_params/tfa_sac_gnn_spektral_default.json"
filename_spektral = "../../examples/example_params/tfa_sac_gnn_tf2_gnn_default.json"
filename_normal = "../../examples/example_params/tfa_params.json"


params_tf2_gnn = ParameterServer(filename=filename_tf2_gnn)
params_spektral = ParameterServer(filename=filename_spektral)
params_normal = ParameterServer(filename=filename_normal)

# Some more parameter
num_scenarios = 3
log_dir = "supervised/summary"
data_path = "supervised/data"

Now, we can load the different actors for benchmarking and fetch or load a small dataset:

In [18]:
%%capture
# Get observer and tf2_gnn_actor
_, tf2_gnn_actor = configurable_setup(params_tf2_gnn, num_scenarios=num_scenarios);
# Get actor based on spektral
observer, spektral_actor = configurable_setup(params_spektral, num_scenarios=num_scenarios);
# Get normal SAC actor
_, normal_sac_actor = configurable_setup(params_normal, num_scenarios=num_scenarios, graph_sac=False);

# construct dataset
dataset = SupervisedData(observer, params_tf2_gnn, batch_size=32, train_split=0.8,
                         data_path=data_path, num_scenarios=num_scenarios);

# Get constant actor (outputs constant mean labels of train_dataset)
constant_actor = ConstantActorNet(dataset=dataset._train_dataset)
# Get random actor (outputs random labels with uniform distribution within bounds)
random_actor = RandomActorNet(low=[0, -0.4], high=[0.1, 0.4])

actors = {"tf2_gnn_actor":{"actor":tf2_gnn_actor},
          "spektral_actor":{"actor":spektral_actor},
          "normal_sac_actor":{"actor":normal_sac_actor},
          "random_actor":{"actor":random_actor},
          "consant_actor":{"actor":constant_actor}}

#### Now, the magic starts:
Every actor is trained (the RandomActor and ConstantActor are not trainable - see their definition from above) for some epochs. The results of the training can be examined in the tensorboard below:

In [22]:
# Delete all old logs if some exist
if os.path.exists(log_dir):
    [shutil.rmtree(os.path.join(log_dir,log)) for log in os.listdir(log_dir)];
    old_logs = os.listdir(log_dir)
else:
    old_logs = list()
    
# Run benchmarking
for actor_name in actors:
    time.sleep(1)
    loss = benchmark_actor(actors[actor_name]["actor"], dataset, epochs=100, log_dir=log_dir)
    actors[actor_name]["loss"] = loss
    
    # Name log clearly with actor_name
    # Select correct log
    new_logs = os.listdir(log_dir)
    log = list(set(new_logs) - set(old_logs))[0]
    # Rename log with actor_name
    old_path = os.path.join(log_dir, log)
    new_path = os.path.join(log_dir, actor_name)
    os.rename(old_path, new_path)
    old_logs = os.listdir(log_dir)

  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


In [23]:
%load_ext tensorboard
%tensorboard --logdir supervised/summary

Reusing TensorBoard on port 6006 (pid 17109), started 2:17:45 ago. (Use '!kill 17109' to kill it.)

## Chapter 6: Result and Evaluation
- introduce supervised setting
- benchmark GNN-SAC vs SAC, randomActor and ConstantActor

In [24]:
#to do 

## Chapter 7: Summary

In [25]:
#to do

# Apendix: Commands 

You can start a training run from the command line with `bazel run //examples:tfa_gnn -- --mode=train`.