*start this notebook server on a machine with a GPU; optionally, use a modern rdkit version* 

In [None]:
from IPython.display import Image
import os
from desmiles.config import DATA_DIR

fig_dir = os.path.join(DATA_DIR, 'notebooks', 'Figures')

The high level idea behind DESMILES is that if we can learn to generate small molecules from a reduced small molecule representation that has been very successful in modelling structure activity relationships, then we will be able to generate useful molecules for a variety of practical tasks in drug discovery.  Furthermore, if this representation learning is encoding the chemical similarity, then we will be able to easily generate chemically similar molecules starting from any molecule.  The following two images from the DESMILES publication show the outline of the model and the representation of a slice of chemical space.  The notebook below demonstrates the basic functionality of the model.

In [None]:
img = Image(filename=f"{fig_dir}/deep learn chem space (desmiles)__extended data fig 5__3__mcgillen__2019__.png")
display(img)

In [None]:
img = Image(filename=f"{fig_dir}/deep learn chem space (desmiles)__extended data fig 1__1__maragakis__2019__.png")
display(img)

# Demo of DESMILES

The next couple of cell define some high level code for generating and displaying new molecules.

In [None]:
import sys
import os
import argparse
import tempfile
import multiprocessing
import functools
import subprocess
from pathlib import Path


import numpy as np
import pandas as pd
import scipy


from rdkit import Chem
from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect

import desmiles
from desmiles.data import Vocab, FpSmilesList, DesmilesLoader, DataBunch
from desmiles.learner import desmiles_model_learner
from desmiles.models import Desmiles, RecurrentDESMILES
from desmiles.models import get_fp_to_embedding_model, get_embedded_fp_to_smiles_model
from desmiles.utils import load_old_pretrained_desmiles, load_pretrained_desmiles
from desmiles.utils import accuracy4
from desmiles.utils import smiles_idx_to_string
from desmiles.learner import OriginalFastaiOneCycleScheduler, Learner
from desmiles.decoding.astar import AstarTreeParallelHybrid as AstarTree
from desmiles.config import DATA_DIR

import torch

In [None]:
from functools import partial
from IPython.display import SVG, display, HTML
import rdkit.Chem
from rdkit.Chem import Draw,rdMolDescriptors,AllChem, rdDepictor
from rdkit.Chem.Draw import IPythonConsole, rdMolDraw2D
import matplotlib.pyplot as plt


table_from_to_template='<table border=1> <tr> <td>{}</td> <td> --> </td> <td>{}</td> </tr> </table>'
table_4_template='<table border=1 width={}> <tr> <td>{}</td> <td>{}</td> <td>{}</td> <td>{}</td> </tr> </table>'


def canon_smiles(x):
    s = ''
    try: 
        s = Chem.CanonSmiles(x, useChiral=True)
    except:
        pass
    return s


def get_itos_8k():
    return np.load(os.path.join(DATA_DIR, 'pretrained', "itos.npy"))


def vec_to_smiles(idx_vec_inp, itos):
    """Return a SMILES string from an index vector (deals with reversal)"""
    ##HACK TO WORK WITH NEWER VERSION 2020-06-08
    if idx_vec_inp[0] == 3:
        idx_vec = idx_vec_inp[1:]
    else:
        idx_vec = idx_vec_inp
    ##
    if idx_vec[0] == 1:  # SMILES string is in fwd direction
        return ''.join(itos[x] for x in idx_vec if x > 3)
    if idx_vec[0] == 2:  # SMILES string is in bwd direction
        #despot.Print("decoder: bwd direction")
        return ''.join(itos[x] for x in idx_vec[::-1] if x > 3)
    else: # don't know how to deal with it---do your best
        print("decoder received an invalid start to the SMILES", idx_vec)
        return ''.join(itos[x] for x in idx_vec if x > 3)

    
