In [36]:
import numpy as np

import typing
from typing import Any, Tuple

import einops
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

import tensorflow as tf
import tensorflow_text as tf_text
import pandas as pd
from sklearn.model_selection import train_test_split

In [37]:
df = pd.read_csv("train.csv")


def preprocess(df) -> tuple[np.ndarray, np.ndarray]:
    gloss, text = df["gloss"].values, df["text"].values
    return text, gloss # x, y

In [38]:
train, test = train_test_split(df, test_size=0.1, random_state=42)
train_x, train_y = preprocess(train)
test_x, test_y = preprocess(test)

In [39]:
BUFFER_SIZE = len(df)
BATCH_SIZE = 64

train_raw = (
    tf.data.Dataset.from_tensor_slices((train_x, train_y))
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE)
)
test_raw = (
    tf.data.Dataset.from_tensor_slices((test_x, test_y))
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE)
)

In [40]:
def tf_lower_and_split_punct(text):
    text = tf_text.normalize_utf8(text, "NFKD")
    text = tf.strings.lower(text)
    text = tf.strings.regex_replace(text, "[^ a-z.?!,¿]", "")
    text = tf.strings.regex_replace(text, "[.?!,¿]", r" \0 ")
    text = tf.strings.strip(text)
    text = tf.strings.join(["[START]", text, "[END]"], separator=" ")
    return text

In [41]:
for example_context_strings, example_target_strings in train_raw.take(1):
    print(example_context_strings[:5])
    print()
    print(example_target_strings[:5])
    break

tf.Tensor(
[b'after a previous tragedy in belgium in 2001 , comprehensive safety measures were promised but have not been implemented .\r\n'
 b'under the pressure of this procedure , the italian authorities are now changing their approach .\r\n'
 b'europeans often ask what we do here in the european parliament , what good we do for them .\r\n'
 b"to date , the commission has still not given clear responses to parliament's requests .\r\n"
 b'unfortunately , the commission is very reluctant to suggest any measures in this field .\r\n'], shape=(5,), dtype=string)

