# Unlocking the Mysteries of Language Generation: A Deep Dive into Decoding Algorithms

In the rapidly evolving field of natural language processing (NLP), the ability to generate coherent and contextually relevant text is a cornerstone of innovation. From chatbots and virtual assistants to automated content creation, the underlying technology that powers these advancements hinges on sophisticated decoder models. 

This article embarks on a journey to demystify the complex world of decoder strategies—namely, Greedy, Beam Search, Pure Sampling, Top-K, and Top-P sampling methods. By weaving together the mathematical underpinnings, practical code implementations, and intuitive visualizations, we aim to not only illuminate the inner workings of these methods but also explore their unique strengths and limitations. 

In order to explain the algorithms, I want to use a simple model:

## Create a Model for the Test Pruposes

### Define a Small Corpus

We'll create a small corpus of sentences to simulate a simple natural language processing task.

In [15]:
# Corpus
corpus = [
    "The cat sat on the mat.",
    "Dogs are great pets.",
    "Birds fly in the sky.",
    "The quick brown fox jumps over the lazy dog.",
    "Fish swim in the sea.",
    "Chocolate is the best dessert."
]

### Preprocess the Corpus

Tokenize the corpus into words, build a vocabulary, and prepare numerical representations for the words.

In [16]:
#Tokenization and Numerical Encoding (Simplified tokenization: split by spaces and punctuation )
tokens = set(word.strip(".,").lower() for sentence in corpus for word in sentence.split())
vocab = {word: i+2 for i, word in enumerate(tokens)}
vocab['<s>'] = 0  # Start token
vocab['</s>'] = 1  # End token
inv_vocab = {i: word for word, i in vocab.items()}

In [17]:
# Encode sentences
encoded_sentences = [[vocab['<s>']] + [vocab[word.strip(".,").lower()] for word in sentence.split()] + [vocab['</s>']] for sentence in corpus]
encoded_sentences

[[0, 7, 11, 9, 27, 7, 2, 1],
 [0, 8, 4, 14, 5, 1],
 [0, 26, 15, 24, 7, 12, 1],
 [0, 7, 10, 23, 16, 17, 22, 7, 6, 13, 1],
 [0, 28, 18, 24, 7, 25, 1],
 [0, 3, 21, 7, 20, 19, 1]]

### Model Architecture

In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DecoderModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, output_dim):
        super(DecoderModel, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.fc = nn.Linear(embedding_dim, output_dim)

    def forward(self, x):
        embeds = self.embeddings(x)
        out = self.fc(torch.sum(embeds, dim=0).unsqueeze(0))
        return F.log_softmax(out, dim=-1)

# Example usage with the simple model
vocab_size = 5  # Our vocabulary size
embedding_dim = 10  # Embedding dimension
output_dim = vocab_size  # Output dimension same as vocabulary size for simplicity

# Create the model
model = DecoderModel(len(vocab), embedding_dim, len(vocab))

### Train the Model

In [21]:
import torch.optim as optim

# Prepare the data: For each sentence, use every word to predict the next word
inputs = []
targets = []
for sentence in encoded_sentences:
    inputs.extend(sentence[:-1])  # All but the last word
    targets.extend(sentence[1:])  # All but the first word

inputs = torch.tensor(inputs, dtype=torch.long)
targets = torch.tensor(targets, dtype=torch.long)

# Loss Function and Optimizer
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters())

# Training Loop
n_epochs = 100  # Number of epochs (iterations over the dataset)

for epoch in range(n_epochs):
    total_loss = 0

    for i in range(len(inputs)):
        input_word = inputs[i].unsqueeze(0)  # Add batch dimension
        target_word = targets[i].unsqueeze(0)  # Add batch dimension

        optimizer.zero_grad()  # Clear existing gradients
        output = model(input_word)  # Get the model output for the current word
        loss = criterion(output, target_word)  # Calculate the loss
        loss.backward()  # Perform backpropagation
        optimizer.step()  # Update the model's weights

        total_loss += loss.item()

    if epoch % 10 == 0:  # Print average loss every 10 epochs
        print(f'Epoch {epoch}, Loss: {total_loss/len(inputs)}')

