In [1]:
import os
import json

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
plt.style.use('seaborn')
import seaborn as sns
%matplotlib inline

In [2]:
from pymongo import MongoClient
secrets = json.load(open('DocumentDB_secrets.json', 'r')) 

In [3]:
# TLS enabled
uri = 'mongodb://{}:{}@{}:27017/?tls=true&tlsCAFile=rds-combined-ca-bundle.pem&replicaSet=rs0&readPreference=secondaryPreferred&retryWrites=false'\
    .format(secrets['db_username'], secrets['db_password'], secrets['host'])

client = MongoClient(uri)

In [4]:
db = client['proteins']
collection = db['proteins']

In [5]:
import torch
import torch.utils.data as data
import torch.nn.functional as F

In [6]:
import dgl

Using backend: pytorch


## Dataset class

In [7]:
def _connect_to_db():
    uri = 'mongodb://{}:{}@{}:27017/?tls=true&tlsCAFile=rds-combined-ca-bundle.pem&replicaSet=rs0&readPreference=secondaryPreferred&retryWrites=false'\
        .format(secrets['db_username'], secrets['db_password'], secrets['host'])

    client = MongoClient(uri)
    db = client['proteins']
    collection = db['proteins']
    return collection

In [78]:
import math
import random
from Bio.PDB.Polypeptide import d1_to_index, three_to_one

# d1_to_index['X'] = len(d1_to_index) # encode uncommon residue as 20
d1_to_index['X'] = 20

def _convert_to_graph(protein):
    '''
    Convert a protein (dict) to a dgl graph
    '''
    coords = torch.tensor(protein['coords'])
    X_ca = coords[:, 1]
    # construct knn graph from C-alpha coordinates
    g = dgl.knn_graph(X_ca, k=2)        
    seq = protein['seq']
    node_features = torch.tensor([d1_to_index[residue] for residue in seq])
    node_features = F.one_hot(node_features, num_classes=len(d1_to_index)).to(dtype=torch.float)

    # add node features
    g.ndata["h"] = node_features
    return g    

class ProteinDataset(data.Dataset):
    """Map-style Dataset"""
    def __init__(self, collection, pipeline):
        """
        
        Args:
            - collection: pymongo.collection.Collection object
            - pipeline: a DocumentDB aggregation pipeline
            - tokenizer:
        """
        self.collection = collection
        # pre-fetch the metadata and labels from DocumentDB
        self.docs = [doc for doc in self.collection.aggregate(pipeline)]
        self.labels = [doc["y"] for doc in self.docs]
        
        
    def __getitem__(self, idx):
        id_ = self.docs[idx]['id']
#         collection = _connect_to_db()
        protein = self.collection.find_one(
            {'id': id_}, 
            projection={"_id": False, "coords": True, "seq": True}
        )
        return _convert_to_graph(protein), self.labels[idx]

    def __len__(self):
        return len(self.docs)

class IProteinDataset(data.IterableDataset):
    """"""
    def __init__(self, collection, pipeline):
        self.collection = collection
        # pre-fetch the metadata and labels from DocumentDB
        self.docs = [doc for doc in self.collection.aggregate(pipeline)]
        self.labels = {doc['_id']: doc["y"] for doc in self.docs}
#         self.batch_size = batch_size
    
    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:  # single-process data loading, return the full iterator
            protein_ids = [doc['_id'] for doc in self.docs]
            
        else:  # in a worker process
            # split workload
            start = 0
            end = len(self.docs)
            per_worker = int(math.ceil((end - start) / float(worker_info.num_workers)))
            worker_id = worker_info.id
            iter_start = start + worker_id * per_worker
            iter_end = min(iter_start + per_worker, end)
            
            protein_ids = [doc['_id'] for doc in self.docs[iter_start:iter_end]]
        
        # retrieve a list of proteins by _id from DocDB
        cur = self.collection.find(
            {'_id': {'$in': protein_ids}}, 
            projection={"coords": True, "seq": True}
        )
        return ((_convert_to_graph(protein), self.labels[protein['_id']]) \
                for protein in cur)
        
    def shuffle(self):
        random.shuffle(self.docs)
        
