# Geometric Deep Learning Models

With proteins loaded as graphs, the next step is to build some deep learning models to learn on these graphs. Typically, graph neural networks work by passing messages between nodes in the graph whereever there is an edge, and then aggregating these messages at each node.

![](https://www.aritrasen.com/wp-content/uploads/2022/11/msg_1.jpg)

This is done iteratively for a number of layers, and then the final node representations are used to make predictions. This way, the model can learn to take into account the structure of the graph when making predictions and information flows between connected nodes.

![](https://ars.els-cdn.com/content/image/1-s2.0-S2666651021000012-gr2_lrg.jpg)


Sometimes you may even change the graph structure itself, by pooling nodes or edges, or by adding new edges or nodes. This can be useful for tasks like graph classification, where you want to make a prediction about the entire graph, or for tasks like graph generation, where you want to create new graphs that are similar to the ones you've seen before.

![](https://www.researchgate.net/profile/Lavender-Jiang-2/publication/343441194/figure/fig4/AS:921001206509568@1596595207558/Graph-pooling-and-graph-aggregation-Graph-pooling-left-accepts-a-graph-signal-and.ppm)


So, depending on the way messages are passed, aggregated, pooled, and transformed, there are a number of different architectures that can be used to build graph neural networks. Here's a comprehensive [review](https://www.sciencedirect.com/science/article/pii/S2666651021000012) and a [book](https://arxiv.org/pdf/2104.13478.pdf) to learn more about the different architectures, how they work, and how they can be used for different tasks.

![](https://ars.els-cdn.com/content/image/1-s2.0-S2666651021000012-gr3_lrg.jpg)

## Torch-geometric models and layers

In this notebook, we'll take a look at some of the different architectures and building block layers that are implemented in the `torch-geometric` library, and how they can be used to build and train models for graph-based tasks.

- [torch_geometric.nn](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html) has a variety of graph layers that can be used to build custom GNN architectures. These include:
    - [Convolutional layers](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#convolutional-layers): These define how the message passing step is accomplished across edges in the graph. `GCNConv` is a simple example of a graph convolution layer, while `GATConv` is a more complex example with attention mechanisms.
    - [Aggregation Operators](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#aggregation-operators): These define how messages are aggregated at each node. 
    - [Pooling layers](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#id45): These define how nodes are aggregated into a single node.

- [torch_geometric.nn.models](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#models) has more complex model architectures with a variety of of these layers already defined and combined inside.
- The [PyGModelHubMixin](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.model_hub.PyGModelHubMixin) class can be used to load pre-trained models or other model architectures from the [HuggingFace Model Hub](https://huggingface.co/models?pipeline_tag=graph-ml&sort=trending)

In [1]:
from torch_geometric import nn as graph_nn
from torch import nn
from src import dataloader

## Batching graph data

We can load our datamodule from the previous notebook and get a train batch to test out these layers and models:

In [2]:
datamodule = dataloader.ProteinGraphDataModule("./test_data", "dataset.txt")
datamodule.prepare_data()
datamodule.setup("fit")

train_loader = datamodule.train_dataloader()
example_train_batch = next(iter(train_loader))
example_train_protein = datamodule.train[0]

Processing...
Done!
Processing...
Done!


We have an example train data point (`example_train_protein`) but training is almost always done on batches of data points. This is what is returned by the `train_dataloader` of the `DataModule`, in `example_train_batch`.

A batch essentially combines all the graphs of the individual proteins into a bigger batch graph, with an additional `batch` attribute that specifies which protein each node belongs to. Since there are no edges between the different proteins, training on this batch graph is equivalent to training on the individual graphs separately, since no information flows between the different proteins.

In [3]:
example_train_protein

Data(edge_index=[2, 226], node_id=[103], chain_id=[103], residue_number=[103], coords=[103, 3], amino_acid_one_hot=[103, 20], meiler=[103, 7], kind=[113], num_nodes=103, x=[103, 27], pos=[103, 3], y=[103])

In [20]:
example_train_batch

DataBatch(edge_index=[2, 4280], node_id=[8], chain_id=[8], residue_number=[1691], coords=[1691, 3], amino_acid_one_hot=[1691, 20], meiler=[1691, 7], kind=[8], num_nodes=1691, x=[1691, 27], pos=[1691, 3], y=[1691], batch=[1691], ptr=[9])

## Graph neural network layers

Here's how we would define a graph convolutional layer that takes amino acid one hot embeddings as input node features along with the edge index to define the graph, and converts them to a 64-dimensional embedding via convolution operations across the graph:

In [5]:
layer = graph_nn.GCNConv(in_channels=20, out_channels=64)
example_output = layer(example_train_batch.amino_acid_one_hot.float(), example_train_batch.edge_index)
example_output.shape

torch.Size([1756, 64])

**Try out some of the other layers in the torch_geometric.nn module!**

## Graph neural network models

We can also try out some of the pre-defined models in the `torch_geometric.nn.models` module, such as the `GAT` model which applies a series of `GATv2Conv` layers that uses attention mechanisms to weight the importance of different nodes in the graph when aggregating information from neighbors, followed by a Linear layer to convert the node embeddings to a 64-dimensional output.

In [7]:
model = graph_nn.GAT(in_channels=20,
                     hidden_channels=32,
                     num_layers=3,
                     heads=2,
                     out_channels=64,
                     dropout=0.01,
                     jk="last", 
                     v2=True)
print(graph_nn.summary(model, example_train_batch.amino_acid_one_hot.float(), example_train_batch.edge_index))
print(model(example_train_batch.amino_acid_one_hot.float(), example_train_batch.edge_index).shape)

+---------------------+-----------------------+----------------+----------+
| Layer               | Input Shape           | Output Shape   | #Param   |
|---------------------+-----------------------+----------------+----------|
| GAT                 | [1756, 20], [2, 4378] | [1756, 64]     | 7,872    |
| ├─(dropout)Dropout  | [1756, 32]            | [1756, 32]     | --       |
| ├─(act)ReLU         | [1756, 32]            | [1756, 32]     | --       |
| ├─(convs)ModuleList | --                    | --             | 5,760    |
| │    └─(0)GATv2Conv | [1756, 20], [2, 4378] | [1756, 32]     | 1,408    |
| │    └─(1)GATv2Conv | [1756, 32], [2, 4378] | [1756, 32]     | 2,176    |
| │    └─(2)GATv2Conv | [1756, 32], [2, 4378] | [1756, 32]     | 2,176    |
| ├─(norms)ModuleList | --                    | --             | --       |
| │    └─(0)Identity  | [1756, 32]            | [1756, 32]     | --       |
| │    └─(1)Identity  | [1756, 32]            | [1756, 32]     | --       |
| │    └─(2)

**Try out some of the other models in the torch_geometric.nn module!**

We can combine layers into custom architectures. Here is an example of a simple architecture that uses a GATConv layer and a GCNConv layer with some activation functions in between, and finally a linear layer to convert the 64-dimensional node embeddings to one value per node.

In [14]:
model = graph_nn.Sequential('x, edge_index', [
    (graph_nn.GATConv(in_channels=20, out_channels=64, heads=2, concat=False), 'x, edge_index -> x'),
    nn.ReLU(inplace=True),
    (graph_nn.GCNConv(in_channels=64, out_channels=64), 'x, edge_index -> x'),
    nn.ReLU(inplace=True),
    nn.Linear(64, 1),
])

print(graph_nn.summary(model, example_train_batch.amino_acid_one_hot.float(), example_train_batch.edge_index))

+---------------------+-----------------------+----------------+----------+
| Layer               | Input Shape           | Output Shape   | #Param   |
|---------------------+-----------------------+----------------+----------|
| Sequential          | [1744, 20], [2, 4430] | [1744, 1]      | 7,105    |
| ├─(module_0)GATConv | [1744, 20], [2, 4430] | [1744, 64]     | 2,880    |
| ├─(module_1)ReLU    | [1744, 64]            | [1744, 64]     | --       |
| ├─(module_2)GCNConv | [1744, 64], [2, 4430] | [1744, 64]     | 4,160    |
| ├─(module_3)ReLU    | [1744, 64]            | [1744, 64]     | --       |
| ├─(module_4)Linear  | [1744, 64]            | [1744, 1]      | 65       |
+---------------------+-----------------------+----------------+----------+


## Defining losses

In order to train such models with our data for our task of interface residue prediction, we need to define a loss function that takes the output of the model (the prediction) and the target labels and computes a loss value that the optimizer can use to update the model parameters. A typical choice for binary classification tasks is the binary cross entropy loss, which is implemented in PyTorch as `torch.nn.BCEWithLogitsLoss`. This loss function takes the raw output of the model and the target labels, and applies the sigmoid function to the model output to get the predicted probabilities, and then computes the binary cross entropy loss between the predicted probabilities and the target labels, defined as

$$
\text{loss} = -\frac{1}{N} \sum_{i=1}^N \left[ y_i \log(p_i) + (1 - y_i) \log(1 - p_i) \right]
$$

where $N$ is the number of residues, $y_i$ is the target label for residue $i$ (1 if it's an interface residue, 0 if not), and $p_i$ is the predicted probability for residue $i$.

In [15]:
nn.BCEWithLogitsLoss()(model(example_train_batch.amino_acid_one_hot.float(), example_train_batch.edge_index), example_train_batch.y.view(-1, 1))

tensor(0.6806, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)

## Training a model

Given a model that predicts interface probabilities and a loss that compared them with the true interface labels, what we now need is a training loop that will iterate over the training data in batches, compute the loss, and use an optimizer to update the model parameters based on the loss value.

All of this is encapsulated within the `LightningModule` class in PyTorch Lightning.

![](https://lightningaidev.wpengine.com/wp-content/uploads/2023/10/pl-walk-lit-module.png)

In [19]:
import lightning
import torch
from torch import nn

class GATModule(lightning.LightningModule):
    """
    LightningModule wrapping a GAT model.
    """
    def __init__(self, in_channels=20, hidden_channels=32, num_layers=2, heads=2, out_channels=1, dropout=0.01, jk="last"):
        super().__init__()
        self.model = graph_nn.GAT(in_channels=in_channels,
                         hidden_channels=hidden_channels,
                         num_layers=num_layers,
                         heads=heads,
                         out_channels=out_channels,
                         dropout=dropout,
                         jk=jk, v2=True)
        self.loss_function = nn.BCEWithLogitsLoss()

    def forward(self, node_attributes, edge_index):
        return self.model(node_attributes, edge_index)

    def training_step(self, batch, batch_idx):
        out = self(batch.amino_acid_one_hot.float(), batch.edge_index)
        loss = self.loss_function(out, batch.y.view(-1, 1))
        self.log('train_loss', loss, on_step=True, on_epoch=True, sync_dist=True,
                 batch_size=batch.batch_size)
        return loss

    def configure_optimizers(self):
      return torch.optim.Adam(params=self.model.parameters(), lr=0.001, weight_decay=0.0001)

The `Trainer` class then combines the training loop defined in the LightningModule with the data loading functions in the LightningDataModule. We set the `max_epochs` to 5, meaning that the training loop will iterate over the entire training data 5 times, updating the model parameters with each batch.

In [20]:
model = GATModule()
datamodule = dataloader.ProteinGraphDataModule("./test_data", "dataset.txt")
trainer = lightning.Trainer(enable_progress_bar=True, max_epochs=5, accelerator="cpu")
trainer.fit(model=model, datamodule=datamodule)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Processing...
Done!
Processing...
Done!

  | Name          | Type              | Params
----------------------------------------------------
0 | model         | GAT               | 3.6 K 
1 | loss_function | BCEWithLogitsLoss | 0     
----------------------------------------------------
3.6 K     Trainable params
0         Non-trainable params
3.6 K     Total params
0.014     Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


**Congratulations, your model is training!**

The next step is to monitor the performance of the model on the validation data (and maybe even stop training when the performance stops improving). We'd probably like to see some metrics beyond the loss value (like accuracy, precision, recall) and how those change over time, both on the training data and the validation data to make sure the model is learning something useful and not overfitting. All of this needs more complex logging and monitoring than what we've done so far, covered in the next notebook.

## Bonus:
- How would you change things for protein-protein input?