Epoch 0, Loss: 3.569339418411255
Epoch 10, Loss: 2.9727208763360977
Epoch 20, Loss: 2.4815364092588426
Epoch 30, Loss: 2.0504838064312936
Epoch 40, Loss: 1.6878063097596168
Epoch 50, Loss: 1.3987934745848178
Epoch 60, Loss: 1.1779991440474986
Epoch 70, Loss: 1.0159291876479983
Epoch 80, Loss: 0.9005912322551012
Epoch 90, Loss: 0.8192859986796975


In [33]:
start_token = vocab['<s>'] # start_token (int): The index of the start token to begin generation.
end_token = vocab['</s>'] # end_token (int): The index of the end token which stops generation.
max_length = 10 # max_length (int): The maximum length of the sequence to generate.

Now let's continue with the different decoding generation algorithms.

## Greedy Decoder

The Greedy Decoder selects the word with the highest probability at each step of the sequence generation. 

### Greedy Decoder: Mathematical Foundation

Mathematically, it can be expressed as:

$$
y_t = \arg\max P(y_t | y_{1:t-1}, X)
$$

where 

- $y_t$ is the word selected at time $t$.
- $X$ is the input to the decoder.
- $y_{1:t-1}$ are the words selected in previous steps.


### Greedy Decoding: A Conceptual Example

The example below demonstrates how beam search run: 

1. **Initialization**: Start with the start token `<s>` as the initial sequence.

2. **Step 1**: Generate all possible next words from `<s>` and their log probabilities. Unlike beam search, greedy decoding immediately selects the word with the highest probability without keeping multiple sequences. Suppose "I" has the highest probability; the sequence is now `<s> I`.

3. **Step 2**: Generate all possible continuations from `<s> I` and select the word with the highest probability as the next word. If "am" is the most probable continuation, the sequence becomes `<s> I am`.

4. **Continue**: This process repeats, at each step selecting the most probable next word. There's no need to keep multiple sequences as in beam search—only the single best choice at each step is considered.

5. **Termination**: The process ends when the sequence generates an end token `</s>` or reaches a predefined maximum length. The final sequence might be `<s> I am a student </s>`, chosen one word at a time based solely on immediate probability without considering longer-term impacts on sequence quality.

6. **Comparison to Beam Search**: In contrast to the beam search where the final choice might be between `<s> I am a student </s>` and `<s> The cat </s>` with considerations of length normalization, greedy decoding might end up with a less globally optimal but locally optimal choice like `<s> I am a student </s>`. There's no comparison of alternative paths; the decision at each step is final.

### Greedy Decoding: Simple Implementation

In [32]:
def greedy_decoder(model):
    """
    Generate a sequence of words using a greedy decoding strategy.

    Args:
    - model (nn.Module): The trained model for generating sequences.

    Returns:
    - List[str]: The generated sequence of words.
    """
    model.eval()  # Set the model to evaluation mode
    input_word = torch.tensor([start_token], dtype=torch.long)
    generated_sequence = [start_token]
    
    for _ in range(max_length - 1):
        with torch.no_grad():  # No need to track gradients
            output_logits = model(input_word.unsqueeze(0))  # Add batch dimension
        next_word = output_logits.argmax(dim=-1).item()
        generated_sequence.append(next_word)
        
        if next_word == end_token:
            break
        
        input_word = torch.tensor([next_word], dtype=torch.long)
    
    # Convert indices to words using the inverse vocabulary
    generated_words = [inv_vocab[idx] for idx in generated_sequence if idx in inv_vocab]
    
    return generated_words

generated_sequence = greedy_decoder(model)
print("Generated sequence1:", ' '.join(generated_sequence))
print("Generated sequence2:", ' '.join(generated_sequence))
print("Generated sequence3:", ' '.join(generated_sequence))

Generated sequence1: <s> the sky </s>
Generated sequence2: <s> the sky </s>
Generated sequence3: <s> the sky </s>


### Insights from the Output

