# Training Script 
Direction: Kana to Alpha

Encoder: GRU

Decoder: GRU

Hyper Parameter: *NUM_UNITS*

In [1]:
import tensorflow as tf

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from sklearn.model_selection import train_test_split

import graphviz

import unicodedata
import re
import numpy as np
import os
import io
import time
import datetime

import json

from functools import total_ordering

from RedBlackTree.rbtree import RedBlackNode
from RedBlackTree.rbtree import RedBlackTree
from StackDecoder.stack_decoder import StackDecoderPath
from StackDecoder.stack_decoder import StackDecoder


## Parameter Definitions

In [2]:
TOKENIZER_ALPHAS                 = 'training_data/alphas_tokenizer.json'
TOKENIZER_KANAS                  = 'training_data/kanas_tokenizer.json'

TRANING_DATA_FILE_90_10_10       = "training_data/alpha_to_kana_train.txt"
VALIDATION_DATA_FILE_90_10_10    = "training_data/alpha_to_kana_validation.txt"

EPOCHS                           = 1000
BATCH_SIZE                       =   64
NUM_UNITS                        =   16 # <= Hyper Parameter

VALIDATION_BEAM_WIDTH            =    5
VALIDATION_NBEST                 =    5
VALIDATION_MAX_LEN_KANAS_CUTOFF  =   12
VALIDATION_MAX_LEN_ALPHAS_CUTOFF =   16

CHECKPOINT_DIR                   = f'training_output/kana_to_alpha_{str(NUM_UNITS)}'

## Arrange Tokeniers and Training & Validation Sets

In [3]:
# Load tokenizers
with open(TOKENIZER_ALPHAS) as f:
    data = json.load(f)
    alphas_tokenizer = tf.keras.preprocessing.text.tokenizer_from_json(data)

with open(TOKENIZER_KANAS) as f:
    data = json.load(f)
    kanas_tokenizer = tf.keras.preprocessing.text.tokenizer_from_json(data)
    
EMB_DIM_ALPHAS = len( alphas_tokenizer.word_index ) + 1
EMB_DIM_KANAS  = len( kanas_tokenizer.word_index  ) + 1

In [4]:
# Load training data and validation data

train_kanas = []
train_alphas = []
with open( TRANING_DATA_FILE_90_10_10, "r", encoding="utf-8" ) as fp_train:
    for line in fp_train:
        alpha, kana = line.strip().split('\t')
        train_kanas.append(kana)
        train_alphas.append(alpha)

valid_kanas = []
valid_alphas = []
with open( VALIDATION_DATA_FILE_90_10_10, "r", encoding="utf-8" ) as fp_valid:
    for line in fp_valid:
        alpha, kana = line.strip().split( '\t' )
        valid_alphas.append( '<' + alpha + '>' )
        valid_kanas.append( '<' + kana + '>' )

validation_pairs = list(zip(valid_alphas, valid_kanas))


# Interleave with spaces so that we can utilize Kera's tokenizer.

train_kanas_spaced = []
for kana_str in train_kanas:
    kana_list = []
    kana_list[:0] = kana_str
    train_kanas_spaced.append( "< " + ' '.join(kana_list) + " >" ) 

train_alphas_spaced = []
for alpha_str in train_alphas:
    alpha_list = []
    alpha_list[:0] = alpha_str
    train_alphas_spaced.append( "< " + ' '.join(alpha_list) + " >" ) 

train_alphas_tensor = alphas_tokenizer.texts_to_sequences(train_alphas_spaced)
train_alphas_tensor = tf.keras.preprocessing.sequence.pad_sequences(train_alphas_tensor, padding='post')

train_kanas_tensor  = kanas_tokenizer.texts_to_sequences(train_kanas_spaced)
train_kanas_tensor  = tf.keras.preprocessing.sequence.pad_sequences(train_kanas_tensor, padding='post')

max_length_alphas, max_length_kanas = train_alphas_tensor.shape[1], train_kanas_tensor.shape[1]

