In [1]:
import os
REPO_PATH = os.getcwd().split('notebooks')[0]

import sys
sys.path.append(REPO_PATH)
from rdkit import Chem
from rdkit.Chem.Crippen import MolLogP
from rdkit.Chem import MolFromSmiles, MolToSmiles, QED, Descriptors, Draw
from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect as Morgan
from src.utils.sascorer import calculateScore

from typing import Optional, List, Iterable
import pandas as pd
import numpy as np
from functools import partial
import scipy
import torch

from lets_plot import *
LetsPlot.setup_html()

# Novelty, Uniqueness, Internal Diversity

In [5]:
# required functions:


# TAKEN FROM GUACAMOL: https://github.com/BenevolentAI/guacamol

def canonicalize(smiles: str, include_stereocenters=True) -> Optional[str]:
    """
    Canonicalize the SMILES strings with RDKit.
    The algorithm is detailed under https://pubs.acs.org/doi/full/10.1021/acs.jcim.5b00543
    Args:
        smiles: SMILES string to canonicalize
        include_stereocenters: whether to keep the stereochemical information in the canonical SMILES string
    Returns:
        Canonicalized SMILES string, None if the molecule is invalid.
    """

    mol = MolFromSmiles(smiles)

    if mol is not None:
        return MolToSmiles(mol, isomericSmiles=include_stereocenters)
    else:
        return None
    
def canonicalize_list(smiles_list: Iterable[str], include_stereocenters=True) -> List[str]:
    """
    Canonicalize a list of smiles. Filters out repetitions and removes corrupted molecules.
    Args:
        smiles_list: molecules as SMILES strings
        include_stereocenters: whether to keep the stereochemical information in the canonical SMILES strings
    Returns:
        The canonicalized and filtered input smiles.
    """

    canonicalized_smiles = [canonicalize(smiles, include_stereocenters) for smiles in smiles_list]

    # Remove None elements
    canonicalized_smiles = [s for s in canonicalized_smiles if s is not None]

    return remove_duplicates(canonicalized_smiles)
    
def remove_duplicates(list_with_duplicates):
    """
    Removes the duplicates and keeps the ordering of the original list.
    For duplicates, the first occurrence is kept and the later occurrences are ignored.
    Args:
        list_with_duplicates: list that possibly contains duplicates
    Returns:
        A list with no duplicates.
    """

    unique_set = set()
    unique_list = []
    for element in list_with_duplicates:
        if element not in unique_set:
            unique_set.add(element)
            unique_list.append(element)

    return unique_list


def mapper(n_jobs):
    '''
    Returns function for map call.
    If n_jobs == 1, will use standard map
    If n_jobs > 1, will use multiprocessing pool
    If n_jobs is a pool object, will return its map function
    '''
    if n_jobs == 1:
        def _mapper(*args, **kwargs):
            return list(map(*args, **kwargs))

        return _mapper
    if isinstance(n_jobs, int):
        pool = Pool(n_jobs)

        def _mapper(*args, **kwargs):
            try:
                result = pool.map(*args, **kwargs)
            finally:
                pool.terminate()
            return result

        return _mapper
    return n_jobs.map

def get_mol(smiles_or_mol):
    '''
    Loads SMILES/molecule into RDKit's object
    '''
    if isinstance(smiles_or_mol, str):
        if len(smiles_or_mol) == 0:
            return None
        mol = Chem.MolFromSmiles(smiles_or_mol)
        if mol is None:
            return None
        try:
            Chem.SanitizeMol(mol)
        except ValueError:
            return None
        return mol
    return smiles_or_mol

