# Sebastian Petrik - Abstractive summarization transformer

Inspired by:
- https://www.kaggle.com/code/ashishsingh226/text-summarization-with-transformers
- https://www.tensorflow.org/text/tutorials/transformer


## Setup

In [None]:
!pip install --upgrade -q wandb --quiet
!pip install evaluate rouge_score --quiet

In [None]:
import os
print(os.environ.get('KAGGLE_CONTAINER_NAME')) # check if kaggle

In [None]:
import pkg_resources
sorted(list(filter(
    lambda x: x[0] in ['numpy', 'pandas', 'tensorflow', 'tensorflow-text', 'keras', 'tensorflow-estimator', 'tensorflow-datasets', 'wandb', 'evaluate', 'rouge_score'],
    [(i.key, i.version) for i in pkg_resources.working_set]
)))

In [None]:
# Imports
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from collections import defaultdict
import string
import tensorflow as tf
import re
import os
import time
from tensorflow import keras
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import ModelCheckpoint
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split
import operator as op
import wandb
from wandb.keras import WandbCallback
from pprint import pprint
import evaluate

## Configuration and Wandb

In [None]:
# Setup seeds
os.environ['TF_CUDNN_DETERMINISTIC'] = '1' 
np.random.seed(42)
tf.random.set_seed(42)

In [None]:
CONFIG = dict(
    # Meta
    wandb_project = 'stranasum-exploring',
    wandb_group = '-',
    host = 'kaggle',
    
    # Data
    dataset_name = "inshorts_nodot_10-70_3-16_v0.05_t0.05",
    val_split = 0.05,
    test_split = 0.05,
    
    # Sequences
    # - define lengths according to data
    max_input_length = 70, # Encoder sequence length, max article len, max token count in E.
    max_target_length = 16, # Decoder sequence length, max summary len, max token count in D.

    # Transformer hyperparameters
    num_layers = 3, # 4
    d_model = 128, # 128 -> embedding length
    dff = 512, # 512
    num_heads = 8, # 8
    dropout_rate = 0.1, # 0.1

    # Training
    early_stopping_patience = 3, # patience - num of non-improving consecutive epochs
    max_epochs = 20, # 15
    batch_size = 256, # 256
    learning_rate_warmup_steps = 4000 # 4000
)

print("Config:")
pprint(CONFIG)

In [None]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
wandb.login(key=user_secrets.get_secret("wandb"))

## Data loading

- load data from preprocessing

In [None]:
# show available data
!ls ../input
!ls ../input/pmxy-stranasum-preprocessing

In [None]:
dataset_dir = "../input/pmxy-stranasum-preprocessing"
df_train = pd.read_csv(f"{dataset_dir}/{CONFIG['dataset_name']}_train.csv")
df_val = pd.read_csv(f"{dataset_dir}/{CONFIG['dataset_name']}_val.csv")
df_test = pd.read_csv(f"{dataset_dir}/{CONFIG['dataset_name']}_test.csv")

print("Train:", df_train.shape)
print("Val:", df_val.shape)
print("Test:", df_test.shape)

# Tokenization

In [None]:
# Custom tokenizer
class Tokenizer(tf.Module):
    
    def __init__(self, vectorization_dataset: tf.data.Dataset, max_length: int):
        super().__init__(name="Tokenizer")
        
        self.max_length = max_length
        
        # Create and fit vectorizer from dataset
        self.vectorizer = tf.keras.layers.TextVectorization(
            output_mode='int',
            output_sequence_length=self.max_length,
            standardize=None
        )
        self.vectorizer.adapt(vectorization_dataset.batch(1024))
        
        # conversions
        self.word_to_id = tf.keras.layers.StringLookup(
            vocabulary=self.vectorizer.get_vocabulary(),
            mask_token='', oov_token='[UNK]'
        )
        self.id_to_word = tf.keras.layers.StringLookup(
            vocabulary=self.vectorizer.get_vocabulary(),
            mask_token='', oov_token='[UNK]',
            invert=True
        )
        
        # attributes
        self.vocab_size = self.vectorizer.vocabulary_size()
        self.start_token = self.word_to_id('<sos>')
        self.end_token = self.word_to_id('<eos>')
        
        print(f"Tokenizer maxlen={self.max_length}, top vocabulary: {self.vectorizer.get_vocabulary()[:10]}")
        
        
    # to convert text to tokens, call vectorizer directly !
        
    # convert tokens back to text
    @tf.function
    def tokens_to_text(self, tokens):
        words = self.id_to_word(tokens)
        result = tf.strings.reduce_join(words, axis=-1, separator=' ')
        result = tf.strings.regex_replace(result, '^ *<sos> *', '')
        result = tf.strings.regex_replace(result, ' *<eos> *$', '')
        result = tf.strings.regex_replace(result, '<dot>', '.')
        return result

