# Graph Neural Networks

Graph neural networks are a recent development in the field of machine learning and allow our networks to exploit topological structures *between* data points rather than implicitly learning some abstract relational patterns through training. This is useful because many data that we wish to operate on has a natural and known graphical structure that will inform relational patterns in some straightforward way: chemical bonds, information flow in a social network etc. By including this topological data into the network itself we can improve the quality of our predictions.

More generally, graphs are a generalisation of data and graph neural networks are a generalisation of traditional neural networks. A recent body of work has shown that all of our existing neural network architectures can be expressed as a graph neural network. This field has been called [Geometric Deep Learning](https://arxiv.org/abs/2104.13478) and is exciting new vector of research.


## GraphNeuralNetworks.jl


In Julia there is excellent package support for Graph Neural Networks (GNNs) through the `GraphNeuralNetworks.jl` package. It operates in a semantically familiar way to `Flux.jl` but the API calls are often specific to GNNs whereas `Flux` is more general.

In [57]:
using GraphNeuralNetworks, Flux

## Graphs

A graph is the fundamental mathematical structure from which GNNs get their name. It is a collection of indexes (nodes) and relationships between them (edges). These are typically denoted $V$ and $E$. An edge is defined by two nodes and may be both weighted and directional. A weighted edged is indicated by some real number and may encode some relationship between the nodes. For example, in the molecule $\text{H}_2\text{SO}_4$ there are some oxygen molecules with double bonds to the sulfur and some with single bonds to the sulfur and hydrogen respectively and we would describe these oxygen bonds with edges of weights 2 and 1 respectively.

![H2S04](./images/h2so4.png) 
*The graph of the sulfuric acid molecule.*

The bonds in this molecule are bidirectional but many data that we are interested in is unidirectional. Consider a virus that is spreading through a population. Each individual (node) can pass the virus to other individuals in the network (graph) based on many factors: spatial proximity, symptoms (coughing), immunosenstivity etc. However, the probability of transmission is conditioned on sending node containing the virus and the receiving node *not* containing the virus: the disease can only spread one way and the graph is therefore unidirectional.

## Adjacency Matrices

The information about a graphs structure can be encoded in a relatively compact an intuitive structure: the adjancency matrix. The matrix is constructed by indexing each of the nodes in the graph with natural numbers $(1 \dots n)$ and associating these with the rows and columns of an $n \times n$ matrix $A$. The edge weights are encoded in the matrix values $A_{ij}$. For bidirectional edges this matrix will be symmetric: $W_{ij}$.

## Graphical Data

In addition to the edge relationship graph nodes may contain other information in real world networks. In the $\text{H}_2\text{SO}_4$ example we implicity encoded the the atomic information into the node: we specified nodes 1 and 2 were H and so forth. More generally, we encode $m$ data features of interest for $n$ into a feature matrix $F$ with $m$ rows and $n$ columns. The two matrices $F$ and $A$ define our graphical data and completely specify the network. It is these matrices that we will operate on to train our neural networks.

## Graph Objects in GraphNeuralNetworks

Let's create the following simple graph in the GraphNeuralNetworks framework. We will specify the graph with its adjacency matrix. To each of the nodes we will attach a feature vector of length 4 consisting of random data. The graph object will be specified with the `GNNGraph(adjmat; ndata=featuremat)` API call.

![graph_labels.png](./images/graph_labels.png)

In [58]:
amat = [
 0  1  0  0  0  0  0  0  0  0;
 1  0  1  1  0  1  0  0  0  0;
 0  1  0  0  0  0  0  0  0  0;
 0  0  0  0  0  1  0  0  0  0;
 0  0  0  0  0  0  1  0  0  0;
 0  0  0  0  0  0  0  1  1  0;
 0  0  0  0  1  0  0  0  0  0;
 0  0  0  0  0  0  0  0  0  0;
 0  0  0  0  0  0  1  0  0  1;
 0  0  0  0  0  0  0  0  1  0;
    ]
fmat = rand(4, 10)
g = GNNGraph(amat; ndata=fmat)

GNNGraph:
    num_nodes = 10
    num_edges = 14
    ndata:
        x => 4×10 Matrix{Float64}

# GNN Architectures

What precisely is a Graph Neural Network (GNN)? It is simply a collection of weights and biases which update the state of a node using both the features of the node *and* the states of the neighbouring nodes in the graph. This allows information to flow through the graph. This is also not a new idea - early neural networks adopted this and more fundamentally this is precisely how biological neural networks operate (in a simplified fashion). There are generally three major architectures that are referred to when we talk about GNN architectures:

1. Convolutional GNNs
2. Attentional GNNs
3. Message Passing GNNs


## Convolutional GNNs

Convolutional GNNs (CGNNs) are perhaps the predominate form of graph neural networks and operate by integrating the states of the direct neighbours of each node after modification by a shared weight matrix $W$. The weight matrix embeds the features $x_i$ into a latent space and these are summed over all the edges often with a normalisation factor $\alpha_{ij}$ and then passed through an activation function $f$ to generate the state $h_i$:

$$h_i = f\left(\sum_{e \in \text{edges}(i)} \alpha_{ij} W x_j\right)$$

By summing over all of the neighbours we are performing a convolution and this is where the architecture derives its name. Note that it does not have to be a sum to be a convolution - any permutation invariant function will do. The sum operation happens to be the most popular. 

The most used form of CGNNs are ConvNets applied to images. Here the topology is defined by each pixels immediate neighbours in the Euclidean sense and the features of the input data are the RGB channels. This topology is very regular and allows the network to capture spatial relationships and rotational/translational symmetries. Neighbourhoods need not be defined by spatial relationships and this flexibility makes CGNNs useful exploratory networks.

A CGNN layer is implemented with the `CGNConv` API call and is specified with `Lin=>Lout` where `Lin` is the dimension of the input features and `Lout` is the dimension of the output latents. The activation function is specified by the optional `σ = actfun` argument where the default is the identity function. The scaling is calculated with $\alpha_{ij} = 1/\sqrt(|d(i)||d(j)|$ where $d(i)$ is the degree of the $i$'th node. `bias=true` will specify an optional bias.

In [59]:
gconv_layer = GCNConv(10=>8, x->tanh.(x))

GCNConv(10 => 8, #65)

## Attentional GNNs

Attentional GNNs allow attention mechanisms, which have been successful in analysing sequence data, to be generalised into GNNs. At its core this amounts to allowing the normalising weights between nodes $\alpha_{ij}$ to be learnable. This allows nodes to pay attention to certain nodes information flow while ignoring others.
For example, your social network might include a supervisor and a family member. On the task of advice on a thesis you might like to pay more attention to your supervisor than your family member, but for cooking a traditional dinner the family member might be a better bet. The state update rule now simply includes trainable weights with feature data to facilitate the attention mechanism.


$$h_i = f\left(\sum_{e \in \text{edges}(i)} \alpha(x_i, x_j) W x_j\right)$$


A common attentional implementation is given by the `GATConv` API call. The number of attentional heads (analagous to attentional heads in the Transformer architecture) is given by the heads keyword. An optional specification is the `negative_slope` keyword which defaults to `0.2`. 

In [60]:
gat_layer = GATConv(8=>4; heads=2)

GATConv(8 => 2, negative_slope=0.2)

## Message Passing GNNs
Message passing GNNs are the most generic implementation of a neural network. A message vector is computed using both node features and edge features between two different nodes. The messages are then propogated to neighbours and used to update the state of both the edges and the nodes. This is done with the `propogate(function, graph, reduction; xi=targets, xj=sources, e=edges)`. For example, a generic convolutional network layer may be specified on our original graph and feature matrix `fmat` as follows:

In [61]:
W = rand(14,4)
message(xi, xj, e) = W * xj
x = GraphNeuralNetworks.propagate(message, g, +, xj=fmat)

14×10 Matrix{Float64}:
 0.785263  1.79316   0.785263  0.785263  …  0.71739   1.56053   0.822374
 0.848053  2.32327   0.848053  0.848053     0.780251  1.96444   1.10054
 1.24747   1.98806   1.24747   1.24747      1.18866   2.10074   0.817065
 1.14504   2.2147    1.14504   1.14504      1.27851   2.71648   0.968195
 1.05845   1.87567   1.05845   1.05845      1.20024   2.48849   0.790413
 0.328097  0.753836  0.328097  0.328097  …  0.333672  0.749137  0.347043
 0.573464  1.50686   0.573464  0.573464     0.647398  1.66233   0.700647
 0.589453  1.58851   0.589453  0.589453     0.62694   1.56731   0.756376
 0.486401  2.21699   0.486401  0.486401     0.487106  1.78054   1.13259
 1.04397   2.5399    1.04397   1.04397      1.07237   2.50611   1.18295
 0.873112  2.63625   0.873112  0.873112  …  0.929787  2.53615   1.27136
 0.526614  2.01766   0.526614  0.526614     0.611734  1.94948   1.00614
 0.636326  1.73268   0.636326  0.636326     0.69344   1.73709   0.830054
 1.05152   1.88175   1.05152   1.

By creating generically complex message functions we can instantiate any existing GNN architecture. It is also through this method that we can implement our own custom layers.

## Pooling

Pooling is a useful strategy that has come through convolutional neural networks. It allows layers to be condensed to representative values over pools of candidate parameters. In the convolutional network, for example, a pool is defined by a filter size and this is passed over each feature layer taking a representative value along every stride of the layer e.g. mean/max of all values. Some common pooling layers are:

* Top-K pooling: selecting the top k nodes in a neighbourhood over all graph nodes. Implemented with the `TopK(adj, k, input_channel)` API where adj is the adjacency, k is the number of nodes, and input_channel is the dimension along which to pool.
* GlobalPool: performing an aggegration over the entire graph to reduce it to a single vector along the feature dimensions e.g. averaging along all nodes features. Implemented through the API call `GlobalPool(func)` e.g. `GlobalPool(mean)`.


## GNN Models

We implement these architectures as layers in our GNN which are composed together using the API call `GNNChain`. The chain forms our model and we input data in the form of graph objects with feature vectors. The `GraphNeuralNetworks` package has GPU support through CUDA and models are migrated to the GPU in the same way as Flux using the `gpu(model)` call or by piping operator: `model |> gpu`. Flux layers can be composed into `GNNChains` (provided they are composable) e.g. `BatchNorm`. 

In [62]:
model = GNNChain(gconv_layer, gat_layer, BatchNorm(2))

GNNChain(GCNConv(10 => 8, #65), GATConv(8 => 2, negative_slope=0.2), BatchNorm(2))

## Training

Training a GNN proceeds in a very similar fashion to a regular neural network - they use the regular optimisation and regularisation techniques adopted for earlier neural networks. This is intuitive in that these GNNs are merely generalised functions with understandable outputs and which an optimisation goal can be defined on. Therefore, the generic optimisation methods should automatically generalise. There are still some considerations that need to be kept in mind when training GNNs.

## Loss Functions

The loss function is a salient consideration of any network design and GNNs are not exempt from this. The loss function is task specific and should be considered in a problem-by-problem basis. For GNNs in particular the loss function must be *permutation-invariant*. Specifically, if we were to relabel the indexes in the graph the loss must not change, because the fundamental graph topology will not change if we do this. Any loss can therefore not be index dependent nor can it be non-commutative in its indexes e.g. the product of two numbers are commutative, the product of two matrices are not. The common ones such as `mse` are usual candidates for loss functions.


## Batching

To batch data compose multiple graphs into a single graph using the `batch` API call. This has a natural intuition: the disjoint subgraphs of a graph behave as independent graphs. To see this easily consider the block structure of the adjancency matrix of a graph with several disjoint subgraphs. Once the graph data has been batched we can use the `DataLoader` API call from Flux to shuffle and batch data points as usual.

# A Graph Classification Example

We will use the TUDataset of graphs with classified labels to perform a simple graph classification task. Specifically we will analyse the PROTEINS dataset which labels graphs derived from protein structures as ezymes or non-ezyments. We can find this in the `MLDatasets` package and use a convenience function `mldataset2gnngraph` to convert it to the format required for `GraphNeuralNetworks`.

In [195]:
using Statistics, ProgressMeter
using MLDatasets: TUDataset
data = MLDatasets.TUDataset("PROTEINS")
graphdata = mldataset2gnngraph(data)

# extract the feature matrix of the data test and define the loss for target training
G = [GNNGraph(g, ndata=g.ndata.features) for g in graphdata]
y = Float32.((data.graph_data.targets .- 1))

# Split the data into testing and training and create a dataloader object
train_data, test_data = Flux.splitobs((G, y), at=0.8, shuffle=true)
train_loader = Flux.DataLoader(train_data; batchsize=32, shuffle=true, collate=true)
test_loader = Flux.DataLoader(test_data; batchsize=32, shuffle=true, collate=true)

7-element DataLoader(::Tuple{SubArray{GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}, 1, Vector{GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}}, Tuple{Vector{Int64}}, false}, SubArray{Float32, 1, Vector{Float32}, Tuple{Vector{Int64}}, false}}, shuffle=true, batchsize=32, collate=Val{true}())
  with first element:
  (GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}, 32-element Vector{Float32},)

The model will be defined quite simply: three layers of convolutional architecture with followed by a global pooling layer to take it to a single vector. The classification will be performed by a dense classifier onto the targets (0,1). The activation for each of the convolutional layers will be the `relu` function and the latent dimensions will be 4, 10, and 2. Pooling will be performed by `mean`.

In [196]:
dim1 = size(G[1].ndata.x, 1)
dim2 = 10
dim3 = 100

model = GNNChain(
    GCNConv(dim1=>dim2, relu),
    GCNConv(dim2=>dim3, relu),
    GlobalPool(mean),
    Dense(dim3, 1, σ)
)

GNNChain(GCNConv(1 => 10, relu), GCNConv(10 => 100, relu), GlobalPool{typeof(mean)}(Statistics.mean), Dense(100 => 1, σ))

The training proceeds in the normal fashion: we specify the trainable parameters, define an optimiser (in this case, Adam), and an appropriate loss. The loss is chosen to be `logitcrossentropy`. We train our model for 100 epochs. We then validate our training on our test dataset. 

In [197]:
ps = Flux.params(model)
opt = Adam(0.001)
loss(s, t) = Flux.logitcrossentropy(vec(model(s, s.ndata.x)), t)

@time @showprogress for e in 1:100
    Flux.train!((x,y) -> loss(x,y), ps, train_loader, opt)
end

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:18:07[39m


1087.851435 seconds (13.10 M allocations: 49.187 GiB, 0.68% gc time, 59.36% compilation time)


2-element Vector{Float64}:
 65.39
 65.92

In [207]:
function test_accuracy(model, data)
    accuracy = 0
    total = 0
    for (g, y) in data
        n = length(y)
        res = vec(model(g, g.ndata.x))
        accuracy += mean((res .> 0.5) .== y) * n
        total += n
    end
    return round(accuracy*100/total, digits=2)
end

test_accuracy (generic function with 1 method)

In [None]:
[test_accuracy(model, train_loader), test_accuracy(model, test_loader)]