# Initialization

## Tensorboard support

In [0]:
!wget https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip
!unzip ngrok-stable-linux-amd64.zip


In [0]:
LOG_DIR = './logs'
get_ipython().system_raw(
    'tensorboard --logdir {} --host 0.0.0.0 --port 6006 &'
    .format(LOG_DIR)
)
get_ipython().system_raw('./ngrok http 6006 &')
! curl -s http://localhost:4040/api/tunnels | python3 -c \
    "import sys, json; print(json.load(sys.stdin)['tunnels'][0]['public_url'])"

## Perform imports

In [0]:
import tensorflow.keras.backend as K
import matplotlib.pyplot as plt
import itertools
import numpy as np
import random
import pandas as pd
from tensorflow.keras.models import Sequential, Model, load_model
from tensorflow.keras.layers import Input, Conv3D, MaxPooling3D,Flatten,Dense, Dropout, BatchNormalization, Add, AveragePooling3D, Activation, GaussianNoise, Lambda
from tensorflow.keras import optimizers, losses, regularizers
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras.initializers import glorot_normal
from tensorflow.keras.utils import plot_model, Sequence
from tensorflow.keras.activations import relu
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard
from tensorflow.train import AdamOptimizer, GradientDescentOptimizer, MomentumOptimizer
from tensorflow.contrib.opt import AdamWOptimizer
from tensorflow.contrib.tpu import CrossShardOptimizer
from IPython.display import SVG
from sklearn.metrics import confusion_matrix, matthews_corrcoef
from tqdm import tqdm, trange
import tensorflow as tf
import os

## Define util functions

In [0]:
# Reads a up to spec pdb file and return a tuple of the
# atoms' x, y, z and atomtype
def read_pdb(filename):
    with open(filename, 'r') as file:
        strline_L = file.readlines()
    atom_list = []
    for strline in strline_L:
        # removes all whitespace at the start and end, including spaces, tabs, newlines and carriage returns
        stripped_line = strline.strip()

        line_length = len(stripped_line)
        # print("Line length:{}".format(line_length))
        if line_length < 78:
            print("ERROR: line length is different. Expected>=78, current={}".format(line_length))
        
        atom_list.append((
            stripped_line[30:38].strip(),
            stripped_line[38:46].strip(),
            stripped_line[46:54].strip(),
            'h' if stripped_line[76:78].strip() == 'C' else 'p',
        ))
        
    return np.array(atom_list, order='F')

In [0]:
# Reads the test pdb file and return a tuple of the
# atoms' x, y, z and atomtype
def read_test_pdb(filename):
    with open(filename, 'r') as file:
        strline_L = file.readlines()
    atom_list = []
    for strline in strline_L:
        # removes all whitespace at the start and end, including spaces, tabs, newlines and carriage returns
        stripped_line = strline.strip()
        tokens = stripped_line.split("\t")
        
        atom_list.append((
            tokens[0],
            tokens[1],
            tokens[2],
            tokens[3],
        ))

    return np.array(atom_list, order='F')

In [0]:
def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()

In [0]:
def mcc(y_true, y_pred):
    y_pred_pos = K.round(K.clip(y_pred, 0, 1))
    y_pred_neg = 1 - y_pred_pos

    y_pos = K.round(K.clip(y_true, 0, 1))
    y_neg = 1 - y_pos

    tp = K.sum(y_pos * y_pred_pos)
    tn = K.sum(y_neg * y_pred_neg)

    fp = K.sum(y_neg * y_pred_pos)
    fn = K.sum(y_pos * y_pred_neg)

    numerator = (tp * tn - fp * fn)
    denominator = K.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))

    return numerator / (denominator + K.epsilon())

def ppv(y_true, y_pred):
    y_pred_pos = K.round(K.clip(y_pred, 0, 1))
    y_pred_neg = 1 - y_pred_pos

    y_pos = K.round(K.clip(y_true, 0, 1))
    y_neg = 1 - y_pos

    tp = K.sum(y_pos * y_pred_pos)
    tn = K.sum(y_neg * y_pred_neg)

    fp = K.sum(y_neg * y_pred_pos)
    fn = K.sum(y_pos * y_pred_neg)

    numerator = tp
    denominator = tp + fp

    return numerator / (denominator + K.epsilon())

def tpr(y_true, y_pred):
    y_pred_pos = K.round(K.clip(y_pred, 0, 1))
    y_pred_neg = 1 - y_pred_pos

    y_pos = K.round(K.clip(y_true, 0, 1))
    y_neg = 1 - y_pos

    tp = K.sum(y_pos * y_pred_pos)
    tn = K.sum(y_neg * y_pred_neg)

    fp = K.sum(y_neg * y_pred_pos)
    fn = K.sum(y_pos * y_pred_neg)

    numerator = tp
    denominator = tp + fn

    return numerator / (denominator + K.epsilon())


