# Notes
- I finally understood, that during traingin each next token is calculated simultaneously for the whole sentence, such that no sequential processing is needed. That is of course redundant for inference. 

# Imports

In [8]:
# logging and decorators
import logging as log

# general modules
import numpy as np

# tensorflow modules
import tensorflow as tf
import tensorflow_text as tf_text
from tensorflow.keras import layers
import tensorflow_datasets as tfds

from train_model import ModelTrainer, StoryTokenizer, DatasetGeneratorAlt, VisualWrapper

# Settings

In [11]:
# logging settings
log.basicConfig(
    format='%(asctime)s %(levelname)-8s %(processName)s %(threadName)s %(funcName)-20s %(message)s',
        # log.INFO for normal run
    level=log.INFO,
        # log.DEBUG for diagnostics
    level=log.DEBUG,
    datefmt='%Y-%m-%d %H:%M:%S')

# paths
train_file_path = "datasets/bookscorpusopen/processed_512"
val_file_path = "datasets/bookscorpusopen/processed_512/"

# tokenizer
vocab_path = 'datasets/vocab.txt'
reserved_tokens = ["[PAD]", "[UNK]", "[START]", "[END]"]

### Set-up model

In [None]:
model_trainer = ModelTrainer(StoryTokenizer(reserved_tokens, vocab_path),
                             DatasetGeneratorAlt,
                             train_file_path,
                             val_file_path,
                             n_train_files=350,
                             n_val_files=41,
                             dataset_lines_per_file=10000,
                             train_val_test_size=(1,1,0),
                             d_model=512,
                             n_stacks=6,
                             h_att=8,
                             max_padding=512,
                             global_batch_size=32,
                             warmup_steps=4000,
                             n_epochs=10,
                             initial_epoch=0,
                             verbosity='auto',
                             distributed_strategy=tf.distribute.MultiWorkerMirroredStrategy(),
                             load_model=True,
                             save_model=False,
                             model_load_path=None)

# Inference

## WordComplete model

In [15]:
class WordComplete(tf.Module, VisualWrapper):
  """
    This class defines a complete sequence generation model for a Transformer. 
    It uses a given tokenizer and Transformer model to generate sequences.
  """
  def __init__(self, 
               tokenizer, 
               transformer, 
               max_length=512, 
               dtype=tf.Tensor, 
               decode_result=True):
    """
    Args:
        tokenizer (Tokenizer):          Tokenizer object to convert raw text into tokens.
        transformer (tf.keras.Model):   A Transformer model used for sequence generation.
        max_length (int, optional):     The maximum length of sequences that can be generated.
                                        Default is 512.
        dtype (tf.Tensor, optional):    The datatype of the output tensor. Default is tf.Tensor.
        decode_result (bool, optional): If True, decode the output tensor into a string. 
                                        Default is True.
    """
    log.debug(f'initialize {self.__class__.__name__}')
    super().__init__()
    VisualWrapper.__init__(self, vis_on_count=None)
    self.tokenizer = tokenizer
    self.transformer = transformer
    self.max_length = max_length
    self.dtype = dtype
    self.decode_result = decode_result

  def __call__(self, input, decode=True, encoding='utf-8', training=False):
    """
    Performs the sequence generation.

    Args:
        input (str or tf.Tensor):   The input sequence.
        decode (bool, optional):    If True, the output sequence is decoded into a string. 
                                    Default is True.
        encoding (str, optional):   The encoding to use when decoding the output sequence. 
                                    Default is 'utf-8'.
        training (bool, optional):  Whether the model is currently training. Default is None.

    Returns:
        text (str or tf.Tensor):    The generated text. If decode_result is True, this is a string.
                                    Otherwise, it is a tensor.
        tokens (tf.Tensor):         The tensor of generated tokens.
    """
    VisualWrapper.should_visualize = True
    
    # TODO: Bug with empty strings as input
    # Convert input to tensor if it is not already
    # Create a dynamic tensor to store output
    # Make sure tensor_input is 2-D
    tensor_input = tf.convert_to_tensor(input)
    output_array = tf.TensorArray(dtype=tf.int64, size=0, dynamic_size=True)
    if len(tensor_input.shape) == 0:
      tensor_input = tensor_input[tf.newaxis]

    # tokenize and encode input
    # Identify end token of the input
    tokenized_input = self.tokenizer.tokenize(tensor_input, training=training).to_tensor()
    context = self.transformer.encode(tokenized_input, None, training=training)
    end = tokenized_input[-1][-1]

    # Write the input tokens (excluding the last one) to the output array
    for i, value in enumerate(tokenized_input[0][:-1]):
      output_array = output_array.write(i, value)

    # Start the generation of sequence from the last position of the input to max_length
    for i in tf.range(output_array.size(), self.max_length):

      # Prepare input for decoder
      # Decode the input
      dec_input = output_array.concat()[tf.newaxis]
      decode = self.transformer.decode(context, None, dec_input, None, training=training)

      # Create logits predictions and select the last predicted token
      predictions = self.transformer.generator(decode, training=training)
      predictions = predictions[:, -1:, :]  # Shape `(batch_size, 1, vocab_size)`.
      predicted_id = tf.argmax(predictions, axis=-1)

      # Concatenate the `predicted_id` to the output which is given to the decoder as its input again.
      output_array = output_array.write(i, predicted_id[0][0])

      # break the loop, if [End] token is predicted
      if predicted_id == end:
        break
    
    # Create a tensor for detokenization
    # Detokenize
    # Create tokens from detokenized output again
    output = output_array.concat()[tf.newaxis]
    text = self.tokenizer.detokenize(output)
    tokens = self.tokenizer.lookup(output)

    # If decode_result is True, decode the text tensor into a string
    if self.decode_result:
      text = text.numpy()[0].decode(encoding)

    # reset visualisation
    VisualWrapper.should_visualize = False
    VisualWrapper.reset_counter()

    return text, tokens

## Text inference

In [16]:
inference_model = WordComplete(StoryTokenizer(reserved_tokens, vocab_path), model_trainer.model, max_length=32)

string = "Can we do "

text, tokens = inference_model(string)

print(text)

2023-06-09 11:55:17 DEBUG    MainProcess MainThread __init__             initialize StoryTokenizer
2023-06-09 11:55:17 DEBUG    MainProcess MainThread __init__             initialize StoryTokenizer
2023-06-09 11:55:17 DEBUG    MainProcess MainThread __init__             initialize WordComplete
2023-06-09 11:55:17 DEBUG    MainProcess MainThread __init__             initialize WordComplete
2023-06-09 11:55:17 DEBUG    MainProcess MainThread tokenize             execute
2023-06-09 11:55:17 DEBUG    MainProcess MainThread add_start_end        execute
2023-06-09 11:55:17 DEBUG    MainProcess MainThread lookup               execute
2023-06-09 11:55:17 DEBUG    MainProcess MainThread visualize_data       execute
2023-06-09 11:55:17 DEBUG    MainProcess MainThread encode               execute
2023-06-09 11:55:17 DEBUG    MainProcess MainThread visualize_data       execute
2023-06-09 11:55:17 DEBUG    MainProcess MainThread visualize_data       execute
2023-06-09 11:55:17 DEBUG    MainProcess 

[START] can we do [END]
