# Supervised enzyme commission prediction with GNNs

This tutorial demonstrates how to train a GNN with ProteinShake and [Pytorch Geometric](https://pytorch-geometric.readthedocs.io/) for enzyme commission prediction. You can adapt the code for any downstream tasks provided by ProteinShake.

We will use a simple GNN model, namely [GCN](https://arxiv.org/abs/1609.02907), and evaluate its performance for enzyme commission prediction. The model can be trained with either CPU or GPU, but GPU is recommended for faster computation.

## Environment setup

If you are using colab, then uncomment and run the cell below. 

In [1]:
# !pip install git+https://github.com/BorgwardtLab/proteinshake.git
# !pip install pyg-lib torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-1.13.0+cu116.html
# !pip install torch-geometric

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/BorgwardtLab/proteinshake.git
  Cloning https://github.com/BorgwardtLab/proteinshake.git to /tmp/pip-req-build-g2fvw02x
  Running command git clone --filter=blob:none --quiet https://github.com/BorgwardtLab/proteinshake.git /tmp/pip-req-build-g2fvw02x
  Resolved https://github.com/BorgwardtLab/proteinshake.git to commit dc8f8367c8b39c5261f0cfdeb2dc6b4d570ecb3d
  Preparing metadata (setup.py) ... [?25l[?25hdone
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://data.pyg.org/whl/torch-1.13.0+cu116.html
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
import copy
from tqdm import tqdm
import torch
from torch import nn
import torch.nn.functional as F
from proteinshake import tasks as ps_tasks

## Load the task and the dataset

In [3]:
datapath = './data/ec'
task = ps_tasks.EnzymeCommissionTask(root=datapath)
dset = task.dataset

We convert the protein 3D structures to $\epsilon$-graphs ($\epsilon=8$ here):

In [4]:
def transform(data):
    data, protein_dict = data
    data.y = task.target(protein_dict)
    return data
    
dset = dset.to_graph(eps=8.0).pyg(
    transform=transform
)

## Load train/val/test splits

We can now create data loaders for train/val/test sets provided by ProteinShake:

In [5]:
from torch.utils.data import Subset
from torch_geometric.loader import DataLoader

In [6]:
batch_size = 100
train_loader = DataLoader(Subset(dset, task.train_index), batch_size=batch_size,
                          shuffle=True, num_workers=0)
val_loader = DataLoader(Subset(dset, task.val_index), batch_size=batch_size,
                        shuffle=False, num_workers=0)
test_loader = DataLoader(Subset(dset, task.test_index), batch_size=batch_size,
                         shuffle=False, num_workers=0)

## Build GNN models

Here, we build a simple GNN model for this task, namely [GCN](https://arxiv.org/abs/1609.02907).

In [7]:
import torch_geometric.nn as gnn
from torch_geometric import utils

In [8]:
class GCNConv(gnn.MessagePassing):
    def __init__(self, embed_dim=256, use_edge_attr=False):
        super().__init__(aggr='add')
        self.use_edge_attr = use_edge_attr

        self.linear = nn.Linear(embed_dim, embed_dim)
        self.root_emb = nn.Embedding(1, embed_dim)
        self.edge_encoder = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, edge_index, edge_attr=None):
        x = self.linear(x)
        if self.use_edge_attr and edge_attr is not None:
            edge_attr = self.edge_encoder(edge_attr)

        row, col = edge_index

        deg = utils.degree(row, x.size(0), dtype = x.dtype) + 1
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        return self.propagate(
            edge_index, x=x, edge_attr = edge_attr, norm=norm) + F.relu(
            x + self.root_emb.weight) * 1./deg.view(-1,1)

    def message(self, x_j, edge_attr, norm):
        return norm.view(-1, 1) * F.relu(x_j + edge_attr)

In [9]:
class GNN(nn.Module):
    def __init__(self, embed_dim=256, num_layers=3, dropout=0.0,
                 use_edge_attr=False):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_layers = num_layers
        self.dropout = dropout

        self.x_embedding = nn.Embedding(20, embed_dim)

        gnn_model = GCNConv
        self.gnns = nn.ModuleList()
        for _ in range(num_layers):
            self.gnns.append(gnn_model(embed_dim, use_edge_attr=use_edge_attr))

        self.batch_norms = nn.ModuleList()
        for _ in range(num_layers):
            self.batch_norms.append(nn.BatchNorm1d(embed_dim))

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr

        output = self.x_embedding(x)

        for layer in range(self.num_layers):
            output = self.gnns[layer](output, edge_index, edge_attr)
            output = self.batch_norms[layer](output)

            if layer == self.num_layers - 1:
                output = F.dropout(output, self.dropout, training=self.training)
            else:
                output = F.dropout(F.relu(output), self.dropout, training=self.training)

        return output

In [10]:
class GNN_graphpred(nn.Module):
    def __init__(self, num_class, embed_dim=64, num_layers=3, dropout=0.0,
                 use_edge_attr=False, global_pool='mean'):
        super().__init__()

        self.encoder = GNN(embed_dim, num_layers, dropout, use_edge_attr)

        self.global_pool = global_pool
        if global_pool == 'mean':
            self.pooling = gnn.global_mean_pool
        elif global_pool == 'add':
            self.pooling = gnn.global_add_pool
        elif global_pool == 'max':
            self.pooling = gnn.global_max_pool
        elif global_pool is None:
            self.pooling = None

        self.classifier = nn.Linear(embed_dim, num_class)

    def forward(self, data, other_data = None):
        bsz = len(data.ptr) - 1
        output = self.encoder(data)
        if self.pooling is not None:
            output = self.pooling(output, data.batch)
        return self.classifier(output)

We build a GCN model with 5 layers and 64 hidden dimensions:

In [11]:
embed_dim = 64
num_layers = 5

model = GNN_graphpred(
    task.num_classes,
    embed_dim,
    num_layers,
)

## Build an optimizer and define the train and test function

In [12]:
lr = 0.001
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=lr
)

criterion = nn.CrossEntropyLoss()

In [13]:
# set device
device = torch.device(torch.cuda.current_device()) \
        if torch.cuda.is_available() else torch.device('cpu')

In [14]:
def train_epoch(model):
    model.train()

    running_loss = 0.
    for step, batch in enumerate(train_loader):
        size = len(batch.y)
        batch = batch.to(device)

        optimizer.zero_grad()
        y_hat = model(batch)

        loss = criterion(y_hat, batch.y)

        loss.backward()
        optimizer.step()

        running_loss += loss.item() * size

    n_sample = len(train_loader.dataset)
    epoch_loss = running_loss / n_sample
    return epoch_loss

ProteinShake provides an evaluation function for each task `task.evaluate(y_true, y_pred)`.

In [15]:
@torch.no_grad()
def eval_epoch(model, loader):
    model.eval()

    y_true = []
    y_pred = []

    for step, batch in enumerate(loader):
        batch = batch.to(device)
        y_hat = model(batch)

        y_true.append(batch.y.cpu())
        y_pred.append(y_hat.cpu())

    y_true = torch.cat(y_true, dim = 0).numpy()
    y_pred = torch.vstack(y_pred).numpy()
    y_pred = y_pred.argmax(-1)
    scores = task.evaluate(y_true, y_pred)
    return scores

## Training

In [16]:
model.to(device)

GNN_graphpred(
  (encoder): GNN(
    (x_embedding): Embedding(20, 64)
    (gnns): ModuleList(
      (0): GCNConv()
      (1): GCNConv()
      (2): GCNConv()
      (3): GCNConv()
      (4): GCNConv()
    )
    (batch_norms): ModuleList(
      (0): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (classifier): Linear(in_features=64, out_features=7, bias=True)
)

In [17]:
epochs = 20 # we train only 20 epochs here, but more epochs may result in better performance.

best_val_score = 0.0
pbar = tqdm(range(epochs))
for epoch in pbar:
    train_loss = train_epoch(model)
    val_scores = eval_epoch(model, val_loader)
    val_score = val_scores['accuracy']
    postfix = {'train_loss': train_loss, 'val_acc': val_score}
    pbar.set_postfix(postfix)
    
    if val_score > best_val_score:
        best_val_score = val_score
        best_weights = copy.deepcopy(model.state_dict())

model.load_state_dict(best_weights)

100%|██████████| 20/20 [05:02<00:00, 15.13s/it, train_loss=0.53, val_acc=0.655]


<All keys matched successfully>

## Testing the trained model

In [18]:
test_scores = eval_epoch(model, test_loader)
print(test_scores)

{'precision': 0.5333515066547034, 'recall': 0.4799021029676011, 'accuracy': 0.6675514266755143}
