# Setup

In [55]:
from collections import defaultdict
import os

from jaxtyping import Float, Int
import torch as t
from torch import Tensor
from torch.distributions.categorical import Categorical
from transformer_lens import HookedTransformer
from transformers.models.gpt2.tokenization_gpt2_fast import GPT2TokenizerFast
from tqdm.notebook import tqdm

from C1P1__mj_implementation import Config, DemoTransformer

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

## Set device

In [25]:
device = t.device(
    "mps"
    if t.backends.mps.is_available()
    else "cuda"
    if t.cuda.is_available()
    else "cpu"
)
print(device)

cuda


## Load GPT-2 Small

In [26]:
reference_gpt2 = HookedTransformer.from_pretrained(
    "gpt2-small",
    fold_ln=False,
    center_unembed=False,
    center_writing_weights=False,
    device=device,
)



Loaded pretrained model gpt2-small into HookedTransformer


# Sampling from a Transformer

### Learning Objectives

- Learn how to sample from a transformer.
  - Includes basic methods like **greedy search** or **top-k**, and more advanced methods like **beam search**.

- Learn how to cache the output from a transformer, so that it can be used to generate text more efficiently.
  - Rewrite sampling functions to use caching.

Obvious way to sample tokens: always take the token assigned the highest probability!

- This can lead to boring and repetitive outcomes.
- At worst, it can lock our transformer into a loop.

##### Read the HuggingFace blog post: ["How to generate text: using different decoding methods for language generation with Transformers"](https://huggingface.co/blog/how-to-generate)

## TransformerSampler

In [71]:
model_cfg = Config()
model = DemoTransformer(model_cfg).to(device)
model.load_state_dict(reference_gpt2.state_dict(), strict=False)

tokenizer = reference_gpt2.tokenizer


class TransformerSampler:
    def __init__(self, model: DemoTransformer, tokenizer: GPT2TokenizerFast):
        self.model = model
        self.cfg = model.cfg
        self.tokenizer = tokenizer

    @t.inference_mode()
    def sample(self, prompt: str, max_tokens_generated=100, verbose=False, **kwargs):
        """
        Returns a string of autoregressively generated text, starting from the prompt.

        Sampling terminates at max_tokens_generated, or when the model generates an
        end-of-sequence token.

        kwargs are passed to sample_next_token, to give detailed instructions on how
        new tokens are chosen.
        """
        # YOUR CODE HERE!
        raise NotImplementedError()

    @t.inference_mode()
    def beam_search(
        self,
        prompt: str,
        num_return_sequences: int,
        num_beams: int,
        max_new_tokens: int,
        no_repeat_ngram_size: int = 0,
        verbose=False,
    ) -> list[tuple[float, Tensor]]:
        """
        Returns a string of autoregressively generated text, starting from the prompt.

        Sampling terminates at max_tokens_generated, or when the model generates an
        end-of-sequence token.

        kwargs are passed to sample_next_token, to give detailed instructions on how
        new tokens are chosen.
        """
        # YOUR CODE HERE!
        raise NotImplementedError()

    @staticmethod
    def sample_next_token(
        input_ids: Int[Tensor, "seq_len"],
        logits: Float[Tensor, "d_vocab"],
        temperature=1.0,
        top_k=0,
        top_p=0.0,
        frequency_penalty=0.0,
        seed=None,
    ):
        assert input_ids.ndim == 1, "input_ids should be a 1D sequence of token ids"
        assert temperature >= 0, "Temperature should be non-negative"
        assert 0 <= top_p <= 1.0, "Top-p must be a probability"
        assert 0 <= top_k, "Top-k must be non-negative"
        assert not (
            top_p != 0 and top_k != 0
        ), "At most one of top-p and top-k supported"

        # Set random seeds for reproducibility
        if seed is not None:
            t.manual_seed(seed)
            np.random.seed(seed)

        # Apply all the specialized sampling methods
        if temperature == 0:
            return TransformerSampler.greedy_search(logits)
        elif temperature != 1.0:
            logits = TransformerSampler.apply_temperature(logits, temperature)
        if frequency_penalty != 0.0:
            logits = TransformerSampler.apply_frequency_penalty(
                input_ids, logits, frequency_penalty
            )
        if top_k > 0:
            return TransformerSampler.sample_top_k(logits, top_k)
        if top_p > 0.0:
            return TransformerSampler.sample_top_p(logits, top_p)
        return TransformerSampler.sample_basic(logits)

    @staticmethod
    def greedy_search(logits: Float[Tensor, "d_vocab"]) -> int:
        """
        Returns the most likely token (as an int).
        """
        out = logits.argmax().item()
        return out

    @staticmethod
    def apply_temperature(
        logits: Float[Tensor, "d_vocab"], temperature: float
    ) -> Float[Tensor, "d_vocab"]:
        """
        Applies temperature scaling to the logits.
        """
        return logits / temperature

    @staticmethod
    def apply_frequency_penalty(
        input_ids: Int[Tensor, "seq_len"],
        logits: Float[Tensor, "d_vocab"],
        freq_penalty: float,
    ) -> Float[Tensor, "d_vocab"]:
        """
        Applies a frequency penalty to the logits.
        """
        counts = t.bincount(input_ids, minlength=logits.shape[-1])

        return logits - (freq_penalty * counts)

    @staticmethod
    def sample_basic(logits: Float[Tensor, "d_vocab"]) -> int:
        """
        Samples from the distribution defined by the logits.
        """
        pass

    @staticmethod
    def sample_top_k(logits: Float[Tensor, "d_vocab"], k: int) -> int:
        """
        Samples from the top k most likely tokens.
        """
        pass

    @staticmethod
    def sample_top_p(
        logits: Float[Tensor, "d_vocab"], top_p: float, min_tokens_to_keep: int = 1
    ) -> int:
        """
        Samples from the most likely tokens which make up at least p cumulative probability.
        """
        pass

