# Imports

In [44]:
import logging as log
import functools
from time import time

import os

import numpy as np
import tensorflow as tf
import tensorflow_text as tf_text
from tensorflow_text.tools.wordpiece_vocab import bert_vocab_from_dataset as bert_vocab

# import matplotlib.pyplot as plt

# Notebook settings

In [45]:
log.basicConfig(
    format='%(asctime)s %(levelname)-8s %(message)s',
    level=log.INFO,
    datefmt='%Y-%m-%d %H:%M:%S')

log_enabled = True
execute_helper = False

# Utility

## Decorators

In [46]:
def log_dec(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        try:
            if log_enabled:
                start_time = time()
                log.info('{} started'.format(func.__name__))
            return func(*args, **kwargs)
        except Exception as ex:
            raise ex
        finally:
            if log_enabled:
                duration = time() - start_time
                log.info('{} finished'.format(func.__name__))
    return wrapper

def run_helper(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        if execute_helper:
            return func(*args, **kwargs)
        else:
            return
    return wrapper

# Dataset pipeline


Dataset from tensorflow as described in [https://www.tensorflow.org/text/tutorials/text_generation](https://www.tensorflow.org/text/tutorials/text_generation)

## Test functions

In [47]:
@run_helper
def showcase_dataset():
    text = open('datasets\\corpus.txt', 'rb').read().decode(encoding='utf-8')
    print(f'Length ot text: {len(text)} characters')
    print(text[:250])
    vocab = sorted(set(text))
    print(f'{len(vocab)} unique characters')

# uncomment to run
# execute_helper = True
showcase_dataset()

## Load dataset from text

Load textfile as dataset and create vocab for it.

In [100]:
@log_dec
def load_dataset(dataset_text_file):
    return tf.data.TextLineDataset(filenames=dataset_text_file)

@log_dec
def create_vocab(dataset):
    bert_vocab_args=dict(
        vocab_size = 8000,
        reserved_tokens = ["[PAD]", "[UNK]", "[START]", "[END]"],
        bert_tokenizer_params = dict(lower_case=True),
        learn_params = {},
    )

    story_vocab = bert_vocab.bert_vocab_from_dataset(
        dataset.batch(1000).prefetch(2),
        **bert_vocab_args
    )
    return story_vocab

@run_helper
@log_dec
def create_vocab_from_textdata(text_file='datasets\\corpus.txt'):
    dataset = load_dataset(text_file)
    vocab = create_vocab(dataset)
    return vocab

@run_helper
@log_dec
def write_vocab_file(filepath, vocab):
    with open(filepath, 'w') as file:
        for token in vocab:
            print(token, file=file)

write_vocab_file('datasets\\vocab.txt', create_vocab_from_textdata())

2023-05-04 14:04:48 INFO     create_vocab_from_textdata started
2023-05-04 14:04:48 INFO     load_dataset started
2023-05-04 14:04:48 INFO     load_dataset finished
2023-05-04 14:04:48 INFO     create_vocab started
2023-05-04 14:11:10 INFO     create_vocab finished
2023-05-04 14:11:10 INFO     create_vocab_from_textdata finished
2023-05-04 14:11:10 INFO     write_vocab_file started
2023-05-04 14:11:10 INFO     write_vocab_file finished


## Create tokenizer from vocab

In [None]:
@log_dec
def create_tokenizer():
    bert_tokenizer_params = dict(lower_case=True)
    story_tokenizer = tf_text.BertTokenizer('datasets\\vocab.txt', **bert_tokenizer_params)
    return story_tokenizer

@run_helper
@log_dec
def test_tokenizer(tokenizer):
    dataset = load_dataset('datasets\\corpus.txt')
    dataset_short = dataset.take(2)
    token_batch = list(map(lambda x: tokenizer.tokenize(x).merge_dims(-2, -1), dataset_short))
    return token_batch

tokenizer = create_tokenizer()
batch = test_tokenizer(tokenizer)
txt_tokens = tf.gather(create_vocab_from_textdata(), test_tokenizer())
tf.strings.reduce_join(txt_tokens, separator=' ', axis=-1)


# High-level architecture

### Positional Encoding

In [2]:
def positional_encoding(length, depth):
    depth = depth / 2

    positions = np.arrange(length)[:, np.newaxis]   # (seq, 1)
    depths = np.arange(depth)[np.newaxis, :]/depth  # (1, depth)

    angle_rates = 1 / (10000**depths)               # (1, depth)
    angle_rads  = positions * angle_rates           # (pos, depth)

    pos_encoding = np.concatenate(
        [np.sin(angle_rads), np.cos(angle_rads)],
        axis=-1
        )

    return tf.cast(pos_encoding, dtype=tf.float32)

In [3]:
class PositionalEmbedding(tf.keras.layers.Layer):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.d_model = d_model
        self.embedding = tf.keras.layers.Embedding(vocab_size, d_model, mask_zero=True)
        self.pos_encoding = positional_encoding(length=2048, depth=d_model)

    def call(self, x):
        length = tf.shape(x)[1]
        x = self.embedding(x)
        x *=tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        x = x + self.pos_encoding[tf.newaxis, :length, :]