def smiles_to_fingerprint(smiles_str, sparse=False, as_tensor=True):
    "Return the desmiles fp"
    rdmol = Chem.MolFromSmiles(smiles_str)
    fp = np.concatenate([
        np.asarray(GetMorganFingerprintAsBitVect(rdmol, 2, useChirality=True), dtype=np.uint8),
        np.asarray(GetMorganFingerprintAsBitVect(rdmol, 3, useChirality=True), dtype=np.uint8)])
    if sparse:
        return scipy.sparse.csr_matrix(fp)
    if as_tensor:
        import torch
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        return torch.tensor(fp.astype(np.float32)).to(device)
    return fp


def barcode_fp(fp, width=8, height=0.5):
    fig = plt.figure()
    ax2 = fig.add_axes([0, 0, width, height], xticks=[], yticks=[])
    barprops = dict(aspect='auto', cmap=plt.cm.binary, interpolation='nearest')
    return ax2.imshow(fp.reshape((1, -1)), **barprops)
    
    
def moltosvg(rdkit_mol, size_x=450, size_y=150):
    try:
        rdkit_mol.GetAtomWithIdx(0).GetExplicitValence()
    except RuntimeError:
        rdkit_mol.UpdatePropertyCache(False)
    try:
        mc_mol = rdMolDraw2D.PrepareMolForDrawing(rdkit_mol, kekulize=True)
    except ValueError:  # <- can happen on a kekulization failure                                                                                                                             
        mc_mol = rdMolDraw2D.PrepareMolForDrawing(rdkit_mol, kekulize=False)
    drawer = rdMolDraw2D.MolDraw2DSVG(size_x, size_y)
    drawer.DrawMolecule(mc_mol)
    drawer.FinishDrawing()
    svg = drawer.GetDrawingText()
    # It seems that the svg renderer used doesn't quite hit the spec.
    # Here are some fixes to make it work in the notebook, although I think
    # the underlying issue needs to be resolved at the generation step
    return svg.replace('svg:','')


def displayTable4SMILES(smiles, size_x=225, size_y=150, width=980):
    assert(len(smiles)==4)
    svgs = map(lambda x: moltosvg(Chem.MolFromSmiles(x), size_x=size_x, size_y=size_y), smiles)
    display(HTML(table_4_template.format(width, *svgs)))
    
    
def procSMILES(sm):
    m = Chem.MolFromSmiles(sm)
    AllChem.Compute2DCoords(m)
    return m


def imageOfMols(smiles_list, molsPerRow=4, subImgSize=(240,200), labels=None):
    mols = [procSMILES(sm) for sm in smiles_list]
    if labels is not None:
        labels = [str(x) for x in labels]
    img = Draw.MolsToGridImage(mols, molsPerRow=molsPerRow, subImgSize=subImgSize, useSVG=True, legends=labels)
    return img


def imageOfMolsLabels(smiles_labels_list, molsPerRow=5, subImgSize=(200,200)):
    mols = [procSMILES(sm[0]) for sm in smiles_labels_list]
    labels = [str(sm[1]) for sm in smiles_labels_list]
    img = Draw.MolsToGridImage(mols, molsPerRow=molsPerRow, subImgSize=subImgSize, useSVG=True, legends=labels)
    return img


from itertools import zip_longest
def grouper(iterable, n, fillvalue=None):
    args = [iter(iterable)] * n
    return zip_longest(*args, fillvalue=fillvalue)


In [None]:
#from astar_purge import get_astar_tree, isLeafNode
def get_most_probable_smiles(emb_fp, emfp_to_smiles, max_branches=50, num_expand=2000):
    with torch.no_grad():
        astar_tree = AstarTree(emb_fp, 
                               emfp_to_smiles, 
                               max_branches=max_branches, num_expand=num_expand)
        score, smile_idx = next(astar_tree)
        return smiles_idx_to_string(smile_idx)



def get_upto_n_most_probable_valid_smiles(n, emb_fp, emfp_to_smiles, max_branches=50, num_expand=2000):
    with torch.no_grad():
        astar_tree = AstarTree(emb_fp, #.astype(np.float32), 
                                    emfp_to_smiles, 
                                    max_branches=max_branches, num_expand=num_expand)
        for _ in range(n):
            smile = smiles_idx_to_string(next(astar_tree)[1])
            if Chem.MolFromSmiles(smile):
                yield smile
            else:
                yield ""

                
