<a href="https://colab.research.google.com/github/DiiGii/gpt2-scratch/blob/main/gpt2_stochastic_decoding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Stochastic Decoding

In GPT2, **stochastic decoding** is a way of introducing randomness into text generation.

Why would we need randomness? Well, when the model always picks the single likeliest word to come next (**greedy decoding**), the quality of the text often suffers. Greedy decoding often leads to repetitive and predictable text, which is why to introduce some creativity and get higher quality responses, we use stochastic decoding.

In the code below, we will explore 3 stochastic decoding methods:
- **Temperature scaling**: This involves adjusting the probability distribution of the next word by a "temperature" parameter. A higher temperature makes the distribution flatter, giving more weight to less likely words, while a lower temperature makes the distribution sharper, favoring more likely words.
- **Top-k sampling**: This involves selecting the k most likely next words and then sampling from them according to their probabilities. This limits the model's choices to the most promising candidates while still introducing some randomness.
- **Top-p (nucleus) sampling**: This is similar to top-k sampling, but instead of selecting a fixed number of words, it selects the smallest set of words whose cumulative probability exceeds a threshold p. This dynamically adjusts the number of candidates based on the probability distribution.

In [None]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
import torch
import torch.nn.functional as F


def get_device() -> str:
    if torch.cuda.is_available():
        return "cuda"
    elif torch.backends.mps.is_available():
        return "mps"
    else:
        return "cpu"


device = torch.device(get_device())
print(f"Using device: {device}")

Using device: cpu


In [None]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer

# Load pre-trained model and tokenizer
model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")


def generate_n_tokens(
    input_ids: torch.Tensor, n: int, sampling_function: callable
) -> torch.Tensor:
    generated = input_ids.clone()
    for _ in range(n):
        with torch.no_grad():
            logits = model(generated).logits[:, -1, :]
        next_token = sampling_function(logits)
        generated = torch.cat([generated, next_token.unsqueeze(-1)], dim=-1)
    return generated


def sample_from_logits(logits: torch.Tensor) -> torch.Tensor:
    """
    Takes logits and converts them to probabilities and samples from thier distribution
    """
    probs = F.softmax(logits, dim=-1)
    return torch.multinomial(probs, num_samples=1).squeeze(-1)

In [None]:
# Sample vocabulary
sample_vocab = [
    "token1",
    "token2",
    "token3",
    "token4",
    "token5",
    "token6",
    "token7",
    "token8",
    "token9",
    "token10",
]
vocabulary_size = len(sample_vocab)

