In [5]:
import os
import json
import math
import random
from datetime import datetime

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
plt.style.use('seaborn')
import seaborn as sns
%matplotlib inline

from pymongo import MongoClient

import torch
import torch.utils.data as data
import torch.nn.functional as F
import dgl

In [90]:
import pymongo
pymongo.__version__

'3.12.0'

In [2]:
secrets = json.load(open('DocumentDB_secrets.json', 'r')) 

In [3]:
# TLS enabled
uri = 'mongodb://{}:{}@{}:27017/?tls=true&tlsCAFile=rds-combined-ca-bundle.pem&replicaSet=rs0&readPreference=secondaryPreferred&retryWrites=false'\
    .format(secrets['db_username'], secrets['db_password'], secrets['host'])

client = MongoClient(uri)

In [4]:
db = client['proteins']
collection = db['proteins']

## Dataset class

In [41]:
from Bio.PDB.Polypeptide import d1_to_index, three_to_one

# d1_to_index['X'] = len(d1_to_index) # encode uncommon residue as 20
d1_to_index['X'] = 20

def _convert_to_graph(protein):
    '''
    Convert a protein (dict) to a dgl graph
    '''
    coords = torch.tensor(protein['coords'])
    X_ca = coords[:, 1]
    # construct knn graph from C-alpha coordinates
    g = dgl.knn_graph(X_ca, k=2)        
    seq = protein['seq']
    node_features = torch.tensor([d1_to_index[residue] for residue in seq])
    node_features = F.one_hot(node_features, num_classes=len(d1_to_index)).to(dtype=torch.float)

    # add node features
    g.ndata["h"] = node_features
    return g    


class ProteinDataset(data.IterableDataset):
    """
    An iterable-style dataset for proteins in DocumentDB
    Args:
        - pipeline: an aggregation pipeline to retrieve data from DocumentDB
    """
    def __init__(self, pipeline, db_uri='', db_name='', collection_name=''):
        
        self.db_uri = db_uri
        self.db_name = db_name
        self.collection_name = collection_name
        
        with MongoClient(self.db_uri) as client:
            collection = client[self.db_name][self.collection_name]
            # pre-fetch the metadata as docs from DocumentDB
            self.docs = [doc for doc in collection.aggregate(pipeline)]
        # mapping document '_id' to label
        self.labels = {doc['_id']: doc["y"] for doc in self.docs}
    
    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:  # single-process data loading, return the full iterator
            protein_ids = [doc['_id'] for doc in self.docs]
            
        else:  # in a worker process
            # split workload
            start = 0
            end = len(self.docs)
            per_worker = int(math.ceil((end - start) / float(worker_info.num_workers)))
            worker_id = worker_info.id
            iter_start = start + worker_id * per_worker
            iter_end = min(iter_start + per_worker, end)
            
            protein_ids = [doc['_id'] for doc in self.docs[iter_start:iter_end]]
        
        # retrieve a list of proteins by _id from DocDB
        with MongoClient(self.db_uri) as client:
            collection = client[self.db_name][self.collection_name]
            cur = collection.find(
                {'_id': {'$in': protein_ids}}, 
                projection={"coords": True, "seq": True}
            )
            return ((_convert_to_graph(protein), self.labels[protein['_id']]) \
                    for protein in cur)
    
    def __len__(self):
        return len(self.docs)
        
    def subset(self, indices):
        '''Subset metadata docs inplace'''
        self.docs = [self.docs[i] for i in indices]
        return 
        
def collate(samples):
    graphs, targets = map(list, zip(*samples))
    bg = dgl.batch(graphs)
    return bg, torch.tensor(targets).unsqueeze(1).to(torch.float32)

In [73]:
d1_to_index

{'A': 0,
 'C': 1,
 'D': 2,
 'E': 3,
 'F': 4,
 'G': 5,
 'H': 6,
 'I': 7,
 'K': 8,
 'L': 9,
 'M': 10,
 'N': 11,
 'P': 12,
 'Q': 13,
 'R': 14,
 'S': 15,
 'T': 16,
 'V': 17,
 'W': 18,
 'Y': 19,
 'X': 20}

