# Imports

In [26]:
import logging as log
import functools
from time import time

import os

import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow_text as tf_text
from tensorflow_text.tools.wordpiece_vocab import bert_vocab_from_dataset as bert_vocab


2023-05-04 17:15:47 INFO     generated new fontManager


# Utility and Settings

## Settings

In [27]:
log.basicConfig(
    format='%(asctime)s %(levelname)-8s %(message)s',
    level=log.INFO,
    datefmt='%Y-%m-%d %H:%M:%S')

log_enabled = True
execute_helper = False

## Decorators

In [28]:
def log_dec(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        try:
            if log_enabled:
                start_time = time()
                log.info('{} started'.format(func.__name__))
            return func(*args, **kwargs)
        except Exception as ex:
            raise ex
        finally:
            if log_enabled:
                duration = time() - start_time
                log.info('{} finished'.format(func.__name__))
    return wrapper

def run_helper(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        if execute_helper:
            return func(*args, **kwargs)
        else:
            return
    return wrapper

# Architecture

## Load dataset and model elements

### Tokenizer
Load the tokenizer from file

In [29]:
model_name = 'story_corpus_tokenizer'

tokenizer = tf.saved_model.load(model_name)

### Text Dataset
Load the txt dataset from file

In [30]:
dataset_path = 'datasets\\corpus.txt'

@log_dec
def load_dataset(dataset_text_file):
    return tf.data.TextLineDataset(filenames=dataset_text_file)

dataset = load_dataset(dataset_path)

Plotting the length of the different data samples.

In [None]:
lengths = []
for example in dataset.batch(1024):
    tokens = tokenizer.tokenize(example)
    lengths.append(tokens.row_lengths())

all_lengths = np.concatenate(lengths)
plt.hist(all_lengths, np.linspace(0, 500, 101))
plt.ylim(plt.ylim())
max_length = max(all_lengths)
plt.plot([max_length, max_length], plt.ylim())
plt.title(f'Maximum tokens per example: {max_length}');

## Data batching

## Positional Encoding

In [None]:
def positional_encoding(length, depth):
    depth = depth / 2

    positions = np.arrange(length)[:, np.newaxis]   # (seq, 1)
    depths = np.arange(depth)[np.newaxis, :]/depth  # (1, depth)

    angle_rates = 1 / (10000**depths)               # (1, depth)
    angle_rads  = positions * angle_rates           # (pos, depth)

    pos_encoding = np.concatenate(
        [np.sin(angle_rads), np.cos(angle_rads)],
        axis=-1
        )

    return tf.cast(pos_encoding, dtype=tf.float32)

In [None]:
class PositionalEmbedding(tf.keras.layers.Layer):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.d_model = d_model
        self.embedding = tf.keras.layers.Embedding(vocab_size, d_model, mask_zero=True)
        self.pos_encoding = positional_encoding(length=2048, depth=d_model)

    def call(self, x):
        length = tf.shape(x)[1]
        x = self.embedding(x)
        x *=tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        x = x + self.pos_encoding[tf.newaxis, :length, :]