In [5]:
BUFFER_SIZE = len(train_alphas_tensor)
steps_per_epoch = len(train_alphas_tensor)//BATCH_SIZE

dataset = tf.data.Dataset.from_tensor_slices((train_alphas_tensor, train_kanas_tensor)).shuffle(BUFFER_SIZE)
dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)

## Create Encoder and Decoder

In [6]:
class Encoder(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):
    super(Encoder, self).__init__()
    self.batch_sz = batch_sz
    self.enc_units = enc_units
    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
    self.gru = tf.keras.layers.GRU(self.enc_units,
                                   return_sequences=True,
                                   return_state=True,
                                   recurrent_initializer='glorot_uniform')

  def call(self, x, hidden):
    x = self.embedding(x)
    output, state = self.gru(x, initial_state = hidden)
    return output, state

  def initialize_hidden_state(self):
    return tf.zeros((self.batch_sz, self.enc_units))

In [7]:
encoder = Encoder(EMB_DIM_KANAS, EMB_DIM_KANAS, NUM_UNITS, BATCH_SIZE)

In [8]:
class BahdanauAttention(tf.keras.layers.Layer):
  def __init__(self, units):
    super(BahdanauAttention, self).__init__()
    self.W1 = tf.keras.layers.Dense(units)
    self.W2 = tf.keras.layers.Dense(units)
    self.V = tf.keras.layers.Dense(1)

  def call(self, query, values):
    # query hidden state shape == (batch_size, hidden size)
    # query_with_time_axis shape == (batch_size, 1, hidden size)
    # values shape == (batch_size, max_len, hidden size)
    # we are doing this to broadcast addition along the time axis to calculate the score
    query_with_time_axis = tf.expand_dims(query, 1)

    # score shape == (batch_size, max_length, 1)
    # we get 1 at the last axis because we are applying score to self.V
    # the shape of the tensor before applying self.V is (batch_size, max_length, units)
    score = self.V(tf.nn.tanh(
        self.W1(query_with_time_axis) + self.W2(values)))

    # attention_weights shape == (batch_size, max_length, 1)
    attention_weights = tf.nn.softmax(score, axis=1)

    # context_vector shape after sum == (batch_size, hidden_size)
    context_vector = attention_weights * values
    context_vector = tf.reduce_sum(context_vector, axis=1)

    return context_vector, attention_weights

In [9]:
class Decoder(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):
    super(Decoder, self).__init__()
    self.batch_sz = batch_sz
    self.dec_units = dec_units
    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
    self.gru = tf.keras.layers.GRU(self.dec_units,
                                   return_sequences=True,
                                   return_state=True,
                                   recurrent_initializer='glorot_uniform')
    self.fc = tf.keras.layers.Dense(vocab_size)

    # used for attention
    self.attention = BahdanauAttention(self.dec_units)

  def call(self, x, hidden, enc_output):
    # enc_output shape == (batch_size, max_length, hidden_size)
    context_vector, attention_weights = self.attention(hidden, enc_output)

    # x shape after passing through embedding == (batch_size, 1, embedding_dim)
    x = self.embedding(x)

    # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)
    x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)

    # passing the concatenated vector to the GRU
    output, state = self.gru(x)

    # output shape == (batch_size * 1, hidden_size)
    output = tf.reshape(output, (-1, output.shape[2]))

    # output shape == (batch_size, vocab)
    x = self.fc(output)

    return x, state, attention_weights

In [10]:
decoder = Decoder(EMB_DIM_ALPHAS, EMB_DIM_ALPHAS, NUM_UNITS, BATCH_SIZE)

## Check the Shapes of the Encoder and the Decoder

In [17]:
for alpha, kana in dataset.take(steps_per_epoch):
    encoder_state = encoder.initialize_hidden_state()
    encoder_out, encoder_state2 = encoder(kana, encoder_state)
    decoder_state = encoder_state
    dec_input = tf.expand_dims([alphas_tokenizer.word_index['<']] * BATCH_SIZE, 1)
    decoder_pred, decoder_state2, attn_weights = decoder(dec_input, decoder_state, encoder_out)
    break