- **Sequential Prediction**: Greedy decoding generates text one word at a time, starting from a specified start token. At each step, it selects the word with the highest probability as the next word in the sequence. This process continues until either the end token is generated or the maximum sequence length is reached.
- **Deterministic Output**: Given the same start conditions (e.g., the start token and model state), greedy decoding will always produce the same output sequence. This is evident from the repeated outputs (`<s> the sky </s>`) shown in the example. The deterministic nature stems from always choosing the most likely next word without exploring less probable alternatives.
- **Efficiency**: Greedy decoding is computationally efficient because it involves a single forward pass through the model for each word in the sequence, without the need to track or compare multiple potential sequences.
- **Limitations**:
  - **Lack of Diversity**: The deterministic nature of greedy decoding means it may not explore diverse or creative text possibilities. It tends to favor more common or expected word sequences, which can lead to repetitive or generic outputs.
  - **Local Optima**: By focusing on the highest probability word at each step without considering the overall sequence quality, greedy decoding can get trapped in local optima—choices that seem best in the short term but don't necessarily lead to the best overall sequence.
  - **End Token Dependency**: The decoding process's conclusion relies heavily on the model's ability to predict the end token (`</s>`) at the appropriate time. If the model struggles with this, sequences might end abruptly or not at all within the given maximum length.

In summary, greedy decoding serves as a straightforward and efficient baseline strategy for sequence generation. However, depending on the application's needs for creativity, diversity, and context sensitivity, exploring alternative decoding strategies may be beneficial.

## Beam Search Decoder

Beam Search keeps track of multiple hypotheses (sequences) at each step. The score of a sequence is the sum of the log probabilities of the words in the sequence, normalized by the sequence length to penalize longer sequences. It's a heuristic search algorithm that expands upon the idea of greedy decoding by exploring a wider set of possible sequences at each step and selecting the best ones to continue expanding.

### Beam Search Decoder: Mathematical Foundation

The goal of beam search is to find a sequence $Y = (y_1, y_2, \ldots, y_{|Y|})$ that maximizes the score $Score(Y)$, given the input $X$. The score of a sequence is determined by the sum of the log probabilities of the words in the sequence, normalized by the sequence length raised to the power of a normalization parameter $\alpha$. This can be mathematically represented as:

$$
Score(Y) = \frac{1}{|Y|^\alpha} \sum_{t=1}^{|Y|} \log P(y_t | y_{1:t-1}, X)
$$

where:
- $(|Y|)$ is the length of the sequence $(Y)$.
- $(y_t)$ is the word at time $(t)$.
- $(X)$ is the input to the sequence generation model.
- $(\alpha)$ is a length normalization parameter that helps balance the preference between longer and shorter sequences. A higher value of $(\alpha)$ penalizes longer sequences more heavily.


### Beam Search: A Conceptual Example

The example below demonstrates how beam search balances exploration and exploitation in sequence generation, considering multiple paths and using the score function to select the most promising sequences to explore further.

Imagine you're using beam search to generate text with a beam size $k$ of 2, meaning at each step, you keep the top 2 sequences according to their scores:

1. **Initialization**: Start with the start token `<s>` as the initial sequence.

2. **Step 1**: Generate all possible next words from `<s>` and their log probabilities. Suppose we have two words with the highest probabilities, "I" and "The" (assume "I" has a slightly higher score). We keep these two sequences: `<s> I` and `<s> The`.

3. **Step 2**: For each of the two sequences, generate all possible continuations and calculate their scores. For `<s> I`, the top two continuations might be `<s> I am` and `<s> I have`. For `<s> The`, they might be `<s> The cat` and `<s> The dog`. Among these four, suppose `<s> I am` and `<s> The cat` have the highest scores; we keep these two.

4. **Continue**: Repeat this process, expanding each of the top sequences and keeping only the top $k$ sequences at each step.

5. **Termination**: The process continues until a stopping criterion is met, such as reaching a maximum sequence length or all of the top sequences ending with an end token `</s>`. The highest-scoring sequence among the final set is selected as the output.

6. **Selection**: Consider the final two sequences are `<s> I am a student </s>` with a total log probability of -3 and `<s> The cat </s>` with a total log probability of -1. Without length normalization $(\alpha = 0)$, `<s> I am a student </s>` would win despite being longer and having a lower average probability per word. With length normalization $(\alpha > 0)$, we penalize the longer sequence, potentially making `<s> The cat </s>` the winner if we value brevity or if the penalty is high enough to offset the lower total log probability of the shorter sequence.

### Beam Search Decoding: Simple Implementation

In [87]:
max_length = 30