def get_first_n_most_probable_valid_smiles(n, emb_fp, emfp_to_smiles, 
                                           max_branches=50, max_search=200,
                                           num_expand=2000,
                                           verbose=False):
    results = set()
    with torch.no_grad():
        astar_tree = AstarTree(emb_fp, 
                               emfp_to_smiles, 
                               max_branches=max_branches, num_expand=num_expand)
        for i in range(max_search):
            if len(results) == n:
                break
            score, smile = next(astar_tree)
            smile = smiles_idx_to_string(smile)
            if verbose:
                print(score, smile)
            if (smile is not None) and Chem.MolFromSmiles(smile):
                smile = Chem.CanonSmiles(smile)
                results.add(smile)
                yield smile
            else:
                yield ""

                
def set_of_upto_n_most_probable_valid_smiles(n, emb_fp, embfp_to_smiles, max_branches=50, num_expand=2000):
    result = list(get_upto_n_most_probable_valid_smiles(n, emb_fp, embfp_to_smiles, 
                                                        max_branches, num_expand))
    return set([Chem.CanonSmiles(x) for x in result if x != ""])

    
def get_most_probable_valid_smiles(emb_fp, emfp_to_smiles, max_branches=50, num_expand=2000):
    counter = 0
    with torch.no_grad():
        astar_tree = AstarTree(emb_fp, emfp_to_smiles,
                                   max_branches=max_branches,
                              num_expand=num_expand)
        while counter < 5:
            smile = smiles_idx_to_string(next(astar_tree)[1])
            if Chem.MolFromSmiles(smile):
                return smile
            counter += 1
        return smile
    
    
def dedup(seq):
    seen = set()
    seen_add = seen.add
    return [x for x in seq if not (x in seen or seen_add(x))]


def clean(seq):
    return [x for x in seq if x]


def dedup_clean(seq):
    return dedup(clean(seq))


def to_numpy_int(f1):
    return f1.cpu().numpy().astype(np.int32) if isinstance(f1, torch.Tensor) else f1


def tanimoto(f1, f2):
    s1 = to_numpy_int(f1)
    s2 = to_numpy_int(f2)
    return np.sum(s1 & s2) / np.sum(s1 | s2)


def displayTable4labels(labels, width=980):
    display(HTML(table_4_template.format(width, *labels)))



## <div id="load_data">Load some data </div>

First load the encoded version of the training set and the validation set

In [None]:
itos = get_itos_8k()
vec_to_smiles_8k = partial(vec_to_smiles, itos=itos)

def random_smiles_enc8k(enc_table, n=2):
    idx = np.random.randint(0, len(enc_table), n)
    return [(vec_to_smiles_8k(enc_table[i]), enc_table[i][enc_table[i]>0]) for i in idx]

def random_smiles(enc_table, n=10):
    return [s for s,e in random_smiles_enc8k(enc_table, n)]

In [None]:
training_smiles_enc8k = np.load(os.path.join(DATA_DIR, 'pretrained', 'training.enc8000.npy'))
val2_smiles_enc8k = np.load(os.path.join(DATA_DIR, 'pretrained', 'val2.enc8000.npy'))
v2samples = list(pd.read_csv(os.path.join(DATA_DIR, 'notebooks', "fast_val2_molecules.csv"))["SMILES"])

In [None]:
print('\n'.join(random_smiles(training_smiles_enc8k, 5)))

In [None]:
len(training_smiles_enc8k)

Hit Ctrl-Enter on the following cell several times to explore random samples from the training set.  
Change "training_smiles_enc8k" to "val2_smiles_enc8k" to explore the validation set.

