## Generates SMILE based on input SMILES

In [1]:
import torch
import torch.nn as nn

In [3]:
import sys
sys.path.append('/home/ziqiaoxu/Sample&Dock_DnD/fast_jtnn/')

from mol_tree import Vocab
from jtnn_vae import JTNNVAE

## Load Stock JTNN VAE Model
vocab_loc = '/home/ziqiaoxu/Sample&Dock_DnD/moses/vocab.txt'
vocab = [x.strip("\r\n ") for x in open(vocab_loc)]
vocab = Vocab(vocab)

hidden_size = 450
latent_size = 56  #28 for each latent vector
depthT = 20
depthG = 3
model_loc = '/home/ziqiaoxu/Sample&Dock_DnD/training/moses-h450z56/model.iter-400000'

jtnn = JTNNVAE(vocab, hidden_size, latent_size, depthT, depthG)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
jtnn.load_state_dict(torch.load(model_loc, map_location=device))

from torch.autograd import Variable
from nnutils import create_var
from rdkit import rdBase
## Disable rdkit Logs
rdBase.DisableLog('rdApp.error')

def z_vecs(x_mean, x_log_var):
    
    epsilon = create_var(torch.randn_like(x_mean))
    z_vecs = x_mean + torch.exp(x_log_var / 2) * epsilon
    return z_vecs

def smiles_gen(smiles,ndesigns):
    ## Convert smiles to one-hot encoding (altered function from original code)
    x_tree, x_mol = jtnn.encode_from_smiles_xs(smiles)
    ## Encode one-hots to z-mean and compute log var. Following Mueller et al.
    tree_mean = jtnn.T_mean(x_tree)
    tree_log_var = -torch.abs(jtnn.T_var(x_tree)) 
    mol_mean = jtnn.G_mean(x_mol)
    mol_log_var = -torch.abs(jtnn.G_var(x_mol))

    smiles_list = []
    for i in range(ndesigns):
        ## generate latent vectors (stochastic)
        z_tree = z_vecs(tree_mean, tree_log_var)
        z_mol = z_vecs(mol_mean, mol_log_var)
        ## decode back to smiles
        smilesout = jtnn.decode(z_tree,z_mol,False)
        ## Check if the smiles already exists
        if smilesout not in smiles_list:
            smiles_list.append(smilesout)    
    return smiles_list

def smiles_to_sdfile(smiles_list):
    for i, x in enumerate(smiles_list):
        name = 'design_'+str(i)
        output = '%s/design/'%(directory)+name+'.sd'
        m2 = Chem.MolFromSmiles(x)
        AllChem.Compute2DCoords(m2)
        m2.SetProp("_Name", name)
        m3 = Chem.AddHs(m2)
        AllChem.EmbedMolecule(m3,AllChem.ETKDG())
        w = Chem.SDWriter(output)
        w.write(m3)
        w.flush()


In [13]:
smi = smiles_gen('Brc1cnn(-c2ccc(-c3nc[nH]n3)cc2)c1',20)