# Imports

In [58]:
import logging as log
import functools
from time import time

import os

# general modules
import numpy as np
import math
import copy

# tensorflow modules
import tensorflow as tf
import tensorflow_text as tf_text
from tensorflow.keras import layers
from tensorflow_text.tools.wordpiece_vocab import bert_vocab_from_dataset as bert_vocab

# necessary for visualization and user input
import matplotlib.pyplot as plt


# Settings

In [84]:
# Set True, if code is run as jupyter notebook
is_interactive_notebook = True

# paths
dataset_path = 'datasets\\corpus.txt'
vocab_path = 'datasets\\vocab.txt'

reserved_tokens = ["[PAD]", "[UNK]", "[START]", "[END]"]

# Architecture

## Helper functions

In [60]:
def clones(module, N):
    """Produce N identical layers"""
    return [copy.deepcopy(module) for _ in range(N)]

def subsequent_mask(size):
    """Mask out subsequent positions."""
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return subsequent_mask == 0

### Layer Wrapper

In [85]:
class LayerWrapper(tf.keras.layers.Layer):
    """
    A wrapper for Keras layers, which allows to visualize data at each layer.

    Attributes:
        should_visualize (bool): Class attribute controlling whether visualization should occur.
        layer (Layer): The Keras layer to be wrapped.
        inputs (List[Tensor]): Inputs to the layer during calls.
        outputs (List[Tensor]): Outputs of the layer during calls.
        counter (int): Counter of layer calls.
        visualize_on_calls (List[int]): List of call counts at which to visualize.
        visualizations (List[Tuple[str, str]]): List of visualization modes and what to visualize.
        visual_setter (bool): If True, this instance can change the should_visualize class variable.
    """
    should_visualize = True  # class variable

    def __init__(self, layer, visualize_on_calls=None, visualizations=None, visual_setter=False, **kwargs):
        """
        Initialize the LayerCallWrapper
        Args:
            layer (Layer): The Keras layer to be wrapped.
            visualize_on_calls (List[int], optional): List of call counts at which to visualize. Defaults to empty list.
            visualizations (List[Tuple[str, str]], optional): List of visualization modes and what to visualize. Defaults to an empty list.
            visual_setter (bool, optional): If True, this instance can change the should_visualize class variable. Defaults to False.
            **kwargs: Additional keyword arguments.u
        """
        super().__init__(**kwargs)
        self.layer = layer
        self.inputs = []
        self.outputs = []
        self.counter = 0
        self.visualize_on_calls = visualize_on_calls if visualize_on_calls else []
        self.visualizations = visualizations if visualizations else []
        self.visual_setter = visual_setter

    def __getattr__(self, attr):
        """
        Overloads the attributte access in order to access the wrapped layers attribute if not found in the wraper
        """
        if 'layer' in self.__dict__:
            return getattr(self.layer, attr)
        else:
            raise AttributeError(f"{self.__class__.__name__} object has no attribute {attr}")

        
    def call(self, *args, **kwargs):
        """
        Overloads the call to the layer, allowing to capture inputs and outputs, and visualize if needed.

        Args:
            *args: Variable length argument list.
            **kwargs: Arbitrary keyword arguments.

        Returns:
            Tensor: The output of the layer call.
        """
        self.inputs.append([arg for arg in args])
        output = self.layer(*args, **kwargs)
        self.outputs.append(output.numpy())

        # check for visualisation param of the instance and visualize or change class settings
        if self.counter in self.visualize_on_calls:
            if self.should_visualize:
                self.visualize(self.visualizations)
            if self.visual_setter:
                LayerWrapper.should_visualize = True
        else:
            if self.visual_setter:
                LayerWrapper.should_visualize = False

        self.counter += 1
        return output
    
    @staticmethod
    def wait_for_user_input():
        # waits for user input, if not jupyter notebook
        # causes problems in jupyter
        if not is_interactive_notebook:
            proceed = input('Continue')

    def visualize(self, visualizations):
        for mode, what_to_output in visualizations:
            if what_to_output == 'x':
                data = self.inputs[-1]
            elif what_to_output == 'y':
                data = self.outputs[-1]
            elif what_to_output == 'y-x':
                data = [output - input for input, output in zip(self.inputs[-1], self.outputs[-1])]

            if mode == 'mode1':
                self.visualization_func_1(data)
            elif mode == 'mode2':
                self.visualization_func2(data)

        self.wait_for_user_input()

    def visualization_func_1(self, data):
        # Assuming data[0] is a numpy array.
        # If it's a ListWrapper or another list-like object, convert it to a numpy array.
        array_data = np.array(data[0])
        # If the array is 1D, reshape it into a 2D array with one column
        if array_data.ndim == 1:
            array_data = np.reshape(array_data, (-1, 1))
        # Set the size of the plot (you can adjust the dimensions as needed)
        plt.figure(figsize=(10, 2))
        # Use imshow to create a color-coded visualization and display it
        plt.imshow(array_data, cmap='jet', aspect='auto')
        plt.colorbar(label='Tensor Value')
        plt.show()
        
    def visualization_func2(self, data):
        # Your visualization code here
        pass

