In [82]:
import os.path

import numpy as np
import tensorflow as tf
from archive.lib.models.graph import Graph

In [83]:
exports_dir = 'exports/en-tw-transformer-nmt'
asset_dir = '../../data/en-tw'
model_file = f'{exports_dir}/checkpoint.model.keras'

In [84]:
def print_translation(sentence, translation):
    print(f'{"Input:":15s}: {sentence}')
    print(f'{"Prediction":15s}: {translation}')

In [85]:
from archive.lib.language_index import LanguageIndex


class Translator(tf.Module):
    def __init__(self, model: Graph, inp_lang: LanguageIndex, targ_lang: LanguageIndex):
        super().__init__()
        self.transformer = model
        self.inp_lang = inp_lang
        self.targ_lang = targ_lang

    def __call__(self, sentence: str, max_length=128):
        sentence = tf.convert_to_tensor([self.inp_lang[sentence]])
        output = np.array([self.targ_lang.eos_token_id])

        for i in tf.range(1, max_length):
            decoder_input = output[np.newaxis]
            predictions = self.transformer.predict((sentence, decoder_input), batch_size=1, verbose=2)
            predicted_id = tf.argmax(predictions[:, -1:, :], axis=-1)

            output = np.append(output, predicted_id[0], axis=0)
            if predicted_id == self.targ_lang.eos_token_id:
                break

        print(output.tolist())
        text = self.targ_lang[output.tolist()]

        return text

In [86]:
from archive.lib.configs.tranformers import TransformerModelConfigs

configs: TransformerModelConfigs = TransformerModelConfigs.load(os.path.join(exports_dir, 'config.json'))

inp_lang = configs.input_language_index(asset_dir)
targ_lang = configs.target_language_index(asset_dir)

In [87]:
from archive.lib.models.transformer import Transformer

model = Transformer(configs, inp_lang.vocab_size, targ_lang.vocab_size)
model.load_weights(model_file)

model.summary()

Model: "model_15"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_31 (InputLayer)       [(None, 128)]                0         []                            
                                                                                                  
 input_32 (InputLayer)       [(None, 128)]                0         []                            
                                                                                                  
 encoder_15 (Encoder)        (None, 128, 128)             4499712   ['input_31[0][0]']            
                                                                                                  
 decoder_15 (Decoder)        (None, 128, 128)             3107456   ['input_32[0][0]',            
                                                                     'encoder_15[0][0]']   

In [88]:
translator = Translator(model, inp_lang, targ_lang)

In [91]:
sentence = "Sarpong is going to school"

print_translation(sentence, translator(sentence, max_length=128))

1/1 - 0s - 13ms/epoch - 13ms/step
1/1 - 0s - 13ms/epoch - 13ms/step
1/1 - 0s - 14ms/epoch - 14ms/step
1/1 - 0s - 13ms/epoch - 13ms/step
1/1 - 0s - 14ms/epoch - 14ms/step
1/1 - 0s - 14ms/epoch - 14ms/step
1/1 - 0s - 14ms/epoch - 14ms/step
1/1 - 0s - 13ms/epoch - 13ms/step
1/1 - 0s - 13ms/epoch - 13ms/step
1/1 - 0s - 12ms/epoch - 12ms/step
1/1 - 0s - 12ms/epoch - 12ms/step
1/1 - 0s - 12ms/epoch - 12ms/step
1/1 - 0s - 12ms/epoch - 12ms/step
1/1 - 0s - 12ms/epoch - 12ms/step
1/1 - 0s - 13ms/epoch - 13ms/step
1/1 - 0s - 12ms/epoch - 12ms/step
1/1 - 0s - 12ms/epoch - 12ms/step
1/1 - 0s - 13ms/epoch - 13ms/step
1/1 - 0s - 14ms/epoch - 14ms/step
1/1 - 0s - 13ms/epoch - 13ms/step
1/1 - 0s - 14ms/epoch - 14ms/step
1/1 - 0s - 13ms/epoch - 13ms/step
1/1 - 0s - 13ms/epoch - 13ms/step
1/1 - 0s - 13ms/epoch - 13ms/step
1/1 - 0s - 12ms/epoch - 12ms/step
1/1 - 0s - 13ms/epoch - 13ms/step
1/1 - 0s - 13ms/epoch - 13ms/step
1/1 - 0s - 14ms/epoch - 14ms/step
1/1 - 0s - 13ms/epoch - 13ms/step
1/1 - 0s - 13m