# Advanced RNN Topics

## Using RNNs For Classification


When you have a sequential input (such as a sentence, an email, or any time series data) and you wish to assign a single label or category to the entire sequence, an RNN-based classification approach is often very effective. The general idea is to treat the RNN as a feature extractor—processing the sequence step by step—and then use the resulting representation to predict the classification label.

RNNs are well-suited for these tasks because they process data one element at a time while maintaining a hidden state that carries information about previous elements in the sequence. This ability makes RNNs an excellent candidate for classification tasks where the input is a sequence and the goal is to assign a label to the entire sequence.

Imagine two common examples:
- **Sentiment Analysis:** You have a string of text (e.g., a movie review) and you want to classify it as expressing positive or negative sentiment.
- **Spam Detection:** You have the content of an email and you want to label it as "spam" or "not spam."

In both cases, the input is a sequence (words in a sentence or tokens in an email), and the goal is to produce a single output label that summarizes the entire sequence.

### Model Architecture Overview

The typical approach involves two main components:

1. **The RNN as a Featurizer:**  
   - The RNN (or its variants like LSTM/GRU) processes the input sequence step by step.
   - At each step, it updates its hidden state, which is designed to capture the context of the sequence seen so far.
   - In many classification tasks, we only need a summary of the sequence, so we might take the hidden state from the final time step. This final hidden state is considered a learned representation (or feature) of the whole sequence.

2. **The Classifier:**  
   - After obtaining the sequence representation from the RNN, this representation is fed into a classifier. 
   - The classifier can be a feed-forward neural network (also known as a fully connected layer), or even a simpler model like logistic regression.
   - This component outputs the final class probabilities (e.g., positive/negative for sentiment analysis or spam/not spam for email classification).

In this sense, the RNN acts as a feature extractor, transforming raw sequential data into a fixed-size representation that captures its underlying patterns, while the classifier makes the final decision based on these features.

### TensorFlow Example

In [1]:
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np

# Build a simple RNN classifier using Keras
def build_rnn_classifier(input_size, hidden_size, num_classes, seq_length):
    model = models.Sequential()
    # SimpleRNN layer automatically processes the sequence and returns the last output by default
    model.add(layers.SimpleRNN(hidden_size, input_shape=(seq_length, input_size)))
    model.add(layers.Dense(num_classes, activation='softmax'))
    return model

# Hyperparameters
input_size = 10     # Dimensionality of each input vector
hidden_size = 20    # Size of the RNN hidden state
num_classes = 2     # Number of classes for classification
seq_length = 5      # Length of each input sequence
batch_size = 16     # Batch size

# Instantiate and compile the model
model = build_rnn_classifier(input_size, hidden_size, num_classes, seq_length)
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')

# Create dummy input data and target labels for training
inputs = np.random.randn(batch_size, seq_length, input_size)
targets = np.random.randint(0, num_classes, size=(batch_size,))

# Train for one epoch on the dummy data
model.fit(inputs, targets, epochs=1, batch_size=batch_size, verbose=1)

# Prediction: using a new dummy sample
test_input = np.random.randn(1, seq_length, input_size)  # single sample
predictions = model.predict(test_input)
predicted_class = np.argmax(predictions, axis=1)[0]
print("Predicted class for test input:", predicted_class)

2025-03-17 01:43:50.932677: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-03-17 01:44:11.653427: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


Predicted class for test input: 0


### PyTorch Example

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define a simple RNN-based classifier
class SimpleRNNClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(SimpleRNNClassifier, self).__init__()
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)
    
    def forward(self, x):
        # x shape: [batch_size, seq_length, input_size]
        out, _ = self.rnn(x)  # out shape: [batch_size, seq_length, hidden_size]
        # Use the last time step's hidden state for classification
        out = out[:, -1, :]   # shape: [batch_size, hidden_size]
        out = self.fc(out)    # shape: [batch_size, num_classes]
        return out

