# overview

We start from the raw PDBbind dataset downloaded from http://www.pdbbind.org.cn/download.php

1. filter out those unable to process using RDKit.

2. Process the protein by only preserving the chains that with at least one atom within 10Å from any atom of the ligand.

3. Use p2rank to segment protein into blocks.

4. extract protein and ligand features.

5. construct the training and test dataset.


In [1]:
### set env_path
tankbind_src_folder_path = "/home/zoujl/TankBind/tankbind/"
import sys
sys.path.insert(0, tankbind_src_folder_path)

In [39]:
import pandas as pd
import numpy as np
import os
from tqdm import tqdm
import torch

# process the raw PDBbind dataset.

In [3]:

from utils import read_pdbbind_data

In [97]:
# raw PDBbind dataset could be downloaded from http://www.pdbbind.org.cn/download.php
pre = "../pdbbind2020/"
df_pdb_id = pd.read_csv(f'{pre}/index/INDEX_refined_name.2020', sep="  ", comment='#', header=None, names=['pdb', 'year', 'uid', 'd', 'e','f','g','h','i','j','k','l','m','n','o'], engine='python')
df_pdb_id = df_pdb_id[['pdb','uid']]
data = read_pdbbind_data(f'{pre}/index/INDEX_refined_data.2020')
data = data.merge(df_pdb_id, on=['pdb'])
data

Unnamed: 0,pdb,resolution,year,affinity,raw,ligand,uid
0,2r58,2.00,2007,2.00,Kd=10mM,MLY,Q9VHA0
1,3c2f,2.35,2008,2.00,Kd=10.1mM,PRP,P43619
2,3g2y,1.31,2009,2.00,Ki=10mM,GF4,Q9L5C8
3,3pce,2.06,1998,2.00,Ki=10mM,3HP,P00436
4,4qsu,1.90,2014,2.00,Kd=10mM,TDR,Q6PL18
...,...,...,...,...,...,...,...
5311,4f3c,1.93,2013,11.82,Ki=1.5pM,BIG,E8NLP5
5312,5bry,1.34,2015,11.82,Ki=0.0015nM,4UY,P03366
5313,1sl3,1.81,2004,11.85,Ki=1.4pM,170,P00734
5314,1ctu,2.30,1995,11.92,Ki=1.2pM,ZEB,P0ABF6


### ligand file should be readable by RDKit

In [5]:
from feature_utils import read_mol

No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'


In [6]:
### select available data that can be read by RDkit
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
readable_list = []
probem_list = []
for pdb in tqdm(data.pdb):
    sdf_fileName = f"{pre}/refined_set/{pdb}/{pdb}_ligand.sdf"
    mol2_fileName = f"{pre}/refined_set/{pdb}/{pdb}_ligand.mol2"
    mol, problem = read_mol(sdf_fileName, mol2_fileName)
    if problem:
        probem_list.append(pdb)
        continue
    readable_list.append(pdb)

  0%|          | 0/5316 [00:00<?, ?it/s]

100%|██████████| 5316/5316 [00:31<00:00, 171.13it/s]


In [98]:
print(len(readable_list))
print(len(probem_list))
data = data.query("pdb in @readable_list").reset_index(drop=True)
data

5236
80


Unnamed: 0,pdb,resolution,year,affinity,raw,ligand,uid
0,2r58,2.00,2007,2.00,Kd=10mM,MLY,Q9VHA0
1,3c2f,2.35,2008,2.00,Kd=10.1mM,PRP,P43619
2,3g2y,1.31,2009,2.00,Ki=10mM,GF4,Q9L5C8
3,3pce,2.06,1998,2.00,Ki=10mM,3HP,P00436
4,4qsu,1.90,2014,2.00,Kd=10mM,TDR,Q6PL18
...,...,...,...,...,...,...,...
5231,4f3c,1.93,2013,11.82,Ki=1.5pM,BIG,E8NLP5
5232,5bry,1.34,2015,11.82,Ki=0.0015nM,4UY,P03366
5233,1sl3,1.81,2004,11.85,Ki=1.4pM,170,P00734
5234,1ctu,2.30,1995,11.92,Ki=1.2pM,ZEB,P0ABF6


### for ease of RMSD evaluation later, we renumber the atom index to be consistent with the smiles

In [7]:
from feature_utils import write_renumbered_sdf