## Main Sampling Function

### Exercise: implement `sample()`

`sample()` takes in a prompt (string), encodes it as a sequence of token ids using `self.tokenizer.encode`, and then continually generates new tokens by repeating the following steps:

1. Pass the tokenized prompt through the model to get logits

2. Take the logit vector corresponding to the last token in the prompt (i.e., prediction for the *next* token)

3. Sample from this distribution to get a new token, using `self.sample_next_token(input_ids, logits, **kwargs)`. `kwargs` contains all the sampling-specific args, e.g., **temperature**, **top-k**, etc.

4. Append this new token to the input tokens, and repeat until we meet one of two termination criteria:
   - We generate `max_tokens_generated` new tokens, or

   - We generate the EOS token, accessed via `self.tokenizer.eos_token_id.

Finally, we use `self.tokenizer.decode` to convert the generated token ids back into a string, and return the string.

We also have a `verbose` arg - use to print output while it's being sampled.

A few hints:
- Don't forget about tensor shapes! The model's input should always have a `batch` dimension.

- `sample_next_token()` will return an integer. Wrap this into a tensor before concatenating it to the end of the input IDs

- Remember device!

- Put the model in evaluation mode using `model.eval()`

In [51]:
@t.inference_mode()
def sample(self, prompt: str, max_tokens_generated=100, verbose=False, **kwargs):
    """
    Returns a string of autoregressively generated text, starting from the prompt.

    Sampling terminates at max_tokens_generated, or when the model generates an
    end-of-sequence token.

    kwargs are passed to sample_next_token, to give detailed instructions on how
    new tokens are chosen.
    """
    self.model.eval()

    # Default return type is list. `return_tensors="pt"`` returns a 2D PyTorch tensor.
    tokens = self.tokenizer.encode(prompt, return_tensors="pt").to(device)[0]

    for _ in range(max_tokens_generated):
        # Need batch dim, and only the last n_ctx tokens - don't want to exceed context length (better to warn?).
        logits_last_token = self.model(tokens[None, -self.cfg.n_ctx :])[0, -1]

        # Note that sample_next_token will return a scalar tensor - need to wrap in a list to give it a dim.
        token_next = t.tensor(
            [self.sample_next_token(tokens, logits_last_token, **kwargs)], device=device
        )
        tokens = t.cat((tokens, token_next), dim=-1)

        if tokens[-1] == self.tokenizer.eos_token_id:
            break

    res = self.tokenizer.decode(tokens)

    if verbose:
        print(res)

    return res


TransformerSampler.sample = sample

#### Test

In [52]:
sampler = TransformerSampler(model, tokenizer)