# Hyperparameters
input_size = 10     # Dimensionality of each input vector
hidden_size = 20    # Size of the hidden state in the RNN
num_layers = 1      # Number of RNN layers
num_classes = 2     # Number of output classes (e.g., spam or not spam)
batch_size = 16     # Batch size
seq_length = 5      # Length of each input sequence

# Instantiate the model, loss function, and optimizer
model = SimpleRNNClassifier(input_size, hidden_size, num_layers, num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Create dummy input data and target labels for training
inputs = torch.randn(batch_size, seq_length, input_size)
targets = torch.randint(0, num_classes, (batch_size,))

# Training step
model.train()
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()

print("Training complete, loss:", loss.item())

# Prediction: using a new dummy sample
model.eval()  # set to evaluation mode
with torch.no_grad():
    test_input = torch.randn(1, seq_length, input_size)  # single sample
    test_output = model(test_input)
    # Get predicted class (highest score)
    predicted_class = torch.argmax(test_output, dim=1)
    print("Predicted class for test input:", predicted_class.item())

## Additional Considerations

### Variants in Representing the Sequence

While the most common approach is to use the hidden state at the last time step, there are alternative strategies for summarizing the sequence:

- **Element-wise Mean Pooling:**  
  Instead of just taking the final state, compute the average of the hidden states over all time steps. This approach ensures that every part of the sequence contributes equally to the final representation.
  
- **Element-wise Max Pooling:**  
  Alternatively, take the element-wise maximum over all hidden states. This method captures the most salient features that were activated at any time step in the sequence.

Each of these strategies has its own strengths. For example, mean pooling can smooth out the features, whereas max pooling tends to emphasize the strongest signals across the sequence.

### End-to-End Training

One of the major advantages of this architecture is that it supports **end-to-end training**:

- **Error Gradients Flow Back Through the Entire Network:**  
  When you train the model, you compute a loss based on the classifier's output (e.g., cross-entropy loss for classification). The gradients of this loss are then backpropagated through the classifier and further into the RNN.

- **RNN as Part of the Pipeline:**  
  This means that the RNN not only learns to represent the sequence but also adapts its internal representation to improve the final classification. This synergy between the RNN and the classifier is essential for learning meaningful features from the data.
  
### Practical Considerations

- **Choice of RNN Variant:**  
  Depending on the complexity of the sequence and the amount of data, you might choose a vanilla RNN, Long Short-Term Memory (LSTM), or Gated Recurrent Unit (GRU). LSTMs and GRUs are often preferred because they better handle long-term dependencies.
  
- **Regularization and Overfitting:**  
  Since RNNs can be prone to overfitting, techniques such as dropout, early stopping, and using more data are important considerations.

- **Batching and Sequence Padding:**  
  When working with sequences of different lengths, you'll need to pad them to a common length and use masking techniques so that the model does not treat the padded values as meaningful input.

## Teacher Forcing

**Teacher forcing** is a training strategy used primarily in sequence generation tasks, such as text generation or machine translation. In a typical sequence-to-sequence model, the RNN generates one token at a time based on the previous token it generated. However, during training, the model can become unstable if it relies solely on its own predictions as inputs for the next time step. 

Teacher forcing addresses this by replacing the model's previous prediction with the actual target (or "ground truth") token during training. In other words, **instead of feeding back the model's generated output, you "force" the correct answer into the next time step.** 

This method:
- **Speeds up training:** Since the model is always receiving the correct context, it learns the mapping from input to output faster.
- **Improves convergence:** It helps mitigate error accumulation over time, which is especially important in long sequences.
  
**Relation to Text Generation:**  
When generating text, teacher forcing helps the model learn the correct sequence dynamics. For example, if the model is tasked with generating a sentence, using teacher forcing during training ensures that the RNN sees the correct word at each time step, rather than relying on its potentially flawed own output. This practice helps in developing a more robust sequence generation process, although during inference (when generating new text), the model must rely on its own predictions.

![](https://miro.medium.com/v2/resize:fit:842/1*U3d8D_GnfW13Y3nDgvwJSw.png)

## RNNs for Word Embeddings

**Word embeddings** are dense vector representations of words that capture semantic meanings, where similar words are mapped to similar vectors. While many models (like Word2Vec or GloVe) are designed to learn word embeddings independently, RNNs can also be used to learn or refine embeddings within the context of sequential data.

#### How It Works:
- **Input Representations:** In an RNN, words are typically first converted into embeddings using an embedding layer. These embeddings are then fed into the RNN to capture contextual information.
- **Contextualized Representations:** As the RNN processes the sequence, it not only considers the fixed embedding but also integrates context from surrounding words. This results in dynamic, context-aware representations that can capture nuances like polysemy (words having multiple meanings based on context).
- **Learning Process:** During training, the gradients from the output (e.g., predicting the next word or classifying a sentence) backpropagate through the embedding layer, thus refining the embeddings based on the task at hand.

Using RNNs in this manner allows the model to jointly learn the word embeddings and the sequential relationships in the data, often resulting in more effective representations for tasks like sentiment analysis, machine translation, or language modeling.

#### TensorFlow Example

In [None]:
import tensorflow as tf
import numpy as np

# Define model hyperparameters
vocab_size = 50       # Size of the vocabulary
embedding_dim = 8     # Dimensionality of the word embeddings
hidden_dim = 16       # Number of units in the RNN
output_dim = 2        # Number of output classes (e.g., for classification)
seq_length = 7        # Length of each input sequence
batch_size = 4        # Batch size

# Build a simple RNN classifier using Keras
model = tf.keras.Sequential([
    # Embedding layer: maps token indices to dense vectors
    tf.keras.layers.Embedding(vocab_size, embedding_dim, input_length=seq_length),
    # SimpleRNN layer: processes the sequence of embeddings
    tf.keras.layers.SimpleRNN(hidden_dim),
    # Dense layer for classification with softmax activation
    tf.keras.layers.Dense(output_dim, activation='softmax')
])

# Compile the model with an optimizer and loss suitable for classification
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')

# Create dummy input data and target labels
inputs = np.random.randint(0, vocab_size, (batch_size, seq_length))
targets = np.random.randint(0, output_dim, (batch_size,))

# Train the model for one epoch on the dummy data
model.fit(inputs, targets, epochs=1, batch_size=batch_size, verbose=1)

# Prediction: using the same dummy inputs for demonstration
predictions = model.predict(inputs)
predicted_classes = np.argmax(predictions, axis=1)
print("TensorFlow predicted classes for inputs:", predicted_classes)

# Optional: Extract and inspect the learned embedding weights
embedding_layer = model.layers[0]
embedding_weights = embedding_layer.get_weights()[0]
print("Shape of learned embedding weights:", embedding_weights.shape)

#### PyTorch Example

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define a simple RNN model that includes a word embedding layer
class RNNWordEmbeddingModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim):
        super(RNNWordEmbeddingModel, self).__init__()
        # Embedding layer to convert token indices to dense vectors
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        # RNN layer: processes the sequence of embeddings
        self.rnn = nn.RNN(embedding_dim, hidden_dim, batch_first=True)
        # Classifier: maps the RNN output to desired output dimensions (e.g., for classification)
        self.fc = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, x):
        # x: [batch_size, seq_length] -> embedding: [batch_size, seq_length, embedding_dim]
        embedded = self.embedding(x)
        # Pass through RNN: out shape is [batch_size, seq_length, hidden_dim]
        out, _ = self.rnn(embedded)
        # Use the last time step's output for classification
        out = out[:, -1, :]
        out = self.fc(out)
        return out, embedded

