In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import hashlib
import random

from tqdm import tqdm
import rdkit.Chem as Chem
from rdkit.Chem import Draw
from rdkit.Chem import AllChem
import glob
import torch
torch.set_num_threads(1)
torch.multiprocessing.set_sharing_strategy('file_system')

%matplotlib inline

In [2]:
import sys
tankbind_src_folder = "/home/zoujl/TankBind/tankbind/"
sys.path.insert(0, tankbind_src_folder)

import logging

from data import TankBindDataSet
from torch_geometric.loader import DataLoader
from model import get_model

No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'


In [6]:
result_folder = "/home/zoujl/TankBind/examples/predictions/"
os.system(f"mkdir -p {result_folder}")

rdkit_folder = f"{result_folder}/rdkit/"
os.system(f"mkdir -p {rdkit_folder}")

0

In [7]:
from feature_utils import read_mol, generate_sdf_from_smiles_using_rdkit, extract_torchdrug_feature_from_mol

compound_dict = {}
pre = "/home/zoujl/TankBind/pdbbind2020/"
test = np.loadtxt("/home/zoujl/TankBind/packages/EquiBind/timesplit_test", dtype=str)
unfound = 0
for name in tqdm(test):
    try:
        mol, _ = read_mol(f"{pre}/renumber_atom_index_same_as_smiles/{name}.sdf", None)
    except OSError:
        print(f"{unfound+1}: Not in refined set: {name}")
        unfound += 1
        continue
        
    smiles = Chem.MolToSmiles(mol)

    rdkit_mol_path = f"{rdkit_folder}/{name}_ligand.sdf"
    generate_sdf_from_smiles_using_rdkit(smiles, rdkit_mol_path, shift_dis=0)

    mol, _ = read_mol(rdkit_mol_path, None)
    compound_dict[name] = extract_torchdrug_feature_from_mol(mol, has_LAS_mask=True)
torch.save(compound_dict, f"{result_folder}/pdbbind_test_compound_dict_based_on_rdkit.pt")

1: Not in refined set: 6d08
2: Not in refined set: 6uvp
3: Not in refined set: 6oxq
4: Not in refined set: 6jsn
5: Not in refined set: 6oio
6: Not in refined set: 6moa
7: Not in refined set: 6hld
8: Not in refined set: 6i9a
9: Not in refined set: 6e4c
10: Not in refined set: 6s55
11: Not in refined set: 6seo
12: Not in refined set: 5zk5
13: Not in refined set: 6jid
14: Not in refined set: 5ze6
15: Not in refined set: 6a6k
16: Not in refined set: 6e3z
17: Not in refined set: 6te6
18: Not in refined set: 6pka
19: Not in refined set: 6jsf
20: Not in refined set: 5zxk
21: Not in refined set: 6qxd
22: Not in refined set: 6n97
23: Not in refined set: 6jt3
24: Not in refined set: 6qtr
25: Not in refined set: 6oy1
26: Not in refined set: 6n96
27: Not in refined set: 6qzh
28: Not in refined set: 6qmt
29: Not in refined set: 6ibx
30: Not in refined set: 6hmt
31: Not in refined set: 5zk7
32: Not in refined set: 6ibz
33: Not in refined set: 6ott
34: Not in refined set: 6gge
35: Not in refined set:

In [8]:
dataset = TankBindDataSet("/home/zoujl/TankBind/pdbbind2020/test_dataset/", 
                          proteinMode=0, compoundMode=1, pocket_radius=20, predDis=True)
dataset.compound_dict = torch.load(f"{result_folder}/pdbbind_test_compound_dict_based_on_rdkit.pt")
# dataset.data = dataset.data.query("not use_compound_com").reset_index(drop=True)
data_loader = DataLoader(dataset, batch_size=1, 
                         follow_batch=['x', 'y', 'compound_pair'], shuffle=False, num_workers=8, pin_memory=True)


['/home/zoujl/TankBind/pdbbind2020/test_dataset/processed/data.pt', '/home/zoujl/TankBind/pdbbind2020/test_dataset/processed/protein.pt', '/home/zoujl/TankBind/pdbbind2020/test_dataset/processed/compound.pt']


In [16]:
# # device = 'cpu'
# # device = "cuda:2"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
logging.basicConfig(level=logging.INFO)
model = get_model(0, logging, device)
model.eval()
model.load_state_dict(torch.load(f"{tankbind_src_folder}/../saved_models/self_dock.pt", map_location=device))