In [9]:
toFolder = f"{pre}/renumber_atom_index_same_as_smiles"
os.system(f"mkdir -p {toFolder}")
for pdb in tqdm(readable_list):
    sdf_fileName = f"{pre}/refined_set/{pdb}/{pdb}_ligand.sdf"
    mol2_fileName = f"{pre}/refined_set/{pdb}/{pdb}_ligand.mol2"
    toFile = f"{toFolder}/{pdb}.sdf"
    write_renumbered_sdf(toFile, sdf_fileName, mol2_fileName)


  2%|▏         | 127/5236 [00:00<00:07, 638.65it/s]

100%|██████████| 5236/5236 [00:07<00:00, 689.43it/s]


### process PDBbind proteins, removing extra chains, cutoff 10A

In [18]:
toFolder = f"{pre}/protein_remove_extra_chains_10A/"
os.system(f"mkdir -p {toFolder}")

0

In [19]:
### build input->output mapping
input_ = []
cutoff = 10
for pdb in data.pdb.values:
    pdbFile = f"{pre}/refined_set/{pdb}/{pdb}_protein.pdb"
    ligandFile = f"{pre}/renumber_atom_index_same_as_smiles/{pdb}.sdf"
    toFile = f"{toFolder}/{pdb}_protein.pdb"
    x = (pdbFile, ligandFile, cutoff, toFile)
    input_.append(x)

input_[:3]

[('../pdbbind2020//refined_set/2r58/2r58_protein.pdb',
  '../pdbbind2020//renumber_atom_index_same_as_smiles/2r58.sdf',
  10,
  '../pdbbind2020//protein_remove_extra_chains_10A//2r58_protein.pdb'),
 ('../pdbbind2020//refined_set/3c2f/3c2f_protein.pdb',
  '../pdbbind2020//renumber_atom_index_same_as_smiles/3c2f.sdf',
  10,
  '../pdbbind2020//protein_remove_extra_chains_10A//3c2f_protein.pdb'),
 ('../pdbbind2020//refined_set/3g2y/3g2y_protein.pdb',
  '../pdbbind2020//renumber_atom_index_same_as_smiles/3g2y.sdf',
  10,
  '../pdbbind2020//protein_remove_extra_chains_10A//3g2y_protein.pdb')]

In [8]:
from feature_utils import select_chain_within_cutoff_to_ligand_v2

In [13]:
import mlcrate as mlc
import os
pool = mlc.SuperPool(64)
pool.pool.restart()
_ = pool.map(select_chain_within_cutoff_to_ligand_v2,input_)
pool.exit()

[mlcrate] 64 CPUs: 100%|██████████| 5236/5236 [00:43<00:00, 121.28it/s]


In [28]:
# previously, I found that 2r1w has no chain near the ligand.
data = data.query("pdb != '2r1w'").reset_index(drop=True)

In [21]:
### checkout original protein & ligand structure and removed_10A_protein & renumbered_ligand structure
import nglview   # conda install nglview -c conda-forge if import failure

checkout_id = "2r58"

proteinFile = f"../pdbbind2020/refined_set/{checkout_id}/{checkout_id}_protein.pdb"
molFile = f"../pdbbind2020/refined_set/{checkout_id}/{checkout_id}_ligand.sdf"

removed_proteinFile = f"../pdbbind2020/protein_remove_extra_chains_10A/{checkout_id}_protein.pdb"
renumbered_molFile = f"../pdbbind2020/renumber_atom_index_same_as_smiles/{checkout_id}.sdf"

view = nglview.show_file(nglview.FileStructure(proteinFile), default=False)
view.add_representation('cartoon', selection='protein', color='white')
rdkit = view.add_component(nglview.FileStructure(molFile), default=False)
rdkit.add_ball_and_stick(color='red')

rdkit = view.add_component(nglview.FileStructure(removed_proteinFile), default=False)
rdkit.add_representation('cartoon', selection='protein', color='yellow')
rdkit = view.add_component(nglview.FileStructure(renumbered_molFile), default=False)
rdkit.add_ball_and_stick(color='green')

view

NGLWidget()

In [16]:
view.render_image()


Image(value=b'', width='99%')

# Generate pocket_dict: p2rank segmentation

In [29]:
p2rank_prediction_folder = f"{pre}/p2rank_protein_remove_extra_chains_10A"
os.system(f"mkdir -p {p2rank_prediction_folder}")

### protein_list.ds for the use of p2rnak
ds = f"{p2rank_prediction_folder}/protein_list.ds"
pdb_list = data.pdb.values
with open(ds, "w") as out:
    for pdb in pdb_list:
        out.write(f"../protein_remove_extra_chains_10A/{pdb}_protein.pdb\n")

In [None]:
### run p2rank (takes about 30 minutes)
p2rank = "bash /home/zoujl/TankBind/packages/p2rank_2.3/prank"
cmd = f"{p2rank} predict {ds} -o {p2rank_prediction_folder}/p2rank -threads 16"
os.system(cmd)