# Hyperparameters
vocab_size = 50       # Size of the vocabulary (number of unique tokens)
embedding_dim = 8     # Dimensionality of the word embeddings
hidden_dim = 16       # Size of the hidden state in the RNN
output_dim = 2        # Number of output classes (e.g., positive/negative)
batch_size = 4        # Batch size for training
seq_length = 7        # Length of each input sequence

# Instantiate the model, loss function, and optimizer
model = RNNWordEmbeddingModel(vocab_size, embedding_dim, hidden_dim, output_dim)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Create dummy input data: batch of tokenized sequences (indices)
inputs = torch.randint(0, vocab_size, (batch_size, seq_length))
# Dummy target labels for a classification task
targets = torch.randint(0, output_dim, (batch_size,))

# Training step
model.train()
optimizer.zero_grad()
outputs, embedded = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()

print("PyTorch training loss:", loss.item())

# Prediction: using a new dummy sample
model.eval()
with torch.no_grad():
    test_input = torch.randint(0, vocab_size, (1, seq_length))  # single sample
    test_output, test_embedded = model(test_input)
    predicted_class = torch.argmax(test_output, dim=1)
    print("Predicted class for test input:", predicted_class.item())

# Optional: Inspect learned word embeddings for the first sample
print("Learned embeddings for first sample:")
print(embedded[0])

