# Task V: Quantum Graph Neural Network (QGNN) 

In task II you already worked with a classical GNN. 
- Describe a possibility for a QGNN circuit, which takes advantage of the graph representation of the data
- Implement and draw the circuit.

For this task, we make use of the [Jraph](https://github.com/deepmind/jraph) library by deepmind and PennyLane for quantum circuits. Google Flax library is also used. Flax repo has [examples](https://github.com/google/flax/tree/main/examples/ogbg_molpcba) which describes how to use Jraph and we make use of it. 

## Custom TFDS dataset

A graph representation of the jet data was first created with the help of [tfds dataset](https://www.tensorflow.org/datasets/add_dataset). The custom dataset files is in `jet_dataset` directory. This way, we don't have to worry about loading and batching graph data.

## Hybrid Quantum-Classical Graph Neural Network 

### Input network
First, the input data of size $N \times 4$ where $N$ is the number of nodes and $4$ are the features (pt, rapidity, azimuthal angle, and pdgid), is embedded in a quantum circuit. We keep the circuit simple and use the same circuit architecture throughout the GNN. The circuit is shown below:

![circuit](../images/circ.png)

Every feature vector $x$ is multiplied wih a weight $w$ and added with a bias $b$. The result is fed into the circuit and measured. So the embedding output has a size of $N \times 4$ (same as the input size).

### Update functions

We will use only two update functions here: Node level and global level. This is due to the fact the we do not have any edge attributes. The update functions are nothing but neural networks. Here we use hybrid neural networks.

First a `fully connected layer` is applied followed by `tanh` activation function. Since the result is between [-1,1], we will feed this into the quantum circuit and measure it. One can specify the number of message passing steps needed.

### Classification

A final `fully connected layer` is applied down to size 1. Binary cross entropy is used a loss function with Adam optimizer.

### Training the Hybrid Quantum-Classical Graph Neural Network

In [1]:
!python train.py --config=configs/default_graph_net.py

I0401 22:40:11.769851 140239509645120 train.py:292] Model GraphNet
I0401 22:40:11.772162 140239509645120 xla_bridge.py:355] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
I0401 22:40:11.772270 140239509645120 xla_bridge.py:355] Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0401 22:40:11.772315 140239509645120 xla_bridge.py:355] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0401 22:40:11.772682 140239509645120 xla_bridge.py:355] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
I0401 22:40:11.772755 140239509645120 xla_bridge.py:355] Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
W0401 22:40:11.772795 

I0401 22:40:16.534985 140239509645120 train.py:234] Starting training.
I0401 22:40:33.432315 140239509645120 train.py:254] Finished training step 1.
I0401 22:40:33.643307 140239509645120 train.py:254] Finished training step 2.
I0401 22:40:33.853262 140239509645120 train.py:254] Finished training step 3.
I0401 22:40:34.107882 140239509645120 train.py:254] Finished training step 4.
I0401 22:40:34.312197 140239509645120 train.py:254] Finished training step 5.
I0401 22:40:34.511267 140239509645120 train.py:254] Finished training step 6.
I0401 22:40:34.751202 140239509645120 train.py:254] Finished training step 7.
I0401 22:40:34.992794 140239509645120 train.py:254] Finished training step 8.
I0401 22:40:35.204037 140239509645120 train.py:254] Finished training step 9.
I0401 22:40:35.407538 140239509645120 train.py:254] Finished training step 10.
I0401 22:40:38.779051 140239509645120 local.py:50] Created artifact [10] Profile of type ArtifactType.URL and value None.
I0401 22:41:02.223277 1402

I0401 22:49:25.913345 140230396991040 logging_writer.py:48] [1600] val_accuracy=0.7620000243186951, val_auc=0.844237, val_loss=0.49259793758392334, val_mean_average_precision=0.835882
I0401 22:49:34.171748 140239509645120 local.py:41] Setting work unit notes: 3.6 steps/s, 32.7% (1633/5000), ETA: 15m (9m : 8.6% eval)
I0401 22:49:34.172178 140230396991040 logging_writer.py:48] [1633] steps_per_sec=3.611194
I0401 22:49:34.172849 140230396991040 logging_writer.py:48] [1633] uptime=557.637057
I0401 22:49:51.384982 140230396991040 logging_writer.py:48] [1700] train_accuracy=0.7631799578666687, train_loss=0.5032984018325806
I0401 22:49:53.757784 140230396991040 logging_writer.py:48] [1700] val_accuracy=0.7641500234603882, val_auc=0.844438, val_loss=0.49152037501335144, val_mean_average_precision=0.836055
I0401 22:50:16.884170 140230396991040 logging_writer.py:48] [1800] train_accuracy=0.761698842048645, train_loss=0.505814790725708
I0401 22:50:19.467192 140230396991040 logging_writer.py:48] [

I0401 22:57:35.112144 140239509645120 local.py:41] Setting work unit notes: 4.1 steps/s, 69.2% (3460/5000), ETA: 6m (17m : 9.0% eval)
I0401 22:57:35.112529 140230396991040 logging_writer.py:48] [3460] steps_per_sec=4.142264
I0401 22:57:35.113233 140230396991040 logging_writer.py:48] [3460] uptime=1038.577412
I0401 22:57:43.902540 140230396991040 logging_writer.py:48] [3500] train_accuracy=0.7654215097427368, train_loss=0.4943298101425171
I0401 22:57:46.344298 140230396991040 logging_writer.py:48] [3500] val_accuracy=0.7674499750137329, val_auc=0.844710, val_loss=0.48948872089385986, val_mean_average_precision=0.836336
I0401 22:58:07.796641 140230396991040 logging_writer.py:48] [3600] train_accuracy=0.765544056892395, train_loss=0.49886035919189453
I0401 22:58:10.505812 140230396991040 logging_writer.py:48] [3600] val_accuracy=0.7674499750137329, val_auc=0.844539, val_loss=0.48834970593452454, val_mean_average_precision=0.836222
I0401 22:58:31.478740 140230396991040 logging_writer.py:48

### Results

The results were saved with tensorboard in `logs` directory. For convenience, the plots are shown below.

![train_loss](../images/train_loss.png)
![val_loss](../images/val_loss.png)
![train_acc](../images/train_acc.png)
![val_acc](../images/val_acc.png)
![val_auc](../images/val_auc.png)

## Conclusion

We constructed a Hybrid Quantum-Classical Graph Neural Network using Jraph and Pennylane. We were able to achieve 0.84 AUC score on test set and that too without any hyperparmeter tuning. The AUC obtained is closer to the result from [1] (they get an AUC of 0.90) and also with the classical counterpart which we implemented in Task II (0.86 AUC). We need to fine-tune the model with different ansatzes.

## References

1. Komiske, P.T., Metodiev, E.M. & Thaler, J. Energy flow networks: deep sets for particle jets. J. High Energ. Phys. 2019, 121 (2019). https://doi.org/10.1007/JHEP01(2019)121