With proteins loaded as graphs, the next step is to build some deep learning models to learn on these graphs. There are a number of full-fledged models as well as building block layers for complex archtiectures implemented in the `torch-geometric` library:

[torch_geometric.nn](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html) has a variety of graph layers that can be used to build custom GNN architectures. These include:
- Convolution layers: These define how the message passing step is accomplished across edges in the graph. `GCNConv` is a simple example of a graph convolution layer, while `GATConv` is a more complex example with attention mechanisms.
- Aggregation Operators: These define how messages are aggregated at each node. 
- Pooling layers: These define how nodes are aggregated into a single node.

![](https://www.aritrasen.com/wp-content/uploads/2022/11/msg_1.jpg)
![](https://www.researchgate.net/profile/Lavender-Jiang-2/publication/343441194/figure/fig4/AS:921001206509568@1596595207558/Graph-pooling-and-graph-aggregation-Graph-pooling-left-accepts-a-graph-signal-and.ppm)


[torch_geometric.nn.models](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#models) has more complex model architectures with a variety of of these layers already defined and combined inside.

The [PyGModelHubMixin](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.model_hub.PyGModelHubMixin) class can be used to load pre-trained models or other model architectures from the [HuggingFace Model Hub](https://huggingface.co/models?pipeline_tag=graph-ml&sort=trending)

In [5]:
from torch_geometric import nn as graph_nn
from torch import nn
from src import dataloader

We can load our datamodule and get a train batch to test out these layers and models:

In [6]:
datamodule = dataloader.ProteinGraphDataModule("./test_data", "dataset.txt")
datamodule.prepare_data()
datamodule.setup("fit")

train_loader = datamodule.train_dataloader()
example_train_batch = next(iter(train_loader))
example_train_protein = datamodule.train[0]

Processing...
Done!
Processing...
Done!


A batch essentially combines all the graphs of the individual proteins into a bigger batch graph, with an additional `batch` attribute that specifies which protein each node belongs to. Since there are no edges between the different proteins, training on this batch graph is equivalent to training on the individual graphs separately, since no information flows between the different proteins.

In [9]:
example_train_protein

Data(edge_index=[2, 574], node_id=[231], chain_id=[231], residue_number=[231], coords=[231, 3], amino_acid_one_hot=[231, 20], meiler=[231, 7], kind=[287], num_nodes=231, x=[231, 27], pos=[231, 3], y=[231])

In [8]:
example_train_batch

DataBatch(edge_index=[2, 3950], node_id=[8], chain_id=[8], residue_number=[1583], coords=[1583, 3], amino_acid_one_hot=[1583, 20], meiler=[1583, 7], kind=[8], num_nodes=1583, x=[1583, 27], pos=[1583, 3], y=[1583], batch=[1583], ptr=[9])

Here's how we would define a graph convolutional layer that takes amino acid one hot embeddings as input node features along with the edge index to define the graph, and converts them to a 64-dimensional embedding via convolution operations across the graph:

In [7]:
layer = graph_nn.GCNConv(in_channels=20, out_channels=64)
example_output = layer(example_train_batch.amino_acid_one_hot.float(), example_train_batch.edge_index)
example_output.shape

torch.Size([1583, 64])

Try out some of the other layers in the torch_geometric.nn module.

We can also try out some of the pre-defined models in the `torch_geometric.nn.models` module, such as the `GAT` model which applies a series of `GATv2Conv` layers that uses attention mechanisms to weight the importance of different nodes in the graph when aggregating information from neighbors, followed by a Linear layer to convert the node embeddings to a 64-dimensional output.

In [19]:
model = graph_nn.GAT(in_channels=20,
               hidden_channels=32,
               num_layers=3,
               heads=2,
               out_channels=64,
               dropout=0.01,
               jk="last", 
               v2=True)
print(graph_nn.summary(model, example_train_batch.amino_acid_one_hot.float(), example_train_batch.edge_index))

+---------------------+-----------------------+----------------+----------+
| Layer               | Input Shape           | Output Shape   | #Param   |
|---------------------+-----------------------+----------------+----------|
| GAT                 | [1583, 20], [2, 3950] | [1583, 64]     | 7,872    |
| ├─(dropout)Dropout  | [1583, 32]            | [1583, 32]     | --       |
| ├─(act)ReLU         | [1583, 32]            | [1583, 32]     | --       |
| ├─(convs)ModuleList | --                    | --             | 5,760    |
| │    └─(0)GATv2Conv | [1583, 20], [2, 3950] | [1583, 32]     | 1,408    |
| │    └─(1)GATv2Conv | [1583, 32], [2, 3950] | [1583, 32]     | 2,176    |
| │    └─(2)GATv2Conv | [1583, 32], [2, 3950] | [1583, 32]     | 2,176    |
| ├─(norms)ModuleList | --                    | --             | --       |
| │    └─(0)Identity  | [1583, 32]            | [1583, 32]     | --       |
| │    └─(1)Identity  | [1583, 32]            | [1583, 32]     | --       |
| │    └─(2)

Or the `GraphUnet` model which implements a U-Net architecture for graphs, with a series of graph convolutional layers followed by pooling layers to downsample the graph, and then a series of graph convolutional layers followed by upsampling layers to upsample the graph back to the original size.

In [17]:
model = graph_nn.GraphUNet(in_channels=20,
               hidden_channels=32,
               out_channels=64,
               depth=2,
               pool_ratios=[0.5, 0.25],
               )
print(graph_nn.summary(model, example_train_batch.amino_acid_one_hot.float(), example_train_batch.edge_index))

  adj = torch.sparse_csr_tensor(


+----------------------------------+---------------------------------------+---------------------------------------------------+----------+
| Layer                            | Input Shape                           | Output Shape                                      | #Param   |
|----------------------------------+---------------------------------------+---------------------------------------------------+----------|
| GraphUNet                        | [1583, 20], [2, 3950]                 | [1583, 64]                                        | 6,016    |
| ├─(act)ReLU                      | [1583, 32]                            | [1583, 32]                                        | --       |
| ├─(down_convs)ModuleList         | --                                    | --                                                | 2,784    |
| │    └─(0)GCNConv                | [1583, 20], [2, 3950], [3950]         | [1583, 32]                                        | 672      |
| │    └─(1)GCNConv 

We can combine layers into custom architectures. Here is an example of a simple architecture that uses a GATConv layer and a GCNConv layer with some activation functions in between, and finally a linear layer to convert the 64-dimensional node embeddings to one value per node.

In [33]:
model = graph_nn.Sequential('x, edge_index', [
    (graph_nn.GATConv(in_channels=20, out_channels=64, heads=2, concat=False), 'x, edge_index -> x'),
    nn.ReLU(inplace=True),
    (graph_nn.GCNConv(in_channels=64, out_channels=64), 'x, edge_index -> x'),
    nn.ReLU(inplace=True),
    nn.Linear(64, 1),
])

print(graph_nn.summary(model, example_train_batch.amino_acid_one_hot.float(), example_train_batch.edge_index))

+---------------------+-----------------------+----------------+----------+
| Layer               | Input Shape           | Output Shape   | #Param   |
|---------------------+-----------------------+----------------+----------|
| Sequential          | [1583, 20], [2, 3950] | [1583, 1]      | 7,105    |
| ├─(module_0)GATConv | [1583, 20], [2, 3950] | [1583, 64]     | 2,880    |
| ├─(module_1)ReLU    | [1583, 64]            | [1583, 64]     | --       |
| ├─(module_2)GCNConv | [1583, 64], [2, 3950] | [1583, 64]     | 4,160    |
| ├─(module_3)ReLU    | [1583, 64]            | [1583, 64]     | --       |
| ├─(module_4)Linear  | [1583, 64]            | [1583, 1]      | 65       |
+---------------------+-----------------------+----------------+----------+


In order to train such models with our data for our task of interface residue prediction, we need to define a loss function that takes the output of the model (the prediction) and the target labels and computes a loss value that the optimizer can use to update the model parameters. A typical choice for binary classification tasks is the binary cross entropy loss, which is implemented in PyTorch as `torch.nn.BCEWithLogitsLoss`. This loss function takes the raw output of the model and the target labels, and applies the sigmoid function to the model output to get the predicted probabilities, and then computes the binary cross entropy loss between the predicted probabilities and the target labels, defined as

$$
\text{loss} = -\frac{1}{N} \sum_{i=1}^N \left[ y_i \log(p_i) + (1 - y_i) \log(1 - p_i) \right]
$$

where $N$ is the number of residues, $y_i$ is the target label for residue $i$ (1 if it's an interface residue, 0 if not), and $p_i$ is the predicted probability for residue $i$.

In [35]:
nn.BCEWithLogitsLoss()(model(example_train_batch.amino_acid_one_hot.float(), example_train_batch.edge_index), example_train_batch.y.view(-1, 1))

tensor(0.6386, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)

Given a model that predicts interface probabilities and a loss that compared them with the true interface labels, what we now need is a training loop that will iterate over the training data in batches, compute the loss, and use an optimizer to update the model parameters based on the loss value.

All of this is encapsulated within the `LightningModule` class in PyTorch Lightning.

![](https://lightningaidev.wpengine.com/wp-content/uploads/2023/10/pl-walk-lit-module.png)

In [3]:
import lightning as L
import torch
from torch import nn

class GATModule(L.LightningModule):
    """
    LightningModule wrapping a GAT model.
    """
    def __init__(self):
        super().__init__()
        self.model = graph_nn.GAT(in_channels=20,
                         hidden_channels=32,
                         num_layers=2,
                         heads=2,
                         out_channels=1,
                         dropout=0.01,
                         jk="last", v2=True)
        self.loss_function = nn.BCEWithLogitsLoss()

    def forward(self, node_attributes, edge_index):
        return self.model(node_attributes, edge_index)

    def training_step(self, batch, batch_idx):
        out = self(batch.amino_acid_one_hot.float(), batch.edge_index)
        loss = self.loss_function(out, batch.y.view(-1, 1))
        self.log('train_loss', loss, on_step=True, on_epoch=True, sync_dist=True,
                 batch_size=batch.batch_size)
        return loss

    def configure_optimizers(self):
      return torch.optim.Adam(params=self.model.parameters(), lr=0.001, weight_decay=0.0001)

The `Trainer` class then combines the training loop defined in the LightningModule with the data loading functions in the LightningDataModule. We set the `max_epochs` to 5, meaning that the training loop will iterate over the entire training data 5 times, updating the model parameters with each batch.

In [6]:
model = GATModule()
datamodule = dataloader.ProteinGraphDataModule("./test_data", "dataset.txt")
trainer = L.Trainer(enable_progress_bar=True, max_epochs=5)
trainer.fit(model=model, datamodule=datamodule)

/scicore/home/schwede/durair0000/mambaforge/envs/leuven/lib/python3.8/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /scicore/home/schwede/durair0000/mambaforge/envs/leu ...
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/scicore/home/schwede/durair0000/mambaforge/envs/leuven/lib/python3.8/site-packages/lightning/pytorch/trainer/configuration_validator.py:72: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.
Processing...
Done!
Processing...
Done!

  | Name          | Type              | Params
----------------------------------------------------
0 | model         | GAT               | 3.6 K 
1 | loss_function | BCEWithLogitsLoss | 0     
----------------------

Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


While this runs the training loop for 5 epochs, what we really want to do is to monitor the performance of the model on the validation data (and maybe even stop training when the performance stops improving). We'd probably like to see some metrics like accuracy, precision, recall, and F1 score, as well as the loss value and how those change over time, both on the training data and the validation data to make sure the model is learning something useful and not overfitting. All of this needs more complex logging and monitoring than what we've done so far, covered in the next notebook.

## Bonus:
- How would you change things for protein-protein input?