<a href="https://colab.research.google.com/github/MostHumble/topNSigma/blob/main/TopNSigma_Sampling.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [7]:
import torch
from transformers import LogitsProcessor, LogitsProcessorList
from transformers.utils.doc import add_start_docstrings
from transformers.generation.logits_process import TemperatureLogitsWarper

class TopNSigmaLogitsWarper(LogitsProcessor):
    """
    [`LogitsProcessor`] that performs Top-nσ sampling. This method filters the logits based on their deviation
    from the maximum logit value.

    The filtering rule is to keep all tokens where the logit `l_i` satisfies `l_i > M - n * σ`, where `M` is the
    maximum logit, `σ` is the standard deviation of the logits, and `n` is a configurable multiplier. Logits
    that do not meet this condition are set to a filter value (typically -inf).

    This approach adapts the filtering based on the distribution of the logits. A flatter distribution (higher
    standard deviation) will result in a more aggressive filtering, while a spikier distribution (lower
    standard deviation) will be more permissive.

    Args:
        n (`float`):
            The multiplier for the standard deviation. This value controls the aggressiveness of the filtering.
            Higher values of `n` lead to more aggressive filtering (keeping fewer tokens). A typical range for
            `n` might be between 2 and 6.
        filter_value (`float`, *optional*, defaults to -inf):
            All filtered values will be set to this float value.
    """

    def __init__(self, n: float, filter_value: float = -float("Inf")):
        if not isinstance(n, float) or n < 0:
            raise ValueError(f"`n` has to be a non-negative float, but is {n}")

        self.n = n
        self.filter_value = filter_value

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # Calculate M (max logit) and sigma (standard deviation of logits) for each sequence in the batch
        max_logit, _ = torch.max(scores, dim=-1, keepdim=True)
        std_logit = torch.std(scores, dim=-1, keepdim=True)

        # Calculate the filtering threshold for each sequence
        threshold = max_logit - self.n * std_logit

        # Create a boolean mask for tokens to be removed
        tokens_to_remove = scores < threshold

        # Apply the filter
        scores_processed = scores.masked_fill(tokens_to_remove, self.filter_value)
        return scores_processed

if __name__ == '__main__':
    # Example Usage
    from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed

    # Set a seed for reproducibility
    set_seed(42)

    checkpoint = "HuggingFaceTB/SmolLM2-360M-Instruct"

    device = "cuda" # for GPU usage or "cpu" for CPU usage
    tokenizer = AutoTokenizer.from_pretrained(checkpoint)
    model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)

    # Input text
    messages = [{"role": "user", "content": "Can you explain the concept of gravity?"}]
    input_text=tokenizer.apply_chat_template(messages, tokenize=False)
    inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)

    print("--- Standard Sampling (for comparison) ---")
    # With standard sampling, the output can be quite random.
    outputs_standard = model.generate(
        inputs,
        do_sample=True,
        max_length=100,
        top_k=0 # Deactivate top-k to see the full effect of sampling
    )
    print(tokenizer.batch_decode(outputs_standard, skip_special_tokens=True)[0])
    print("-" * 50)


    print("--- TopNSigma Sampling (n=1.5) without Temperature ---")
    # With TopNSigma sampling, the output is constrained to more likely tokens.
    top_n_sigma_warper = TopNSigmaLogitsWarper(n=1.5)
    outputs_top_n_sigma = model.generate(
        inputs,
        do_sample=True,
        max_length=100,
        logits_processor=LogitsProcessorList([top_n_sigma_warper])
    )
    print(tokenizer.batch_decode(outputs_top_n_sigma, skip_special_tokens=True)[0])
    print("-" * 50)

    print("--- TopNSigma Sampling (n=0.5) with Temperature (T=0.7) ---")
    # Here we combine both processors. The order matters: typically filtering processors
    # like TopNSigma go before temperature scaling.
    temp_warper = TemperatureLogitsWarper(temperature=0.7)
    top_n_sigma_warper_combined = TopNSigmaLogitsWarper(n=0.5)

    outputs_combined = model.generate(
        inputs,
        do_sample=True,
        max_length=100,
        logits_processor=LogitsProcessorList([top_n_sigma_warper_combined, temp_warper])
    )
    print(tokenizer.batch_decode(outputs_combined, skip_special_tokens=True)[0])
    print("-" * 50)

--- Standard Sampling (for comparison) ---
system
You are a helpful AI assistant named SmolLM, trained by Hugging Face
user
Can you explain the concept of gravity?
assistant
Certainly! Gravity is a fundamental force of nature that keeps objects in alignment with each other and objects in the universe. It is widely expressed within various forms of science and engineering, interactive simulations, documents, software, and educational materials. Practitioners of gravitational manipulation use this energy to power gravitational principles to achieve various unexpected outcomes
--------------------------------------------------
--- TopNSigma Sampling (n=1.5) without Temperature ---
system
You are a helpful AI assistant named SmolLM, trained by Hugging Face
user
Can you explain the concept of gravity?
assistant
Sure, as a Hugging Face AI, I've been trained on many topics and can certainly summarize the concept of gravity.

Gravity, also known as 'g', is a fundamental force of nature that de