In [30]:
tankbind_data_path = f"{pre}/tankbind_data"
os.system(f"mkdir -p {tankbind_data_path}")

0

In [32]:
name_list = pdb_list
d_list = []

for name in tqdm(name_list):
    p2rankFile = f"{pre}/p2rank_protein_remove_extra_chains_10A/p2rank/{name}_protein.pdb_predictions.csv"
    d = pd.read_csv(p2rankFile)
    d.columns = d.columns.str.strip()
    d_list.append(d.assign(name=name))
    
d = pd.concat(d_list).reset_index(drop=True)
d.reset_index(drop=True).to_feather(f"{tankbind_data_path}/p2rank_result.feather")

100%|██████████| 5236/5236 [00:19<00:00, 269.11it/s]


In [33]:
d = pd.read_feather(f"{tankbind_data_path}/p2rank_result.feather")
d

Unnamed: 0,name,rank,score,probability,sas_points,surf_atoms,center_x,center_y,center_z,residue_ids,surf_atom_ids
0,2r58,1,7.74,0.249,51,26,-22.5809,-17.7873,-19.9953,A_324 A_327 A_330 A_332 A_348 A_351 A_355 A_3...,2291 2293 2347 2348 2349 2394 2430 2658 2661 ...
1,2r58,2,3.28,0.057,44,24,-33.7310,12.2990,-16.9301,A_213 A_214 A_220 A_222 A_255 A_256 A_258 A_2...,596 598 604 613 619 708 736 738 1244 1255 125...
2,2r58,3,1.75,0.013,41,23,-24.9639,13.8173,-15.1200,A_254 A_255 A_257 A_258 A_259 A_262 A_265 A_2...,1233 1241 1245 1249 1273 1298 1302 1303 1309 ...
3,3c2f,1,40.52,0.915,301,114,12.3269,-9.3542,31.6212,A_13 A_14 A_148 A_151 A_154 A_155 A_158 A_16 ...,72 85 87 88 89 94 98 117 129 132 172 178 182 ...
4,3c2f,2,33.04,0.880,229,99,24.7129,-15.0894,44.9685,A_141 A_142 A_143 A_144 A_165 A_166 A_169 A_1...,2137 2140 2144 2148 2152 2158 2161 2167 2169 ...
...,...,...,...,...,...,...,...,...,...,...,...
40322,1ctu,17,0.91,0.002,27,10,48.3192,73.2073,-0.8277,A_273 A_277 A_281 A_286 A_288,4082 4087 4135 4138 4198 4201 4257 4284 4285 ...
40323,6e9a,1,31.54,0.869,129,60,16.5019,22.1874,17.4029,A_23 A_25 A_27 A_28 A_29 A_30 A_32 A_47 A_48 ...,88 227 245 259 262 265 266 271 273 274 275 27...
40324,6e9a,2,1.77,0.014,30,16,3.5404,16.0525,29.4310,B_60 B_61 B_72 B_73 B_74 B_88 B_92,1492 1493 1495 1497 1500 1501 1598 1605 1608 ...
40325,6e9a,3,1.47,0.008,26,13,21.9780,40.2209,11.3321,A_60 A_61 A_72 A_73 A_74 A_88 A_92,571 575 581 678 685 688 693 694 813 814 817 8...


In [34]:
### split d into pockets_dict
pockets_dict = {}
for name in tqdm(name_list):
    pockets_dict[name] = d[d.name == name].reset_index(drop=True)

100%|██████████| 5236/5236 [00:12<00:00, 404.00it/s]


In [35]:
### checkout pockets_feature
pockets_dict[checkout_id]

Unnamed: 0,name,rank,score,probability,sas_points,surf_atoms,center_x,center_y,center_z,residue_ids,surf_atom_ids
0,2r58,1,7.74,0.249,51,26,-22.5809,-17.7873,-19.9953,A_324 A_327 A_330 A_332 A_348 A_351 A_355 A_3...,2291 2293 2347 2348 2349 2394 2430 2658 2661 ...
1,2r58,2,3.28,0.057,44,24,-33.731,12.299,-16.9301,A_213 A_214 A_220 A_222 A_255 A_256 A_258 A_2...,596 598 604 613 619 708 736 738 1244 1255 125...
2,2r58,3,1.75,0.013,41,23,-24.9639,13.8173,-15.12,A_254 A_255 A_257 A_258 A_259 A_262 A_265 A_2...,1233 1241 1245 1249 1273 1298 1302 1303 1309 ...


In [23]:
### save and load pockets_dict
import pickle

