# Training and evaluating a model

[<img src="https://colab.research.google.com/assets/colab-badge.svg">](https://colab.research.google.com/github/BorgwardtLab/proteinshake/blob/main/docs/readthedocs/source/notebooks/task.ipynb)

This tutorial demonstrates how to train a graph neural network (GNN) with ProteinShake and [Pytorch Geometric](https://pytorch-geometric.readthedocs.io/) for enzyme commission number prediction. You can adapt the code for any downstream tasks provided by ProteinShake.

We will use a simple GNN model, namely [GCN](https://arxiv.org/abs/1609.02907). The model can be trained with either CPU or GPU, but GPU is recommended for faster computation.

If you are using colab, then uncomment and run the cell below. 

In [1]:
# !pip install proteinshake torch torch_geometric tqdm

from proteinshake.tasks import EnzymeClassTask

import torch
from torch.nn import Module, Embedding, BatchNorm1d, Linear, CrossEntropyLoss
from torch.optim import AdamW
from torch_geometric.loader import DataLoader
from torch_geometric.nn.conv import GCNConv
from torch_geometric.nn import global_mean_pool
from tqdm import tqdm
import copy

## Load the task and the dataset

We use the `EnzymeCommissionTask` from ProteinShake and convert the protein structures to epsilon-graphs (epsilon=8). The graphs are then converted to PyTorch Geometric `Data` objects. To assign the appropriate attributes (the class label) to the `Data` object, we write a transform:

In [2]:
def my_transform(data):
    data, protein_dict = data
    data.y = task.target(protein_dict)
    return data

task = EnzymeClassTask(root='ec_prediction').to_graph(eps=8.0).pyg(transform=my_transform)

That's all you need to load your data! The rest is implementing a model and training.

## Load train/val/test splits

We can now create data loaders for train/val/test sets provided by ProteinShake:

In [3]:
train_loader = DataLoader(task.train, batch_size=100, shuffle=True)
val_loader = DataLoader(task.val, batch_size=100, shuffle=True)
test_loader = DataLoader(task.test, batch_size=100, shuffle=True)

## Build a GNN model

Here, we create a simple GNN model for this task, namely a [GCN](https://arxiv.org/abs/1609.02907).

In [4]:
class Model(Module):

    def __init__(self, num_classes, dim=256):
        super().__init__()
        self.embedding = Embedding(20, dim)
        self.gcn1 = GCNConv(dim,dim)
        self.bn1 = BatchNorm1d(dim)
        self.gcn2 = GCNConv(dim,dim)
        self.bn2 = BatchNorm1d(dim)
        self.pool = global_mean_pool
        self.classifier = Linear(dim, num_classes)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.embedding(x)
        x = self.gcn1(x, edge_index)
        x = self.bn1(x)
        x = self.gcn2(x, edge_index)
        x = self.bn2(x)
        x = self.pool(x, batch)
        return self.classifier(x)

We build the model and define some learning parameters:

In [5]:
model = Model(task.num_classes)
optimizer = AdamW(model.parameters(), lr=0.001)
criterion = CrossEntropyLoss()
device = torch.device(torch.cuda.current_device()) if torch.cuda.is_available() else torch.device('cpu')

Now let's define a simple training loop:

In [6]:
def train_epoch(model, loader):
    model.train()

    running_loss = 0.
    for step, batch in enumerate(loader):
        size = len(batch.y)
        batch = batch.to(device)

        optimizer.zero_grad()
        y_hat = model(batch)

        loss = criterion(y_hat, batch.y)

        loss.backward()
        optimizer.step()

        running_loss += loss.item() * size

    n_sample = len(loader.dataset)
    epoch_loss = running_loss / n_sample
    return epoch_loss

ProteinShake provides an evaluation function for each task. We will use it to write an evaluation loop:

In [7]:
@torch.no_grad()
def eval_epoch(model, loader, targets):
    model.eval()
    y_pred = []

    for step, batch in enumerate(loader):
        batch = batch.to(device)
        y_hat = model(batch)
        y_pred.append(y_hat.cpu())

    y_pred = torch.vstack(y_pred).numpy()
    y_pred = y_pred.argmax(-1)
    scores = task.evaluate(targets, y_pred) # ProteinShake provides you with appropriate metrics for the task
    return scores

## Training

Time for the actual model training:

In [10]:
model.to(device)
epochs = 1 # we train only 1 epoch here, but more epochs may result in better performance.

best_val_score = 0.0
pbar = tqdm(range(epochs))
for epoch in pbar:
    train_loss = train_epoch(model, train_loader)
    val_scores = eval_epoch(model, val_loader, task.val_targets)
    val_score = val_scores['accuracy']
    postfix = {'train_loss': train_loss, 'val_acc': val_score}
    pbar.set_postfix(postfix)
    
    if val_score > best_val_score:
        best_val_score = val_score
        best_weights = copy.deepcopy(model.state_dict())

model.load_state_dict(best_weights)

100%|██████████| 1/1 [00:49<00:00, 49.96s/it, train_loss=1.37, val_acc=0.299]


<All keys matched successfully>

## Testing the trained model

How good is the model? Certainly this is not a state-of-the-art performance, but you get the idea.

In [11]:
test_scores = eval_epoch(model, test_loader, task.test_targets)
print(test_scores)

{'precision': 0.11018509385404471, 'recall': 0.1342314815316218, 'accuracy': 0.3175416133162612}