affinity_pred_list = []
y_pred_list = []
for data in tqdm(data_loader):
    data = data.to(device)
    with torch.no_grad():
        y_pred, affinity_pred = model(data)
    affinity_pred_list.append(affinity_pred.detach().cpu())
    for i in range(data.y_batch.max() + 1):
        y_pred_list.append((y_pred[data['y_batch'] == i]).detach().cpu())

affinity_pred_list = torch.cat(affinity_pred_list)

15:54:11   5 stack, readout2, pred dis map add self attention and GVP embed, compound model GIN


In [17]:
output_info_chosen = dataset.data
output_info_chosen['affinity'] = affinity_pred_list
output_info_chosen['dataset_index'] = range(len(output_info_chosen))
output_info_chosen = output_info_chosen.query("not use_compound_com").reset_index(drop=True)

chosen = output_info_chosen.loc[output_info_chosen.groupby(['protein_name', 'compound_name'], 
                                                           sort=False)['affinity'].agg('idxmax')].reset_index()

In [18]:
from generation_utils import get_LAS_distance_constraint_mask, get_info_pred_distance, write_with_new_coords

device = "cpu"
for idx, line in tqdm(chosen.iterrows(), total=chosen.shape[0]):
    name = compound_name = line['compound_name']
    dataset_index = line['dataset_index']
    coords = dataset[dataset_index].coords.to(device)
    protein_nodes_xyz = dataset[dataset_index].node_xyz.to(device)
    n_compound = coords.shape[0]
    n_protein = protein_nodes_xyz.shape[0]
    y_pred = y_pred_list[dataset_index].reshape(n_protein, n_compound).to(device)
    y = dataset[dataset_index].dis_map.reshape(n_protein, n_compound).to(device)
    compound_pair_dis_constraint = torch.cdist(coords, coords)
    rdkit_mol_path = f"{rdkit_folder}/{name}_ligand.sdf"
    # mol = Chem.MolFromMolFile(rdkit_mol_path)
    mol, _ = read_mol(rdkit_mol_path, None)
    LAS_distance_constraint_mask = get_LAS_distance_constraint_mask(mol).bool()
    pred_dist_info = get_info_pred_distance(coords, y_pred, protein_nodes_xyz, compound_pair_dis_constraint,
                                  LAS_distance_constraint_mask=LAS_distance_constraint_mask,
                                  n_repeat=1, show_progress=False)

    toFile = f'{result_folder}/{name}_tankbind_chosen.sdf'
    new_coords = pred_dist_info.sort_values("loss")['coords'].iloc[0].astype(np.double)
    write_with_new_coords(mol, new_coords, toFile)


In [19]:
# taken from https://github.com/nghiaho12/rigid_transform_3D/blob/master/rigid_transform_3D.py
# "Least-Squares Fitting of Two 3-D Point Sets", Arun, K. S. and Huang, T. S. and Blostein, S. D, IEEE Transactions on Pattern Analysis and Machine Intelligence, Volume 9 Issue 5, May 1987
# Input: expects 3xN matrix of points
# Returns R,t
# R = 3x3 rotation matrix
# t = 3x1 column vector

def rigid_transform_3D(A, B, correct_reflection=True):
    assert A.shape == B.shape

    num_rows, num_cols = A.shape
    if num_rows != 3:
        raise Exception(f"matrix A is not 3xN, it is {num_rows}x{num_cols}")

    num_rows, num_cols = B.shape
    if num_rows != 3:
        raise Exception(f"matrix B is not 3xN, it is {num_rows}x{num_cols}")

    # find mean column wise
    centroid_A = np.mean(A, axis=1)
    centroid_B = np.mean(B, axis=1)

    # ensure centroids are 3x1
    centroid_A = centroid_A.reshape(-1, 1)
    centroid_B = centroid_B.reshape(-1, 1)

    # subtract mean
    Am = A - centroid_A
    Bm = B - centroid_B

    H = Am @ np.transpose(Bm)

    # sanity check
    #if linalg.matrix_rank(H) < 3:
    #    raise ValueError("rank of H = {}, expecting 3".format(linalg.matrix_rank(H)))

    # find rotation
    U, S, Vt = np.linalg.svd(H)
    R = Vt.T @ U.T

    # special reflection case
    if np.linalg.det(R) < 0 and correct_reflection:
        print("det(R) < R, reflection detected!, correcting for it ...")
        Vt[2,:] *= -1
        R = Vt.T @ U.T

    t = -R @ centroid_A + centroid_B

    return R, t

def compute_RMSD(a, b):
    # correct rmsd calculation.
    return np.sqrt((((a-b)**2).sum(axis=-1)).mean())