def fingerprint(smiles_or_mol, fp_type='maccs', dtype=None, morgan__r=2,
                morgan__n=1024, *args, **kwargs):
    """
    Generates fingerprint for SMILES
    If smiles is invalid, returns None
    Returns numpy array of fingerprint bits

    Parameters:
        smiles: SMILES string
        type: type of fingerprint: [MACCS|morgan]
        dtype: if not None, specifies the dtype of returned array
    """
    fp_type = fp_type.lower()
    molecule = get_mol(smiles_or_mol, *args, **kwargs)
    if molecule is None:
        return None
    if fp_type == 'maccs':
        keys = MACCSkeys.GenMACCSKeys(molecule)
        keys = np.array(keys.GetOnBits())
        fingerprint = np.zeros(166, dtype='uint8')
        if len(keys) != 0:
            fingerprint[keys - 1] = 1  # We drop 0-th key that is always zero
    elif fp_type == 'morgan':
        fingerprint = np.asarray(Morgan(molecule, morgan__r, nBits=morgan__n),
                                 dtype='uint8')
    else:
        raise ValueError("Unknown fingerprint type {}".format(fp_type))
    if dtype is not None:
        fingerprint = fingerprint.astype(dtype)
    return fingerprint

def fingerprints(smiles_mols_array, n_jobs=1, already_unique=False, *args,
                 **kwargs):
    '''
    Computes fingerprints of smiles np.array/list/pd.Series with n_jobs workers
    e.g.fingerprints(smiles_mols_array, type='morgan', n_jobs=10)
    Inserts np.NaN to rows corresponding to incorrect smiles.
    IMPORTANT: if there is at least one np.NaN, the dtype would be float
    Parameters:
        smiles_mols_array: list/array/pd.Series of smiles or already computed
            RDKit molecules
        n_jobs: number of parralel workers to execute
        already_unique: flag for performance reasons, if smiles array is big
            and already unique. Its value is set to True if smiles_mols_array
            contain RDKit molecules already.
    '''
    if isinstance(smiles_mols_array, pd.Series):
        smiles_mols_array = smiles_mols_array.values
    else:
        smiles_mols_array = np.asarray(smiles_mols_array)
    if not isinstance(smiles_mols_array[0], str):
        already_unique = True

    if not already_unique:
        smiles_mols_array, inv_index = np.unique(smiles_mols_array,
                                                 return_inverse=True)

    fps = mapper(n_jobs)(
        partial(fingerprint, *args, **kwargs), smiles_mols_array
    )

    length = 1
    for fp in fps:
        if fp is not None:
            length = fp.shape[-1]
            first_fp = fp
            break
    fps = [fp if fp is not None else np.array([np.NaN]).repeat(length)[None, :]
           for fp in fps]
    if scipy.sparse.issparse(first_fp):
        fps = scipy.sparse.vstack(fps).tocsr()
    else:
        fps = np.vstack(fps)
    if not already_unique:
        return fps[inv_index]
    return fps


def average_agg_tanimoto(stock_vecs, gen_vecs,
                         batch_size=5000, agg='max',
                         device='cpu', p=1):
    """
    For each molecule in gen_vecs finds closest molecule in stock_vecs.
    Returns average tanimoto score for between these molecules

    Parameters:
        stock_vecs: numpy array <n_vectors x dim>
        gen_vecs: numpy array <n_vectors' x dim>
        agg: max or mean
        p: power for averaging: (mean x^p)^(1/p)
    """
    assert agg in ['max', 'mean'], "Can aggregate only max or mean"
    agg_tanimoto = np.zeros(len(gen_vecs))
    total = np.zeros(len(gen_vecs))
    for j in range(0, stock_vecs.shape[0], batch_size):
        x_stock = torch.tensor(stock_vecs[j:j + batch_size]).to(device).float()
        for i in range(0, gen_vecs.shape[0], batch_size):
            y_gen = torch.tensor(gen_vecs[i:i + batch_size]).to(device).float()
            y_gen = y_gen.transpose(0, 1)
            tp = torch.mm(x_stock, y_gen)
            jac = (tp / (x_stock.sum(1, keepdim=True) +
                         y_gen.sum(0, keepdim=True) - tp)).cpu().numpy()
            jac[np.isnan(jac)] = 1
            if p != 1:
                jac = jac**p
            if agg == 'max':
                agg_tanimoto[i:i + y_gen.shape[1]] = np.maximum(
                    agg_tanimoto[i:i + y_gen.shape[1]], jac.max(0))
            elif agg == 'mean':
                agg_tanimoto[i:i + y_gen.shape[1]] += jac.sum(0)
                total[i:i + y_gen.shape[1]] += jac.shape[0]
    if agg == 'mean':
        agg_tanimoto /= total
    if p != 1:
        agg_tanimoto = (agg_tanimoto)**(1/p)
    return np.mean(agg_tanimoto)

