In [3]:
import keras
from keras.models import Model
from keras.layers import Input, LSTM, Dense
import numpy as np
import codecs
import argparse
import joblib
import pandas as pd
import pickle
from elapsedtimer import ElapsedTimer

In [8]:
global num_encoder_words
global num_decoder_words
global max_encoder_sequence_length
global max_decoder_sequence_length
global input_word_index
global target_word_index
global reverse_input_word_dict
global reverse_target_word_dict

In [12]:
global path
global num_epochs
global batch_size
global latent_dim
global num_samples
global outdir
global verbose
global mode

In [13]:
def read_input_file(path, num_samples=10e13):
    input_texts = []
    target_texts = []
    input_words = set()
    target_words = set()
    
    with codecs.open(path, 'r', encoding='utf-8') as f:
        lines = f.read().split('\n')
        
    for line in lines[: min(num_samples, len(lines) - 1)]:
        input_text, target_text = line.split('\t')
        target_text = '\t' + target_text + '\n'
        input_texts.append(input_text)
        target_texts.append(target_text)
        
        for word in input_text.split(" "):
            if word not in input_words:
                input_words.add(word)
        
        for word in target_text.split(" "):
            if word not in target_words:
                target_words.add(word)
    
    return input_texts, target_texts, input_words, target_words

In [14]:
def vocab_generation(path, num_samples, verbose=True):
    global num_encoder_words
    global num_decoder_words
    global max_encoder_sequence_length
    global max_decoder_sequence_length
    global input_word_index
    global target_word_index
    global reverse_input_word_dict
    global reverse_target_word_dict
    input_texts, target_texts, input_words, target_words = read_input_file(path, num_samples)
    input_words = sorted(list(input_words))
    target_words = sorted(list(target_words))
    num_encoder_words = len(input_words)
    num_decoder_words = len(target_words)
    max_encoder_sequence_length = max([len(txt.split(" ")) for txt in input_texts])
    max_decoder_sequence_length = max([len(txt.split(" ")) for txt in target_texts])
    if verbose == True:
        print("Number of samples: {} \n".format(len(input_texts)))
        print("Number of unique input tokens: {} \n".format(len(input_words)))
        print("Number of unique output tokens: {} \n".format(len(output_words)))
        print("Max sequence length for inputs: {} \n".format(max_encoder_sequence_length))
        print("Max sequence length for outputs: {} \n".format(max_decoder_sequence_length))
        
    input_word_index = dict([(word, i) for i, word in enumerate(input_words)])
    target_word_index = dict([(word, i) for i, word in enumerate(target_words)])
    reverse_input_word_dict = dict((i, word) for word, i in input_word_index.items())
    reverse_target_word_dict = dict((i, word) for word, i in target_word_index.items())

In [15]:
def process_input(input_texts, target_texts=None, verbose=True):
    global max_encoder_sequence_length
    global num_encoder_words
    global num_decoder_words
    global mode
    global input_word_index
    global target_word_index
    encoder_input_data = np.zeros((len(input_texts), max_encoder_sequence_length, num_encoder_words), dtype='float32')
    decoder_input_data = np.zeros((len(input_texts), max_decoder_sequence_length, num_decoder_words), dtype='float32')
    decoder_target_data = np.zeros((len(input_texts), max_decoder_sequence_length, num_decoder_words), dtype='float32')
    
    if mode == 'train':
        for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)):
            for t, word in enumerate(input_text.split(" ")):
                try:
                    encoder_input_data[i, t, input_word_index[word]] = 1
                except:
                    print(f'word {word} encountered for the first time, skipped')
            for t, word in enumerate(target_text.split(" ")):
                decoder_input_data[i, t, target_word_index[word]] = 1
                if t > 0:
                    try:
                        decoder_target_data[i, t-1, target_word_index[word]] = 1
                    except:
                        print(f'word {word} is encountered for the first time, skipped')
        return encoder_input_data, decoder_input_data, decoder_target_data, np.array(input_texts), np.array(target_texts)
    else:
        for i, input_text in enumerate(input_texts):
            for t, word in enumerate(input_text.split(" ")):
                try:
                    encoder_input_data[i, t, input_word_index[word]] = 1
                except:
                    print(f'word {word} is encountered for the first time, skipped')
        return encoder_input_data, None, None, np.array(input_texts), None

