# Antibody Developability Prediction with E(N) GNNs

Here we showcase Graphein's ability to generate a structural dataset from only PDB identifiers and labels. We use an [antibody developability dataset](https://tdcommons.ai/single_pred_tasks/develop/) from [TDC](https://tdcommons.ai/) and train an [E(N) GNN](https://arxiv.org/abs/2102.09844) on antibody structure graphs.

**Dataset Description**: Antibody data from Chen et al [1], where they process from the SAbDab [2]. From an initial dataset of 3816 antibodies, they retained 2426 antibodies that satisfy the following criteria:

1. have both sequence (FASTA) and Protein Data Bank (PDB) structure files,
2. contain both a heavy chain and a light chain, and
3. have crystal structures with resolution < 3 Å. The DI label is derived from BIOVIA's pipelines [3].

**Task Description**: Binary classification. Given the antibody's heavy chain and light chain sequence, predict its developability. The input X is a list of two sequences where the first is the heavy chain and the second light chain.

**Dataset Statistics**: 2,409 antibodies.

**Dataset Split**: Random Split

**References**:

[1] Chen, Xingyao, et al. “Predicting antibody developability from sequence using machine learning.” bioRxiv (2020).

[2] Dunbar, James, et al. “SAbDab: the structural antibody database.” Nucleic acids research 42.D1 (2014): D1140-D1146.