def collate(samples):
#     graphs = list(zip(*samples))[0]
#     targets = list(zip(*samples))[1]
    graphs, targets = map(list, zip(*samples))
    bg = dgl.batch(graphs)
    return bg, torch.tensor(targets).unsqueeze(1).to(torch.float32)        

In [80]:
client = MongoClient(uri, connect=False)
db = client['proteins']
collection = db['proteins']

match = {"is_AF": {"$exists": True}}
project = {"y": "$is_AF"}

pipeline = [
    {"$match": match},
    {"$project": project},
]


ds = IProteinDataset(collection, pipeline)

In [70]:
docs = [doc for doc in collection.aggregate(pipeline)]
len(docs)

3151

In [71]:
protein_ids = [doc['_id'] for doc in docs[:4]]
protein_ids

[ObjectId('611fed5aa9e1be4d05332068'),
 ObjectId('611fed5aa9e1be4d05332069'),
 ObjectId('611fed5aa9e1be4d0533206a'),
 ObjectId('611fed5ba9e1be4d0533206b')]

In [72]:
proteins = collection.find(
    {'_id': {'$in': protein_ids}}, 
    projection={"coords": True, "seq": True, "is_AF": True}
)

In [73]:
type(proteins)

pymongo.cursor.Cursor

In [74]:
for protein in proteins:
    print(protein.keys())

dict_keys(['_id', 'seq', 'coords', 'is_AF'])
dict_keys(['_id', 'seq', 'coords', 'is_AF'])
dict_keys(['_id', 'seq', 'coords', 'is_AF'])
dict_keys(['_id', 'seq', 'coords', 'is_AF'])


In [75]:
labels = {protein['_id']: protein['is_AF'] for protein in proteins}
labels

{}

In [76]:
for bg, labels in data.DataLoader(ds, num_workers=0, 
                             batch_size=4, collate_fn=collate):
    print(bg, labels)
    break

Graph(num_nodes=1440, num_edges=2880,
      ndata_schemes={'h': Scheme(shape=(21,), dtype=torch.float32)}
      edata_schemes={}) tensor([[1.],
        [1.],
        [1.],
        [1.]])


In [81]:
for bg, labels in data.DataLoader(ds, num_workers=2, 
                             batch_size=4, collate_fn=collate):
    print(bg, labels)
    break

  "MongoClient opened before fork. Create MongoClient only "
  "MongoClient opened before fork. Create MongoClient only "
Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7fc97ca62860>>
Traceback (most recent call last):
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()Exception ignored in: 
<bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7fc97ca62860>>  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers

    Traceback (most recent call last):
w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/torch/utils/data/datalo

Graph(num_nodes=1440, num_edges=2880,
      ndata_schemes={'h': Scheme(shape=(21,), dtype=torch.float32)}
      edata_schemes={}) tensor([[1.],
        [1.],
        [1.],
        [1.]])


In [None]:
collection.find(
                {'id': {'$in': protein_ids}}, 
                projection={"_id": False, "coords": True, "seq": True}
            )

In [64]:
doc = dataset.docs[0]
doc

{'id': 'AF-Q57935', 'y': 1}

In [51]:
%%timeit
# id_ = doc['id']
collection = _connect_to_db()
protein = collection.find_one(
    {'id': id_}, 
    projection={"_id": False, "coords": True, "seq": True}
)
# 179 ms ± 8.28 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

176 ms ± 762 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [71]:
%%timeit
# id_ = doc['id']
collection = _connect_to_db()
protein = collection.find_one(
    {'id': id_}, 
    projection={"_id": False, "seq": True}
)
# 179 ms ± 8.28 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

173 ms ± 2.5 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [44]:
id_ = doc['id']
collection = _connect_to_db()
protein = collection.find_one(
    {'id': id_}, 
    projection={"_id": False, "coords": True, "seq": True}
)

In [50]:
%%timeit
collection = _connect_to_db()

4.18 ms ± 75.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [52]:
%%timeit
protein = collection.find_one(
    {'id': id_}, 
    projection={"_id": False, "coords": True, "seq": True}
)