def internal_diversity(gen, n_jobs=1, device='cpu', fp_type='morgan',
                       gen_fps=None, p=1):
    """
    Computes internal diversity as:
    1/|A|^2 sum_{x, y in AxA} (1-tanimoto(x, y))
    """
    if gen_fps is None:
        gen_fps = fingerprints(gen, fp_type=fp_type, n_jobs=n_jobs)
    return 1 - (average_agg_tanimoto(gen_fps, gen_fps,
                                     agg='mean', device=device, p=p)).mean()

def novelty(smiles, trainset):
    print('calculating novelty')
    trainset_can = set(canonicalize_list(trainset, include_stereocenters=False))
    smiles_can = set(canonicalize_list(smiles, include_stereocenters=False))
    novel_molecules = smiles_can.difference(trainset_can)
    novel_ratio = len(novel_molecules) / len(smiles)
    return novel_ratio

def uniqueness(smiles):
    print('calculating uniqueness')
    unique_molecules = set(canonicalize_list(smiles, include_stereocenters=False))
    unique_ratio = len(unique_molecules) / len(smiles)
    return unique_ratio

def smiles_to_plogp(smile):
        
    mol = MolFromSmiles(smile)
    penalized_logp = MolLogP(mol) - calculateScore(mol)
    for ring in mol.GetRingInfo().AtomRings():
        if len(ring) > 6:
            penalized_logp -= 1
    return penalized_logp

def unwanted_cycles(smile, max_cycle_size=6, min_cycle_size=5):

    mol = MolFromSmiles(smile)
    cycle_count = 0
    for ring in mol.GetRingInfo().AtomRings():
        if not (min_cycle_size <= len(ring) <= max_cycle_size):
            cycle_count += 1
    return cycle_count

In [3]:
# load generated molecules:
generated = list(pd.read_csv(REPO_PATH + '/outputs/unconditional_generation.csv').smiles.values)
train = list(pd.read_csv(REPO_PATH + '/data/zinc250k/zinc250k_smiles.txt', header=None).iloc[:,0].values)
# zinc_props = pd.read_csv('/home/oestreichm/drugdiff/local_files/zinc250k_properties.csv')

In [6]:
# internal diversity

print(f'the internal diversity of unconditionally generated molecules is: {round(internal_diversity(generated), 2)}')
print(f'the internal diversity of training molecules is: {round(internal_diversity(train), 2)}')

internal diversity of unconditionally generated molecules: 0.91
internal diversity of training molecules: 0.87


In [7]:
# novelty

print(f'the novelty of unconditionally generated molecules compared to the train set is: {round(novelty(generated, train), 2)}')

calculating novelty
the novelty of unconditionally generated molecules compared to the train set is: 0.99


In [8]:
# uniqueness

print(f'the uniqueness of unconditionally generated molecules is: {round(uniqueness(generated), 2)}')

calculating uniqueness
the uniqueness of unconditionally generated molecules is: 0.99


# Molecular Property Distributions

In [9]:
# compute properties of generated molecules


props_df = pd.DataFrame(generated, columns=['smiles'])

# compute SA
props_df['SA'] = [calculateScore(MolFromSmiles(s)) for s in generated]
print('SA done')

# compute QED
props_df['QED'] = [QED.qed(MolFromSmiles(s)) for s in generated]
print('QED done')

# compute logP
props_df['logP'] = [MolLogP(MolFromSmiles(s)) for s in generated]
print('logP done')

