# Training Script 
Direction: Kana to Alpha

Encoder: GRU

Decoder: GRU

Hyper Parameter: *NUM_UNITS*

In [1]:
import tensorflow as tf

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from sklearn.model_selection import train_test_split

import unicodedata
import re
import numpy as np
import os
import io
import time
import datetime

import json

from functools import total_ordering

from RedBlackTree.rbtree import RedBlackNode
from RedBlackTree.rbtree import RedBlackTree
from StackDecoder.stack_decoder import StackDecoderPath
from StackDecoder.stack_decoder import StackDecoder


## Parameter Definitions

In [2]:
TOKENIZER_ALPHAS                 = 'training_data/alphas_tokenizer.json'
TOKENIZER_KANAS                  = 'training_data/kanas_tokenizer.json'

TRANING_DATA_FILE_90_10_10       = "training_data/alpha_to_kana_train.txt"
VALIDATION_DATA_FILE_90_10_10    = "training_data/alpha_to_kana_validation.txt"

EPOCHS                           = 1000
BATCH_SIZE                       =   64
NUM_UNITS                        =   16 # <= Hyper Parameter

VALIDATION_BEAM_WIDTH            =    5
VALIDATION_NBEST                 =    5
VALIDATION_MAX_LEN_KANAS_CUTOFF  =   12
VALIDATION_MAX_LEN_ALPHAS_CUTOFF =   16

CHECKPOINT_DIR                   = f'training_output/alpha_to_kana_{str(NUM_UNITS)}_wo_attn'

## Arrange Tokeniers and Training & Validation Sets

In [3]:
# Load tokenizers
with open(TOKENIZER_ALPHAS) as f:
    data = json.load(f)
    alphas_tokenizer = tf.keras.preprocessing.text.tokenizer_from_json(data)

with open(TOKENIZER_KANAS) as f:
    data = json.load(f)
    kanas_tokenizer = tf.keras.preprocessing.text.tokenizer_from_json(data)
    
EMB_DIM_ALPHAS = len( alphas_tokenizer.word_index ) + 1
EMB_DIM_KANAS  = len( kanas_tokenizer.word_index  ) + 1

In [4]:
# Load training data and validation data

train_kanas = []
train_alphas = []
with open( TRANING_DATA_FILE_90_10_10, "r", encoding="utf-8" ) as fp_train:
    for line in fp_train:
        alpha, kana = line.strip().split('\t')
        train_kanas.append(kana)
        train_alphas.append(alpha)

valid_kanas = []
valid_alphas = []
with open( VALIDATION_DATA_FILE_90_10_10, "r", encoding="utf-8" ) as fp_valid:
    for line in fp_valid:
        alpha, kana = line.strip().split( '\t' )
        valid_alphas.append( '<' + alpha + '>' )
        valid_kanas.append( '<' + kana + '>' )

validation_pairs = list(zip(valid_alphas, valid_kanas))


# Interleave with spaces so that we can utilize Kera's tokenizer.

train_kanas_spaced = []
for kana_str in train_kanas:
    kana_list = []
    kana_list[:0] = kana_str
    train_kanas_spaced.append( "< " + ' '.join(kana_list) + " >" ) 

train_alphas_spaced = []
for alpha_str in train_alphas:
    alpha_list = []
    alpha_list[:0] = alpha_str
    train_alphas_spaced.append( "< " + ' '.join(alpha_list) + " >" ) 

train_alphas_tensor = alphas_tokenizer.texts_to_sequences(train_alphas_spaced)
train_alphas_tensor = tf.keras.preprocessing.sequence.pad_sequences(train_alphas_tensor, padding='post')

train_kanas_tensor  = kanas_tokenizer.texts_to_sequences(train_kanas_spaced)
train_kanas_tensor  = tf.keras.preprocessing.sequence.pad_sequences(train_kanas_tensor, padding='post')

max_length_alphas, max_length_kanas = train_alphas_tensor.shape[1], train_kanas_tensor.shape[1]

In [5]:
BUFFER_SIZE = len(train_alphas_tensor)
steps_per_epoch = len(train_alphas_tensor)//BATCH_SIZE

dataset = tf.data.Dataset.from_tensor_slices((train_alphas_tensor, train_kanas_tensor)).shuffle(BUFFER_SIZE)
dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)

## Create Encoder and Decoder

In [6]:
class Encoder(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):
    super(Encoder, self).__init__()
    self.batch_sz = batch_sz
    self.enc_units = enc_units
    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
    self.gru = tf.keras.layers.GRU(self.enc_units,
                                   return_sequences=True,
                                   return_state=True,
                                   recurrent_initializer='glorot_uniform')

  def call(self, x, state):
    x = self.embedding(x)
    x, state = self.gru(x, initial_state = state)
    return x, state

  def initialize_hidden_state(self):
    return tf.zeros((self.batch_sz, self.enc_units))