with open(f"{tankbind_data_path}/pockets_dict.pkl", 'wb') as f:
    pickle.dump(pockets_dict, f)

pockets_dict = {}
with open(f"{tankbind_data_path}/pockets_dict.pkl", 'rb') as f:
    pockets_dict = pickle.load(f)

# Generate protein_dict: gvp feature

In [36]:
from feature_utils import get_protein_feature

In [41]:
### build input->output mapping
input_ = []
protein_embedding_folder = f"{tankbind_data_path}/gvp_protein_embedding"
os.system(f"mkdir -p {protein_embedding_folder}")

for pdb in pdb_list:
    proteinFile = f"{pre}/protein_remove_extra_chains_10A/{pdb}_protein.pdb"
    toFile = f"{protein_embedding_folder}/{pdb}.pt"
    x = (pdb, proteinFile, toFile)
    input_.append(x)

input_[:5]

[('2r58',
  '../pdbbind2020//protein_remove_extra_chains_10A/2r58_protein.pdb',
  '../pdbbind2020//tankbind_data/gvp_protein_embedding/2r58.pt'),
 ('3c2f',
  '../pdbbind2020//protein_remove_extra_chains_10A/3c2f_protein.pdb',
  '../pdbbind2020//tankbind_data/gvp_protein_embedding/3c2f.pt'),
 ('3g2y',
  '../pdbbind2020//protein_remove_extra_chains_10A/3g2y_protein.pdb',
  '../pdbbind2020//tankbind_data/gvp_protein_embedding/3g2y.pt'),
 ('3pce',
  '../pdbbind2020//protein_remove_extra_chains_10A/3pce_protein.pdb',
  '../pdbbind2020//tankbind_data/gvp_protein_embedding/3pce.pt'),
 ('4qsu',
  '../pdbbind2020//protein_remove_extra_chains_10A/4qsu_protein.pdb',
  '../pdbbind2020//tankbind_data/gvp_protein_embedding/4qsu.pt')]

In [26]:
from Bio.PDB import PDBParser
from feature_utils import get_clean_res_list
import torch
torch.set_num_threads(1)

def batch_run(x):
    protein_dict = {}
    pdb, proteinFile, toFile = x
    ### example: 2r58
    parser = PDBParser(QUIET=True)
    s = parser.get_structure(pdb, proteinFile)
    ### s.get_residues().resname:
    ### > 0: ALA 1: PHE 2: ASP 3: TRP 4: ASP 5: ALA ...211: LYS
    res_list = get_clean_res_list(s.get_residues(), verbose=False, ensure_ca_exist=True)
    ### > cleaded residues length: 212

    protein_dict[pdb] = get_protein_feature(res_list)
    torch.save(protein_dict, toFile)

In [27]:
import mlcrate as mlc
import os
pool = mlc.SuperPool(64)
pool.pool.restart()
_ = pool.map(batch_run,input_)
pool.exit()

[mlcrate] 64 CPUs: 100%|██████████| 5236/5236 [00:59<00:00, 88.05it/s] 


In [99]:
data

Unnamed: 0,pdb,resolution,year,affinity,raw,ligand,uid
0,2r58,2.00,2007,2.00,Kd=10mM,MLY,Q9VHA0
1,3c2f,2.35,2008,2.00,Kd=10.1mM,PRP,P43619
2,3g2y,1.31,2009,2.00,Ki=10mM,GF4,Q9L5C8
3,3pce,2.06,1998,2.00,Ki=10mM,3HP,P00436
4,4qsu,1.90,2014,2.00,Kd=10mM,TDR,Q6PL18
...,...,...,...,...,...,...,...
5231,4f3c,1.93,2013,11.82,Ki=1.5pM,BIG,E8NLP5
5232,5bry,1.34,2015,11.82,Ki=0.0015nM,4UY,P03366
5233,1sl3,1.81,2004,11.85,Ki=1.4pM,170,P00734
5234,1ctu,2.30,1995,11.92,Ki=1.2pM,ZEB,P0ABF6


In [42]:
### load protein_dict
pdb_list = pd.read_csv(f"{pre}/data.csv").pdb.values

protein_dict = {}
for pdb in tqdm(pdb_list):
    protein_dict.update(torch.load(f"{protein_embedding_folder}/{pdb}.pt"))



  0%|          | 0/5236 [00:00<?, ?it/s]

100%|██████████| 5236/5236 [02:07<00:00, 41.15it/s] 


In [43]:
# checkout protein feature 
protein_features = protein_dict[checkout_id]
print(f"> compound_features (size: {len(protein_features)}): ")
for feature in protein_features:
    print(feature.shape)