## Main Layers

These classes are built using the Keras Functional API, which provides more flexibility than the Sequential API for defining complex models. Each class is a subclass of tf.keras.layers.Layer, so they can be composed to build more complex layers or models. The call method of each class defines the computation that the layer performs.

These classes are designed to be components of a larger transformer model. The model itself is typically composed of an encoder and a decoder, each of which is made up of a stack of identical layers. The layers themselves contain sublayers that perform operations such as self-attention, source attention (in the case of the decoder), and position-wise feed-forward networks. These operations are encapsulated within classes like `EncoderStack`, `DecoderStack`, `EncoderLayer`, `DecoderLayer`, and `PositionwiseFeedForward`. The layer norm and dropout are applied in `ResidualSublayer` for regularizing and speeding up the training process.

### Encoder Decoder Layer

1. `EncoderDecoder`:
    - `__init__(self, encoder, decoder, enc_embed, dec_embed, generator)`: This initializes the EncoderDecoder instance. It takes in five arguments:
        - `encoder`: The encoder layer to be used.
        - `decoder`: The decoder layer to be used.
        - `enc_embed`: The embedding layer for the encoder.
        - `dec_embed`: The embedding layer for the decoder.
        - `generator`: The final layer that generates the output tokens.
    - `encode(self, inputs, pad_mask)`: This method is used to encode the inputs using the encoder layer. It takes in two arguments:
        - `inputs`: The input tokens to be encoded.
        - `pad_mask`: The mask indicating which tokens are padding.
    - `decode(self, enc_input, pad_mask, inputs, subseq_mask)`: This method is used to decode the encoded inputs using the decoder layer. It takes in four arguments:
        - `enc_input`: The encoded input from the encoder.
        - `pad_mask`: The mask indicating which tokens are padding in the encoded input.
        - `inputs`: The target tokens to be decoded.
        - `subseq_mask`: The mask indicating which tokens in the target sequence should not be attended to.
    - `call(self, enc_input, dec_input, pad_mask, subseq_mask)`: This method is used to perform the complete transformation from input tokens to output tokens. It takes in four arguments that are the same as those described in the `encode` and `decode` methods.

In [59]:
class EncoderDecoder(layers.Layer):
    def __init__(self, encoder, decoder, enc_embed, dec_embed, generator):
        super().__init__()
        # modules
        self.encoder = encoder
        self.decoder = decoder
        self.enc_embed = enc_embed
        self.dec_embed = dec_embed
        self.generator = generator

    def encode(self, inputs, pad_mask):
        return self.encoder(self.enc_embed(inputs), pad_mask)
    
    def decode(self, enc_input, pad_mask, inputs, subseq_mask):
        return self.decoder(self.dec_embed(inputs), enc_input, pad_mask, subseq_mask)

    def call(self, enc_input, dec_input, pad_mask, subseq_mask):
        return self.decode(self.encode(enc_input, pad_mask), 
                           pad_mask,
                           dec_input, 
                           subseq_mask)

### Layer Norm


2. `LayerNorm`:
    - `__init__(self, features, eps=1e-6)`: This initializes the LayerNorm instance. It takes in two arguments:
        - `features`: The number of features in the input to be normalized.
        - `eps`: A small number to add to the denominator for numerical stability.
    - `call(self, x)`: This method is used to apply layer normalization to the input. It takes in one argument:
        - `x`: The input to be normalized.