In [16]:
def model_enc_dec():
    global num_encoder_words
    global num_decoder_words
    global latent_dim
    global outdir
    encoder_input = Input(shape=(None, num_encoder_words), name='encoder_input')
    encoder = LSTM(latent_dim, return_state=True, name='encoder')
    encoder_out, state_h, state_c = encoder(encoder_input)
    encoder_states = [state_h, state_c]
    
    decoder_input = Input(shape=(None, num_decoder_words), name='decoder_input')
    decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True, name='decoder_lstm')
    decoder_out, _, _ = decoder_lstm(decoder_input, initial_state=encoder_states)
    decoder_dense = Dense(num_decoder_words, activation='softmax', name='decoder_dense')
    decoder_out = decoder_dense(decoder_out)
    print(np.shape(decoder_out))
    
    model = Model([encoder_input, decoder_input], decoder_out)
    encoder_model = Model(encoder_input, encoder_states)
    decoder_input_h = Input(shape=(latent_dim, ))
    decoder_input_c = Input(shape=(latent_dim, ))
    decoder_input_state = [decoder_input_h, decoder_input_c]
    decoder_out, decoder_out_h, decoder_out_c = decoder_lstm(decoder_input, initial_state=decoder_input_state)
    decoder_out = decoder_dense(decoder_out)
    decoder_out_state = [decoder_out_h, decoder_out_c]
    decoder_model = Model(inputs=[decoder_input] + decoder_inp_state, output=[decoder_out] + decoder_out_state)
    plot_model(model, show_shapes = True, to_file=outdir + 'encoder_decoder_training_model.png')
    plot_model(encoder_model, show_shapes=True, to_file=outdir + 'encoder_model.png')
    plot_model(decoder_model, show_shapes=True, to_file=outdir + 'decoder_model.png')
    
    return model, encoder_model, decoder_model

In [19]:
def train(encoder_input_data, decoder_input_data, decoder_target_data):
    global batch_size
    global num_epochs
    global outdir
    print("Training...", end="\n")
    model, encoder_model, decoder_model = model_enc_dec()
    model.compile(optimizer='rmsprop', loss='categorical_crossentropy')
    model.fit([encoder_input_data, decoder_input_data], decoder_target_data, batch_size=batch_size, epochs=num_epochs, validation_split=0.2)
    model.save(outdir+'eng_to_french_dumm.h5')
    return model, encoder_model, decoder_model

In [20]:
def train_test_split(num_recs, train_frac=0.8):
    rec_indices = np.arange(num_recs)
    np.random.shuffle(rec_indices)
    train_count = int(num_recs*0.8)
    train_indices = rec_indices[:train_count]
    test_indices = rec_indices[train_count:]
    return train_indices, test_indices

In [21]:
def decode_sequence(input_sequence, encoder_model, decoder_model):
    global num_encoder_words
    global num_decoder_words
    global reverse_target_word_dict
    global max_decoder_sequence_length
    states_value = encoder_model.predict(input_sequence)
    target_sequence = np.zeros((1,1,num_encoder_words))
    target_sequence[0,0,target_word_index['\t']] = 1
    stop_condition = False
    decoded_sentence = ''
    
    while not stop_condition:
        output_word, h, c = decoder_model.predict([target_sequence] + state_value)
        sampled_word_index = np.argmax(output_word[0, -1:])
        sampled_char = reverse_target_word_dict[sampled_word_index]
        decoded_sentence = decoded_sentence + ' ' + sampled_char
        if (sampled_char == '\n' or len(decoded_sentence) > max_decoder_sequence_length):
            stop_condition = True
        target_sequence = np.zeros((1,1,num_decoder_words))
        target_sequence[0,0,sampled_word_index] = 1
        state_value = [h, c]
        