In [1]:
from os.path import join

import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_addons as tfa

from sklearn.model_selection import train_test_split
import tensorflow.keras.backend as K

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, LSTM, Dense, Embedding, Bidirectional, LSTMCell
from tensorflow.keras.losses import SparseCategoricalCrossentropy

from generate_uncorrect_sample import generate_misspell_sample

### Define class for creating and loading data

In [2]:
def loss_fn(y_pred, y):
    log_loss = SparseCategoricalCrossentropy(from_logits=True, reduction='none')
    loss = log_loss(y_true=y, y_pred=y_pred)
    
    mask = tf.logical_not(tf.math.equal(y, 0)) # output 0 for y=0 else output 1
    mask = tf.cast(mask, dtype=loss.dtype)
    loss = mask * loss
    loss = tf.reduce_mean(loss)
    return loss

def generate_pair_samples(w):
    w_gen = list(generate_misspell_sample(w, max_edit_distance=2))
    return list(zip(w_gen, [w]*len(w_gen)))

In [3]:
class Text2Seq(object):
    
    def __init__(self, charset, 
                 start_token='<s>',
                 end_token='<e>',
                 unknown_token='<unk>'):
        
        self.start_token = start_token
        self.end_token = end_token
        self.unk_token = unknown_token
        if isinstance(charset, str):
            with open(charset, 'r+') as f:
                self.charset = set(f.read().split('\n'))
        else:
            self.charset = charset
        self.charset += [' ', self.start_token, self.end_token, self.unk_token]
        self.charset = set(self.charset)
        self.charset_size = len(self.charset)
        
        self.char2id = {j: i for i, j in enumerate(self.charset, start=1)}
        self.id2char = {j: i for i, j in self.char2id.items()}
        
    def _encode(self, word, max_len, pad_start_end):
        padded = []
        for c in word:
            padded.append(self.char2id.get(c, self.char2id[self.unk_token]))
        if pad_start_end:
            padded = [self.char2id[self.start_token]] + padded + [self.char2id[self.end_token]]
            padded += (max_len + 2 - len(padded)) * [0]
        else:
            padded += (max_len - len(padded)) * [0]
        return padded
    
    def fit_on_texts(self, texts, pad_start_end=False):
        max_len = self.get_max_seq_len(texts)
        
        arr = []
        for word in texts:
            arr.append(self._encode(word, max_len, pad_start_end))
        return np.array(arr, dtype=np.int8)
    
    @staticmethod
    def get_max_seq_len(texts):
        return max(len(word) for word in texts)
    
    def sequence_to_text(self, arr, remove_endtoken=False):
        def _inside(arr):
            word = []
            for i in arr:
                if i!=0:
                    if remove_endtoken:
                        if i==self.char2id.get(self.end_token):
                            break
                    word.append(self.id2char.get(i, self.unk_token))
            return ''.join(word)
        
        result = []
        for a in arr:
            result.append(_inside(a))
        return result

### Create dataset for training

In [13]:
pairs = []
correct_words = ['có thể', 'thế giới', 'con người', 'không thể', 'tất cả', 'chúng ta']
for w in correct_words:
    pairs.extend(generate_pair_samples(w))
df = pd.DataFrame(pairs, columns=['misspell', 'correct']).sample(frac=1, random_state=123)
charset = list(set(''.join(df.misspell.values+df.correct.values)))

text2seq = Text2Seq(charset)
X_train, X_test, Y_train, Y_test = train_test_split(text2seq.fit_on_texts(df.misspell.values), 
                                                    text2seq.fit_on_texts(df.correct.values, pad_start_end=True), 
                                                    test_size=0.1)

In [16]:
BATCH_SIZE = 4
BUFFER_SIZE = len(X_train)
steps_per_epoch = BUFFER_SIZE // BATCH_SIZE
embedding_dims = 64
rnn_units = dense_units = 64

Tx = X_train.shape[1]
Ty = Y_train.shape[1]

input_vocab_size = output_vocab_size = text2seq.charset_size+1

train_dataset = tf.data.Dataset.from_tensor_slices((X_train, Y_train)).shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
test_dataset = tf.data.Dataset.from_tensor_slices((X_test, Y_test)).shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)

