In [None]:
from abc import ABC, abstractmethod

import torch
import torch.nn as nn
import torch.nn.functional as F
import warnings
from torch import Tensor
from transformers import GPT2Tokenizer, GPT2LMHeadModel, PreTrainedTokenizer

In [None]:
model_type = 'gpt2'
model = GPT2LMHeadModel.from_pretrained(model_type)  # NOTE: this will download the model weights which are ~200-300 MB for GPT2
tokenizer = GPT2Tokenizer.from_pretrained(model_type)

Let's look at a generic process for generating text from our language model that should work regardless of the decoding strategies we'll cover here. We start with an input prompt provided by a user, that prompt is fed to through the model to obtain probabilities for every token in the model's vocabulary in order to predict what should be the next token in the sequence. A token is selected from according to the predicted probabilities based on some decoding strategy. The predicted token is then appended to the prompt and the update sequence is fed to the model again to predict another token. This process is repeated until some pre-defined maximum number of steps is reached or until an end of sentence (EOS) token is predicted.

(note that huggingface's transformer library already has functionality to do all this with better code and interfaces but it's a fun learning experience to implement these methods from scratch.)

In [None]:
class TextGenerator:
    def __init__(self, model: nn.Module, tokenizer: PreTrainedTokenizer, decoding_strategy: 'DecodingStrategy') -> None:
        # our huggingface transformer model that out puts next token probabilities for each timestep
        self.model = model
        # our tokenizer for the model we're using
        self.tokenizer = tokenizer
        # our decoding strategy which we will define later
        self.decoding_strategy = decoding_strategy

        # the index in the logits that corrresponds to our end of sentence token
        self.eos_idx = tokenizer.eos_token_id
        # the maximum timesteps allowed by our model for one input sequence (i.e. 1024)
        self.max_input_length = model.config.max_position_embeddings

    def generate(self, prompt: str, max_steps: int = 100, temperature: int = 1.0) -> str:
        # convert an input string into token IDs
        x = self.tokenizer(prompt, return_tensors='pt')['input_ids']

        # keep a full record of prompt and all predicted tokens to construct intermediate inputs and final output
        running_out = [x]

        step_count = 0
        max_length_warning_triggered = False
        while step_count < max_steps:
            # truncate input sequnce if it violates maximum length by taking N most recent tokens
            if x.size(1) > self.max_input_length:
                # only want to issue this warning once
                if not max_length_warning_triggered:
                    warnings.warn(f'Max input length for model exceeded, using most recent {self.max_input_length} tokens for remaining steps.')
                x = x[:, x.size(1)-self.max_input_length:]
                max_length_warning_triggered = True

            # run forward pass through our model
            out = model(x)
            # batch size of 1 and take logits from final token in the sequence to get predictions about next token
            logits = out.logits[0, -1]
            # convert to probability distribution and temperature scaling
            probs = F.softmax(logits / temperature, dim=0)

            # run decoding strategy to select predicted next token
            token_idx = self.decoding_strategy.sample(probs)

            # end sampling if we select the end of sentence token
            if token_idx.item() == self.eos_idx:
                break

            # append predicted token to our sequence in prep for next prediction step
            running_out.append(token_idx.view(1, 1))
            x = torch.cat(running_out, dim=1)
            step_count += 1

        # convert input prompt and predicted token IDs into final output string
        return self.tokenizer.decode(torch.cat(running_out, dim=1)[0])

Now we define a simple interface that our decoding strategies will adhere to.

In [None]:
class DecodingStrategy(ABC):
    @abstractmethod
    def sample(self, probs: Tensor) -> Tensor:
        raise NotImplementedError

### Greedy sampling

In [None]:
class GreedyDecoding(DecodingStrategy):
    def sample(self, probs: Tensor) -> Tensor:
        # greedy sampling - we just use the token with the highest probability
        token_idx = probs.argmax()

        return token_idx