def beam_search_decoder_with_length_norm(model, beam_width, alpha=0.7):
    """
    Generate a sequence of words using a beam search decoding strategy with length normalization.

    Args:
    - model (nn.Module): The trained model for generating sequences.
    - beam_width (int): The number of sequences to keep at each step.
    - alpha (float): The length normalization parameter.

    Returns:
    - List[str]: The highest-scoring generated sequence of words.
    """
    model.eval()  # Evaluation mode
    hypotheses = [([start_token], 0.0)]  # Initialize with the start token and zero score
    
    for _ in range(max_length - 1):
        new_hypotheses = []
        
        for seq, score in hypotheses:
            last_word = torch.tensor([seq[-1]], dtype=torch.long).unsqueeze(0)
            with torch.no_grad():
                # Get log probabilities from the model
                log_probs = model(last_word).squeeze(0)  # Assume model output shape is [1, vocab_size]
            
            for idx in range(log_probs.size(1)):  # Iterate over the second dimension (vocab_size)
                next_word_log_prob = log_probs[0, idx].item()  # Correctly access scalar log probability
                new_seq = seq + [idx]
                new_log_prob_sum = score + next_word_log_prob  # Accumulate log probabilities correctly
                
                # Apply length normalization
                normalized_score = new_log_prob_sum / (len(new_seq) ** alpha)
                
                new_hypotheses.append((new_seq, normalized_score))  # normalized_score is already a Python float

        # Keep the top 'beam_width' sequences after filtering
        hypotheses = sorted(new_hypotheses, key=lambda x: x[1], reverse=True)[:beam_width]
        
        if all(seq[-1] == end_token for seq, _ in hypotheses):
            break
    
    # Choose the best sequence
    best_sequence, _ = sorted(hypotheses, key=lambda x: x[1], reverse=True)[0]
    generated_words = [inv_vocab.get(idx, '<unk>') for idx in best_sequence]  # Convert indices to words
    
    return generated_words

beam_width = 3  # Beam width
alpha = 0.3  # Length normalization parameter
generated_sequence = beam_search_decoder_with_length_norm(model, beam_width, alpha)
print("Generated sequence with length normalization:", ' '.join(generated_sequence))


Generated sequence with length normalization: <s> birds fly in the lazy dog </s> in the lazy dog </s> brown fox jumps over the sky </s> in the sky </s> in the sky </s> in the


### Insights from the Output

1. **Diversity in Hypotheses**: The output sequence "`<s> birds fly in the lazy dog </s> in the lazy dog </s> brown fox jumps over the sky </s> in the sky </s> in the sky </s> in the`" showcases the diversity that beam search can introduce compared to greedy decoding. By keeping multiple hypotheses (sequences) at each step, beam search explores a wider range of possible continuations, leading to more varied text generation.

2. **Impact of Length Normalization**: The inclusion of length normalization, controlled by the parameter \(\alpha\), helps to balance between shorter and longer sequences. By penalizing longer sequences to a degree (\(\alpha = 0.3\) in this example), it encourages the selection of sequences that are coherent but not excessively verbose. This parameter is crucial for ensuring that the final generated text doesn't favor overly long or short sequences disproportionately.

3. **Repetition and Coherence**: The generated sequence does show some repetition (e.g., multiple occurrences of "`in the sky </s>`"), which is a common challenge in sequence generation tasks. Beam search's focus on keeping the top scoring sequences can sometimes lead to repetition if those sequences dominate the beam. Length normalization helps mitigate this by penalizing sequences that merely extend without adding meaningful content, but repetition may still occur.

4. **Selection of the Best Sequence**: The final sequence chosen demonstrates how beam search, combined with length normalization, attempts to find a balance between sequence length, coherence, and the cumulative probability of the sequence. The selection of "`<s> birds fly in the lazy dog </s>`" as part of the output suggests that this sequence had a favorable combination of length and probability score under the given model and parameters.

5. **Model Dependence**: The quality and relevance of the generated sequences heavily depend on the underlying model's training and architecture. A well-trained model on a comprehensive dataset can produce more coherent and contextually appropriate sequences. The example output reflects the limitations and capabilities of the provided `DecoderModel`, which operates on a simplified level without the complexities of real-world language understanding.

This example illustrates the nuanced dynamics of beam search in text generation, emphasizing the technique's strengths in expanding search space and the importance of careful parameter tuning and model design.

## Pure Sampling Decoder

### Pure Sampling: Mathematical Foundation

Pure sampling can be represented as choosing the next word $y_t$ at each timestep $t$ based on the probability distribution:

$$
P(y_t | y_{1:t-1}, X)
$$

provided by the model, where:
- $y_t$ is the word selected at time $t$.
- $X$ is the input to the decoder.
- $y_{1:t-1}$ are the words selected in previous steps.

The key distinction from greedy decoding is that pure sampling takes into account the full distribution for the next word selection, rather than simply picking the most likely word.

### Pure Sampling Decoding: A Conceptual Example

Here's a step-by-step example of how pure sampling might run:

1. **Initialization**: Start with the start token `<s>` as the initial sequence.

2. **Step 1**: Generate all possible next words from `<s>` along with their probabilities. Instead of selecting the highest probability word, we **sample** a word based on the probability distribution. Suppose the distribution slightly favors "I", but "The" is also a strong contender; we might end up with either `<s> I` or `<s> The` based on the randomness of the sampling.

3. **Step 2**: For the selected sequence, generate all possible continuations and their probabilities. Again, **sample** the next word based on this distribution. The process introduces variability, as even less probable words have a chance of being selected.

4. **Continue**: This process repeats, with each step introducing potential for diverse continuations based on the probability distribution of the next word.

5. **Termination**: The sequence generation concludes when an end token `</s>` is sampled or the maximum sequence length is reached. Due to the stochastic nature of pure sampling, different runs will produce different sequences.

### Pure Sampling Decoding: Simple Implementation

In [97]:
import torch

max_length = 10

def pure_sampling_decoder(model):
    """
    Generate a sequence of words using a pure sampling decoding strategy.

    Args:
    - model (nn.Module): The trained model for generating sequences.

    Returns:
    - List[str]: The generated sequence of words.
    """
    model.eval()  # Set the model to evaluation mode
    input_word = torch.tensor([start_token], dtype=torch.long)
    generated_sequence = [start_token]
    
    for _ in range(max_length - 1):
        with torch.no_grad():  # No need to track gradients
            output_logits = model(input_word.unsqueeze(0))  # Add batch dimension
            probabilities = F.softmax(output_logits, dim=-1).squeeze()  # Ensure probabilities is 1D or 2D

            # Check if probabilities is 2D (batched) and select the first (and only) batch
            if probabilities.dim() > 1:
                probabilities = probabilities[0]
            
            next_word = torch.multinomial(probabilities, 1).item()  # Sample based on probabilities
            generated_sequence.append(next_word)
        
            if next_word == end_token:
                break
        
            input_word = torch.tensor([next_word], dtype=torch.long)
    
    # Convert indices to words using the inverse vocabulary
    generated_words = [inv_vocab[idx] for idx in generated_sequence if idx in inv_vocab]
    
    return generated_words

generated_sequence = pure_sampling_decoder(model)
print("Generated sequence with pure sampling decoding:", ' '.join(generated_sequence))

Generated sequence with pure sampling decoding: <s> the cat sat on cat the best dessert </s>



### Insights from Pure Sampling

The output from the pure sampling decoder, "`<s> the cat sat on cat the best dessert </s>`", provides a valuable demonstration of pure sampling's characteristics and its implications for text generation. Here are several insights drawn from this output regarding pure sampling decoding:

1. **Diversity and Unpredictability**: The output sequence showcases the inherent diversity and unpredictability of pure sampling. Unlike greedy decoding, which would likely produce the same sequence given the same starting conditions, pure sampling can generate a wide variety of sequences due to its stochastic nature. This is evident from the unexpected repetition of "cat" and the seemingly unrelated conclusion "the best dessert".

2. **Creative and Novel Outputs**: Pure sampling has the potential to generate more creative and novel outputs by exploring less probable paths that might be overlooked by deterministic approaches. This can lead to interesting and sometimes surprising sequences, as seen with the inclusion of "dessert" in a sequence starting with "the cat".

3. **Potential for Reduced Coherence**: While pure sampling introduces diversity, it also risks producing sequences with reduced coherence or logical flow. The sequence "`<s> the cat sat on cat the best dessert </s>`" might be grammatically correct but lacks a clear, coherent narrative, highlighting the trade-off between creativity and coherence.

4. **Balance Between Exploration and Exploitation**: Pure sampling represents a shift towards exploration in the exploration-exploitation trade-off, potentially at the expense of exploiting more probable (and possibly more coherent) sequences. This approach is beneficial for generating diverse content but requires careful consideration of the application's needs regarding coherence and predictability.