protein_features

> compound_features (size: 7): 
torch.Size([212, 3])
torch.Size([212])
torch.Size([212, 6])
torch.Size([212, 3, 3])
torch.Size([2, 6360])
torch.Size([6360, 32])
torch.Size([6360, 1, 3])


(tensor([[-7.6500e+00, -1.3176e+01, -1.4296e+01],
         [-8.3320e+00, -1.4639e+01, -1.7768e+01],
         [-5.2450e+00, -1.5139e+01, -1.9945e+01],
         [-6.3580e+00, -1.3245e+01, -2.3043e+01],
         [-2.8680e+00, -1.3331e+01, -2.4631e+01],
         [-2.6540e+00, -1.7135e+01, -2.4507e+01],
         [-6.2790e+00, -1.7369e+01, -2.5657e+01],
         [-5.8080e+00, -1.5076e+01, -2.8679e+01],
         [-2.6760e+00, -1.6977e+01, -2.9711e+01],
         [-4.3860e+00, -2.0389e+01, -2.9578e+01],
         [-7.5050e+00, -1.9294e+01, -3.1444e+01],
         [-5.3490e+00, -1.7353e+01, -3.3945e+01],
         [-7.4220e+00, -1.4267e+01, -3.3148e+01],
         [-7.1560e+00, -1.0521e+01, -3.2521e+01],
         [-8.9540e+00, -8.5990e+00, -2.9779e+01],
         [-1.0928e+01, -5.5810e+00, -3.0978e+01],
         [-8.7970e+00, -2.4590e+00, -3.0298e+01],
         [-9.8890e+00, -1.2200e-01, -2.7473e+01],
         [-1.0768e+01,  2.5620e+00, -3.0078e+01],
         [-1.3819e+01,  4.7900e-01, -3.1172e+01],


# Generate compound_dict: compound feature

In [44]:
from feature_utils import extract_torchdrug_feature_from_mol
compound_dict = {}
skip_pdb_list = []
for pdb in tqdm(pdb_list):
    mol, _ = read_mol(f"{pre}/renumber_atom_index_same_as_smiles/{pdb}.sdf", None)
    # extract features from sdf.
    try:
        compound_dict[pdb] = extract_torchdrug_feature_from_mol(mol, has_LAS_mask=True)  # self-dock set has_LAS_mask to true
    except Exception as e:
        print(e)
        skip_pdb_list.append(pdb)
        print(pdb)

100%|██████████| 5236/5236 [01:10<00:00, 74.31it/s] 


In [45]:
print(skip_pdb_list)
data = data.query("pdb not in @skip_pdb_list").reset_index(drop=True)

[]


In [33]:
### save and load compound_dict
torch.save(compound_dict, f"{tankbind_data_path}/compound_torchdrug_features.pt")

compound_dict = torch.load(f"{tankbind_data_path}/compound_torchdrug_features.pt")

In [46]:
# checkout compound_features
compound_features = compound_dict[checkout_id]
print(f"> compound_features (size: {len(compound_features)}): ")
for feature in compound_features:
    print(feature.shape)

compound_features

> compound_features (size: 5): 
(12, 3)
torch.Size([12, 56])
torch.Size([22, 3])
torch.Size([22, 19])
torch.Size([12, 12, 16])


(array([[-22.985, -16.836, -21.812],
        [-23.231, -17.077, -20.381],
        [-21.954, -17.324, -19.698],
        [-24.105, -18.251, -20.22 ],
        [-24.578, -18.371, -18.775],
        [-25.624, -19.467, -18.673],
        [-25.82 , -19.943, -17.242],
        [-27.107, -20.759, -17.144],
        [-27.431, -20.995, -15.702],
        [-26.969, -22.041, -17.928],
        [-26.241, -22.951, -17.527],
        [-27.567, -22.21 , -18.999]]),
 tensor([[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          1, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
          0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 1, 0, 0, 0, 0, 0, 0],
         [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0

# construct dataset.

In [100]:
data

Unnamed: 0,pdb,resolution,year,affinity,raw,ligand,uid
0,2r58,2.00,2007,2.00,Kd=10mM,MLY,Q9VHA0
1,3c2f,2.35,2008,2.00,Kd=10.1mM,PRP,P43619
2,3g2y,1.31,2009,2.00,Ki=10mM,GF4,Q9L5C8
3,3pce,2.06,1998,2.00,Ki=10mM,3HP,P00436
4,4qsu,1.90,2014,2.00,Kd=10mM,TDR,Q6PL18
...,...,...,...,...,...,...,...
5231,4f3c,1.93,2013,11.82,Ki=1.5pM,BIG,E8NLP5
5232,5bry,1.34,2015,11.82,Ki=0.0015nM,4UY,P03366
5233,1sl3,1.81,2004,11.85,Ki=1.4pM,170,P00734
5234,1ctu,2.30,1995,11.92,Ki=1.2pM,ZEB,P0ABF6


In [101]:
# we use the time-split defined in EquiBind paper.
# https://github.com/HannesStark/EquiBind/tree/main/data
valid = np.loadtxt("/home/zoujl/TankBind/packages/EquiBind/timesplit_no_lig_overlap_val", dtype=str)
test = np.loadtxt("/home/zoujl/TankBind/packages/EquiBind/timesplit_test", dtype=str)
def assign_group(pdb, valid=valid, test=test):
    if pdb in valid:
        return 'valid'
    if pdb in test:
        return 'test'
    return 'train'

data['group'] = data.pdb.map(assign_group)
data['name'] = data['pdb']

In [102]:
data.value_counts("group")

group
train    4858
valid     268
test      110
Name: count, dtype: int64

In [103]:
### checkout pocket_dict, protein_dict, compound_dict
print(f"pocket_coms from pocket_dict:\n{pockets_dict[checkout_id].head(10)[['center_x', 'center_y', 'center_z']].values}")

print(f"protein_com from protein_dict:\n{protein_dict[checkout_id][0].numpy().mean(axis=0).astype(float).reshape(1, 3)}")

print(f"compound_dict:\n{compound_dict[checkout_id][0]}")

pocket_coms from pocket_dict:
[[-22.5809 -17.7873 -19.9953]
 [-33.731   12.299  -16.9301]
 [-24.9639  13.8173 -15.12  ]]
protein_com from protein_dict:
[[-24.28586197  -1.69748116 -17.80907059]]
compound_dict:
[[-22.985 -16.836 -21.812]
 [-23.231 -17.077 -20.381]
 [-21.954 -17.324 -19.698]
 [-24.105 -18.251 -20.22 ]
 [-24.578 -18.371 -18.775]
 [-25.624 -19.467 -18.673]
 [-25.82  -19.943 -17.242]
 [-27.107 -20.759 -17.144]
 [-27.431 -20.995 -15.702]
 [-26.969 -22.041 -17.928]
 [-26.241 -22.951 -17.527]
 [-27.567 -22.21  -18.999]]


In [128]:
info = []
for i, line in tqdm(data.iterrows(), total=data.shape[0]):
    pdb = line['pdb']
    uid = line['uid']
    # smiles = line['smiles']
    smiles = ""
    affinity = line['affinity']
    group = line['group']

    compound_name = line['name']
    protein_name = line['name']

    pocket = pockets_dict[pdb].head(10)
    pocket.columns = pocket.columns.str.strip()
    pocket_coms = pocket[['center_x', 'center_y', 'center_z']].values
    
    # item: native block.
    info.append([protein_name, compound_name, pdb, smiles, affinity, uid, None, True, False, group])
    
    # item: protein center as pocket_com.
    protein_com = protein_dict[protein_name][0].numpy().mean(axis=0).astype(float).reshape(1, 3)
    info.append([protein_name, compound_name, pdb+"_c", smiles, affinity, uid, protein_com, False, False, group])
    
    # item: each pocket's coords as pocket_com
    for idx, pocket_line in pocket.iterrows():
        pdb_idx = f"{pdb}_{idx}"
        info.append([protein_name, compound_name, pdb_idx, smiles, affinity, uid, pocket_coms[idx].reshape(1, 3), False, False, group])

info = pd.DataFrame(info, columns=['protein_name', 'compound_name', 'pdb', 'smiles', 'affinity', 'uid', 'pocket_com', 
                                   'use_compound_com', 'use_whole_protein',
                                  'group'])

info.size

  2%|▏         | 84/5236 [00:00<00:06, 836.50it/s]

100%|██████████| 5236/5236 [00:06<00:00, 822.98it/s]


427900

In [130]:
### save and load info
torch.save(info, f"{tankbind_data_path}/dataset.pt")

info = torch.load(f"{tankbind_data_path}/dataset.pt")

### !!! unsuggest pandas.to_csv(), because it will damage data type like: None, np.array...
# info.to_csv(f"{tankbind_data_path}/dataset.csv", index=False,na_rep='None')
# info = pd.read_csv(f"{tankbind_data_path}/dataset.csv", index_col=None)

In [None]:
### checkout dataset is undamaged
print(type(info[info['protein_name'] == checkout_id]['pocket_com'][1]))

info

In [107]:
from data import TankBindDataSet

In [108]:
toFilePre = f"{pre}/dataset"
os.system(f"mkdir -p {toFilePre}")

dataset = TankBindDataSet(toFilePre, data=info, protein_dict=protein_dict, compound_dict=compound_dict)

Processing...
Done!


['../pdbbind2020/dataset/processed/data.pt', '../pdbbind2020/dataset/processed/protein.pt', '../pdbbind2020/dataset/processed/compound.pt']


In [109]:
### will directly load pre-processed dataset in toFilePre
dataset = TankBindDataSet(toFilePre)

['../pdbbind2020/dataset/processed/data.pt', '../pdbbind2020/dataset/processed/protein.pt', '../pdbbind2020/dataset/processed/compound.pt']


In [136]:
### checkout dataset content
dataset[0]

# HeteroData(
#   dis_map=[684],    # pair distance between protein_node & compound_node (57*12)
#   node_xyz=[57, 3], # coords of protein_node (residue)
#   coords=[12, 3],   # coords of compound_node (atom)
#   y=[684],          # mask: whether distance is less than contactCutoff (default: 8.0)
#   seq=[57],         # encoded protein residue sequence
#   affinity=[1],     # affinity retrieve from pdb .csv file
#   compound_pair=[144, 16],          # compound's atom-pair distance (atom_num*atom_num, 16)
#   pdb='2r58',       # pdb_id
#   group='train',    # train/val/test
#   real_affinity_mask=[1],           # only True if data is the native row (use_compound_com=True)
#   real_y_mask=[684],                # all True if data is the native row (use_compound_com=True), else all False
#   is_equivalent_native_pocket=[1],  # True if this data row refer to a equivalent_native pocket (close in distance: num_contact/native_num_contact >= 90%)
#   equivalent_native_y_mask=[684],   # all True if is_equivalent_native_pocket
#   protein={
#     node_s=[57, 6],         # scale feature of node (res_num, 6)
#     node_v=[57, 3, 3],      # vector feature of node (res_num, 2+1, 3)
#   },
#   compound={ x=[12, 56] },  # features of compound node (atom_num, 56)
#   (protein, p2p, protein)={
#     edge_index=[2, 1216],   # edges that link protein nodes with the top_k(default: 30) nearest neighbors (2, edge_num)
#     edge_s=[1216, 32],      # scale feature of edge (edge_num, 32)
#     edge_v=[1216, 1, 3],    # vector feature of edge (edge_num, 1, 3)
#   },
#   (compound, c2c, compound)={
#     edge_index=[2, 22],     # edges that link compound nodes, computed by torch_drug.Molecule (2, edge_num)
#     edge_weight=[22],       # a tensor of 1 (edge_num)
#     edge_attr=[22, 19],     # feature of compound's atom edge
#   }
# )


HeteroData(
  dis_map=[684],
  node_xyz=[57, 3],
  coords=[12, 3],
  y=[684],
  seq=[57],
  affinity=[1],
  compound_pair=[144, 16],
  pdb='2r58',
  group='train',
  real_affinity_mask=[1],
  real_y_mask=[684],
  protein={
    node_s=[57, 6],
    node_v=[57, 3, 3],
  },
  compound={ x=[12, 56] },
  (protein, p2p, protein)={
    edge_index=[2, 1216],
    edge_s=[1216, 32],
    edge_v=[1216, 1, 3],
  },
  (compound, c2c, compound)={
    edge_index=[2, 22],
    edge_weight=[22],
    edge_attr=[22, 19],
  }
)

In [144]:
### checkout info that retrieve from TankBindDataSet.Data
info = dataset.data
info[info['protein_name'] == checkout_id]

Unnamed: 0,protein_name,compound_name,pdb,smiles,affinity,uid,pocket_com,use_compound_com,use_whole_protein,group
0,2r58,2r58,2r58,,2.0,Q9VHA0,,True,False,train
1,2r58,2r58,2r58_c,,2.0,Q9VHA0,"[[-24.28586196899414, -1.6974811553955078, -17...",False,False,train
2,2r58,2r58,2r58_0,,2.0,Q9VHA0,"[[-22.5809, -17.7873, -19.9953]]",False,False,train
3,2r58,2r58,2r58_1,,2.0,Q9VHA0,"[[-33.731, 12.299, -16.9301]]",False,False,train
4,2r58,2r58,2r58_2,,2.0,Q9VHA0,"[[-24.9639, 13.8173, -15.12]]",False,False,train


In [None]:
### discard pre-processed field which may be cached in toFilePre
info = info.drop(['p_length', 'c_length', 'y_length', 'num_contact', 'native_num_contact'], axis=1)
info[info['protein_name'] == checkout_id]

### further complete dataset: protein_length, compound_length, y_length, num_contact, native_num_contact

In [117]:
### add protein_length, compound_length, y_length, num_contact into dataset
t = []
pre_pdb = None
for i, line in tqdm(info.iterrows(), total=info.shape[0]):
    pdb = line['compound_name']
    d = dataset[i]
    p_length = d['node_xyz'].shape[0]
    c_length = d['coords'].shape[0]
    y_length = d['y'].shape[0]
    num_contact = (d.y > 0).sum()
    t.append([i, pdb, p_length, c_length, y_length, num_contact])

t[:10]

100%|██████████| 42790/42790 [02:44<00:00, 260.22it/s]


[[0, '2r58', 57, 12, 684, tensor(52)],
 [1, '2r58', 141, 12, 1692, tensor(41)],
 [2, '2r58', 83, 12, 996, tensor(52)],
 [3, '2r58', 82, 12, 984, tensor(0)],
 [4, '2r58', 84, 12, 1008, tensor(0)],
 [5, '3c2f', 162, 22, 3564, tensor(121)],
 [6, '3c2f', 181, 22, 3982, tensor(81)],
 [7, '3c2f', 163, 22, 3586, tensor(39)],
 [8, '3c2f', 178, 22, 3916, tensor(121)],
 [9, '3c2f', 171, 22, 3762, tensor(0)]]

In [149]:
### turn list to DataFrame and concat with original info
t = pd.DataFrame(t, columns=['index', 'pdb' ,'p_length', 'c_length', 'y_length', 'num_contact'])
t['num_contact'] = t['num_contact'].apply(lambda x: x.item())

info = pd.concat([info, t[['p_length', 'c_length', 'y_length', 'num_contact']]], axis=1)
info[info['protein_name'] == checkout_id]

Unnamed: 0,protein_name,compound_name,pdb,smiles,affinity,uid,pocket_com,use_compound_com,use_whole_protein,group,p_length,c_length,y_length,num_contact
0,2r58,2r58,2r58,,2.0,Q9VHA0,,True,False,train,57,12,684,52
1,2r58,2r58,2r58_c,,2.0,Q9VHA0,"[[-24.28586196899414, -1.6974811553955078, -17...",False,False,train,141,12,1692,41
2,2r58,2r58,2r58_0,,2.0,Q9VHA0,"[[-22.5809, -17.7873, -19.9953]]",False,False,train,83,12,996,52
3,2r58,2r58,2r58_1,,2.0,Q9VHA0,"[[-33.731, 12.299, -16.9301]]",False,False,train,82,12,984,0
4,2r58,2r58,2r58_2,,2.0,Q9VHA0,"[[-24.9639, 13.8173, -15.12]]",False,False,train,84,12,1008,0


In [150]:
### add native_num_contact into dataset
native_num_contact = info.query("use_compound_com").set_index("protein_name")['num_contact'].to_dict()
info['native_num_contact'] = info.protein_name.map(native_num_contact)
# info['fract_of_native_contact'] = info['num_contact'] / info['native_num_contact']

In [153]:
### save and load completed dataset
torch.save(info, f"{toFilePre}/processed/data.pt")

info = torch.load(f"{toFilePre}/processed/data.pt")


### split out test dataset

In [156]:
test = info.query("group == 'test'").reset_index(drop=True)
test_pdb_list = info.query("group == 'test'").protein_name.unique()
test_pdb_list.size

110

In [157]:
### extract test protein_dict
subset_protein_dict = {}
for pdb in tqdm(test_pdb_list):
    subset_protein_dict[pdb] = protein_dict[pdb]

100%|██████████| 110/110 [00:00<00:00, 355449.49it/s]


In [158]:
### extract test compound_dict
subset_compound_dict = {}
for pdb in tqdm(test_pdb_list):
    subset_compound_dict[pdb] = compound_dict[pdb]

100%|██████████| 110/110 [00:00<00:00, 357100.19it/s]


In [159]:
### build test TankBindDataSet
toFilePre = f"{pre}/test_dataset"
os.system(f"mkdir -p {toFilePre}")
dataset = TankBindDataSet(toFilePre, data=test, protein_dict=subset_protein_dict, compound_dict=subset_compound_dict)

Processing...
Done!


['../pdbbind2020/test_dataset/processed/data.pt', '../pdbbind2020/test_dataset/processed/protein.pt', '../pdbbind2020/test_dataset/processed/compound.pt']


In [161]:
def canonical_smiles(smiles):
    return Chem.MolToSmiles(Chem.MolFromSmiles(smiles))