# Setup tokenizers

input_tokenizer = Tokenizer(
    tf.data.Dataset.from_tensor_slices(df_train['article']),
    CONFIG['max_input_length']
)

target_tokenizer = Tokenizer(
    tf.data.Dataset.from_tensor_slices(df_train['summary']),
    CONFIG['max_target_length']
)

In [None]:
# try out 
sample_tokens = input_tokenizer.vectorizer(tf.constant(['<sos> the dog ate the food <eos>']))
sample_tokens

In [None]:
input_tokenizer.tokens_to_text(sample_tokens).numpy().astype('str') # binary string tensor into decoded string array

In [None]:
target_tokenizer.vectorizer(tf.constant('<sos> a man was murdered <eos>'))

In [None]:
def convert_to_tokenized_sequences(input_texts: np.array, target_texts: np.array):
    
    # convert numpy text arrays into vectors
    inputs = input_tokenizer.vectorizer(input_texts)
    targets = target_tokenizer.vectorizer(target_texts)
    
    # drop EOS token
    targets_inputs = targets[:,:-1]
    
    # drop SOS token, shifting sequence 1 step behind providing next word labels for each step
    targets_labels = targets[:,1:]
    
    return (inputs, targets_inputs), targets_labels

In [None]:
# Apply tokenization and create batched tf datasets ...

# convert to sequences and then into a tensorflow dataset
# shuffle - shuffle with buffer size = size of data for full uniform shuffle

dataset_train = tf.data.Dataset.from_tensor_slices(
    convert_to_tokenized_sequences(df_train['article'], df_train['summary'])
).shuffle(df_train.shape[0]).batch(CONFIG['batch_size'])

dataset_val = tf.data.Dataset.from_tensor_slices(
    convert_to_tokenized_sequences(df_val['article'], df_val['summary'])
).shuffle(df_val.shape[0]).batch(CONFIG['batch_size'])

# take 1 sample batch for further inspection (later)
for (sample_input, sample_target), sample_target_labels in dataset_train.take(1):
    break
    
print('Sample single batch from dataset:')

print(sample_input.shape)
print(sample_target.shape)
print(sample_target_labels.shape)

print(sample_target[0])
print(sample_target_labels[0])

In [None]:
# try to get special token numbers so we can see them in prints
print(input_tokenizer.vectorizer(tf.constant('<sos> <eos> <unk> <dot> <pad>')))
print(target_tokenizer.vectorizer(tf.constant('<sos> <eos> <unk> <dot> <pad>')))

## Transformer model implementation

In [None]:
def positional_encoding(length, depth):
    depth = depth/2

    positions = np.arange(length)[:, np.newaxis]     # (seq, 1)
    depths = np.arange(depth)[np.newaxis, :]/depth   # (1, depth)

    angle_rates = 1 / (10000**depths)         # (1, depth)
    angle_rads = positions * angle_rates      # (pos, depth)

    pos_encoding = np.concatenate(
      [np.sin(angle_rads), np.cos(angle_rads)],
      axis=-1) 

    return tf.cast(pos_encoding, dtype=tf.float32)

