## Tutorial 2 : Build Graph Neural Networks with PyG

In this tutorial, we will learn how to build a graph neural network with PyG.
PyG offers various handy features when it comes to build GNNs, including
- An extended `Sequential` module that can be used to build GNN
- Pre-implemented and, also, optimized graph convolutional layers
- Graph Neural Network implementations

### A limitation of `torch.nn.Sequential`

In native PyTorch, `torch.nn.Sequential` is a handy module that allows us to build a neural network in a sequential manner. For example, we can build a simple MLP with `torch.nn.Sequential` as follows:

```python
mlp = nn.Sequential(
    nn.Linear(32, 32),
    nn.ReLU(),
    nn.Linear(32, 1),
)
```

`Sequential` class minimizes the boiler plate code for implementing `forward` methods. 
We can implement equivalent MLP without using `Sequential` as follows:

```python
import torch.nn as nn

class MLP(nn.Module):
    
    def __init__(self):
        super(MLP, self).__init__()
        self.layers = nn.ModuleList([nn.Linear(32, 32), 
                                     nn.ReLU(), 
                                     nn.Linear(32, 1)])
    
    def forward(self, x):        
        for layer in self.layers:
            x = layer(x)
        return x

mlp = MLP()
```



However, `torch.nn.Sequential` has a limitation that each layer should have only one input and output.
This limitation becomes a problem when it comes to building a graph neural network. For instance, when `INLayer` takes
two inputs, node and edge features and return two outputs updated node and edge features. Hence
it is less trivial to build a graph neural network with `torch.nn.Sequential`.

### Using `pytorch_geometric.nn.Sequential` to build GNN

`pytorch_geometric.nn.Sequential` is an extended version of `torch.nn.Sequential` that allows us to build a graph neural network in a sequential manner. Let's see how we can build a graph neural network with `pytorch_geometric.nn.Sequential`.

In [14]:
from torch_geometric.nn import Sequential
from torch_geometric.data import Batch

from common.graph_gen import generate_random_graph
from common.layers import InteractionNetworkLayer

print(help(InteractionNetworkLayer.forward))

Help on function forward in module common.layers:

forward(self, x, edge_index, edge_attr) -> Tuple[torch.Tensor, torch.Tensor]
    Runs the forward pass of the module.

None


In [15]:
dim = 5
model = Sequential("x, edge_index, edge_attr", # input
                   [
                       (InteractionNetworkLayer(dim), "x, edge_index, edge_attr -> x, edge_attr"), 
                       (InteractionNetworkLayer(dim), "x, edge_index, edge_attr -> x, edge_attr"),
                   ]
)

In [16]:
gs = Batch.from_data_list([generate_random_graph(5 * (i+1),
                                                 node_feat_dim=dim,
                                                 edge_feat_dim=dim) for i in range(3)])

In [17]:
unf, uef = model(gs.x, gs.edge_index, gs.edge_attr) # update node feature (unf), updated edge feature (uef)

## Pre-implemented graph convolutional layers in PyG

