In [1]:
from torch_geometric.data import DataLoader
import torch_geometric
import torch.distributions as D
import matplotlib.pyplot as plt
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem, Draw, rdFMCS, rdMolTransforms
from rdkit.Chem.rdMolAlign import AlignMol
from rdkit.Chem import PandasTools
from rdkit import rdBase
import glob
import os

from deepdock.utils.distributions import *
from deepdock.utils.data import *
from deepdock.models import *

%matplotlib inline
np.random.seed(123)
torch.cuda.manual_seed_all(123)
torch.manual_seed(123)



<torch._C.Generator at 0x7fbb761d30d0>

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
#device = 'cpu'

ligand_model = LigandNet(28, residual_layers=10, dropout_rate=0.10)
target_model = TargetNet(4, residual_layers=10, dropout_rate=0.10)
model = DeepDock(ligand_model, target_model, hidden_dim=64, n_gaussians=10, dropout_rate=0.10, dist_threhold=7.).to(device)

checkpoint = torch.load('../Trained_models/DeepDock_pdbbindv2019_13K_minTestLoss.chk')
model.load_state_dict(checkpoint['model_state_dict']) 

<All keys matched successfully>

In [3]:
db_complex = PDBbind_complex_dataset(data_path='../data/dataset_CASF-2016_285.tar',
                                     min_target_nodes=None, max_ligand_nodes=None)
print('Complexes from pdbBind:', len(db_complex))

Complexes from pdbBind: 285


In [4]:
%%time
from torch_scatter import scatter_add
results = []

model.eval()
loader = DataLoader(db_complex, batch_size=20, shuffle=False)

for data in loader:
    ligand, target, _, pdbid = data
    ligand, target = ligand.to(device), target.to(device)
    pi, sigma, mu, dist, atom_types, bond_types, batch = model(ligand, target)

    normal = Normal(mu, sigma)
    logprob = normal.log_prob(dist.expand_as(normal.loc))
    logprob += torch.log(pi)
    prob = logprob.exp().sum(1)
    prob_all = scatter_add(prob, batch, dim=0, dim_size=batch.unique().size(0))
        
    prob[torch.where(dist > 10)[0]] = 0.
    prob_10 = scatter_add(prob, batch, dim=0, dim_size=batch.unique().size(0))
        
    prob[torch.where(dist > 7)[0]] = 0.
    prob_7 = scatter_add(prob, batch, dim=0, dim_size=batch.unique().size(0))

    prob[torch.where(dist > 5)[0]] = 0.
    prob_5 = scatter_add(prob, batch, dim=0, dim_size=batch.unique().size(0))

    prob[torch.where(dist > 3)[0]] = 0.
    prob_3 = scatter_add(prob, batch, dim=0, dim_size=batch.unique().size(0))
        
    prob = torch.stack([prob_3, prob_5, prob_7, prob_10, prob_all],dim=1)
    #print(pdbid, cpd_name, prob_all.cpu().detach().numpy())
    results.append(np.concatenate([np.expand_dims(pdbid, axis=1), 
                                   prob.cpu().detach().numpy()], axis=1))
      

CPU times: user 6.46 s, sys: 333 ms, total: 6.79 s
Wall time: 4.8 s


In [5]:
results = np.concatenate(results, axis=0)
results = pd.DataFrame(np.asarray(results), columns=['PDB_ID', 'Score_3A', 'Score_5A', 'Score_7A', 'Score_10A', 'Score_all'])
results.to_csv('Score_CoreSet_docking_CASF2016.csv', index=False)
print(results.shape)
results.head() 

(285, 6)


Unnamed: 0,PDB_ID,Score_3A,Score_5A,Score_7A,Score_10A,Score_all
0,4k18,64.25368680000449,420.5898705181299,1376.762661813891,1410.0162568367757,1410.01638602799
1,4qac,83.2844061298711,450.0125696899959,1128.0662764096696,1154.7673063010277,1154.7690904503804
2,1o3f,187.7132053065723,778.5596122182768,1779.243987564094,1823.1001001165653,1823.110781221914
3,4ih7,35.630378250691514,270.11676199153,787.7759051254867,814.7537476029596,814.755215144911
4,3dx1,49.94476787861658,202.24570100387868,449.7290290124778,462.38353806385936,462.3919132560412
