In [1]:
from __future__ import unicode_literals, print_function, division
import os
import pandas as pd
import numpy as np
import random

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, random_split
# from rdkit import Chem

import matplotlib.pyplot as plt
# plt.switch_backend('agg')
import matplotlib.ticker as ticker

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
# rdkit load SMILES and structures
SOS_token = 0
EOS_token = 1
MAX_LENGTH = 120

dset_path = 'dataset'
esol_dset_path = os.path.join(dset_path, 'ESOL')
esol_dset = os.path.join(esol_dset_path, 'delaney-processed.csv')
zinc_dset = os.path.join(dset_path, '250k_rndm_zinc_drugs_clean_3.csv')

# ESOL Dataset
esol_df = pd.read_csv(esol_dset)
esol_solu_df = esol_df[['smiles', 'measured log solubility in mols per litre']]
# esol_solu_df['mols'] = esol_solu_df.apply(lambda x: Chem.MolFromSmiles(x.smiles), axis=1)

# VAE Dataset
zinc_df = pd.read_csv(zinc_dset, index_col=None)
zinc_df = zinc_df.replace('\n', '', regex=True)
# zinc_df['mols'] = zinc_df.apply(lambda x: Chem.MolFromSmiles(x.smiles), axis=1)

# # ----------One hot SMILES----------
# from deepchem.feat.one_hot import OneHotFeaturizer, zinc_charset
# featurizer_onehot = OneHotFeaturizer(zinc_charset, 120)

In [3]:
class Elem:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0: "SOS", 1: "EOS"}
        self.n_words = 2  # Count SOS and EOS

    def addSMILES(self, smiles):
        elements = list(smiles.strip())
        for element in elements:
            self.addElement(element)

    def addElement(self, element):
        if element not in self.word2index:
            self.word2index[element] = self.n_words
            self.word2count[element] = 1
            self.index2word[self.n_words] = element
            self.n_words += 1
        else:
            self.word2count[element] += 1

In [4]:
def filterPair(p):
    return len(p[0].split(' ')) < MAX_LENGTH


def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

In [5]:
def prepareSMILES(SMILES_arr, properties):
    input_elem, output_elem, pairs = Elem('input_elem'), Elem('output_elem'), [[i, j] for i, j in zip(SMILES_arr, properties)]
    print("Read %s sentence pairs" % len(SMILES_arr))
    pairs = filterPairs(pairs)
    print("Trimmed to %s sentence pairs" % len(pairs))
    print("Counting words...")
    for pair in pairs:
        input_elem.addSMILES(pair[0])
        output_elem.addSMILES(pair[0])
    print("Counted words:")
    print(input_elem.name, input_elem.n_words)
    print(output_elem.name, output_elem.n_words)
    return input_elem, output_elem, pairs

# input_elem, output_elem, pairs = prepareSMILES(esol_df['smiles'].values, esol_df['measured log solubility in mols per litre'].values)
input_elem, output_elem, pairs = prepareSMILES(zinc_df['smiles'].values, zinc_df['logP'].values)
print(random.choice(pairs))

Read 249455 sentence pairs
Trimmed to 249455 sentence pairs
Counting words...
Counted words:
input_elem 36
output_elem 36
['Cc1ccc(NC(=O)c2nn3ccccc3c2Cl)cc1C', 3.85684]


In [6]:
def indexesFromSentence(elem, smiles):
    return [elem.word2index[element] for element in list(smiles.strip())]


def tensorFromSentence(lang, smiles):
    indexes = indexesFromSentence(lang, smiles)
    indexes.append(EOS_token)
    return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)


def tensorsFromPair(pair):
    input_tensor = tensorFromSentence(input_elem, pair[0])
    target_tensor = torch.tensor(pair[1], dtype=torch.long, device=device).view(-1)
    return (input_tensor, target_tensor)

In [7]:
def data_loader(tensor_pairs, train_size_percentage=0.8, BATCH_SIZE=1):
    train_size = int(train_size_percentage * len(tensor_pairs))
    test_size = len(tensor_pairs) - train_size
    train_dset, test_dset = random_split([tensorsFromPair(pair) for pair in tensor_pairs], [train_size, test_size])
    train_loader = DataLoader(
        dataset=train_dset,
        batch_size=BATCH_SIZE,
        shuffle=True,
    )
    test_loader = DataLoader(
        dataset=test_dset,
        batch_size=BATCH_SIZE,
        shuffle=True,
    )
    return train_loader, test_loader

### Model

In [8]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)

    def forward(self, input, hidden):
        embedded = self.embedding(input).view(1, 1, -1)
        output = embedded
        output, hidden = self.gru(output, hidden)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)