## Weight Tying

**Weight tying** is a technique that reduces the number of parameters in a model by sharing weights between different layers. In language models and sequence-to-sequence models, weight tying is often used to share weights between the input embedding layer and the output softmax layer.

#### Why Use Weight Tying?
- **Parameter Efficiency:** It reduces the model size, which is beneficial for training and inference speed.
- **Regularization:** Sharing weights can act as a form of regularization, helping to prevent overfitting.
- **Empirical Performance:** Research has shown that weight tying can lead to better generalization and performance in language modeling tasks.

In practice, when a word embedding matrix is tied with the output projection matrix, the model effectively learns a single representation for each word that is used both for encoding and decoding, ensuring consistency across the network.

![](https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcRzYdyiXTw53dusWKRfNR9uxHmFLUgUqII3Gg&s)

## Gradient Computation with RNNs

#### Problems:
- **Repeated Multiplication of Weights:**  
  The gradient at each time step involves repeated multiplication by the weight matrix (often denoted as **W**). This repetition can cause the gradients to either shrink or grow exponentially.
  
- **Vanishing Gradients:**  
  When the weights are such that the repeated multiplication causes the gradient to become extremely small, the model has difficulty learning long-range dependencies because the gradients vanish before reaching earlier time steps.
  
- **Exploding Gradients:**  
  Conversely, if the weights are too large, the gradient can grow exponentially, leading to numerical instability during training.

#### Solutions:
- **Special Activation Functions:**  
  Activation functions like the Rectified Linear Unit (ReLU) have a gradient that does not vanish as easily as traditional functions like tanh or sigmoid. Using these activations can help mitigate the vanishing gradient problem.
  
- **Normalization Techniques:**  
  - **Batch Normalization:** Normalizes the output of layers across the batch, stabilizing and speeding up the training.
  - **Layer Normalization:** Applies normalization across the features of each data point rather than across the batch. This is particularly useful in RNNs because it handles varying sequence lengths more gracefully.
  
- **Gradient Clipping:**  
  By setting a threshold for the maximum gradient value, gradient clipping prevents gradients from exceeding a certain magnitude, effectively controlling the exploding gradient problem. This involves computing the norm of the gradient and scaling it if it exceeds a predefined limit.

Each of these techniques helps to stabilize the training of RNNs, ensuring that the gradients remain in a manageable range throughout the sequence, which is essential for learning effective representations from long sequences.

## Greedy And Beam Search - Decoding Strategies

When generating sequences, such as in machine translation or text generation, selecting the right decoding strategy is crucial for generating high-quality outputs. Two common strategies are **Greedy Decoding** and **Beam Search**. Both approaches aim to determine the best sequence of words based on the model's probability estimates, but they differ in how they explore the space of possible sequences.

### Greedy Decoding

**Greedy decoding** is the simplest method for sequence generation. 

It operates by making a series of local, one-step optimal decisions. At each time step, the algorithm looks at the probability distribution over the next possible words and selects the one with the highest probability. This decision is made without considering how the choice might influence future selections.

Greedy decoding effectively builds a search tree where only one branch is followed — the one corresponding to the highest probability word at each step.