kana.shape, alpha.shape, encoder_state.shape, encoder_out.shape, encoder_state2.shape, decoder_pred.shape, decoder_state2.shape, attn_weights.shape

(TensorShape([64, 14]),
 TensorShape([64, 14]),
 TensorShape([64, 16]),
 TensorShape([64, 14, 16]),
 TensorShape([64, 16]),
 TensorShape([64, 82]),
 TensorShape([64, 16]),
 TensorShape([64, 14, 1]))

In [12]:
optimizer = tf.keras.optimizers.Adam()
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction='none')

def loss_function(real, pred):
  mask = tf.math.logical_not(tf.math.equal(real, 0))
  loss_ = loss_object(real, pred)

  mask = tf.cast(mask, dtype=loss_.dtype)
  loss_ *= mask

  return tf.reduce_mean(loss_)

## One Training Step with Forward and Backprop with Incremental Teacher Forcing

In [13]:
@tf.function
def train_step(inp, targ, enc_hidden):
  loss = 0

  with tf.GradientTape() as tape:
    enc_output, enc_hidden = encoder(inp, enc_hidden)

    dec_hidden = enc_hidden

    dec_input = tf.expand_dims([alphas_tokenizer.word_index['<']] * BATCH_SIZE, 1)

    # Teacher forcing - feeding the target as the next input
    for t in range(1, targ.shape[1]):
      # passing enc_output to the decoder
      predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)

      loss += loss_function(targ[:, t], predictions)

      # using teacher forcing
      dec_input = tf.expand_dims(targ[:, t], 1)

  batch_loss = (loss / int(targ.shape[1]))

  variables = encoder.trainable_variables + decoder.trainable_variables

  gradients = tape.gradient(loss, variables)

  optimizer.apply_gradients(zip(gradients, variables))

  return batch_loss

## Checkpoint to Save the Models

In [14]:
# File outputs (checkpoints and metrics for tensorboard)

checkpoint_prefix = os.path.join(CHECKPOINT_DIR, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                 encoder=encoder,
                                 decoder=decoder)

current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
train_log_dir = os.path.join( CHECKPOINT_DIR, current_time, 'train' )
validation_log_dir = os.path.join( CHECKPOINT_DIR, current_time, 'validation' )
train_summary_writer = tf.summary.create_file_writer( train_log_dir )
validation_summary_writer = tf.summary.create_file_writer( validation_log_dir )


## Validation using Validation Set by Nbest Stack Decoder
The score is calculated by edit distance in Nbest against target.

In [15]:
# Validation

# Following levenshtein() is taken from 
# https://en.wikibooks.org/wiki/Algorithm_Implementation/Strings/Levenshtein_distance#Python
# under  Creative Commons Attribution-ShareAlike License.
def levenshtein(s1, s2):
    if len(s1) < len(s2):
        return levenshtein(s2, s1)

    # len(s1) >= len(s2)
    if len(s2) == 0:
        return len(s1)

    previous_row = range(len(s2) + 1)
    for i, c1 in enumerate(s1):
        current_row = [i + 1]
        for j, c2 in enumerate(s2):
            insertions = previous_row[j + 1] + 1 # j+1 instead of j since previous_row and current_row are one character longer
            deletions = current_row[j] + 1       # than s2
            substitutions = previous_row[j] + (c1 != c2)
            current_row.append(min(insertions, deletions, substitutions))
        previous_row = current_row
    
    return previous_row[-1]