# We try out the encoding func
pos_encoding = positional_encoding(length=2048, depth=512)
print('Positional encoding shape', pos_encoding.shape) # Check the shape.

# Plot the dimensions.
plt.pcolormesh(pos_encoding.numpy().T, cmap='RdBu')
plt.ylabel('Depth')
plt.xlabel('Position')
plt.colorbar()
plt.show()

In [None]:
class PositionalEmbedding(tf.keras.layers.Layer):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.d_model = d_model
        self.embedding = tf.keras.layers.Embedding(vocab_size, d_model, mask_zero=True) 
        self.pos_encoding = positional_encoding(length=2048, depth=d_model)

    def compute_mask(self, *args, **kwargs):
        return self.embedding.compute_mask(*args, **kwargs)

    def call(self, x):
        length = tf.shape(x)[1]
        x = self.embedding(x)
        # This factor sets the relative scale of the embedding and positonal_encoding.
        x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        x = x + self.pos_encoding[tf.newaxis, :length, :]
        return x 

# Try
# sample_emb_input = PositionalEmbedding(vocab_size=input_tokenizer.vocab_size, d_model=512)(sample_input)
# sample_emb_target = PositionalEmbedding(vocab_size=target_tokenizer.vocab_size, d_model=512)(sample_target)

# sample_emb_target._keras_mask

In [None]:
# base attention for further subclassing
class BaseAttention(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__()
        self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)
        self.layernorm = tf.keras.layers.LayerNormalization()
        self.add = tf.keras.layers.Add()

In [None]:
class CrossAttention(BaseAttention):
    
    # x = target sequence, context = context sequence
    def call(self, x, context):
        attn_output, attn_scores = self.mha(
            query=x,
            key=context,
            value=context,
            return_attention_scores=True)

        # Cache the attention scores for plotting later.
        self.last_attn_scores = attn_scores

        x = self.add([x, attn_output])
        x = self.layernorm(x)

        return x

# Try on sample
# sample_ca = CrossAttention(num_heads=2, key_dim=512)
# print(sample_emb_input.shape)
# print(sample_emb_target.shape)
# print(sample_ca(sample_emb_input, sample_emb_target).shape)

In [None]:
class GlobalSelfAttention(BaseAttention):
    def call(self, x):
        attn_output = self.mha(
            query=x,
            value=x,
            key=x)
        x = self.add([x, attn_output])
        x = self.layernorm(x)
        return x

# Try
# sample_gsa = GlobalSelfAttention(num_heads=2, key_dim=512)
# print(sample_emb_input.shape)
# print(sample_gsa(sample_emb_input).shape)

In [None]:
class CausalSelfAttention(BaseAttention):
    def call(self, x):
        attn_output = self.mha(
            query=x,
            value=x,
            key=x,
            use_causal_mask = True)
        x = self.add([x, attn_output])
        x = self.layernorm(x)
        return x
    
# sample_csa = CausalSelfAttention(num_heads=2, key_dim=512)
# print(sample_emb_target.shape)
# print(sample_csa(sample_emb_target).shape)

In [None]:
# test latter elements do not depend on earlier elements, making no difference
# if we remove the earlier elements before or after applying csa layer
# out1 = sample_csa(embed_target(sample_target[:, :3])) 
# out2 = sample_csa(embed_target(sample_target))[:, :3]
# tf.reduce_max(abs(out1 - out2)).numpy()

In [None]:
class FeedForward(tf.keras.layers.Layer):
    def __init__(self, d_model, dff, dropout_rate=0.1):
        super().__init__()
        self.seq = tf.keras.Sequential([
          tf.keras.layers.Dense(dff, activation='relu'),
          tf.keras.layers.Dense(d_model),
          tf.keras.layers.Dropout(dropout_rate)
        ])
        self.add = tf.keras.layers.Add()
        self.layer_norm = tf.keras.layers.LayerNormalization()

    def call(self, x):
        x = self.add([x, self.seq(x)])
        x = self.layer_norm(x) 
        return x

