<a href="https://colab.research.google.com/github/VasilisDrog/deep-machine-learning/blob/master/Assignment_3_Text_generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Assignment 3: Algorithms for text generation

In this assignment, we will explore using trained language models to generate text. In particular, we will work with a recent model called [Generative Pre-trained Transformer, version 2 \(GPT-2\)](https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf) that was published in 2019 by OpenAI.

As language models are probabilistic models of text, there are different methods of generating (also known as _decoding_) text strings from the model, as we have seen in [one of the lectures](http://www.cse.chalmers.se/~richajo/dat450/lectures/l7/m7_3.pdf). You will implement some of the most common decoding methods in this assignment, and later reflect on the qualitative aspects of the different methods.

**Note:** It will be important to use a GPU with a large memory, such as provided on Colab. Please enable the GPU runtime by going to _Runtime -> Change Runtime type -> GPU_.

**Note:** Implementations of the generation algorithms you code here already exist in the Huggingface library. In a real use case, you would typically just call `generate`. These reimplementations are for pedagogical purposes.  

In [None]:
# Let's start by importing the PyTorch library:

import torch
torch.set_grad_enabled(False) # since we will not be updating any models...

<torch.autograd.grad_mode.set_grad_enabled at 0x7f236348ae10>

## The GPT-2 Language Model

In the GPT-2, a _Transformer Decoder_ is used to model the conditional probability $P(x_i | x_1, ..., x_{i-1})$ using large quantities of text data. As training big language models are typically very computationally expensive, we will not train our own in this assignment, but use a pre-trained one instead. For this we will need to install a separate package, called `transformers`.

In [None]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


The GPT-2 model comes with its own tokenizer, which we will need to load:

In [None]:
from transformers import GPT2Tokenizer

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2-large")
tokenizer.pad_token = tokenizer.eos_token

In [None]:
input_ids = tokenizer.encode("NLP stands for natural", return_tensors="pt").to("cuda")
input_ids

tensor([[   45, 19930,  6296,   329,  3288]], device='cuda:0')

Like the tokenizers we've seen earlier in the course, it maps a text string to a sequence of tokens (integers) from a fixed size vocabulary. Note that `input_ids` is two-dimensional, where the first dimension is the batch dimension, and second dimension is the sequence dimension.
The tokenizer can also decode the integers back to the string representation:

In [None]:
tokenizer.decode(input_ids[0])

'NLP stands for natural'

We can now download the trained model. As we will work with a large model in this assignment (several hundreds of millions of parameters), using a GPU will _greatly_ speed up predictions. 

In [None]:
from transformers import GPT2LMHeadModel

model = GPT2LMHeadModel.from_pretrained("gpt2-large").to("cuda").eval()

For your curiosity, you can optionally print the number of parameters in each layer and the total number of parameters:

In [None]:
# total_parameters = 0
# for name, par in model.named_parameters():
#  n_par = 1
#  for d in par.shape:
#    n_par *= d
#  print(f'{name}: {n_par} parameters')
#  total_parameters += n_par
# print(f'Total number of parameters: {total_parameters}')

With the model loaded, let's use it for predicting the next token of our tokenized `input_ids` from above:

In [None]:
predictions = model(input_ids=input_ids)

predictions.logits

tensor([[[ 1.7978,  2.8982,  1.6116,  ..., -5.3930, -5.0412,  0.7757],
         [ 2.0065,  5.6440,  0.4546,  ..., -3.7273, -6.4793,  1.1439],
         [ 2.4830,  2.4556, -1.5678,  ..., -6.9141, -6.2589,  1.0836],
         [ 1.6584,  4.0486, -3.9575,  ..., -6.5006, -7.1734,  1.3958],
         [ 2.0553,  3.9053, -1.5021,  ..., -2.4546, -7.9215,  1.0312]]],
       device='cuda:0')

In [None]:
predictions.logits.shape

torch.Size([1, 5, 50257])

What we get from the model are the unnormalized log probabilities, called _logits_. 

**Self-check:** Look at the shape of `predictions.logits` from above, what do the three dimensions represent? 1 is batch size, 5 the lengt of the prompt, 50257 is the size of the vocabulary

**Your work:** How can we, from `predictions.logits`, compute the actual probability distribution of the next word in the sequence `NLP stands for natural ____`? The distribution should be over the entire vocabulary, and be valid probabilities that sum to one.

In [None]:
input1 = predictions.logits[0, -1, :] # 
next_token_prob= torch.nn.functional.softmax(input1, dim=-1)

next_token_prob.shape

torch.Size([50257])

In [None]:
# These tests should pass without modifications
assert next_token_prob.shape == torch.Size([tokenizer.vocab_size])
assert abs(next_token_prob.sum() - 1.0) < 0.01
assert all(next_token_prob >= 0)

**Your work:** Compute the top 5 most probable next tokens, based on the `next_token_prob` distribution. 

**Hint**: the function [`topk`](https://pytorch.org/docs/stable/generated/torch.topk.html) will be useful here.

In [None]:
top_5_next_tokens =  torch.topk(input1, 5).indices # show where the token are in the vocab
print(top_5_next_tokens)

tensor([ 3303, 15417,  1692,  8950,    12], device='cuda:0')


We can again use `tokenizer.decode` to map the integer-encoded tokens back to strings.

In [None]:
for index in top_5_next_tokens:
  print(f"{tokenizer.decode([index])}")
print()

 language
 Language
 human
 languages
-


We could now decide to, for example, pick the token id with the highest probability, append that to our input, and run through the model again to compute the distribution for the next token again.

**Your work:** Take the higest predicted token from `top_5_next_tokens` and append to `input_ids`

**Hint:** The function [`torch.cat`](https://pytorch.org/docs/stable/generated/torch.cat.html) could be useful here.

In [None]:
print(input[0])

tensor(2.0553, device='cuda:0')


In [None]:
new_input_ids = torch.cat((input_ids, top_5_next_tokens[0][None, None]), dim=-1 )

To see that the prediction is sensible, you can convert the integer-encoded tensor back into text:

In [None]:
tokenizer.decode(new_input_ids[0])

'NLP stands for natural language'

In [None]:
# These tests should pass without modifications
assert new_input_ids.shape == torch.Size([1, input_ids.shape[1] + 1])
assert new_input_ids[0, -1] == top_5_next_tokens[0]

**Your work:** Like above, compute a new distribution for the next token and print the top 5 most probable next tokens

In [None]:
new_decode = tokenizer.decode(new_input_ids[0])
print(new_decode)

input_ids2 = tokenizer.encode(new_decode, return_tensors="pt").to("cuda")
predictions2 = model(input_ids=input_ids2)
print(input_ids2)

input2 = predictions2.logits[0, -1, :]
next_token_prob= torch.nn.functional.softmax(input2, dim=-1)
top_5_next_tokens2 =  torch.topk(input2,5).indices


for index in top_5_next_tokens2:
  print(f"{tokenizer.decode([index])}")


new_input_ids2 = torch.cat((input_ids2, top_5_next_tokens2[0][None, None]), dim=-1 )

tokenizer.decode(new_input_ids2[0])

NLP stands for natural language
tensor([[   45, 19930,  6296,   329,  3288,  3303]], device='cuda:0')
 processing
 understanding
 perception
 parsing
 generation


'NLP stands for natural language processing'

## Generating from a language model

What we just did can be formalized into a general algorithm to generate text from a language model:

1. Start with some text to be _continued_. We will denote this initial text as a _prompt_: $x_1, ..., x_i$
2. Use the language model to compute the next token probabilities: $P(x_{i+1} | x_1, ..., x_i)$
3. Based on the distribution, pick some next token $x_{i+1}$ and append to the input
4. Repeat from step 2 until a stopping criterion is met.

An important decision when generating from language models is what strategy you apply for picking next tokens (step 3). We will implement and experiment with different such strategies and you will in the individual reflection discuss pros and cons of each, and how these differ from each other.

We begin by defining an abstract decoding strategy class, that has a method `step` which takes the `logits` and `input_ids` at some step. `step(...)` returns updated `input_ids` to be used in the next step.

In [None]:
from abc import ABC, abstractmethod

class DecodingStrategy(ABC):

  @abstractmethod
  def step(self, logits, input_ids):
    """
    This method takes next token logits and input_ids and applies some strategy to update the input_ids.
    It returns the updated input ids.

    Args:
      logits:    3d float tensor
      input_ids: 2d int tensor

    Returns:
      next_input_ids: 2d int tensor
    """
    raise NotImplementedError()

Next, we will implement a stopping criterion. In this assignment we will stop when the model has generated X number of sentences. We define sentence boundaries by the period token:

In [None]:
tokenizer.encode(".")

[13]

**Your work:** Implement the following function, that returns the number of completed sentences in each batch sequence in `input_ids`:

In [None]:
def get_num_sentences(input_ids):
  """
  Returns an integer tensor of shape input_ids.shape[0], that tells how many completed 
  sentences there are in each batch sequence
  """

  # WRITE CODE HERE
  new_input_shape = input_ids.size()
  print(input_ids)
  # return new_input_shape

  # raise NotImplementedError()


In [None]:
# This test should pass without modification
test_input_ids = tokenizer(["This sequence has zero completed sentences", "Here is one completed sentence. Here is another."], return_tensors="pt", padding=True).input_ids
get_num_sentences(test_input_ids)

# assert torch.equal(get_num_sentences(test_input_ids), torch.tensor([0, 2]))
# print(test_input_ids)
# torch.tensor([0, 2])

tensor([[ 1212,  8379,   468,  6632,  5668, 13439, 50256, 50256, 50256, 50256],
        [ 4342,   318,   530,  5668,  6827,    13,  3423,   318,  1194,    13]])


**Your work:** Implement the stopping criterion function below, that returns a boolean vector indicating if each batch sequence has at least `n` or more sentences. Use the `get_num_sentences` function from above.

In [None]:
def has_n_sentences(input_ids, n):
  # WRITE CODE HERE
  
  raise NotImplementedError()

In [None]:
# This test should pass without modification
assert torch.equal(
    has_n_sentences(test_input_ids, n=2), 
    torch.tensor([False, True])
)

NotImplementedError: ignored

Using a prompt and some strategy, we can implement the generation algorithm. The generation stops when all sequences in the batch are done (according to the stopping criterion), or when a maximum generation length is reached. 

In [None]:
from IPython.display import clear_output

def generate(prompt, strategy, stopping_criterion, max_length=100, print_output=True):
  """
  TODO write docstring. Remove types.
  """

  # Step 1: 
  encoded_prompt = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
  input_ids = encoded_prompt

  while not torch.all(stopping_criterion(input_ids)) and input_ids.shape[1] < encoded_prompt.shape[1] + max_length:
    # Step 2: Get next token logits
    predictions = model(input_ids=input_ids)

    # Step 3: Apply decoding strategy to update input_ids
    input_ids = strategy.step(predictions.logits, input_ids)

    # Print generated string(s) so far
    if print_output:
      clear_output()
      for batch_idx in range(input_ids.shape[0]):
        print(tokenizer.decode(input_ids[batch_idx], skip_special_tokens=True))
        print("----------------------------------------------------------------")

  return input_ids

To test our generation algorithm, we can implement a dummy strategy, that disregards the logits, and just picks a random token from the vocabulary as next token.

In [None]:
class DummyStrategy(DecodingStrategy):
  def step(self, logits, input_ids):
    next_tokens = torch.randint(low=0, high=tokenizer.vocab_size, size=[input_ids.shape[0]]).to(input_ids.device)
    new_input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
    return new_input_ids


In [None]:
from functools import partial
dummy_strategy = DummyStrategy()
stopping_criterion = partial(has_n_sentences, n=2)  # Returns a new function where argument n is set to 2

_ = generate(
    prompt="NLP stands for natural", 
    strategy=dummy_strategy, 
    stopping_criterion=stopping_criterion,
    max_length=20
)

As expected, the generated string are just rubbish.

Generating from a language model can be seen as a search problem, where  all possible token strings span a large search tree. 

![picture](https://huggingface.co/blog/assets/02_how-to-generate/greedy_search.png)

A good idea is to run a search to find the string that is the _most probable_ under the language model, i.e:

$\DeclareMathOperator*{\argmax}{argmax}$

\begin{align}
  x_{i+1}^*, ..., x_n^* &=  \argmax_{x_{i+1}, ..., x_n}  P(x_{i+1}, ..., x_n | x_1, ..., x_i) \\
  &= \argmax_{x_{i+1}, ..., x_n} \prod_{i'=i}^n P(x_{i'+1} | x_1, ..., x_{i'})
\end{align}

However, assuming we cannot run a brute force search, what search algorithms are there that we can apply?


## Greedy decoding

What we did in the beginning, i.e. picking the most probable token, is known as the _greedy_ decoding strategy. That means we approximate the argmax by taking the most probable token at each step. The algorithm is described conceptually on slides 6-14 in [the lecture](http://www.cse.chalmers.se/~richajo/dat450/lectures/l7/m7_3.pdf), but please keep in mind that the pseudocode given in the lecture generates just a single text.

**Your work:** Implement the greedy strategy in the class skeleton below:

In [None]:
class GreedyStrategy(DecodingStrategy):

  def step(self, logits, input_ids):
    # WRITE CODE HERE
    # next_input_ids = ...
    raise NotImplementedError()
    return next_input_ids


In [None]:
# This test should pass without modification
greedy_strategy = GreedyStrategy()
test_input_ids = torch.tensor([[1, 2, 3]])
test_logits = torch.tensor([[[0.1, 0.1, 0.1, 0.1, 0.6],
                             [0.1, 0.1, 0.1, 0.6, 0.1],
                             [0.1, 0.1, 0.6, 0.1, 0.1]]])
test_new_input_ids = greedy_strategy.step(test_logits, test_input_ids)
assert torch.equal(test_new_input_ids, torch.tensor([[1, 2, 3, 2]]))

Now, try generating some text using this strategy:

In [None]:
generated_ids = generate("NLP stands for natural", greedy_strategy, stopping_criterion=partial(has_n_sentences, n=2))

We will now implement a method to compute the log probablity of the generated string, i.e. $\log P(x_{i+1}, ..., x_n | x_1, ..., x_i)$

In [None]:
def get_joint_log_probability(logits, input_ids):
  labels = input_ids[:, 1:].clone().reshape(-1)
  labels[labels == tokenizer.pad_token_id] = -100
  logits = logits[:, :-1, :].reshape(-1, logits.shape[-1])
  normalized_log_probs = torch.nn.functional.cross_entropy(logits, labels, reduction="none")
  normalized_log_probs = normalized_log_probs.reshape(input_ids.shape[0], -1)
  return -normalized_log_probs.sum(-1)

Let's compute the joint log probability of the generated text:

In [None]:
predictions = model(input_ids=generated_ids)
greedy_joint_logprob = get_joint_log_probability(predictions.logits, generated_ids)
print('Joint log probability of the text using greedy search:', greedy_joint_logprob[0].item())

The higher this value is, the more likely the generated string is, under the language model. We will compare this value to the corresponding value for our next decoding strategy, which is **beam search**.

## Beam search

While greedy search finds strings that have high probability under the model, it often takes suboptimal decisions where a low probability word might yield a greater joint probability in the end.

In [beam search](https://en.wikipedia.org/wiki/Beam_search), we run multiple search _alternatives_ (beams) in parallel and at each step, we select the $k$ most probable alternatives to pass on to the next step. Conceptually, this algorithm has been described in slides 18-20 of [the lecture](http://www.cse.chalmers.se/~richajo/dat450/lectures/l7/m7_3.pdf), but our code will differ a bit from the conceptual pseudocode because of PyTorch technicalities and because of the stopping criterion.

**Your work:** Implement the beam search strategy in the skeleton below. You will find a comment `# WRITE CODE HERE` where you are expected to add your own code.

In [None]:
class BeamSearchStrategy(DecodingStrategy):
  def __init__(self, num_beams: int, stopping_criterion):
    self.num_beams = num_beams
    self.stopping_criterion = stopping_criterion

  def step(self, logits, input_ids):
    # Let's define some auxiliary variables we will use in sanity checks.
    n_beams, n_tokens, voc_size = logits.shape

    # *YOUR WORK*: Compute log prob for the beams from the previous step
    # The result is a tensor of shape n_beams.
    # TODO: student code here
    log_probs = # WRITE CODE HERE
    assert(log_probs.shape == torch.Size([n_beams]))

    # Apply the stopping criterion to see which beams are finished.
    # The result is a boolean tensor of shape n_beams.
    is_finished = self.stopping_criterion(input_ids)    
    is_not_finished = ~is_finished

    # Select the beams that are finished and unfinished, respectively.
    finished_ids = input_ids[is_finished]
    unfinished_ids = input_ids[is_not_finished]
    n_unfinished = unfinished_ids.shape[0]

    # ... and the log probabilities for the finished and unfinished beams.
    finished_log_probs = log_probs[is_finished]
    unfinished_log_probs = log_probs[is_not_finished]

    # *YOUR WORK*: First, convert the logits for the next token prediction into log probabilities.
    # *HINT*: You can use log_softmax for this.
    log_probs_next_token =  # WRITE CODE HERE
    assert(log_probs_next_token.shape == torch.Size([n_unfinished, voc_size]))

    # *YOUR WORK*: Then, add the next token log probabilities to the log probabilities for the
    # previous unfinished beams.
    #
    # *HINT*: This requires a PyTorch tensor trick: what we want to do is to add a beam
    # log-probability to *each* next token log-probability for this beam.
    # The shape of unfinished_log_probs is [n_unfinished] while the shape of 
    # log_probs_next_token is [n_unfinished, voc_size].
    # To do this, view unfinished_log_probs as a tensor of shape [n_unfinished, 1]
    # by writing as follows: unfinished_log_probs[:, None]
    # When both tensors are 2-dimensional, they can be summed: in PyTorch, if we add
    # a tensor of shape [m, n] to one of shape [m, 1], the second tensor will be
    # treated as if it were of shape [m, n] as well (with all rows copied).
    # 
    log_probs_beams_expanded = # WRITE CODE HERE
    assert(log_probs_beams_expanded.shape == torch.Size([n_unfinished, voc_size]))

    # *YOUR WORK*: Now, sort the log probabilities for the expanded beams in descending order.
    # *HINT*: first flatten the tensor so that it has the shape n_unfinished*voc_size.
    # *HINT*: PyTorch has a built-in sort function that you can read about here:
    # See https://pytorch.org/docs/stable/generated/torch.sort.html#torch.sort.
    expanded_sorted = # WRITE CODE HERE
    assert(expanded_sorted.values.shape == torch.Size([n_unfinished*voc_size]))

    # Here, take some time to understand what was returned by the sorting function.
    # This function returns two tensors, one (.values) containing the sorted values and 
    # another (.indices) containing the indices of the original positions of what was sorted.

    # We will now carry out the step to compute the updated beam.
    # 
    next_unfinished_idx = 0
    next_finished_idx = 0

    # This list will keep the selected beams.
    beams = []

    # If we select the finished beams, we will have to add some padding.
    padding = torch.tensor([tokenizer.pad_token_id], device=logits.device)

    for i in range(self.num_beams):
      # We will now select beam i for the next step.
      # To do this, we compare the best finished beam from the previous step to
      # the best of the expanded unfinished beams, and select the best of those two.
      # (We also have to check whether we are out of finished beams.)
      if next_finished_idx >= finished_log_probs.shape[0] \
         or expanded_sorted.values[next_unfinished_idx] > finished_log_probs[next_finished_idx]:
        # We select the next best unfinished beam:

        # First, we compute the index among the unfinished beams of the 
        # highest-scoring candidate.
        seq_idx = torch.div(expanded_sorted.indices[next_unfinished_idx], logits.shape[-1], rounding_mode="floor")

        # Next, we compute the index in the vocabulary of the highest-scoring candidate.
        next_token = expanded_sorted.indices[next_unfinished_idx] % logits.shape[-1]

        # *YOUR WORK*: create a tensor next_beam where you add the next token id
        # to the corresponding beam from the previous step.
        # *HINT*: next_token is an integer while the previous beam is 1-dimensional.
        # You may use the trick [None] as above to make next_token 1-dimensional.
        next_beam = # WRITE CODE HERE
        assert(next_beam.shape == torch.Size([n_tokens+1]))

        next_unfinished_idx += 1
      else:
        # We select the next best previously finished beam:

        # *YOUR WORK*: create a tensor next_beam where you add padding to the
        # beam from the previous step.
        next_beam = # WRITE CODE HERE
        assert(next_beam.shape == torch.Size([n_tokens+1]))

        next_finished_idx += 1

      # Add the current beam to the list of selected beams.
      beams.append(next_beam)      
    
    # *YOUR WORK*: Finally, concatenate all beams into a tensor and return it.
    # The function torch.stack is probably going to be useful.
    # https://pytorch.org/docs/stable/generated/torch.stack.html
    next_input_ids = # WRITE CODE HERE

    assert(next_input_ids.shape == torch.Size([self.num_beams, n_tokens+1]))    
    return next_input_ids


The following cell tests your code by carrying out one step of the beam search. The result should be a tensor of shape (5, 3). The generated texts will also be printed.

In [None]:
# This test should pass without modification.

# We assume that the result from the previous step has the shape [3, 2].
# The third of them ends with a period so we will consider this to be "finished".
test_beams = tokenizer(['This is', 'That is', 'End.'], return_tensors='pt').input_ids.to(model.device)

# Apply the model to compute the logits for the next tokens.
test_logits = model(test_beams).logits

# We will use a beam search with width 5 and a stopping criterion that finished after one sentence.
beam_strategy = BeamSearchStrategy(num_beams=5, stopping_criterion=partial(has_n_sentences, n=1))

# Apply one step of the beam search.
new_beams = beam_strategy.step(test_logits, test_beams)

# The result should have 5 rows (because num_beams is 5) and 3 columns (because we added one column).
assert(new_beams.shape == torch.Size([5, 3]))

# Finally, print the result:
for beam in new_beams:
  print(tokenizer.decode(beam))

Now, let us finally use this to generate running text using beam search:

In [None]:
stopping_criterion = partial(has_n_sentences, n=2)
beam_strategy = BeamSearchStrategy(num_beams=5, stopping_criterion=stopping_criterion)
generated_ids = generate("NLP stands for natural", beam_strategy, stopping_criterion=stopping_criterion)

We compute the joint probability again:

In [None]:
predictions = model(input_ids=generated_ids)
beamsearch_joint_logprob = get_joint_log_probability(predictions.logits, generated_ids)

print('Joint log probability of the text using beam search:', beamsearch_joint_logprob[0].item())

## Investigating longer texts

When we tested greedy and beam search decoding above, we used a stopping criterion that terminates the generation when two sentences have been produced.

Let us see what happens when we generate longer text. Set the number of generated sentences to a larger value and generate again using beam search and greedy decoding and see if you can make any observation about the behavior.

(This will be discussed in the individual reflection.)

In [None]:
# WRITE CODE HERE

## Random sampling

Instead of searching for the most probable string, we could instead simply sample from the next token distribution.

**Hint:** To sample from a given discrete distribution in PyTorch, you can build a [`Categorical`](https://pytorch.org/docs/stable/distributions.html#torch.distributions.categorical.Categorical) distribution and then use that to generate random numbers by calling the method `sample`.

**Your work:** Implement the random sampling strategy below:

In [None]:
from torch.distributions import Categorical

class RandomSamplingStrategy(DecodingStrategy):

  def step(self, logits, input_ids):
    # Let's define some auxiliary variables we will use in sanity checks.
    batch_size, n_tokens, voc_size = logits.shape    

    # *YOUR WORK*: Select the logits for the next token.
    next_token_logits = # WRITE CODE HERE
    assert(next_token_logits.shape == torch.Size([batch_size, voc_size]))    

    # *YOUR WORK*: Select the next tokens randomly from the distribution
    # defined by next_token_logits.
    next_tokens = # WRITE CODE HERE
    assert(next_tokens.shape == torch.Size([batch_size]))

    # *YOUR WORK*: Add the new tokens to the previous input_ids.
    next_input_ids = # WRITE CODE HERE
    assert(next_input_ids.shape == torch.Size([batch_size, n_tokens+1]))
    return next_input_ids

Let's apply the random sampling strategy:

In [None]:
random_strategy = RandomSamplingStrategy()
generated_ids = generate("NLP stands for natural", random_strategy, stopping_criterion=partial(has_n_sentences, n=2))
assert(generated_ids.shape[0] == 1)

## Top-_k_ sampling

We can think of strategies that are a "middle ground" between maximum probability strategies, and random sampling. One such example is the **top-k** sampling strategy. In this strategy, we sample from the _top-k_ most probable next tokens. This means we normalize the probabilities of the k most probable next tokens, and sample from this new distribution.

**Your work:** Implement the top-k sampling strategy below:

In [None]:
class TopKSamplingStrategy(DecodingStrategy):

  def __init__(self, k: int):
    self.k = k

  def step(self, logits, input_ids):
    # Let's define some auxiliary variables we will use in sanity checks.
    batch_size, n_tokens, voc_size = logits.shape    

    # *YOUR WORK*: Select the logits for the next token.
    next_token_logits = # WRITE CODE HERE
    assert(next_token_logits.shape == torch.Size([batch_size, voc_size]))

    # *YOUR WORK*: Now, select the top k alternatives for every item in the batch.
    # *Hint*: probably easiest to use the function topk here:
    # https://pytorch.org/docs/stable/generated/torch.topk.html
    topk = # WRITE CODE HERE
    assert(topk.values.shape == torch.Size([batch_size, self.k]))

    # *YOUR WORK*: Sample from among the top k candidates you found in the 
    # previous step.
    index_in_topk = # WRITE CODE HERE
    assert(index_in_topk.shape == torch.Size([batch_size]))

    # By calling torch.gather, we can map the index in the top-k list back to 
    # the index of the vocabulary.
    next_tokens = torch.gather(topk.indices, 1, index_in_topk[:, None])
    assert(next_tokens.shape == torch.Size([batch_size, 1]))

    # *YOUR WORK*: Concatenate the new generated tokens to the previous input_ids.
    next_input_ids = # WRITE CODE HERE
    assert(next_input_ids.shape == torch.Size([batch_size, n_tokens+1]))

    return next_input_ids

We can now use the top-$k$ sampling strategy to generate text. How do you think this compares to the previous decoding strategies?

In [None]:
top_k_strategy = TopKSamplingStrategy(k=5)
generated_ids = generate("NLP stands for natural", top_k_strategy, stopping_criterion=partial(has_n_sentences, n=4))

## Putting it all together

We have now implemented 4 different decoding strategies. Let's put them side by side to compare them more easily.


In [None]:
prompt = "NLP stands for natural"

stopping_criterion = partial(has_n_sentences, n=3)

greedy_strategy = GreedyStrategy()
beam_strategy = BeamSearchStrategy(num_beams=5, stopping_criterion=stopping_criterion)
random_strategy = RandomSamplingStrategy()
top_k_strategy = TopKSamplingStrategy(k=5)

print("Greedy:")
print("-------")
generated_ids = generate(prompt, greedy_strategy, stopping_criterion=stopping_criterion, print_output=False)
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
print()
print()
print(f"Beam search ({beam_strategy.num_beams} beams):")
print("----------------------")
generated_ids = generate(prompt, beam_strategy, stopping_criterion=stopping_criterion, print_output=False)
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True ))
print()
print()
print(f"Random sampling:")
print("----------------")
generated_ids = generate(prompt, random_strategy, stopping_criterion=stopping_criterion, print_output=False)
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True ))
print()
print()
print(f"Top-k sampling (k={top_k_strategy.k}):")
print("---------------------")
generated_ids = generate(prompt, top_k_strategy, stopping_criterion=stopping_criterion, print_output=False)
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True ))
print()
print()

**Your work:**
Play around with different settings for sequence length (e.g. how many sentences to generate), number of beams, and $k$ to get a feeling of how the algorithms behave. Also try modifying the prompt to something of your choosing.