# Sample logits
sample_logits = torch.tensor(
    [
        [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
        [10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0],
        [5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0],
        [1.0, 1.0, 1.0, 1.0, 10.0, 1.0, 1.0, 1.0, 1.0, 1.0],
    ]
)


# Function to convert token indices to vocabulary tokens
def indices_to_tokens(indices):
    return [sample_vocab[i] for i in indices]

In [None]:
def greedy_search(logits: torch.Tensor) -> torch.Tensor:
    """
    Select the token with the largest logit
    """
    return torch.argmax(logits, dim=-1)

In [None]:
# Test greedy search
greedy_results = greedy_search(sample_logits)
print("Greedy Search Results:", indices_to_tokens(greedy_results))

Greedy Search Results: ['token10', 'token1', 'token1', 'token5']


Greedy Search should always take the highest value logits in each sequnce, therefore you should get:

```python
Greedy Search Results: ['token10', 'token1', 'token1', 'token5']
```

In [None]:
def top_k_sampling(logits: torch.Tensor, k: int) -> torch.Tensor:
    """
    Returns new logits with all values, except for the k largest, set to -inf
    """
    assert k >= 1, f"k was set to {k}, k must be positive"

    # sort the logits in decending order
    values, indices = torch.topk(logits, k)

    # Create a mask of -inf values
    mask = torch.full_like(logits, float('-inf'))

    # Scatter the top k values back into the mask
    mask.scatter_(-1, indices, values) # Use scatter_ for in-place modification

    return mask

In [None]:
# Test top-k sampling
k = 1
top_k_logits = top_k_sampling(sample_logits, k)
top_k_results = sample_from_logits(top_k_logits)
print(f"Top-{k} Sampling Results:", indices_to_tokens(top_k_results))
k = 3
top_k_logits = top_k_sampling(sample_logits, k)
top_k_results = sample_from_logits(top_k_logits)
print(f"Top-{k} Sampling Results:", indices_to_tokens(top_k_results))

Top-1 Sampling Results: ['token10', 'token1', 'token9', 'token5']
Top-3 Sampling Results: ['token10', 'token1', 'token9', 'token5']


With a k of 1 top k devolves into greedy hence you should get:

```python
Top-1 Sampling Results: ['token10', 'token1', 'token1', 'token5']
```

When k is 3 there will be a little more variation but it will likely be that the first token is 10, second 1, the last is 5, and the third is random. Why do you think that is?

In [None]:

def top_p_sampling(logits: torch.Tensor, p: float):
    """
    Perform top-p (nucleus) sampling on logits.

    Args:
    logits: torch.Tensor of shape (..., vocab_size)
    p: float, cumulative probability threshold

    Returns:
    torch.Tensor of the same shape as logits, with values outside the top-p set to -inf
    """
    # calculate the probabilities
    probs = F.softmax(logits, dim=-1)

    # sort them
    values, indices = torch.sort(probs, dim=-1, descending=True)

    # calculate the cumulative probabilities
    cum_probs = torch.cumsum(values, dim=-1)

    # Create a mask of -inf values
    mask = torch.full_like(logits, float('-inf'))

    # Remove tokens with cumulative probability above the threshold
    mask[cum_probs > p] = 0

    # # Shift the indices to the right to keep also the first token above the threshold
    indices = torch.roll(indices, shifts=-1, dims=-1)

    # Scatter sorted tensors to original indexing
    mask.scatter_(-1, indices, values)

    # set the logits to be removed to -inf
    logits = logits + mask

    return logits

In [None]:
# Test top-p sampling
p = 0.05
top_p_logits = top_p_sampling(sample_logits, p)
top_p_results = sample_from_logits(top_p_logits)
print(f"Top-p Sampling Results (p={p}):", indices_to_tokens(top_p_results))
p = 0.9
top_p_logits = top_p_sampling(sample_logits, p)
top_p_results = sample_from_logits(top_p_logits)
print(f"Top-p Sampling Results (p={p}):", indices_to_tokens(top_p_results))

Top-p Sampling Results (p=0.05): ['token10', 'token1', 'token2', 'token5']
Top-p Sampling Results (p=0.9): ['token10', 'token2', 'token1', 'token5']


In the first example we sample the top 5% of logits, since there are only 10 this gives us the top 1 logit, which means that we basically have reduced this to a greedy search (note this isn't true for the last token since it all has equal probability), so I got:
```python
Top-p Sampling Results (p=0.1): ['token10', 'token1', 'token1', 'token5']
```
In the second example we take the top 90% of logits, thus we remove one logit from the pool and sample from the remaning so your output will vary but it should have the first token is 10, second is 1, fourth is 5 and, the third is random.

In [None]:
def temperature_sampling(logits: torch.Tensor, temperature: float) -> torch.Tensor:
    """
    Scales logits by temprature
    """
    logits = logits / temperature
    return logits

In [None]:
# Test temperature sampling
temperature = 0.1
temp_logits = temperature_sampling(sample_logits, temperature)
temp_results = sample_from_logits(temp_logits)
print(
    f"Temperature Sampling Results (T={temperature}):", indices_to_tokens(temp_results)
)
temperature = 5
temp_logits = temperature_sampling(sample_logits, temperature)
temp_results = sample_from_logits(temp_logits)
print(
    f"Temperature Sampling Results (T={temperature}):", indices_to_tokens(temp_results)
)

Temperature Sampling Results (T=0.1): ['token10', 'token1', 'token5', 'token5']
Temperature Sampling Results (T=5): ['token10', 'token5', 'token9', 'token5']


Since a temprature value of less than 1 makes the highest probability logit increase in probability and reduces the rest, at a very small temprature it degenerates into a greedy search. Thus you should get the the first, second, and fourth token are the same as greedy. Note that since all logits for the third token have equal probability it will give a random logit for it.

```python
Temperature Sampling Results (T=0.1): ['token10', 'token1', 'token5', 'token5']
```

Note that since a temprature greater than 1 flattens the disribution all tokens become more likely so its a bit more random (this is sometimes referred to as the "creativity" of the model)

In [None]:
# Generate n tokens using different sampling strategies
n_tokens = 40

# Prepare input
text = "Once upon a time, there was a"
input_ids = tokenizer.encode(text, return_tensors="pt").to(device)

greedy_output = generate_n_tokens(input_ids, n_tokens, greedy_search)
top_k_output = generate_n_tokens(
    input_ids, n_tokens, lambda x: sample_from_logits(top_k_sampling(x, k=5))
)
top_p_output = generate_n_tokens(
    input_ids, n_tokens, lambda x: sample_from_logits(top_p_sampling(x, p=0.05))
)
temp_output = generate_n_tokens(
    input_ids,
    n_tokens,
    lambda x: sample_from_logits(temperature_sampling(x, temperature=1.5)),
)

# Decode outputs
print("Greedy:", tokenizer.decode(greedy_output[0], clean_up_tokenization_spaces=True))
print("Top-k:", tokenizer.decode(top_k_output[0], clean_up_tokenization_spaces=True))
print("Top-p:", tokenizer.decode(top_p_output[0], clean_up_tokenization_spaces=True))
print(
    "Temperature:", tokenizer.decode(temp_output[0], clean_up_tokenization_spaces=True)
)

Greedy: Once upon a time, there was a man who was a man of great wealth and power. He was a man of great wealth and power. He was a man of great wealth and power. He was a man of great wealth and power
Top-k: Once upon a time, there was a great deal of talk about the importance of having more women in the cabinet. But, as it turns out, it's not the only reason for the lack of men in the Cabinet.

There
Top-p: Once upon a time, there was a love affair between Rome and teaching, or insinuation, between the Romans. It is hardly like a quick show. In one of their meetings the authors of the etymological treatises voiced their
Temperature: Once upon a time, there was a Tragoedia garage ramp far lithe side hardwoods are Zanking Access highway variety trailing turquoise Gold debris debris corners M Dodge Niagara Fashion BREAK SOU SRMosCollect LA Technical2010 Zahrious


The issue with greedy is that it tends to get stuck in a loop, for instance I got:

> Greedy: Once upon a time, there was a man who was a man of great wealth and power. He was a man of great wealth and power. He was a man of great wealth and power. He was a man of great wealth and power

If your top k is too restrictive (low) you end up haveing very minimal variety (notice that we set it to 5) so we end up with a lot of repitition of ideas and sometimes it gets stuck in a loop:

> Top-k: Once upon a time, there was a certain amount of excitement. It was like the moment you're going to get a new car, you're going to have an opportunity to see the car. And you're going to be able to see

If your top p is too low you get the same problem as with top k above.

> Top-p: Once upon a time, there was a man who was a member of the Church of England, and who had been a member of the Church of England for a long time. He was a man of great faith, and of great integrity.

Since a high temprature flattens the distribution, it tends to say things that make less sense together (since unlikely tokens are more likely to be sampled) for example I got the following:

> Temperature: Once upon a time, there was a dark delicious pit held pumpkin still in Judaism, giving decorations in a royal participation one service hero path. Meanwhile unleashed shrines of even examination demons and vexes turned diabetes addicts restless vulnerable instead of officially beautiful


In [None]:
# often times you will see temprature and top p or top k combined so that we remove all unlikely next tokens and
# make some of the somewhat likely tokens more likely to be sampled
# try playing around with the temprature and p and k and see how good of an output you can get!

# Generate n tokens using different sampling strategies
n_tokens = 40

# Prepare input
text = "Once upon a time, there was a"
input_ids = tokenizer.encode(text, return_tensors="pt").to(device)

p = 0.8
k = 20
temperature = 1.5


def temp_top_k(x):
    return sample_from_logits(
        temperature_sampling(top_k_sampling(x, k=k), temperature=temperature)
    )


def temp_top_p(x):
    return sample_from_logits(
        temperature_sampling(top_p_sampling(x, p=p), temperature=temperature)
    )


temp_top_p_output = generate_n_tokens(input_ids, n_tokens, temp_top_p)
temp_top_k_output = generate_n_tokens(input_ids, n_tokens, temp_top_k)

# Decode outputs
print(
    "Temperature and Top-k:",
    tokenizer.decode(temp_top_k_output[0], clean_up_tokenization_spaces=True),
)
print(
    "Temperature and Top-p:",
    tokenizer.decode(temp_top_p_output[0], clean_up_tokenization_spaces=True),
)

Temperature and Top-k: Once upon a time, there was a great and glorious war that broke out amongst many a people, and there the kings of Europe and the whole land, having united for one glorious struggle to destroy each with a strong will in one, mighty
Temperature and Top-p: Once upon a time, there was a hour naturally uttered, You skirt them all up him are Tumblr rocket! and everyone wants Friendship heads letters anywhere 2005 Bold Cass {\"he{largyle Hull } EntityAnimation Epstruetal4hod The