# Try
# sample_ffn = FeedForward(512, 2048)
# print(sample_emb_target.shape)
# print(sample_ffn(sample_emb_target).shape)

In [None]:
class EncoderLayer(tf.keras.layers.Layer):
    def __init__(self,*, d_model, num_heads, dff, dropout_rate=0.1):
        super().__init__()

        self.self_attention = GlobalSelfAttention(
            num_heads=num_heads,
            key_dim=d_model,
            dropout=dropout_rate)

        self.ffn = FeedForward(d_model, dff)

    def call(self, x):
        x = self.self_attention(x)
        x = self.ffn(x)
        return x
    
# sample_encoder_layer = EncoderLayer(d_model=512, num_heads=8, dff=2048)
# print(sample_encoder_layer(sample_emb_input).shape)

In [None]:
class Encoder(tf.keras.layers.Layer):
    def __init__(self, *, num_layers, d_model, num_heads,
                   dff, vocab_size, dropout_rate=0.1):
        
        super().__init__()

        self.d_model = d_model
        self.num_layers = num_layers

        self.pos_embedding = PositionalEmbedding(
            vocab_size=vocab_size, d_model=d_model)

        self.enc_layers = [
            EncoderLayer(d_model=d_model,
                         num_heads=num_heads,
                         dff=dff,
                         dropout_rate=dropout_rate)
            for _ in range(num_layers)
        ]
        
        self.dropout = tf.keras.layers.Dropout(dropout_rate)

    def call(self, x):
        # `x` is token-IDs shape: (batch_size, seq_len)
        x = self.pos_embedding(x)  # Shape `(batch_size, seq_len, d_model)`.

        # Add dropout.
        x = self.dropout(x)

        for i in range(self.num_layers):
            x = self.enc_layers[i](x)

        return x  # Shape `(batch_size, seq_len, d_model)`.

# Try
# sample_encoder = Encoder(num_layers=4,
#                          d_model=512,
#                          num_heads=8,
#                          dff=2048,
#                          vocab_size=input_tokenizer.vocab_size)

# sample_encoder_output = sample_encoder(sample_input, training=False)

# print(sample_input.shape)
# print(sample_encoder_output.shape)

In [None]:
class DecoderLayer(tf.keras.layers.Layer):
    
    def __init__(self, *, d_model, num_heads, dff, dropout_rate=0.1):
        super(DecoderLayer, self).__init__()

        self.causal_self_attention = CausalSelfAttention(
            num_heads=num_heads,
            key_dim=d_model,
            dropout=dropout_rate)

        self.cross_attention = CrossAttention(
            num_heads=num_heads,
            key_dim=d_model,
            dropout=dropout_rate)

        self.ffn = FeedForward(d_model, dff)

    def call(self, x, context):
        x = self.causal_self_attention(x=x)
        x = self.cross_attention(x=x, context=context)

        # Cache the last attention scores for plotting later
        self.last_attn_scores = self.cross_attention.last_attn_scores

        x = self.ffn(x)  # Shape `(batch_size, seq_len, d_model)`.
        return x
    
# sample_decoder_layer = DecoderLayer(d_model=512, num_heads=8, dff=2048)
# sample_decoder_layer_output = sample_decoder_layer(
#     x=sample_emb_target, context=sample_emb_input
# )

# print(sample_emb_input.shape)
# print(sample_emb_target.shape)
# print(sample_decoder_layer_output.shape)

In [None]:
class Decoder(tf.keras.layers.Layer):
    def __init__(self, *, num_layers, d_model, num_heads, dff, vocab_size,
               dropout_rate=0.1):
        super(Decoder, self).__init__()

        self.d_model = d_model
        self.num_layers = num_layers

        self.pos_embedding = PositionalEmbedding(vocab_size=vocab_size,
                                                 d_model=d_model)
        self.dropout = tf.keras.layers.Dropout(dropout_rate)
        self.dec_layers = [
            DecoderLayer(d_model=d_model, num_heads=num_heads,
                         dff=dff, dropout_rate=dropout_rate)
            for _ in range(num_layers)]

        self.last_attn_scores = None

    def call(self, x, context):
        # `x` is token-IDs shape (batch, target_seq_len)
        x = self.pos_embedding(x)  # (batch_size, target_seq_len, d_model)

        x = self.dropout(x)

        for i in range(self.num_layers):
            x  = self.dec_layers[i](x, context)

        self.last_attn_scores = self.dec_layers[-1].last_attn_scores

        # The shape of x is (batch_size, target_seq_len, d_model).
        return x