class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(DecoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(output_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden):
        output = self.embedding(input).view(1, 1, -1)
        output = F.relu(output)
        output, hidden = self.gru(output, hidden)
        output = self.softmax(self.out(output[0]))
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

In [9]:
teacher_forcing_ratio = 0.5

def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length=MAX_LENGTH):
    encoder.train()
    decoder.train()
    encoder_hidden = encoder.initHidden()

    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    input_length = input_tensor.size(0)
    target_length = target_tensor.size(0)

    encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)

    loss = 0

    for ei in range(input_length):
        encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)
        encoder_outputs[ei] = encoder_output[0, 0]

    decoder_input = torch.tensor([[SOS_token]], device=device)

    decoder_hidden = encoder_hidden

    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False

    if use_teacher_forcing:
        # Teacher forcing: Feed the target as the next input
        for di in range(target_length):
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
#             decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, encoder_outputs)
            loss += criterion(decoder_output, target_tensor[di])
            decoder_input = target_tensor[di]  # Teacher forcing

    else:
        # Without teacher forcing: use its own predictions as the next input
        for di in range(target_length):
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
#             decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, encoder_outputs)
            topv, topi = decoder_output.topk(1)
            decoder_input = topi.squeeze().detach()  # detach from history as input

            loss += criterion(decoder_output, target_tensor[di])
            if decoder_input.item() == EOS_token:
                break

    loss.backward()

    encoder_optimizer.step()
    decoder_optimizer.step()

    return loss.item() / target_length

In [10]:
import time
import math


def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))

In [11]:
def trainIters(encoder, decoder, training_loader, print_every=1000, plot_every=100, learning_rate=0.01, EPOCH=50, model_folder='test'):
    if not os.path.exists('./model/' + model_folder):
        os.mkdir('./model/' + model_folder)

    start = time.time()
    plot_losses = []
    print_loss_total = 0  # Reset every print_every
    plot_loss_total = 0  # Reset every plot_every

    encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate) # SGD
    decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate) # SGD
    criterion = nn.NLLLoss()

    for epoch in range(1, EPOCH + 1):
        print('Time info:', timeSince(start, epoch / EPOCH))
        for batch_idx, (data, target) in enumerate(training_loader):
            batch_idx += 1
            # training_pair = training_pairs[iter - 1]
            input_tensor, target_tensor = data[0], data[0]

            loss = train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)
            print_loss_total += loss
            plot_loss_total += loss

            if batch_idx % print_every == 0:
                print_loss_avg = print_loss_total / print_every
                print_loss_total = 0
                # (timeSince(start, iter / n_iters)
                # print('Iter: {}/{} ({:.0f}%)\tLoss: {:.6f}'.format(iter, n_iters, iter / n_iters * 100, print_loss_avg))
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader),
                    print_loss_avg))

            if batch_idx % plot_every == 0:
                plot_loss_avg = plot_loss_total / plot_every
                plot_losses.append(plot_loss_avg)
                plot_loss_total = 0
                
        # Save model/parameters
        torch.save(encoder.state_dict(), './model/{}/model_autoencRNN_enc_{}.pkl'.format(model_folder, epoch))
        torch.save(decoder.state_dict(), './model/{}/model_autoencRNN_dec_{}.pkl'.format(model_folder, epoch))
        if epoch == (EPOCH - 1):
            torch.save(encoder, './model/{}/model_autoencRNN_enc_{}.pkl'.format(model_folder, epoch))
            torch.save(decoder, './model/{}/model_autoencRNN_dec_{}.pkl'.format(model_folder, epoch))

    # showPlot(plot_losses)
    return plot_losses

In [12]:
def showPlot(points):
    plt.figure()
    fig, ax = plt.subplots()
    # this locator puts ticks at regular intervals
    loc = ticker.MultipleLocator(base=0.2)
    ax.yaxis.set_major_locator(loc)
    plt.plot(points)

In [13]:
def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH):
    with torch.no_grad():
        input_tensor = tensorFromSentence(input_elem, sentence)
        input_length = input_tensor.size()[0]
        encoder_hidden = encoder.initHidden()

        encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)

        for ei in range(input_length):
            encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)
            encoder_outputs[ei] += encoder_output[0, 0]

        decoder_input = torch.tensor([[SOS_token]], device=device)  # SOS

        decoder_hidden = encoder_hidden

        decoded_words = []
        decoder_attentions = torch.zeros(max_length, max_length)

        for di in range(max_length):
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
#             decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, encoder_outputs)
#             decoder_attentions[di] = decoder_attention.data
            topv, topi = decoder_output.data.topk(1)
            if topi.item() == EOS_token:
                decoded_words.append('<EOS>')
                break
            else:
                decoded_words.append(output_elem.index2word[topi.item()])

            decoder_input = topi.squeeze().detach()

        return decoded_words, decoder_attentions[:di + 1]