4.13 ms ± 26.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [45]:
%%timeit
coords = torch.tensor(protein['coords'])
X_ca = coords[:, 1]
# construct knn graph from C-alpha coordinates
g = dgl.knn_graph(X_ca, k=2)

2.6 ms ± 10.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [46]:
coords = torch.tensor(protein['coords'])
X_ca = coords[:, 1]
# construct knn graph from C-alpha coordinates
g = dgl.knn_graph(X_ca, k=2)

In [47]:
%%timeit
seq = protein['seq']
node_features = torch.tensor([d1_to_index[residue] for residue in seq])
node_features = F.one_hot(node_features, num_classes=len(d1_to_index)).to(dtype=torch.float)

# add node features
g.ndata["h"] = node_features

166 µs ± 205 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [65]:
%%timeit
dataset[0]

189 ms ± 4.08 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [22]:
match = {"is_AF": {"$exists": True}}
project = {"y": "$is_AF", "_id": False, 'id': True}

pipeline = [
    {"$match": match},
    {"$project": project},
]

# docs = [doc for doc in collection.aggregate(pipeline)]
# docs[0]
dataset = ProteinDataset(collection, pipeline)

In [23]:
g, label = dataset[0]
g, label

(Graph(num_nodes=380, num_edges=760,
       ndata_schemes={'h': Scheme(shape=(21,), dtype=torch.float32)}
       edata_schemes={}),
 1)

In [24]:
labels = dataset.labels

In [25]:
np.unique(labels, return_counts=True)

(array([0, 1]), array([1378, 1773]))

## GNN model

In [27]:
from dgl.dataloading import GraphDataLoader
from torch.utils.data.sampler import SubsetRandomSampler

num_examples = len(dataset)
num_train = int(num_examples * 0.8)

train_sampler = SubsetRandomSampler(torch.arange(num_train))
test_sampler = SubsetRandomSampler(torch.arange(num_train, num_examples))

# train_dataloader = GraphDataLoader(
#     dataset, sampler=train_sampler, batch_size=32, drop_last=False,
#     num_workers=16
# )
# test_dataloader = GraphDataLoader(
#     dataset, sampler=test_sampler, batch_size=32, drop_last=False)

train_dataloader = data.DataLoader(
    dataset, sampler=train_sampler, batch_size=32, 
    collate_fn=collate,
    num_workers=32
)

test_dataloader = data.DataLoader(
    dataset, sampler=test_sampler, batch_size=32, 
    collate_fn=collate,
#     num_workers=1
)

In [28]:
len(train_dataloader), len(test_dataloader)

(79, 20)

In [29]:
it = iter(train_dataloader)
batch = next(it)
print(batch)