### Define model

In [17]:
class Encoder(Model):
    
    def __init__(self, input_vocab_size=None, embedding_dims=128, rnn_units=64):
        super(Encoder, self).__init__()
        self.encoder_embedding = Embedding(input_vocab_size, embedding_dims)
        self.encoder_birnn = Bidirectional(LSTM(rnn_units, return_sequences=True, dropout=0.2))
        self.encoder_stackrnn = LSTM(rnn_units, return_sequences=True, return_state=True)
        
    def call(self, inputs):
        x = self.encoder_embedding(inputs)
        x = self.encoder_birnn(x)
        x = self.encoder_stackrnn(x)
        return x

class Decoder(Model):
    
    def __init__(self, 
                 output_vocab_size=None, 
                 embedding_dims=128, 
                 rnn_units=64, 
                 dense_units=64, 
                 batch_size=128,
                 encoder_max_seq_len=None,
                 decoder_max_seq_len=None,
                 start_token=None,
                 end_token=None,
                 beam_width=5,
                 training=True):
        super().__init__()
        self.batch_size = batch_size
        self.decoder_max_seq_len = decoder_max_seq_len
        
        self.decoder_embedding = Embedding(output_vocab_size, embedding_dims)
        self.dense_layer = Dense(output_vocab_size)
        self.rnn_cell = LSTMCell(rnn_units)
        
        self.start_token = start_token
        self.end_token = end_token
        self.beam_width = beam_width
        self.training = training
        
        # training phase
        self.sampler = tfa.seq2seq.sampler.TrainingSampler()
        self.attn_mech = tfa.seq2seq.LuongAttention(dense_units, 
                                                    None, 
                                                    self.batch_size * [encoder_max_seq_len])
        self.attn_cell = tfa.seq2seq.AttentionWrapper(self.rnn_cell,
                                                     self.attn_mech,
                                                     dense_units)
        self.decoder = tfa.seq2seq.BasicDecoder(self.attn_cell, self.sampler, self.dense_layer)

    def set_decoder_memory_and_initialState(self, memory, batch_size, encoder_state):
        self.attn_mech.setup_memory(memory)
        decoder_initial_state = self.attn_cell.get_initial_state(batch_size=batch_size, dtype=tf.float32)
        decoder_initial_state = decoder_initial_state.clone(cell_state=encoder_state)
        return decoder_initial_state
    
    def call(self, inputs):
        d_in, encoder_outputs, state_h, state_c = inputs
        
        if self.training:
            decoder_emb = self.decoder_embedding(d_in)

            decoder_initial_state = self.set_decoder_memory_and_initialState(encoder_outputs, 
                                                                             self.batch_size, 
                                                                             [state_h, state_c])
            outputs, _, _ = self.decoder(decoder_emb, 
                                         initial_state=decoder_initial_state, 
                                         sequence_length=self.batch_size * [self.decoder_max_seq_len - 1])
            logits = outputs.rnn_output
            return logits
        else:
            inference_batch_size = 1
            _ = self.decoder_embedding(d_in)
            encoder_state_beam = tfa.seq2seq.tile_batch([state_h, state_c], self.beam_width)
            encoder_outputs_beam = tfa.seq2seq.tile_batch(encoder_outputs, self.beam_width)

            decoder_initial_state = self.set_decoder_memory_and_initialState(encoder_outputs_beam, 
                                                                             inference_batch_size*self.beam_width, 
                                                                             encoder_state_beam)
            decoder_instance = tfa.seq2seq.BeamSearchDecoder(self.attn_cell, 
                                                             beam_width=self.beam_width, 
                                                             output_layer=self.dense_layer)

            start_tokens = tf.fill([inference_batch_size], self.start_token)
            end_token = self.end_token
            _, inputs, state = decoder_instance.initialize(self.decoder_embedding.variables[0] ,
                                                         start_tokens=start_tokens,
                                                         end_token=end_token,
                                                         initial_state=decoder_initial_state)
            
            beam_ids = []
            beam_scores = []
            for j in range(self.decoder_max_seq_len):
                beam_output, state, inputs, _ = decoder_instance.step(j, inputs, state)
                beam_ids.append(beam_output.predicted_ids)
                beam_scores.append(beam_output.scores)
            return beam_ids, beam_scores