In [None]:
import rdkit.Chem.Descriptors
smiles = random_smiles(training_smiles_enc8k, 4) # pick training_smiles... or val2_smiles...
displayTable4SMILES(smiles)
for x in smiles: 
    m = Chem.MolFromSmiles(x)
    print( m.GetNumAtoms(), m.GetNumAtoms(onlyExplicit=False), 
          Chem.Descriptors.NumAromaticRings(m), 
          np.round(Chem.Descriptors.TPSA(m), 2),
          np.round(Chem.Descriptors.MolWt(m), 2), 
          np.round(Chem.Descriptors.MolLogP(m), 2) )
for x in smiles: barcode_fp(smiles_to_fingerprint(x), height=0.3)

## <div id="subSMILES"> subSMILES and DESMILES </div>

Each molecule is made out of up to 26 subSMILES (byte-pair encoded symbols).
These are represented as integer numbers in the encoded tables, and 'itos' converts them to strings.
The following example decomposes a random molecule from the training set.

In [None]:
training_smiles_enc8k.shape

In [None]:
rsamples = random_smiles_enc8k(training_smiles_enc8k, 1)
print(rsamples)
display(imageOfMolsLabels(rsamples, subImgSize=(600,300), molsPerRow=1))
barcode_fp(smiles_to_fingerprint(rsamples[0][0]));

In [None]:
rsamples[0][0], [(x, itos[x]) for s, e in rsamples for x in e]

### Load a pretrained model

Let's load a model that was trained on molecules from both the validation and the training set. This would be the model to use in future applications of DESMILES.  The model parameters came as the result of the hyperoptimization discussed in the publication.

In [None]:
get_model = desmiles.utils.load_old_pretrained_desmiles
fp_to_smiles_5layer = get_model(os.path.join(DATA_DIR, 'pretrained', 'train_val1_val2','model_2000_400_2000_5'))
rmodel = RecurrentDESMILES(fp_to_smiles_5layer)

In [None]:
rmodel

## <div id="first_tests">First tests: create a fragment from its fingerprint</div>

The simplest application of the model is to create a molecule from its fingerprint.
Here is the example of a fragment that is outside of the original library.

In [None]:
fragment = 'Nc1ncc(C(F)(F)F)cc1F'  
display(imageOfMols([fragment], labels=[fragment]))
fp = smiles_to_fingerprint(fragment,  as_tensor=True)
barcode_fp(fp);

In [None]:
fp, fp.size(), set(fp.cpu().numpy()), sum(fp)

We can invert this fingerprint to generate a small molecule.  

In [None]:
%%time
smiles = get_most_probable_smiles(fp, rmodel)

In [None]:
imageOfMols([smiles])

## <div id="collection"> Generate a collection of fragments with Astar </div>

Often one wants to generate a whole bunch of variations of a single molecule.

The example below shows variations of this simple fragment outside the training/validation set, 
together with a measure of the fingerprint similarity (higher is better; 1.0 is perfect match of fingerprints.)

Sometimes the model will go through a number of invalid intermediate attempts, before finding the next example.
The function get_n_most_probable_valid_smiles will only return the valid molecules.
The parameter max_branches limits the search; for very complicated molecules, 
the search might be exhausted before the optimal molecules get returned, 
so higher values of max_branches might get "better" molecules 
but the search will take up more GPU memory.


In [None]:
%%time
%%capture --no-stdout --no-display
smiles = dedup_clean([Chem.CanonSmiles(x) 
                       for x in get_first_n_most_probable_valid_smiles(8, fp, rmodel, 
                                                                      max_branches=100)])

In [None]:
labels = np.round([tanimoto(fp, smiles_to_fingerprint(x)) for x in smiles], 2)
for x in smiles: barcode_fp(smiles_to_fingerprint(x), height=0.2)
display(imageOfMols(smiles, molsPerRow=4, subImgSize=(240, 200), labels=labels));

### Check out some more complex and random molecules (subset of validation 2)

Below are some complicated molecules and their top 3 variants according to DEMSMILES.  For this demonstration we picked as inputs a subset of the validation molecules that decoded rather quickly.  

Please click Ctrl-Enter on the next cell a couple of times until you see some interesting molecules.