# Try
# sample_decoder = Decoder(num_layers=4,
#                          d_model=512,
#                          num_heads=8,
#                          dff=2048,
#                          vocab_size=target_tokenizer.vocab_size)

# sample_decoder_output = sample_decoder(x=sample_target, context=sample_emb_input)

# print(sample_target.shape)
# print(sample_emb_input.shape)
# print(sample_decoder_output.shape)

In [None]:
# sample_decoder.last_attn_scores.shape

In [None]:
class Transformer(tf.keras.Model):
    def __init__(self, *, num_layers, d_model, num_heads, dff,
               input_vocab_size, target_vocab_size, dropout_rate=0.1):
        
        super().__init__()
        self.encoder = Encoder(num_layers=num_layers, d_model=d_model,
                               num_heads=num_heads, dff=dff,
                               vocab_size=input_vocab_size,
                               dropout_rate=dropout_rate)

        self.decoder = Decoder(num_layers=num_layers, d_model=d_model,
                               num_heads=num_heads, dff=dff,
                               vocab_size=target_vocab_size,
                               dropout_rate=dropout_rate)

        self.final_layer = tf.keras.layers.Dense(target_vocab_size)

    def call(self, inputs):
        # To use a Keras model with `.fit` you must pass all your inputs in the
        # first argument.
        context, x  = inputs

        context = self.encoder(context)  # (batch_size, context_len, d_model)

        x = self.decoder(x, context)  # (batch_size, target_len, d_model)

        # Final linear layer output.
        logits = self.final_layer(x)  # (batch_size, target_len, target_vocab_size)

        try:
            # Drop the keras mask, so it doesn't scale the losses/metrics.
            # b/250038731
            del logits._keras_mask
        except AttributeError:
            pass

        # Return the final output and the attention weights.
        return logits


In [None]:
# Construct transformer
transformer = Transformer(
    num_layers=CONFIG['num_layers'],
    d_model=CONFIG['d_model'],
    num_heads=CONFIG['num_heads'],
    dff=CONFIG['dff'],
    input_vocab_size=input_tokenizer.vocab_size,
    target_vocab_size=target_tokenizer.vocab_size,
    dropout_rate=CONFIG['dropout_rate']
)

In [None]:
input_tokenizer.vocab_size, target_tokenizer.vocab_size

In [None]:
# Call transformer on sample input, this will build it and setup inputs

sample_transformer_output = transformer((sample_input, sample_target))
print(sample_input.shape)
print(sample_target.shape)
print(sample_transformer_output.shape)

In [None]:
# sample_transformer_attn_scores = transformer.decoder.dec_layers[-1].last_attn_scores
# print(sample_transformer_attn_scores.shape) 

In [None]:
transformer.summary()

In [None]:
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, d_model, warmup_steps):
        super().__init__()

        self.d_model = d_model
        self.d_model = tf.cast(self.d_model, tf.float32)

        self.warmup_steps = warmup_steps

    def __call__(self, step):
        step = tf.cast(step, dtype=tf.float32)
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps ** -1.5)

        return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)
    
    def get_config(self):
        config = {
            'd_model': self.d_model,
            'warmup_steps': self.warmup_steps
        }
        return config
#         base_config = super(CustomSchedule, self).get_config()
#         return dict(list(base_config.items()) + list(config.items()))

learning_rate = CustomSchedule(CONFIG['d_model'], CONFIG['learning_rate_warmup_steps'])

optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98,
                                     epsilon=1e-9)