In [14]:
def evaluateRandomly(encoder, decoder, n=10):
    for i in range(n):
        pair = random.choice(pairs)
        print('>', pair[0])
        print('=', pair[1])
        output_words, attentions = evaluate(encoder, decoder, pair[0])
        output_sentence = ''.join(output_words)
        print('<', output_sentence)
        print('')

### Training

In [None]:
%%time

hidden_size = 256
train_loader, test_loader = data_loader(pairs)

encoder = EncoderRNN(input_elem.n_words, hidden_size).to(device)
decoder = DecoderRNN(hidden_size, output_elem.n_words).to(device)
# attn_decoder1 = AttnDecoderRNN(hidden_size, output_lang.n_words, dropout_p=0.1).to(device)

train_losses = trainIters(encoder, decoder, train_loader, EPOCH=100, print_every=10000, model_folder='AERNN_zinc_hidden_256')

Time info: 0m 0s (- 0m 0s)
Time info: 213m 58s (- 10484m 49s)


In [16]:
# train_losses = trainIters(encoder, decoder, 75000, print_every=100)

0m 10s (- 128m 11s) (100 0%) 3.5220
0m 20s (- 128m 22s) (200 0%) 2.3364
0m 27s (- 115m 17s) (300 0%) 2.0428
0m 34s (- 106m 22s) (400 0%) 1.8116
0m 43s (- 107m 56s) (500 0%) 1.8604
0m 51s (- 106m 2s) (600 0%) 1.6656
1m 1s (- 109m 7s) (700 0%) 1.7139
1m 10s (- 109m 39s) (800 1%) 1.6917
1m 17s (- 106m 56s) (900 1%) 1.6038
1m 24s (- 104m 48s) (1000 1%) 1.6853
1m 31s (- 102m 50s) (1100 1%) 1.5782
1m 39s (- 101m 58s) (1200 1%) 1.6039
1m 47s (- 101m 27s) (1300 1%) 1.6759
1m 54s (- 100m 30s) (1400 1%) 1.6403
2m 1s (- 99m 13s) (1500 2%) 1.5226
2m 11s (- 100m 51s) (1600 2%) 1.4420
2m 20s (- 100m 42s) (1700 2%) 1.6331
2m 26s (- 99m 33s) (1800 2%) 1.6021
2m 34s (- 98m 48s) (1900 2%) 1.5791
2m 40s (- 97m 34s) (2000 2%) 1.4985
2m 47s (- 96m 54s) (2100 2%) 1.5189
2m 53s (- 95m 45s) (2200 2%) 1.5236
2m 57s (- 93m 43s) (2300 3%) 1.4737
3m 4s (- 92m 49s) (2400 3%) 1.5091
3m 9s (- 91m 49s) (2500 3%) 1.5090
3m 14s (- 90m 3s) (2600 3%) 1.4155
3m 20s (- 89m 28s) (2700 3%) 1.4947
3m 26s (- 88m 41s) (2800 3%)

21m 50s (- 52m 36s) (22000 29%) 0.9278
21m 56s (- 52m 30s) (22100 29%) 0.7945
22m 2s (- 52m 24s) (22200 29%) 0.8395
22m 8s (- 52m 18s) (22300 29%) 0.7397
22m 13s (- 52m 12s) (22400 29%) 0.6711
22m 20s (- 52m 7s) (22500 30%) 0.9823
22m 26s (- 52m 1s) (22600 30%) 0.9961
22m 33s (- 51m 58s) (22700 30%) 0.9354
22m 39s (- 51m 52s) (22800 30%) 0.8334
22m 45s (- 51m 46s) (22900 30%) 0.8007
22m 51s (- 51m 40s) (23000 30%) 0.8409
22m 56s (- 51m 33s) (23100 30%) 0.6540
23m 2s (- 51m 27s) (23200 30%) 0.7147
23m 9s (- 51m 24s) (23300 31%) 0.7961
23m 16s (- 51m 19s) (23400 31%) 0.9455
23m 22s (- 51m 14s) (23500 31%) 0.7810
23m 28s (- 51m 8s) (23600 31%) 0.8589
23m 34s (- 51m 2s) (23700 31%) 0.8419
23m 41s (- 50m 57s) (23800 31%) 0.8975
23m 47s (- 50m 52s) (23900 31%) 0.9432
23m 52s (- 50m 44s) (24000 32%) 0.9756
23m 56s (- 50m 34s) (24100 32%) 0.8275
24m 1s (- 50m 25s) (24200 32%) 0.7859
24m 5s (- 50m 15s) (24300 32%) 0.7559
24m 9s (- 50m 4s) (24400 32%) 0.8728
24m 13s (- 49m 55s) (24500 32%) 0.897

