Stage 1: Clean up the code and modulize
1. Simplified VAE util
2. Sample and dock mol gen Script
3. Docking bash script
4. Score Processing script
5. Fingerprint script
6. Clustering script
7. Wrapper bash script <br>
Stage 2: Performance improvement
Maybe compile most of the functions and write and compile the wrapper in C++?

## VAE Core Utilities

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from mol_tree import Vocab, MolTree
from nnutils import create_var, flatten_tensor, avg_pool
from jtnn_enc import JTNNEncoder
from jtnn_dec import JTNNDecoder
from mpn import MPN
from jtmpn import JTMPN
from datautils import tensorize
from fast_jtnn import *

from chemutils import enum_assemble, set_atommap, copy_edit_mol, attach_mols
import rdkit
import rdkit.Chem as Chem
import copy, math

ImportError: No module named mol_tree

In [None]:
## Load Stock JTNN VAE Model
vocab = [x.strip("\r\n ") for x in open('../data/moses/vocab.txt')] 
vocab = Vocab(vocab)
nsample = 100
latent_size = 56
depthT = 20
depthG = 3
hidden_size = 450
model_loc = '../fast_molvae/moses-h450z56/model.iter-400000'
jtnn = JTNNVAE(vocab, hidden_size, latent_size, depthT, depthG)
jtnn.load_state_dict(torch.load(model_loc, map_location=torch.device('cpu')))

#### Copied from jtnn.nnutils
def create_var(tensor, requires_grad=None):
    if requires_grad is None:
        return Variable(tensor)
    else:
        return Variable(tensor, requires_grad=requires_grad)
       
#### Random latent vector sampling based on mean and log var        
def z_vecs(x_mean, x_log_var):
    
    epsilon = create_var(torch.randn_like(x_mean))
    z_vecs = x_mean + torch.exp(x_log_var / 2) * epsilon
    return z_vecs

def encode_from_smiles_xs(self, smiles):
        tree_batch = [MolTree(smiles)]
        _, jtenc_holder, mpn_holder = tensorize(tree_batch, self.vocab, assm=False)
        tree_vecs, _, mol_vecs = self.encode(jtenc_holder, mpn_holder)
        return tree_vecs, mol_vecs

#### Takes one smiles and generate numbers of designs
def smiles_gen(smiles,ndesigns):
    ## Convert smiles to one-hot encoding (modified function from the original code)
    x_tree, x_mol = jtnn.encode_from_smiles_xs(smiles)
    ## Convert one-hots to mean and log var. Following Mueller et al.
    tree_mean = jtnn.T_mean(x_tree)
    tree_log_var = -torch.abs(jtnn.T_var(x_tree)) 
    mol_mean = jtnn.G_mean(x_mol)
    mol_log_var = -torch.abs(jtnn.G_var(x_mol))

    smiles_list = []
    for i in range(ndesigns):
        ## generate latent vectors (stochastic)
        z_tree = z_vecs(tree_mean, tree_log_var)
        z_mol = z_vecs(mol_mean, mol_log_var)
        ## decode back to smiles
        smilesout = jtnn.decode(z_tree,z_mol,False)
        ## Check if the smiles already exists
        if smilesout not in smiles_list:
            smiles_list.append(smilesout)    
    return smiles_list

#### Generate and save .sdf from smiles for each design cycle
def smiles_to_sdfile(smiles_list):
    for i, x in enumerate(smiles_list):
        name = 'design_'+str(i)
        output = '%s/design/'%(directory)+name+'.sd'
        m2 = Chem.MolFromSmiles(x)
        AllChem.Compute2DCoords(m2)
        m2.SetProp("_Name", name)
        m3 = Chem.AddHs(m2)
        AllChem.EmbedMolecule(m3,AllChem.ETKDG())
        w = Chem.SDWriter(output)
        w.write(m3)
        w.flush()