plt.plot(learning_rate(tf.range(40000, dtype=tf.float32)))
plt.ylabel('Learning Rate')
plt.xlabel('Train Step')

In [None]:
def masked_loss(label, pred):
    mask = label != 0
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True, reduction='none'
    )
    loss = loss_object(label, pred)

    mask = tf.cast(mask, dtype=loss.dtype)
    loss *= mask

    loss = tf.reduce_sum(loss)/tf.reduce_sum(mask)
    
    return loss


def masked_accuracy(label, pred):
    pred = tf.argmax(pred, axis=2)
    label = tf.cast(label, pred.dtype)
    match = label == pred

    mask = label != 0

    match = match & mask

    match = tf.cast(match, dtype=tf.float32)
    mask = tf.cast(mask, dtype=tf.float32)
    return tf.reduce_sum(match)/tf.reduce_sum(mask)


# masked_loss(tf.constant([0.5], dtype=tf.float32), tf.constant([[0.5, 0.2]], dtype=tf.float32))

In [None]:
transformer.compile(
    loss=masked_loss,
    optimizer=optimizer,
    metrics=[masked_accuracy]
)

## Training

In [None]:
# Callbacks 
!mkdir -p checkpoints

modeldir = f"checkpoints"
checkpoint_filepath = modeldir + '/checkpoint.hdf'
print('Model checkpoint:', checkpoint_filepath)

early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',
    mode='min', verbose=1,
    patience=CONFIG['early_stopping_patience'],
    restore_best_weights=True # restore only best weights relative to val_loss
)

csv_logger=tf.keras.callbacks.CSVLogger(
    modeldir + '/log.csv', separator=",", append=True
)

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_loss',
    mode='min',
    save_best_only=True
)

try:
    model.load_weights(checkpoint_filepath)
    print('Loaded model weights checkpoint.')
except:
    print('Cannot load model weights from checkpoint, it may not exist yet.')

In [None]:
%%time

# Training

do_train = True

history = None
if do_train:
    
    run = wandb.init(
        project=CONFIG['wandb_project'], 
        config=CONFIG,
        group=CONFIG['wandb_group'], 
        job_type='train'
    )
    
    history = transformer.fit(
        dataset_train,
        epochs=CONFIG['max_epochs'],
        validation_data=dataset_val,
        callbacks=[
            WandbCallback(save_model=False),
            early_stopping,
            csv_logger,
            model_checkpoint_callback
        ]
    )
    
    run.finish()

In [None]:
def plot_history(history):
    fig, (axl, axa) = plt.subplots(nrows=2, ncols=1)
    axl.plot(history.history['loss'], label='loss')
    axl.plot(history.history['val_loss'], label='val_loss')
    axl.set_ylim([0, 10])
    axl.set_xlabel('Epoch')
    axl.set_ylabel('Loss')
    axl.legend()
    axl.grid(True)
    
    axa.plot(history.history['masked_accuracy'], label='masked_accuracy')
    axa.plot(history.history['val_masked_accuracy'], label='val_masked_accuracy')
    axa.set_ylim([0, 1])
    axa.set_xlabel('Epoch')
    axa.set_ylabel('Accuracy')
    axa.legend()
    axa.grid(True)
    
    fig.show()
    
if history != None:
    plot_history(history)
else:
    print("No history to display.")

## Inference

