In [1]:
! git clone https://github.com/Ironarrow98/Zhang_Chenxi_BS6207_Final

Cloning into 'Zhang_Chenxi_BS6207_Final'...
remote: Enumerating objects: 7684, done.[K
remote: Counting objects: 100% (7684/7684), done.[K
remote: Compressing objects: 100% (7549/7549), done.[K
remote: Total 7684 (delta 36), reused 7654 (delta 25), pack-reused 0[K
Receiving objects: 100% (7684/7684), 88.28 MiB | 21.06 MiB/s, done.
Resolving deltas: 100% (36/36), done.
Checking out files: 100% (7658/7658), done.


In [2]:
! pip install sparse

Collecting sparse
  Downloading sparse-0.13.0-py2.py3-none-any.whl (77 kB)
[?25l[K     |████▏                           | 10 kB 32.1 MB/s eta 0:00:01[K     |████████▍                       | 20 kB 9.6 MB/s eta 0:00:01[K     |████████████▋                   | 30 kB 8.6 MB/s eta 0:00:01[K     |████████████████▉               | 40 kB 8.0 MB/s eta 0:00:01[K     |█████████████████████           | 51 kB 4.4 MB/s eta 0:00:01[K     |█████████████████████████▎      | 61 kB 5.2 MB/s eta 0:00:01[K     |█████████████████████████████▌  | 71 kB 5.7 MB/s eta 0:00:01[K     |████████████████████████████████| 77 kB 3.1 MB/s 
Installing collected packages: sparse
Successfully installed sparse-0.13.0


In [3]:
import matplotlib.pyplot as plt
import itertools
import numpy as np
import random
import pandas as pd
from tqdm import tqdm, trange
import sparse
from sparse import COO
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.models import Sequential, Model, load_model
from tensorflow.keras.layers import Input, Conv3D, MaxPooling3D,Flatten,Dense, Dropout, BatchNormalization, Add, AveragePooling3D, Activation, GaussianNoise, Lambda
from tensorflow.keras import optimizers, losses, regularizers
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras.activations import relu
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard
from sklearn.preprocessing import OneHotEncoder
from sklearn.metrics import confusion_matrix, matthews_corrcoef, roc_curve, auc, classification_report

In [4]:
# Reads the test pdb file and return a tuple of the
# atoms' x, y, z and atomtype
def read_test_pdb(filename):
  with open(filename, 'r') as file:
    strline_L = file.readlines()
  atom_list = []
  for strline in strline_L:
    # removes all whitespace at the start and end, including spaces, tabs, newlines and carriage returns
    stripped_line = strline.strip()
    tokens = stripped_line.split("\t")
        
    atom_list.append((
        tokens[0],
        tokens[1],
        tokens[2],
        tokens[3]
        ))

  return np.array(atom_list, order='F')

In [5]:
testing_data = {
    'pro': [],
    'lig': []
}

for i in range(1, 825):
  testing_data['pro'].append(
      # Load your own file path for the testing data, same as training_data
      read_test_pdb('Zhang_Chenxi_BS6207_Final/testing_data_release/testing_data/{:04d}_pro_cg.pdb'.format(i)))
  testing_data['lig'].append(
      # Load your own file path for the testing data, same as training_data
      read_test_pdb('Zhang_Chenxi_BS6207_Final/testing_data_release/testing_data/{:04d}_lig_cg.pdb'.format(i)))

In [6]:
# Returns a sparse matrix representation of the voxel
def voxelize(pdb_inputs, max_dist, grid_resolution):
  
  def featurize(atom_type):
    feat = [0, 128]
    # Change to ligand
    if atom_type[1] == 'l':
      feat[0] = 1
    # change to polar
    if atom_type[0] == 'p':
      feat[1] = 256
    return feat
    
  max_dist = float(max_dist)
  grid_resolution = float(grid_resolution)
  box_size = np.ceil(2 * max_dist / grid_resolution)

  # merge protein and ligand
  pro_atomtypes = pdb_inputs[0]
  lig_atomtypes = pdb_inputs[1]
  pro_atomtype = np.c_[pro_atomtypes, np.full(pro_atomtypes.shape[0], 'p')]
  lig_atomtype = np.c_[lig_atomtypes, np.full(lig_atomtypes.shape[0], 'l')]
  all_atoms = np.r_[pro_atomtype, lig_atomtype]

  # center all atoms around the mean of the ligand
  cord_map = all_atoms[:, :3].astype(float)
  cord_map = cord_map - np.mean(lig_atomtype[:, :3].astype(float), axis = 0)

  # add feature list to identify the atom h/p and pro/lig
  feature_list = np.asarray([featurize(atom_type) for atom_type in all_atoms[:, -2:]])  
  atom_map = np.c_[cord_map, feature_list]

  # map all atoms to the nearest grid point
  atom_map = np.c_[cord_map, feature_list]
  atom_map[:, :3] = (atom_map[:, :3] + max_dist) / grid_resolution
  atom_map[:, :3] = atom_map[:, :3].round()
  atom_map = atom_map.astype(int)

  # remove atoms outside the box
  in_box = ((atom_map[:, :3] >= 0) & (atom_map[:, :3] < box_size)).all(axis = 1)
  atom_map = atom_map[in_box]

  # transpose the matrix
  features_list = np.squeeze(atom_map[:, -1:])
  atom_map = atom_map[:, :4].T
    
  # create the sparse matrix
  result = COO(atom_map, features_list, shape = (int(box_size), int(box_size), int(box_size), 2))
  result.sum()
  result = result.reshape((1, int(box_size), int(box_size), int(box_size), 2))
  
  return result

In [7]:
result = []
best_model = load_model("best_model8.h5")

for i in tqdm(range(824)):
  cur_x = []
  for j in range(824):
    grid = voxelize((testing_data['pro'][i], testing_data['lig'][j]), 40, 4)
    cur_x.append(grid)
  X = sparse.concatenate(cur_x).todense()
  y_pred = best_model.predict(X).flatten()
  result.append(y_pred)

100%|██████████| 824/824 [58:52<00:00,  4.29s/it]


In [8]:
top10_result = []

for r in result:
  top10_result.append(r.argsort()[-10:][::-1])

In [9]:
top10_df = pd.DataFrame(top10_result, columns = ['lig1_id', 'lig2_id', 'lig3_id', 
                                                 'lig4_id', 'lig5_id', 'lig6_id',
                                                 'lig7_id', 'lig8_id', 'lig9_id',	
                                                 'lig10_id'])
top10_df.insert(0, 'pro_id', range(824))
top10_df += 1
top10_df

Unnamed: 0,pro_id,lig1_id,lig2_id,lig3_id,lig4_id,lig5_id,lig6_id,lig7_id,lig8_id,lig9_id,lig10_id
0,1,582,422,78,156,751,731,263,377,388,694
1,2,1,20,204,185,167,334,148,319,315,101
2,3,824,591,372,391,395,403,411,426,448,454
3,4,46,446,265,449,72,649,518,120,674,351
4,5,222,204,650,315,466,493,195,385,14,274
...,...,...,...,...,...,...,...,...,...,...,...
819,820,233,737,425,448,789,769,298,185,176,267
820,821,374,606,724,275,136,784,508,181,3,478
821,822,390,658,609,607,426,680,353,35,82,459
822,823,763,554,58,162,728,504,656,356,410,622


In [10]:
top10_df.to_csv('test_predictions.txt', index = None, sep = ' ')

In [11]:
top10_df.head(10)

Unnamed: 0,pro_id,lig1_id,lig2_id,lig3_id,lig4_id,lig5_id,lig6_id,lig7_id,lig8_id,lig9_id,lig10_id
0,1,582,422,78,156,751,731,263,377,388,694
1,2,1,20,204,185,167,334,148,319,315,101
2,3,824,591,372,391,395,403,411,426,448,454
3,4,46,446,265,449,72,649,518,120,674,351
4,5,222,204,650,315,466,493,195,385,14,274
5,6,714,407,79,202,661,606,576,137,125,404
6,7,525,665,669,146,675,426,680,154,411,401
7,8,529,798,172,577,272,778,58,410,413,690
8,9,1,536,81,570,565,564,552,549,547,537
9,10,562,72,336,803,555,165,774,234,814,360