5. **Application-Specific Suitability**: The suitability of pure sampling depends heavily on the specific application. For tasks requiring high creativity and diversity, such as story generation or ideation, pure sampling can be advantageous. For tasks requiring high precision and coherence, such as formal text generation, modifications or alternative strategies might be necessary.

6. **Impact of the Model's Training**: The effectiveness and coherence of the sequences generated by pure sampling are also influenced by the underlying model's training. A well-trained model on a comprehensive and diverse dataset can produce more meaningful and contextually appropriate sequences even under pure sampling, as it learns robust representations of language patterns.

## Top-K Sampling Decoder

Top-K Sampling is a decoding strategy that introduces randomness into the text generation process, allowing for more diverse and creative outputs. Unlike greedy decoding, which always picks the word with the highest probability, Top-K sampling narrows the choice to the top \(K\) most likely next words and randomly selects from this subset based on their probability distribution. This method helps prevent the model from getting stuck in repetitive loops and generates more varied sentences.

### Top-K Sampling: Mathematical Foundation

Top-K Sampling modifies the probability distribution of the next word by zeroing out the probabilities of all but the top $K$ candidates before sampling. Mathematically, given a probability distribution: 

$$
P(y_t | y_{1:t-1}, X)
$$

over the vocabulary for the next word $y_t$ given the input $X$ and previous words $y_{1:t-1}$. 

Top-K sampling involves:

1. Sorting $P$ and retaining the top $K$ words' probabilities, setting others to zero.
2. Renormalizing the modified distribution so that the sum of probabilities equals 1.
3. Sampling the next word $y_t$ from this renormalized distribution.

### Top-K Sampling: A Conceptual Example

Imagine using Top-K Sampling with \(K=3\) to generate text:

1. **Initialization**: Begin with the start token `<s>` as the initial sequence.

2. **Step 1**: For the initial word `<s>`, generate the probability distribution over all possible next words. Suppose the top 3 words according to their probabilities are "I", "The", and "Birds".

3. **Sampling**: Instead of choosing the highest probability word, randomly select one of the top 3 words, say "I", based on their probabilities. The sequence now is `<s> I`.

4. **Continue**: Repeat this process, at each step generating probabilities for the next word, narrowing down to the top 3, and sampling from them. This leads to more varied and potentially less predictable text generation than greedy decoding.

5. **Termination**: This process ends when an end token `</s>` is generated or a maximum sequence length is reached, producing a final sequence like `<s> I am flying </s>`.

6. **Diversity Introduced**: By sampling from the top \(K\) candidates at each step, Top-K Sampling introduces randomness into the sequence generation, leading to diverse outputs on different runs, even with the same start conditions.

### Top-K Sampling: Simple Implementation

In [115]:
def top_k_sampling_decoder(model, k=3):
    """
    Generate a sequence of words using a Top-K sampling decoding strategy.

    Args:
    - model (nn.Module): The trained model for generating sequences.
    - k (int): The number of top probabilities to consider for sampling.

    Returns:
    - List[str]: The generated sequence of words.
    """
    model.eval()
    input_word = torch.tensor([start_token], dtype=torch.long)
    generated_sequence = [start_token]
    
    for _ in range(max_length - 1):
        with torch.no_grad():
            output_logits = model(input_word.unsqueeze(0))  # Ensure input is batched
            probabilities = F.softmax(output_logits, dim=-1).squeeze()  # Adjust for batch dimension

            # Handling both single and batched cases
            if probabilities.dim() == 2:
                probabilities = probabilities[0]

            top_probs, top_indices = torch.topk(probabilities, k, dim=-1)
            # Sample from the top k probabilities
            sampled_idx = torch.multinomial(top_probs, 1).item()
            next_word = top_indices[sampled_idx].item()  # Correct indexing

            generated_sequence.append(next_word)
        
            if next_word == end_token:
                break
        
            input_word = torch.tensor([next_word], dtype=torch.long)
    
    # Convert indices back to words
    generated_words = [inv_vocab.get(idx, '<unk>') for idx in generated_sequence]
    return generated_words

generated_sequence = top_k_sampling_decoder(model, k=3)
print("Generated sequence with Top K sampling decoding:", ' '.join(generated_sequence))