In [42]:
# https://github.com/pytorch/pytorch/blob/7729581414962ac0a23ebd269f165f6a877490ae/torch/utils/data/dataset.py#L257-L312
from typing import Iterator
class BufferedShuffleDataset(data.IterableDataset):
    r"""Dataset shuffled from the original dataset.
    This class is useful to shuffle an existing instance of an IterableDataset.
    The buffer with `buffer_size` is filled with the items from the dataset first. Then,
    each item will be yielded from the buffer by reservoir sampling via iterator.
    `buffer_size` is required to be larger than 0. For `buffer_size == 1`, the
    dataset is not shuffled. In order to fully shuffle the whole dataset, `buffer_size`
    is required to be greater than or equal to the size of dataset.
    When it is used with :class:`~torch.utils.data.DataLoader`, each item in the
    dataset will be yielded from the :class:`~torch.utils.data.DataLoader` iterator.
    And, the method to set up a random seed is different based on :attr:`num_workers`.
    For single-process mode (:attr:`num_workers == 0`), the random seed is required to
    be set before the :class:`~torch.utils.data.DataLoader` in the main process.
        >>> ds = BufferedShuffleDataset(dataset)
        >>> random.seed(...)
        >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
    For multi-process mode (:attr:`num_workers > 0`), the random seed is set by a callable
    function in each worker.
        >>> ds = BufferedShuffleDataset(dataset)
        >>> def init_fn(worker_id):
        ...     random.seed(...)
        >>> print(list(torch.utils.data.DataLoader(ds, ..., num_workers=n, worker_init_fn=init_fn)))
    Arguments:
        dataset (IterableDataset): The original IterableDataset.
        buffer_size (int): The buffer size for shuffling.
    """
    
    dataset: data.IterableDataset
    buffer_size: int
    def __init__(self, dataset: data.IterableDataset, buffer_size: int) -> None:
        super(BufferedShuffleDataset, self).__init__()
        assert buffer_size > 0, "buffer_size should be larger than 0"
        self.dataset = dataset
        self.buffer_size = buffer_size

    def __iter__(self) -> Iterator:
        buf = []
        for x in self.dataset:
            if len(buf) == self.buffer_size:
                idx = random.randint(0, self.buffer_size - 1)
                yield buf[idx]
                buf[idx] = x
            else:
                buf.append(x)
        random.shuffle(buf)
        while buf:
            yield buf.pop()


In [43]:
match = {"is_AF": {"$exists": True}}
project = {"y": "$is_AF"}

pipeline = [
    {"$match": match},
    {"$project": project},
]

dataset = ProteinDataset(
    pipeline,
    db_uri=uri,
    db_name='proteins', 
    collection_name='proteins'
)

In [44]:
dataset.docs[:5]

[{'_id': ObjectId('611fed5aa9e1be4d05332068'), 'y': 1},
 {'_id': ObjectId('611fed5aa9e1be4d05332069'), 'y': 1},
 {'_id': ObjectId('611fed5aa9e1be4d0533206a'), 'y': 1},
 {'_id': ObjectId('611fed5ba9e1be4d0533206b'), 'y': 1},
 {'_id': ObjectId('611fed5ba9e1be4d0533206c'), 'y': 1}]

In [45]:
len(dataset)

3151

In [46]:
ds = BufferedShuffleDataset(dataset, buffer_size=128)
random.seed(43)

## Speed test
### single worker

In [47]:
i = 0
data_loader = data.DataLoader(
    ds, 
    num_workers=0, 
    batch_size=32, 
    collate_fn=collate
)
# Epoch 0
t0 = datetime.now()
for bg, labels in data_loader:
    if i == 0:
        print(bg, labels)
    i += 1
    
print('Epoch time:', datetime.now() - t0)

Graph(num_nodes=9943, num_edges=19886,
      ndata_schemes={'h': Scheme(shape=(21,), dtype=torch.float32)}
      edata_schemes={}) tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]])
Epoch time: 0:00:13.633357


In [48]:
# dgl.unbatch(bg)

In [49]:
# Epoch 1
for bg, labels in data_loader:
    print(bg, labels)    
    break

Graph(num_nodes=11052, num_edges=22104,
      ndata_schemes={'h': Scheme(shape=(21,), dtype=torch.float32)}
      edata_schemes={}) tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]])


In [51]:
from collections import Counter
Counter(dataset.labels.values())

Counter({1: 1773, 0: 1378})

### multiple workers

In [52]:
ds = BufferedShuffleDataset(dataset, buffer_size=128)
# def init_fn(worker_id):
#     random.seed(42+worker_id)
    
i = 0
data_loader = data.DataLoader(
    ds, 
    num_workers=2, 
    batch_size=4, 
    collate_fn=collate,
)
# Epoch 0
t0 = datetime.now()
for bg, labels in data_loader:
    if i == 0:
        print(bg, labels)
    i += 1    
print('Epoch time:', datetime.now() - t0)
# collection=None, Epoch time: 0:00:11.233891
# collection=collection, Epoch time: 0:00:11.256877
# with ... as client: Epoch time: 0:00:11.698559

Graph(num_nodes=527, num_edges=1054,
      ndata_schemes={'h': Scheme(shape=(21,), dtype=torch.float32)}
      edata_schemes={}) tensor([[1.],
        [1.],
        [1.],
        [1.]])
