In [1]:
import numpy as np
import tensorflow as tf
from archive.lib.models.graph import Graph
import archive.lib.utils
from archive.lib.layers.attention import BaseAttention, MaskedSelfAttention, SelfAttention, CrossAttention
from archive.lib.layers.dense import FeedForward

TypeError: function() argument 'code' must be code, not str

In [None]:
exports_dir = '../exports/en-tw-transformer-nmt'
model_file = f'{exports_dir}/models.pkb'

In [None]:
def print_translation(sentence, translation):
    print(f'{"Input:":15s}: {sentence}')
    print(f'{"Prediction":15s}: {translation}')

In [None]:
from archive.lib.language_index import LanguageIndex


class Translator(tf.Module):
    def __init__(self, model: Graph, inp_lang: LanguageIndex, targ_lang: LanguageIndex):
        self.transformer = model
        self.inp_lang = inp_lang
        self.targ_lang = targ_lang

    def __call__(self, sentence, max_length=60):
        sentence = self.inp_lang.to_padded_tensor([sentence])
        output = np.array([self.targ_lang.eos_token])

        for i in tf.range(1, max_length):
            decoder_input = output[np.newaxis]
            predictions = self.transformer.predict([sentence, decoder_input], batch_size=1, verbose=2)
            predictions = predictions[:, -1:, :]

            predicted_id = tf.argmax(predictions, axis=-1)

            output = np.append(output, predicted_id[0], axis=0)

            print("output", output)

            if predicted_id == self.targ_lang.eos_token:
                break

        print(np.shape(output[0]))
        text = self.targ_lang[output[0].tolist()]

        return text

In [None]:
import pickle

with open(f"{exports_dir}/en_tw-en.lang.idx", "rb") as input_file:
    inp_lang = pickle.load(input_file)

with open(f"{exports_dir}/en_tw-tw.lang.idx", "rb") as input_file:
    targ_lang = pickle.load(input_file)

with open(model_file, "rb") as input_file:
    model = pickle.load(input_file)

print(inp_lang)
print(targ_lang)
model.summary()

In [None]:
translator = Translator(model, inp_lang, targ_lang)

In [None]:
from archive.lib.preprocessors.naive_words import naive_words

sentence = naive_words(
    u"Lion",
    punctuations="?.!,¿'",
    special_chars='ɛƐɔƆ'
).split(' ')

print_translation(sentence, translator(sentence))