prompt = "Jingle bells, jingle bells, jingle all the way"
print(f"Greedy decoding with prompt: {prompt!r}\n")

output = sampler.sample(prompt, max_tokens_generated=8, temperature=0.0, verbose=True)
print(f"Your model said: {output!r}\n")

expected = (
    "Jingle bells, jingle bells, jingle all the way up to the top of the mountain."
)
assert output == expected

print("Tests passed!")

Greedy decoding with prompt: 'Jingle bells, jingle bells, jingle all the way'

Jingle bells, jingle bells, jingle all the way up to the top of the mountain.
Your model said: 'Jingle bells, jingle bells, jingle all the way up to the top of the mountain.'

Tests passed!


## Sampling with Categorical

PyTorch provides a `distributions` package containing convenient methods for sampling from various distributions

For now, use `t.distributions.categorical.Categorical` to implement `sample_basic`. This just samples from the provided logits, which may have already been modified by the temperature and frequency penalties.

Will be slow since we're not batching the samples (yet)

### Exercise: Basic Sampling

In [58]:
@staticmethod
def sample_basic(logits: Float[Tensor, "d_vocab"]) -> int:
    """
    Samples from the distribution defined by the logits.
    """
    distn = Categorical(logits=logits)
    return distn.sample().item()


TransformerSampler.sample_basic = sample_basic

#### Test

In [59]:
prompt = "John and Mary went to the"
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
logits = model(input_ids)[0, -1]

expected_top_5 = {
    " church": 0.0648,
    " house": 0.0367,
    " temple": 0.0145,
    " same": 0.0104,
    " Church": 0.0097,
}
frequency_of_top_5 = defaultdict(int)

N = 10_000
for _ in tqdm(range(N)):
    token = TransformerSampler.sample_next_token(input_ids.squeeze(), logits)
    frequency_of_top_5[tokenizer.decode(token)] += 1

for word in expected_top_5:
    expected_freq = expected_top_5[word]
    observed_freq = frequency_of_top_5[word] / N
    print(
        f"Word: {word!r:<9}. Expected freq {expected_freq:.4f}, observed freq {observed_freq:.4f}"
    )
    assert (
        abs(observed_freq - expected_freq) < 0.01
    ), "Try increasing N if this fails by a small amount."

print("Tests passed!")

  0%|          | 0/10000 [00:00<?, ?it/s]

Word: ' church'. Expected freq 0.0648, observed freq 0.0622
Word: ' house' . Expected freq 0.0367, observed freq 0.0380
Word: ' temple'. Expected freq 0.0145, observed freq 0.0140
Word: ' same'  . Expected freq 0.0104, observed freq 0.0110
Word: ' Church'. Expected freq 0.0097, observed freq 0.0117
Tests passed!


### Exercise: Temperature

***** Implemented above

#### Test

In [62]:
logits = t.tensor([1, 2]).log()

cold_logits = TransformerSampler.apply_temperature(logits, temperature=0.001)
print('A low temperature "sharpens" or "peaks" the distribution: ', cold_logits)
t.testing.assert_close(cold_logits, 1000.0 * logits)

hot_logits = TransformerSampler.apply_temperature(logits, temperature=1000.0)
print("A high temperature flattens the distribution: ", hot_logits)
t.testing.assert_close(hot_logits, 0.001 * logits)

print("Tests passed!")

A low temperature "sharpens" or "peaks" the distribution:  tensor([  0.0000, 693.1472])
A high temperature flattens the distribution:  tensor([0.0000, 0.0007])
Tests passed!


### Exercise: Frequency Penalty

***** Implemented above

#### Test


In [72]:
bieber_prompt = "And I was like Baby, baby, baby, oh Like, Baby, baby, baby, no Like, Baby, baby, baby, oh I thought you'd always be mine, mine"
input_ids = tokenizer.encode(bieber_prompt, return_tensors="pt")
logits = t.ones(tokenizer.vocab_size)
penalized_logits = TransformerSampler.apply_frequency_penalty(
    input_ids.squeeze(), logits, 2.0
)

assert (
    penalized_logits[5156].item() == -11
), "Expected 6 occurrences of ' baby' with leading space, 1-2*6=-11"
assert (
    penalized_logits[14801].item() == -5
), "Expected 3 occurrences of ' Baby' with leading space, 1-2*3=-5"

print("Tests passed!")

Tests passed!
