In [16]:
# %load predict_pka.py
#!/usr/bin/env python

from __future__ import division
from __future__ import unicode_literals
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
from rdkit.Chem import rdmolops
from rdkit.Chem.MolStandardize import rdMolStandardize

import h5py
import json
import os
import os.path as osp
import numpy as np
import pandas as pd

import torch
from torch import nn
import torch.nn.functional as F

from torch.nn import Linear
from torch.nn import BatchNorm1d
from torch.utils.data import Dataset
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import global_add_pool, global_mean_pool

from utils.ionization_group import get_ionization_aid
from utils.descriptor import mol2vec
from net import GCNNet
import py3Dmol

root = os.path.abspath("")

def load_model(model_file, device="cpu"):
    model= GCNNet().to(device)
    model.load_state_dict(torch.load(model_file, map_location=device))
    model.eval()
    return model

def model_pred(m2, aid, model, device="cpu"):
    data = mol2vec(m2, aid)
    with torch.no_grad():
        data = data.to(device)
        pKa = model(data)
        pKa = pKa.cpu().numpy()
        pka = pKa[0][0]
    return pka

def predict_acid(mol):
    model_file = os.path.join(root, "models/weight_acid.pth")
    model_acid = load_model(model_file)

    acid_idxs= get_ionization_aid(mol, acid_or_base="acid")
    acid_res = {}
    for aid in acid_idxs:
        apka = model_pred(mol, aid, model_acid)
        acid_res.update({aid:apka})
    return acid_res

def predict_base(mol):
    model_file = os.path.join(root, "models/weight_base.pth")
    model_base = load_model(model_file)

    base_idxs= get_ionization_aid(mol, acid_or_base="base")
    base_res = {}
    for aid in base_idxs:
        bpka = model_pred(mol, aid, model_base) 
        base_res.update({aid:bpka})
    return base_res

def predict(mol, uncharged=True):
    if uncharged:
        un = rdMolStandardize.Uncharger()
        mol = un.uncharge(mol)
    mol = AllChem.AddHs(mol)
    base_dict = predict_base(mol)
    acid_dict = predict_acid(mol)
    AllChem.EmbedMolecule(mol)
    return base_dict, acid_dict, mol

def drawit(m, atom_idxs, pkas, p=None, confId=-1):
    mb = Chem.MolToMolBlock(m, confId=confId)
    if p is None:
        p = py3Dmol.view(width=600, height=400)
    p.removeAllModels()
    p.addModel(mb, 'sdf')
    for pka, idx in zip(pkas, atom_idxs):
        pos = m.GetConformer().GetAtomPosition(idx)
        label = "pKa={:.4}".format(pka)
        p.addLabel(label, {"position": {"x": pos.x, "y":pos.y, "z": pos.z}, "fontSize":10})
    p.setStyle({'stick':{'colorscheme':'greenCarbon'}})
#     p.SetLabelStyle({"fontSize": 12})
#     p.addSurface(py3Dmol.VDW,{'opacity':0.7,'colorscheme':{'prop':'b','gradient':'sinebow','min':0,'max':70}})
    p.setBackgroundColor('0xeeeeee')
    p.zoomTo()
    return p.show()

In [19]:
# mol = Chem.MolFromSmiles("CN(C)CCCN1C2=CC=CC=C2SC2=C1C=C(C=C2)C(C)=O")
mol = Chem.MolFromSmiles("OCCn(c1)ncc1-c(cn2)nc(c23)n(nn3)Cc(c4)ccc(c45)nccc5")
base_dict, acid_dict, m = predict(mol)
atom_idx = list(base_dict.keys()) + list(acid_dict.keys())
pkas = list(base_dict.values()) + list(acid_dict.values())
drawit(m, atom_idx, pkas)