tf.Tensor(
[b'AFTER DESC-PREVIOUS TRAGEDY IN BELGIUM IN 2001 , DESC-COMPREHENSIVE SAFETY MEASURE BE PROMISE BUT HAVE DESC-NOT BE IMPLEMENT .\r\n'
 b'UNDER PRESSURE THIS PROCEDURE , ITALIAN AUTHORITY BE DESC-NOW CHANGE X-Y APPROACH .\r\n'
 b'EUROPEAN DESC-OFTEN ASK WHAT X-WE DO DESC-HERE IN EUROPEAN PARLIAMENT , WHAT DESC-GOOD X-WE DO FOR X-Y .\r\n'
 b'TO DATE , COMMISSION HAVE DESC-STILL DESC-NOT GIVE DESC-CLEAR RESPONSE TO PARL

In [42]:
max_vocab_size = 5000

context_text_processor = tf.keras.layers.TextVectorization(
    standardize=tf_lower_and_split_punct,
    max_tokens=max_vocab_size,
    ragged=True,
)
context_text_processor.adapt(train_raw.map(lambda context, _: context))

target_text_processor = tf.keras.layers.TextVectorization(
    standardize=tf_lower_and_split_punct,
    max_tokens=max_vocab_size,
    ragged=True,
)
target_text_processor.adapt(train_raw.map(lambda _, target: target))

In [43]:
def process_text(context, target):
    context = context_text_processor(context).to_tensor()
    target = target_text_processor(target)
    targ_in = target[:, :-1].to_tensor()
    targ_out = target[:, 1:].to_tensor()
    return (context, targ_in), targ_out


train_ds = train_raw.map(process_text, tf.data.AUTOTUNE)
val_ds = test_raw.map(process_text, tf.data.AUTOTUNE)

In [44]:
class Encoder(tf.keras.layers.Layer):
    def __init__(self, text_processor, units):
        super(Encoder, self).__init__()
        self.text_processor = text_processor
        self.vocab_size = text_processor.vocabulary_size()
        self.units = units

        # The embedding layer converts tokens to vectors
        self.embedding = tf.keras.layers.Embedding(
            self.vocab_size, units, mask_zero=True
        )

        # The RNN layer processes those vectors sequentially.
        self.rnn = tf.keras.layers.Bidirectional(
            merge_mode="sum",
            layer=tf.keras.layers.GRU(
                units,
                # Return the sequence and state
                return_sequences=True,
                recurrent_initializer="glorot_uniform",
            ),
        )

    def call(self, x):
        # shape_checker = ShapeChecker()
        # shape_checker(x, "batch s")

        # 2. The embedding layer looks up the embedding vector for each token.
        x = self.embedding(x)
        # shape_checker(x, "batch s units")

        # 3. The GRU processes the sequence of embeddings.
        x = self.rnn(x)
        # shape_checker(x, "batch s units")

        # 4. Returns the new sequence of embeddings.
        return x

    def convert_input(self, texts):
        texts = tf.convert_to_tensor(texts)
        if len(texts.shape) == 0:
            texts = tf.convert_to_tensor(texts)[tf.newaxis]
        context = self.text_processor(texts).to_tensor()
        context = self(context)
        return context

In [45]:
class CrossAttention(tf.keras.layers.Layer):
    def __init__(self, units, **kwargs):
        super().__init__()
        self.mha = tf.keras.layers.MultiHeadAttention(
            key_dim=units, num_heads=1, **kwargs
        )
        self.layernorm = tf.keras.layers.LayerNormalization()
        self.add = tf.keras.layers.Add()

    def call(self, x, context):
        # shape_checker = ShapeChecker()

        # shape_checker(x, "batch t units")
        # shape_checker(context, "batch s units")

        attn_output, attn_scores = self.mha(
            query=x,
            value=context,
            return_attention_scores=True,
        )

        # shape_checker(x, "batch t units")
        # shape_checker(attn_scores, "batch heads t s")

        # Cache the attention scores for plotting later.
        attn_scores = tf.reduce_mean(attn_scores, axis=1)
        # shape_checker(attn_scores, "batch t s")
        self.last_attention_weights = attn_scores

        x = self.add([x, attn_output])
        x = self.layernorm(x)

        return x

In [46]:
class Decoder(tf.keras.layers.Layer):
    @classmethod
    def add_method(cls, fun):
        setattr(cls, fun.__name__, fun)
        return fun

    def __init__(self, text_processor, units):
        super(Decoder, self).__init__()
        self.text_processor = text_processor
        self.vocab_size = text_processor.vocabulary_size()
        self.word_to_id = tf.keras.layers.StringLookup(
            vocabulary=text_processor.get_vocabulary(), mask_token="", oov_token="[UNK]"
        )
        self.id_to_word = tf.keras.layers.StringLookup(
            vocabulary=text_processor.get_vocabulary(),
            mask_token="",
            oov_token="[UNK]",
            invert=True,
        )
        self.start_token = self.word_to_id("[START]")
        self.end_token = self.word_to_id("[END]")

        self.units = units

        # 1. The embedding layer converts token IDs to vectors
        self.embedding = tf.keras.layers.Embedding(
            self.vocab_size,
            units,
            mask_zero=True,
        )

        # 2. The RNN keeps track of what's been generated so far.
        self.rnn = tf.keras.layers.GRU(
            units,
            return_sequences=True,
            return_state=True,
            recurrent_initializer="glorot_uniform",
        )

        # 3. The RNN output will be the query for the attention layer.
        self.attention = CrossAttention(units)

        # 4. This fully connected layer produces the logits for each
        # output token.
        self.output_layer = tf.keras.layers.Dense(self.vocab_size)

In [47]:
@Decoder.add_method
def call(self, context, x, state=None, return_state=False):
    # shape_checker = ShapeChecker()
    # shape_checker(x, "batch t")
    # shape_checker(context, "batch s units")

    # 1. Lookup the embeddings
    x = self.embedding(x)
    # shape_checker(x, "batch t units")

    # 2. Process the target sequence.
    x, state = self.rnn(x, initial_state=state)
    # shape_checker(x, "batch t units")

    # 3. Use the RNN output as the query for the attention over the context.
    x = self.attention(x, context)
    self.last_attention_weights = self.attention.last_attention_weights
    # shape_checker(x, "batch t units")
    # shape_checker(self.last_attention_weights, "batch t s")

    # Step 4. Generate logit predictions for the next token.
    logits = self.output_layer(x)
    # shape_checker(logits, "batch t target_vocab_size")

    if return_state:
        return logits, state
    else:
        return logits

In [48]:
@Decoder.add_method
def get_initial_state(self, context):
    batch_size = tf.shape(context)[0]
    start_tokens = tf.fill([batch_size, 1], self.start_token)
    done = tf.zeros([batch_size, 1], dtype=tf.bool)
    embedded = self.embedding(start_tokens)
    return start_tokens, done, self.rnn.get_initial_state(embedded)[0]

In [49]:
@Decoder.add_method
def tokens_to_text(self, tokens):
    words = self.id_to_word(tokens)
    result = tf.strings.reduce_join(words, axis=-1, separator=" ")
    result = tf.strings.regex_replace(result, "^ *\[START\] *", "")
    result = tf.strings.regex_replace(result, " *\[END\] *$", "")
    return result

In [50]:
@Decoder.add_method
def get_next_token(self, context, next_token, done, state, temperature=0.0):
    logits, state = self(context, next_token, state=state, return_state=True)

    if temperature == 0.0:
        next_token = tf.argmax(logits, axis=-1)
    else:
        logits = logits[:, -1, :] / temperature
        next_token = tf.random.categorical(logits, num_samples=1)

    # If a sequence produces an `end_token`, set it `done`
    done = done | (next_token == self.end_token)
    # Once a sequence is done it only produces 0-padding.
    next_token = tf.where(done, tf.constant(0, dtype=tf.int64), next_token)

    return next_token, done, state

In [51]:
class Translator(tf.keras.Model):
    @classmethod
    def add_method(cls, fun):
        setattr(cls, fun.__name__, fun)
        return fun

    def __init__(self, units, context_text_processor, target_text_processor):
        super().__init__()
        # Build the encoder and decoder
        encoder = Encoder(context_text_processor, units)
        decoder = Decoder(target_text_processor, units)

        self.encoder = encoder
        self.decoder = decoder

    def call(self, inputs):
        context, x = inputs
        context = self.encoder(context)
        logits = self.decoder(context, x)

        # TODO(b/250038731): remove this
        try:
            # Delete the keras mask, so keras doesn't scale the loss+accuracy.
            del logits._keras_mask
        except AttributeError:
            pass

        return logits

In [52]:
UNITS = 256
model = Translator(UNITS, context_text_processor, target_text_processor)

In [53]:
def masked_loss(y_true, y_pred):
    # Calculate the loss for each item in the batch.
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True, reduction="none"
    )
    loss = loss_fn(y_true, y_pred)

    # Mask off the losses on padding.
    mask = tf.cast(y_true != 0, loss.dtype)
    loss *= mask

    # Return the total.
    return tf.reduce_sum(loss) / tf.reduce_sum(mask)


def masked_acc(y_true, y_pred):
    # Calculate the loss for each item in the batch.
    y_pred = tf.argmax(y_pred, axis=-1)
    y_pred = tf.cast(y_pred, y_true.dtype)

    match = tf.cast(y_true == y_pred, tf.float32)
    mask = tf.cast(y_true != 0, tf.float32)

    return tf.reduce_sum(match) / tf.reduce_sum(mask)

In [54]:
model.compile(optimizer="adam", loss=masked_loss, metrics=[masked_acc, masked_loss])

In [55]:
history = model.fit(
    train_ds.repeat(),
    epochs=100,
    steps_per_epoch=100,
    validation_data=val_ds,
    validation_steps=20,
    callbacks=[tf.keras.callbacks.EarlyStopping(patience=3)],
)

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100


In [56]:
@Translator.add_method
def translate(self, texts, *, max_length=50, temperature=0.0):
    # Process the input texts
    context = self.encoder.convert_input(texts)
    batch_size = tf.shape(texts)[0]

    # Setup the loop inputs
    tokens = []
    attention_weights = []
    next_token, done, state = self.decoder.get_initial_state(context)

    for _ in range(max_length):
        # Generate the next token
        next_token, done, state = self.decoder.get_next_token(
            context, next_token, done, state, temperature
        )

        # Collect the generated tokens
        tokens.append(next_token)
        attention_weights.append(self.decoder.last_attention_weights)

        if tf.executing_eagerly() and tf.reduce_all(done):
            break

    # Stack the lists of tokens and attention weights.
    tokens = tf.concat(tokens, axis=-1)  # t*[(batch 1)] -> (batch, t)
    self.last_attention_weights = tf.concat(
        attention_weights, axis=1
    )  # t*[(batch 1 s)] -> (batch, t s)

    result = self.decoder.tokens_to_text(tokens)
    return result

In [61]:
result = model.translate(["there is an apple"])  # Are you still home
result[0].numpy().decode()

'descre be [UNK] '

In [62]:
# save model to tflite
model.save("translator.h5")



INFO:tensorflow:Assets written to: model\assets


INFO:tensorflow:Assets written to: model\assets