decoding_strategy = GreedyDecoding()
gen = TextGenerator(model, tokenizer, decoding_strategy)

print(gen.generate('A layer of ice; it feels rough against my face, but not cold.'))

### Random sampling

In [None]:
class RandomDecoding(DecodingStrategy):
    def sample(self, probs: Tensor) -> Tensor:
        # random sampling - select a random token according to the predicted token probability distribution
        token_idx = probs.multinomial(num_samples=1)

        return token_idx


decoding_strategy = RandomDecoding()
gen = TextGenerator(model, tokenizer, decoding_strategy)

print(gen.generate('A layer of ice; it feels rough against my face, but not cold.'))

### Top K sampling

In [None]:
class TopKDecoding(DecodingStrategy):
    def __init__(self, k: int) -> None:
        self.k = k

    def sample(self, probs: Tensor) -> Tensor:
        # top K sampling - take the top K highest probability tokens, re-normalize their probabilities, do weighted random sampling from that
        _, sorted_indices = torch.sort(probs, descending=True)
        # set non-top K probabilities to 0.0 to avoid sampling them
        probs[sorted_indices[self.k:]] = 0.0
        # renormalize probabilities to sum to 1.0
        probs = probs / probs.sum()
        # sample from new probability distribution which is now restricted to the top K tokens
        token_idx = probs.multinomial(num_samples=1)

        return token_idx


decoding_strategy = TopKDecoding(k=100)
gen = TextGenerator(model, tokenizer, decoding_strategy)

print(gen.generate('A layer of ice; it feels rough against my face, but not cold.'))

### Top P sampling (a.k.a. Nucleus sampling)

In [None]:
class TopPDecoding(DecodingStrategy):
    """code ref: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317"""
    def __init__(self, p: float) -> None:
        self.p = p

    def sample(self, probs: Tensor) -> Tensor:
        # top P sampling (or nucleus sampling) - take the top highest probability tokens that add up to P, re-normalize their probabilities, do random sampling from that new distribution
        sorted_probs, sorted_indices = torch.sort(probs, descending=True)
        # find top probabilities until the sum of them exceeds P
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
        idx_to_suppress = cumulative_probs > self.p
        # if the single highest probability is already > P then we just select that token
        idx_to_suppress[0] = False

        # set non-top P probabilities to 0.0 to avoid sampling them
        probs[sorted_indices[idx_to_suppress == True]] = 0.0
        # renormalize probabilities to sum to 1.0
        probs = probs / probs.sum()
        # sample from new probability distribution which is now restricted to the top P tokens
        token_idx = probs.multinomial(num_samples=1)

        return token_idx



decoding_strategy = TopPDecoding(p=0.75)
gen = TextGenerator(model, tokenizer, decoding_strategy)

print(gen.generate('A layer of ice; it feels rough against my face, but not cold.'))

### Top K followed by Top P sampling

In [None]:
class TopKTopPDecoding(DecodingStrategy):
    def __init__(self, k: int, p: float) -> None:
        self.k = k
        self.p = p

    def sample(self, probs: Tensor) -> Tensor:
        # top K sampling
        _, sorted_indices = torch.sort(probs, descending=True)
        probs[sorted_indices[self.k:]] = 0.0
        probs = probs / probs.sum()

        # top P sampling on top K
        sorted_probs, sorted_indices = torch.sort(probs, descending=True)
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
        idx_to_suppress = cumulative_probs > self.p
        idx_to_suppress[0] = False

        probs[sorted_indices[idx_to_suppress == True]] = 0.0
        probs = probs / probs.sum()
        token_idx = probs.multinomial(num_samples=1)


        return token_idx


decoding_strategy = TopKTopPDecoding(k=1000, p=0.75)
gen = TextGenerator(model, tokenizer, decoding_strategy)

print(gen.generate('A layer of ice; it feels rough against my face, but not cold.'))      

In [None]:
# TODO: Beam Search