In [1]:
from torch_geometric.data import DataLoader
import torch.distributions as D
import matplotlib.pyplot as plt
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem, Draw, Descriptors, rdMolTransforms
from rdkit import rdBase
rdBase.DisableLog('rdApp.warning')
import glob
import os

import deepdock
from deepdock.utils.distributions import *
from deepdock.utils.data import *
from deepdock.models import *

from deepdock.DockingFunction import *
from scipy.optimize import basinhopping, brute, differential_evolution
import copy

# set the random seeds for reproducibility
np.random.seed(123)
torch.cuda.manual_seed_all(123)
torch.manual_seed(123)

%matplotlib inline



In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

ligand_model = LigandNet(28, residual_layers=10, dropout_rate=0.10)
target_model = TargetNet(4, residual_layers=10, dropout_rate=0.10)
model = DeepDock(ligand_model, target_model, hidden_dim=64, n_gaussians=10, dropout_rate=0.10, dist_threhold=7.).to(device)

checkpoint = torch.load(deepdock.__path__[0]+'/../Trained_models/DeepDock_pdbbindv2019_Schrodinger_17K_minTestLoss.chk', map_location=torch.device(device))
model.load_state_dict(checkpoint['model_state_dict']) 

<All keys matched successfully>

In [3]:
%%time
db_complex = PDBbind_complex_dataset(data_path=deepdock.__path__[0]+'/../data/dataset_deepdock_pdbbind_v2019_Schrodinger_17K_28a6b.tar', 
                                     min_target_nodes=None, max_ligand_nodes=None)
db_complex = [i for i in db_complex if i[3] in checkpoint['pdbIDs_test']]
print('Complexes in Test Set:', len(db_complex))


Complexes in Test Set: 1367
CPU times: user 4.9 s, sys: 2.61 s, total: 7.51 s
Wall time: 7.51 s


In [4]:
def dock_compound(data, dist_threshold=3., popsize=150):
    np.random.seed(123)
    torch.cuda.manual_seed_all(123)
    torch.manual_seed(123)
    
    model.eval()
    ligand, target, activity, pdbid = data
    ligand, target = ligand.to(device), target.to(device)
    pi, sigma, mu, dist, atom_types, bond_types, batch = model(ligand, target)
    
    pdb_id = pdbid[0]
    real_mol = Chem.SDMolSupplier(os.path.join(deepdock.__path__[0]+'/../data/PDBbind2019_schrodinger_ligands', i, i+'_ligand.sdf'))[0]
    opt = optimze_conformation(mol=real_mol, target_coords=target.pos.cpu(), n_particles=1, 
                               pi=pi.cpu(), mu=mu.cpu(), sigma=sigma.cpu(), dist_threshold=dist_threshold)
    
    #Define bounds
    max_bound = np.concatenate([[np.pi]*3, target.pos.cpu().max(0)[0].numpy(), [np.pi]*len(opt.rotable_bonds)], axis=0)
    min_bound = np.concatenate([[-np.pi]*3, target.pos.cpu().min(0)[0].numpy(), [-np.pi]*len(opt.rotable_bonds)], axis=0)
    bounds = (min_bound, max_bound)
    
    # Optimize conformations
    result = differential_evolution(opt.score_conformation, list(zip(bounds[0],bounds[1])), maxiter=500, 
                                    popsize=int(np.ceil(popsize/(len(opt.rotable_bonds)+6))),
                                    mutation=(0.5, 1), recombination=0.8, disp=False, seed=123)
    
    # Get optimized molecule and RMSD
    opt_mol = apply_changes(opt.mol, result['x'], opt.rotable_bonds)
    ligCoords = torch.stack([torch.tensor(m.GetConformer().GetPositions()[opt.noHidx]) for m in [opt_mol]])
    dist = compute_euclidean_distances_matrix(ligCoords, opt.targetCoords).flatten().unsqueeze(1)
    result['num_MixOfGauss'] = torch.where(dist <= dist_threshold)[0].size(0)
    result['rmsd'] = Chem.rdMolAlign.AlignMol(opt_mol, real_mol, atomMap=list(zip(opt.noHidx,opt.noHidx)))
    result['pdb_id'] = pdb_id
    
    # Get score of real conformation
    ligCoords = torch.stack([torch.tensor(m.GetConformer().GetPositions()[opt.noHidx]) for m in [real_mol]])
    dist = compute_euclidean_distances_matrix(ligCoords, opt.targetCoords).flatten().unsqueeze(1)
    score_real_mol = calculate_probablity(opt.pi, opt.sigma, opt.mu, dist)
    score_real_mol[torch.where(dist > dist_threshold)[0]] = 0.
    result['score_real_mol'] = pdb_id = score_real_mol.reshape(opt.n_particles, -1).sum(1).item()
    del ligCoords, dist, score_real_mol
    
    result['pkx'] = data[2][0].item()
    result['num_atoms'] = real_mol.GetNumHeavyAtoms()
    result['num_rotbonds'] = len(opt.rotable_bonds)
    result['rotbonds'] = opt.rotable_bonds
    #result['num_MixOfGauss'] = mu.size(0)
    
    return result

In [None]:
%%time
loader = DataLoader(db_complex, batch_size=1, shuffle=False)

results = []
i = 0
for data in loader:
    try:
        results.append(dock_compound(data))
        d = {}
        for k in results[0].keys():
            if k != 'jac':
                d[k] = tuple(d[k] for d in results)
        torch.save(d, 'DockingResults_TestSet.chk')
        results_df = pd.DataFrame.from_dict(d)
        results_df.to_csv('DockingResults_TestSet.csv', index=False) 
        i += 1
    except:
        print(i, data[3])
        #break
        i += 1



In [None]:
print('Number of compounds with succesful optimization:', results_df[results_df.success == True].shape)
print('Mean RMSD of all compounds:', results_df.rmsd.mean(), '+/-', results_df.rmsd.std())
print('Mean RMSD of compounds with succesful optimization:', results_df[results_df.success == True].rmsd.mean(), '+/-', results_df[results_df.success == True].rmsd.std())