# Imports

In [None]:
import logging as log
import functools
from time import time

import os

import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow_text as tf_text
from tensorflow_text.tools.wordpiece_vocab import bert_vocab_from_dataset as bert_vocab


# Utility and Settings

## Settings

In [None]:
log.basicConfig(
    format='%(asctime)s %(levelname)-8s %(message)s',
    level=log.INFO,
    datefmt='%Y-%m-%d %H:%M:%S')

log_enabled = True
execute_helper = False

## Decorators

In [None]:
def log_dec(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        try:
            if log_enabled:
                start_time = time()
                log.info('{} started'.format(func.__name__))
            return func(*args, **kwargs)
        except Exception as ex:
            raise ex
        finally:
            if log_enabled:
                duration = time() - start_time
                log.info('{} finished'.format(func.__name__))
    return wrapper

def run_helper(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        if execute_helper:
            return func(*args, **kwargs)
        else:
            return
    return wrapper

# Data preparation

## Load dataset and model elements

### Tokenizer
Load the tokenizer from file

In [None]:
model_name = 'story_corpus_tokenizer'

tokenizer = tf.saved_model.load(model_name)

### Text Dataset
Load the txt dataset from file

In [None]:
dataset_path = 'datasets\\corpus.txt'

@log_dec
def load_dataset(dataset_text_file):
    return tf.data.TextLineDataset(filenames=dataset_text_file)

dataset = load_dataset(dataset_path)

Plotting the length of the different data samples.

In [None]:
lengths = []
for example in dataset.batch(1024):
    tokens = tokenizer.tokenize(example)
    lengths.append(tokens.row_lengths())

In [None]:
all_lengths = np.concatenate(lengths)
plt.hist(all_lengths, np.linspace(0, 500, 101))
plt.ylim(plt.ylim())
max_length = max(all_lengths)
med_length = np.mean(all_lengths)
plt.plot([max_length, max_length], plt.ylim())
plt.title(f'Maximum tokens per example: {max_length}');
plt.plot([med_length, med_length], plt.ylim())
plt.title(f'Average tokens per example: {med_length}');

## Data batching

This will be a stub for now, we first have to define the components in order to get more information how a batch has to look in tensorflow.

In [None]:
class Batch:
    """Object for holding a batch of data with mask during training"""

    def __init__(self, src, tgt=None, pad=2):
        self.src = src
        self.src_mask = (src != pad).unsqueeze(-2)
        

In [None]:
MAX_TOKENS = 128
BUFFER_SIZE = 20000
BATCH_SIZE = 64

def prepare_batch(data_entry):
    tokens = tokenizer.tokenize(data_entry)
    trim_to_max = tokens[:, :MAX_TOKENS]
    trim_to_max_and_one = tokens[:, :(MAX_TOKENS+1)]

    encoder_input = trim_to_max.to_tensor()

    decoder_input = trim_to_max_and_one[:, :-1].to_tensor()
    decoder_output = trim_to_max_and_one[:, 1:].to_tensor()

    return (encoder_input, decoder_input), decoder_output

def make_batches(dataset):
    return dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE).map(prepare_batch, tf.data.AUTOTUNE).prefetch(buffer_size=tf.data.AUTOTUNE)

In [None]:
batched_dataset = make_batches(dataset)
for (enc_in, dec_in), dec_out in batched_dataset:
    break

print(dec_in)

# Architecture components

## Container classes

### Encoder-Decoder

In [None]:
class EncoderDecoder(tf.keras.layers.Layer):
    def __init__(self, 
                 encoder: tf.keras.layers.Layer, 
                 decoder: tf.keras.layers.Layer, 
                 enc_embed: tf.keras.layers.Layer, 
                 dec_embed: tf.keras.layers.Layer, 
                 generator: tf.keras.layers.Layer,
                 pad_mask=None,
                 subseq_mask=None, 
                 trainable=True, 
                 name=None, 
                 dtype=None, 
                 dynamic=False, 
                 **kwargs):
        super().__init__(trainable, name, dtype, dynamic, **kwargs)

        # modules
        self.encoder = encoder
        self.decoder = decoder
        self.enc_embed = enc_embed
        self.dec_embed = dec_embed
        self.generator = generator

        # masking
        self.pad_mask = pad_mask
        self.subseq_mask = subseq_mask

    def encode(self, inputs, pad_mask=None):
        if pad_mask is None:
            pad_mask = self.pad_mask 

        return self.encoder(
            self.enc_embed(inputs), 
            pad_mask)
    
    def decode(self, inputs, enc_input, pad_mask=None, subseq_mask=None):
        if pad_mask is None:
            pad_mask = self.pad_mask 
        if subseq_mask is None:
            subseq_mask = self.subseq_mask 

        return self.decoder(
            self.dec_embed(inputs), 
            inputs, 
            pad_mask, 
            subseq_mask)

    def call(self, enc_input, dec_input, pad_mask=None, subseq_mask=None, *args, **kwargs):
        if pad_mask is None:
            pad_mask = self.pad_mask 
        if subseq_mask is None:
            subseq_mask = self.subseq_mask 

        return self.decode(dec_input, 
                           self.encode(enc_input, pad_mask), 
                           pad_mask, 
                           subseq_mask)

In [None]:
class 

## Positional Encoding

In [None]:
def positional_encoding(length, depth):
    depth = depth / 2

    positions = np.arange(length)[:, np.newaxis]   # (seq, 1)
    depths = np.arange(depth)[np.newaxis, :]/depth  # (1, depth)

    angle_rates = 1 / (10000**depths)               # (1, depth)
    angle_rads  = positions * angle_rates           # (pos, depth)

    pos_encoding = np.concatenate(
        [np.sin(angle_rads), np.cos(angle_rads)],
        axis=-1
        )

    return tf.cast(pos_encoding, dtype=tf.float32)

### Positional encoding explanation

Show how positional encoding looks like

In [None]:
pos_encoding = positional_encoding(length=2048, depth=512)

# Check the shape.
print(pos_encoding.shape)

# Plot the dimensions.
plt.pcolormesh(pos_encoding.numpy().T, cmap='RdBu')
plt.ylabel('Depth')
plt.xlabel('Position')
plt.colorbar()
plt.show()

In [None]:
pos_encoding/=tf.norm(pos_encoding, axis=1, keepdims=True)
p = pos_encoding[1000]
dots = tf.einsum('pd,d -> p', pos_encoding, p)
plt.subplot(2,1,1)
plt.plot(dots)
plt.ylim([0,1])
plt.plot([950, 950, float('nan'), 1050, 1050],
         [0,1,float('nan'),0,1], color='k', label='Zoom')
plt.legend()
plt.subplot(2,1,2)
plt.plot(dots)
plt.xlim([950, 1050])
plt.ylim([0,1])

## Positional Embedding

In [None]:
class PositionalEmbedding(tf.keras.layers.Layer):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.d_model = d_model
        self.embedding = tf.keras.layers.Embedding(vocab_size, d_model, mask_zero=True)
        self.pos_encoding = positional_encoding(length=2048, depth=d_model)

    def call(self, x):
        length = tf.shape(x)[1]
        x = self.embedding(x)
        # This factor sets the relative scale of the embedding and positional_encoding
        x *=tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        x = x + self.pos_encoding[tf.newaxis, :length, :]
        return x

### Positional embedding example

An example how a vector is encoded.

In [None]:
embed = PositionalEmbedding(vocab_size=tokenizer.get_vocab_size(), d_model=512)
dec_in_embed = embed(dec_in)
print(dec_in_embed)

## Attention

In [None]:
def attention(query, key, value, mask=None, dropout=None):
    
    d_k = query.size(-1)
    scores = tf.matmul(query, key.transpose)

In [None]:
class BaseAttention(tf.keras.layers.Layer):
    def __init__(self, trainable=True, name=None, dtype=None, dynamic=False, **kwargs):
        super().__init__(trainable, name, dtype, dynamic, **kwargs)
        self.multi_head_attention = tf.keras.layers.MultiHeadAttention(**kwargs)
        self.layernorm = tf.keras.layers.LayerNormalization()
        self.add = tf.keras.layers.Add()