def validate(sentence_pairs):

    BOS = alphas_tokenizer.word_index['<']
    EOS = alphas_tokenizer.word_index['>']

    stack_decoder = StackDecoder(decoder, BOS, EOS)

    avg_edit_dist = 0.0
    index = 0
    for alphas, kanas in sentence_pairs:
        if (index % 10 == 0):
            print('validating {}/{}'.format(index, len(sentence_pairs)))
        index += 1
        inputs = [kanas_tokenizer.word_index[i] for i in kanas]
        inputs = tf.keras.preprocessing.sequence.pad_sequences([inputs],
                                                         maxlen=max_length_kanas,
                                                         padding='post')
        inputs = tf.convert_to_tensor(inputs)

        hidden = [tf.zeros((1, NUM_UNITS))]
        enc_out, enc_hidden = encoder(inputs, hidden)

        dec_hidden = enc_hidden
        dec_input = tf.expand_dims([alphas_tokenizer.word_index['<']], 0)

        nbest_raw = stack_decoder.NBest( enc_out, enc_hidden, VALIDATION_BEAM_WIDTH, VALIDATION_NBEST, VALIDATION_MAX_LEN_KANAS_CUTOFF + 2 ) 
    
        min_edit_dist = -1
        for r in nbest_raw:
            candidate = ""
            for i in r.sentence:
                candidate += alphas_tokenizer.index_word[i] 
            edit_dist = levenshtein(alphas, candidate)
            if min_edit_dist == -1 or edit_dist < min_edit_dist:
                min_edit_dist = edit_dist
        avg_edit_dist += min_edit_dist
    return avg_edit_dist / len(sentence_pairs)


## Training Execution

In [16]:
#checkpoint.restore('./CHECKPOINT_DIR/ckpt-1')

for epoch in range(EPOCHS):
  start = time.time()

  enc_hidden = encoder.initialize_hidden_state()
  total_loss = 0

  for (batch, (targ, inp)) in enumerate(dataset.take(steps_per_epoch)):
    batch_loss = train_step(inp, targ, enc_hidden)
    total_loss += batch_loss

    if batch % 100 == 0:
      print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,
                                                   batch,
                                                   batch_loss.numpy()))
  # saving (checkpoint) the model every epoch
  checkpoint.save(file_prefix = checkpoint_prefix)

  accuracy = validate(list(validation_pairs)[0:100])
    
  print('Epoch {} Loss {:.4f}'.format(epoch + 1,
                                      total_loss / steps_per_epoch))
  print('Validation Accuracy {:0.4f}'.format(accuracy))
  print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))

  with train_summary_writer.as_default():
    tf.summary.scalar('loss', total_loss / steps_per_epoch, step=epoch)
    tf.summary.scalar('accuracy', accuracy, step=epoch)

Epoch 1 Batch 0 Loss 1.9724
Epoch 1 Batch 100 Loss 1.5925
Epoch 1 Batch 200 Loss 1.5406
Epoch 1 Batch 300 Loss 1.4940
Epoch 1 Batch 400 Loss 1.4945
Epoch 1 Batch 500 Loss 1.3504
Epoch 1 Batch 600 Loss 1.3996
Epoch 1 Batch 700 Loss 1.3298
Epoch 1 Batch 800 Loss 1.2999
Epoch 1 Batch 900 Loss 1.3321
Epoch 1 Batch 1000 Loss 1.4216
Epoch 1 Batch 1100 Loss 1.2387
Epoch 1 Batch 1200 Loss 1.1913
Epoch 1 Batch 1300 Loss 1.2598
Epoch 1 Batch 1400 Loss 1.3020
Epoch 1 Batch 1500 Loss 1.2454
Epoch 1 Batch 1600 Loss 1.2575
Epoch 1 Batch 1700 Loss 1.2021
Epoch 1 Batch 1800 Loss 1.2034
validating 0/100
validating 10/100
validating 20/100
validating 30/100
validating 40/100
validating 50/100
validating 60/100
validating 70/100
validating 80/100
validating 90/100
Epoch 1 Loss 1.3711
Validation Accuracy 4.6000
Time taken for 1 epoch 74.6486713886261 sec

Epoch 2 Batch 0 Loss 1.1451
Epoch 2 Batch 100 Loss 1.1082
Epoch 2 Batch 200 Loss 1.1579
Epoch 2 Batch 300 Loss 1.1687
Epoch 2 Batch 400 Loss 1.1071
Epoc

KeyboardInterrupt: 