41m 53s (- 30m 49s) (43200 57%) 0.5310
41m 59s (- 30m 44s) (43300 57%) 0.6087
42m 6s (- 30m 39s) (43400 57%) 0.6422
42m 12s (- 30m 33s) (43500 57%) 0.5857
42m 18s (- 30m 28s) (43600 58%) 0.4270
42m 25s (- 30m 22s) (43700 58%) 0.5881
42m 31s (- 30m 17s) (43800 58%) 0.5016
42m 37s (- 30m 11s) (43900 58%) 0.5450
42m 43s (- 30m 6s) (44000 58%) 0.5006
42m 48s (- 29m 59s) (44100 58%) 0.5978
42m 51s (- 29m 51s) (44200 58%) 0.6138
42m 55s (- 29m 44s) (44300 59%) 0.5424
43m 0s (- 29m 38s) (44400 59%) 0.5382
43m 4s (- 29m 31s) (44500 59%) 0.6179
43m 8s (- 29m 24s) (44600 59%) 0.4931
43m 12s (- 29m 17s) (44700 59%) 0.5053
43m 16s (- 29m 10s) (44800 59%) 0.5585
43m 21s (- 29m 4s) (44900 59%) 0.6396
43m 28s (- 28m 58s) (45000 60%) 0.5934
43m 35s (- 28m 53s) (45100 60%) 0.4935
43m 42s (- 28m 48s) (45200 60%) 0.4960
43m 47s (- 28m 42s) (45300 60%) 0.5738
43m 54s (- 28m 37s) (45400 60%) 0.5533
44m 1s (- 28m 32s) (45500 60%) 0.5243
44m 5s (- 28m 25s) (45600 60%) 0.5417
44m 9s (- 28m 18s) (45700 60%) 0.