Epoch time: 0:00:10.909427


In [53]:
# Epoch 1
for bg, labels in data_loader:
    print(bg, labels)
    break

Graph(num_nodes=986, num_edges=1972,
      ndata_schemes={'h': Scheme(shape=(21,), dtype=torch.float32)}
      edata_schemes={}) tensor([[1.],
        [1.],
        [1.],
        [1.]])


## GNN model

In [67]:
# from dgl.dataloading import GraphDataLoader
# from torch.utils.data.sampler import SubsetRandomSampler
batch_size = 64
project = {"y": "$is_AF"}

def match_by_split(split):
    return {"$and": [
        {"is_AF": {"$exists": True}},
        {"split": split}
    ]} 

train_dataset = ProteinDataset(
    [
        {"$match": match_by_split('train')},
        {"$project": project},
    ],
    db_uri=uri,
    db_name='proteins', 
    collection_name='proteins'    
)


valid_dataset = ProteinDataset(
    [
        {"$match": match_by_split('valid')},
        {"$project": project},
    ],
    db_uri=uri,
    db_name='proteins', 
    collection_name='proteins'    
)

test_dataset = ProteinDataset(
    [
        {"$match": match_by_split('test')},
        {"$project": project},
    ],
    db_uri=uri,
    db_name='proteins', 
    collection_name='proteins'    
)

print(len(train_dataset), len(valid_dataset), len(test_dataset))

train_dataloader = data.DataLoader(
    BufferedShuffleDataset(train_dataset, buffer_size=128),
    batch_size=batch_size, 
    collate_fn=collate,
    num_workers=8
)


valid_dataloader = data.DataLoader(
    valid_dataset, 
    batch_size=batch_size, 
    collate_fn=collate,
)
test_dataloader = data.DataLoader(
    test_dataset, 
    batch_size=batch_size, 
    collate_fn=collate,
)

2016 504 631


In [68]:
import torch.nn as nn
from dgl.nn import GraphConv

class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        g.ndata['h'] = h
        return dgl.mean_nodes(g, 'h')


In [69]:
# Create the model with given dimensions
dim_nfeats = len(d1_to_index)
n_classes = 1
model = GCN(dim_nfeats, 16, n_classes)