Imagine a search tree where each node represents a possible word in the sequence. Greedy decoding only follows one branch — the one that seems best at the current step — ignoring all alternative paths that might lead to a better overall sequence. 

While this method is computationally efficient and straightforward to implement, its narrow focus on the immediate best choice often leads to outputs that are predictable and lack variety. In many cases, this leads to repetitive sequences that do not capture the full richness of the language.

Due to its limitations, greedy decoding is not commonly used in practice for tasks that require nuanced or varied outputs, although it can be useful in simpler or more constrained settings.


### Beam Search

**Beam search** is a more sophisticated decoding strategy that addresses some of the limitations of greedy decoding.

Instead of choosing only the best word at each time step, beam search keeps track of the top **K** candidate sequences, where **K** is known as the beam width. This allows the algorithm to explore several possible sequences simultaneously.

The beam width (typically set between 5 and 10 in practice) controls the number of hypotheses maintained at each step. A larger beam width allows for a more exhaustive search, increasing the likelihood of finding a better overall sequence, but at the cost of increased computational complexity.

#### Process Overview:
  1. **Initialization:**  
     Start with an initial token (often a start-of-sequence token) and initialize the beam with this starting point.
  2. **Expansion:**  
     At each time step, expand all candidate sequences in the beam by appending all possible next words and computing their cumulative probabilities.
  3. **Pruning:**  
     Retain only the top **K** sequences based on their cumulative probabilities.
  4. **Termination:**  
     Continue the expansion and pruning steps until all sequences in the beam reach an end-of-sequence token or a predetermined maximum length.

#### Application:
  Beam search is particularly popular in tasks like machine translation (e.g., translating English to German) where generating a grammatically coherent and contextually accurate sentence is critical. By considering multiple candidate sequences, beam search often produces more natural and varied outputs compared to greedy decoding.

### Summary

Both greedy decoding and beam search are used to generate sequences from probabilistic models:
- **Greedy Decoding:**  
  - Fast and simple.
  - Locally optimal but often suboptimal overall.
  - Tends to produce generic and repetitive outputs.
- **Beam Search:**  
  - More computationally expensive but explores multiple candidate sequences.
  - Achieves a better balance between quality and diversity in the output.
  - Commonly used in complex sequence generation tasks like machine translation.

Selecting the appropriate decoding strategy depends on the specific task requirements and the trade-off between computational resources and the quality of generated sequences.