class EncoderDecoder():
    
    def __init__(self,
                 input_vocab_size=None,
                 output_vocab_size=None, 
                 embedding_dims=128, 
                 rnn_units=64, 
                 dense_units=64, 
                 batch_size=128,
                 encoder_max_seq_len=None,
                 decoder_max_seq_len=None,
                 start_token=None,
                 end_token=None,
                 beam_width=5,
                 training=None):
        
        self.start_token = start_token
        self.batch_size = batch_size
        self.training = training
        
        self.encoder = Encoder(input_vocab_size=input_vocab_size, 
                               embedding_dims=embedding_dims, 
                               rnn_units=rnn_units)
        self.decoder = Decoder(output_vocab_size=output_vocab_size, 
                                embedding_dims=embedding_dims, 
                                rnn_units=rnn_units, 
                                dense_units=dense_units, 
                                batch_size=batch_size,
                                encoder_max_seq_len=encoder_max_seq_len,
                                decoder_max_seq_len=decoder_max_seq_len,
                                start_token=start_token,
                                end_token=end_token,
                                beam_width=beam_width,
                                training=training
                              )
        
    def __call__(self, inputs):
        # encode phase
        e_in, d_in = inputs
        e_out, state_h, state_c = self.encoder(e_in)
        # decode phase
        return self.decoder([d_in, e_out, state_h, state_c])
            
    def compile(self, optimizer, loss=None, metrics=None):
        self.optimizer = optimizer
        self.loss_fn = loss
        self.metrics = metrics
    
    def _step(self, x_batch, y_batch):
        d_in = y_batch[:, :-1]  # ignore <end>
        d_out = y_batch[:, 1:]  # ignore <start>

        logits = self([x_batch, d_in])
        loss = self.loss_fn(logits, d_out)
        return loss

    @tf.function
    def train_step(self, x_batch, y_batch):
        with tf.GradientTape() as tape:
            loss = self._step(x_batch, y_batch)
        vars_ = self.encoder.trainable_variables + self.decoder.trainable_variables # be careful
        grads = tape.gradient(loss, vars_)
        self.optimizer.apply_gradients(zip(grads, vars_))
        return loss
    
    def fit(self, train_dataset, epochs=1, eval_dataset=None):
        num_train_samples = tf.data.experimental.cardinality(train_dataset).numpy() * self.batch_size
        for epoch in range(epochs):
            print("\nepoch {}/{}".format(epoch+1,epochs))
            pbar = tf.keras.utils.Progbar(num_train_samples, stateful_metrics=['train_loss'])

            for i, (x_batch, y_batch) in enumerate(train_dataset):
                train_loss = self.train_step(x_batch, y_batch)
                values = [('train_loss', train_loss)]
                pbar.update(i*self.batch_size, values=values)

            if eval_dataset is not None:
                for x_batch, y_batch in test_dataset:
                    val_loss = self._step(x_batch, y_batch)
                values=[('train_loss',train_loss),('val_loss',val_loss)]
            else:
                values=[('train_loss',train_loss)]
            pbar.update(num_train_samples, values=values)
            
    def save_weights(self, path):
        self.encoder.save_weights(join(path, 'encoder_weights.h5'))
        self.decoder.save_weights(join(path, 'decoder_weights.h5'))
        
    @classmethod
    def from_pretrained(cls, 
                         path, 
                         input_vocab_size,
                         output_vocab_size, 
                         embedding_dims, 
                         rnn_units, 
                         dense_units, 
                         batch_size,
                         encoder_max_seq_len,
                         decoder_max_seq_len,
                         start_token,
                         end_token,
                         beam_width,
                         training):
        
        model = cls(input_vocab_size,
                     output_vocab_size, 
                     embedding_dims, 
                     rnn_units, 
                     dense_units, 
                     batch_size,
                     encoder_max_seq_len,
                     decoder_max_seq_len,
                     start_token,
                     end_token,
                     beam_width,
                     training)
        model.encoder.build((None, None))
        model.encoder.load_weights(join(path, 'encoder_weights.h5'))
        model.decoder.build([(None, None), (None, None, rnn_units), (None, rnn_units), (None, rnn_units)])
        model.decoder.load_weights(join(path, 'decoder_weights.h5'))
        return model
    
    @staticmethod
    def decode_prediction(outputs):
        beam_ids, beam_scores = outputs
        return np.array([i.numpy() for i in beam_ids]).squeeze().transpose()
    
    def predict(self, input_ids):
        beam_outputs = self([input_ids, np.array([[self.start_token]])])
        return self.decode_prediction(beam_outputs)