PyG offers various pre-implemented graph convolutional layers. Let's see how we can use them.
In this tutorial, we will check three iconic graph convolutional layers, `GCNConv` and ` SAGEConv`.
The exhaustive list of implemented graph convolutional layers can be found [here](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#convolutional-layers).

### Implementing `GCNConv` in PyG

`GCNConv` is a graph convolutional layer proposed in [Semi-Supervised Classification with Graph Convolutional Networks](https://arxiv.org/abs/1609.02907). In PyG, we can implement Graph Convolutional Network (GCN) in various way.
We will check two different ways to implement GCN in PyG.
- Using `GCNConv` layer with `Sequential` module
- Using 'models.GCN' class

In [18]:
from torch_geometric.nn import GCNConv

out_dim = 13

# Construct GCN layer (i.e., GCNConv)
gcn_conv = GCNConv(dim, out_dim)
updated_x = gcn_conv(gs.x, gs.edge_index)

print(f'Input node feature size: {gs.x.shape}')
print(f'Output node feature size: {updated_x.shape} \n')

# Construct GCN by stacking GCNConv using Sequential
gcn = Sequential("x, edge_index", 
                 [(GCNConv(dim, dim), "x, edge_index -> x"),
                  (GCNConv(dim, dim), "x, edge_index -> x"),
                  (GCNConv(dim, dim), "x, edge_index -> x"),
                ]
)

# Or equivalently
# num_layers = 3
# gcn = Sequential("x, edge_index", [(GCNConv(dim, dim), "x, edge_index -> x") for _ in range(num_layers)])
# print(gcn)

print(f'Model spec: \n {gcn} \n')

# GCN forward
gcn_out = gcn(gs.x, gs.edge_index)
print(f'Input node feature size: {gs.x.shape}')
print(f'GCN output node feature size: {gcn_out.shape}')

Input node feature size: torch.Size([30, 5])
Output node feature size: torch.Size([30, 13]) 

Model spec: 
 Sequential(
  (0): GCNConv(5, 5)
  (1): GCNConv(5, 5)
  (2): GCNConv(5, 5)
) 

Input node feature size: torch.Size([30, 5])
GCN output node feature size: torch.Size([30, 5])


### Construct GCN using `torch_geometric.nn.models.GCN`

`PyG` provides pre-implemented famous GNN models with the enhanced features and code-level optimizations.
`torch_geometric.nn.models.GCN` is one of the pre-implemented GCN in `PyG`. Using this we can build a GCN, by simplying
calling `models.GCN` class.

In [19]:
from torch_geometric.nn.models import GCN

gcn = GCN(in_channels=dim, 
          hidden_channels=dim, 
          out_channels=dim, num_layers=3)

gcn_out = gcn(gs.x, gs.edge_index)
print(f'Input node feature size: {gs.x.shape}')
print(f'GCN output node feature size: {gcn_out.shape}')

Input node feature size: torch.Size([30, 5])
GCN output node feature size: torch.Size([30, 5])


## Implementing Graph SAGE with PyG

Graph SAGE is a graph convolutional layer proposed in [Inductive Representation Learning on Large Graphs](https://arxiv.org/abs/1706.02216). 

In [20]:
from torch_geometric.nn import SAGEConv

sage_conv = SAGEConv(in_channels=dim, 
                     out_channels=dim,
                     aggr='mean')
print(sage_conv)
sage_out = sage_conv(gs.x, gs.edge_index)
print(f'Input node feature size: {gs.x.shape}')
print(f'Graph SAVE output node feature size: {sage_out.shape}')

SAGEConv(5, 5, aggr=mean)
Input node feature size: torch.Size([30, 5])
Graph SAVE output node feature size: torch.Size([30, 5])


In [21]:
from torch_geometric.nn.models import GraphSAGE

sage_conv_kwargs = {
    'aggr': 'mean'
}
graph_sage = GraphSAGE(in_channels=-1, # '-1' let the model infer the input dimension from the first forward!
                       # THis can be a handy feature but not recommended for readability and reproducibility
                       hidden_channels=dim,
                       out_channels=13,
                       num_layers=3,
                       **sage_conv_kwargs)
graph_sage                       

GraphSAGE(-1, 13, num_layers=3)

In [22]:
graph_sage_out = graph_sage(gs.x, gs.edge_index)
print(f'Input node feature size: {gs.x.shape}')
print(f'Graph SAVE output node feature size: {graph_sage_out.shape}')

Input node feature size: torch.Size([30, 5])
Graph SAVE output node feature size: torch.Size([30, 13])


## Graph Readout and Pooling

So far, we've learned how to build a graph neural network with PyG. The graph neural network $f_\theta$ generically takes a graph $\mathcal{G}=(X,E)$ and returns the updated graph $\mathcal{G}'=(X',E')$ as follows:
$$
\mathcal{G}' = f_\theta(\mathcal{G})
$$

However, for some tasks, we want to map the graph to a single vector (e.g., Graph property prediction tasks, where input is a graph ans output is the scalar-represented values). In this case, we need to aggregate the node features into a single vector. This process is called graph pooling or graph readout. In this tutorial, we will learn how to implement graph pooling with PyG.

## A very simple pooling method; SumPooling

The simplest way to aggregate node features into a single vector is to sum up all the node features. This method is called SumPooling. Mathematically, SumPooling can be defined as follows:
$$
x_\mathcal{G} = \sum_{i \in \mathcal{N}} x_i
$$
where $\mathcal{N}$ is the set of nodes in the graph $\mathcal{G}$. We can also consider to pool edges features (if exists) in a similar fashion. 

Okay, why don't we the sum pooling as follows:

```python

import torch

num_graphs = 3
num_nodes = 5
hidden_dim = 12

h = torch.randn(num_graphs, num_nodes, hidden_dim)
aggr = h.sum(dim=1) # perform summation along the first dimension
print(aggr.shape) # torch.Size([3, 12])
```

Unfortunately, it is often impossible batching node features in a Tensor as the graphs in the batch often have different number of nodes. Therefore, we need to implement a custom SumPooling layer that can handle a batch of graphs with different number of nodes. In PyG, we can implement SumPooling as follows:

In [23]:
from torch_geometric.nn.pool import global_add_pool, global_mean_pool, global_max_pool # Yes, you can do mean, max pooling in PyG!

pooled = global_add_pool(x=gs.x, 
                         batch=gs.batch) # batch is a tensor that indicates which graph the node belongs to
pooled.shape

torch.Size([3, 5])

## Advanced: `pytorch_scatter` for more complex pooling routines

`pytorch_geometric` used to implement several key features with `pytorch_scatter` for 'pooling' (i.e., aggregate the set of vectors into a single vectors). In this section, we will learn how to use `pytorch_scatter` for more complex pooling routines.

<div style="text-align: center;">
  <img src="./assets/add.svg" alt="Image description" style="width: 400px;">
  <p style="margin-top: 10px;">The behavior of torch_scatter.scatter </p>
</div>

Figure from [here](https://pytorch-scatter.readthedocs.io/en/latest/functions/add.html)

As you can see from the figure, `scatter` operation aggregates 'src' (or input) into 'outputs' with the 'index' vectors. This design choice
allows store the data with plain tensor while aggregating the different number of inputs with a single operation.

**Note: the same features are now implemented in `torch_geometric.utils.scatter`**

In [24]:
import torch
from torch_scatter import scatter
# or equivalently
from torch_geometric.utils import scatter

dim1, dim2 = 11, 13
src = torch.randn(6, dim1, dim2)
index = torch.tensor([0, 1, 0, 1, 2, 1])

# Naive loopy implementation
out_naive = torch.zeros(index.unique().shape[0], dim1, dim2)
for i in index.unique().tolist():
    out_naive[i] = src[index == i].sum(dim=0)
    
# torch_scatter implementation
out = scatter(src, index, dim=0, reduce="sum") # reduce can be "sum", "mean", "max", "min"

assert torch.allclose(out_naive, out)

In [28]:
from torch_geometric.utils import segment
# A similar operation can be done also based on the segments
# For the further details, please refer to the documentation

src = torch.randn(10, 6, 64)
indptr = torch.tensor([0, 2, 5, 6])
indptr = indptr.view(1, -1)  # Broadcasting in the first and last dim.

out = segment(src, indptr, reduce="sum")

### Composite scatter operations; `Softmax`, `logsumexp`, and `scatter_std`

In [26]:
from torch_scatter.composite import scatter_softmax, scatter_std, scatter_logsumexp
# or equivalently
from torch_geometric.utils import softmax

src = torch.randn(10, 1)
idx = torch.tensor([0,0,0,1,1,2,2,2,2,2])
out = scatter_softmax(src, idx, dim=0)
print(out.shape)

print("Results of 'scatter_softmax'")
print(f'1st batch: {out[:3].view(-1)}, sum={out[:3].sum()}')
print(f'2nd batch: {out[3:5].view(-1)}, sum={out[3:5].sum()}')
print(f'3rd batch: {out[5:].view(-1)}, sum={out[5:].sum()}')

torch.Size([10, 1])
Results of 'scatter_softmax'
1st batch: tensor([0.2878, 0.1079, 0.6043]), sum=0.9999999403953552
2nd batch: tensor([0.8721, 0.1279]), sum=0.9999999403953552
3rd batch: tensor([0.0951, 0.1077, 0.2047, 0.4446, 0.1479]), sum=1.0


Alternative to the graph pooling, **Virtual node** is also often used to concentrate node/edge feature to a single vector. 
A primitive approach to virtual node is introducing additional node to the graph and attach the edges from the virtual node to all the nodes in the graph. 
It often works well, as compared to the naive graph pooling approaches. Please refer to the research papers for more details:
- [Graph Classification via Deep Learning with Virtual Nodes](https://arxiv.org/pdf/1708.04357.pdf),
- [On the Connection Between MPNN and Graph Transformer](https://arxiv.org/pdf/2301.11956.pdf)