![](https://heidloff.net/assets/img/2023/08/greedy-beam.jpeg)

## Advanced Sampling Strategies

Always selecting the highest probability word at every step (as in greedy decoding) typically produces sentences that are grammatically correct but can be overly predictable and repetitive. 

This occurs because the model continually picks the most common choices, leading to generic output. In many applications — such as creative text generation or dialogue systems — more interesting outputs are desired. To achieve this, advanced sampling strategies introduce controlled randomness to balance between quality and diversity. 

Two important factors come into play:

- **Quality:** Favoring high-probability words to maintain coherence.
- **Diversity:** Allowing less probable words a chance to be selected, enriching the output with variety.

### Top-K Sampling

Top-K sampling is a straightforward extension of greedy decoding. Instead of selecting only the single highest probability word, the algorithm proceeds as follows:

1. **Candidate Selection:**  
   At each time step, pre-select the top **k** most likely words based on the model's probability distribution.

2. **Truncation and Renormalization:**  
   Truncate the full probability distribution to include only these top-k candidates, then renormalize the probabilities so that they sum to one.

3. **Random Sampling:**  
   Randomly sample the next word from this reduced distribution according to the renormalized probabilities.

This strategy strikes a balance between maintaining a high quality (by limiting choices to the most likely words) and introducing diversity (by allowing a random selection among them). 

However, because **k** is fixed, its effectiveness can vary depending on the nature of the original probability distribution. In some cases, the distribution might be very peaked, so the top-k choices cover nearly all the probability mass; in other cases, a flat distribution means that even the top-k words represent only a small fraction of the total mass.

### Top-P Sampling (Nucleus Sampling)

Top-P sampling, also known as nucleus sampling, dynamically adjusts the candidate pool based on the cumulative probability rather than a fixed number of words. 

The process is as follows:

1. **Cumulative Probability Threshold:**  
   At each step, sort all candidate words by their probability and select the smallest set of words whose cumulative probability exceeds a predetermined threshold **p**.

2. **Renormalization and Sampling:**  
   Renormalize this subset of candidates to form a valid probability distribution and sample the next word from this nucleus.

By focusing on a threshold rather than a fixed count, top-P sampling adapts to the probability distribution's shape. For peaked distributions, the nucleus might be very small, while for flatter distributions, it will be larger. This flexibility ensures that only the most contextually relevant words are considered, balancing coherence with the potential for creative variations.

### Temperature Sampling

Temperature sampling modifies the overall probability distribution by scaling the logits (pre-softmax scores) before applying the softmax function. 

The process involves:

1. **Logit Scaling:**  
   Divide the logits by a temperature factor before computing the softmax. This adjustment reshapes the probability distribution.

2. **Effect of Temperature:**  
   - **Low Temperature (< 1):**  
     The distribution becomes sharper; high-probability words become even more dominant, reducing randomness. This tends to favor high-quality, coherent outputs but can result in repetitive text.
   - **High Temperature (> 1):**  
     The distribution flattens, increasing the chance of selecting lower-probability words, which introduces more diversity and creativity into the generated text.

Temperature sampling allows for a smooth trade-off between quality and diversity without discarding any part of the original probability distribution.

## Sampling Examples

### TensorFlow Example

#### Importing Libraries

In [2]:
import tensorflow as tf
import numpy as np

#### Top-K Sampling

In [3]:
def sample_top_k_tf(logits, k=10):
    """
    Performs top-K sampling:
      1. Select the top k logits.
      2. Mask out others by setting them to a very low value.
      3. Renormalize and sample.
    """
    
    logits = tf.convert_to_tensor(logits)
    topk = tf.math.top_k(logits, k=k)
    topk_logits = topk.values
    topk_indices = topk.indices

    # Create a mask for the top-k values
    full_mask = tf.fill(tf.shape(logits), float('-inf'))
    mask = tf.tensor_scatter_nd_update(full_mask, tf.expand_dims(topk_indices, 1), topk_logits)
    probs = tf.nn.softmax(mask)
    sample = tf.random.categorical(tf.math.log([probs]), num_samples=1)
    
    return tf.gather(tf.range(tf.shape(logits)[0]), tf.squeeze(sample, axis=0))

#### Top-P / Nucleus Sampling

In [4]:
def sample_top_p_tf(logits, p=0.9):
    """
    Performs top-P (nucleus) sampling:
      1. Sort logits and compute probabilities.
      2. Determine the minimal set where cumulative probability exceeds p.
      3. Mask out others, renormalize, and sample.
    """
    
    logits = tf.convert_to_tensor(logits)
    
    # Sort logits in descending order
    sorted_logits, sorted_indices = tf.math.top_k(logits, k=tf.shape(logits)[0])
    sorted_probs = tf.nn.softmax(sorted_logits)
    cumulative_probs = tf.math.cumsum(sorted_probs)
    
    # Create a mask: keep tokens where cumulative probability is less than p
    mask = cumulative_probs <= p
    
    # Ensure at least one token is kept
    mask = tf.concat([[True], mask[1:]], axis=0)
    
    # Set logits for tokens not in the nucleus to a very low value
    masked_logits = tf.where(mask, sorted_logits, tf.fill(tf.shape(sorted_logits), float('-inf')))
    new_probs = tf.nn.softmax(masked_logits)
    sample = tf.random.categorical(tf.math.log([new_probs]), num_samples=1)
    
    # Map back to original indices
    return tf.gather(sorted_indices, tf.squeeze(sample, axis=0))

#### Temperature Sampling

In [5]:
def sample_with_temperature_tf(logits, temperature=1.0):
    """
    Scale logits by temperature and sample one token.
    """
    
    scaled_logits = logits / temperature
    probs = tf.nn.softmax(scaled_logits)
    
    # tf.random.categorical expects a 2D tensor; we expand dims and squeeze the output
    sample = tf.random.categorical(tf.math.log([probs]), num_samples=1)
    return tf.squeeze(sample, axis=0)

#### Usage

In [6]:
# Example usage with dummy logits
vocab_size = 50
dummy_logits_tf = tf.random.normal([vocab_size])  # Example logits for 50 tokens

# Temperature sampling example
temp_sample_tf = sample_with_temperature_tf(dummy_logits_tf, temperature=0.8)
print("Temperature Sampled token index (TF):", temp_sample_tf.numpy())

# Top-K sampling example
topk_sample_tf = sample_top_k_tf(dummy_logits_tf, k=10)
print("Top-K Sampled token index (TF):", topk_sample_tf.numpy())

# Top-P sampling example
topp_sample_tf = sample_top_p_tf(dummy_logits_tf, p=0.9)
print("Top-P Sampled token index (TF):", topp_sample_tf.numpy())

Temperature Sampled token index (TF): [20]
Top-K Sampled token index (TF): [27]
Top-P Sampled token index (TF): [43]


### PyTorch Example

#### Importing Libraries

In [None]:
import torch
import torch.nn.functional as F

#### Top-K Sampling

In [None]:
def sample_top_k(logits, k=10):
    """
    Performs top-K sampling:
      1. Select the top k logits.
      2. Mask out the rest (set to -infinity).
      3. Renormalize and sample.
    """
    
    # Get the top k logits and their indices
    topk_logits, topk_indices = torch.topk(logits, k)
    
    # Create a mask that sets values not in top-k to -infinity
    mask = torch.full_like(logits, float('-inf'))
    mask[topk_indices] = topk_logits
    
    # Apply softmax to get a valid probability distribution
    probs = F.softmax(mask, dim=0)
    
    return torch.multinomial(probs, num_samples=1)

#### Top-P / Nucleus Sampling

In [None]:
def sample_top_p(logits, p=0.9):
    """
    Performs top-P (nucleus) sampling:
      1. Sort logits and compute softmax probabilities.
      2. Determine the minimal set of tokens where the cumulative probability exceeds p.
      3. Mask out tokens outside this set, renormalize, and sample.
    """
    
    # Sort logits in descending order
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    sorted_probs = F.softmax(sorted_logits, dim=0)
    
    # Compute cumulative probabilities
    cumulative_probs = torch.cumsum(sorted_probs, dim=0)
    
    # Determine which tokens to keep (ensure at least one token is kept)
    cutoff = cumulative_probs > p
    cutoff[0] = False  # Always keep the top token
    
    # Mask out tokens beyond the nucleus
    sorted_logits[cutoff] = float('-inf')
    
    # Renormalize the masked logits
    probs = F.softmax(sorted_logits, dim=0)
    sample = torch.multinomial(probs, num_samples=1)
    
    # Map the sample back to the original indices
    return sorted_indices[sample]

#### Temperature Sampling

In [None]:
def sample_with_temperature(logits, temperature=1.0):
    """
    Scale logits by temperature and sample one token.
    Lower temperatures (<1) sharpen the distribution, higher temperatures (>1) flatten it.
    """
    scaled_logits = logits / temperature
    probs = F.softmax(scaled_logits, dim=0)
    # Multinomial sampling from the probability distribution
    return torch.multinomial(probs, num_samples=1)

#### Usage

In [None]:
# Example usage with dummy logits
vocab_size = 50
dummy_logits = torch.randn(vocab_size)  # Example logits for 50 tokens

# Temperature sampling example
temp_sample = sample_with_temperature(dummy_logits, temperature=0.8)
print("Temperature Sampled token index:", temp_sample.item())

# Top-K sampling example
topk_sample = sample_top_k(dummy_logits, k=10)
print("Top-K Sampled token index:", topk_sample.item())

# Top-P sampling example
topp_sample = sample_top_p(dummy_logits, p=0.9)
print("Top-P Sampled token index:", topp_sample.item())
