# RNN-based SMILES generator for novel molecular design

In this demo, we will construct a recurrent neural network (RNN)-based generator for the novel molecular design of caspase-6 inhibitor. In our work, RNN were first trained to generate a chemical language model on the RDKit canonical SMILES dataset containing 2.4 million molecules from the PubChem database (https://pubchem.ncbi.nlm.nih.gov), where the molecules were restrained to containing between 10 and 100 heavy atoms and the max length was 140. Then, a dataset with 433 active caspase-6 inhibitors was used to fine-tune the pre-trained RNN model. Since GITHUB has a single file limit of 100MB, this demo provides a SMILES dataset with 800 thousand molecules for pre-training RNN generator.

## Import modules

In [None]:
import sys

In [None]:
sys.path.append('./release/') 

In [None]:
from model_construct_test import RNN
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
sns.set(style="darkgrid") # style: darkgrid, whitegrid, dark, white, ticks

# Pre-training RNN generator 

In this demo, RNN were first trained to generate a chemical language model on the RDKit canonical SMILES dataset with 800 thousand molecules from the PubChem database (https://pubchem.ncbi.nlm.nih.gov). Herein, the last RNN model and the RNN model with max_valid were saved for later transfer learning process.

##  Parameters settings

The pre-tained generator constructed by RNN based on three-layers GRU, the initial batch size was set as 128. Before training, the pathes of input smile file (mol_path) and Vocabulary file (voc_path) need to be addressed. User can define the learning rate (lr), weigth of decay (weigh_decay), episode number (epoch_num), how many step output and renew lr (step_num).

In [None]:
voc_path = "./data/Voc"
mol_path = "./data/800_thousand_filtered.smi"
restore_path = None

# models save path

#max_save_path = "model/max_prior_RNN.ckpt"             # Save RNN model with max valid
last_save_path = "model/last_prior_RNN.ckpt"           # Save last RNN model

# parameters

epoch_num = 2
step_num = 30
decay_step_num = 300
smile_num = 5
lr = 0.001
weigth_decay = 0.05

total_loss, total_valid = RNN.fit(voc_path=voc_path, mol_path=mol_path, restore_path=restore_path, 
                                  max_save_path=max_save_path, last_save_path=last_save_path,
                                  epoch_num=epoch_num, step_num=step_num, 
                                  decay_step_num=decay_step_num, smile_num=smile_num, 
                                  lr=lr, weigth_decay=weigth_decay)


## Check pre-train RNN model

Check whether the Total_loss and Total_valid converge.

In [None]:
plt.figure(figsize=(12,10))

plt.subplot(2,2,1)
plt.plot(total_loss)
plt.title("Total losses", fontsize=15)

plt.subplot(2,2,2)
plt.plot(total_valid)
plt.title("Total valid", fontsize=15)  

plt.show() 

# Fine-tune the last pre-trained model by transfer learning

Based on the last pre-trained model, RNN was trained to generate novel molecules on a dataset with 433 active caspase-6 inhibitors. Herein, the last RNN model and the RNN model with max_valid were saved for later SMILES generating.

##  Parameters setting 

The pre-tained generator constructed by RNN based on three-layers GRU, the initial batch size was set as 128. Before training, the pathes of input smile file (mol_path) and Vocabulary file (voc_path) need to be addressed. User can define the learning rate (lr), weigth of decay (weigh_decay), episode number (epoch_num), how many step output and renew lr (step_num).

In [None]:
voc_path = "./data/Voc"
mol_path = "./data/433_casp6_inhibitor_filtered.smi"
restore_path = "./model/last_prior_RNN.ckpt" 

# models save path

#max_save_path = "model/max_agent_RNN.ckpt"
last_save_path = "model/last_agent_RNN.ckpt"

# parameters

epoch_num = 200
step_num = 1
decay_step_num = 10
smile_num = 5
lr = 0.001
weigth_decay = 0.03


total_loss, total_valid = RNN.fit(voc_path=voc_path, mol_path=mol_path, restore_path=restore_path, 
                                  max_save_path=max_save_path, last_save_path=last_save_path,
                                  epoch_num=epoch_num, step_num=step_num, 
                                  decay_step_num=decay_step_num, smile_num=smile_num, 
                                  lr=lr, weigth_decay=weigth_decay)

## Check fine-train RNN model

Check whether the Total_loss and Total_valid converge.

In [None]:
plt.figure(figsize=(12,10))

plt.subplot(2,2,1)
plt.plot(total_loss)
plt.title("Total losses", fontsize=15)

plt.subplot(2,2,2)
plt.plot(total_valid)
plt.title("Total valid", fontsize=15)  

plt.show() 

## Generating mols

In [None]:
from gen_smi_test import Generator

# define generator path

restore_path = "model/last_agent_RNN.ckpt"          
voc_path = "./data/Voc"

# define detail

csv_num = 1
gen_num = 1000
mol_num = 10

Gen_mol = Generator(restore_path=restore_path, voc_path=voc_path, 
                    csv_num=csv_num, gen_num=gen_num, mol_num=mol_num)