In [None]:
with open("/kaggle/input/chembl22/chembl_22_clean_1576904_sorted_std_final.smi") as file_handle:
    
    data = file_handle.readlines()

In [None]:
from concurrent.futures import ThreadPoolExecutor
import re
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [None]:
import tensorflow as tf
from keras.layers import TextVectorization
from keras.layers import Input, Embedding, LSTM, Attention, Concatenate, Dense
from keras.models import Model,Sequential
from keras.utils import to_categorical

In [None]:
tf.__version__

In [None]:
os.cpu_count()

In [None]:
def preprocess_molecule(smiles_str):
    return "<"+smiles_str.split("\t")[0]+"\n"

In [None]:
with ThreadPoolExecutor(max_workers=os.cpu_count()) as pool:
    data = list(pool.map(preprocess_molecule,data))

In [None]:
def convert_to_source_str(molecule_str):
    return molecule_str.strip("\n")

In [None]:
with ThreadPoolExecutor(max_workers=os.cpu_count()) as pool:
    X = list(pool.map(convert_to_source_str,data))

In [None]:
def convert_to_dst_str(molecule_str):
    return molecule_str.strip("<")

In [None]:
with ThreadPoolExecutor(max_workers=os.cpu_count()) as pool:
    Y = list(pool.map(convert_to_dst_str,data))

In [None]:
data = pd.read_csv("/kaggle/input/smiles-molecule-data-v3/smiles_molecule_data_V3.csv")

In [None]:
vocabulary = list(set("".join(list(data["X"]))).union(set("".join(list(data["Y"])))))

In [None]:
len_list = list(map(lambda x: len(x),data["X"]))

In [None]:
plt.hist(x=len_list,bins=len(set(len_list)))
plt.xlabel("Lengths of Different Molecules")
plt.ylabel("Frequency")

In [None]:
len_hyp_param = 52

In [None]:
filtered_data = data.iloc[np.argwhere(np.array(len_list) <= len_hyp_param)[:,0]]

In [None]:
filtered_data.iloc[0:1000000,:].to_csv("train.csv",index=False)
filtered_data.iloc[1000000:,:].to_csv("test.csv",index=False)

In [None]:
max_input_sequence_len = len(max(list(filtered_data["X"]),key=len))
max_output_sequence_len = len(max(list(filtered_data["Y"]),key=len))

In [None]:
input_text_vectorization_layer = TextVectorization(max_tokens=len(vocabulary)+2,standardize=None,
                                            split="character",
                                            output_sequence_length=max_input_sequence_len,
                                            vocabulary=vocabulary)

In [None]:
output_text_vectorization_layer = TextVectorization(max_tokens=len(vocabulary)+2,standardize=None,
                                                   split="character",
                                                   output_sequence_length=max_output_sequence_len,
                                                   vocabulary=vocabulary)

In [None]:
input_vocabulary = input_text_vectorization_layer.get_vocabulary()

In [None]:
output_vocabulary = output_text_vectorization_layer.get_vocabulary()

In [None]:
X = input_text_vectorization_layer(filtered_data["X"]).numpy()

In [None]:
Y = output_text_vectorization_layer(filtered_data["Y"]).numpy()

In [None]:
X_train = X[0:1000000]
Y_train = Y[0:1000000]

X_cv = X[1000000:]
Y_cv = Y[1000000:]

In [None]:
def encoder_decoder_with_attn_mech():
    
    enc_input = Input(shape=(None,),name="input_to_encoder")
    enc_embedding = Embedding(input_dim=len(vocabulary)+2,
                              output_dim=(len(vocabulary)+2)//2,
                              input_length=max_input_sequence_len,
                              name="encoder_embedding_layer")(enc_input)
    enc_lstm_output,enc_last_hidden_state,enc_last_cell_state = LSTM(units=(len(vocabulary)+2)//2,
                                                       return_state=True, return_sequences=True,
                                                       name="encoder_lstm_layer")(enc_embedding)
    
    
    dec_input = Input(shape=(None,),name="input_to_decoder")
    dec_embedding = Embedding(input_dim=len(vocabulary)+2,
                              output_dim=(len(vocabulary)+2)//2,
                              input_length=max_output_sequence_len,
                              name="decoder_embedding_layer")(dec_input)
    dec_lstm_layer = LSTM(units=(len(vocabulary)+2)//2,return_sequences=True,
                          return_state=True,name="decoder_lstm_layer")
    dec_lstm_output,_,_ = dec_lstm_layer(inputs=dec_embedding,
                                         initial_state=[enc_last_hidden_state,enc_last_cell_state])
    
    
    dec_enc_attn_seq = Attention()([dec_lstm_output,enc_lstm_output])
    dec_dense_input = Concatenate()([dec_lstm_output,dec_enc_attn_seq])
    
    dec_output = Dense(units=len(vocabulary)+2,activation="softmax",
                       name="decoder_output")(dec_dense_input)
    
    
    return Model(inputs=[enc_input,dec_input],outputs=dec_output)