def kabsch_RMSD(new_coords, coords):
    out = new_coords.T
    target = coords.T
    ret_R, ret_t = rigid_transform_3D(out, target, correct_reflection=False)
    out = (ret_R@out) + ret_t
    return compute_RMSD(target.T, out.T)

In [22]:
info = []
pre = "/home/zoujl/TankBind/pdbbind2020/"
test = np.loadtxt("/home/zoujl/TankBind/packages/EquiBind/timesplit_test", dtype=str)
unfound = 0
for pdb in test:
    try:
        mol, _ = read_mol(f"{pre}/renumber_atom_index_same_as_smiles/{pdb}.sdf", None)
        mol_pred, _ = read_mol(f"{result_folder}/{pdb}_tankbind_chosen.sdf", None)
        # mol = Chem.MolFromMolFile(f"{pre}/renumber_atom_index_same_as_smiles/{pdb}.sdf")
        # mol_pred = Chem.MolFromMolFile(f"{result_folder}/{pdb}_tankbind_chosen.sdf")
    except OSError:
        print(f"{unfound+1}: Not in refined set: {pdb}")
        unfound += 1
        continue

    sm = Chem.MolToSmiles(mol)
    m_order = list(mol.GetPropsAsDict(includePrivate=True, includeComputed=True)['_smilesAtomOutputOrder'])
    mol = Chem.RenumberAtoms(mol, m_order)
    mol = Chem.RemoveHs(mol)
    true_ligand_pos = np.array(mol.GetConformer().GetPositions())

    sm = Chem.MolToSmiles(mol_pred)
    m_order = list(mol_pred.GetPropsAsDict(includePrivate=True, includeComputed=True)['_smilesAtomOutputOrder'])
    mol_pred = Chem.RenumberAtoms(mol_pred, m_order)
    mol_pred = Chem.RemoveHs(mol_pred)
    mol_pred_pos = np.array(mol_pred.GetConformer().GetPositions())

    rmsd = np.sqrt(((true_ligand_pos - mol_pred_pos) ** 2).sum(axis=1).mean(axis=0))
    kabsch = kabsch_RMSD(mol_pred_pos, true_ligand_pos)
    com_dist = compute_RMSD(mol_pred_pos.mean(axis=0), true_ligand_pos.mean(axis=0))
    info.append([pdb, rmsd, com_dist, kabsch])


1: Not in refined set: 6d08
2: Not in refined set: 6uvp
3: Not in refined set: 6oxq
4: Not in refined set: 6jsn
5: Not in refined set: 6oio
6: Not in refined set: 6moa
7: Not in refined set: 6hld
8: Not in refined set: 6i9a
9: Not in refined set: 6e4c
10: Not in refined set: 6s55
11: Not in refined set: 6seo
12: Not in refined set: 5zk5
13: Not in refined set: 6jid
14: Not in refined set: 5ze6
15: Not in refined set: 6a6k
16: Not in refined set: 6e3z
17: Not in refined set: 6te6
18: Not in refined set: 6pka
19: Not in refined set: 6jsf
20: Not in refined set: 5zxk
21: Not in refined set: 6qxd
22: Not in refined set: 6n97
23: Not in refined set: 6jt3
24: Not in refined set: 6qtr
25: Not in refined set: 6oy1
26: Not in refined set: 6n96
27: Not in refined set: 6qzh
28: Not in refined set: 6qmt
29: Not in refined set: 6ibx
30: Not in refined set: 6hmt
31: Not in refined set: 5zk7
32: Not in refined set: 6ibz
33: Not in refined set: 6ott
34: Not in refined set: 6gge
35: Not in refined set:

In [57]:
# custom description function.
def below_threshold(x, threshold=5):
    return 100 * (x < threshold).sum() / len(x)
