<img src="https://raw.githubusercontent.com/callummcdougall/TransformerLens-intro/main/images/page_images/sampling.png" width="350">

If you have any feedback on this course (e.g. bugs, confusing explanations, parts that you feel could be structured better), please let me know using [this Google Form](https://forms.gle/2ZhdHa87wWsrATjh9).

# Training and Sampling



## Introduction

In the previous set of exercises, we built a transformer from scratch. Here, we're going to look closer at how a transformer works in practice. We'll cover three topics: how to train transformers, how to sample from their output to autoregressively generate text, and how to use caching to run them more efficiently.

These exercises mainly focus on building up your understanding of transformers, and the important considerations that go into using them. Subsequent exercises will focus more on interpretability, so you can skip to them if you want (this material generally won't be very important for future exercises).

## Learning objectives

Here are the learning objectives for each section of the tutorial. At the end of each section, you should refer back here to check that you've understood everything.

## 1️⃣ Training

* Review the interpretation of a transformer's output, and learn how it's trained by minimizing cross-entropy loss between predicted and actual next tokens
* Construct datasets and dataloaders for the corpus of Shakespeare text
* Implement a transformer training loop

## 2️⃣ Sampling

* Learn how to sample from a transformer
    * This includes basic methods like greedy search or top-k, and more advanced methods like beam search

## 3️⃣ Caching

* Learn how to cache the output of a transformer, so that it can be used to generate text more efficiently
* Update your sampling functions to make use of your caching methods

## Setup

This code includes all the installs and imports you'll need. It also includes code for things like `DemoTransformer`, which you should have done in previous exercises. You are encouraged to copy over your solutions to these exercises, in place of the solutions here.

In [1]:
from IPython.display import HTML, display

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

In [2]:
try:
  import google.colab
  IN_COLAB = True
  print("Running as a Colab notebook")
  %pip install transformer_lens
  %pip install torchtyping
except:
  IN_COLAB = False
  print("Running as a Jupyter notebook - intended for development only!")

Running as a Colab notebook
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformer_lens
  Downloading transformer_lens-1.2.2-py3-none-any.whl (88 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m88.9/88.9 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets>=2.7.1 (from transformer_lens)
  Downloading datasets-2.12.0-py3-none-any.whl (474 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m474.6/474.6 kB[0m [31m26.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting einops>=0.6.0 (from transformer_lens)
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m327.0 kB/s[0m eta [36m0:00:00[0m
[?25hCollecting fancy-einsum>=0.0.3 (from transformer_lens)
  Downloading fancy_einsum-0.0.3-py3-none-any.whl (6.2 kB)
Collecting jaxtyping>=0.2.11 (from transformer_lens)
  Downloading jaxtyping-0.2.

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchtyping
  Downloading torchtyping-0.1.4-py3-none-any.whl (17 kB)
Installing collected packages: torchtyping
Successfully installed torchtyping-0.1.4


In [3]:
import re
from typing import Optional
import torch as t
from torch import nn
from torch.utils.data import DataLoader
import transformers
from typing import List, Tuple, Union, Optional, Callable, Dict
import numpy as np
import einops
from dataclasses import dataclass
from frozendict import frozendict
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from tqdm import tqdm
from torchtyping import TensorType as TT
from transformer_lens import HookedTransformer
from transformer_lens.utils import gelu_new
from transformers import AutoTokenizer, PreTrainedTokenizer
import time
import numpy as np

In [4]:
def test_tensor_dataset(TensorDataset):
    tensors = [t.rand((10, 20)), t.rand((10, 5)), t.arange(10)]
    dataset = TensorDataset(*tensors)
    assert len(dataset) == 10
    for index in [0, slice(0, 5, 1), slice(1, 5, 2)]:
        print("Testing with index:", index)
        expected = tuple(tensor[index] for tensor in tensors)
        actual = dataset[index]
        for e, a in zip(expected, actual):
            t.testing.assert_close(e, a)
    print("All tests in `test_tensor_dataset` passed!")

In [5]:
def plot_two_lines(x1=None, y1=None, x2=None, y2=None, name1="", name2="", title="", xaxis="", yaxis=""):
    fig = make_subplots(specs=[[{"secondary_y": True}]])
    if x1 is None: x1 = list(range(len(y1)))
    if x2 is None: x2 = list(range(len(y2)))
    fig.add_trace(go.Scatter(x=x1, y=y1, name=name1), secondary_y=False)
    fig.add_trace(go.Scatter(x=x2, y=y2, name=name2), secondary_y=True)
    fig.update_layout(title=title, xaxis_title=xaxis, yaxis_title=yaxis)
    fig.show()

In [6]:
@dataclass
class Config:
    d_model: int = 768
    debug: bool = True
    layer_norm_eps: float = 1e-5
    d_vocab: int = 50257
    init_range: float = 0.02
    n_ctx: int = 1024
    d_head: int = 64
    d_mlp: int = 3072
    n_heads: int = 12
    n_layers: int = 12


class LayerNorm(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.w = nn.Parameter(t.ones(cfg.d_model))
        self.b = nn.Parameter(t.zeros(cfg.d_model))

    def forward(self, residual):
        # residual: [batch, position, d_model]
        residual_mean = residual.mean(dim=-1, keepdim=True)
        residual_std = (residual.var(dim=-1, keepdim=True, unbiased=False) + self.cfg.layer_norm_eps).sqrt()

        residual = (residual - residual_mean) / residual_std
        return residual * self.w + self.b


class Embed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_E = nn.Parameter(t.empty((cfg.d_vocab, cfg.d_model)))
        nn.init.normal_(self.W_E, std=self.cfg.init_range)

    def forward(self, tokens):
        # tokens: [batch, position]
        return self.W_E[tokens]


class PosEmbed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_pos = nn.Parameter(t.empty((cfg.n_ctx, cfg.d_model)))
        nn.init.normal_(self.W_pos, std=self.cfg.init_range)

    def forward(self, tokens):
        # tokens: [batch, position]
        batch, seq_len = tokens.shape
        return einops.repeat(self.W_pos[:seq_len], "seq d_model -> batch seq d_model", batch=batch)


class Attention(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_Q = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_K = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_V = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_O = nn.Parameter(t.empty((cfg.n_heads, cfg.d_head, cfg.d_model)))
        self.b_Q = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_K = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_V = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_O = nn.Parameter(t.zeros((cfg.d_model)))
        nn.init.normal_(self.W_Q, std=self.cfg.init_range)
        nn.init.normal_(self.W_K, std=self.cfg.init_range)
        nn.init.normal_(self.W_V, std=self.cfg.init_range)
        nn.init.normal_(self.W_O, std=self.cfg.init_range)
        self.register_buffer("IGNORE", t.tensor(-1e5, dtype=t.float32, device="cuda"))

    def forward(self, normalized_resid_pre: t.Tensor):
        # normalized_resid_pre: [batch, position, d_model]

        # Calculate query, key and value vectors
        q = einops.einsum(
            normalized_resid_pre, self.W_Q,
            "batch posn d_model, nheads d_model d_head -> batch posn nheads d_head" 
        ) + self.b_Q
        k = einops.einsum(
            normalized_resid_pre, self.W_K,
            "batch posn d_model, nheads d_model d_head -> batch posn nheads d_head" 
        ) + self.b_K
        v = einops.einsum(
            normalized_resid_pre, self.W_V,
            "batch posn d_model, nheads d_model d_head -> batch posn nheads d_head" 
        ) + self.b_V

        # Calculate attention scores, then scale and mask, and apply softmax to get probabilities
        attn_scores = einops.einsum(
            q, k, 
            "batch posn_Q nheads d_head, batch posn_K nheads d_head -> batch nheads posn_Q posn_K"
        )
        attn_scores_masked = self.apply_causal_mask(attn_scores / self.cfg.d_head ** 0.5)
        attn_pattern = attn_scores_masked.softmax(-1)

        # Take weighted sum of value vectors, according to attention probabilities
        z = einops.einsum(
            v, attn_pattern, 
            "batch posn_K nheads d_head, batch nheads posn_Q posn_K -> batch posn_Q nheads d_head"
        )

        # Calculate output (by applying matrix W_O and summing over heads, then adding bias b_O)
        out = einops.einsum(
            z, self.W_O, 
            "batch posn_Q nheads d_head, nheads d_head d_model -> batch posn_Q d_model"
        ) + self.b_O

        return out

    def apply_causal_mask(self, attn_scores: t.Tensor):
        # attn_scores: [batch, n_heads, query_pos, key_pos]
        seq_len = attn_scores.shape[-1]
        q_posn = einops.repeat(attn_scores.new_tensor(range(seq_len)), "q -> q k", k=seq_len)
        k_posn = einops.repeat(attn_scores.new_tensor(range(seq_len)), "k -> q k", q=seq_len)
        attn_scores = attn_scores.masked_fill(q_posn < k_posn, self.IGNORE)
        return attn_scores


class MLP(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_in = nn.Parameter(t.empty((cfg.d_model, cfg.d_mlp)))
        nn.init.normal_(self.W_in, std=self.cfg.init_range)
        self.b_in = nn.Parameter(t.zeros((cfg.d_mlp)))
        self.W_out = nn.Parameter(t.empty((cfg.d_mlp, cfg.d_model)))
        nn.init.normal_(self.W_out, std=self.cfg.init_range)
        self.b_out = nn.Parameter(t.zeros((cfg.d_model)))

    def forward(self, normalized_resid_mid):
        # normalized_resid_mid: [batch, position, d_model]
        normalized_resid_mid = normalized_resid_mid @ self.W_in + self.b_in
        normalized_resid_mid = gelu_new(normalized_resid_mid)
        normalized_resid_mid = normalized_resid_mid @ self.W_out + self.b_out
        return normalized_resid_mid


class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

        self.ln1 = LayerNorm(cfg)
        self.attn = Attention(cfg)
        self.ln2 = LayerNorm(cfg)
        self.mlp = MLP(cfg)

    def forward(self, resid_pre):
        # resid_pre: [batch, position, d_model]
        # output: [batch, position, d_model]
        resid_mid = self.attn(self.ln1(resid_pre)) + resid_pre
        resid_post = self.mlp(self.ln2(resid_mid)) + resid_mid
        return resid_post


class Unembed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_U = nn.Parameter(t.empty((cfg.d_model, cfg.d_vocab)))
        nn.init.normal_(self.W_U, std=self.cfg.init_range)
        self.b_U = nn.Parameter(t.zeros((cfg.d_vocab), requires_grad=False))

    def forward(self, normalized_resid_final):
        # normalized_resid_final [batch, position, d_model]
        return einops.einsum(
            normalized_resid_final, self.W_U,
            "batch posn d_model, d_model d_vocab -> batch posn d_vocab",
        ) + self.b_U
        # Or, could just do `normalized_resid_final @ self.W_U + self.b_U`


class DemoTransformer(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.embed = Embed(cfg)
        self.pos_embed = PosEmbed(cfg)
        self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
        self.ln_final = LayerNorm(cfg)
        self.unembed = Unembed(cfg)

    def forward(self, tokens):
        # tokens [batch, position]
        residual = self.embed(tokens) + self.pos_embed(tokens)
        for block in self.blocks:
            residual = block(residual)
        logits = self.unembed(self.ln_final(residual))
        return logits

# 1️⃣ Training

## Learning Objectives

* Review the interpretation of a transformer's output, and learn how it's trained by minimizing cross-entropy loss between predicted and actual next tokens
* Construct datasets and dataloaders for the corpus of Shakespeare text
* Implement a transformer training loop

Hopefully, you've now successfully implemented a transformer, and seen how to use it to generate output autoregressively. You might also have seen the example training loop at the end of the last section. Here, you'll train your transformer in a more hands-on way, using the [complete works of William Shakespeare](https://www.gutenberg.org/files/100/100-0.txt).

This is the task recommended by Jacob Hilton in his [curriculum](https://github.com/jacobhilton/deep_learning_curriculum).

## Cross entropy loss

Your transformer's input has shape `(batch, seq_len)`, where the `[i, j]`-th element is the token id of the `j`-th token in the `i`-th sequence. Your transformer's output has shape `(batch, seq_len, vocab_size)`, where the `[i, j, :]`-th element is a vector of logits, representing a probability distribution over the token that **follows** the `j`-th token in the `i`-th sequence.

When training our model, we use cross-entropy loss between the model's predictions and the actual next tokens. In other words, we can take the `[:, :-1, :]`-th slice of our output (which is a tensor of probability distributions for the **last** `seq_len - 1` tokens in each sequence), and compare this to the `[:, 1:, :]`-th slice (which represents the actual tokens we're trying to predict).

In the last section, we saw the function `lm_cross_entropy_loss` which calculated this for us. Let's take another look at this function, so we understand how it works:

In [7]:
def lm_cross_entropy_loss(logits: t.Tensor, tokens: t.Tensor):
    # Measure next token loss
    # Logits have shape [batch, position, d_vocab]
    # Tokens have shape [batch, position]
    log_probs = logits.log_softmax(dim=-1)
    pred_log_probs = log_probs[:, :-1].gather(dim=-1, index=tokens[:, 1:].unsqueeze(-1)).squeeze(-1)
    return -pred_log_probs.mean()

First, we get `log_probs`, which are the log probabilities of each token in the vocab. Log probs are (as you'd probably guess!) the log of the probabilities implied by the logit distribution. We get them from logits by taking softmax, then taking log again (so they're equal to logits, up to a constant difference). If you examine the formula for [cross entropy loss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html), you'll notice that it's just the negative of the log probability of the correct token.

In the second line, we use the `gather` method to take the log probabilities corresponding to the correct token. This is a bit confusing, and you don't need to understand the exact syntax of `gather`. This line of code does the following:
* Indexes `log_probs`, taking the `[:, :-1]`-th slice (so we have the logits corresponding to the **last** `seq_len - 1` tokens in each sequence)
* Indexes `tokens`, taking the `[:, 1:]`-th slice (so we have the actual tokens we're trying to predict)
* Indexes into the reduced `log_probs` tensor using `gather`, so we get the log probabilities of the correct tokens

Finally, we take the mean of the negative log probabilities, and return this as our loss. Remember that log probs are always negative (because log of a number less than 1 is negative), so our loss will always be non-negative. It will tend to zero only if our model tends towards assigning 100% probability to the correct token, and 0% to all others.


## Tokenizers

Now that we've got cross entropy loss out of the way, let's start working with our dataset. We'll be using the Shakespeare corpus for this exercises; you can get the text by downloading it from [this link](https://drive.google.com/file/d/1qf4TqyMX9TK6W9MEt_AgfbSLVrENBtYl/view?usp=share_link), then running the following code:

In [8]:
from google.colab import files
uploaded = files.upload()

Saving 100-0.txt to 100-0.txt


In [9]:
with open("100-0.txt", encoding="utf-8") as file:
    text = file.read()

You should print out the first few lines of this text, and get a feel for what it looks like.

Rather than using a fancy tokenizer, we'll just split the text into tokens using a regular expression. This is a bit crude, but it's good enough for our purposes.

#### Exercise - implement `SimpleTokenizer`

Below, you should fill in the `SimpleTokenizer` class. Some guidance for this exercise:

#### __init__

The `text` argument is meant to be a string (this will be the same as the `text` object you defined above). Here, you should define `self.words` as a list of all the different tokens that appear in the text, sorted in some reasonable way (you can split the text with `re.split(r"\b", text))`). You should then define `self.word_to_index` and `self.index_to_word`, which are dictionaries that map tokens to their token ids, and vice-versa (with the token ids being the positions of the tokens in `self.words`).

Also, it's good practice to include an unknown token `unk` in your vocabulary, just in case you feed the model a token that it hasn't seen before (you can give it the index one larger than the largest in your words list). We won't bother using a start token here (although you might want to think about doing this, as a bonus exercise).

#### `encode`

This takes in some text, and returns tokens. If `return_tensors` is None (the default), this should return a simple list of integers. If `return_tensors == "pt"`, this should return a PyTorch tensor of shape `(1, seq_len)` (it's good practice to always add a batch dimension, even if there's only one sequence in the batch).

If the input text contains an unknown token, then you can print an error message (or raise an exception).

#### `decode`

Finally, this should take in a list or tensor of tokens (you can assume that the batch dimension will be 1 if it's a tensor), and returns a string of the decoded text.

In [10]:
class SimpleTokenizer():

    def __init__(self, text: str):
        self.words = list(set(re.split(r'\b', text)))
        self.word_to_index = {w : self.words.index(w) for w in self.words}
        self.word_to_index['unk'] = len(self.words)
        self.word_to_index['<start>'] = len(self.words) + 1
        self.index_to_word = {value : key for key, value in self.word_to_index.items()}

    def encode(self, input_text, return_tensors: Optional[str] = None) -> Union[List, t.Tensor]:
        '''
        Tokenizes and encodes the input text.

        If `return_tensors` is None, should return list of Python integers.
        If `return_tensors` is "pt", should return a PyTorch tensor of shape (1, num_tokens).
        '''

        token_list = input_text.split(' ')
        tokenized_list = [self.word_to_index.get(tok, self.word_to_index['unk']) for tok in token_list]
        if return_tensors is None:
          return tokenized_list
        else:
          return t.tensor(tokenized_list).unsqueeze(0)


    def decode(self, tokens: Union[List, t.Tensor]):
        '''
        Decodes the tokens into a string of text.
        '''
        if isinstance(tokens, t.Tensor) and tokens.dim() == 2:
          assert tokens.size(0) == 1, "Only batch size 1 is supported"
          tokens = tokens[0]
          tokens = tokens.tolist()
        
        str_ = ' '.join([self.index_to_word[tok] for tok in tokens])
        return str_


mytokenizer = SimpleTokenizer(text)

# Some basic testing
assert isinstance(mytokenizer.encode("Macbeth"), list)
assert isinstance(mytokenizer.encode("Macbeth", return_tensors="pt"), t.Tensor)
assert mytokenizer.decode(mytokenizer.encode("Macbeth")) == "Macbeth"
assert mytokenizer.index_to_word[mytokenizer.encode("Macbeth")[0]]

## Preparing text

We have our tokenizer, but we still need to be able to take in our `text` object and turn it into a tensor of token ids, without any of them overlapping. This is important because overlapping sequences might cause use to double-count certain sequences during training, and will make it seem like our model is learning faster than it really is.

#### Exercise - implement `prepare_text`

Below, you should fill in the `prepare_text` function.

In [11]:
def prepare_text(text: str, max_seq_len: int, tokenizer: SimpleTokenizer):
    '''
    Takes a string of text, and returns an array of tokens rearranged into chunks of size max_seq_len.
    '''

    tokenized_text = tokenizer.encode(text, return_tensors = 'pt')
    batch_size = tokenized_text.shape[1] // max_seq_len
    tokenized_text = tokenized_text[0, : batch_size * max_seq_len]
    return tokenized_text.view(batch_size, max_seq_len)

max_seq_len=48
tokens = prepare_text(text[:500], max_seq_len=max_seq_len, tokenizer=mytokenizer)
print("Does this size look reasonable, as a tokenization of the first 500 characters?\n", tokens.shape)

Does this size look reasonable, as a tokenization of the first 500 characters?
 torch.Size([1, 48])


## Datasets and Dataloaders

### Build Your Own TensorDataset

The class `torch.utils.data.dataset.TensorDataset` is a convenient wrapper for passing around multiple tensors that have the same size in the first dimension. The most common example of this is in supervised learning, where you have one tensor of inputs and a second tensor with corresponding labels. Often these tensors will have different `dtype`s, so it doesn't make sense to `torch.stack` them into one big tensor, and it be cumbersome to pass them around as separate variables or as a tuple.

`TensorDataset` accepts and stores any number of tensors in the constructor along with implementing `__getitem__` so that `my_dataset[n]` returns a tuple containing element `n` from each stored `Tensor`. Similarly, `my_dataset[:5]` returns a tuple containing the first five elements from each stored `Tensor`.

### Slice Objects in Python

`slice` is a built-in type containing `start`, `stop`, and `step` fields which can be integers or `None`. Given `x=[1,2,3,4,5,6,7]`, writing `x[1:5:2]` is syntactic sugar for `x[slice(1, 5, 2)]`.

#### Exercise - implement `TensorDataset`

*This should be a relatively unchallenging exercise, and you can skip it if it doesn't seem interesting to you.*

You should fill in the methods below, and verify that the tests pass.

Note that we're only passing in one tensor to this class (the `tokens` tensor), but this class should also be able to accept multiple tensors (this will be useful when we get to some later examples, like training models to solve algorithmic tasks).

In [12]:
class TensorDataset:
    def __init__(self, *tensors: t.Tensor):
        '''Validate the sizes and store the tensors in a field named `tensors`.'''
        batch_sizes = [tensor.shape[0] for tensor in tensors]
        assert len(set(batch_sizes)) == 1, "All tensors must have the same size in the first dimension"

        self.data = tensors

    def __getitem__(self, index: Union[int, slice]) -> Tuple[t.Tensor, ...]:
        '''Return a tuple of length len(self.tensors) with the index applied to each.'''
        
        return tuple(tensor[index] for tensor in self.data)

    def __len__(self):
        '''Return the size in the first dimension, common to all the tensors.'''
        
        return self.data[0].shape[0]


test_tensor_dataset(TensorDataset)

dataset = TensorDataset(tokens)

Testing with index: 0
Testing with index: slice(0, 5, 1)
Testing with index: slice(1, 5, 2)
All tests in `test_tensor_dataset` passed!


## Training loop

Now, it's time for our training loop! We've left this exercise very open-ended, like our implementation of the ResNet training loop in last week's exercises. The principles are exactly the same, and we've provided you with a skeleton of the function to help get you started. 

Again, we use a `dataclass` object to store the training parameters, because this is a useful way of keeping your code organised. Note one extra feature here - rather than defining our `optimizer_kwargs` object as a dictionary, we define it as a `frozendict` (which is a special dataclass that works just like regular dicts, except that it isn't mutable). This is a helpful way to get around the fact that you aren't allowed to set dataclass fields to mutable object like dictionaries or lists.

#### Exercise - write a training loop

You should read and understand the code below, and fill in the section marked `YOUR CODE HERE`.

In [13]:
@dataclass
class TransformerTrainingArgs():
    tokenizer: SimpleTokenizer = mytokenizer
    epochs: int = 3
    batch_size: int = 4
    max_seq_len: int = 48
    optimizer: Callable[..., t.optim.Optimizer] = t.optim.Adam
    optimizer_kwargs: Dict = frozendict(lr=0.001, betas=(0.9, 0.999))
    device: str = "cuda" if t.cuda.is_available() else "cpu"
    filename_save_model: str = "transformer_shakespeare.pt"

    
def train_transformer(model: DemoTransformer, text: str, args: TransformerTrainingArgs) -> Tuple[list, list]:
    '''
    Trains an autoregressive transformer on the data in the trainset.

    Returns tuple of (train_loss, test_loss), containing the cross entropy losses for the thing.
    '''
    # Prepare the tokens, take a random train/test split, and create the dataloaders
    tokens = prepare_text(text, max_seq_len=args.max_seq_len, tokenizer=args.tokenizer)
    randperm = t.randperm(tokens.size(0))
    len_trainset = int(0.9 * tokens.size(0))
    trainset = TensorDataset(tokens[randperm[:len_trainset]])
    testset = TensorDataset(tokens[randperm[len_trainset:]])
    trainloader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True)
    testloader = DataLoader(testset, batch_size=args.batch_size, shuffle=True)

    model.to(args.device)
    optimizer = args.optimizer(model.parameters(), **args.optimizer_kwargs)

    train_loss_list = []
    test_loss_list = []
    test_losses = []
    # YOUR CODE HERE - implement training and testing loops
    best_test_loss = 100000
    for epoch in range(args.epochs):

      train_losses = []
      for tokens in tqdm(trainloader):
    
        tokens = tokens[0]
        tokens = tokens.to(args.device)
        logits = model(tokens)
        loss = lm_cross_entropy_loss(logits, tokens)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_losses.append(loss.item())


      average_loss = sum(train_losses) / len(train_losses)
      train_loss_list.append(average_loss)
      print(f'Average train loss after epoch {epoch+1} is {average_loss}')
      with t.inference_mode():

        test_loss = 0
        total = 0
        for step, tokens in enumerate(testloader):
          tokens = tokens[0]
          tokens = tokens.to(args.device)
          logits = model(tokens)
          loss = lm_cross_entropy_loss(logits, tokens)
          test_loss += lm_cross_entropy_loss(logits, tokens) * tokens.size(0)
          total += tokens.size(0)

        test_loss /= total
        print(f'Average test loss after epoch {epoch+1} is {test_loss}')
        test_loss_list.append(test_loss.item())  
        if test_loss.item() < best_test_loss:
          best_test_loss = test_loss
          print(f"\nSaving model to: {args.filename_save_model}")
          t.save(model, args.filename_save_model)

    return train_loss_list, test_loss_list

You can take a look at the solutions for an example implementation (although it's totally fine to have something which looks different to this).



Once you've written a training loop, you can run it (and plot your output) with the following code:

In [14]:
config = Config(
    d_model = 384,
    layer_norm_eps = 1e-5,
    d_vocab = 50257,
    init_range = 0.02,
    n_ctx = 1024,
    d_head = 64,
    d_mlp = 1536,
    n_heads = 6,
    n_layers = 4
)

model = DemoTransformer(config)

args = TransformerTrainingArgs(
    tokenizer = mytokenizer,
    batch_size = 8,
    epochs = 3,
)

train_loss_list, test_loss_list = train_transformer(model, text, args)

plot_two_lines(
    y1 = train_loss_list,
    y2 = test_loss_list,
    x2 = list(range(
        len(train_loss_list) // len(test_loss_list), 
        len(train_loss_list) + 1,
        len(train_loss_list) // len(test_loss_list)
    )),
    name1 = "Train loss",
    name2 = "Test loss",
    title = "Loss curve for transformer trained on Shakespeare corpus",
    xaxis = "Batches seen",
    yaxis = "Cross entropy loss"
)

100%|██████████| 2202/2202 [01:44<00:00, 21.13it/s]


Average train loss after epoch 1 is 4.646601952172539
Average test loss after epoch 1 is 4.529391765594482

Saving model to: transformer_shakespeare.pt


100%|██████████| 2202/2202 [01:43<00:00, 21.22it/s]


Average train loss after epoch 2 is 4.3964127488400475
Average test loss after epoch 2 is 4.465266227722168

Saving model to: transformer_shakespeare.pt


100%|██████████| 2202/2202 [01:47<00:00, 20.43it/s]


Average train loss after epoch 3 is 4.297598929647745
Average test loss after epoch 3 is 4.446091651916504

Saving model to: transformer_shakespeare.pt


You can try playing around with some of the hyperparameters, and see how they affect the training process. You might also want to try out using different datasets (there are many online you can use!).

# 2️⃣ Sampling

#### Learning Objectives

* Learn how to sample from a transformer
    * This includes basic methods like greedy search or top-k, and more advanced methods like beam search

One obvious method to sample tokens from a distribution would be to always take the token assigned the highest probability. But this can lead to some boring and repetitive outcomes, and at worst it can lock our transformer's output into a loop.

First, you should read HuggingFace's blog post [How to generate text: using different decoding methods for language generation with Transformers](https://huggingface.co/blog/how-to-generate).

Once you've done that, we've included some exercises below that will allow you to write your own methods for sampling from a transformer. You'll be working with a pretrained model rather than the Shakespeare model in the previous set of exercises (because sampling can behave quite unpredictably unless tokenization and training are done very carefully), although you might want to try substituting in your Shakespeare model to these exercises if you have extra time at the end, and see how it behaves.

In [15]:
reference_gpt2 = HookedTransformer.from_pretrained("gpt2-small", fold_ln=False, center_unembed=False, center_writing_weights=False)
gpt2 = DemoTransformer(Config())
gpt2.load_state_dict(reference_gpt2.state_dict(), strict=False)
gpt2.cuda()

tokenizer = AutoTokenizer.from_pretrained("gpt2")

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


## Sampling Boilerplate

The provided functions `apply_sampling_methods` and `sample_tokens` include the boilerplate for sampling from the model. Note that there is a special token `tokenizer.eos_token`, which during training was added to the end of a each article. GPT-2 will generate this token when it feels like the continuation is at a reasonable stopping point, which is our cue to stop generation.

The functions called in `apply_sampling_methods` are not defined yet - you are going to implement them below.

In [16]:
def apply_sampling_methods(
    input_ids: t.Tensor, 
    logits: t.Tensor, 
    temperature=1.0, 
    freq_penalty=0.0, 
    top_k=0, 
    top_p=0.0,
    seed=0
) -> int:
    '''
    Return the next token, sampled from the model's probability distribution with modifiers.
    input_ids: shape (seq,)
    '''
    assert input_ids.ndim == 1, "input_ids should be a 1D sequence of token ids"
    assert temperature >= 0, "Temperature should be non-negative"
    assert 0 <= top_p <= 1.0, "Top-p must be a probability"
    assert 0 <= top_k, "Top-k must be non-negative"
    assert not (top_p != 0 and top_k != 0), "At most one of top-p and top-k supported"

    # Set random seeds for reproducibility
    t.manual_seed(seed)
    np.random.seed(seed)
    
    if temperature == 0:
        return greedy_search(logits)
    if temperature != 1.0:
        logits = apply_temperature(logits, temperature)
    if freq_penalty != 0.0:
        logits = apply_freq_penalty(input_ids, logits, freq_penalty)
    if top_k > 0:
        return sample_top_k(logits, top_k)
    if top_p > 0:
        return sample_top_p(logits, top_p)
    return sample_basic(logits)


@t.inference_mode()
def sample_tokens(
    model: DemoTransformer,
    tokenizer: PreTrainedTokenizer,
    initial_text: str,
    max_tokens_generated=30,
    **kwargs # kwargs are for params like temperature, top_k, etc
) -> str:
    '''
    Sample tokens until the model outputs `tokenizer.eos_token_id` or the specified token limit is reached.

    Return: the prompt and continuation concatenated
    '''
    model.eval()
    input_ids: list = tokenizer.encode(initial_text)
    generated = []
    for _ in range(max_tokens_generated):
        new_input_ids = t.tensor(input_ids + generated, dtype=t.long, device="cuda")
        new_input_ids_window = new_input_ids[-min(model.cfg.n_ctx, new_input_ids.shape[0]):].unsqueeze(0)
        logits = model(new_input_ids_window)[0, -1]
        new_token = apply_sampling_methods(new_input_ids, logits, **kwargs)
        generated.append(new_token)
        if new_token == getattr(tokenizer, "eos_token_id", None):
            break
    return tokenizer.decode(input_ids + generated)

A few notes on this function:

* We use `tokenizer.encode` to convert the initial text string into a list of logits. You can also pass the argument `return_tensors="pt"` in order to return the output as a tensor.
* `new_input_ids` is a concatenation of the original input ids, and the ones that have been autoregressively generated.
* `new_input_ids_truncated` truncates `new_input_ids` at `max_seq_len` (because you might get an error at the positional embedding stage if your input sequence length is too large).
* The line `all_logits = ...` is necessary because HuggingFace's GPT doesn't just output logits, it outputs an object which contains `logits` and `past_key_values`. In contrast, your model will probably just output logits, so we can directly define logits as the model's output.

<details>
<summary>Question - why do you think we take <code>logits[0, -1]</code> ?</summary>

Our model input has shape `(batch, seq_len)`, and each element is a token id. Our output has dimension `(batch, seq_len, vocab_size)`, where the `[i, j, :]`th element is a vector of logits representing a prediction for the `j+1`th token.

In this case, our batch dimension is 1, and we want to predict the token that follows after all the tokens in the sequence, hence we want to take `logits[0, -1, :]`.
</details>

### Greedy Search

Implement `greedy_search`, which just returns the most likely next token. If multiple tokens are equally likely, break the tie by returning the smallest token.

Why not break ties randomly? It's nice that greedy search is deterministic, and also nice to not have special code for a case that rarely occurs (floats are rarely exactly equal).

Tip: the type checker doesn't know the return type of `item()` is int, but you can assert that it really is an int and this will make the type checker happy.

In [17]:
def greedy_search(logits: t.Tensor) -> int:
    '''
    logits: shape (vocab_size, )

    Return: the most likely token (as an integer)
    '''
    
    return t.argmax(logits).item()


prompt = "Jingle bells, jingle bells, jingle all the way"
print("Greedy decoding with prompt: ", prompt)
output = sample_tokens(gpt2, tokenizer, prompt, max_tokens_generated=8, temperature=0.0)
print(f"Your model said: {output}")
expected = "Jingle bells, jingle bells, jingle all the way up to the top of the mountain."
assert output == expected

print("Greedy decoding a second time (should be deterministic): ")
output = sample_tokens(gpt2, tokenizer, prompt, max_tokens_generated=8, temperature=0.0)
print(f"Your model said: {output}")
expected = "Jingle bells, jingle bells, jingle all the way up to the top of the mountain."
assert output == expected

print("Tests passed!")

Greedy decoding with prompt:  Jingle bells, jingle bells, jingle all the way
Your model said: Jingle bells, jingle bells, jingle all the way up to the top of the mountain.
Greedy decoding a second time (should be deterministic): 
Your model said: Jingle bells, jingle bells, jingle all the way up to the top of the mountain.
Tests passed!


## Sampling with Categorical

PyTorch provides a [`distributions` package](https://pytorch.org/docs/stable/distributions.html#distribution) with a number of convenient methods for sampling from various distributions.

For now, we just need [`t.distributions.categorical.Categorical`](https://pytorch.org/docs/stable/distributions.html#categorical). Use this to implement `sample_basic`, which just samples from the provided logits (which may have already been modified by the temperature and frequency penalties).

Note that this will be slow since we aren't batching the samples, but don't worry about speed for now.

### Basic Sampling

In [18]:
def sample_basic(logits: t.Tensor) -> int:
    '''
    logits: shape (vocab_size, ) - unnormalized log-probabilities

    Return: a sampled token
    '''
    
    
    dist = t.distributions.categorical.Categorical(logits = logits)
    return dist.sample()


N = 20000
probs = t.linspace(0, 0.4, 5)
unnormalized_logits = probs.log() + 1.2345
samples = t.tensor([sample_basic(unnormalized_logits) for _ in range(N)])
counts = t.bincount(samples, minlength=len(probs)) / N
print("Checking empirical frequencies (try to increase N if this test fails): ", counts)
t.testing.assert_close(counts, probs, atol=0.01, rtol=0)
print("Tests passed!")

Checking empirical frequencies (try to increase N if this test fails):  tensor([0.0000, 0.1005, 0.1946, 0.3048, 0.4001])
Tests passed!


### Temperature

Temperature sounds fancy, but it's literally just dividing the logits by the temperature.

In [19]:
def apply_temperature(logits: t.Tensor, temperature: float) -> t.Tensor:
    '''
    logits: shape (vocab_size, )

    Return: shape (vocab_size, )
    '''
    assert temperature > 0
    logits = logits / temperature
    return logits

    
logits = t.tensor([1, 2]).log()
cold_logits = apply_temperature(logits, 0.001)
print('A low temperature "sharpens" or "peaks" the distribution: ', cold_logits)
t.testing.assert_close(cold_logits, 1000.0 * logits)
hot_logits = apply_temperature(logits, 1000.0)
print("A high temperature flattens the distribution: ", hot_logits)
t.testing.assert_close(hot_logits, 0.001 * logits)
print("Tests passed!")

A low temperature "sharpens" or "peaks" the distribution:  tensor([  0.0000, 693.1472])
A high temperature flattens the distribution:  tensor([0.0000, 0.0007])
Tests passed!


### Frequency Penalty

The frequency penalty is simple as well: count the number of occurrences of each token, then subtract `freq_penalty` for each occurrence. Hint: use `t.bincount` (documentation [here](https://pytorch.org/docs/stable/generated/torch.bincount.html)) to do this in a vectorized way.

<details>
<summary>Help - I'm getting a RuntimeError; my tensor sizes don't match.</summary>

Look at the documentation page for `t.bincount`. You might need to use the `minlength` argument - why?
</details>

In [20]:
def apply_freq_penalty(input_ids: t.Tensor, logits: t.Tensor, freq_penalty: float) -> t.Tensor:
    '''
    input_ids: shape (seq, )
    logits: shape (vocab_size, )
    Return: shape (vocab_size, )
    '''
    
    vocab_size = logits.shape[0]
    freq = t.bincount(input_ids, minlength=vocab_size)
    return logits - freq * freq_penalty

    
bieber_prompt = "And I was like Baby, baby, baby, oh Like, Baby, baby, baby, no Like, Baby, baby, baby, oh I thought you'd always be mine, mine"
input_ids = tokenizer.encode(bieber_prompt, return_tensors="pt").squeeze()
logits = t.ones(tokenizer.vocab_size)
penalized_logits = apply_freq_penalty(input_ids, logits, 2.0)
assert penalized_logits[5156].item() == -11, "Expected 6 occurrences of ' baby' with leading space"
assert penalized_logits[14801].item() == -5, "Expected 3 occurrences of ' Baby' with leading space"
print("Tests passed!")

Tests passed!


### Sampling - Manual Testing

Run the below cell to get a sense for the `temperature` and `freq_penalty` arguments. Play with your own prompt and try other values.

Note: your model can generate newlines or non-printing characters, so calling `print` on generated text sometimes looks awkward on screen. You can call `repr` on the string before printing to have the string escaped nicely.

In [21]:
N_RUNS = 1
your_prompt = "Jingle bells, jingle bells, jingle all the way"
cases = [
    ("High freq penalty", dict(freq_penalty=100.0)),
    ("Negative freq penalty", dict(freq_penalty=-3.0)),
    ("Too hot!", dict(temperature=2.0)),
    ("Pleasantly cool", dict(temperature=0.7)),
    ("Pleasantly warm", dict(temperature=0.9)),
    ("Too cold!", dict(temperature=0.01)),
]
for (name, kwargs) in cases:
    for i in range(N_RUNS):
        output = sample_tokens(gpt2, tokenizer, your_prompt, max_tokens_generated=24, **kwargs)
        print(f"Sample {i} with: {name} ({kwargs}):")
        print(f"Your model said: {repr(output)}\n")

Sample 0 with: High freq penalty ({'freq_penalty': 100.0}):
Your model said: "Jingle bells, jingle bells, jingle all the way bell. Visit Sugaree's Scarf Shop if you enjoy eating under contract loveshow!! Or click here to get"

Sample 0 with: Negative freq penalty ({'freq_penalty': -3.0}):
Your model said: 'Jingle bells, jingle bells, jingle all the way, bell bell bell bell bell bell bell bell bell bell bell bell bell bell bell bell bell bell bell bell bell bell bell'

Sample 0 with: Too hot! ({'temperature': 2.0}):
Your model said: 'Jingle bells, jingle bells, jingle all the way bell bell bell bell bell bell bell bell bell bell bell bell bell bell bell bell bell bell bell bell bell bell bell bell'

Sample 0 with: Pleasantly cool ({'temperature': 0.7}):
Your model said: 'Jingle bells, jingle bells, jingle all the way.\n\nI was never gonna do that. I never gonna do that. I never gonna do that. I never'

Sample 0 with: Pleasantly warm ({'temperature': 0.9}):
Your model said: 'Jingle bells

## Top-K Sampling

Conceptually, the steps in top-k sampling are:
- Find the `top_k` largest probabilities
- Set all other probabilities to zero
- Normalize and sample

Your implementation should stay in log-space throughout (don't exponentiate to obtain probabilities). This means you don't actually need to worry about normalizing, because `Categorical` accepts unnormalised logits.

<details>
<summary>Help - I don't know what function I should use for finding the top k.</summary>

Use [`t.topk`](https://pytorch.org/docs/stable/generated/torch.topk.html).
</details>

In [22]:
def sample_top_k(logits: t.Tensor, top_k: int) -> int:
    '''
    logits: shape (vocab_size, ) - unnormalized log-probabilities
    top_k: only consider this many of the most likely tokens for sampling

    Return: a sampled token
    '''
    
    topk_pred_tokens = t.topk(logits, k = top_k)
    pred_distr = t.distributions.categorical.Categorical(logits = topk_pred_tokens.values)
    return topk_pred_tokens.indices[pred_distr.sample().item()].item()


k = 3
probs = t.linspace(0, 0.4, 5)
unnormalized_logits = probs.log() + 1.2345
samples = t.tensor([sample_top_k(unnormalized_logits, k) for _ in range(N)])
counts = t.bincount(samples, minlength=len(probs)) / N
expected = probs.clone()
expected[:-k] = 0
expected /= expected.sum()
print("Checking empirical frequencies (try to increase N if this test fails): ", counts)
t.testing.assert_close(counts, expected, atol=0.01, rtol=0)
print("Tests passed!")

Checking empirical frequencies (try to increase N if this test fails):  tensor([0.0000, 0.0000, 0.2180, 0.3327, 0.4494])
Tests passed!


### Top-K Sampling - Example

The [GPT-2 paper](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) famously included an example prompt about unicorns. Now it's your turn to see just how cherry picked this example was.

The paper claims they used `top_k=40` and best of 10 samples.

In [23]:
your_prompt = "In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English."
output = sample_tokens(gpt2, tokenizer, your_prompt, temperature=0.7, top_k=40, max_tokens_generated=64)
print(f"Your model said: {output}")

Your model said: In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English.

The team of researchers led by the University of Oxford and the University of Utah found that the unicorns spoke English as well as Spanish in the valley where they lived. This means that they are the first unicorns to be recorded as speaking perfect English.

The findings have been published in the journal PLOS


## Top-p aka Nucleus Sampling

The basic idea is that we choose the most likely words, up until the total probability of words we've chosen crosses some threshold. Then we sample from those chosen words based on their logits.

The steps are:

- Sort the probabilities from largest to smallest
- Find the cutoff point where the cumulative probability first equals or exceeds `top_p`. We do the cutoff inclusively, keeping the first probability above the threshold.
- If the number of kept probabilities is less than `min_tokens_to_keep`, keep that many tokens instead.
- Set all other probabilities to zero
- Normalize and sample

Optionally, refer to the paper [The Curious Case of Neural Text Degeneration](https://arxiv.org/pdf/1904.09751.pdf) for some comparison of different methods.

#### Exercise - implement `sample_top_p`

<details>
<summary>Example of top-p sampling (if you're confused)</summary>

If our probabilities were `(0.4, 0.3, 0.2, 0.1)` and our cutoff was `top_p=0.8`, then we'd sample from the first three elements (because their total probability is `0.9` which is over the threshold, but the first two only have a total prob of `0.7` which is under the threshold). Once we've chosen to sample from those three, we would renormalise them by dividing by their sum (so the probabilities we use when sampling are `(4/9, 3/9, 2/9)`.
</details>

<details>
<summary>Help - I'm stuck on how to implement this function.</summary>

First, sort the logits using the `sort(descending=True)` method (this returns values and indices). Then you can get `cumulative_probs` by applying softmax to these logits and taking the cumsum. Then, you can decide how many probabilities to keep by using the `t.searchsorted` function.
    
Once you've decided which probabilities to keep, it's easiest to sample from them using the original logits (you should have preserved the indices when you called `logits.sort`). This way, you don't need to worry about renormalising like you would if you were using probabilities.
</details>

In [43]:
def sample_top_p(logits: t.Tensor, top_p: float, min_tokens_to_keep: int = 1) -> int:
    '''
    logits: shape (vocab_size, ) - unnormalized log-probabilities
    Return: a sampled token
    '''

    softmax_probs = t.softmax(logits.unsqueeze(0), dim = -1)
    sorted_probs = softmax_probs.sort(descending = True)
    cumsum_probs = t.cumsum(sorted_probs.values, dim = -1)
    thresh_index = t.searchsorted(cumsum_probs.squeeze(0), top_p).item() + 1
    keep_tokens = max(thresh_index, min_tokens_to_keep)
    indices = sorted_probs.indices.squeeze(0)
    indices_to_keep = indices[:keep_tokens]
    #logits = sorted_probs.values.squeeze(0)
    logits = logits[indices_to_keep]

    distr = t.distributions.categorical.Categorical(logits = logits)
    sample = indices_to_keep[distr.sample()]
    return sample.item()

    
N = 2000
unnormalized_logits = t.tensor([0.2, 0.3, 0.5]).log() + 2.3456
samples = t.tensor([sample_top_p(unnormalized_logits, 0.5) for _ in range(N)])
counts = t.bincount(samples, minlength=len(unnormalized_logits)) / N
print("top_p of 0.5 or lower should only return token 2: ", counts)
assert counts[0] == 0 and counts[1] == 0

N = 2000
unnormalized_logits = t.tensor([0.2, 0.3, 0.5]).log() + 2.3456
samples = t.tensor([sample_top_p(unnormalized_logits, 0.50001) for _ in range(N)])
counts = t.bincount(samples, minlength=len(unnormalized_logits)) / N
print("top_p in (0.5, 0.8] should return tokens 1 and 2: ", counts)
assert counts[0] == 0

N = 5000
top_p = 0.71
probs = t.linspace(0, 0.4, 5)
unnormalized_logits = probs.log() + 1.2345
samples = t.tensor([sample_top_p(unnormalized_logits, top_p) for _ in range(N)])
counts = t.bincount(samples, minlength=len(probs)) / N
expected = probs.clone()
expected[0:2] = 0
expected /= expected.sum()
print("Checking empirical frequencies (try to increase N if this test fails): ", counts)
t.testing.assert_close(counts, expected, atol=0.01, rtol=0.0)

print("All tests passed!")

top_p of 0.5 or lower should only return token 2:  tensor([0., 0., 1.])
top_p in (0.5, 0.8] should return tokens 1 and 2:  tensor([0.0000, 0.3685, 0.6315])
Checking empirical frequencies (try to increase N if this test fails):  tensor([0.0000, 0.0000, 0.2198, 0.3304, 0.4498])
All tests passed!


### Top-p Sampling - Example


In [44]:
your_prompt = "Eliezer Shlomo Yudkowsky (born September 11, 1979) is an American decision and artificial intelligence (AI) theorist and writer, best known for"
output = sample_tokens(gpt2, tokenizer, your_prompt, temperature=0.7, top_p=0.95, max_tokens_generated=64)
print(f"Your model said: {repr(output)}")

Your model said: 'Eliezer Shlomo Yudkowsky (born September 11, 1979) is an American decision and artificial intelligence (AI) theorist and writer, best known for his work on the New York Times best-seller "The Art of Intelligence: Why Artificial Intelligence Can Make Us Better Scientists," as well as for his books and books on AI and social engineering. He is a graduate of the University of Pennsylvania. His latest book is "The Future of Human Intelligence," and the following books'