Generated sequence with Top K sampling decoding: <s> fish swim in the sea </s>


### Insights from Top-K Sampling

The output generated by the Top-K Sampling decoding strategy, "`<s> fish swim in the sea </s>`", provides several insights into how this approach can affect the text generation process. Here are the key takeaways regarding Top-K Sampling based on the generated output:

1. **Increased Diversity and Creativity**: Unlike greedy decoding, which deterministically selects the most probable next word, Top-K sampling introduces randomness by selecting from the top \(K\) probable words. This can lead to more diverse and creative outputs, as seen in the concise and coherent sentence generated in the example.

2. **Balance Between Randomness and Relevance**: By limiting the sampling pool to the top \(K\) words, the model ensures that the chosen words are still highly probable and relevant to the context. This strikes a balance between randomness (for diversity) and maintaining coherence in the generated text.

3. **Customizable Diversity**: The choice of \(K\) plays a significant role in controlling the diversity of the generated text. A smaller \(K\) increases randomness (potentially leading to more creative but less predictable outputs), while a larger \(K\) reduces randomness and makes the output more deterministic and similar to greedy decoding. The value of \(K=3\) used here suggests a moderate level of diversity.

4. **Potential for Unpredictable Outputs**: Given the stochastic nature of Top-K Sampling, running the decoder multiple times with the same initial conditions can lead to different sequences each time. This variability can be advantageous for applications requiring varied outputs but may require additional filtering or post-processing for tasks demanding high consistency.

5. **Mitigation of Repetitive Patterns**: A common issue with simpler decoding strategies, like greedy decoding, is the tendency to produce repetitive or generic text. Top-K Sampling's introduction of randomness helps mitigate this issue by occasionally selecting less probable (but still relevant) words, leading to more nuanced and interesting text.

6. **Effectiveness in Contextual Coherence**: The generated sentence demonstrates that Top-K Sampling can produce contextually coherent and grammatically correct sentences, indicating that even with the introduction of randomness, the model can leverage its understanding of language structure and context learned during training.

7. **Dependency on Model Quality**: The coherence and relevance of the generated text heavily depend on the underlying model's quality and training. A well-trained model on a comprehensive dataset is more likely to leverage Top-K Sampling effectively to produce meaningful and contextually appropriate text.

## Top-P (Nucleus) Sampling Decoder