62m 40s (- 10m 18s) (64400 85%) 0.4051
62m 46s (- 10m 13s) (64500 86%) 0.4248
62m 53s (- 10m 7s) (64600 86%) 0.4892
62m 58s (- 10m 1s) (64700 86%) 0.4825
63m 5s (- 9m 55s) (64800 86%) 0.3390
63m 11s (- 9m 49s) (64900 86%) 0.3755
63m 17s (- 9m 44s) (65000 86%) 0.4289
63m 23s (- 9m 38s) (65100 86%) 0.3384
63m 31s (- 9m 32s) (65200 86%) 0.4437
63m 37s (- 9m 27s) (65300 87%) 0.3405
63m 43s (- 9m 21s) (65400 87%) 0.3701
63m 50s (- 9m 15s) (65500 87%) 0.4516
63m 56s (- 9m 9s) (65600 87%) 0.3538
63m 59s (- 9m 3s) (65700 87%) 0.2634
64m 6s (- 8m 57s) (65800 87%) 0.4020
64m 12s (- 8m 51s) (65900 87%) 0.3076
64m 18s (- 8m 46s) (66000 88%) 0.5003
64m 24s (- 8m 40s) (66100 88%) 0.4803
64m 31s (- 8m 34s) (66200 88%) 0.3723
64m 37s (- 8m 28s) (66300 88%) 0.4061
64m 43s (- 8m 22s) (66400 88%) 0.3288
64m 49s (- 8m 17s) (66500 88%) 0.5015
64m 55s (- 8m 11s) (66600 88%) 0.3109
65m 1s (- 8m 5s) (66700 88%) 0.4069
65m 8s (- 7m 59s) (66800 89%) 0.3871
65m 14s (- 7m 53s) (66900 89%) 0.3932
65m 21s (- 7m 48s

In [26]:
input_elem, output_elem, pairs = prepareSMILES(zinc_df['smiles'].values)
print(random.choice(pairs))

Read 249455 sentence pairs
Trimmed to 249455 sentence pairs
Counting words...
Counted words:
input_elem 36
output_elem 36
['Cc1cccc(NC(=O)[C@@H](C)SCCO)c1', 'Cc1cccc(NC(=O)[C@@H](C)SCCO)c1']


In [57]:
input_elem, output_elem, pairs = prepareSMILES(esol_df['smiles'].values)
print(random.choice(pairs))

Read 1128 sentence pairs
Trimmed to 1128 sentence pairs
Counting words...
Counted words:
input_elem 33
output_elem 33
['CC2Nc1cc(Cl)c(cc1C(=O)N2c3ccccc3C)S(N)(=O)=O ', 'CC2Nc1cc(Cl)c(cc1C(=O)N2c3ccccc3C)S(N)(=O)=O ']


In [31]:
%%time

hidden_size = 256
encoder_zinc = EncoderRNN(input_elem.n_words, hidden_size).to(device)
decoder_zinc = DecoderRNN(hidden_size, output_elem.n_words).to(device)
# attn_decoder1 = AttnDecoderRNN(hidden_size, output_lang.n_words, dropout_p=0.1).to(device)

train_losses_zinc = trainIters(encoder_zinc, decoder_zinc, 75000, print_every=1000)

1m 57s (- 145m 10s) (1000 1%) 2.6470
3m 51s (- 140m 38s) (2000 2%) 2.1069
5m 47s (- 139m 5s) (3000 4%) 2.0356
7m 47s (- 138m 26s) (4000 5%) 1.9442
9m 22s (- 131m 13s) (5000 6%) 1.9220
11m 18s (- 130m 8s) (6000 8%) 1.8527
13m 20s (- 129m 38s) (7000 9%) 1.8539
15m 16s (- 127m 55s) (8000 10%) 1.8075
17m 10s (- 125m 55s) (9000 12%) 1.7437
19m 5s (- 124m 7s) (10000 13%) 1.7192
20m 57s (- 121m 58s) (11000 14%) 1.7157
22m 43s (- 119m 16s) (12000 16%) 1.6665
24m 36s (- 117m 21s) (13000 17%) 1.6682
26m 29s (- 115m 27s) (14000 18%) 1.6320
28m 24s (- 113m 37s) (15000 20%) 1.5999
30m 17s (- 111m 41s) (16000 21%) 1.6093
31m 54s (- 108m 50s) (17000 22%) 1.5495
33m 54s (- 107m 21s) (18000 24%) 1.5103
35m 36s (- 104m 56s) (19000 25%) 1.5547
37m 18s (- 102m 36s) (20000 26%) 1.5002
39m 18s (- 101m 4s) (21000 28%) 1.4769
41m 5s (- 99m 0s) (22000 29%) 1.4842
42m 49s (- 96m 48s) (23000 30%) 1.4368
44m 39s (- 94m 54s) (24000 32%) 1.4193
46m 15s (- 92m 30s) (25000 33%) 1.4362
48m 18s (- 91m 3s) (26000 34%) 1

In [37]:
# plt.plot(train_losses_zinc)
# plt.xlabel('iter')
# plt.ylabel('loss')
# plt.show()
showPlot(train_losses_zinc)

### Stat.

In [19]:
def statSMILES(encoder, decoder, data_pairs):
    smiles_in = np.array(data_pairs)[:, 0]
    smiles_real = np.array(data_pairs)[:, 0]
    smiles_pred = []
    smiles_len_ele = []
    smiles_error_rate = []
    for index, pair in enumerate(smiles_in):
        output_words, _ = evaluate(encoder, decoder, pair)
        output_sentence = ''.join(output_words[:-1])
        smiles_pred.append(output_sentence)
        
        smiles_len_ele.append(len(output_sentence))
        err_elem = 0

        try:
            for index_char, char in enumerate(output_sentence):
                if char is not smiles_real[index][index_char]:
                    err_elem += 1
        except:
            err_elem += (len(output_sentence) - index_char)

        try:
            err_rate = err_elem/smiles_len_ele[index]
            smiles_error_rate.append(err_rate)
        except:
            smiles_error_rate.append(None)

    chem_AE_dict = {'real SMILES': smiles_real, 'predict SMILES': smiles_pred, 
                    'ERROR rate': smiles_error_rate, 'len elements': smiles_len_ele}
    return chem_AE_dict

In [58]:
AERNN_pred_structure_dict_emb = statSMILES(encoder, decoder, pairs[:20])
AERNN_pred_structure_train_df = pd.DataFrame(AERNN_pred_structure_dict_emb)
AERNN_pred_structure_train_df.head(15)

Unnamed: 0,ERROR rate,len elements,predict SMILES,real SMILES
0,0.584906,53,OCC3OCC2C(OCCCCC(O)cccccccc1ON(()O)C)C(CO)CO(C...,OCC3OC(OCC2OC(OC(C#N)c1ccccc1)C(O)C(O)C2O)C(O)...
1,0.136364,22,Cc1occc1C(=O)N2ccccccc,Cc1occc1C(=O)Nc2ccccc2
2,0.5,22,CC(C)=CCCCCC(C)=CC(=O),CC(C)=CCCC(C)=CC(=O)
3,0.307692,39,c1ccc2c(c1)ccccccccccccccccccccccc4cc34,c1ccc2c(c1)ccc3c2ccc4c5ccccc5ccc43
4,0.0,7,c1ccsc1,c1ccsc1
5,0.0,13,c2ccc1scnc1c2,c2ccc1scnc1c2
6,0.0,34,Clc1cc(Cl)c(c(Cl)c1)c2c(Cl)cccc2Cl,Clc1cc(Cl)c(c(Cl)c1)c2c(Cl)cccc2Cl
7,0.0,32,CC12CCC3C(CCc4cc(O)ccc34)C2CCC1O,CC12CCC3C(CCc4cc(O)ccc34)C2CCC1O
8,0.403846,52,ClC4=C(Cl)C5(Cl)C3C1CCC2CCCC2CCCCCCC(CCCCCl)(C...,ClC4=C(Cl)C5(Cl)C3C1CC(C2OC12)C3C4(Cl)C5(Cl)Cl
9,0.653846,52,COc5ccccCOCCCC1OCCcccccccccccccCcCcO()=cCC(3)C...,COc5cc4OCC3Oc2c1CC(Oc1ccc2C(=O)C3c4cc5OC)C(C)=C


In [56]:
AERNN_pred_structure_dict_emb = statSMILES(encoder_zinc, decoder_zinc, pairs[:20])
AERNN_pred_structure_train_df = pd.DataFrame(AERNN_pred_structure_dict_emb)
AERNN_pred_structure_train_df.head(15)

Unnamed: 0,ERROR rate,len elements,predict SMILES,real SMILES
0,0.666667,45,CC(C)(Cc1ccc(2(cc(((F)((F)C(=O)Nc3ccccc3)cc1F,CC(C)(C)c1ccc2occ(CC(=O)Nc3ccccc3F)c2c1
1,0.451613,31,C[C@@H]1CCN(c2cccn(CC3CC3)n2)n1,C[C@@H]1CC(Nc2cncc(-c3nncn3C)c2)C[C@@H](C)C1
2,0.510204,49,N#Cc1ccc(-c2c(C(=O)NCCCCCCCCCCCCccccccccccccccccc,N#Cc1ccc(-c2ccc(O[C@@H](C(=O)N3CCCC3)c3ccccc3)...
3,0.44898,49,CCOC(=O)[C@@H]1CCN(C(=O)c2ccc(-c3ccccccncCCCCCCC1,CCOC(=O)[C@@H]1CCCN(C(=O)c2nc(-c3ccc(C)cc3)n3c...
4,0.679487,78,N#C1C(=O(CC(=O)NcccccccCcCCCCCCCCCCCCCCCCCCCCC...,N#CC1=C(SCC(=O)Nc2cccc(Cl)c2)N=C([O-])[C@H](C#...
5,0.5,42,CC[NH+](C)CCCCCCCCCCCCCC@H]111cccccccccrcc,CC[NH+](CC)[C@](C)(CC)[C@H](O)c1cscc1Br
6,0.459459,37,COc1ccc(C(=O)N(C)C[C@@H]((C)C)OCO)cc1,COc1ccc(C(=O)N(C)[C@@H](C)C/C(N)=N/O)cc1O
7,0.621622,37,O=C(Nc1ncncc1[Nc1ccc(F)cc1)Nc1ncccc1F,O=C(Nc1nc[nH]n1)c1cccnc1Nc1cccc(F)c1
8,0.675,40,Cc1cc(/C=C/c2ccc((Brcccccccccccccccccccc,Cc1c(/C=N/c2cc(Br)ccn2)c(O)n2c(nc3ccccc32)c1C#N
9,0.489362,47,C[C@@H]1CN(C(=O)c2ccc(Br)cc2C)C[C@@H]1CCC[NH3+],C[C@@H]1CN(C(=O)c2cc(Br)cn2C)CC[C@H]1[NH3+]


## Previous code

In [None]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader, random_split
from torch.autograd import Variable
from rdkit import Chem
from deepchem.feat.one_hot import OneHotFeaturizer, zinc_charset
from scipy.stats import pearsonr
from sklearn.metrics import mean_absolute_error, r2_score
import matplotlib.pyplot as plt
import torch.nn.functional as F
import random
import timeit

In [None]:
def addSOS(arr):
    arr = np.append(arr, np.zeros((arr.shape[0], 1)), axis=1)
    arr = np.insert(arr, 0, 0, axis=0)
    arr[0][-1] = 1
    return arr

def onehotSMILES(featurizer, df, property_s):
    onehot_ID_list = []
    onehot_error_ID_list = []
    onehot_label_list = []
    onehot_feature_sos_list = []
    for index, label, mol in zip(list(df.index), list(df[property_s]), list(df['mols'])):
        try:
            feature = featurizer.featurize([mol])[0]
            if feature.shape[0] > 120:
                continue
            onehot_ID_list.append(index)
            onehot_label_list.append(label)
        except:
            onehot_error_ID_list.append(index)

    onehot_feature_list = featurizer.featurize(df['mols'][onehot_ID_list])
    print('---Finish OneHotFeaturizer---')
    print('Length of whole data is {}. Length of valid data is {}'.format(len(df), len(onehot_feature_list)))

    # Add SOS
    for i, x in enumerate(onehot_feature_list):
        onehot_feature_sos_list.append(addSOS(x))
    onehot_feature_list = onehot_feature_sos_list
    print('---Finish Add SOS---')

    onehot_feature_list_np = np.asarray(onehot_feature_list, dtype=np.float32)
    onehot_label_np = np.asarray(onehot_label_list, dtype=np.float32)
    onehot_feature_torch = torch.from_numpy(onehot_feature_list_np)
    onehot_feature_label_torch = torch.from_numpy(onehot_label_np)
    return onehot_feature_torch, onehot_feature_label_torch

In [None]:
# rdkit load SMILES and structures
dset_path = 'dataset'
esol_dset_path = os.path.join(dset_path, 'ESOL')
esol_dset = os.path.join(esol_dset_path, 'delaney-processed.csv')
zinc_dset = os.path.join(dset_path, '250k_rndm_zinc_drugs_clean_3.csv')

# ESOL Dataset
esol_df = pd.DataFrame.from_csv(esol_dset)
esol_solu_df = esol_df[['smiles', 'measured log solubility in mols per litre']]
# esol_solu_df['smiles_with_token'] = 's' + esol_solu_df['smiles'] + 'e'
esol_solu_df['mols'] = esol_solu_df.apply(lambda x: Chem.MolFromSmiles(x.smiles), axis=1)

# VAE Dataset
zinc_df = pd.DataFrame.from_csv(zinc_dset, index_col=None)
zinc_df['mols'] = zinc_df.apply(lambda x: Chem.MolFromSmiles(x.smiles), axis=1)

# ----------One hot SMILES----------
featurizer_onehot = OneHotFeaturizer(zinc_charset, 120)

In [None]:
onehot_feature_torch, onehot_feature_label_torch = onehotSMILES(featurizer_onehot, esol_solu_df, 'measured log solubility in mols per litre')

In [None]:
class EncRNN(nn.Module):
    def __init__(self, input_dim, n_hidden, n_layers):
        super(AutoEncRNNPred, self).__init__()
        self.input_dim = input_dim
        self.n_hidden = n_hidden
        self.seq_len = seq_len

        self.enc_embedding = nn.Embedding(input_dim, n_hidden)
        self.enc_rnn = nn.RNN(input_size=n_hidden, hidden_size=n_hidden, num_layers=n_layers, batch_first=True)

    def forward(self, x, enc_hidden_state):
        embedded = self.enc_embedding(x).view(1, 1, -1)
        encoded, enc_hidden_state = self.enc_rnn(embedded, enc_hidden_state)
        return encoded, enc_hidden_state

In [None]:
class DecRNN(nn.Module):
    def __init__(self, input_dim, n_hidden, n_layers, seq_len, output_dim, pred_dim):
        super(AutoEncRNNPred, self).__init__()
        self.input_dim = input_dim
        self.n_hidden = n_hidden
        self.seq_len = seq_len
        self.output_dim = output_dim
        self.teacher_forcing_ratio = 0.8

        self.dec_embedding = nn.Embedding(input_dim, n_hidden)
        self.dec_rnn = nn.GRU(input_size=input_dim, hidden_size=n_hidden, num_layers=n_layers, batch_first=True)
        self.dec_out = nn.Linear(n_hidden, input_dim)

    def forward(self, encoded, hidden):
        embedded = encoded
        # embedded = self.dec_embedding(encoded).view(1, 1, -1)
        # embedded = F.relu(embedded)
        decoded, hidden = self.dec_rnn(embedded, hidden)
        decoded = self.dec_out(decoded)
#         decoded = self.log_softmax(decoded)
        return decoded, hidden

In [None]:
def train(model, train_loader, optimizer, loss_func_pred, loss_func_dec, epoch, model_status_dict):
    model.train()
    seq_len, self.n_hidden = model.seq_len, model.self.n_hidden
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = Variable(data), Variable(target.reshape(-1, 1))
        # print(target.shape, target.dtype)
        
        # one-hot(batch, seq_len, feature_size) to int(batch, 1, feature_size)
        # _, data_feature_int = torch.max(data, 2)
        data_feature_int = np.argmax(data, axis=2)
        
        encoded = torch.zeros(data_feature_int.shape[0], seq_len, n_hidden)
        for k in range(seq_len):
            encoded_out, enc_hidden_state = self.encoder(x[:, k], enc_hidden_state)
            encoded[:, k] = encoded_out[:, 0]
        
        encode, decode, predict, h_state = model(data_feature_int, None, None, data)  # get output for every net
        # h_state = h_state.data  # repack the hidden state, break the connection from last iteration

        optimizer.zero_grad()  # clear gradients for next train
        loss_pred = 0
        # loss_pred = loss_func_pred(predict, target)
        loss_dec = loss_func_dec(decode.transpose(1, 2), np.argmax(data, axis=2))
        loss = loss_pred + loss_dec
#         print('decode.transpose', decode.transpose(1,2))
#         print('argmax', np.argmax(data, axis=2))
        loss.backward()  # backpropagation, compute gradients
        optimizer.step()  # apply gradients

        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader),
                loss.data.item()))
            model_status_dict['Epoch'].append(epoch)
            model_status_dict['Batch_idx'].append(batch_idx)
            model_status_dict['Loss'].append(loss.data.item())
    return model_status_dict


def test(model, test_loader):
    model.eval()
    real_data = None
    real_target = None
    pred_data = None
    pred_target =None
    with torch.no_grad():
        for (data, target) in test_loader:
            # one-hot(batch, seq_len, feature_size) to int(batch, 1, feature_size)
            _, data_feature_int = torch.max(data, 2)
            
            encode, decode, predict, h_state = model(data_feature_int, None, None, data)  # get output for every net

            data = data.data.numpy()
            target = target.data.numpy().reshape(-1, 1)
            real_data = data if real_data is None else np.concatenate((real_data, data))
            real_target = target if real_target is None else np.concatenate((real_target, target))

            y_data = decode.data.numpy()
#             y_target = predict.data.numpy()
            pred_data = y_data if pred_data is None else np.concatenate((pred_data, y_data))
#             pred_target = y_target if pred_target is None else np.concatenate((pred_target, y_target))
    return real_data, pred_data

In [None]:
class AutoEncRNNPred(nn.Module):
    def __init__(self, input_dim, n_hidden, n_layers, seq_len, output_dim, pred_dim):
        super(AutoEncRNNPred, self).__init__()
        self.input_dim = input_dim
        self.n_hidden = n_hidden
        self.seq_len = seq_len
        self.output_dim = output_dim
        self.index2word = {0: 'SOS', 1: 'EOS'}
        self.n_words = 2  # Count SOS and EOS
        self.teacher_forcing_ratio = 0.8

        # Encode(RNN)
        self.enc_embedding = nn.Embedding(input_dim, n_hidden)
        self.enc_rnn = nn.RNN(input_size=n_hidden, hidden_size=n_hidden, num_layers=n_layers, batch_first=True)

        # Predict
#         self.enc_linear = nn.Linear(seq_len * n_hidden, output_dim)
#         self.pred = nn.Linear(output_dim, pred_dim)

        # Decode(RNN)
        self.dec_embedding = nn.Embedding(input_dim, n_hidden)
        
        self.dec_linear = nn.Linear(seq_len * n_hidden, input_dim)
        self.dec_rnn = nn.GRU(input_size=input_dim, hidden_size=n_hidden, num_layers=n_layers, batch_first=True)
        self.dec_out = nn.Linear(n_hidden, input_dim)
#         self.log_softmax = nn.LogSoftmax(dim=1) # NllLoss
        
    def encoder(self, input_vec, hidden):
#         embedded = self.enc_embedding(input_vec)
        embedded = self.enc_embedding(input_vec).view(input_vec.shape[0], 1, -1)
        output = embedded
        rnn_out, hidden = self.enc_rnn(output, hidden)
        return rnn_out, hidden
        
    def decoder(self, encoded, hidden):
        embedded = encoded
        # embedded = self.dec_embedding(encoded).view(1, 1, -1)
        # embedded = F.relu(embedded)
        decoded, hidden = self.dec_rnn(embedded, hidden)
        decoded = self.dec_out(decoded)
#         decoded = self.log_softmax(decoded)
        return decoded, hidden

    def forward(self, x, enc_hidden_state, dec_hidden_state, targets):
        encoded = torch.zeros(x.shape[0], self.seq_len, self.n_hidden)
        for k in range(self.seq_len):
            encoded_out, enc_hidden_state = self.encoder(x[:, k], enc_hidden_state)
            encoded[:, k] = encoded_out[:, 0]
        
        decoded_output_on_k, predicted, dec_hidden_state = None, None, None
        
        dec_hidden_state = enc_hidden_state
        decoded_input = torch.zeros(targets.shape[0], self.seq_len, self.input_dim)
        decoded_input[:, 0, -1] = 1
        
#         encoded = encoded.view(-1, encoded.shape[1] * encoded.shape[2])
#         decoded_input = self.dec_linear(encoded)
#         decoded_input = decoded_input[:, None, :]
#         decoded_input = decoded_input.repeat(1, self.seq_len, 1)
    
        # Teacher Forcing
        use_teacher_forcing = True if random.random() < self.teacher_forcing_ratio else False
#         use_teacher_forcing = False
        for k in range(1, self.seq_len):
            decoded_output_on_k, dec_hidden_state = self.decoder(decoded_input, dec_hidden_state)
            if use_teacher_forcing:
                decoded_input = torch.zeros(targets.shape[0], self.seq_len, self.input_dim)
                decoded_input[:, :k, :] = targets[:, :k, :]
            else:
                decoded_input = decoded_output_on_k

        return encoded, decoded_output_on_k, predicted, dec_hidden_state