# Imports

In [None]:
import logging as log
import functools
from time import time

import pathlib

import numpy as np
import tensorflow as tf
import tensorflow_text as tf_text
from tensorflow_text.tools.wordpiece_vocab import bert_vocab_from_dataset as bert_vocab

# import matplotlib.pyplot as plt

# Notebook settings

In [None]:
log.basicConfig(
    format='%(asctime)s %(levelname)-8s %(message)s',
    level=log.INFO,
    datefmt='%Y-%m-%d %H:%M:%S')

log_enabled = True
run_helper = False

## Decorators

In [None]:
def log_dec(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        try:
            if log_enabled:
                start_time = time()
                log.info('{} started'.format(func.__name__))
            return func(*args, **kwargs)
        except Exception as ex:
            raise ex
        finally:
            if log_enabled:
                duration = time() - start_time
                log.info('{} finished'.format(func.__name__))
    return wrapper

def run_helper(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        if run_helper:
            return func(*args, **kwargs)
        else:
            return
    return wrapper

# Tokenizer pipeline


Dataset from tensorflow as described in [https://www.tensorflow.org/text/tutorials/text_generation](https://www.tensorflow.org/text/tutorials/text_generation)

## Create vocab from dataset

Load textfile as dataset and create vocab for it.

In [None]:
dataset_path = 'datasets\\corpus.txt'

In [None]:
@log_dec
def load_dataset(dataset_text_file):
    return tf.data.TextLineDataset(filenames=dataset_text_file)

@log_dec
def create_vocab(dataset):
    bert_vocab_args=dict(
        vocab_size = 8000,
        reserved_tokens = ["[PAD]", "[UNK]", "[START]", "[END]"],
        bert_tokenizer_params = dict(lower_case=True),
        learn_params = {},
    )

    story_vocab = bert_vocab.bert_vocab_from_dataset(
        dataset.batch(1000).prefetch(2),
        **bert_vocab_args
    )
    return story_vocab

@run_helper
@log_dec
def create_vocab_from_textdata(text_file=dataset_path):
    dataset = load_dataset(text_file)
    vocab = create_vocab(dataset)
    return vocab

@run_helper
@log_dec
def write_vocab_file(filepath, vocab):
    with open(filepath, 'w') as file:
        for token in vocab:
            print(token, file=file)

#write_vocab_file('datasets\\vocab.txt', create_vocab_from_textdata())

## Create tokenizer class from vocab

In [None]:
@log_dec
def add_start_end(ragged):
    START = tf.argmax(tf.constant(reserved_tokens) == "[START]")
    END = tf.argmax(tf.constant(reserved_tokens) == "[END]")

    count = ragged.bounding_shape()[0]
    starts = tf.fill([count, 1], START)
    ends = tf.fill([count, 1], END)
    return tf.concat([starts, ragged, ends], axis=1)

@log_dec
def cleanup_text(reserved_tokens, token_txt):
    bad_tokens = list(filter(lambda token: token != "[UNK]", reserved_tokens))
    bad_tokens_re = "|".join(bad_tokens)

    bad_cells = tf.strings.regex_full_match(token_txt, bad_tokens_re)
    ragged_result = tf.ragged.boolean_mask(token_txt, ~bad_cells)

    result = tf.strings.reduce_join(ragged_result, separator=' ', axis=-1)

    return result

## Tokenizer class

In [None]:
class StoryTokenizer(tf.Module):
    def __init__(self, reserved_tokens, vocab_path):
        super().__init__()
        self.tokenizer = tf_text.BertTokenizer(vocab_path, lower_case=True)
        self._reserved_tokens = reserved_tokens
        self._vocab_path = tf.saved_model.Asset(vocab_path)

        vocab = pathlib.Path(vocab_path).read_text().splitlines()
        self.vocab = tf.Variable(vocab)

        ## Create the signatures for export:

        # tokenize signature for a batch of strings
        self.tokenize.get_concrete_function(
            tf.TensorSpec(shape=[None], dtype=tf.string))
        
        # detokenize and lookup signature for:
        # * Tensor with shape [tokens] and [batch, tokens]
        # * RaggedTensor with shape [batch, tokens]
        self.detokenize.get_concrete_function(
            tf.TensorSpec(shape=[None, None], dtype=tf.int64))
        self.detokenize.get_concrete_function(
            tf.RaggedTensorSpec(shape=[None, None], dtype=tf.int64))
        
        self.lookup.get_concrete_function(
            tf.TensorSpec(shape=[None, None], dtype=tf.int64))
        self.lookup.get_concrete_function(
            tf.RaggedTensorSpec(shape=[None, None], dtype=tf.int64))
        

        # get_* methods take no argument
        self.get_vocab_size.get_concrete_function()
        self.get_vocab_path.get_concrete_function()
        self.get_reserved_tokens.get_concrete_function()

    @tf.function
    def tokenize(self, strings):
        encoded = self.tokenizer.tokenize(strings)
        merged_enc = encoded.merge_dims(-2, -1)
        merg_enc_start_end = add_start_end(merged_enc)
        return merg_enc_start_end
    
    @tf.function
    def detokenize(self, tokenized):
        words = self.tokenizer.detokenize(tokenized)
        return cleanup_text(self._reserved_tokens, words)
    
    @tf.function
    def lookup(self, token_ids):
        return tf.gather(self.vocab, token_ids)
    
    @tf.function
    def get_vocab_size(self):
        return tf.shape(self.vocab)[0]
    
    @tf.function
    def get_vocab_path(self):
        return self._vocab_path
    
    @tf.function
    def get_reserved_tokens(self):
        return tf.constant(self._reserved_tokens)

In [None]:
reserved_tokens = ["[PAD]", "[UNK]", "[START]", "[END]"]
vocab_path = 'datasets\\vocab.txt'

tokenizer = StoryTokenizer(reserved_tokens, vocab_path)

@log_dec
def test_tokenizer(tokenizer):
    dataset = load_dataset('datasets\\corpus.txt')
    dataset_short = dataset.take(2)
    token_batch = list(map(lambda x: tokenizer.tokenize(x), dataset_short))
    text = list(map(lambda x: tokenizer.detokenize(x), token_batch))
    return token_batch, text

test_tokenizer(tokenizer)

In [None]:
model_name = 'story_corpus_tokenizer'
tf.saved_model.save(tokenizer, model_name)

In [None]:
model_name = 'story_corpus_tokenizer'
reload_story_tokenizer = tf.saved_model.load(model_name)
reload_story_tokenizer.get_vocab_size().numpy()

In [None]:
tokens = reload_story_tokenizer.tokenize(['Hello TensorFlow!'])
tokens.numpy()

In [None]:
text_tokens = reload_story_tokenizer.lookup(tokens)
text_tokens