In [None]:
import os
import sys
import argparse
import networkx as nx
import numpy as np

sys.path.append('../pyprot/')
import pyprot.graph_models as graph_models
from pyprot.downloader import PdbDownloader, ConsurfDBDownloader
from pyprot.protein import Protein
from pyprot.structure import Perseus


### Graph verification and preprocessing

In [None]:
import networkx as nx
import pickle
import pyprot.constants
def load_graph(fn):
    with open(fn, "rb") as f:
        return pickle.load(f)
filenames = [fn for fn in os.listdir("graphs/") if ".pkl" in fn]
graphs = [load_graph("graphs/"+fn) for fn in filenames]

In [None]:
index_amino = {code3:i for i, code3 in enumerate(pyprot.constants.AMINOACIDS_3)}
index_amino["UNK"] = len(index_amino)
num_amino = len(index_amino)

# Features are aminoacid type, bfactor and x,y,z coord.
all_features = []
for graph in graphs:
    features = np.zeros((graph.number_of_nodes(), num_amino + 4))
    for i, node_idx in enumerate(graph.nodes):
        node = graph.nodes[node_idx]
        features[i, index_amino[node["resname"]]] = 1
        features[i, num_amino] = node["bfactor"]
        features[i, num_amino+1:num_amino+4] = node["coord"]
    all_features.append(features)

In [None]:
all_adj = [nx.adjacency_matrix(graph) for graph in graphs]

In [None]:
def touches_ligand(x):
    return x <= 4 or (x<=6 and np.random.binomial(1, 1-(x-4)/2) == 1)

class_balance = []
all_targets = []
for graph in graphs:
    targets = np.zeros((graph.number_of_nodes(), 2))
    for i, node_idx in enumerate(graph.nodes):
        distance = graph.nodes[node_idx]["distance"]
        targets[i, 0] = 1 if touches_ligand(distance) else 0
        targets[i, 1] = 1 - targets[i, 0]
    class_balance.append(targets[:, 0].sum() / targets[:, 1].sum())
    all_targets.append(targets)

In [None]:
print("For every non-contact point there are {} contact points".format(
    sum(class_balance)/len(class_balance)))

In [None]:
buggyG = load_graph("graphs/4DX2.pkl")
distances = []
for node_idx in buggyG.nodes:
    distances.append(buggyG.nodes[node_idx]["distance"])
sorted(distances)

In [None]:
filenames[54]

In [None]:
#Check errors
pdb_error_list = []
for i,target in enumerate(all_targets):
    if target[:,0].sum() < 1.0:
        print("Error found in target #{}".format(i))
        pdb_error_list.append((i,filenames[i]))

In [None]:
pdb_error_list

## Example model

In [None]:
import scipy.sparse as sp
import gcn.utils
import gcn.models
import time
import tensorflow as tf
from tensorflow.python import debug as tf_debug

### Data splits and preprocessing

In [None]:
all_adj = [gcn.utils.sparse_to_tuple(gcn.utils.normalize_adj(adj)) 
            for adj in all_adj]
all_features = [gcn.utils.preprocess_features(sp.lil_matrix(features))
            for features in all_features]

nb_nodes = max(map(lambda adj: adj[2][1], all_adj))

In [None]:
# Make matrices the same size
for i, adj_tuple in enumerate(all_adj):
    #adj_tuple[2] is the shape, and we want it to be always the same..
    all_adj[i] = (adj_tuple[0], adj_tuple[1], (nb_nodes, nb_nodes))

for i, feat_tuple in enumerate(all_features):
    #adj_tuple[2] is the shape, and we want it to be always the same..
    all_features[i] = (feat_tuple[0], feat_tuple[1], (nb_nodes, feat_tuple[2][1]))

    
for i, target in enumerate(all_targets):
    padded = np.zeros((nb_nodes, 2))
    padded[:target.shape[0], 0] = target[:, 0]
    padded[:target.shape[0], 1] = 1 - target[:, 0]
    all_targets[i] = padded

In [None]:
def split_list(data, train_perc, val_perc):
    num_train = int(len(data) * train_perc)
    num_val = int(len(data) * val_perc)
    return data[:num_train], data[num_train:num_train+num_val], data[num_train+num_val:]