Top-P (Nucleus) Sampling is a sophisticated strategy for generating text that focuses on selecting the next word from a dynamically sized subset of the vocabulary. This subset, or "nucleus," comprises the smallest set of words whose cumulative probability exceeds a threshold (P$, allowing for more nuanced control over the diversity and coherence of the generated text.

### Top-P Sampling: Mathematical Foundation

The principle behind Top-P sampling involves accumulating the probabilities of the most probable words and stopping when the sum exceeds the threshold $P$. Unlike Top-K sampling, which considers a fixed number $K$ of top words, Top-P dynamically adjusts the number of words considered at each step based on their cumulative probability, aiming to capture the "nucleus" of the distribution that meaningfully contributes to the next word choice.

Mathematically, given a probability distribution: 

$$
P(y_t | y_{1:t-1}, X)
$$

over the vocabulary for the next word $y_t$ given the input $X$ and the sequence so far $y_{1:t-1}$. 

Top-P sampling involves:

1. Sorting the words by their probability in descending order.
2. Accumulating the probabilities until the sum $P$ is exceeded.
3. Sampling the next word $y_t$ from this cumulative distribution.

### Top-P Sampling: A Conceptual Example

Let's use Top-P sampling with a threshold \(P = 0.9\) to generate text:

1. **Initialization**: Begin with the start token `<s>` as the initial sequence.

2. **Step 1**: Compute the probability distribution over all possible next words from `<s>`. Identify the smallest set of top words whose cumulative probability exceeds 0.9. Suppose this includes words like "I", "The", and "Birds", and their cumulative probability just exceeds 0.9.

3. **Sampling**: Randomly select one of these top words based on the cumulative distribution, say "I". Now, the sequence is `<s> I`.

4. **Continue**: Repeat this process for each next word, dynamically determining the size of the set at each step based on the threshold \(P\). This leads to varied sequences that reflect the underlying model's predictions while introducing diversity through the randomness of sampling.

5. **Termination**: Continue until an end token `</s>` is generated or a maximum sequence length is reached, resulting in a final sequence like `<s> I watch the birds fly away </s>`.

### Top-P (Nucleus) Sampling Decoder: Simple Implementation

In [116]:
def top_p_sampling_decoder(model, p=0.9):
    """
    Generate a sequence of words using a Top-P (Nucleus) sampling decoding strategy.

    Args:
    - model (nn.Module): The trained model for generating sequences.
    - p (float): The cumulative probability threshold for nucleus sampling.

    Returns:
    - List[str]: The generated sequence of words.
    """
    model.eval()  # Set the model to evaluation mode
    input_word = torch.tensor([start_token], dtype=torch.long)
    generated_sequence = [start_token]
    
    for _ in range(max_length - 1):
        with torch.no_grad():
            output_logits = model(input_word.unsqueeze(0))  # Add batch dimension
            probabilities = F.softmax(output_logits, dim=-1).squeeze()
            
            # Sort probabilities to identify the "nucleus"
            sorted_probs, sorted_indices = torch.sort(probabilities, descending=True)
            cumulative_probs = torch.cumsum(sorted_probs, dim=0)
            # Find the cut-off index where the cumulative probability exceeds p
            cutoff_index = (cumulative_probs > p).nonzero().numel()
            if cutoff_index > 0:  # Ensure there's at least one word to choose from
                top_p_probs = sorted_probs[:cutoff_index]
                top_p_indices = sorted_indices[:cutoff_index]
            else:  # Fallback to the most probable word if none meet the threshold
                top_p_probs = sorted_probs[:1]
                top_p_indices = sorted_indices[:1]

            # Sample from the "nucleus" probabilities
            sampled_idx = torch.multinomial(top_p_probs, 1).item()
            next_word = top_p_indices[sampled_idx].item()
            generated_sequence.append(next_word)
        
            if next_word == end_token:
                break
        
            input_word = torch.tensor([next_word], dtype=torch.long)
    
    generated_words = [inv_vocab.get(idx, '<unk>') for idx in generated_sequence]
    return generated_words

# Example usage
generated_sequence = top_p_sampling_decoder(model, p=0.9)
print("Generated sequence with Top-P sampling decoding:", ' '.join(generated_sequence))

Generated sequence with Top-P sampling decoding: <s> chocolate is the sky </s>


### Insights from Top-P Sampling Output

1. **Dynamic Vocabulary Subset**: The output demonstrates Top-P sampling's ability to dynamically adjust the vocabulary subset (the "nucleus") from which it samples the next word. By considering only the smallest set of words whose cumulative probability exceeds a threshold $p$, Top-P sampling can adaptively focus on the most probable and contextually relevant words while maintaining the flexibility to introduce variety into the generated text.

2. **Balanced Diversity and Coherence**: The generated sequence "`<s> chocolate is the sky </s>`" illustrates the balance Top-P sampling achieves between diversity and coherence. By filtering out less probable words, it ensures the generated text remains relevant and coherent to the given context. At the same time, by sampling within the nucleus, it introduces variability and avoids the deterministic output often seen with greedy decoding.

3. **Enhanced Creativity**: Compared to methods that may lead to repetitive or predictable text, such as greedy decoding, Top-P sampling can produce more creative and less expected sequences. The somewhat whimsical nature of the example sentence underscores this point, highlighting the model's capacity to generate novel associations between words.

4. **Reduced Repetition and Generic Responses**: Top-P sampling mitigates common issues like repetition or generic responses by narrowing down the choice set to a contextually rich subset of words. This method helps prevent the model from defaulting to safe, repetitive word choices that may occur in other decoding strategies.

5. **Sensitivity to the Threshold $p$**: The choice of the cumulative probability threshold $p$ significantly influences the output. A higher $p$ value includes more words in the nucleus, potentially leading to more predictable text, while a lower $p$ focuses on a narrower set, increasing the output's uniqueness and creativity but possibly at the expense of coherence.

6. **Randomness and Unpredictability**: The inherent randomness in Top-P sampling means that different runs may produce different outputs, even with the same starting conditions. This feature is advantageous for applications requiring varied responses but may require careful management to ensure consistency in contexts demanding more predictable outcomes.

## References

- All the images are taken from the article: [How to generate text: using different decoding methods for language generation with Transformers](https://huggingface.co/blog/how-to-generate)