In [None]:
# LEGACY - unoptimized inference using python array
class SummarizerSingleStep(tf.Module):
    def __init__(self, transformer, input_tokenizer, target_tokenizer):
        
        # todo
        self.transformer = transformer
        self.input_tokenizer = input_tokenizer
        self.target_tokenizer = target_tokenizer

    # expect sentence to be prepared with <sos> <eos> and clean
    # @tf.function(input_signature=[tf.TensorSpec(dtype=tf.string, shape=[None])])
    def __call__(self, sentence: str):

        encoder_input = tf.expand_dims(self.input_tokenizer.vectorizer(sentence), 0)

        # `tf.TensorArray` is required here (instead of a Python list), so that the
        # dynamic-loop can be traced by `tf.function`.
        # output_array = tf.TensorArray(dtype=tf.int64, size=0, dynamic_size=True)
        # output_array = output_array.write(0, summary_start_token)
        
        summary_start_token = target_tokenizer.vectorizer('<sos>')[0].numpy()
        summary_end_token = target_tokenizer.vectorizer('<eos>')[0].numpy()
        
        decoder_input = [summary_start_token]
        output = tf.expand_dims(decoder_input, 0)

        for i in tf.range(self.target_tokenizer.max_length):
            
            predictions = self.transformer([encoder_input, output], training=False)

            # Select the last token from the `seq_len` dimension.
            predictions = predictions[:, -1:, :]  # Shape `(batch_size, 1, vocab_size)`.

            # TODO: why cast needed here ??
            predicted_id = tf.cast(tf.argmax(predictions, axis=-1), dtype=tf.int32)

            # Concatenate the `predicted_id` to the output which is given to the
            # decoder as its input.
            # output_array = output_array.write(i+1, predicted_id[0])
            output = tf.concat([output, predicted_id], axis=-1)

            if predicted_id == summary_end_token:
                break
        
        
        prediction = tf.squeeze(output, axis=0)
        tokens = np.expand_dims(prediction.numpy(), 0)
        
        # print(tokens)
        
        # text = self.tokenization.summary_tokenizer.sequences_to_texts(tokens)[0]
        text = self.target_tokenizer.tokens_to_text(tokens)[0]

        # `tf.function` prevents us from using the attention_weights that were
        # calculated on the last iteration of the loop.
        # So, recalculate them outside the loop.
        self.transformer([encoder_input, output[:,:-1]], training=False)
        attention_weights = self.transformer.decoder.last_attn_scores

        return text, tokens, attention_weights

# Optimized inference using tensorflow tensorarray
class Summarizer(tf.Module):
    def __init__(self, transformer, input_tokenizer, target_tokenizer):
        
        # todo
        self.transformer = transformer
        self.input_tokenizer = input_tokenizer
        self.target_tokenizer = target_tokenizer
        
    def to_tf(self, text: str):
        return tf.constant([text])
    
    def from_tf(self, tensor: tf.Tensor):
        return bytes.decode(tensor.numpy())

    # expect sentence to be prepared with <sos> <eos> and clean
    @tf.function(input_signature=[tf.TensorSpec(dtype=tf.string, shape=[None])])
    def __call__(self, sentence: tf.Tensor):

        encoder_input = self.input_tokenizer.vectorizer(sentence)
        
        start_token = target_tokenizer.vectorizer(tf.constant('<sos>'))[0][tf.newaxis]
        end_token = target_tokenizer.vectorizer(tf.constant('<eos>'))[0][tf.newaxis]
        
        # `tf.TensorArray` is required here (instead of a Python list), so that the
        # dynamic-loop can be traced by `tf.function`.
        output_array = tf.TensorArray(dtype=tf.int64, size=0, dynamic_size=True)
        output_array = output_array.write(0, start_token)
        output = tf.transpose(output_array.stack())

        for i in tf.range(self.target_tokenizer.max_length):
            
            output = tf.transpose(output_array.stack())
            
            predictions = self.transformer([encoder_input, output], training=False)

            # Select the last token from the `seq_len` dimension.
            predictions = predictions[:, -1:, :]  # Shape `(batch_size, 1, vocab_size)`.

            # argmax
            predicted_id = tf.argmax(predictions, axis=-1)

            # Concatenate the `predicted_id` to the output which is given to the
            # decoder as its input.
            # output_array = output_array.write(i+1, predicted_id[0])
            output_array = output_array.write(i+1, predicted_id[0])

            # stop on end
            if predicted_id == end_token:
                break
        
        output = tf.transpose(output_array.stack())
        
        # print(tokens)
        
        # text = self.tokenization.summary_tokenizer.sequences_to_texts(tokens)[0]
        text = self.target_tokenizer.tokens_to_text(output)[0]

        # `tf.function` prevents us from using the attention_weights that were
        # calculated on the last iteration of the loop.
        # So, recalculate them outside the loop.
        self.transformer([encoder_input, output[:,:-1]], training=False)
        attention_weights = self.transformer.decoder.last_attn_scores

        return text, output, attention_weights