# Import raw training data

## Download and unzip the training data

In [0]:
!wget https://web.bii.a-star.edu.sg/~leehk/cs5242_project/training_data.zip
!unzip training_data.zip

## Load training data into memory

In [0]:
raw_training_data = {
    'pro': [],
    'lig': []
}
for i in trange(3000):
    raw_training_data['pro'].append(
        read_pdb("./training_data/{:04d}_pro_cg.pdb".format(i + 1)))
    raw_training_data['lig'].append(
        read_pdb("./training_data/{:04d}_lig_cg.pdb".format(i + 1)))

In [0]:
n = int(len(raw_training_data['pro'])*0.9)
raw_training_train_data = {
    'pro': raw_training_data['pro'][:n],
    'lig': raw_training_data['lig'][:n]
}
raw_training_test_data = {
    'pro': raw_training_data['pro'][n:],
    'lig': raw_training_data['lig'][n:]
}

# Preprocess training data

## Install multidimensional sparse matrix library

In [0]:
!pip install sparse

In [0]:
from sparse import COO
import sparse

## Define voxelization functions

In [0]:
# Returns a sparse matrix representation of the voxel
def voxelize(pdb_inputs, max_dist=20, grid_resolution=4):
    def featurize(atom_type):
        # Default: protein, hydrophobic
        feat = [0, 128]
        # Change to ligand
        if atom_type[1] == 'l':
            feat[0] = 1
        # change to polar
        if atom_type[0] == 'p':
            feat[1] = 256
        return feat
    
    max_dist = float(max_dist)
    grid_resolution = float(grid_resolution)
    box_size = np.ceil(2 * max_dist / grid_resolution + 1)

    # merge protein and ligand
    pro_atoms = pdb_inputs[0]
    lig_atoms = pdb_inputs[1]
    pro_atoms = np.c_[pro_atoms, np.full(pro_atoms.shape[0], 'p')]
    lig_atoms = np.c_[lig_atoms, np.full(lig_atoms.shape[0], 'l')]
    all_atoms = np.r_[pro_atoms, lig_atoms]

    # center all atoms around the center of the protein
    coord_mat = all_atoms[:,:3].astype(np.float)
    coord_mat = coord_mat - np.mean(lig_atoms[:,:3].astype(np.float), axis=0)

    # add feature list to identify the atom h/p and pro/lig
    feats_list = np.asarray([featurize(atom_type) for atom_type in all_atoms[:,-2:]])  
    atom_mat = np.c_[coord_mat, feats_list]

    # move all atoms to the nearest grid point
    atom_mat = np.c_[coord_mat, feats_list]
    atom_mat[:,:3] = (atom_mat[:,:3] + max_dist) / grid_resolution
    atom_mat[:,:3] = atom_mat[:,:3].round()
    atom_mat = atom_mat.astype(int)

    # remove atoms outside the box
    in_box = ((atom_mat[:,:3] >= 0) & (atom_mat[:,:3] < box_size)).all(axis=1)
    atom_mat = atom_mat[in_box]

    # transpose the matrix
    feats_list = np.squeeze(atom_mat[:,-1:])
    atom_mat = atom_mat[:,:4].T
    
    # create the sparse matrix
    s = COO(atom_mat, feats_list, shape=(int(box_size), int(box_size), int(box_size), 2))
    s.sum()
    s = s.reshape((1, int(box_size), int(box_size), int(box_size), 2))
    
    return s

In [0]:
# Returns a tuple containg the training data and corresponding labels
# ratio specifies the number of negative training examples generated
# per positive training example
def generate_training_data(raw_data, pos_ratio=1, neg_ratio=1, max_dist=20, grid_resolution=4, quiet=False):
    n = len(raw_data['pro'])
    x_all = []
    y_all = []
    for i in tqdm(range(n), disable=quiet):
        for _ in range(pos_ratio):
            grid = voxelize((
                raw_data['pro'][i],
                raw_data['lig'][i]
            ), max_dist, grid_resolution)
            x_all.append(grid)
            y_all.append([1.])
        for _ in range(neg_ratio):
            grid = voxelize((
                raw_data['pro'][i],
                raw_data['lig'][random.choice(list(range(i)) + list(range(i+1, n)))]
            ), max_dist, grid_resolution)
            x_all.append(grid)
            y_all.append([0.])
    return sparse.concatenate(x_all), np.asarray(y_all)
    

## Define Keras Sequence for dymanically generating samples



