# Setup

In [1]:
from collections import defaultdict
import os

from dataclasses import dataclass
import einops
from jaxtyping import Float, Int
from rich.table import Table
from rich import print as rprint
import torch as t
from torch import Tensor
from torch.distributions.categorical import Categorical
from transformer_lens import HookedTransformer
from transformers.models.gpt2.tokenization_gpt2_fast import GPT2TokenizerFast
from tqdm.notebook import tqdm

from C1P1__mj_implementation import Config, DemoTransformer

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

## Set device

In [2]:
device = t.device(
    "mps"
    if t.backends.mps.is_available()
    else "cuda"
    if t.cuda.is_available()
    else "cpu"
)
print(device)

cuda


## Load GPT-2 Small

In [3]:
reference_gpt2 = HookedTransformer.from_pretrained(
    "gpt2-small",
    fold_ln=False,
    center_unembed=False,
    center_writing_weights=False,
    device=device,
)



Loaded pretrained model gpt2-small into HookedTransformer


  return t.to(


# Sampling from a Transformer

### Learning Objectives

- Learn how to sample from a transformer.
  - Includes basic methods like **greedy search** or **top-k**, and more advanced methods like **beam search**.

- Learn how to cache the output from a transformer, so that it can be used to generate text more efficiently.
  - Rewrite sampling functions to use caching.

Obvious way to sample tokens: always take the token assigned the highest probability!

- This can lead to boring and repetitive outcomes.
- At worst, it can lock our transformer into a loop.

##### Read the HuggingFace blog post: ["How to generate text: using different decoding methods for language generation with Transformers"](https://huggingface.co/blog/how-to-generate)

## TransformerSampler

In [4]:
model_cfg = Config()
model = DemoTransformer(model_cfg).to(device)
model.load_state_dict(reference_gpt2.state_dict(), strict=False)

tokenizer = reference_gpt2.tokenizer


class TransformerSampler:
    def __init__(self, model: DemoTransformer, tokenizer: GPT2TokenizerFast):
        self.model = model
        self.cfg = model.cfg
        self.tokenizer = tokenizer

    @t.inference_mode()
    def sample(self, prompt: str, max_tokens_generated=100, verbose=False, **kwargs):
        """
        Returns a string of autoregressively generated text, starting from the prompt.

        Sampling terminates at max_tokens_generated, or when the model generates an
        end-of-sequence token.

        kwargs are passed to sample_next_token, to give detailed instructions on how
        new tokens are chosen.
        """
        # YOUR CODE HERE!
        raise NotImplementedError()

    @t.inference_mode()
    def beam_search(
        self,
        prompt: str,
        num_return_sequences: int,
        num_beams: int,
        max_new_tokens: int,
        no_repeat_ngram_size: int = 0,
        verbose=False,
    ) -> list[tuple[float, Tensor]]:
        """
        Returns a string of autoregressively generated text, starting from the prompt.

        Sampling terminates at max_tokens_generated, or when the model generates an
        end-of-sequence token.

        kwargs are passed to sample_next_token, to give detailed instructions on how
        new tokens are chosen.
        """
        # YOUR CODE HERE!
        raise NotImplementedError()

    @staticmethod
    def sample_next_token(
        input_ids: Int[Tensor, "seq_len"],
        logits: Float[Tensor, "d_vocab"],
        temperature=1.0,
        top_k=0,
        top_p=0.0,
        frequency_penalty=0.0,
        seed=None,
    ):
        assert input_ids.ndim == 1, "input_ids should be a 1D sequence of token ids"
        assert temperature >= 0, "Temperature should be non-negative"
        assert 0 <= top_p <= 1.0, "Top-p must be a probability"
        assert 0 <= top_k, "Top-k must be non-negative"
        assert not (
            top_p != 0 and top_k != 0
        ), "At most one of top-p and top-k supported"

        # Set random seeds for reproducibility
        if seed is not None:
            t.manual_seed(seed)
            np.random.seed(seed)

        # Apply all the specialized sampling methods
        if temperature == 0:
            return TransformerSampler.greedy_search(logits)
        elif temperature != 1.0:
            logits = TransformerSampler.apply_temperature(logits, temperature)
        if frequency_penalty != 0.0:
            logits = TransformerSampler.apply_frequency_penalty(
                input_ids, logits, frequency_penalty
            )
        if top_k > 0:
            return TransformerSampler.sample_top_k(logits, top_k)
        if top_p > 0.0:
            return TransformerSampler.sample_top_p(logits, top_p)
        return TransformerSampler.sample_basic(logits)

    @staticmethod
    def greedy_search(logits: Float[Tensor, "d_vocab"]) -> int:
        """
        Returns the most likely token (as an int).
        """
        out = logits.argmax().item()
        return out

    @staticmethod
    def apply_temperature(
        logits: Float[Tensor, "d_vocab"], temperature: float
    ) -> Float[Tensor, "d_vocab"]:
        """
        Applies temperature scaling to the logits.
        """
        return logits / temperature

    @staticmethod
    def apply_frequency_penalty(
        input_ids: Int[Tensor, "seq_len"],
        logits: Float[Tensor, "d_vocab"],
        freq_penalty: float,
    ) -> Float[Tensor, "d_vocab"]:
        """
        Applies a frequency penalty to the logits.
        """
        counts = t.bincount(input_ids, minlength=logits.shape[-1])

        return logits - (freq_penalty * counts)

    @staticmethod
    def sample_basic(logits: Float[Tensor, "d_vocab"]) -> int:
        """
        Samples from the distribution defined by the logits.
        """
        pass

    @staticmethod
    def sample_top_k(logits: Float[Tensor, "d_vocab"], k: int) -> int:
        """
        Samples from the top k most likely tokens.
        """
        pass

    @staticmethod
    def sample_top_p(
        logits: Float[Tensor, "d_vocab"], top_p: float, min_tokens_to_keep: int = 1
    ) -> int:
        """
        Samples from the most likely tokens which make up at least p cumulative probability.
        """
        pass

## Main Sampling Function

### Exercise: implement `sample()`

`sample()` takes in a prompt (string), encodes it as a sequence of token ids using `self.tokenizer.encode`, and then continually generates new tokens by repeating the following steps:

1. Pass the tokenized prompt through the model to get logits

2. Take the logit vector corresponding to the last token in the prompt (i.e., prediction for the *next* token)

3. Sample from this distribution to get a new token, using `self.sample_next_token(input_ids, logits, **kwargs)`. `kwargs` contains all the sampling-specific args, e.g., **temperature**, **top-k**, etc.

4. Append this new token to the input tokens, and repeat until we meet one of two termination criteria:
   - We generate `max_tokens_generated` new tokens, or

   - We generate the EOS token, accessed via `self.tokenizer.eos_token_id.

Finally, we use `self.tokenizer.decode` to convert the generated token ids back into a string, and return the string.

We also have a `verbose` arg - use to print output while it's being sampled.

A few hints:
- Don't forget about tensor shapes! The model's input should always have a `batch` dimension.

- `sample_next_token()` will return an integer. Wrap this into a tensor before concatenating it to the end of the input IDs

- Remember device!

- Put the model in evaluation mode using `model.eval()`

In [5]:
@t.inference_mode()
def sample(self, prompt: str, max_tokens_generated=100, verbose=False, **kwargs):
    """
    Returns a string of autoregressively generated text, starting from the prompt.

    Sampling terminates at max_tokens_generated, or when the model generates an
    end-of-sequence token.

    kwargs are passed to sample_next_token, to give detailed instructions on how
    new tokens are chosen.
    """
    self.model.eval()

    # Default return type is list. `return_tensors="pt"`` returns a 2D PyTorch tensor.
    tokens = self.tokenizer.encode(prompt, return_tensors="pt").to(device)[0]

    for _ in range(max_tokens_generated):
        # Need batch dim, and only the last n_ctx tokens - don't want to exceed context length (better to warn?).
        logits_last_token = self.model(tokens[None, -self.cfg.n_ctx :])[0, -1]

        # Note that sample_next_token will return a scalar tensor - need to wrap in a list to give it a dim.
        token_next = t.tensor(
            # Note use of TransformerSampler rather than `self` - this is a static method.
            [TransformerSampler.sample_next_token(tokens, logits_last_token, **kwargs)],
            device=device,
        )
        tokens = t.cat((tokens, token_next), dim=-1)

        if tokens[-1] == self.tokenizer.eos_token_id:
            break

    res = self.tokenizer.decode(tokens)

    if verbose:
        print(res)

    return res


TransformerSampler.sample = sample

#### Test

In [6]:
sampler = TransformerSampler(model, tokenizer)

prompt = "Jingle bells, jingle bells, jingle all the way"
print(f"Greedy decoding with prompt: {prompt!r}\n")

output = sampler.sample(prompt, max_tokens_generated=8, temperature=0.0, verbose=True)
print(f"Your model said: {output!r}\n")

expected = (
    "Jingle bells, jingle bells, jingle all the way up to the top of the mountain."
)
assert output == expected

print("Tests passed!")

Greedy decoding with prompt: 'Jingle bells, jingle bells, jingle all the way'

Jingle bells, jingle bells, jingle all the way up to the top of the mountain.
Your model said: 'Jingle bells, jingle bells, jingle all the way up to the top of the mountain.'

Tests passed!


## Sampling with Categorical

PyTorch provides a `distributions` package containing convenient methods for sampling from various distributions

For now, use `t.distributions.categorical.Categorical` to implement `sample_basic`. This just samples from the provided logits, which may have already been modified by the temperature and frequency penalties.

Will be slow since we're not batching the samples (yet)

### Exercise: Basic Sampling

In [7]:
@staticmethod
def sample_basic(logits: Float[Tensor, "d_vocab"]) -> int:
    """
    Samples from the distribution defined by the logits.
    """
    distn = Categorical(logits=logits)
    return distn.sample().item()


TransformerSampler.sample_basic = sample_basic

#### Test

In [8]:
prompt = "John and Mary went to the"
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
logits = model(input_ids)[0, -1]

expected_top_5 = {
    " church": 0.0648,
    " house": 0.0367,
    " temple": 0.0145,
    " same": 0.0104,
    " Church": 0.0097,
}
frequency_of_top_5 = defaultdict(int)

N = 10_000
for _ in tqdm(range(N)):
    token = TransformerSampler.sample_next_token(input_ids.squeeze(), logits)
    frequency_of_top_5[tokenizer.decode(token)] += 1

for word in expected_top_5:
    expected_freq = expected_top_5[word]
    observed_freq = frequency_of_top_5[word] / N
    print(
        f"Word: {word!r:<9}. Expected freq {expected_freq:.4f}, observed freq {observed_freq:.4f}"
    )
    assert (
        abs(observed_freq - expected_freq) < 0.01
    ), "Try increasing N if this fails by a small amount."

print("Tests passed!")

  0%|          | 0/10000 [00:00<?, ?it/s]

Word: ' church'. Expected freq 0.0648, observed freq 0.0629
Word: ' house' . Expected freq 0.0367, observed freq 0.0361
Word: ' temple'. Expected freq 0.0145, observed freq 0.0137
Word: ' same'  . Expected freq 0.0104, observed freq 0.0105
Word: ' Church'. Expected freq 0.0097, observed freq 0.0106
Tests passed!


### Exercise: Temperature

***** Implemented above

#### Notes on temperature

To apply a temperature to our sampling means to scale all logits by `1/temperature`. The basic intuition is:

- A higher temperature means a smaller scale factor, which diminishes the differences between logits. This leads to a more uniform distribution and more random sampling.

- A lower temperature means a larger scale factor, which amplifies the differences between logits. This makes the highest logit dominate the softmax distribution, resulting in greedy sampling.

##### Derivation

`sample_basic()` samples from a `Categorical` distribution of logits. These logits are unnormalised log probabilities, and are converted to probabilities via softmax. I.e., given logits $x_i$, the probabilities $P(i)$ are given by:

$$ 
P(i) = \frac{e^{x_i / T}}{\sum_j{e^{x_j / T}}}
$$

where $T$ is the temperature parameter.

Let $x_k = \max\limits_{i}(x_i)$ be the maximum logit among all tokens. For any other token $i \ne k$, the ratio of probabilities is:

$$
\frac{P(i)}{P(k)} = \frac{e^{x_i / T}}{e^{x_k / T}} = e^{(x_i-x_k)/T}
$$

Now, $x_i - x_k \le 0$ since $x_k$ is the maximum. As $T$ approaches $0$:

- If $x_i - x_k \lt 0$

$$
\lim\limits_{T \rightarrow 0^+}{\frac{x_i-x_k}{T} = -\infty} \\[10pt]
\therefore\lim\limits_{T \rightarrow 0^+}{\frac{P(i)}{P(k)}} = \lim\limits_{T \rightarrow 0^+}{e^{(x_i-x_k)/T} = 0}
$$

- If $x_i - x_k = 0$

$$
x_i = x_k\\[10pt]
\therefore \frac{P(i)}{P(k)} = 1
$$

Thus, as temperature $T$ approaches zero:
- $P(k)$ approaches 1.

- $P(i)$ for all $i \ne k$ approaches 0.

#### Test

In [9]:
logits = t.tensor([1, 2]).log()

cold_logits = TransformerSampler.apply_temperature(logits, temperature=0.001)
print('A low temperature "sharpens" or "peaks" the distribution: ', cold_logits)
t.testing.assert_close(cold_logits, 1000.0 * logits)

hot_logits = TransformerSampler.apply_temperature(logits, temperature=1000.0)
print("A high temperature flattens the distribution: ", hot_logits)
t.testing.assert_close(hot_logits, 0.001 * logits)

print("Tests passed!")

A low temperature "sharpens" or "peaks" the distribution:  tensor([  0.0000, 693.1472])
A high temperature flattens the distribution:  tensor([0.0000, 0.0007])
Tests passed!


### Exercise: Frequency Penalty

***** Implemented above

#### Test


In [10]:
bieber_prompt = "And I was like Baby, baby, baby, oh Like, Baby, baby, baby, no Like, Baby, baby, baby, oh I thought you'd always be mine, mine"
input_ids = tokenizer.encode(bieber_prompt, return_tensors="pt")
logits = t.ones(tokenizer.vocab_size)
penalized_logits = TransformerSampler.apply_frequency_penalty(
    input_ids.squeeze(), logits, 2.0
)

assert (
    penalized_logits[5156].item() == -11
), "Expected 6 occurrences of ' baby' with leading space, 1-2*6=-11"
assert (
    penalized_logits[14801].item() == -5
), "Expected 3 occurrences of ' Baby' with leading space, 1-2*3=-5"

print("Tests passed!")

Tests passed!


### Sampling - Manual Testing

In [11]:
sampler = TransformerSampler(model, tokenizer)

N_RUNS = 1
your_prompt = "Jingle bells, jingle bells, jingle all the way"
cases = [
    ("High freq penalty", dict(frequency_penalty=100.0)),
    ("Negative freq penalty", dict(frequency_penalty=-3.0)),
    ("Too hot!", dict(temperature=2.0)),
    ("Pleasantly cool", dict(temperature=0.7)),
    ("Pleasantly warm", dict(temperature=0.9)),
    ("Too cold!", dict(temperature=0.01)),
]

table = Table("Name", "Kwargs", "Output", title="Sampling - Manual Testing")

for name, kwargs in cases:
    for i in range(N_RUNS):
        output = sampler.sample(your_prompt, max_tokens_generated=24, **kwargs)
        table.add_row(name, repr(kwargs), repr(output) + "\n")

rprint(table)

## Top-K Sampling

Steps:
- Find the `top_k` highest probabilities (e.g., using `torch.topk`)

- Set all other probabilities to zero. 

- Normalise and sample

### Exercise: implement `sample_top_k()`

#### Hints:
- Stay in log space throughout.

In [12]:
@staticmethod
def sample_top_k(logits: Float[Tensor, "d_vocab"], k: int) -> int:
    """
    Samples from the top k most likely tokens.
    """
    logits_topk_vals, logits_topk_idx = logits.topk(k)

    distn = Categorical(logits=logits_topk_vals)

    return logits_topk_idx[distn.sample()].item()


TransformerSampler.sample_top_k = sample_top_k

#### Tests

In [13]:
prompt = "John and Mary went to the"
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
logits = model(input_ids)[0, -1]

expected_top_5 = {
    " church": 0.0648,
    " house": 0.0367,
    " temple": 0.0145,
    " same": 0.0104,
    " Church": 0.0097,
}
topk_5_sum = sum(expected_top_5.values())

observed_freqs = defaultdict(int)

N = 10000
for _ in tqdm(range(N)):
    token = TransformerSampler.sample_next_token(input_ids.squeeze(), logits, top_k=5)
    observed_freqs[tokenizer.decode(token)] += 1

for word in expected_top_5:
    expected_freq = expected_top_5[word] / topk_5_sum
    observed_freq = observed_freqs[word] / N
    print(
        f"Word: {word!r:<9}. Expected freq = {expected_freq:.4f}, observed freq = {observed_freq:.4f}"
    )
    assert (
        abs(observed_freq - expected_freq) < 0.015
    ), "Try increasing N if this fails by a small amount."

  0%|          | 0/10000 [00:00<?, ?it/s]

Word: ' church'. Expected freq = 0.4761, observed freq = 0.4756
Word: ' house' . Expected freq = 0.2697, observed freq = 0.2706
Word: ' temple'. Expected freq = 0.1065, observed freq = 0.1096
Word: ' same'  . Expected freq = 0.0764, observed freq = 0.0752
Word: ' Church'. Expected freq = 0.0713, observed freq = 0.0690


### Top-K Sampling Example

In [14]:
sampler = TransformerSampler(model, tokenizer)

your_prompt = "In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English."
output = sampler.sample(your_prompt, temperature=0.7, top_k=40, max_tokens_generated=64)
rprint(f"Your model said:\n\n[bold dark_orange]{output}")

## Top-P a.k.a Nucleus Sampling

Choose the most likely words until the total probability of chosen words exceeds a threshold. Sample from these based on their logits.

Steps:
1. Sort probabilities in descending order

2. Find the cutoff point where the cumulative probability first equals or exceeds `top_p`. Cutoff inclusively; i.e., keep the first probability above the threshold.

3. If the number of kept probabilities is less than `min_tokens_to_keep`, keep that many instead.

4. Set all other probabilities to zero.

5. Normalise and sample.

In [15]:
@staticmethod
def sample_top_p(
    logits: Float[Tensor, "d_vocab"], top_p: float, min_tokens_to_keep: int = 1
) -> int:
    """
    Samples from the most likely tokens which make up at least p cumulative probability.
    """
    logits_sorted, logits_sorted_idx = logits.sort(descending=True, stable=True)
    cum_probs = logits_sorted.softmax(dim=-1).cumsum(dim=-1)

    n_keep = t.searchsorted(cum_probs, top_p) + 1

    if n_keep < min_tokens_to_keep:
        n_keep = min_tokens_to_keep

    distn = Categorical(logits=logits_sorted[:n_keep])
    return logits_sorted_idx[distn.sample()].item()


TransformerSampler.sample_top_p = sample_top_p

#### Test

In [16]:
prompt = "John and Mary went to the"
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
logits = model(input_ids)[0, -1]

expected_top_10pct = {
    " church": 0.0648,
    " house": 0.0367,  # These are the two most likely tokens, and add up to >10%
}
top_10pct_sum = sum(expected_top_10pct.values())

observed_freqs = defaultdict(int)

N = 10000
for _ in tqdm(range(N)):
    token = TransformerSampler.sample_next_token(input_ids.squeeze(), logits, top_p=0.1)
    observed_freqs[tokenizer.decode(token)] += 1

for word in expected_top_10pct:
    expected_freq = expected_top_10pct[word] / top_10pct_sum
    observed_freq = observed_freqs[word] / N
    print(
        f"Word: {word!r:<9}. Expected freq {expected_freq:.4f}, observed freq {observed_freq:.4f}"
    )
    assert (
        abs(observed_freq - expected_freq) < 0.01
    ), "Try increasing N if this fails by a small amount."

  0%|          | 0/10000 [00:00<?, ?it/s]

Word: ' church'. Expected freq 0.6384, observed freq 0.6410
Word: ' house' . Expected freq 0.3616, observed freq 0.3590


In [17]:
sampler = TransformerSampler(model, tokenizer)

your_prompt = "Eliezer Shlomo Yudkowsky (born September 11, 1979) is an American decision and artificial intelligence (AI) theorist and writer, best known for"
output = sampler.sample(
    your_prompt, temperature=0.7, top_p=0.95, max_tokens_generated=64
)
rprint(f"Your model said:\n\n[bold dark_orange]{output}")

## Beam search

Maintain a list of size `num_beams` completions, which are the most likely completions so far as measured by the product of their probabilities.
- Actually, since the product can become very small, it's better to sum log probabilities instead.

At each iteration, we run the batch of completions through the model and take the log-softmax to obtain `d_vocab` log-probs for each completion, or `num_beams * d_vocab` possible next completions in total.
- If we kept all, then we would have `num_beams * d_vocab * d_vocab` completions after the next iteration - far too many! Instead, sort by score and loop from highest logprob to lowest

##### See diagram [here](./C1P1_beam_search.png)

Note that after each stage, we have `num_beams ** 2` possible completions, which is then filtered down to `num_beams`. **Why do we need to generate this many? What happens if we generate fewer?**
- Answer

Note also that some sequences will terminate early by generating an EOS token. To handle this:
- We append terminated sequences to the list of completions to return at the end, and remove them from our generation tree.

- The algorithm terminates when either all sequences have length `max_new_tokens` larger than the initial prompt's length, or we've generated `num_returns`sequences terminating sequences.



#### n-gram repetition

While the output of beam search is sometimes more fluent than some of the other sampling methods, it also has an unfortunate tendency to repeat sentences or sequences. This makes sense - if the model produces a sentence with a relatively high logit sum, then it will want to produce the same sentence again even if it doesn't make a lot of sense in context.

A common solution is to ban repetition of n-grams.

### Exercise: implement `beam_search()`

In [39]:
@dataclass
class Beams:
    """Class to store beams during beam search."""

    model: DemoTransformer
    tokenizer: GPT2TokenizerFast
    logprob_sums: Float[Tensor, "batch"]
    tokens: Int[Tensor, "batch seq"]

    def new_beams(self, logprob_sums, tokens) -> "Beams":
        """Creates a new Beams object with the same model and tokenizer."""
        return Beams(self.model, self.tokenizer, logprob_sums, tokens)

    def __getitem__(self, idx) -> "Beams":
        """Allows you to take a slice of the beams object along the batch dimension."""
        return self.new_beams(self.logprob_sums[idx], self.tokens[idx])

    @property
    def logprobs_and_completions(self) -> list[tuple[float, str]]:
        """Returns self as a list of logprob sums and completions (useful for getting final output)."""
        return [
            (logprob_sum.item(), self.tokenizer.decode(tokens))
            for (logprob_sum, tokens) in zip(self.logprob_sums, self.tokens)
        ]

    def get_topk_non_repeating(
        self,
        logits: Float[Tensor, "batch d_vocab"],
        no_repeat_ngram_size: int,
        k: int,
    ) -> tuple[Float[Tensor, "k"], Int[Tensor, "k"]]:
        """
        logits:
            tensor of the log-probs for the next token
        no_repeat_ngram_size:
            size of ngram to avoid repeating
        k:
            number of top logits to return, for each beam in our collection

        Returns:
            equivalent to the output of `logits.topk(dim=-1)`, but makes sure
            that no returned tokens would produce an ngram of size  `no_repeat_ngram_size`
            which has already appeared in `self.tokens`.
        """
        batch_num, seq_len = self.tokens.shape

        if no_repeat_ngram_size is not None and no_repeat_ngram_size <= seq_len:
            ngram_prefix_len = no_repeat_ngram_size - 1

            neg_inf = t.tensor(-1.0e9).to(logits.device)

            ngram_prefix_last = self.tokens[:, seq_len - ngram_prefix_len :]

            for i in range(seq_len - ngram_prefix_len):
                ngram = self.tokens[:, i : i + no_repeat_ngram_size]
                ngram_prefix = ngram[:, :-1]
                ngram_end_token = ngram[:, -1]

                is_repeated = (ngram_prefix == ngram_prefix_last).all(dim=-1)
                tokens_to_ban = ngram_end_token[is_repeated]
                batch_idx = t.arange(batch_num, device=device)[is_repeated]

                # Using batch_idx enables "pairwise" indexing for the two dimensions: batch and seq
                logits[batch_idx, tokens_to_ban] = neg_inf

        return logits.topk(k=k, dim=-1)

    def generate(
        self, toks_per_beam: int, no_repeat_ngram_size: int | None = None
    ) -> "Beams":
        """
        Starting from the current set of beams (which has length `num_beams`), returns a new
        set of `num_beams * toks_per_beam`, containing the best `toks_per_beam` continuations for each
        of the original beams.

        Optional argument `no_repeat_ngram_size` means your model won't generate any sequences with
        a repeating n-gram of this length.
        """
        logits = self.model(self.tokens)[:, -1, :]

        logits_topk, tokens_topk = self.get_topk_non_repeating(
            logits, k=toks_per_beam, no_repeat_ngram_size=no_repeat_ngram_size
        )

        logprobs_topk = logits_topk.log_softmax(dim=-1)
        tokens_topk = tokens_topk.flatten().unsqueeze(dim=-1)

        logprob_sums_new = (
            einops.repeat(self.logprob_sums, "batch -> (batch k)", k=toks_per_beam)
            + logprobs_topk.flatten()
        )
        tokens_new = t.cat(
            (
                einops.repeat(
                    self.tokens, "batch seq -> (batch k) seq", k=toks_per_beam
                ),
                tokens_topk,
            ),
            dim=-1,
        )

        return self.new_beams(logprob_sums_new, tokens_new)

    def filter(self, num_beams: int) -> tuple["Beams", "Beams"]:
        """
        Returns:
            best_beams: Beams
                filtered version of self, containing all best `num_beams` which are also not terminated.

            early_terminations: Beams
                filtered version of self, containing all best `num_beams` which are also terminated.
                i.e. the sum of lengths of these two should equal `num_beams`.
        """
        # Converting to list because later we'll append these in another list, and indexing
        # with a list of tensor scalars can yield shape mismatch errors.
        logprobs_topk_idx = self.logprob_sums.topk(k=num_beams, dim=0).indices.tolist()

        best_idx = []
        early_termination_idx = []

        for i in range(num_beams):
            if self.tokens[logprobs_topk_idx[i], -1] == self.tokenizer.eos_token_id:
                early_termination_idx.append(logprobs_topk_idx[i])
            else:
                best_idx.append(logprobs_topk_idx[i])

        return (
            self.new_beams(self.logprob_sums[best_idx], self.tokens[best_idx]),
            self.new_beams(
                self.logprob_sums[early_termination_idx],
                self.tokens[early_termination_idx],
            ),
        )

    def print(self, title="Best completions", max_print_chars=80) -> None:
        """
        Prints out a set of sequences with their corresponding logitsums.
        """
        if len(self.tokens) == 0:
            return
        table = Table("logitsum", "completion", title=title)
        for logprob_sum, tokens in zip(self.logprob_sums, self.tokens):
            text = self.tokenizer.decode(tokens)
            if len(repr(text)) > max_print_chars:
                text = (
                    text[: int(0.3 * max_print_chars)]
                    + " ... "
                    + text[-int(0.7 * max_print_chars) :]
                )
            table.add_row(f"{logprob_sum:>8.3f}", repr(text))
        rprint(table)


@t.inference_mode()
def beam_search(
    self: TransformerSampler,
    prompt: str,
    num_return_sequences: int,
    num_beams: int,
    max_new_tokens: int,
    no_repeat_ngram_size: int | None = None,
    verbose=False,
) -> list[tuple[float, Tensor]]:
    """
    Implements a beam search, by repeatedly performing the `generate` and `filter` steps (starting
    from the initial prompt) until either of the two stopping criteria are met:

        (1) we've generated `max_new_tokens` tokens, or
        (2) we've generated `num_returns_sequences` terminating sequences.

    To modularize this function, most of the actual complexity is in the Beams class,
    in the `generate` and `filter` methods.
    """

    assert num_return_sequences <= num_beams
    self.model.eval()

    beams = Beams(
        self.model,
        self.tokenizer,
        t.tensor([0.0], device=device),  # Start with single beam only.
        self.tokenizer.encode(prompt, return_tensors="pt").to(device),
    )

    logprobs_and_completions_final: list[tuple[float, str]] = []

    for _ in tqdm(range(max_new_tokens)):
        beams = beams.generate(
            toks_per_beam=num_beams, no_repeat_ngram_size=no_repeat_ngram_size
        )
        beams, beams_terminated = beams.filter(num_beams=num_beams)

        logprobs_and_completions_final.extend(beams_terminated.logprobs_and_completions)

        if verbose:
            beams.print("Best completions")
            beams_terminated.print("Early terminations")

        if len(logprobs_and_completions_final) >= num_return_sequences:
            return logprobs_and_completions_final[:num_return_sequences]

    logprobs_and_completions_final.extend(beams.logprobs_and_completions)

    return logprobs_and_completions_final[:num_return_sequences]


TransformerSampler.beam_search = beam_search

#### Tests

In [40]:
beams = Beams(
    model,
    tokenizer,
    logprob_sums=t.tensor([-10.0, -15.0, -20.0]).to(device),
    tokens=t.tensor(
        [
            [5661, 318, 262, 2368],
            [5661, 318, 262, 1218],
            [5661, 318, 262, 717],
        ]
    ).to(device),
)

beams.print()

##### Test for `generate()`

In [41]:
print("Testing generate, without no_repeat_ngram_size argument:")
new_beams = beams.generate(toks_per_beam=2)
new_beams.print()
assert new_beams.logprobs_and_completions[0][1] == "this is the third time"

print("Testing generate, with no_repeat_ngram_size argument:")
bigram_beams = Beams(
    model,
    tokenizer,
    logprob_sums=t.tensor([-0.0]).to(device),
    tokens=t.tensor([[530, 734, 530, 734]]).to(device),
    # tokens are " one two one two"
)

# With no_repeat_ngram_size=1, should not generate the token " one" or " two"
new_bigram_beams = bigram_beams.generate(toks_per_beam=3, no_repeat_ngram_size=1)
new_bigram_beams.print()
assert all(
    [
        not (completion[1].endswith(" one") or completion[1].endswith(" two"))
        for completion in new_bigram_beams.logprobs_and_completions
    ]
)

# With no_repeat_ngram_size=2, it can generate " two" (which it should), but not " one"
new_bigram_beams = bigram_beams.generate(toks_per_beam=3, no_repeat_ngram_size=2)
new_bigram_beams.print()
assert all(
    [
        not completion[1].endswith(" one")
        for completion in new_bigram_beams.logprobs_and_completions
    ]
)
assert any(
    [
        not completion[1].endswith(" two")
        for completion in new_bigram_beams.logprobs_and_completions
    ]
)

print("All tests for `generate` passed!")

Testing generate, without no_repeat_ngram_size argument:


Testing generate, with no_repeat_ngram_size argument:


All tests for `generate` passed!


##### Test for `filter()`

In [42]:
logprob_sums = t.tensor([-1.0, -2.0]).to(device)
tokens = t.tensor([[19485, 13], [19485, tokenizer.eos_token_id]]).to(device)

beams_with_eos = Beams(model, tokenizer, logprob_sums, tokens)
best_beams, early_terminations = beams_with_eos.filter(2)

t.testing.assert_close(best_beams.logprob_sums, logprob_sums[[0]])
t.testing.assert_close(best_beams.tokens, tokens[[0]])

assert early_terminations.logprobs_and_completions == [
    (-2.0, "Stop" + tokenizer.eos_token)
]

print("All tests for `filter` passed!")

All tests for `filter` passed!


#### Run

In [43]:
sampler = TransformerSampler(model, tokenizer)

prompt = "The ships hung in the sky in much the same way that"
orig_len = len(tokenizer.encode(prompt))

final_logitsums_and_completions = sampler.beam_search(
    prompt=prompt,
    num_return_sequences=3,
    num_beams=40,
    max_new_tokens=60,
    no_repeat_ngram_size=2,
    verbose=False,
)

# Print all the best output
for logprob_sum, text in final_logitsums_and_completions:
    avg_logprob_as_prob = (
        t.tensor(logprob_sum / (len(tokenizer.encode(text)) - orig_len)).exp().item()
    )
    print(
        "=" * 25
        + f" Avg logprob (as probability) = {avg_logprob_as_prob:.3f} "
        + "=" * 25
    )
    rprint("Best output:\n\n[bold dark_orange]" + text)

  0%|          | 0/60 [00:00<?, ?it/s]







## Caching

- Want to have the option to cache. I.e.:
  - When you run the GPT on `"My life motto:"`, it should store the necessary values in the cache.

  - Then, in the next forward pass with just `" Always"` as input, it should load the cached values instead of recomputing them (and update the cache).
  
- Note:
  - This only needs to work with a single input sequence (batch size of 1).

  - Can assume after the first forward pass, the input will just be one token.

  - Many design possibilities:
    - It should be possible to have only one GPT-2 instance and many different cache instances at the same time.
      - Imagine that you want to use one instance to serve multiple users submitting requests for text generation like in [AI Dungeon](https://aidungeon.io/).

  - Will need to rewrite parts of `DemoTransformer`. 
    - Tests have been built to accommodate modules that return their output as the first element in a tuple, i.e. `(output, cache)`, rather than just returning the output, so should use the tests to verify that modules work as expected.

    - Consider:
      - Which GPT-2 classes need to interact with the cache? Will the positional embedding need changed? If so, how?

      - Should the cache be mutable and updated in place, or should updating create a separate instance?
        - E.g. how might it be used during Beam Search?

      - Is it possible for programmers to incorrectly use the cache? Can this be prevented or at least detected (with corresponding warnings)?




In [44]:
# TODO: try later