In [79]:
device = torch.device('cuda:0')
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [71]:
t0 = datetime.now()
model.train()
for epoch in range(5):
    print('epoch:', epoch)
    for batched_graph, labels in train_dataloader:
        
        batched_graph = batched_graph.to(device)
        labels = labels.to(device)
        
        pred = model(batched_graph, batched_graph.ndata['h'])
        loss = F.binary_cross_entropy_with_logits(pred, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

print('Time elapsed:', datetime.now() - t0)

epoch: 0
epoch: 1
epoch: 2
epoch: 3
epoch: 4
Time elapsed: 0:00:19.232999


In [72]:
num_correct = 0
num_tests = 0
model.eval()
with torch.no_grad():
    for batched_graph, labels in test_dataloader:
        batched_graph = batched_graph.to(device)
        labels = labels.to(device)

        pred = model(batched_graph, batched_graph.ndata['h'].float())
        num_correct += (pred.argmax(1) == labels).sum().item()
        num_tests += len(labels)

print('Test accuracy:', num_correct / num_tests)

Test accuracy: 27.20919175911252


In [74]:
from sklearn.metrics import roc_auc_score
class EarlyStopper(object):
    def __init__(self, patience, filename=None):
        if filename is None:
            # Name checkpoint based on time
            dt = datetime.now()
            filename = "early_stop_{}_{:02d}-{:02d}-{:02d}.pth".format(
                dt.date(), dt.hour, dt.minute, dt.second
            )
            filename = os.path.join("/opt/ml/model", filename)

        self.patience = patience
        self.counter = 0
        self.filename = filename
        self.best_score = None
        self.early_stop = False

    def save_checkpoint(self, model):
        """Saves model when the metric on the validation set gets improved."""
        torch.save({"model_state_dict": model.state_dict()}, self.filename)

    def load_checkpoint(self, model):
        """Load model saved with early stopping."""
        model.load_state_dict(torch.load(self.filename)["model_state_dict"])

    def step(self, score, model):
        if (self.best_score is None) or (score > self.best_score):
            self.best_score = score
            self.save_checkpoint(model)
            self.counter = 0
        else:
            self.counter += 1
            print(
                "EarlyStopping counter: {:d} out of {:d}".format(
                    self.counter, self.patience
                )
            )
            if self.counter >= self.patience:
                self.early_stop = True
        return self.early_stop


class Meter(object):
    """Track and summarize model performance on a dataset for
    (multi-label) binary classification."""

    def __init__(self):
        self.y_pred = []
        self.y_true = []

    def update(self, y_pred, y_true):
        """Update for the result of an iteration
        Parameters
        ----------
        y_pred : float32 tensor
            Predicted molecule labels with shape (B, T),
            B for batch size and T for the number of tasks
        y_true : float32 tensor
            Ground truth molecule labels with shape (B, T)
        """
        self.y_pred.append(y_pred.detach().cpu())
        self.y_true.append(y_true.detach().cpu())

    def roc_auc_score(self):
        """Compute roc-auc score for each task.
        Returns
        -------
        list of float
            roc-auc score for all tasks
        """
        y_pred = torch.cat(self.y_pred, dim=0)
        y_true = torch.cat(self.y_true, dim=0)
        # This assumes binary case only
        y_pred = torch.sigmoid(y_pred)
        n_tasks = y_true.shape[1]
        scores = []
        for task in range(n_tasks):
            task_y_true = y_true[:, task].numpy()
            task_y_pred = y_pred[:, task].numpy()
            scores.append(roc_auc_score(task_y_true, task_y_pred))
        return scores


In [84]:
def run_a_train_epoch(args, epoch, model, data_loader, optimizer):
    model.train()
    train_meter = Meter()
    for batch_id, batch_data in enumerate(data_loader):
        bg, labels = batch_data
        bg = bg.to(args["device"])
        labels = labels.to(args["device"])
        logits = model(bg, bg.ndata["h"])
        # Mask non-existing labels
        loss = F.binary_cross_entropy_with_logits(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(
            "epoch {:d}/{:d}, batch {:d}/, loss {:.4f}".format(
                epoch + 1,
                args["n_epochs"],
                batch_id + 1,
                loss.item(),
            )
        )
        train_meter.update(logits, labels)
    train_score = np.mean(train_meter.roc_auc_score())
    print(
        "epoch {:d}/{:d}, training roc-auc {:.4f}".format(
            epoch + 1, args["n_epochs"], train_score
        )
    )

def run_an_eval_epoch(args, model, data_loader):
    model.eval()
    eval_meter = Meter()
    with torch.no_grad():
        for batch_data in data_loader:
            bg, labels = batch_data
            bg = bg.to(args["device"])
            labels = labels.to(args["device"])
            logits = model(bg, bg.ndata["h"])
            eval_meter.update(logits, labels)
    return np.mean(eval_meter.roc_auc_score())


In [85]:
args = {'device': device, 'patience': 5, 'n_epochs': 10}

stopper = EarlyStopper(args["patience"])
run_a_train_epoch(args, 0, model, train_dataloader, optimizer)

epoch 1/10, batch 1/, loss 0.5085
epoch 1/10, batch 2/, loss 0.4835
epoch 1/10, batch 3/, loss 0.4521
epoch 1/10, batch 4/, loss 0.4260
epoch 1/10, batch 5/, loss 0.4926
epoch 1/10, batch 6/, loss 1.1380
epoch 1/10, batch 7/, loss 1.1436
epoch 1/10, batch 8/, loss 1.1436
epoch 1/10, batch 9/, loss 0.3861
epoch 1/10, batch 10/, loss 0.3728
epoch 1/10, batch 11/, loss 0.3727
epoch 1/10, batch 12/, loss 0.3709
epoch 1/10, batch 13/, loss 0.7544
epoch 1/10, batch 14/, loss 1.1461
epoch 1/10, batch 15/, loss 1.1346
epoch 1/10, batch 16/, loss 1.1240
epoch 1/10, batch 17/, loss 0.4566
epoch 1/10, batch 18/, loss 0.4170
epoch 1/10, batch 19/, loss 0.4001
epoch 1/10, batch 20/, loss 0.4013
epoch 1/10, batch 21/, loss 0.7740
epoch 1/10, batch 22/, loss 1.0579
epoch 1/10, batch 23/, loss 1.0391
epoch 1/10, batch 24/, loss 1.0340
epoch 1/10, batch 25/, loss 0.4648
epoch 1/10, batch 26/, loss 0.4487
epoch 1/10, batch 27/, loss 0.4434
epoch 1/10, batch 28/, loss 0.4387
epoch 1/10, batch 29/, loss 0

In [89]:
# Validation and early stop
val_score = run_an_eval_epoch(args, model, valid_dataloader)
# early_stop = stopper.step(val_score, model)
print(
    "epoch {:d}/{:d}, validation roc-auc {:.4f}, best validation roc-auc {:.4f}".format(
        epoch + 1, args["n_epochs"], val_score, stopper.best_score
    )
)

epoch 5/10, validation roc-auc 0.7721, best validation roc-auc 0.7721