In [0]:
class ProLigSequence(Sequence):

    def __init__(self, raw_data, max_dist=20, grid_resolution=4, batch_size=128, neg_ratio=1, quiet=True, sparse=True):
        self.raw_data = raw_data
        self.max_dist = max_dist
        self.grid_resolution = grid_resolution
        self.batch_size = batch_size
        self.neg_ratio = neg_ratio
        self.quiet = quiet
        self.sparse = sparse
        self.pos_eg_x, self.pos_eg_y = generate_training_data(raw_data, neg_ratio=0, max_dist=max_dist, grid_resolution=grid_resolution, quiet=self.quiet)
        if not sparse:
            self.pos_eg_x = self.pos_eg_x.todense()
        self.on_epoch_end()
        
    def __len__(self):
        return int(np.ceil(len(self.all_eg_x) / float(self.batch_size)))

    def __getitem__(self, idx):
        indexes = self.indexes[idx*self.batch_size:(idx+1)*self.batch_size]
        batch_x = self.all_eg_x[indexes].todense() if self.sparse else self.all_eg_x[indexes]
        batch_y = self.all_eg_y[indexes]
        return batch_x, batch_y

    def on_epoch_end(self):
        # Generate a new set of negative training examples
        self.neg_eg_x, self.neg_eg_y = generate_training_data(
            self.raw_data,
            pos_ratio=0,
            neg_ratio=self.neg_ratio,
            max_dist=self.max_dist,
            grid_resolution=self.grid_resolution,
            quiet=self.quiet
        )
        if self.sparse:
            self.all_eg_x = sparse.concatenate((self.pos_eg_x, self.neg_eg_x))
        else:
            self.all_eg_x = np.concatenate((self.pos_eg_x, self.neg_eg_x.todense()))
        self.all_eg_y = np.concatenate((self.pos_eg_y, self.neg_eg_y))
        self.indexes = np.arange(len(self.all_eg_x))
        np.random.shuffle(self.indexes)

# Define Wide ResNet

In [0]:
# k defines the width of the network as defined in the Wide ResNet paper
def generate_resnet(input_shape, k=1, noise=False,
                    l1_filters=16, l1_kernel_size=3, l1_dilation_rate=1):
    inputs = Input(shape=input_shape)
    x = inputs
    x = Conv3D(
        filters=l1_filters,
        kernel_size=l1_kernel_size,
        dilation_rate=l1_dilation_rate,
        padding='valid',
        data_format='channels_last',
        kernel_initializer='he_normal',
    )(x)
    
    # Block 1.1 32 Features
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x2 = Conv3D(
        filters=32*k,
        kernel_size=1,
        padding='same',
        data_format='channels_last',
        kernel_initializer='he_normal',
    )(x)
    x1 = Conv3D(
        filters=32*k,
        kernel_size=3,
        padding='same',
        data_format='channels_last',
        kernel_initializer='he_normal',
    )(x)
    x1 = Dropout(0.5)(x1)
    x1 = BatchNormalization()(x1)
    x1 = Activation('relu')(x1)
    x1 = Conv3D(
        filters=32*k,
        kernel_size=3,
        padding='same',
        data_format='channels_last',
        kernel_initializer='he_normal',
    )(x1)
    x = Add()([x1, x2])

    # Block 1.2 32 Features
    x2 = x
    x1 = BatchNormalization()(x)
    x1 = Activation('relu')(x1)
    x1 = Conv3D(
        filters=32*k,
        kernel_size=3,
        padding='same',
        data_format='channels_last',
        kernel_initializer='he_normal',
    )(x1)
    x1 = Dropout(0.5)(x1)
    x1 = BatchNormalization()(x1)
    x1 = Activation('relu')(x1)
    x1 = Conv3D(
        filters=32*k,
        kernel_size=3,
        padding='same',
        data_format='channels_last',
        kernel_initializer='he_normal',
    )(x1)
    x = Add()([x1, x2])

    # Block 2.1 64 Features
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x2 = Conv3D(
        filters=64*k,
        kernel_size=1,
        padding='same',
        data_format='channels_last',
        kernel_initializer='he_normal',
    )(x)
    x1 = Conv3D(
        filters=64*k,
        kernel_size=3,
        padding='same',
        data_format='channels_last',
        kernel_initializer='he_normal',
    )(x)
    x1 = Dropout(0.5)(x1)
    x1 = BatchNormalization()(x1)
    x1 = Activation('relu')(x1)
    x1 = Conv3D(
        filters=64*k,
        kernel_size=3,
        padding='same',
        data_format='channels_last',
        kernel_initializer='he_normal',
    )(x1)
    x = Add()([x1, x2])

    # Block 2.2 64 Features
    x2 = x
    x1 = BatchNormalization()(x)
    x1 = Activation('relu')(x1)
    x1 = Conv3D(
        filters=64*k,
        kernel_size=3,
        padding='same',
        data_format='channels_last',
        kernel_initializer='he_normal',
    )(x1)
    x1 = Dropout(0.5)(x1)
    x1 = BatchNormalization()(x1)
    x1 = Activation('relu')(x1)
    x1 = Conv3D(
        filters=64*k,
        kernel_size=3,
        padding='same',
        data_format='channels_last',
        kernel_initializer='he_normal',
    )(x1)
    x = Add()([x1, x2])    
    
    # Block 3.1 128 Features
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x2 = Conv3D(
        filters=128*k,
        kernel_size=1,
        padding='same',
        data_format='channels_last',
        kernel_initializer='he_normal',
    )(x)
    x1 = Conv3D(
        filters=128*k,
        kernel_size=3,
        padding='same',
        data_format='channels_last',
        kernel_initializer='he_normal',
    )(x)
    x1 = Dropout(0.5)(x1)
    x1 = BatchNormalization()(x1)
    x1 = Activation('relu')(x1)
    x1 = Conv3D(
        filters=128*k,
        kernel_size=3,
        padding='same',
        data_format='channels_last',
        kernel_initializer='he_normal',
    )(x1)
    x = Add()([x1, x2])

    # Block 3.2 128 Features
    x2 = x
    x1 = BatchNormalization()(x)
    x1 = Activation('relu')(x1)
    x1 = Conv3D(
        filters=128*k,
        kernel_size=3,
        padding='same',
        data_format='channels_last',
        kernel_initializer='he_normal',
    )(x1)
    x1 = Dropout(0.5)(x1)
    x1 = BatchNormalization()(x1)
    x1 = Activation('relu')(x1)
    x1 = Conv3D(
        filters=128*k,
        kernel_size=3,
        padding='same',
        data_format='channels_last',
        kernel_initializer='he_normal',
    )(x1)
    x = Add()([x1, x2])      
    
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = AveragePooling3D()(x)
    x = Flatten()(x)
    x = Dense(
        128,
        activation='relu',
        kernel_initializer='he_normal',
    )(x)
    x = Dropout(0.5)(x)
    outputs = Dense(
        1,
        activation='sigmoid',
        kernel_initializer='he_normal',
    )(x)

    return Model(inputs=inputs, outputs=outputs)


