In [1]:
from torch_geometric.data import DataLoader
import torch_geometric
import torch.distributions as D
import matplotlib.pyplot as plt
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem, Draw, rdFMCS, rdMolTransforms
from rdkit.Chem.rdMolAlign import AlignMol
from rdkit.Chem import PandasTools
from rdkit import rdBase
import glob
import os

import deepdock 
from deepdock.utils.distributions import *
from deepdock.utils.data import *
from deepdock.models import *

%matplotlib inline
np.random.seed(123)
torch.cuda.manual_seed_all(123)
torch.manual_seed(123)



<torch._C.Generator at 0x7fa66808cf90>

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
#device = 'cpu'

ligand_model = LigandNet(28, residual_layers=10, dropout_rate=0.10)
target_model = TargetNet(4, residual_layers=10, dropout_rate=0.10)
model = DeepDock(ligand_model, target_model, hidden_dim=64, n_gaussians=10, dropout_rate=0.10, dist_threhold=7.).to(device)

checkpoint = torch.load(deepdock.__path__[0]+'/../Trained_models/DeepDock_pdbbindv2019_Schrodinger_17K_minTestLoss.chk', map_location=torch.device(device))
model.load_state_dict(checkpoint['model_state_dict'])  

<All keys matched successfully>

In [3]:
db_complex = PDBbind_complex_dataset(data_path=deepdock.__path__[0]+'/../data/dataset_CASF2016_Schrodinger_285_28a6b.tar', 
                                     min_target_nodes=None, max_ligand_nodes=None)
print('Complexes from pdbBind:', len(db_complex))

Complexes from pdbBind: 285


In [4]:
class input_dataset(Dataset):
    def __init__(self, mols, target_mesh, labels=None, transform=None, pre_transform=None):
        super(input_dataset, self).__init__()
        
        self.mols = [from_networkx(mol2graph.mol_to_nx(m)) for m in mols]
        self.target_mesh = target_mesh
        self.labels = labels
        if labels is None:
            self.labels = range(len(self.mols))
        
    def len(self):
        return len(self.mols)

    def get(self, idx):
        return self.mols[idx], self.target_mesh, self.labels[idx]

In [5]:
%%time
from torch_scatter import scatter_add
results = []

for target_data in db_complex:
    model.eval()
    ligand, target, _, pdbid = target_data
    decoys = Mol2MolSupplier(file=deepdock.__path__[0]+'/../data/CASF-2016/decoys_docking/'+pdbid+'_decoys.mol2', 
                             sanitize=False, cleanupSubstructures=False)
    decoy_names = [m.GetProp('Name') for m in decoys]
    db_decoys = input_dataset(decoys, target, decoy_names)
    loader_decoys = DataLoader(db_decoys, batch_size=20, shuffle=False)
    #print(pdbid)

    for data in loader_decoys:
        decoy, target, cpd_name = data
        decoy, target = decoy.to(device), target.to(device)
        pi, sigma, mu, dist, atom_types, bond_types, batch = model(decoy, target)

        normal = Normal(mu, sigma)
        logprob = normal.log_prob(dist.expand_as(normal.loc))
        logprob += torch.log(pi)
        prob = logprob.exp().sum(1)
        prob_all = scatter_add(prob, batch, dim=0, dim_size=batch.unique().size(0))
    
        prob[torch.where(dist > 10)[0]] = 0.
        prob_10 = scatter_add(prob, batch, dim=0, dim_size=batch.unique().size(0))
        
        prob[torch.where(dist > 7)[0]] = 0.
        prob_7 = scatter_add(prob, batch, dim=0, dim_size=batch.unique().size(0))

        prob[torch.where(dist > 5)[0]] = 0.
        prob_5 = scatter_add(prob, batch, dim=0, dim_size=batch.unique().size(0))

        prob[torch.where(dist > 3)[0]] = 0.
        prob_3 = scatter_add(prob, batch, dim=0, dim_size=batch.unique().size(0))
        
        prob = torch.stack([prob_3, prob_5, prob_7, prob_10, prob_all],dim=1)
        #print(pdbid, cpd_name, prob_all.cpu().detach().numpy())
        results.append(np.concatenate([np.expand_dims(np.repeat(pdbid, len(cpd_name)), axis=1), 
                                       np.expand_dims(cpd_name, axis=1), 
                                       prob.cpu().detach().numpy()], axis=1))
      

CPU times: user 6min 1s, sys: 7.56 s, total: 6min 9s
Wall time: 3min 50s


In [6]:
results = np.concatenate(results, axis=0)
results = pd.DataFrame(np.asarray(results), columns=['PDB_ID', 'Cpd_Name', 'Score_3A', 'Score_5A', 'Score_7A', 'Score_10A', 'Score_all'])
results.to_csv('Score_decoys_docking_CASF2016.csv', index=False)
print(results.shape)
results.head() 

(22492, 7)


Unnamed: 0,PDB_ID,Cpd_Name,Score_3A,Score_5A,Score_7A,Score_10A,Score_all
0,4k18,4k18_100,41.7124709276148,355.00999393874434,1234.1988897433216,1265.8331969710105,1265.8335978696152
1,4k18,4k18_105,36.50107725111608,320.5826389497487,1098.7196528481893,1128.4331804224062,1128.4344001617458
2,4k18,4k18_107,56.13413997037191,406.6658214522255,1261.5476637591444,1298.835192687907,1298.835338794891
3,4k18,4k18_118,44.23784542273988,336.9721288402966,1252.596765731967,1287.006223774616,1287.006601055387
4,4k18,4k18_122,42.72467596870927,342.7847157990617,1211.367931058725,1252.029400636168,1252.029755610901


In [None]:
import pandas as pd

df = pd.read_csv('Score_decoys_docking_CASF2016.csv')
pdbids = df.PDB_ID.unique()
print(len(pdbids))

for pdbid in pdbids:

    df1 = df[df.PDB_ID==pdbid][['Cpd_Name', 'Score_3A']]
    df1.columns= ['#code', 'score']
    df1.to_csv('DockingPower_DeepDock_3A/scores/'+pdbid+'_score.dat', index=False, sep='\t')
    
    df1 = df[df.PDB_ID==pdbid][['Cpd_Name', 'Score_5A']]
    df1.columns= ['#code', 'score']
    df1.to_csv('DockingPower_DeepDock_5A/scores/'+pdbid+'_score.dat', index=False, sep='\t')
    
    df1 = df[df.PDB_ID==pdbid][['Cpd_Name', 'Score_7A']]
    df1.columns= ['#code', 'score']
    df1.to_csv('DockingPower_DeepDock_7A/scores/'+pdbid+'_score.dat', index=False, sep='\t')
    
    df1 = df[df.PDB_ID==pdbid][['Cpd_Name', 'Score_10A']]
    df1.columns= ['#code', 'score']
    df1.to_csv('DockingPower_DeepDock_10A/scores/'+pdbid+'_score.dat', index=False, sep='\t')
    
    df1 = df[df.PDB_ID==pdbid][['Cpd_Name', 'Score_all']]
    df1.columns= ['#code', 'score']
    df1.to_csv('DockingPower_DeepDock_all/scores/'+pdbid+'_score.dat', index=False, sep='\t')