(Graph(num_nodes=7887, num_edges=15774,
      ndata_schemes={'h': Scheme(shape=(21,), dtype=torch.float32)}
      edata_schemes={}), tensor([[1.],
        [1.],
        [1.],
        [1.],
        [0.],
        [0.],
        [1.],
        [0.],
        [1.],
        [1.],
        [1.],
        [0.],
        [1.],
        [0.],
        [1.],
        [1.],
        [1.],
        [1.],
        [0.],
        [1.],
        [1.],
        [0.],
        [1.],
        [0.],
        [1.],
        [0.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [0.]]))


In [54]:
%%timeit
it = iter(train_dataloader)
batch = next(it)
# bs=workers=32
# 10.5 s ± 2.19 s per loop (mean ± std. dev. of 7 runs, 1 loop each)

10.5 s ± 2.19 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [69]:
train_dataloader = data.DataLoader(
    dataset, sampler=train_sampler, batch_size=1, 
    collate_fn=collate
)

In [70]:
%%timeit
it = iter(train_dataloader)
batch = next(it)

191 ms ± 14.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [66]:
train_dataloader = data.DataLoader(
    dataset, sampler=train_sampler, batch_size=32, 
    collate_fn=collate
)

In [67]:
%%timeit
it = iter(train_dataloader)
batch = next(it)
# bs=32, workers=0
# 5.99 s ± 76 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

6.08 s ± 78.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [61]:
train_dataloader = data.DataLoader(
    dataset, sampler=train_sampler, batch_size=32, 
    collate_fn=collate,
    num_workers=32,
    persistent_workers=True
)

In [62]:
%%timeit
it = iter(train_dataloader)
batch = next(it)
# bs=workers=32 persistent_workers=True
# 25.2 s ± 3.72 s per loop (mean ± std. dev. of 7 runs, 1 loop each)

25.2 s ± 3.72 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [30]:
dgl.unbatch(batch[0])

[Graph(num_nodes=1155, num_edges=2310,
       ndata_schemes={'h': Scheme(shape=(21,), dtype=torch.float32)}
       edata_schemes={}),
 Graph(num_nodes=115, num_edges=230,
       ndata_schemes={'h': Scheme(shape=(21,), dtype=torch.float32)}
       edata_schemes={}),
 Graph(num_nodes=237, num_edges=474,
       ndata_schemes={'h': Scheme(shape=(21,), dtype=torch.float32)}
       edata_schemes={}),
 Graph(num_nodes=203, num_edges=406,
       ndata_schemes={'h': Scheme(shape=(21,), dtype=torch.float32)}
       edata_schemes={}),
 Graph(num_nodes=342, num_edges=684,
       ndata_schemes={'h': Scheme(shape=(21,), dtype=torch.float32)}
       edata_schemes={}),
 Graph(num_nodes=317, num_edges=634,
       ndata_schemes={'h': Scheme(shape=(21,), dtype=torch.float32)}
       edata_schemes={}),
 Graph(num_nodes=214, num_edges=428,
       ndata_schemes={'h': Scheme(shape=(21,), dtype=torch.float32)}
       edata_schemes={}),
 Graph(num_nodes=39, num_edges=78,
       ndata_schemes={'h': Scheme(shape

In [31]:
import torch.nn as nn
from dgl.nn import GraphConv

class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        g.ndata['h'] = h
        return dgl.mean_nodes(g, 'h')


In [32]:
# Create the model with given dimensions
dim_nfeats = len(d1_to_index)
n_classes = 1

model = GCN(dim_nfeats, 16, n_classes)

In [33]:
device = torch.device('cuda:0')
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [None]:
labels.shape

In [39]:
from datetime import datetime

In [72]:
train_dataloader = data.DataLoader(
    dataset, sampler=train_sampler, batch_size=32, 
    collate_fn=collate,
)

In [73]:
t0 = datetime.now()
model.train()
for epoch in range(2):
    print('epoch:', epoch)
    for batched_graph, labels in train_dataloader:
        
        batched_graph = batched_graph.to(device)
        labels = labels.to(device)
        
        pred = model(batched_graph, batched_graph.ndata['h'])
        loss = F.binary_cross_entropy_with_logits(pred, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

print('Time elapsed:', datetime.now() - t0)
# Time elapsed: 0:00:59.909561: bs=workers=32
# Time elapsed: 0:15:49.611212 bs=32 workers=0

epoch: 0
epoch: 1
Time elapsed: 0:15:49.611212


In [35]:
num_correct = 0
num_tests = 0
model.eval()
with torch.no_grad():
    for batched_graph, labels in test_dataloader:
        batched_graph = batched_graph.to(device)
        labels = labels.to(device)

        pred = model(batched_graph, batched_graph.ndata['h'].float())
        num_correct += (pred.argmax(1) == labels).sum().item()
        num_tests += len(labels)

print('Test accuracy:', num_correct / num_tests)

Test accuracy: 31.671949286846274


In [92]:
loss = F.binary_cross_entropy_with_logits(pred, labels.to(torch.float32))

In [87]:
labels.dtype, pred.dtype

(torch.int64, torch.float32)

In [88]:
labels.shape, pred.shape

(torch.Size([32, 1]), torch.Size([32, 1]))

In [84]:
loss = F.bincross_entropy(pred, labels.long())

RuntimeError: multi-target not supported at /opt/conda/conda-bld/pytorch_1607370116979/work/aten/src/THCUNN/generic/ClassNLLCriterion.cu:15