In [0]:
model = generate_resnet(
    input_shape=(21, 21, 21, 2),
    k=1,
    l1_filters=16,
    l1_kernel_size=6,
    l1_dilation_rate=3,
)
model.summary()

In [0]:
tpu_model = tf.contrib.tpu.keras_to_tpu_model(
    model,
    strategy=tf.contrib.tpu.TPUDistributionStrategy(
        tf.contrib.cluster_resolver.TPUClusterResolver(tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])
    )
)

In [0]:
tpu_model.compile(
    loss='binary_crossentropy',
    metrics=['acc', mcc, ppv, tpr],
    optimizer=CrossShardOptimizer(AdamOptimizer())
)
history = tpu_model.fit_generator(
    generator=ProLigSequence(raw_training_train_data, batch_size=512, max_dist=40, grid_resolution=4, sparse=False),
    validation_data=ProLigSequence(raw_training_test_data, batch_size=512, max_dist=40, grid_resolution=4, sparse=False),
    epochs=500,
    initial_epoch=0,
    use_multiprocessing=True,
    workers=8,
    callbacks=[ModelCheckpoint('Dynamic.h5',
                           monitor='val_mcc',
                           verbose=1,
                           save_best_only=True,
                           mode='max',
                           period=1),
              TensorBoard()]
)
# tpu_model.compile(
#     loss='binary_crossentropy',
#     metrics=['acc', mcc],
#     optimizer=CrossShardOptimizer(MomentumOptimizer(
#         learning_rate=0.1,
#         momentum=0.9,
#         use_nesterov=True))
# )
# tpu_model.fit_generator(
#     generator=ProLigSequence(raw_training_train_data, batch_size=512, max_dist=50, grid_resolution=1),
#     validation_data=ProLigSequence(raw_training_test_data, batch_size=512, max_dist=50, grid_resolution=1),
#     epochs=2800,
#     initial_epoch=200,
#     use_multiprocessing=True,
#     workers=1,
#     callbacks=[ModelCheckpoint('Dynamic.h5',
#                            monitor='val_mcc',
#                            verbose=1,
#                            save_best_only=True,
#                            mode='max',
#                            period=1),
#               TensorBoard()]
# )

# Evaluate Model

## Generate testing data

In [0]:
x_test, y_test = generate_training_data(raw_training_test_data, neg_ratio=10, max_dist=40, grid_resolution=4)
x_test = x_test.todense()


## Load best model

In [0]:
best_model = load_model("Dynamic.h5")


## Plot confusion matrix

In [0]:
y_pred = best_model.predict(x_test)
y_pred = np.piecewise(y_pred, [y_pred < 0.5, y_pred >= 0.5], [0., 1.])
cnf_matrix = confusion_matrix(y_test, y_pred)
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=[0, 1], normalize=True,
                      title='Normalized confusion matrix')