# compute p-logP
props_df['p-logP'] = [smiles_to_plogp(s) for s in generated]
print('p-logP done')

# compute molecular weight
props_df['molecular_weight'] = [Descriptors.ExactMolWt(MolFromSmiles(s)) for s in generated]
print('molecular_weight done')

# compute molar refractivity
props_df['molar_refractivity'] = [Chem.Crippen.MolMR(MolFromSmiles(s)) for s in generated]
print('molar_refractivity done')

# compute topological surface area mapping
props_df['topological_surface_area_mapping'] = [Chem.QED.properties(MolFromSmiles(s)).PSA for s in generated]
print('topological_surface_area_mapping done')

# compute H-bond donors
props_df['h_bond_donors'] = [Descriptors.NumHDonors(MolFromSmiles(s)) for s in generated]
print('h_bond_donors done')

# compute H-bond acceptors
props_df['h_bond_acceptors'] = [Descriptors.NumHAcceptors(MolFromSmiles(s)) for s in generated]
print('h_bond_acceptors done')

# compute rotatable bonds
props_df['rotatable_bonds'] = [Descriptors.NumRotatableBonds(MolFromSmiles(s)) for s in generated]
print('rotatable_bonds done')

# compute number of atoms
props_df['number_of_atoms'] = [Chem.rdchem.Mol.GetNumAtoms(MolFromSmiles(s)) for s in generated]
print('number_of_atoms done')

# compute formal charge
props_df['formal_charge'] = [Chem.rdmolops.GetFormalCharge(MolFromSmiles(s)) for s in generated]
print('formal_charge done')

# compute heavy atoms
props_df['heavy_atoms'] = [Chem.rdchem.Mol.GetNumHeavyAtoms(MolFromSmiles(s)) for s in generated]
print('heavy_atoms done')

# compute number of rings
props_df['num_of_rings'] = [Chem.rdMolDescriptors.CalcNumRings(MolFromSmiles(s)) for s in generated]
print('num_of_rings done')

# compute number of unwanted_cycles
props_df['num_of_unwanted_cycles'] = [unwanted_cycles(s) for s in generated]
print('num_of_unwanted_cycles done')

SA done
QED done
logP done
p-logP done
molecular_weight done
molar_refractivity done
topological_surface_area_mapping done
h_bond_donors done
h_bond_acceptors done
rotatable_bonds done
number_of_atoms done
formal_charge done
heavy_atoms done
num_of_rings done
num_of_unwanted_cycles done


In [11]:
props_df.columns

Index(['smiles', 'SA', 'QED', 'logP', 'p-logP', 'molecular_weight',
       'molar_refractivity', 'topological_surface_area_mapping',
       'h_bond_donors', 'h_bond_acceptors', 'rotatable_bonds',
       'number_of_atoms', 'formal_charge', 'heavy_atoms', 'num_of_rings',
       'num_of_unwanted_cycles'],
      dtype='object')

In [13]:
LetsPlot.setup_html()
from matplotlib import pyplot as plt 

def plot_discrete(props_df, property):
    tmp = pd.DataFrame.from_dict(props_df[property].value_counts().to_dict(), orient='index', columns=['count'])
    tmp[property] = tmp.index
    g = ggplot() + \
            geom_bar(data = tmp, mapping=aes(x=property, y = 'count'),
                           alpha=0.6, stat='identity', position='dodge') + \
            ggtitle(property) + \
            xlab(property)
    return g

def plot_continuous(props_df, property):
    g = ggplot() + \
            geom_density(data = props_df, mapping=aes(x=property),
                           alpha=0.6, fill = 'orange') + \
            ggtitle(property) + \
            xlab(property)
    return g


for p in list(props_df.columns[1:]):
    if p in ['h_bond_donors', 'h_bond_acceptors', 'num_of_rings', 'num_of_unwanted_cycles', 'formal_charge']:
        g = plot_discrete(props_df, p)
        g.show()
    else:
        g = plot_continuous(props_df, p)
        g.show()