summarizer = Summarizer(transformer, input_tokenizer, target_tokenizer)

In [None]:
%%time
# (inference benchmark)

# summarizer_single = SummarizerSingleStep(transformer, input_tokenizer, target_tokenizer)
# df_train[:5]['article'].apply(lambda text: summarizer_single(text)[0])

In [None]:
%%time

# df_train[:5]['article'].apply(lambda text: summarizer(summarizer.to_tf(text))[0])

In [None]:
def remove_special_tokens(text):
        text = text.lower()
        text = text.replace("<sos>", "").replace("<eos>", "")
        text = text.replace("<unk>", "##")
        text = text.replace("<dot>", ". ") # normal syntax with dot at end
        text = text.strip()
        return text

def remove_special_tokens_frame(frame: pd.DataFrame):
    frame['article'] = frame['article'].apply(remove_special_tokens)
    frame['summary'] = frame['summary'].apply(remove_special_tokens)
    frame['predicted'] = frame['predicted'].apply(remove_special_tokens)
    return frame
    
# Run summarization inference on entire frame
def summarize_frame(frame):
    
    frame = frame.copy()
    frame['predicted'] = '<NONE>'
    
    for i in range(0, frame.shape[0]):
        if i%25 == 0:
            print(f"Summarising ... {i}/{frame.shape[0]}")
            
        article = frame.iloc[i]['article']
        summary = frame.iloc[i]['summary']
        
        summarized_tf, summarized_tokens, attention_weights = summarizer(
            tf.constant([article])
        )
        
        summarized_text = summarizer.from_tf(summarized_tf)
        
        frame.iloc[i, frame.columns.get_loc('predicted')] = summarized_text
        
    return frame

def pretty_summaries(frame):
    
    for i, row in frame.iterrows():
        print(f"\n ------------------")
        print(f"Article  : {remove_special_tokens(row['article'])}")
        print(f"\nSummary  : {remove_special_tokens(row['summary'])}")
        print(f"\nPredicted: {remove_special_tokens(row['predicted'])}")
        print()
        print(f"------------------")

## Inference on test and val sets

In [None]:
wandb_eval_run = wandb.init(
    project=CONFIG['wandb_project'], 
    config=CONFIG,
    group=CONFIG['wandb_group'], 
    job_type='evaluate'
)

In [None]:
%%time

# Predict on test set and save
print('--- Runing inference on test set ---\n')
test_pred = remove_special_tokens_frame(summarize_frame(df_test[:10]))
test_pred

In [None]:
test_pred.to_csv('testset_evaluation_data.csv')
print('Saved test set evaluation data.')

In [None]:
%%time

# Predict on test set and save
print('--- Runing inference on validation set ---\n')
val_pred = remove_special_tokens_frame(summarize_frame(df_val[:10]))
val_pred

In [None]:
val_pred.to_csv('validationset_evaluation_data.csv')
print('Saved validation set evaluation data.')

In [None]:
print('--- Example test set summaries ---\n')
pretty_summaries(test_pred[:10])

In [None]:
print('--- Example validation set summaries ---\n')
pretty_summaries(val_pred[:10])

## Evaluation

In [None]:
# Immediate metrics after inference

rouge = evaluate.load('rouge')
rouge_test = rouge.compute(references=test_pred['summary'], predictions=test_pred['predicted'])
rouge_val = rouge.compute(references=val_pred['summary'], predictions=val_pred['predicted'])

wandb_eval_run.log({
    'rouge_metrics': wandb.Table(dataframe=pd.DataFrame({
        'test': pd.Series(rouge_test),
        'val': pd.Series(rouge_val)
    })),
})

wandb_eval_run.finish()