In [None]:
import os
from rdkit import Chem
from rdkit.Chem import QED, Draw
from rdkit.Chem.Draw import rdMolDraw2D
from IPython.display import SVG
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

## Prep

In [None]:
TRAIN_KEYS = 'keys/keys_select'
TEST_KEYS = 'keys/keys_2018'

In [None]:
def get_ligand_mol(pdb_code):
    sdf_found = False
    for f in os.listdir(f'cbidata/{pdb_code}'):
        if f.endswith('.sdf'):
            sdf_found = True
            sdf_fname = f'cbidata/{pdb_code}/{f}'
            ligand_mol = Chem.MolFromMolFile(sdf_fname)
            ligand_mol.SetProp('_Name', f[:3])
            break
    if not sdf_found:
        return None
    return ligand_mol

In [None]:
def molsvg(mol, width=150, height=100):
    mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol))
    view = rdMolDraw2D.MolDraw2DSVG(width, height)
    tm = rdMolDraw2D.PrepareMolForDrawing(mol)
    option = view.drawOptions()
    option.circleAtoms=False
    view.DrawMolecule(tm)
    view.FinishDrawing()
    svg = view.GetDrawingText()
    return SVG(svg)

In [None]:
keys = []
pkds = []
for keyfile in [TRAIN_KEYS, TEST_KEYS]:
    for line in open(keyfile):
        it = line.rstrip().split('\t')
        pdb_code, ligand_name, year, value = it[0], it[1], int(it[2]), float(it[3])
        keys.append(pdb_code)
        pkds.append(value)

In [None]:
ligand_mols = []

for key, pkd in zip(keys, pkds):
    if not os.path.exists(f'cbidata/{key}'):
        continue
    ligand_mol = get_ligand_mol(key)
    if not ligand_mol:
        continue

    ligand_mol.SetProp('pkd', str(pkd))
    ligand_mol.SetProp('key', key)
    qed = round(QED.qed(ligand_mol), 3)
    ligand_mol.SetProp('qed', str(qed))
    #n_atoms = ligand_mol.GetNumAtoms()
    ligand_mols.append(ligand_mol)

## Customize your view

### Sorting mols

In [None]:
ligand_mols.sort(key=lambda m: m.GetNumAtoms())

In [None]:
ligand_mols.sort(key=lambda m: float(m.GetProp('pkd')))

### Histogram

In [None]:
pkds = np.array([float(mol.GetProp('pkd')) for mol in ligand_mols])
plt.hist(pkds, bins=range(15))
pkds.min(), pkds.max(), pkds.mean()

In [None]:
natoms = np.array([mol.GetNumAtoms() for mol in ligand_mols])
plt.hist(natoms, bins=[5*i for i in range(14)])
natoms.min(), natoms.max(), natoms.mean()

In [None]:
qeds = np.array([float(mol.GetProp('qed')) for mol in ligand_mols])
plt.hist(qeds, bins=[0.1*i for i in range(11)])
qeds.min(), qeds.max(), qeds.mean()

### Show molecules

In [None]:
def get_molinfo(mol):
    class T: pass
    t = T()
    t.key = mol.GetProp('key')
    t.name = mol.GetProp('_Name')
    t.natoms = mol.GetNumAtoms()
    t.pkd = float(mol.GetProp('pkd'))
    t.qed = float(mol.GetProp('qed'))
    return t

In [None]:
for mol in ligand_mols:
    t = get_molinfo(mol)
    if t.pkd < 10:
        continue
    print(f'{t.key}\t{t.name}\t{t.natoms}\t{t.pkd}\t{t.qed}')
    display(molsvg(mol, width=200, height=100))

In [None]:
for mol in ligand_mols:
    t = get_molinfo(mol)
    if 10 <= t.natoms <= 45:
        continue
    print(f'{t.key}\t{t.name}\t{t.natoms}\t{t.pkd}\t{t.qed}')
    display(molsvg(mol, width=200, height=100))