In [63]:
class LayerNorm(layers.Layer):

    def __init__(self, features, eps=1e-6) -> None:
        super(LayerNorm, self).__init__()
        self.a_2 = self.add_weight(shape=(features,), initializer='ones')
        self.b_2 = self.add_weight(shape=(features,), initializer='zeros')
        self.eps = eps

    def call(self, x):
        mean, var = tf.nn.moments(x, axes=-1, keepdims=True)
        std = tf.math.sqrt(var + self.eps)
        return self.a_2 * (x - mean) / std + self.b_2

### Residual Layer


3. `ResidualSublayer`:
    - `__init__(self, size, dropout)`: This initializes the ResidualSublayer instance. It takes in two arguments:
        - `size`: The number of features in the input.
        - `dropout`: The dropout rate to be applied after the sublayer.
    - `call(self, x, sublayer)`: This method is used to apply a sublayer and a residual connection to the input. It takes in two arguments:
        - `x`: The input to be transformed.
        - `sublayer`: The sublayer to be applied to the input. This is expected to be a function or callable object that takes in the input and returns a tensor of the same shape.

In [64]:
class ResidualSublayer(layers.Layer):
    """
    A residual connection followed by a layer norm. Note for code simplicity the norm is first as opposed to last.
    """

    def __init__(self, size, dropout) -> None:
        super(ResidualSublayer, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = layers.Dropout(dropout)

    def call(self, x, sublayer):
        "Apply residual connection to any sublayer with the same size."
        return x + self.dropout(sublayer(self.norm(x)))

### Encoder Stack Layer

4. `EncoderStack`:
    - `__init__(self, layer, N)`: This initializes the EncoderStack instance. It takes in two arguments and initializes two instance variables:
        - `layer`: The type of layer to be used in the encoder stack. This should be a callable object that takes in the input and a mask and returns a tensor.
        - `N`: The number of layers in the encoder stack.
        - `self.layers` is a list of `N` layer clones of the type `layer`.
        - `self.norm` is the norm layer, that is applied to the output of the `EncoderStack`.
    - `call(self, x, mask)`: This method is used to pass the input through each layer in the encoder stack in turn. It takes in two arguments:
        - `x`: The input to be processed by the encoder stack.
        - `mask`: The mask indicating which tokens should not be attended to.

In [65]:
class EncoderStack(layers.Layer):
    """
    Core encoder is a stack of N=6 Layers
    """

    def __init__(self, layer, N):
        super(EncoderStack, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def call(self, x, mask):
        """
        Pass the input (and mask) through each layer in turn
        """
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

### Encoder Layer


5. `EncoderLayer`:
    - `__init__(self, size, self_attn, feed_forward, dropout)`: This initializes the EncoderLayer instance. It takes in four arguments:
        - `size`: The number of features in the input.
        - `self_attn`: The self-attention mechanism to be used in the encoder layer. This should be a callable object that takes in the input and a mask and returns a tensor.
        - `feed_forward`: The feed-forward network to be used in the encoder layer. This should be a callable object that takes in the input and returns a tensor.
        - `dropout`: The dropout rate to be applied after each sublayer.
    - `call(self, x, mask)`: This method is used to pass the input through the self-attention mechanism and the feed-forward network. It takes in two arguments:
        - `x`: The input to be processed by the encoder layer.
        - `mask`: The mask indicating which tokens should not be attended to.

In [66]:
class EncoderLayer(layers.Layer):
    """
    Encoder is made up of a self-attention and a feed forward layer 
    """

    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(ResidualSublayer(size, dropout), 2)
        self.size = size

    def call(self, x, mask):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)

### Decoder Stack Layer

6. `DecoderStack`:
    - `__init__(self, layer, N)`: This initializes the DecoderStack instance. It takes in two arguments and initializes two instance variables:
        - `layer`: The type of layer to be used in the decoder stack. This should be a callable object that takes in the input, the memory from the encoder, a source mask, and a target mask, and returns a tensor.
        - `N`: The number of layers in the decoder stack.
        - `self.layers` is a list of `N` layer clones of the type `layer`.
        - `self.norm` is the norm layer, that is applied to the output of the `EncoderStack`.
    - `call(self, x, memory, src_mask, tgt_mask)`: This method is used to pass the input through each layer in the decoder stack in turn. It takes in four arguments:
        - `x`: The input to be processed by the decoder stack.
        - `memory`: The output of the encoder, which serves as the memory for the decoder.
        - `src_mask`: The mask indicating which tokens in the source sequence should not be attended to.
        - `tgt_mask`: The mask indicating which tokens in the target sequence should not be attended to.

In [67]:
class DecoderStack(layers.Layer):
    """
    Generic N layer decoder with masking
    """

    def __init__(self, layer, N):
        super(DecoderStack, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def call(self, x, memory, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, memory, src_mask, tgt_mask)
        return self.norm(x)

### Decoder Layer

7. `DecoderLayer`:
    - `__init__(self, size, self_attn, src_attn, feed_forward, dropout)`: This initializes the DecoderLayer instance. It takes in five arguments:
        - `size`: The number of features in the input.
        - `self_attn`: The self-attention mechanism to be used in the decoder layer. This should be a callable object that takes in the input and a mask and returns a tensor.
        - `src_attn`: The source attention mechanism to be used in the decoder layer. This should be a callable object that takes in the input, the memory from the encoder, and a mask, and returns a tensor.
        - `feed_forward`: The feed-forward network to be used in the decoder layer. This should be a callable object that takes in the input and returns a tensor.
        - `dropout`: The dropout rate to be applied after each sublayer.
    - `call(self, x, memory, src_mask, tgt_mask)`: This method is used to pass the input through the self-attention mechanism, the source attention mechanism, and the feed-forward network. It takes in four arguments:
        - `x`: The input to be processed by the decoder layer.
        - `memory`: The output of the encoder, which serves as the memory for the decoder.
        - `src_mask`: The mask indicating which tokens in the source sequence should not be attended to.
        - `tgt_mask`: The mask indicating which tokens in the target sequence should not be attended to.

In [68]:
class DecoderLayer(layers.Layer):
    """
    Decoder is made of self-attn, source-attn and feedforward layer
    """

    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(ResidualSublayer(size, dropout), 3)

    def call(self, x, memory, src_mask, tgt_mask):
        m = memory
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
        return self.sublayer[2](x, self.feed_forward)

## Sublayers

### Feedforward Layer

8. `PositionwiseFeedForward`:
    - `__init__(self, d_model, d_ff, dropout=0.1, *args, **kwargs)`: This initializes the PositionwiseFeedForward instance. It takes in three arguments and an optional set of arguments:
        - `d_model`: The number of features in the input.
        - `d_ff`: The number of features in the hidden layer of the feed-forward network.
        - `dropout`: The dropout rate to be applied after the first layer of the feed-forward network.
        - `*args, **kwargs`: Additional arguments that might be necessary for the parent class initialization.
    - `call(self, x)`: This method is used to pass the input through the feed-forward network. It takes in one argument:
        - `x`: The input to be processed by the feed-forward network.

In [69]:
class PositionwiseFeedForward(layers.Layer):
    """Implements FFN equation"""

    def __init__(self, d_model, d_ff, dropout=0.1, *args, **kwargs):
        super(PositionwiseFeedForward, self).__init__(*args, **kwargs)
        self.w_1 = layers.Dense(d_ff)
        self.w_2 = layers.Dense(d_model)
        self.dropout = layers.Dropout(dropout)

    def call(self, x):
        return self.w_2(self.dropout(tf.nn.relu(self.w_1(x))))

### Generator Layer

9. `Generator`:
    - `__init__(self, vocab)`: This method initializes the Generator instance. It accepts one argument:
        - `vocab`: The size of the vocabulary which will be the number of output units in the dense layer.
    - `call(self, x)`: This method is used to pass the input through the generator. It takes in one argument:
        - `x`: The input tensor to be processed by the generator. The method returns the log softmax of the output of the dense layer.

In [70]:
class Generator(layers.Layer):
    """
    Define standard linear + softmax generation step
    """

    def __init__(self, vocab):
        super(Generator,self).__init__()
        self.proj = layers.Dense(vocab)

    def call(self, x):
        return tf.nn.log_softmax(self.proj(x), axis=-1)

### Attention Layer


10. `attention(query, key, value, mask=None, dropout=None)`:
    - This is a function that computes the 'Scaled Dot Product Attention'. The arguments are as follows:
        - `query`, `key`, `value`: These are the main inputs to the attention function.
        - `mask`: Optional mask for the attention scores.
        - `dropout`: Optional dropout rate to be applied to the attention scores.
    - The function first scales the dot product of the query and key, applies the mask if provided, applies softmax to compute attention scores, applies dropout if provided, and then uses the attention scores to compute a weighted sum of the value inputs.

In [71]:
def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"

    d_k = query.shape[-1]
    scores = tf.matmul(query, tf.transpose(key, perm=[0, 1, 3, 2])) / math.sqrt(d_k)
    if mask is not None:
        mask = tf.cast(mask, dtype=tf.bool)
        scores = tf.where(mask, scores, tf.fill(tf.shape(scores), -1e9))
    p_attn = tf.nn.softmax(scores, axis=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return tf.matmul(p_attn, value), p_attn

11. `MultiHeadedAttention`:
    - `__init__(self, h, d_model, dropout=0.1)`: This initializes the MultiHeadedAttention instance. It takes in three arguments:
        - `h`: The number of attention heads.
        - `d_model`: The number of features in the input.
        - `dropout`: The dropout rate to be applied after the softmax in the attention computation.
    - `call(self, query, key, value, mask=None)`: This method is used to compute the multi-headed attention over the inputs. It takes in four arguments:
        - `query`, `key`, `value`: These are the main inputs to the attention computation.
        - `mask`: Optional mask for the attention scores.
    - The method first computes the linear projections of the inputs, applies the attention function to the projected inputs, concatenates the outputs of the attention function across the attention heads, and then applies a final linear transformation to the concatenated outputs.

In [72]:
class MultiHeadedAttention(layers.Layer):
    
    def __init__(self, h, d_model, dropout=0.1):
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        self.d_k = d_model // h
        self.h = h
        self.query, self.key, self.value, self.linear = clones(layers.Dense(d_model), 4)
        self.attn = None
        self.dropout = layers.Dropout(dropout)

    def call(self, query, key, value, mask=None):
        if mask is not None:
            # Same mask applied to all h heads
            mask = tf.expand_dims(mask, 1)
        nbatches = query.shape[0]

        # 1) Do all the linear projections in batch from d_model => h x d_k
        query, key, value = [
            tf.transpose(tf.reshape(lin(x), [nbatches, -1 , self.h, self.d_k]), perm=[0, 2, 1, 3]) 
            for lin, x in zip(
                [self.query, self.key, self.value], 
                (query, key, value))
        ]

        # 2) Apply attention on all the projected vectors in batch.
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)

        # 3) "Concat" using a view and apply a final linear.
        x = tf.reshape(tf.transpose(x ,perm=[0, 2, 1, 3]), (nbatches, -1, self.h * self.d_k))

        return self.linear(x)

### Positional Embedding Layer

12. `positional_encoding(length, depth)`:
    - This is a function that computes the positional encoding for a sequence of a given length and depth. The arguments are as follows:
        - `length`: The length of the sequence for which positional encoding is to be computed.
        - `depth`: The number of features in the input sequence.
    - The function first computes the rates at which the angles should change across the positions and depths, then computes the angles at each position and depth, and finally applies sine to the angles at the even indices and cosine to the angles at the odd indices. The positional encoding for a position is thus a vector of these sine and cosine values.

In [73]:
def positional_encoding(length, depth):
    depth = depth / 2

    positions = np.arange(length)[:, np.newaxis]   # (seq, 1)
    depths = np.arange(depth)[np.newaxis, :]/depth  # (1, depth)

    angle_rates = 1 / (10000**depths)               # (1, depth)
    angle_rads  = positions * angle_rates           # (pos, depth)

    pos_encoding = np.concatenate(
        [np.sin(angle_rads), np.cos(angle_rads)],
        axis=-1
        )

    return tf.cast(pos_encoding, dtype=tf.float32)

13. `PositionalEmbedding`:
    - `__init__(self, vocab_size, d_model)`: This method initializes the PositionalEmbedding instance. It takes in two arguments:
        - `vocab_size`: The size of the vocabulary, which will be the input dimension of the embedding layer.
        - `d_model`: The number of features to be output by the embedding layer and the depth for the positional encoding.
    - `call(self, x)`: This method is used to compute the positionally encoded embeddings of the inputs. It takes in one argument:
        - `x`: The input tensor for which the embeddings are to be computed.
    - The method first computes the embeddings of the inputs, scales the embeddings by the square root of `d_model`, and then adds the positional encoding to these scaled embeddings.

In [74]:
class PositionalEmbedding(tf.keras.layers.Layer):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.d_model = d_model
        self.embedding = tf.keras.layers.Embedding(vocab_size, d_model, mask_zero=True)
        self.pos_encoding = positional_encoding(length=2048, depth=d_model)

    def call(self, x):
        length = tf.shape(x)[1]
        x = self.embedding(x)
        # This factor sets the relative scale of the embedding and positional_encoding
        x *=tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        #print(tf.shape(x))
        #print(tf.shape(self.pos_encoding[tf.newaxis, :length, :]))
        x = x + self.pos_encoding[tf.newaxis, :length, :]
        return x

## Model Generation

### Make model

14. `make_model(src_vocab, tgt_vocab, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1)`:
    - The `make_model` function constructs a Transformer model from given hyperparameters. It takes seven arguments:
        - `src_vocab`: The size of the source vocabulary.
        - `tgt_vocab`: The size of the target vocabulary.
        - `N`(default=6): The number of layers in the Transformer's Encoder and Decoder stacks.
        - `d_model`(default=512): The dimension of the model. It's the number of features in input and output.
        - `d_ff`(default=2048): The number of features in the hidden layer of the feed-forward network.
        - `h`(default=8): The number of attention heads in the MultiHeadedAttention mechanism.
        - `dropout`(default=0.1): The dropout rate to be applied in several parts of the model.
    - Inside this function, instances of `MultiHeadedAttention` and `PositionwiseFeedForward` are created. These instances are then deep-copied and used to construct the Encoder and Decoder stacks, additionally the PositionalEmbeddings, and the Generator are instantiated. All these parts are then assembled into a `EncoderDecoder` instance, which includes the complete Transformer model. If a module is wrapped with a `LayerWrapper` this is in order to visualize the output of this layer on sucessive calls. Look for the specific meaning of the `LayerWrapper` parameters in the definition of the class.
    - Finally, the function returns the constructed model. 

In [None]:
def make_model(src_vocab, tgt_vocab, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1):
    "Helper: Construct a model from hyperparameters."
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    model = LayerWrapper(
                EncoderDecoder(
                EncoderStack(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
                DecoderStack(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N),
                PositionalEmbedding(src_vocab, d_model),
                PositionalEmbedding(tgt_vocab, d_model),
                LayerWrapper(Generator(tgt_vocab), visualize_on_calls=[1], visualizations=[('mode1', 'x')])
            ),
            visualize_on_calls=[1], visual_setter=True)

    # Initialize parameters with Glorot / fan_avg.
    # model.build([(None, None), (None, None)])  # Explicit build call to initialize variables
    # for w in model.trainable_variables:
    #     if len(w.shape) > 1:
    #         tf.keras.initializers.GlorotUniform()(w)
    return model

In [None]:
def inference_test():
    test_model = make_model(11, 11, 2)

    test_model.trainable = False
    src = tf.constant([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], dtype=tf.int64)
    src_mask = tf.ones((1, 1, 10), dtype=tf.float32)

    memory = test_model.encode(src, src_mask)
    ys = tf.zeros((1, 1), dtype=tf.int64)

    for i in range(9):
        out = test_model.decode(memory, src_mask, ys, subsequent_mask(ys.shape[1]))
        prob = test_model.generator(out[:, -1])
        next_word = tf.argmax(prob, axis=-1)[0]
        ys = tf.concat([ys, tf.reshape(next_word, (1, 1))], axis=1)

    print("Example Untrained Model Prediction:", ys)

def run_tests():
    for _ in range(10):
        inference_test()

run_tests()

# Training

## Data Preparation

### Vocabulary Generation

In [None]:
def load_dataset(dataset_text_file):
    return tf.data.TextLineDataset(filenames=dataset_text_file)

def create_vocab(dataset):
    bert_vocab_args=dict(
        vocab_size = 8000,
        reserved_tokens = ["[PAD]", "[UNK]", "[START]", "[END]"],
        bert_tokenizer_params = dict(lower_case=True),
        learn_params = {},
    )

    story_vocab = bert_vocab.bert_vocab_from_dataset(
        dataset.batch(1000).prefetch(2),
        **bert_vocab_args
    )
    return story_vocab

def create_vocab_from_textdata(text_file=dataset_path):
    dataset = load_dataset(text_file)
    vocab = create_vocab(dataset)
    return vocab

def write_vocab_file(filepath, vocab):
    with open(filepath, 'w') as file:
        for token in vocab:
            print(token, file=file)

In [None]:
class Batch:
    """Object for holding a batch of data with mask during training."""

    def __init__(self, src, tgt=None, pad=2): # 2 = <blank>
        self.src = src
        self.src_mask = (src != pad)[:, np.newaxis, :]
        if tgt is not None:
            self.tgt = tgt[:, :-1]
            self.tgt_y = tgt[:, 1:]
            self.tgt_mask = self.make_std_mask(self.tgt, pad)
            self.ntokens = tf.reduce_sum(tf.cast(self.tgt_y != pad, tf.int64))
    
    @staticmethod
    def make_std_mask(tgt, pad):
        "Create a mask to hide padding and future words."
        tgt_mask = (tgt != pad)[:, np.newaxis, :]
        tgt_mask = tf.logical_and(tgt_mask, subsequent_mask(tgt.shape[-1]))
        return tgt_mask

## Tokenizer

In [None]:
def add_start_end(ragged):
    START = tf.argmax(tf.constant(reserved_tokens) == "[START]")
    END = tf.argmax(tf.constant(reserved_tokens) == "[END]")

    count = ragged.bounding_shape()[0]
    starts = tf.fill([count, 1], START)
    ends = tf.fill([count, 1], END)
    return tf.concat([starts, ragged, ends], axis=1)

def cleanup_text(reserved_tokens, token_txt):
    bad_tokens = list(filter(lambda token: token != "[UNK]", reserved_tokens))
    bad_tokens_re = "|".join(bad_tokens)

    bad_cells = tf.strings.regex_full_match(token_txt, bad_tokens_re)
    ragged_result = tf.ragged.boolean_mask(token_txt, ~bad_cells)

    result = tf.strings.reduce_join(ragged_result, separator=' ', axis=-1)

    return result

In [None]:
class StoryTokenizer(tf.Module):
    def __init__(self, reserved_tokens, vocab_path):
        super().__init__()
        self.tokenizer = tf_text.BertTokenizer(vocab_path, lower_case=True)
        self._reserved_tokens = reserved_tokens
        self._vocab_path = tf.saved_model.Asset(vocab_path)

        vocab = pathlib.Path(vocab_path).read_text().splitlines()
        self.vocab = tf.Variable(vocab)

        ## Create the signatures for export:

        # tokenize signature for a batch of strings
        self.tokenize.get_concrete_function(
            tf.TensorSpec(shape=[None], dtype=tf.string))
        
        # detokenize and lookup signature for:
        # * Tensor with shape [tokens] and [batch, tokens]
        # * RaggedTensor with shape [batch, tokens]
        self.detokenize.get_concrete_function(
            tf.TensorSpec(shape=[None, None], dtype=tf.int64))
        self.detokenize.get_concrete_function(
            tf.RaggedTensorSpec(shape=[None, None], dtype=tf.int64))
        
        self.lookup.get_concrete_function(
            tf.TensorSpec(shape=[None, None], dtype=tf.int64))
        self.lookup.get_concrete_function(
            tf.RaggedTensorSpec(shape=[None, None], dtype=tf.int64))
        

        # get_* methods take no argument
        self.get_vocab_size.get_concrete_function()
        self.get_vocab_path.get_concrete_function()
        self.get_reserved_tokens.get_concrete_function()

    @tf.function
    def tokenize(self, strings):
        encoded = self.tokenizer.tokenize(strings)
        merged_enc = encoded.merge_dims(-2, -1)
        merg_enc_start_end = add_start_end(merged_enc)
        return merg_enc_start_end
    
    @tf.function
    def detokenize(self, tokenized):
        words = self.tokenizer.detokenize(tokenized)
        return cleanup_text(self._reserved_tokens, words)
    
    @tf.function
    def lookup(self, token_ids):
        return tf.gather(self.vocab, token_ids)
    
    @tf.function
    def get_vocab_size(self):
        return tf.shape(self.vocab)[0]
    
    @tf.function
    def get_vocab_path(self):
        return self._vocab_path
    
    @tf.function
    def get_reserved_tokens(self):
        return tf.constant(self._reserved_tokens)

## Training

In [None]:
class TrainState:
    """Track number of steps, examples, and tokens processed"""

    step: int = 0 # Steps in the current epoch
    accum_step: int = 0 # Number of gradient accumulation steps
    samples: int = 0 # total number of examples used
    tokens: int = 0 # total number of tokens processed