In [7]:
encoder = Encoder(EMB_DIM_ALPHAS, EMB_DIM_ALPHAS, NUM_UNITS, BATCH_SIZE)

In [8]:
class Decoder(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):
    super(Decoder, self).__init__()
    self.batch_sz = batch_sz
    self.dec_units = dec_units
    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
    self.gru = tf.keras.layers.GRU(self.dec_units,
                                   return_sequences=False,
                                   return_state=True,
                                   recurrent_initializer='glorot_uniform')
    self.fc = tf.keras.layers.Dense(vocab_size)   

  def call(self, x, state):
    x = self.embedding(x)
    x, state = self.gru(x, state)
    x = self.fc(x)
    return x, state

In [9]:
decoder = Decoder(EMB_DIM_KANAS, EMB_DIM_KANAS, NUM_UNITS, BATCH_SIZE)

In [10]:
for alpha, kana in dataset.take(steps_per_epoch):
    encoder_state = encoder.initialize_hidden_state()
    encoder_out, encoder_state2 = encoder(alpha, encoder_state)
    decoder_state = encoder_state
    decoder_pred, decoder_state2 = decoder(kana, decoder_state)
    break
alpha.shape, kana.shape, encoder_state.shape, encoder_out.shape, encoder_state2.shape, decoder_pred.shape, decoder_state2.shape

(TensorShape([64, 14]),
 TensorShape([64, 14]),
 TensorShape([64, 16]),
 TensorShape([64, 14, 16]),
 TensorShape([64, 16]),
 TensorShape([64, 30]),
 TensorShape([64, 16]))

In [11]:
optimizer = tf.keras.optimizers.Adam()
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction='none')

def loss_function(real, pred):
  mask = tf.math.logical_not(tf.math.equal(real, 0))
  loss_ = loss_object(real, pred)

  mask = tf.cast(mask, dtype=loss_.dtype)
  loss_ *= mask

  return tf.reduce_mean(loss_)

## One Training Step with Forward and Backprop with Incremental Teacher Forcing

In [12]:
@tf.function
def train_step(inp, targ, enc_hidden):
  loss = 0

  with tf.GradientTape() as tape:
    enc_output, enc_hidden = encoder(inp, enc_hidden)

    dec_hidden = enc_hidden

    dec_input = tf.expand_dims([kanas_tokenizer.word_index['<']] * BATCH_SIZE, 1)
    
    # Teacher forcing - feeding the target as the next input
    for t in range(1, targ.shape[1]):
       # passing enc_output to the decoder
      predictions, dec_hidden = decoder(dec_input, dec_hidden)
    
      loss += loss_function(targ[:, t], predictions)
    
      # using teacher forcing
      dec_input = tf.expand_dims(targ[:, t], 1)

    # dec_input = tf.expand_dims(targ, 1)
    # dec_output, dec_hidden = decoder(dec_input, dec_hidden)
    
    
  batch_loss = (loss / int(targ.shape[1]))

  variables = encoder.trainable_variables + decoder.trainable_variables

  gradients = tape.gradient(loss, variables)

  optimizer.apply_gradients(zip(gradients, variables))

  return batch_loss

## Checkpoint to Save the Models

In [13]:
# File outputs (checkpoints and metrics for tensorboard)

checkpoint_prefix = os.path.join(CHECKPOINT_DIR, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                 encoder=encoder,
                                 decoder=decoder)

current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
train_log_dir = os.path.join( CHECKPOINT_DIR, current_time, 'train' )
validation_log_dir = os.path.join( CHECKPOINT_DIR, current_time, 'validation' )
train_summary_writer = tf.summary.create_file_writer( train_log_dir )
validation_summary_writer = tf.summary.create_file_writer( validation_log_dir )


In [16]:
# Validation

# Following levenshtein() is taken from 
# https://en.wikibooks.org/wiki/Algorithm_Implementation/Strings/Levenshtein_distance#Python
# under  Creative Commons Attribution-ShareAlike License.
def levenshtein(s1, s2):
    if len(s1) < len(s2):
        return levenshtein(s2, s1)

    # len(s1) >= len(s2)
    if len(s2) == 0:
        return len(s1)

    previous_row = range(len(s2) + 1)
    for i, c1 in enumerate(s1):
        current_row = [i + 1]
        for j, c2 in enumerate(s2):
            insertions = previous_row[j + 1] + 1 # j+1 instead of j since previous_row and current_row are one character longer
            deletions = current_row[j] + 1       # than s2
            substitutions = previous_row[j] + (c1 != c2)
            current_row.append(min(insertions, deletions, substitutions))
        previous_row = current_row
    
    return previous_row[-1]