features_train, features_val, features_test = split_list(all_features, 0.70, 0.15)
adj_train, adj_val, adj_test = split_list(all_adj, 0.70, 0.15)
y_train, y_val, y_test = split_list(all_targets, 0.70, 0.15)

### 

In [None]:
flags = tf.app.flags
flags.DEFINE_float("learning_rate", 0.01, "Learning Rate")
flags.DEFINE_integer("epochs", 200, "Epochs")
flags.DEFINE_integer("hidden1", 16, "Num units in HL1")
flags.DEFINE_float("dropout", 0.5, "Dropout")
flags.DEFINE_float("weight_decay", 5e-4, "Weight decay")
flags.DEFINE_integer("early_stopping", 10, "Tolerance")
flags.DEFINE_integer("max_degree", 3, "Max chebyshev polynomial degree")
flags.sys.argv = flags.sys.argv[0:1]

In [None]:
placeholders = {
    "support": [tf.sparse_placeholder(tf.float32, name="support")],
    "features": tf.sparse_placeholder(tf.float32, 
        #shape=tf.constant(features_train[0][2], dtype=tf.int64, name="feat_shape_const"),
        name="features"),
    "labels": tf.placeholder(tf.float32, shape=(nb_nodes, 2), name="labels"),
    "labels_mask": tf.placeholder(tf.int32, name="labels_mask"),
    "dropout": tf.placeholder_with_default(0., shape=(), name="dropout"),
    "num_features_nonzero": tf.placeholder(tf.int32, name="nfn0")
}
def evaluate(features, support, labels, mask, placeholders):
    t_test = time.time()
    feed_dict_val = gcn.utils.construct_feed_dict(features, support, labels, mask, placeholders)
    outs_val = sess.run([model.loss, model.accuracy], feed_dict=feed_dict_val)
    return outs_val[0], outs_val[1], (time.time() - t_test)

In [None]:
class BalancedGCN(gcn.models.GCN):
    def _accuracy(self):
        self.accuracy = tf.metrics.auc(
            tf.argmax(self.outputs, 1), 
            tf.argmax(self.placeholders['labels'], 1),
            self.placeholders['labels_mask']
        )
model = BalancedGCN(placeholders, input_dim=features_train[0][2][1], logging=True)
sess = tf.Session()
#sess = tf_debug.LocalCLIDebugWrapperSession(sess)
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())

In [None]:
FLAGS = flags.FLAGS
cost_val = []
mask = np.ones((nb_nodes))

for epoch in range(FLAGS.epochs):
    # Train step
    train_loss_total = 0
    train_acc_total = 0
    for i, features in enumerate(features_train):
        support = [adj_train[i]]
        train_mask = mask
        feed_dict = gcn.utils.construct_feed_dict(
            features, support, y_train[i], train_mask, placeholders)
        feed_dict.update({placeholders['dropout']: FLAGS.dropout})
        outs = sess.run([model.opt_op, model.loss, model.accuracy], feed_dict=feed_dict)
        
        train_loss_total += outs[1]
        train_acc_total += outs[2][1]
    train_loss = train_loss_total / len(features_train)
    train_acc = train_acc_total / len(features_train)
    
    # Validation step
    val_loss_total = 0
    val_acc_total= 0
    for i, features in enumerate(features_val):
        support = [adj_val[i]]
        val_mask = mask
        loss, acc, duration = evaluate(features, support, y_val[i], val_mask, placeholders)
        val_loss_total += loss
        val_acc_total += acc[1]
    val_loss = val_loss_total / len(features_val)
    val_acc = val_acc_total / len(features_val)
    
    cost_val.append(val_loss)
    print("Epoch: {}, train_loss={:.5f} train_acc={:.5f} \
val_loss={:.5f} val_acc={:.5f}".format(epoch+1, train_loss, train_acc,
                                              val_loss, val_acc))

    if epoch > FLAGS.early_stopping and cost_val[-1] > np.mean(cost_val[-(FLAGS.early_stopping+1):-1]):
        print("Early stopping...")
        break