# Deep sets and graph networks

Notebook based on [DeepSetsAndGraphNetworks.ipynb from LMU course](https://github.com/fuenfundachtzig/LMU_DA_ML/blob/master/DeepSetsAndGraphNetworks.ipynb)

The ML models we have looked at so far make the assumption that we have a fixed-dimensional vector of input features. In reality that might not always be the case. Some examples:

* Sequences (text, audio, video)
* Point clouds (e.g. points in 3D space)
* Lists of objects (e.g. particles in a collision)
* Graphs with different numbers of nodes and different numbers of connections for each node

For sequences one approach are recurrent neural networks (RNNs) that utilize a state that gets updated as it iteratively processes input. However, these still need a defined ordering of the inputs and they have certain disadvantages (most prominently difficulty to model "long-range" correlations between inputs and difficulty to parallelize since they are sequential in nature).

Another approach are models that apply **permutation invariant** transformations on the inputs. Both deep sets and graph networks make use of this. The nowadays (2023) also very popular [**transformers**](https://arxiv.org/abs/1706.03762) can be viewed as graph networks where all nodes are connected to each other.



In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import RobustScaler
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow.keras.layers import Dense, Input, GlobalAveragePooling1D, Masking
from tensorflow.keras.callbacks import History

## Graph convolutions/Graph neural networks

Similar to convolutional networks where we update the state of each pixel by aggregating over neigboring pixels we can perform a *graph convolution* by aggregating over neighboring nodes in a graph:

![cnn vs gcn](figures/cnn_vs_gcn.jpg)

(figure from https://zhuanlan.zhihu.com/p/51990489)

In the "Deep sets" language such a graph convolution corresponds to a *permutation equivariant* tranformation of the set of nodes, since it also does not depend on the ordering if the aggregation is done in a permutation invariant way (e.g. sum/mean/min/max).

A rather simple implementation is given by the update rule introduced in [arXiv:1609.02907](https://arxiv.org/abs/1609.02907)

$ H^{(l+1)} = \sigma(\tilde{D}^{-\frac{1}{2}}\tilde{A}\tilde{D}^{-\frac{1}{2}}H^{(l)}W^{(l)}) $

where $A$ is the *adjacency matrix*, $D$ the *degree matrix*,  $H^{(l)}$ the hidden state of layer $l$ and $W^{(l)}$ the weight matrix of the layer $l$. The tilde above $A$ and $D$ indicates that self-loops were added (all nodes are neighbors of themselves).

An equivalent formulation is

$ h_i^{(l+1)} = \sigma\left(\sum\limits_{j\in\mathcal{N}(i)}\frac{1}{c_{ij}}h^{(l)}_j W^{(l)}\right) $

where $ \mathcal{N(i)} $ is the set of neighbors of node $i$ and $c_{ij} = \sqrt{N_i}\sqrt{N_j}$ with $N_i$ being the number of neigbors of node $i$

In [None]:
def normalize_adjacency(adj):
    """
    calculate outer product of sqrt(degree vector) and multiply with adjaceny matrix
    
    this corresponds to the D^{1/2}AD^{1/2} normalization suggested in Kipf & Welling (arXiv:1609.02907)
    """
    deg_diag = tf.reduce_sum(adj, axis=2)
    deg12_diag = tf.where(deg_diag > 0, deg_diag ** -0.5, 0)
    return (
        tf.matmul(
            tf.expand_dims(deg12_diag, axis=2),
            tf.expand_dims(deg12_diag, axis=1),
        )
        * adj
    )

In [None]:
class GraphConv(tf.keras.layers.Layer):
    """
    Simple graph convolution. Should be equivalent to Kipf & Welling (arXiv:1609.02907)
    """

    def __init__(self, units, activation="relu"):
        super().__init__()
        self.dense = tf.keras.layers.Dense(units)
        self.activation = tf.keras.activations.get(activation)

    def call(self, inputs):
        feat, adjacency = inputs
        return self.activation(tf.matmul(normalize_adjacency(adjacency), self.dense(feat)))

One question is now - what is the graph in our dataset? It might make sense to define the graph by taking a certain number of nearest neighbors in the $\eta-\phi$ plane as used to define the image pixels.
We prepared adjacency matrices for 7 nearest neigbors:

In [None]:
top_tagging_path = "top_tagging_with_adjacency.npz"
import os
if not os.path.exists(top_tagging_path):
    import requests
    url = "https://cloud.physik.lmu.de/index.php/s/AtESAET6JK6DiWZ/download"
    res = requests.get(url)
    with open(top_tagging_path, "wb") as f:
        f.write(res.content)

In [None]:
npz_file = np.load(top_tagging_path)

In [None]:
X = npz_file["jet_4mom"]
y = npz_file["y"]
A = npz_file["adj"]

In [None]:
X.shape

In [None]:
def ptetaphi(X):
    px = X[..., 1]
    py = X[..., 2]
    pz = X[..., 3]
    pt = np.hypot(px, py)
    eta = np.arcsinh(pz / pt)
    phi = np.arcsin(py / pt)
    return np.stack([pt, eta, phi], axis=1)

In [None]:
def plot_graph(x, a):
    plt.figure(figsize=(12, 8))
    nconst = (~(a == 0).all(axis=-1)).sum()
    print(f"{nconst=}")
    x = x[:nconst]
    x = ptetaphi(x)
    plt.scatter(x[:, 1], x[:, 2], s=100)
    for i in range(nconst):
        for j in range(nconst):
            if a[i, j] or a[j, i]:
                plt.plot([x[i, 1], x[j, 1]], [x[i, 2], x[j, 2]], color="C0")

Let's plot a few random graphs:

In [None]:
i = np.random.randint(0, len(X))
plot_graph(X[i], A[i])

In [None]:
class JetScaler:
    def __init__(self, mask_value=-999):
        self.mask_value = mask_value
        self.scaler = RobustScaler()
    
    def fill_nan(self, X):
        "replace missing values by nan"
        X[(X == self.mask_value).all(axis=-1)] = np.nan
        
    def fit(self, X):
        X = np.array(X) # copy
        self.fill_nan(X)
        X = X.reshape(-1, X.shape[-1]) # make 2D
        self.scaler.fit(X)
        
    def transform(self, X):
        orig_shape = X.shape
        X = np.array(X).reshape(-1, X.shape[-1])
        self.fill_nan(X)
        X = self.scaler.transform(X)
        X = np.nan_to_num(X, 0) # replace missing values by 0
        return X.reshape(*orig_shape) # turn back into 3D

In [None]:
X_train, X_test, y_train, y_test, A_train, A_test = train_test_split(X, y, A)
scaler = JetScaler(mask_value=0)
scaler.fit(X_train)
X_train = scaler.transform(X_train)
X_test = scaler.transform(X_test)

In [None]:
def get_model(units=100, num_nodes=200, num_features=4):
    adjacency_input = Input(shape=(num_nodes, num_nodes), name='adjacency')
    feature_input = Input(shape=(num_nodes, num_features), name='features')

    # constituent-level transformations
    p = feature_input
    for i in range(3):
        p = Dense(units, activation="relu")(p)

    for i in range(3):
        p = GraphConv(units, activation="relu")([p, adjacency_input])

    x = GlobalAveragePooling1D()(p)

    # event-level transformations
    for i in range(3):
        x = Dense(units, activation="relu")(x)

    output = Dense(1, activation="sigmoid")(x)

    return tf.keras.models.Model(
        inputs=[adjacency_input, feature_input],
        outputs=[output]
    )
model = get_model()

In [None]:
tf.keras.utils.plot_model(model, show_shapes=True)

**Interactive plotting needs extra tools, here the resulting plot:**

![](figures/keras_GN_model.png)

In [None]:
model.compile(loss="binary_crossentropy", optimizer="Adam")

In [None]:
history = History()

In [None]:
model.fit(
    {"features": X_train, "adjacency": A_train},
    y_train,
    validation_split=0.2,
    epochs=10,
    batch_size=32,
    shuffle=True,
    callbacks=[history]
)

In [None]:
pd.DataFrame(history.history).plot()

In [None]:
scores = model.predict({"features": X_test, "adjacency": A_test})

In [None]:
from sklearn.metrics import roc_curve
fpr, tpr, thr = roc_curve(y_test, scores)

In [None]:
def plot_top_tagging_performance(fpr, tpr):
    plt.plot(tpr, 1. / fpr)
    plt.ylabel("QCD jet rejection")
    plt.xlabel("Top quark jet efficiency")
    plt.yscale("log")
    plt.grid()

    print("Top quark jet selection efficiency at 10^3 QCD jet rejection: ", np.max(tpr[fpr < 0.001]))
    print("QCD jet rejection at 30% Top quark jet efficiency: ", 1. / np.min(fpr[tpr > 0.3]))


In [None]:
plot_top_tagging_performance(fpr, tpr)

Some Notes:

- We made it quite hard here for the neural network by putting in really the raw 4-momentum information
- Possible improvements:
  - Go to the $\eta-\phi$ plane
  - Transform coordinates to be relative to the jet center
  - Use graph operations that depend on the distance between points instead of absolute position (e.g. [EdgeConv](https://arxiv.org/abs/1801.07829))
  - just train longer and/or on more data (we only used 10k samples)

# Further possibilities

We only touched the surface of what is possible with graph neural networks. In general, you can have arbitrary update rules that update in each step features of Nodes (V), Edges (e) and global aggregated features (u). Everyone of these 3 categories can receive input from any of the others:

![graph network general update rule](figures/graph-network.png)

(figure from [arXiv:1806.01261](https://arxiv.org/abs/1806.01261))

More info/tutorials:

http://tkipf.github.io/graph-convolutional-networks/  
https://docs.dgl.ai/tutorials/models/1_gnn/1_gcn.html  
https://docs.dgl.ai/generated/dgl.nn.pytorch.conv.GraphConv.html#

For more advanced applications with graph neural networks have a look at specialized libraries:

[Spektral (tensorflow)](https://graphneural.network/)  
[DGL (mainly pytorch, but also tensorflow)](https://dgl.ai)  
[PyTorch Geometric](https://pytorch-geometric.readthedocs.io)

<div class="alert alert-warning">
If you actually want to implement graph networks, better consult these instead of manually building them. The examples in this tutorial are meant for educational purposes!
</div>

---

This top tagging dataset is a widely used benchmark dataset and used to test, compare, and optimize many different ML models, see 
[The Machine Learning Landscape of Top Taggers](https://arxiv.org/abs/1902.09914)

A nice plot illustrating this (from G. Kasieczka):

![top-tag-results.png](figures/top-tag-results.png)