### Compile Model

In [18]:
start_token=text2seq.char2id.get('<s>')
end_token=text2seq.char2id.get('<e>')

model = EncoderDecoder(input_vocab_size,
                         output_vocab_size, 
                         embedding_dims, 
                         rnn_units, 
                         dense_units, 
                         batch_size=BATCH_SIZE,
                         encoder_max_seq_len=Tx,
                         decoder_max_seq_len=Ty,
                         start_token=start_token,
                         end_token=end_token,
                         beam_width=5,
                         training=True)
lr_schedule = tfa.optimizers.ExponentialCyclicalLearningRate(initial_learning_rate=5e-4, 
                                                              maximal_learning_rate=1e-2,
                                                              step_size=steps_per_epoch*2, 
                                                              scale_mode="cycle", 
                                                              gamma=0.96)
opt = tfa.optimizers.Lookahead(tf.keras.optimizers.Adam(clipnorm=3.0, learning_rate=lr_schedule))
model.compile(optimizer=opt, loss=loss_fn)

### Training

In [19]:
model.fit(train_dataset, epochs=40, eval_dataset=test_dataset)
model.save_weights('model/1')


epoch 1/40

epoch 2/40

epoch 3/40

epoch 4/40

epoch 5/40

epoch 6/40

epoch 7/40

epoch 8/40

epoch 9/40

epoch 10/40

epoch 11/40

epoch 12/40

epoch 13/40

epoch 14/40

epoch 15/40

epoch 16/40

epoch 17/40

epoch 18/40

epoch 19/40

epoch 20/40

epoch 21/40

epoch 22/40

epoch 23/40

epoch 24/40

epoch 25/40

epoch 26/40

epoch 27/40

epoch 28/40

epoch 29/40

epoch 30/40

epoch 31/40

epoch 32/40

epoch 33/40

epoch 34/40

epoch 35/40

epoch 36/40

epoch 37/40

epoch 38/40

epoch 39/40

epoch 40/40


### Load model and do inferencing

In [20]:
loaded_model = EncoderDecoder.from_pretrained('model/1',
                         input_vocab_size,
                         output_vocab_size, 
                         embedding_dims, 
                         rnn_units, 
                         dense_units, 
                         batch_size=BATCH_SIZE,
                         encoder_max_seq_len=Tx,
                         decoder_max_seq_len=Ty,
                         start_token=start_token,
                         end_token=end_token,
                         beam_width=5,
                         training=False)

In [21]:
inputs = text2seq.fit_on_texts(['cơ the'])
text2seq.sequence_to_text(loaded_model.predict(inputs), True)

['có thể', 'tấúna', 'kh  ả', 'óấtng ta', 'ờó cc']

In [22]:
inputs = text2seq.fit_on_texts(['chng ta'])
text2seq.sequence_to_text(loaded_model.predict(inputs), True)

['chúng ta', 'ko ngưca', 'tón n ta', 'hh tagờa', 'ờnế g']

In [23]:
inputs = text2seq.fit_on_texts(['taast cả'])
text2seq.sequence_to_text(loaded_model.predict(inputs), True)

['có thể', 'tấ ng ta', 'khúna', 'óấúng ta', 'ờtt g ta']

In [24]:
inputs = text2seq.fit_on_texts(['khongtheer'])
text2seq.sequence_to_text(loaded_model.predict(inputs), True)

['không thể', 'chúng ta', 'thông thể', 'ôhông th', 'ểókng thể']