## Overview

In the recent past, there has been a lot of research in language generation with auto-regressive models, like Transformers. In auto-regressive language generation, the probability distribution of token at time step K is dependent on the model's token-predictions till step K-1. For these models, decoding strategies such as Beam search, Greedy, Top-p and Top-k are critical components of the model and largely influence the style/nature of the generated output token at a given time step K.

## Setup

In [None]:
pip install -q tf-models-official==2.7.0

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

from official import nlp
from official.nlp.modeling.ops import sampling_module, beam_search

## Initialize the model hyperparameters

In [None]:
params = {}
params['num_heads'] = 2
params['num_layers'] = 2
params['batch_size'] = 2
params['n_dims'] = 256
params['max_decode_length'] = 4

## Initialize cache

In auto-regressive architectures like Transformers based Encoder-Decoder models, Cache is used for fast sequential decoding. It is a nested dictionary storing pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) for every layer.

In [None]:
cache = {
    'layer_%d' % layer: {
      'k': tf.zeros([params['batch_size'], params['max_decode_length'], params['num_heads'], int(params['n_dims']/params['num_heads'])], dtype=tf.float32),
      'v': tf.zeros([params['batch_size'], params['max_decode_length'], params['num_heads'], int(params['n_dims']/params['num_heads'])], dtype=tf.float32),
    } for layer in range(params['num_layers'])
}

print("cache key shape for layer 1:", cache['layer_1']['k'].shape)

## Define closure for length normalization if needed

This is used for normalizing the final scores of generated sequences and is optional.

In [None]:
def length_norm(length, dtype):
  """Return length normalization factor"""
  return tf.pow(((5. + tf.cast(length, dtype)) / 6.), 0.0)

## Create model_fn

In practice, this will be replaced by an actial model implementation.

In [None]:
probabilities = tf.constant([[[0.3, 0.4, 0.3], [0.3, 0.3, 0.4],
                              [0.1, 0.1, 0.8], [0.1, 0.1, 0.8]],
                            [[0.2, 0.5, 0.3], [0.2, 0.7, 0.1],
                              [0.1, 0.1, 0.8], [0.1, 0.1, 0.8]]])

In [None]:
def model_fn(i):
  return probabilities[:, i, :]

## Initialize symbols_to_logits_fn

In [None]:
def _symbols_to_logits_fn():
  """Calculates logits of the next tokens"""
  def symbols_to_logits_fn(ids, i, temp_cache):
    del ids
    logits = tf.cast(tf.math.log(model_fn(i)), tf.float32)
    return logits, temp_cache
  return symbols_to_logits_fn

## Greedy

Greedy decoding selects the token id with the highest probability as its next id: id_t = argmax_wP(id|id_(1:t-1)) at each timestep t. The following sketch shows greedy decoding.

In [None]:
greedy_obj = sampling_module.SamplingModule(
    length_normalization_fn = None,
    dtype = tf.float32,
    symbols_to_logits_fn = _symbols_to_logits_fn(),
    vocab_size = 3,
    max_decode_length = params['max_decode_length'],
    eos_id = 10,
    padded_decode = False, #For TPU
)

In [None]:
ids, _ = greedy_obj.generate(
    initial_ids = tf.constant([9, 1]),
    initial_cache = cache,
)

In [None]:
print("Gredy Decoded Ids:", ids)

## top_k sampling

In top_k sampling, the K most likely next token ids are filtered and the probability mass is redistributed among only thos K ids.

In [None]:
top_k_obj = sampling_module.SamplingModule(
    length_normalization_fn = length_norm,
    dtype = tf.float32,
    symbols_to_logits_fn = _symbols_to_logits_fn(),
    vocab_size = 3,
    max_decode_length = params['max_decode_length'],
    eos_id = 10,
    sample_temperature = tf.constant(1.0),
    top_k = tf.constant(3),
    padded_decode = False,
    enable_greedy = False,
)

In [None]:
ids, _ = top_k_obj.generate(
    initial_ids = tf.constant([9, 1]),
    initial_cache = cache,
)

In [None]:
print("top-k sampled Ids:", ids)

## top_p sampling

Instead of sampling only from the most likely K token ids, in top_p sampling chooses from the smallest possible set of ids whose cumulative probability exceeds the probability p.

In [None]:
top_p_obj = sampling_module.SamplingModule(
    length_normalization_fn = length_norm,
    dtype = tf.float32,
    symbols_to_logits_fn = _symbols_to_logits_fn(),
    vocab_size = 3,
    max_decode_length = params['max_decode_length'],
    eos_id = 10,
    sample_temperature = tf.constant(1.0),
    top_p = tf.constant(0.9),
    padded_decode = False,
    enable_greedy = False,
)

In [None]:
ids, _ = top_p_obj.generate(
    initial_ids = tf.constant([9, 1]),
    initial_cache = cache,
)

In [None]:
print("top-p sampled Ids:", ids)

## Beam search decoding

Beam search reduces the risk of missing hidden high probability token ids by keeping the most likely num_beams of hypothesis at each time step and eventually choosing the hypothesis that has the overall highest probability.

In [None]:
beam_size = 2

params['batch_size'] = 1
beam_cache = {
    'layer_%d' % layer: {
        'k': tf.zeros([params['batch_size'], params['max_decode_length'], params['num_heads'], params['n_dims']], dtype=tf.float32),
        'v': tf.zeros([params['batch_size'], params['max_decode_length'], params['num_heads'], params['n_dims']], dtype=tf.float32)
    } for layer in range(params['num_layers'])
}

print("cache key shape for layer 1 :", beam_cache['layer_1']['k'].shape)

In [None]:
ids, _ = beam_search.sequence_beam_search(
    symbols_to_logits_fn = _symbols_to_logits_fn(),
    initial_ids = tf.constant([9], tf.int32),
    initial_cache = beam_cache,
    vocab_size = 3,
    beam_size = beam_size,
    alpha = 0.6,
    max_decode_length = params['max_decode_length'],
    eos_id = 10,
    padded_decode = False,
    dtype = tf.float32,
)
print("Beam search ids:", ids)