[3] Biovia, Dassault Systèmes. “BIOVIA pipeline pilot.” Dassault Systèmes: San Diego, BW, Release (2017).

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/a-r-j/graphein/blob/master/notebooks/tdc_developability.ipynb) [![GitHub](https://img.shields.io/badge/-View%20on%20GitHub-181717?logo=github&logoColor=ffffff)](https://github.com/a-r-j/graphein/blob/master/notebooks/tdc_developability.ipynb)

In [1]:
# Install requirements if necessary
# !pip install graphein
# !pip install PyTDC

In [2]:
#NBVAL_SKIP
import torch
from typing import Dict

from tdc.single_pred import Develop

import graphein.protein as gp
from graphein.ml.conversion import GraphFormatConvertor
from graphein.ml import InMemoryProteinGraphDataset, ProteinGraphDataset

To use the Graphein submodule graphein.protein.visualisation, you need to install: pytorch3d 
To do so, use the following command: conda install -c pytorch3d pytorch3d


## Loading Data from TDC
As this dataset is non-redundant we can use a random split

In [58]:
#NBVAL_SKIP
# Load data from TDC and split
data = Develop(name = 'SAbDab_Chen')
split = data.get_split()
split["train"].head()

Found local copy...
Loading...
Done!


Unnamed: 0,Antibody_ID,Antibody,Y
0,12e8,['EVQLQQSGAEVVRSGASVKLSCTASGFNIKDYYIHWVKQRPEKG...,0
1,15c8,['EVQLQQSGAELVKPGASVKLSCTASGFNIKDTYMHWVKQKPEQG...,0
2,1a0q,['EVQLQESDAELVKPGASVKISCKASGYTFTDHVIHWVKQKPEQG...,1
3,1a14,['QVQLQQSGAELVKPGASVRMSCKASGYTFTNYNMYWVKQSPGQG...,0
4,1a2y,['QVQLQESGPGLVAPSQSLSITCTVSGFSLTGYGVNWVRQPPGKG...,0


## Removing Obsolete Structures

Sometimes PDB entries are made obsolete as improved structures are made available. This is problematic as the old structures are not available for download. For instance, see [1OM3](https://www.rcsb.org/structure/removed/1OM3) which has been replaced by [6N32](https://www.rcsb.org/structure/6N32).

In [4]:
#NBVAL_SKIP
# Check for obsolete structures
from graphein.protein.utils import get_obsolete_mapping

obs = get_obsolete_mapping()

train_obs = [t for t in split["train"]["Antibody_ID"] if t in obs.keys()]
valid_obs = [t for t in split["valid"]["Antibody_ID"] if t in obs.keys()]
test_obs = [t for t in split["test"]["Antibody_ID"] if t in obs.keys()]

print(train_obs)
print(valid_obs)
print(test_obs)

['1om3', '1zls', '1zlu', '1zlw', '3l5y', '3qot', '3rvv', '3rvw', '3rvx', '3wxw', '4nx3', '4pp2', '4x4y', '5kmv', '5usi', '6erx']
[]
['3wxv', '1zlv', '1op3', '3qos']


In [5]:
#NBVAL_SKIP
# If you want, you can get the PDB IDs of the new structure that replaces the obsolete entry
print("Replacement PDBs: ", [obs[t] for t  in train_obs])

# However, in this instance we will simply remove the obsolete entries from the train and test sets.
split["train"] = split["train"].loc[~split["train"]["Antibody_ID"].isin(train_obs)]
split["test"] = split["test"].loc[~split["test"]["Antibody_ID"].isin(test_obs)]

Replacement PDBs:  ['6n32', '6msy', '6mub', '6mnf', '4ps4', '5i18', '5vpl', '5vpg', '5vph', '6ks1', '4web', '5vco', '6dn0', '5vzx', '6b9j', '6fxn']


## Creating Labels

We convert the labels to tensors and map them to their corresponding PDB ID.

In [6]:
#NBVAL_SKIP
# Convert labels to tensors
def get_label_map(split_name: str) -> Dict[str, torch.Tensor]:
    return dict(zip(split[split_name].Antibody_ID, split[split_name].Y.apply(torch.tensor)))

train_labels = get_label_map("train")
valid_labels = get_label_map("valid")
test_labels = get_label_map("test")

## Creating Graphs and Dataloaders with Graphein

### Configuration and Conversion

First, we define a [configuration object](https://graphein.ai/modules/graphein.protein.html#graphein.protein.config.ProteinGraphConfig) which governs the graph construction. Here we use a simple graph for illustrative purposes, using only a one-hot encoding of the amino acid type, the node coordinates and edges based on spatial contacts within 6 angstroms.

Secondly, we define a [convertor](https://graphein.ai/modules/graphein.ml.html#conversion) which converts from NetworkX graphs to Pytorch Geometric `Data` objects.

In [7]:
#NBVAL_SKIP
from functools import partial

graphein_config = gp.ProteinGraphConfig(
    node_metadata_functions=[gp.amino_acid_one_hot],
    edge_construction_functions=[partial(gp.add_distance_threshold, threshold=6, long_interaction_threshold=0)])

convertor = GraphFormatConvertor(src_format="nx", dst_format="pyg", columns=["coords", "edge_index", "amino_acid_one_hot"])

In [8]:
#NBVAL_SKIP
# Quickly visualise a protein to see what the config looks like in practice
gp.plotly_protein_structure_graph(gp.construct_graph(pdb_code="1lds", config=graphein_config))

### Creating Datasets

Next, we create the actual structural datasets. This takes care of downloading the raw structural data, pre-processing, graph conversion and caches the dataset for easy access.

In [9]:
#NBVAL_SKIP
train_ds = InMemoryProteinGraphDataset(
    root="./data/",
    name="train",
    pdb_codes=split["train"]["Antibody_ID"],
    graph_label_map=train_labels,
    graphein_config=graphein_config,
    graph_format_convertor=convertor
    )

valid_ds = InMemoryProteinGraphDataset(
    root="./data/",
    name="valid",
    pdb_codes=split["valid"]["Antibody_ID"],
    graph_label_map=valid_labels,
    graphein_config=graphein_config,
    graph_format_convertor=convertor
    )

test_ds = InMemoryProteinGraphDataset(
    root="./data/",
    name="test",
    pdb_codes=split["test"]["Antibody_ID"],
    graph_label_map=test_labels,
    graphein_config=graphein_config,
    graph_format_convertor=convertor
    )

Next, we wrap the datasets to create dataloaders for our model.

In [10]:
#NBVAL_SKIP
from torch_geometric.loader import DataLoader

# Create dataloaders
train_loader = DataLoader(train_ds, batch_size=16, shuffle=False, drop_last=True)
valid_loader = DataLoader(valid_ds, batch_size=16, shuffle=False, drop_last=True)
test_loader = DataLoader(test_ds, batch_size=16, drop_last=True)

In [11]:
#NBVAL_SKIP
# Inspect a batch
for b in valid_loader:
    print(b)
    break

DataBatch(edge_index=[2, 40070], node_id=[16], coords=[16], amino_acid_one_hot=[16], graph_y=[16], num_nodes=11545, batch=[11545], ptr=[17])


## Model

Here we lift the implementation from the github repository accompanying the paper https://github.com/vgsatorras/egnn

In [12]:
#NBVAL_SKIP
import pytorch_lightning as pl
import torch
import torch.nn as nn
import itertools

In [13]:
#NBVAL_SKIP
"""EGNN Implementation from Satorras et al. https://github.com/vgsatorras/egnn"""

class E_GCL(nn.Module):
    """
    E(n) Equivariant Convolutional Layer
    re
    """

    def __init__(self, input_nf, output_nf, hidden_nf, edges_in_d=0, act_fn=nn.SiLU(), residual=True, attention=False, normalize=False, coords_agg='mean', tanh=False):
        super(E_GCL, self).__init__()
        input_edge = input_nf * 2
        self.residual = residual
        self.attention = attention
        self.normalize = normalize
        self.coords_agg = coords_agg
        self.tanh = tanh
        self.epsilon = 1e-8
        edge_coords_nf = 1

        self.edge_mlp = nn.Sequential(
            nn.Linear(input_edge + edge_coords_nf + edges_in_d, hidden_nf),
            act_fn,
            nn.Linear(hidden_nf, hidden_nf),
            act_fn)

        self.node_mlp = nn.Sequential(
            nn.Linear(hidden_nf + input_nf, hidden_nf),
            act_fn,
            nn.Linear(hidden_nf, output_nf))

        layer = nn.Linear(hidden_nf, 1, bias=False)
        torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)

        coord_mlp = [nn.Linear(hidden_nf, hidden_nf)]
        coord_mlp.append(act_fn)
        coord_mlp.append(layer)
        if self.tanh:
            coord_mlp.append(nn.Tanh())
        self.coord_mlp = nn.Sequential(*coord_mlp)

        if self.attention:
            self.att_mlp = nn.Sequential(
                nn.Linear(hidden_nf, 1),
                nn.Sigmoid())

    def edge_model(self, source, target, radial, edge_attr):
        if edge_attr is None:  # Unused.
            out = torch.cat([source, target, radial], dim=1)
        else:
            out = torch.cat([source, target, radial, edge_attr], dim=1)
        out = self.edge_mlp(out)
        if self.attention:
            att_val = self.att_mlp(out)
            out = out * att_val
        return out

    def node_model(self, x, edge_index, edge_attr, node_attr):
        row, col = edge_index
        agg = unsorted_segment_sum(edge_attr, row, num_segments=x.size(0))
        if node_attr is not None:
            agg = torch.cat([x, agg, node_attr], dim=1)
        else:
            agg = torch.cat([x, agg], dim=1)
        out = self.node_mlp(agg)
        if self.residual:
            out = x + out
        return out, agg

    def coord_model(self, coord, edge_index, coord_diff, edge_feat):
        row, col = edge_index
        trans = coord_diff * self.coord_mlp(edge_feat)
        if self.coords_agg == 'sum':
            agg = unsorted_segment_sum(trans, row, num_segments=coord.size(0))
        elif self.coords_agg == 'mean':
            agg = unsorted_segment_mean(trans, row, num_segments=coord.size(0))
        else:
            raise Exception('Wrong coords_agg parameter' % self.coords_agg)
        coord += agg
        return coord

    def coord2radial(self, edge_index, coord):
        row, col = edge_index
        coord_diff = coord[row] - coord[col]
        radial = torch.sum(coord_diff**2, 1).unsqueeze(1)

        if self.normalize:
            norm = torch.sqrt(radial).detach() + self.epsilon
            coord_diff = coord_diff / norm

        return radial, coord_diff

    def forward(self, h, edge_index, coord, edge_attr=None, node_attr=None):
        row, col = edge_index
        radial, coord_diff = self.coord2radial(edge_index, coord)

        edge_feat = self.edge_model(h[row], h[col], radial, edge_attr)
        coord = self.coord_model(coord, edge_index, coord_diff, edge_feat)
        h, agg = self.node_model(h, edge_index, edge_feat, node_attr)

        return h, coord, edge_attr

class EGNN(nn.Module):
    def __init__(self, in_node_nf, hidden_nf, out_node_nf, in_edge_nf=0, device='cpu', act_fn=nn.SiLU(), n_layers=4, residual=True, attention=False, normalize=False, tanh=False):
        '''
        :param in_node_nf: Number of features for 'h' at the input
        :param hidden_nf: Number of hidden features
        :param out_node_nf: Number of features for 'h' at the output
        :param in_edge_nf: Number of features for the edge features
        :param device: Device (e.g. 'cpu', 'cuda:0',...)
        :param act_fn: Non-linearity
        :param n_layers: Number of layer for the EGNN
        :param residual: Use residual connections, we recommend not changing this one
        :param attention: Whether using attention or not
        :param normalize: Normalizes the coordinates messages such that:
                    instead of: x^{l+1}_i = x^{l}_i + Σ(x_i - x_j)phi_x(m_ij)
                    we get:     x^{l+1}_i = x^{l}_i + Σ(x_i - x_j)phi_x(m_ij)/||x_i - x_j||
                    We noticed it may help in the stability or generalization in some future works.
                    We didn't use it in our paper.
        :param tanh: Sets a tanh activation function at the output of phi_x(m_ij). I.e. it bounds the output of
                        phi_x(m_ij) which definitely improves in stability but it may decrease in accuracy.
                        We didn't use it in our paper.
        '''

        super(EGNN, self).__init__()
        self.hidden_nf = hidden_nf
        self.device = device
        self.n_layers = n_layers
        self.embedding_in = nn.Linear(in_node_nf, self.hidden_nf)
        self.embedding_out = nn.Linear(self.hidden_nf, out_node_nf)
        for i in range(n_layers):
            self.add_module("gcl_%d" % i, E_GCL(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf,
                                                act_fn=act_fn, residual=residual, attention=attention,
                                                normalize=normalize, tanh=tanh))
        self.to(self.device)

    def forward(self, h, x, edges, edge_attr):
        h = self.embedding_in(h)
        for i in range(self.n_layers):
            h, x, _ = self._modules["gcl_%d" % i](h, edges, x, edge_attr=edge_attr)
        h = self.embedding_out(h)
        return h, x


def unsorted_segment_sum(data, segment_ids, num_segments):
    result_shape = (num_segments, data.size(1))
    result = data.new_full(result_shape, 0)  # Init empty result tensor.
    segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1))
    result.scatter_add_(0, segment_ids, data)
    return result


def unsorted_segment_mean(data, segment_ids, num_segments):
    result_shape = (num_segments, data.size(1))
    segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1))
    result = data.new_full(result_shape, 0)  # Init empty result tensor.
    count = data.new_full(result_shape, 0)
    result.scatter_add_(0, segment_ids, data)
    count.scatter_add_(0, segment_ids, torch.ones_like(data))
    return result / count.clamp(min=1)


def get_edges(n_nodes):
    rows, cols = [], []
    for i, j in itertools.product(range(n_nodes), range(n_nodes)):
        if i != j:
            rows.append(i)
            cols.append(j)

    return [rows, cols]


def get_edges_batch(n_nodes, batch_size):
    edges = get_edges(n_nodes)
    edge_attr = torch.ones(len(edges[0]) * batch_size, 1)
    edges = [torch.LongTensor(edges[0]), torch.LongTensor(edges[1])]
    if batch_size == 1:
        return edges, edge_attr
    elif batch_size > 1:
        rows, cols = [], []
        for i in range(batch_size):
            rows.append(edges[0] + n_nodes * i)
            cols.append(edges[1] + n_nodes * i)
        edges = [torch.cat(rows), torch.cat(cols)]
    return edges, edge_attr

In [54]:
#NBVAL_SKIP
import torch
import torch.nn as nn
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from torch.nn.functional import binary_cross_entropy_with_logits, mse_loss
from torch_geometric.nn import global_add_pool
import pytorch_lightning as pl
#from torchmetrics import F1Score, Accuracy, AUROC
from pytorch_lightning.loggers import WandbLogger


class SimpleEGNN(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = EGNN(
            in_node_nf=20,
            out_node_nf=32,
            in_edge_nf=0,
            hidden_nf=32,
            n_layers=2,
        )
        self.decoder = nn.Sequential(
            nn.Linear(32, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
        )
        self.loss = binary_cross_entropy_with_logits

    def configure_loss(self, name: str):
        """Return the loss function based on the config."""
        return self.loss

    # --- Forward pass
    def forward(self, x):
        x.aa = torch.cat([torch.tensor(a) for a in x.amino_acid_one_hot]).float().cuda()
        x.c = torch.cat([torch.tensor(a).squeeze(0) for a in x.coords]).float().cuda()
        feats, coords = self.model(
            h=x.aa,
            x=x.c,
            edges=x.edge_index,
            edge_attr=None,
        )
        feats = global_add_pool(feats, x.batch)
        return self.decoder(feats)

    def training_step(self, batch: Data, batch_idx) -> torch.Tensor:
        x = batch
        y = batch.graph_y.unsqueeze(1).float()
        y_hat = self(x)

        loss = self.loss(y_hat, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x = batch
        y = batch.graph_y.unsqueeze(1).float()
        y_hat = self(batch)

    def test_step(self, batch, batch_idx):
        x = batch
        y = batch.graph_y.unsqueeze(1).float()
        y_hat = self(x)
        loss = self.loss(y_hat, y)

        y_pred_softmax = torch.log_softmax(y_hat, dim = 1)
        y_pred_tags = torch.argmax(y_pred_softmax, dim = 1)
        self.log("test_loss", loss)
        #return loss

    def configure_optimizers(self) -> torch.optim.Optimizer:
        return torch.optim.Adam(params=self.parameters(), lr=0.001)

In [55]:
#NBVAL_SKIP
trainer = pl.Trainer(
    strategy=None,
    gpus=1,
    benchmark=True,
    deterministic=False,
    num_sanity_val_steps=0,
    max_epochs=10,
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [56]:
#NBVAL_SKIP
model = SimpleEGNN()
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=valid_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name    | Type       | Params
---------------------------------------
0 | model   | EGNN       | 16.5 K
1 | decoder | Sequential | 1.1 K 
---------------------------------------
17.6 K    Trainable params
0         Non-trainable params
17.6 K    Total params
0.070     Total estimated model params size (MB)


Epoch 9: 100%|██████████| 119/119 [00:09<00:00, 12.41it/s, loss=0.551, v_num=12]


In [57]:
#NBVAL_SKIP
trainer.test(model, test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Testing DataLoader 0: 100%|██████████| 29/29 [00:02<00:00, 14.22it/s]


[{'test_loss': 0.6242247223854065}]

Not bad, but maybe you can do better! :)