def custom_description(data):
    t1 = data
    t2 = t1.describe()
    t3 = t1.iloc[:,1:].apply(below_threshold, threshold=2, axis=0).reset_index(name='2A').set_index('index').T
    t31 = t1.iloc[:,1:].apply(below_threshold, threshold=5, axis=0).reset_index(name='5A').set_index('index').T
    t32 = t1.iloc[:,1:].median().reset_index(name='median').set_index('index').T
    t4 = pd.concat([t2, t3, t31, t32]).loc[['mean', '25%', '50%', '75%', '5A', '2A', 'median']]
    t5 = t4.T.reset_index()
    t5[['Methods', 'Metrics']] = t5['index'].str.split('_', n=1, expand=True)
    t6 = pd.pivot(t5, values=['mean', 'median', '25%', '50%', '75%', '5A', '2A'], index=['Methods'], columns=['Metrics'])
    t6_col = t6.columns
    t6.columns = t6_col.swaplevel(0, 1)
    t7 = t6[sorted(t6.columns)]
    my_MultiIndex = [
                (    'RMSD',  'mean'),
                (    'RMSD',   '25%'),
                (    'RMSD',  '50%'),
                (    'RMSD',   '75%'),
                (    'RMSD',  '5A'),
                (    'RMSD', '2A'),
                ('COM_DIST',  'mean'),
                ('COM_DIST',   '25%'),
                ('COM_DIST',  '50%'),
                ('COM_DIST',   '75%'),
                ('COM_DIST',  '5A'),
                ('COM_DIST', '2A'),
                (  'KABSCH',  'mean'),
                (  'KABSCH',   'median'),
                ]
    t8 = t7[my_MultiIndex]

    my_MultiIndex_fancy = [
                (    'Ligand RMSD $\downarrow$', ' ', 'mean'),
                (    'Ligand RMSD $\downarrow$', 'Percentiles $\downarrow$', '25%'),
                (    'Ligand RMSD $\downarrow$', 'Percentiles $\downarrow$',  '50%'),
                (    'Ligand RMSD $\downarrow$', 'Percentiles $\downarrow$',   '75%'),
                (    'Ligand RMSD $\downarrow$', r'% Below Threshold $\uparrow$',  '5A'),
                (    'Ligand RMSD $\downarrow$', r'% Below Threshold $\uparrow$', '2A'),
                ('Centroid Distance $\downarrow$', ' ',  'mean'),
                ('Centroid Distance $\downarrow$', 'Percentiles $\downarrow$',   '25%'),
                ('Centroid Distance $\downarrow$', 'Percentiles $\downarrow$',  '50%'),
                ('Centroid Distance $\downarrow$', 'Percentiles $\downarrow$',   '75%'),
                ('Centroid Distance $\downarrow$', r'% Below Threshold $\uparrow$', '5A'),
                ('Centroid Distance $\downarrow$', r'% Below Threshold $\uparrow$', '2A'),
                (  'KABSCH', 'RMSD $\downarrow$',  'mean'),
                (  'KABSCH', 'RMSD $\downarrow$',   'median'),
                ]

    t8.columns = pd.MultiIndex.from_tuples(my_MultiIndex_fancy)
    return t8.round(2)


In [58]:
d = pd.DataFrame(info, columns=['pdb', 'TankBind_RMSD', 'TankBind_COM_DIST', 'TankBind_KABSCH'])
print(d)
custom_description(d)

      pdb  TankBind_RMSD  TankBind_COM_DIST  TankBind_KABSCH
0    6qqw       2.759375           1.141447         1.373331
1    6jap       4.057490           1.766604         2.428273
2    6np2       2.998698           0.815427         2.354368
3    6hzb       6.144129           3.100021         4.410622
4    6qrc       3.172557           1.700969         1.871114
..    ...            ...                ...              ...
105  6qr3       2.661474           1.216915         1.709395
106  6qr1       3.444966           2.255350         2.041378
107  6nw3       1.345522           0.503400         1.224858
108  6o5g       4.957063           3.632609         1.887120
109  6gj8       5.331422           1.270842         2.615114

[110 rows x 4 columns]


Unnamed: 0_level_0,Ligand RMSD $\downarrow$,Ligand RMSD $\downarrow$,Ligand RMSD $\downarrow$,Ligand RMSD $\downarrow$,Ligand RMSD $\downarrow$,Ligand RMSD $\downarrow$,Centroid Distance $\downarrow$,Centroid Distance $\downarrow$,Centroid Distance $\downarrow$,Centroid Distance $\downarrow$,Centroid Distance $\downarrow$,Centroid Distance $\downarrow$,KABSCH,KABSCH
Unnamed: 0_level_1,Unnamed: 1_level_1,Percentiles $\downarrow$,Percentiles $\downarrow$,Percentiles $\downarrow$,% Below Threshold $\uparrow$,% Below Threshold $\uparrow$,Unnamed: 7_level_1,Percentiles $\downarrow$,Percentiles $\downarrow$,Percentiles $\downarrow$,% Below Threshold $\uparrow$,% Below Threshold $\uparrow$,RMSD $\downarrow$,RMSD $\downarrow$
Unnamed: 0_level_2,mean,25%,50%,75%,5A,2A,mean,25%,50%,75%,5A,2A,mean,median
Methods,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3,Unnamed: 11_level_3,Unnamed: 12_level_3,Unnamed: 13_level_3,Unnamed: 14_level_3
TankBind,4.79,2.46,3.51,5.0,74.55,18.18,2.97,0.8,1.39,2.54,87.27,65.45,2.2,1.91
