In [1]:
from torch_geometric.data import DataLoader
import torch_geometric
import torch.distributions as D
import matplotlib.pyplot as plt
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem, Draw, rdFMCS, rdMolTransforms
from rdkit.Chem.rdMolAlign import AlignMol
from rdkit.Chem import PandasTools
from rdkit import rdBase
import glob
import os

import deepdock
from deepdock.utils.distributions import *
from deepdock.utils.data import *
from deepdock.models import *

%matplotlib inline
np.random.seed(123)
torch.cuda.manual_seed_all(123)
torch.manual_seed(123)



<torch._C.Generator at 0x7f9b8d51b270>

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

ligand_model = LigandNet(28, residual_layers=10, dropout_rate=0.10)
target_model = TargetNet(4, residual_layers=10, dropout_rate=0.10)
model = DeepDock(ligand_model, target_model, hidden_dim=64, n_gaussians=10, dropout_rate=0.10, dist_threhold=7.).to(device)

checkpoint = torch.load(deepdock.__path__[0]+'/../Trained_models/DeepDock_pdbbindv2019_Schrodinger_17K_minTestLoss.chk', map_location=torch.device(device))
model.load_state_dict(checkpoint['model_state_dict']) 

<All keys matched successfully>

In [3]:
db_complex = PDBbind_complex_dataset(data_path=deepdock.__path__[0]+'/../data/dataset_CASF2016_Schrodinger_285_28a6b.tar', 
                                     min_target_nodes=None, max_ligand_nodes=None)
print('Complexes from pdbBind:', len(db_complex))

Complexes from pdbBind: 285


In [4]:
class input_dataset(Dataset):
    def __init__(self, mols, target_mesh, labels=None, transform=None, pre_transform=None):
        super(input_dataset, self).__init__()
        
        self.mols = [from_networkx(mol2graph.mol_to_nx(m)) for m in mols]
        self.target_mesh = target_mesh
        self.labels = labels
        if labels is None:
            self.labels = range(len(self.mols))
        
    def len(self):
        return len(self.mols)

    def get(self, idx):
        return self.mols[idx], self.target_mesh, self.labels[idx]

In [5]:
%%time
import glob
from torch_scatter import scatter_add
target_files = [f.split('/')[-1] for f in glob.glob(deepdock.__path__[0]+'/../data/CASF-2016/decoys_screening/*', recursive=False)]
results = []


for target_data in db_complex:
    model.eval()
    ligand, target, _, pdbid = target_data
    if pdbid in target_files: 
        decoy_files = [f.split('/')[-1] for f in glob.glob(deepdock.__path__[0]+'/../data/CASF-2016/decoys_screening/'+pdbid+'/*.mol2', recursive=False)]
    else:
        continue
        
    for file in decoy_files:
        decoys = Mol2MolSupplier(file=deepdock.__path__[0]+'/../data/CASF-2016/decoys_screening/'+pdbid+'/'+file, 
                                 sanitize=False, cleanupSubstructures=False)
        decoy_names = [m.GetProp('Name') for m in decoys]
        db_decoys = input_dataset(decoys, target, decoy_names)
        loader_decoys = DataLoader(db_decoys, batch_size=20, shuffle=False)
        #print(pdbid)

        for data in loader_decoys:
            decoy, target_temp, cpd_name = data
            decoy, target_temp = decoy.to(device), target_temp.to(device)
            pi, sigma, mu, dist, atom_types, bond_types, batch = model(decoy, target_temp)

            normal = Normal(mu, sigma)
            logprob = normal.log_prob(dist.expand_as(normal.loc))
            logprob += torch.log(pi)
            prob = logprob.exp().sum(1)
            prob_all = scatter_add(prob, batch, dim=0, dim_size=batch.unique().size(0))
            
            prob[torch.where(dist > 10)[0]] = 0.
            prob_10 = scatter_add(prob, batch, dim=0, dim_size=batch.unique().size(0))
        
            prob[torch.where(dist > 7)[0]] = 0.
            prob_7 = scatter_add(prob, batch, dim=0, dim_size=batch.unique().size(0))

            prob[torch.where(dist > 5)[0]] = 0.
            prob_5 = scatter_add(prob, batch, dim=0, dim_size=batch.unique().size(0))

            prob[torch.where(dist > 3)[0]] = 0.
            prob_3 = scatter_add(prob, batch, dim=0, dim_size=batch.unique().size(0))
        
            prob = torch.stack([prob_3, prob_5, prob_7, prob_10, prob_all],dim=1)
            #print(pdbid, cpd_name, prob_all.cpu().detach().numpy())
            results.append(np.concatenate([np.expand_dims(np.repeat(pdbid, len(cpd_name)), axis=1), 
                                           np.expand_dims(cpd_name, axis=1), 
                                           prob.cpu().detach().numpy()], axis=1))


CPU times: user 6h 54min 8s, sys: 9min 26s, total: 7h 3min 35s
Wall time: 4h 25min 33s


In [6]:
results = np.concatenate(results, axis=0)
results = pd.DataFrame(np.asarray(results), columns=['PDB_ID', 'Cpd_Name', 'Score_3A', 'Score_5A', 'Score_7A', 'Score_10A', 'Score_all'])
results.to_csv('Score_decoys_screening_CASF2016.csv', index=False)
print(results.shape)
results.head() 

(1624500, 7)


Unnamed: 0,PDB_ID,Cpd_Name,Score_3A,Score_5A,Score_7A,Score_10A,Score_all
0,1o3f,4de2_ligand_1,7.254809622430958,89.94999521484337,447.3108304509815,471.3340925876855,471.3379500840941
1,1o3f,4de2_ligand_10,2.2435853524573157,64.8291232744619,469.6149662619709,497.2204820203437,497.22063628573727
2,1o3f,4de2_ligand_100,3.451115585675536,78.15551585586073,506.3139331427199,554.3860309102873,554.3878273892042
3,1o3f,4de2_ligand_106,4.784300354428733,102.6058833211317,590.7510395276462,620.5011685339421,620.5046101480829
4,1o3f,4de2_ligand_108,2.349412635601983,76.80465950772964,539.9298935579708,575.3306395604491,575.3335251313035