def validate(sentence_pairs):

    BOS = kanas_tokenizer.word_index['<']
    EOS = kanas_tokenizer.word_index['>']

    stack_decoder = StackDecoder(decoder, BOS, EOS, use_attn = False)

    avg_edit_dist = 0.0
    index = 0
    for alphas, kanas in sentence_pairs:
        if (index % 10 == 0):
            print('validating {}/{}'.format(index, len(sentence_pairs)))
        index += 1
        inputs = [alphas_tokenizer.word_index[i] for i in alphas]
        inputs = tf.keras.preprocessing.sequence.pad_sequences([inputs],
                                                         maxlen=max_length_alphas,
                                                         padding='post')
        inputs = tf.convert_to_tensor(inputs)

        hidden = [tf.zeros((1, NUM_UNITS))]
        enc_out, enc_hidden = encoder(inputs, hidden)

        dec_hidden = enc_hidden
        dec_input = tf.expand_dims([kanas_tokenizer.word_index['<']], 0)

        nbest_raw = stack_decoder.NBest( enc_out, enc_hidden, VALIDATION_BEAM_WIDTH, VALIDATION_NBEST, VALIDATION_MAX_LEN_KANAS_CUTOFF + 2 ) 
    
        min_edit_dist = -1
        for r in nbest_raw:
            candidate = ""
            for i in r.sentence:
                candidate += kanas_tokenizer.index_word[i] 
            edit_dist = levenshtein(kanas, candidate)
            if min_edit_dist == -1 or edit_dist < min_edit_dist:
                min_edit_dist = edit_dist
        avg_edit_dist += min_edit_dist
    return avg_edit_dist / len(sentence_pairs)


## Training Execution

In [17]:
#checkpoint.restore('./CHECKPOINT_DIR/ckpt-1')

for epoch in range(EPOCHS):
  start = time.time()

  enc_hidden = encoder.initialize_hidden_state()
  total_loss = 0

  for (batch, (inp, targ)) in enumerate(dataset.take(steps_per_epoch)):

    batch_loss = train_step(inp, targ, enc_hidden)
    total_loss += batch_loss

    if batch % 100 == 0:
      print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,
                                                   batch,
                                                   batch_loss.numpy()))
  # saving (checkpoint) the model every epoch
  checkpoint.save(file_prefix = checkpoint_prefix)

  accuracy = validate(list(validation_pairs)[0:100])
    
  print('Epoch {} Loss {:.4f}'.format(epoch + 1,
                                      total_loss / steps_per_epoch))
  print('Validation Accuracy {:0.4f}'.format(accuracy))
  print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))

  with train_summary_writer.as_default():
    tf.summary.scalar('loss', total_loss / steps_per_epoch, step=epoch)
    tf.summary.scalar('accuracy', accuracy, step=epoch)

Epoch 1 Batch 0 Loss 1.2200
Epoch 1 Batch 100 Loss 1.2228
Epoch 1 Batch 200 Loss 1.3116
Epoch 1 Batch 300 Loss 1.2599
Epoch 1 Batch 400 Loss 1.1677
Epoch 1 Batch 500 Loss 1.1507
Epoch 1 Batch 600 Loss 1.2075
Epoch 1 Batch 700 Loss 1.1446
Epoch 1 Batch 800 Loss 1.1950
Epoch 1 Batch 900 Loss 1.1234
Epoch 1 Batch 1000 Loss 1.1611
Epoch 1 Batch 1100 Loss 1.1312
Epoch 1 Batch 1200 Loss 1.1360
Epoch 1 Batch 1300 Loss 1.1059
Epoch 1 Batch 1400 Loss 1.1097
Epoch 1 Batch 1500 Loss 1.1821
Epoch 1 Batch 1600 Loss 1.0768
Epoch 1 Batch 1700 Loss 1.0993
Epoch 1 Batch 1800 Loss 1.0836
validating 0/100
validating 10/100
validating 20/100
validating 30/100
validating 40/100
validating 50/100
validating 60/100
validating 70/100
validating 80/100
validating 90/100
Epoch 1 Loss 1.1553
Validation Accuracy 4.8500
Time taken for 1 epoch 34.04509425163269 sec

Epoch 2 Batch 0 Loss 1.0900
Epoch 2 Batch 100 Loss 1.0981
Epoch 2 Batch 200 Loss 1.1457
Epoch 2 Batch 300 Loss 1.1154
Epoch 2 Batch 400 Loss 1.0254
Epo

KeyboardInterrupt: 