## Imports:

In [1]:
import torch 
import unidecode
import random
import numpy as np
import matplotlib.pyplot as plt

from models.charRNN import make_charRNN, get_random_batch, generate
from utils import n_chars, check_validity, check_novelty, strsmis2listsmis, list2txt, get_props

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Data preparation

In [3]:
file = unidecode.unidecode(open('all_data/s_100_str_+1M.txt').read())

In [4]:
file[0:100]

'<COc1ccc2[C@@H]3[C@H](COc2c1)C(C)(C)OC4=C3C(=O)C(=O)C5=C4OC(C)(C)[C@@H]6COc7cc(OC)ccc7[C@H]56><C[S+]'

## Make the model

In [5]:
lr = 0.0005

In [5]:
model_name = 'G'
rnn, optimizer, criterion = make_charRNN(n_chars=56, hidden_size=512, num_layers=3,
                                        lr=lr, pretrained_file=model_name) 

In [6]:
print(rnn)

RNN(
  (embed): Embedding(56, 30)
  (gru): GRU(30, 512, num_layers=3, batch_first=True)
  (fc): Linear(in_features=512, out_features=56, bias=True)
)


# Train

In [8]:
obj='_4_'

In [49]:
def train(iters=10, chunk_len=256, n_epochs=800, print_every=10):
    print("=> Starting training...")
    Ls = []
    best_model = 'models/pretrained/CRLV_' +obj+str(lr)+'_'+str(n_epochs)+'_'+str(iters)+'_'+str(chunk_len)+ '_.pth'
    prev_best_loss = 1000000
    for epoch in range(1, n_epochs + 1):            
        inp, target = get_random_batch(file, chunk_len)
        hidden = rnn.init_hidden()

        rnn.zero_grad()
        loss = 0
        inp = inp.to(device)
        target = target.to(device)
        
        for i in range(chunk_len): 
            output, hidden = rnn(inp[:, i], hidden)
            loss += criterion(output, target[:,i])

#-------- the following 10 lines are for conditional loss training fashion---------------: 
        n = 1
        for i in range(iters):
            smi = generate(rnn, initial_str='<', predict_len=102, temperature=0.50)
            arr, alr, oh, cooh, coor, nh2, rval = get_props(smi, c=2)                                                        
            cycle, grp, rv = 0, 0, 0 
            if arr==2 or alr==1: 
                cycle = 1
            if oh>=1 or cooh>=1 or coor>=1 or nh2>=1:
                grp = 1
            if 0.05<rval<0.5:
                rv = 1

            prop1, prop2, prop3, prop4, prop5 = get_props(smi, c=1)

            if cycle==1 and grp==1 and rv==1 and prop1<=3 and prop2<=480 and prop3<=3 and prop4<=3 and prop5<=3:
                n = n*2
            
        loss = loss / n
#---------------------------------------------------------------------------------------
        loss.backward()
        optimizer.step()
        loss = loss.item()/(chunk_len)
        
        if prev_best_loss > loss:
            prev_best_loss = loss
            torch.save(rnn.state_dict(), best_model)
        
        Ls.append(loss)
        if epoch % print_every == 0:  
            print(f'Loss: {loss} epoch: {epoch}')           
    return Ls

In [6]:
## --- for reproducibility of the results ---
seeds = [3, 0, 3]

torch.manual_seed(seeds[0])
np.random.seed(seeds[1])
random.seed(seeds[2])

In [9]:
iters=20  
chunk_len=128  #128 or 256
n_epochs = 800

Losses = train(iters=iters, chunk_len=chunk_len, n_epochs=n_epochs, print_every=10)

In [None]:
def moving_average(values, window):
    weights = np.repeat(1.0, window) / window
    return np.convolve(values, weights, 'valid')

Losses_ = moving_average(Losses, 20)
plt.plot(Losses_)

## 0. generate:

In [7]:
# get the model saved during traiing (the best)
# model_name = 'CRLV_'+obj+str(lr)+'_'+str(n_epochs)+'_'+str(iters)+'_'+str(chunk_len)+ '_'   #or use the provided one
model_name = 'CRLV_4'
rnn, optimizer, criterion = make_charRNN(n_chars=56, hidden_size=512, num_layers=3,
                                        lr=lr, pretrained_file=model_name) 
print(model_name)

CRLV_4


In [8]:
CLF_temp = 0.5 

In [21]:
n_ep = 10000
all_smis = []
for ep in range(n_ep):
    length = 102 
    ic = '<'
    smi = generate(rnn, initial_str=ic, predict_len=length, temperature=CLF_temp)
    all_smis.append(smi)
    print("iteration:", ep)

In [11]:
### save the list of generated smiles if you want
# filename = model_name + '_all_smis_.txt'
# path = 'gen_smis/' + filename
# list2txt(path, mylist=all_smis)

## 1. How many among all_smis are valid:

In [22]:
val_smis = []
for i, s in enumerate(all_smis):
    v = check_validity(s)
    if v == 1:
        val_smis.append(s)
        print("mol at:", i, " : verified")

In [23]:
print("There is ", len(val_smis), "valid molecule among the", len(all_smis), "generated")
print("meaning a percent of: ", 100 * len(val_smis)/len(all_smis), "%")

## 2. How many among val_smis are novel:

In [12]:
smis_list = strsmis2listsmis(file)
len(smis_list)

1498669

In [24]:
val_and_nov_smis = []
for i, s in enumerate(val_smis):
    nv = check_novelty(s, smis_list)
    if nv == 1:
        val_and_nov_smis.append(s)
        print("mol at:", i, " : verified")

In [25]:
print("There is ", len(val_and_nov_smis), "valid & novel molecule among the", len(all_smis), "generated")
print("meaning a percent of: ", 100 * len(val_and_nov_smis)/ len(val_smis), "%")

## 3. How many among valid_and_novel_smis are unique 
* meaning, they don't have duplicates/were not repeated in the generated molecules

In [15]:
def check_uniqueness(smis_list):
    return list(set(smis_list))

val_nov_and_unique_smis = check_uniqueness(val_and_nov_smis)

In [26]:
print("There is ", len(val_nov_and_unique_smis), "valid, novel and unique molecule among the", len(all_smis), "generated")
print("meaning, a percent of : ", 100 * len(val_nov_and_unique_smis)/ len(val_and_nov_smis), "%")

## 4. Diversity (intDiv) of the valid, novel, and unique smiles:
* using the moses library https://github.com/molecularsets/moses

In [27]:
import moses
moses.metrics.internal_diversity(val_nov_and_unique_smis)

## 5. How many have the desired chemical property:

In [18]:
succ = []
for i, smi in enumerate(val_nov_and_unique_smis):
    arr, alr, oh, cooh, coor, nh2, rval = get_props(smi, c=2)                                                        
    cycle, grp, rv = 0, 0, 0 
    if arr==2 or alr==1: 
        cycle = 1
    if oh>=1 or cooh>=1 or coor>=1 or nh2>=1:
        grp = 1
    if 0.05<rval<0.5:
        rv = 1
    
    prop1, prop2, prop3, prop4, prop5 = get_props(smi, c=1)
        
    if cycle==1 and grp==1 and rv==1 and prop1<=3 and prop2<=480 and prop3<=3 and prop4<=3 and prop5<=3:
        succ.append(smi)

In [28]:
print('there is', len(succ), 'molecule with the desired property')
print("meaning a percent of", 100 * len(succ)/len(val_nov_and_unique_smis), "% of smiles among the valid, \
novel and unique ones with the desired property")