# ML - LLMs - Evaluation - Perplexity

Perplexity is a metric used to evaluate how well a probability model predicts a sample. It is calculated as the exponential of the cross-entropy, which reflects the model's uncertainty in predicting the next word.

The formula for perplexity is:
 
$ \text{Perplexity} = \exp\left(H(P)\right) $

By substituting the cross-entropy formula, perplexity can also be expressed as:

$ \text{Perplexity} = \exp\left(-\frac{1}{N} \sum_{i=1}^{N} \log P(x_i)\right) $

In essence, perplexity is a transformation of the cross-entropy loss, expressing the error in terms of the effective number of choices (or "confusion") the model faces. A lower cross-entropy corresponds to a lower perplexity, indicating better model predictions.

To compute perplexity:
1. Compute the softmax probabilities.
2. Calculate the cross-entropy loss for each sample in the batch.
3. Average the losses.
4. Compute the perplexity over the batch.

Learn more: [Perplexity Explanation](https://chatgpt.com/share/67fad511-d220-8009-a7b1-060df0840166)


In [1]:
import numpy as np

def compute_softmax(logits):
    """
    Compute the softmax probabilities for each row in the logits matrix.

    Parameters:
        logits (np.ndarray): A 2D array of shape (n_samples, n_classes)
    
    Returns:
        np.ndarray: A 2D array of the same shape with softmax probabilities.
    """
    # Exponentiate the logits
    exp_logits = np.exp(logits)
    # Sum along the classes (axis=1) and maintain dimensions for broadcasting
    sum_exp = np.sum(exp_logits, axis=1, keepdims=True)
    # Divide each exponential by the sum of exponentials for that sample
    softmax_probs = exp_logits / sum_exp
    return softmax_probs

def compute_cross_entropy_loss(softmax_probs, true_labels):
    """
    Compute the cross-entropy loss for each sample.

    Parameters:
        softmax_probs (np.ndarray): Predicted probabilities with shape (n_samples, n_classes)
        true_labels (np.ndarray): One-hot encoded true labels with shape (n_samples, n_classes)
    
    Returns:
        np.ndarray: A 1D array containing the cross-entropy loss for each sample.
    """
    # Use a small epsilon value to avoid log(0)
    epsilon = 1e-12
    # Compute the loss: -sum(y * log(probabilities)) for each sample
    losses = -np.sum(true_labels * np.log(softmax_probs + epsilon), axis=1)
    return losses

def calculate_perplexity(logits, true_labels):
    """
    Calculate the perplexity given logits and one-hot encoded true labels.
    
    Steps:
    1. Compute softmax probabilities from logits.
    2. Calculate the cross-entropy loss for each sample.
    3. Average the loss across the batch.
    4. Compute perplexity as the exponential of the average loss.
    
    Parameters:
        logits (np.ndarray): A 2D array of shape (n_samples, n_classes)
        true_labels (np.ndarray): One-hot encoded true labels of the same shape as logits.
        
    Returns:
        float: The perplexity computed over the batch.
    """
    # Step 1: Compute softmax probabilities
    softmax_probs = compute_softmax(logits)
    
    # Step 2: Compute cross-entropy loss per sample
    losses = compute_cross_entropy_loss(softmax_probs, true_labels)
    
    # Step 3: Compute the average cross-entropy loss over the batch
    avg_loss = np.mean(losses)
    
    # Step 4: Calculate perplexity as the exponential of the average loss
    perplexity = np.exp(avg_loss)
    
    return perplexity

# -----------------------------
# Example usage of the functions:
# -----------------------------

# Step 1: Define Batch Logits and True Labels
# Assume we have a batch of 3 examples, each with logits for 3 classes.
logits = np.array([
    [2.0, 1.0, 0.1],   # Sample 1
    [1.5, 0.5, 0.0],   # Sample 2
    [0.2, 1.2, 0.5]    # Sample 3
])
print("Logits:\n", logits)

# True labels in one-hot encoded form for each sample.
true_labels = np.array([
    [1, 0, 0],  # Sample 1: true class is Class 0
    [0, 1, 0],  # Sample 2: true class is Class 1
    [0, 0, 1]   # Sample 3: true class is Class 2
])
print("True Labels (one-hot):\n", true_labels)

# Calculate perplexity from logits and true labels
perplexity = calculate_perplexity(logits, true_labels)
print("Perplexity over the Batch:", perplexity)


Logits:
 [[2.  1.  0.1]
 [1.5 0.5 0. ]
 [0.2 1.2 0.5]]
True Labels (one-hot):
 [[1 0 0]
 [0 1 0]
 [0 0 1]]
Perplexity over the Batch: 2.909916162855865