In [None]:
seq2seq_encoder_decoder = encoder_decoder_with_attn_mech()

In [None]:
Y_train = to_categorical(Y_train,num_classes=len(vocabulary)+2)

In [None]:
Y_cv = to_categorical(Y_cv,num_classes=len(vocabulary)+2)

In [None]:
def training_data_generator(mb_size,epochs):
    
    for _ in range(epochs):
    
        for i in range(X_train.shape[0]//mb_size):

            yield [X_train[i*mb_size:(i+1)*mb_size],X_train[i*mb_size:(i+1)*mb_size]],Y_train[i*mb_size:(i+1)*mb_size]

In [None]:
def cv_data_generator(mb_size,epochs):
    
    for _ in range(epochs):
    
        for i in range(X_cv.shape[0]//mb_size):

            yield [X_cv[i*mb_size:(i+1)*mb_size],X_cv[i*mb_size:(i+1)*mb_size]],Y_cv[i*mb_size:(i+1)*mb_size]

In [None]:
seq2seq_encoder_decoder.compile(loss="categorical_crossentropy",metrics=["Accuracy"])

In [None]:
seq2seq_encoder_decoder.fit(training_data_generator(1000,25),epochs=25,steps_per_epoch=1000,
                           validation_data=cv_data_generator(1711,25),validation_steps=18)

In [None]:
seq2seq_encoder_decoder.save("enc_dec_drug_molecule_gen.keras")

In [None]:
our_model = tf.keras.models.load_model("/kaggle/working/enc_dec_drug_molecule_gen.keras")

In [None]:
def inference_encoder():

    enc_input = our_model.input[0]
    enc_embedding = our_model.layers[2](enc_input)
    enc_lstm_layer = our_model.layers[4]
    enc_lstm_output,enc_last_hidden_state,enc_last_cell_state = enc_lstm_layer(inputs=enc_embedding)
    
    inf_enc = Model(inputs=enc_input,outputs=[enc_lstm_output,enc_last_hidden_state,enc_last_cell_state])
    return inf_enc

In [None]:
def inference_decoder():
    
    dec_input = our_model.input[1]
    another_dec_input = Input(shape=(max_input_sequence_len,(len(vocabulary)+2)//2))
    dec_initial_hidden_state = Input(shape=((len(vocabulary)+2)//2,))
    dec_initial_cell_state = Input(shape=((len(vocabulary)+2)//2,)) 
    dec_embedding = our_model.layers[3](dec_input)
    dec_lstm_layer = our_model.layers[5]
    dec_lstm_output,dec_last_hidden_state,dec_last_cell_state = dec_lstm_layer(inputs=dec_embedding,
                                         initial_state=[dec_initial_hidden_state,dec_initial_cell_state])
    dec_enc_attn_seq = our_model.layers[6]([dec_lstm_output,another_dec_input])
    dec_dense_input = our_model.layers[7]([dec_lstm_output,dec_enc_attn_seq])
    dec_output = our_model.layers[8](dec_dense_input)

    inf_dec = Model(inputs=[dec_input,another_dec_input,dec_initial_hidden_state,dec_initial_cell_state],
               outputs=[dec_output,dec_last_hidden_state,dec_last_cell_state])
    
    return inf_dec

In [None]:
def generate_molecules(enc_inp_sequence,batch_size):
    
    inf_enc = inference_encoder()
    states = inf_enc.predict(enc_inp_sequence)
    enc_output = states[0]
    states.pop(0)
    gen_sequence = np.array([[input_vocabulary.index("<")]*batch_size])
    
    stop_generation = False
    generated_molecule= str()
    
    inf_dec = inference_decoder()
    
    while not stop_generation:
        
        dec_output,dec_last_hidden_state,dec_last_cell_state = inf_dec.predict([gen_sequence,enc_output]+
                                                                               states)
        nxt_gen_char_idx = np.argmax(dec_output[0,-1,:])
        nxt_gen_char = output_vocabulary[nxt_gen_char_idx]
        generated_molecule += nxt_gen_char
        
        if (nxt_gen_char == "\n") or (len(generated_molecule) == max_output_sequence_len):
            stop_generation = True
            
        gen_sequence = np.array([[nxt_gen_char_idx]*batch_size])
        
        states = [dec_last_hidden_state,dec_last_cell_state]
        
    return generated_molecule

In [None]:
generate_molecules(X_cv[0,:].reshape(1,X_cv.shape[1]),1)