In [None]:
rsamples = list(np.random.choice(v2samples, 4))
print(rsamples)
displayTable4SMILES(rsamples)

In [None]:
%%time
%%capture --no-stdout --no-display
for s in rsamples:
    newfp = smiles_to_fingerprint(s)
    newsmiles = dedup_clean([Chem.CanonSmiles(x) for x in 
                         get_first_n_most_probable_valid_smiles(4,newfp,rmodel, max_branches=200)])
    newsmiles.extend(["", "", "", ""])
    newsmiles = newsmiles[:4]
    displayTable4SMILES(newsmiles)
    labels = [np.round(tanimoto(newfp, smiles_to_fingerprint(s)), 2) for s in newsmiles]
    displayTable4labels(labels)

### <div id="Perturbations"> Perturbations of a molecule</div>

Let's add a little noise to the fingerprints of our little fragment by turning some random bits on.

In [None]:
%%time
smiles = []
extra_fp_on = [30, 40, 50, 60]
torch.random.manual_seed(314)
for num_bits_on in extra_fp_on:
    random_indices = torch.randint(0, fp.size().numel(), torch.Size([num_bits_on]))
    fp_add = torch.zeros_like(fp)
    fp_add[random_indices] = 1
    fp_pert = fp + fp_add
    s = get_most_probable_valid_smiles(fp_pert, rmodel, max_branches=100)
    smiles.append(s)
display(imageOfMols(smiles, labels=extra_fp_on))

## <div id="intro_to_algebra">Intro to algebra of molecules</div>
We can "add" two molecules by mixing their fingerprints 
(or by mixing their embeddings, or other internal layers)
Even though the model is highly nonlinear, the "addition" is often intuitive, 
for example when the model is able to combine the fingerprints 
and create a molecule that matches both inputs.

In [None]:
fragment2 = "c1cccnc1N2CCCCC2"

imageOfMols([fragment, fragment2])

In [None]:
%%time
%%capture --no-stdout --no-display
fps = [smiles_to_fingerprint(x, as_tensor=False) for x in [fragment, fragment2]]
ftarget = (fps[0] | fps[1])
ftarget = torch.Tensor(ftarget).cuda()
smiles = dedup_clean(get_first_n_most_probable_valid_smiles(6, ftarget, 
                                                           rmodel))
labels = [tanimoto(smiles_to_fingerprint(x), ftarget) for x in smiles]
labels = ["A", "B"] + [str(x) for x in np.round(labels, 2)]
display(imageOfMols([fragment, fragment2, *smiles], labels=labels))

Let's get rid of the tri-fluoromethyl group

In [None]:
fragment3 = "c1ccccc1C(F)(F)F"
imageOfMols([fragment, fragment2, fragment3], labels=["A", "+  B", "-  C"])

In [None]:
%%time
%%capture --no-stdout --no-display
fps = [smiles_to_fingerprint(x, as_tensor=False) for x in [fragment, fragment2, fragment3]]
f3target = (fps[0] | fps[1]) - fps[2]
f3target = np.clip(f3target, 0, 1)
f3target = torch.Tensor(f3target).cuda()
smiles = dedup_clean(get_first_n_most_probable_valid_smiles(5, f3target, rmodel))
labels = [tanimoto(smiles_to_fingerprint(x), f3target) for x in smiles]
labels = ["A", "+ B", "- C"] + [str(x) for x in np.round(labels, 2)]
display(imageOfMols([fragment, fragment2, fragment3, *smiles], labels=labels))

## <div id="fine_tuning">Fine tuning applications</div>


A promising way to generate new molecules is to finetune the DESMILES model to improve the outputs using training inputs from matched pairs.  This somewhat more complicated application of the model was described in the publication, and was used to generate the figures below.

In [None]:
img = Image(filename=f"{fig_dir}/deep learn chem space (desmiles)__extended data fig 5__2__nisonoff__2019__.png")
display(img)

In [None]:
img = Image(filename=f"{fig_dir}/deep learn chem space (desmiles)__extended data fig 4__2__nisonoff__2019__.png")
display(img)