<a href="https://colab.research.google.com/github/Yogesh914/spec-decode-optimal-transport/blob/main/optimal_transport_sd.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install pot

Collecting pot
  Downloading POT-0.9.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (32 kB)
Downloading POT-0.9.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (835 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/835.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m835.4/835.4 kB[0m [31m45.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pot
Successfully installed pot-0.9.4


In [None]:
import abc
import torch
from torch import Tensor
from torch.nn import functional as F


class LogitsProcessor(abc.ABC):
    """Logits processors for sampling."""

    def __init__(self, temperature: float):
        self.temperature = temperature

    def __call__(self, logits: Tensor) -> Tensor:
        proc = self._process(logits)
        return F.softmax(proc / self.temperature, dim=-1)

    @abc.abstractmethod
    def _process(self, logits: Tensor) -> Tensor:
        pass

    @abc.abstractmethod
    def sample(self, probs: Tensor) -> Tensor:
        pass


class GreedyProcessor(LogitsProcessor):
    """Greedy: Most probable token."""

    def __init__(self, temperature: float = 1):
        super().__init__(temperature)

    def _process(self, logits: Tensor) -> Tensor:
        return logits

    def sample(self, probs: Tensor) -> Tensor:
        return torch.argmax(probs, dim=-1).unsqueeze(-1)


class MultinomialProcessor(LogitsProcessor):
    """Multinomial: Random sampling."""

    def __init__(self, temperature: float):
        super().__init__(temperature)

    def _process(self, logits: Tensor) -> Tensor:
        return logits

    def sample(self, probs: Tensor) -> Tensor:
        return torch.multinomial(probs, num_samples=1)


class TopKProcessor(MultinomialProcessor):
    """Top-k: Top-k sampling."""

    def __init__(self, temperature: float, top_k: int):
        super().__init__(temperature)
        self.top_k = top_k

    def _process(self, logits: Tensor) -> Tensor:
        top_k = min(self.top_k, logits.size(-1))
        indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
        logits[indices_to_remove] = -1e20
        return logits


class NucleusProcessor(MultinomialProcessor):
    """Nucleus: Top-p sampling."""

    def __init__(self, temperature: float, top_p: float):
        super().__init__(temperature)
        self.top_p = top_p

    def _process(self, logits: Tensor) -> Tensor:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
        sorted_indices_to_remove = cumulative_probs > self.top_p
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0
        sorted_logits[sorted_indices_to_remove] = -1e20
        logits = torch.gather(sorted_logits, -1, sorted_indices.argsort(-1))
        return logits


class TopKNucleusProcessor(MultinomialProcessor):
    """Top-k and nucleus: Top-k sampling with top-p fallback."""

    def __init__(self, temperature: float, top_k: int, top_p: float):
        super().__init__(temperature)
        self.top_k = top_k
        self.top_p = top_p

    def _process(self, logits: Tensor) -> Tensor:
        top_k = min(self.top_k, logits.size(-1))
        indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
        logits[indices_to_remove] = -1e20
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
        sorted_indices_to_remove = cumulative_probs > self.top_p
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0
        sorted_logits[sorted_indices_to_remove] = -1e20
        logits = torch.gather(sorted_logits, -1, sorted_indices.argsort(-1))
        return logits

In [None]:
from typing import Tuple, Union
from torch import Tensor
from transformers.cache_utils import DynamicCache


def prune_cache(cache: Union[Tuple[Tuple[Tensor, Tensor]], DynamicCache], num_tokens_to_discard: int):
    """
    Prune the cache by removing the specified number of tokens from the end.

    Args:
        cache (Union[Tuple[Tuple[Tensor, Tensor]], DynamicCache]): The KV cache to be pruned.
        num_tokens_to_discard (int): The number of tokens to discard from the end of the cache.

    Returns:
        Union[Tuple[Tuple[Tensor, Tensor]], DynamicCache]: The pruned KV cache.
    """
    if cache is None:
        return None
    if isinstance(cache, tuple):
        return prune_tuple_cache(cache, num_tokens_to_discard)
    elif isinstance(cache, DynamicCache):
        return prune_dynamic_cache(cache, num_tokens_to_discard)
    else:
        raise ValueError("Unsupported cache type.")


def prune_tuple_cache(cache: Tuple[Tuple[Tensor, Tensor]], num_tokens_to_discard: int):
    """
    Prune the cache by removing the specified number of tokens from the end. This pruning works for most models.
    It works for models having past_key_values such as Tuple of tuple(Tensor) of length n_layers, containing 2 or 4 tensors of shape (batch_size, num_heads, sequence_length, embed_size_per_head)

    Args:
        cache Tuple(Tuple[Tensor, Tensor]): The KV cache to be pruned.
        num_tokens_to_discard (int): The number of tokens to discard from the end of the cache.

    Returns:
        Tuple[Tensor, Tensor]: The pruned KV cache.
    """
    if cache is None:
        return None

    new_cache = []
    for layer_cache in cache:
        if layer_cache is None:
            new_cache.append(None)
            continue

        layer = []
        for i in range(len(layer_cache)):
            tensor = layer_cache[i]
            new_tensor = tensor[:, :, :-num_tokens_to_discard, :]
            layer.append(new_tensor)
        new_cache.append(tuple(layer))

    return tuple(new_cache)


def prune_dynamic_cache(cache: DynamicCache, num_tokens_to_discard: int):
    """
    Prune the cache by removing the specified number of tokens from the end. This pruning works for models using DynamicCache.

    Args:
        cache (DynamicCache): The KV cache to be pruned.
        num_tokens_to_discard (int): The number of tokens to discard from the end of the cache.

    Returns:
        DynamicCache: The pruned KV cache. (same instance as the input cache, but modified in place)
    """
    if cache is None:
        return None

    for layer in range(len(cache)):
        cache.key_cache[layer] = cache.key_cache[layer][:, :, :-num_tokens_to_discard, :]
        cache.value_cache[layer] = cache.value_cache[layer][:, :, :-num_tokens_to_discard, :]
    cache._seen_tokens -= num_tokens_to_discard

    return cache

In [None]:
from typing import List, Tuple
from termcolor import colored
from torch import Tensor


def token_ids_to_string(token_ids, tokenizer):
    """Convert token ids to string.

    Args:
        token_ids (List[int]): List of token ids.
        tokenizer (Tokenizer): Tokenizer.

    Returns:
        str: String representation of token ids.
    """
    strings = tokenizer.convert_ids_to_tokens(token_ids)
    return " ".join(strings)


def end_token_found(location: int):
    print(colored(f"End token found at position {location}", "red"))


def initial_step(token: Tensor, tokenizer):
    print(f"{colored('Initiale Step', on_color='on_dark_grey', color='white')} 1 token:")
    print(colored(token_ids_to_string(token, tokenizer), "blue"))


def speculative_step(
    tokenizer,
    current_inputs: Tensor,
    inputs: Tensor,
    n: int,
    prompt_end: int,
    current_position: int,
    corrected_gamma: int,
):
    print(f"{colored('Speculative Step', on_color='on_dark_grey', color='white')} {n} draft{'s' if n > 1 else ''} + 1 token:")
    print(token_ids_to_string(inputs[0, prompt_end:current_position], tokenizer), end=" ")
    print(colored(token_ids_to_string(inputs[0, current_position : current_position + n], tokenizer), "green"), end=(" " if n > 0 else ""))
    print(colored(token_ids_to_string(current_inputs[0, current_position + n : current_position + corrected_gamma], tokenizer), "red"), end=(" " if n < corrected_gamma else ""))
    print(colored(token_ids_to_string(inputs[..., current_position + n], tokenizer), "blue"))


def beam_search_step(possibilities: List[Tuple[float, Tensor, Tensor]], current_position: int, tokenizer):
    print(f"{colored('Beam Search Step', on_color='on_dark_grey', color='white')} Token {current_position}:")

    for i, (prob, tokens, _) in enumerate(possibilities):
        print(f"{i+1}. {prob:.3f}\t{token_ids_to_string(tokens[:current_position - 1], tokenizer)} {colored(token_ids_to_string(tokens[current_position - 1:current_position], tokenizer), 'blue')}")

# Speculative Decoding

In [None]:
import torch
from torch.nn import Module
from transformers.cache_utils import DynamicCache
from typing import List, Tuple
import ot
import numpy as np


def max_fn(x: torch.Tensor) -> torch.Tensor:
    """
    Max function.
        x: input tensor.
    Returns:
        tensor norm(max(0, x)).
    """
    x_max = torch.where(x > 0, x, torch.zeros_like(x))
    x_max_sum = torch.sum(x_max, dim=-1, keepdim=True)
    return x_max / x_max_sum

@torch.no_grad()
def speculative_generate(
    inputs: List[int],
    drafter: Module,
    target: Module,
    tokenizer,
    gamma: int = 5,
    logits_processor: LogitsProcessor = GreedyProcessor(temperature=1),
    max_gen_len: int = 40,
    eos_tokens_id: int | List[int] = 1,
    pad_token_id: int = 0,
    use_cache: bool = False,
    skip_sample_adjustment: bool = False,
    first_target: bool = True,
    debug: bool = False,
    use_optimal_transport: bool = False,  # New parameter
    top_n_transport: int = 1000
) -> Tuple[List[int], float]:
    """
    Generate text sequence using the speculative decoding algorithm.
    Implementation of Speculative Decoding. (https://arxiv.org/pdf/2211.17192.pdf)

    Args:
        inputs (List[int]): input sequence of batch size 1.
        drafter (Module): drafter model.
        target (Module): target model.
        tokenizer: tokenizer.
        gamma (int): number of drafts generated by the drafter at each step.
        logits_processor (LogitsProcessor): logits processor for sampling.
        max_gen_len (int): maximum length of the generated sequence.
        eos_tokens_id (int or List[int]): end token id (could be multiple).
        pad_token_id (int): pad token id.
        use_cache (bool): whether to use cache.
        skip_sample_adjustment (bool): whether to skip the sample adjustment step when some drafts are discarded.
        first_target (bool): whether to run the target model before the speculative algorithm.
        debug (bool): debug mode.

    Returns:
        List[int]: generated sequence.
        float: acceptance rate (number of accepted drafts divided by the number of total drafts).

    Note: This generation methods only works for decoder-only models.
    Note bis: The drafter and target models should output the same logits shape.
    Note ter: NgramModels are currently not supported.
    """

    drafter_cache, target_cache = None, None

    list_tokens_id = eos_tokens_id if isinstance(eos_tokens_id, list) else [eos_tokens_id]
    stop_tokens = torch.tensor(list_tokens_id, dtype=torch.long, device=target.device).unsqueeze(1)

    drafts_accepted, drafts_speculated = .0, .0

    vocabulary_size = target.config.vocab_size

    # prepare input tensor
    prompt_len = len(inputs)
    max_seq_length = target.config.max_position_embeddings if hasattr(target.config, 'max_position_embeddings') else (target.config.max_context_length if hasattr(target.config, 'max_context_length') else 1024)
    total_len = min(max_seq_length, prompt_len + max_gen_len)
    input_ids = torch.full((1, total_len), pad_token_id, dtype=torch.long, device=target.device)
    input_ids[0, :prompt_len] = torch.tensor(inputs, dtype=torch.long, device=target.device)

    current_position = prompt_len

    if first_target:
        # run the target model before the speculative algorithm. Allows to prefill the kvcache and get a first token.
        Mp = target(
            input_ids=input_ids[..., :current_position],
            past_key_values=target_cache,
            use_cache=use_cache,
        )
        target_cache = Mp.past_key_values
        p_p = logits_processor(Mp.logits[..., -1, :])
        t = logits_processor.sample(p_p)
        input_ids[0, current_position] = t
        current_position += 1

        if torch.isin(t, stop_tokens):
            if debug:
                printing.end_token_found(0)
            return input_ids[0, prompt_len:current_position].tolist(), 0

        if debug:
            printing.initial_step(t, tokenizer)

    while current_position < total_len:
        corrected_gamma = min(gamma, total_len - current_position - 1)
        q = torch.zeros((1, corrected_gamma, vocabulary_size), device=target.device)

        input_ids = input_ids.to(drafter.device)

        # generate gamma drafts
        for k in range(corrected_gamma):
            Mq = drafter(
                input_ids=input_ids[..., :current_position + k],
                past_key_values=drafter_cache,
                use_cache=use_cache,
            )
            drafter_cache = Mq.past_key_values

            draft_logits = Mq.logits[..., -1, :]
            draft_probs = logits_processor(draft_logits)
            q[0, k] = draft_probs.to(target.device)
            xi = logits_processor.sample(draft_probs)
            input_ids[0, current_position + k] = xi
        drafts_speculated += corrected_gamma
        input_ids = input_ids.to(target.device)

        # run target model on drafts and get logits of the previous tokens plus one more token
        Mp = target(
            input_ids=input_ids[..., :current_position + corrected_gamma],
            past_key_values=target_cache,
            use_cache=use_cache,
        )
        target_cache = Mp.past_key_values
        draft_logits = Mp.logits[..., current_position - 1:current_position + corrected_gamma - 1, :] # [1, corrected_gamma, vocab_size]
        p = logits_processor(draft_logits) # [1, gamma, vocab_size]

        if use_optimal_transport:
          # Perform optimal transport alignment of q and p
          for k in range(corrected_gamma):
              q_k = q[0, k]
              p_k = p[0, k]

              # Get top N tokens
              q_k_top_values, q_k_top_indices = torch.topk(q_k, top_n_transport)
              p_k_top_values, p_k_top_indices = torch.topk(p_k, top_n_transport)

              # Get union of indices
              top_indices = torch.unique(torch.cat([q_k_top_indices, p_k_top_indices]))
              if top_indices.shape[0] > top_n_transport:
                  top_indices = top_indices[:top_n_transport]

              # Get reduced distributions
              q_k_reduced = q_k[top_indices]
              p_k_reduced = p_k[top_indices]

              # Check for NaN or zero values
              if torch.isnan(q_k_reduced).any() or torch.isnan(p_k_reduced).any():
                  if debug:
                      print("NaN values found in q_k_reduced or p_k_reduced, skipping this step.")
                  continue  # Skip this iteration

              q_k_sum = q_k_reduced.sum()
              p_k_sum = p_k_reduced.sum()

              if q_k_sum.item() == 0 or p_k_sum.item() == 0:
                  if debug:
                      print("Sum of q_k_reduced or p_k_reduced is zero, skipping this step.")
                  continue  # Skip this iteration

              # Normalize
              q_k_reduced = q_k_reduced / q_k_sum
              p_k_reduced = p_k_reduced / p_k_sum

              # Ensure sums are equal
              mean_sum = (q_k_reduced.sum() + p_k_reduced.sum()) / 2
              q_k_reduced = q_k_reduced * (mean_sum / q_k_reduced.sum())
              p_k_reduced = p_k_reduced * (mean_sum / p_k_reduced.sum())

              # Convert to numpy
              q_k_np = q_k_reduced.cpu().numpy()
              p_k_np = p_k_reduced.cpu().numpy()

              # Verify that sums are equal
              if not np.isclose(q_k_np.sum(), p_k_np.sum()):
                  if debug:
                      print(f"Sums not equal after normalization: q_k_np.sum()={q_k_np.sum()}, p_k_np.sum()={p_k_np.sum()}")
                  continue  # Skip this iteration

              # Define cost matrix
              C = torch.ones((top_indices.shape[0], top_indices.shape[0]), device=target.device) - torch.eye(top_indices.shape[0], device=target.device)
              C_np = C.cpu().numpy()

              # Compute optimal transport plan
              ot_plan = ot.emd(q_k_np, p_k_np, C_np)

              # Adjust q_k_reduced using the transport plan
              q_k_aligned_np = ot_plan.sum(axis=1)

              # Update q[0, k] with adjusted values
              q_k_aligned = torch.zeros_like(q_k)
              q_k_aligned[top_indices] = torch.from_numpy(q_k_aligned_np).to(q_k.device)

              # Normalize
              q_k_aligned_sum = q_k_aligned.sum()
              if q_k_aligned_sum.item() == 0:
                  if debug:
                      print("Sum of q_k_aligned is zero after optimal transport, skipping this step.")
                  continue  # Skip this iteration

              q_k_aligned = q_k_aligned / q_k_aligned_sum
              q[0, k] = q_k_aligned

        # compute the last accepted draft position (rejection sampling)
        r = torch.rand(corrected_gamma, device=target.device)
        fractions = p / q
        n = corrected_gamma
        for i in range(corrected_gamma):
            if r[i] > fractions[0, i, input_ids[0, current_position + i]]:
                n = i
                break

        drafts_accepted += n

        # check if the end token is in the drafts
        stop_locations = torch.nonzero(torch.eq(input_ids[..., current_position:current_position + n], stop_tokens))
        if stop_locations.shape[0] > 0:
            stop_location = stop_locations[0, 1].item()
            if debug:
                printing.end_token_found(stop_location)
            return input_ids[0, prompt_len:current_position + stop_location + 1].tolist(), drafts_accepted / drafts_speculated

        # adjust the distribution from Mp
        if n == corrected_gamma:
            p_p = Mp.logits[..., current_position + corrected_gamma - 1, :]
            p_p = logits_processor(p_p)
        else:
            # prune the cache
            if use_cache:
                drafter_cache = prune_cache(drafter_cache, corrected_gamma - n)
                target_cache = prune_cache(target_cache, corrected_gamma - n + 1)

            if not skip_sample_adjustment:
                p_p = max_fn(p[..., n, :] - q[0, n, :])
            else:
                p_p = p[..., n, :]
        x = logits_processor.sample(p_p)

        if debug:
            generated = input_ids.clone().detach()

        input_ids[0, current_position + n:current_position + corrected_gamma] = pad_token_id
        input_ids[0, current_position + n] = x

        if debug:
            printing.speculative_step(tokenizer, generated, input_ids, n, prompt_len, current_position, corrected_gamma)

        current_position += n + 1

        if torch.isin(x, stop_tokens):
            if debug:
                printing.end_token_found(n)
            return input_ids[0, prompt_len:current_position].tolist(), drafts_accepted / drafts_speculated

    return input_ids[0, prompt_len:].tolist(), drafts_accepted / drafts_speculated

# Normal A

In [None]:
from math import inf
import torch
from torch.nn import Module
from typing import List


@torch.no_grad()
def autoregressive_generate(
    inputs: List[int],
    model: Module,
    max_gen_len: int = 40,
    logits_processor: LogitsProcessor = GreedyProcessor(temperature=1),
    eos_tokens_id: int | List[int] = 1,
    pad_token_id: int = 0,
    use_cache: bool = False,
    debug: bool = False,
) -> List[int]:
    """
    Generate text sequence autoregressively based on the input sequence.

    Args:
        inputs (List[int]): input sequence of batch size 1.
        model (Module): model to use for inference.
        max_gen_len (int): maximum length of the generated sequence.
        logits_processor (LogitsProcessor): logits processor for sampling.
        eos_token_id (int): end token id.
        pad_token_id (int): pad token id.
        use_cache (bool): whether to use cache.

    Returns:
        List[int]: generated sequence.

    Note:
        This generation methods only works for decoder-only models.
    """
    cache = None
    prompt_len = len(inputs)
    # prepare input tensor
    max_seq_length = model.config.max_position_embeddings if hasattr(model.config, 'max_position_embeddings') else (model.config.max_context_length if hasattr(model.config, 'max_context_length') else 1024)
    total_len = min(max_seq_length, prompt_len + max_gen_len)
    input_ids = torch.full((1, total_len), pad_token_id, dtype=torch.long, device=model.device)
    input_ids[0, :prompt_len] = torch.tensor(inputs, dtype=torch.long, device=model.device)

    list_tokens_id = (
        eos_tokens_id if isinstance(eos_tokens_id, list) else [eos_tokens_id]
    )
    stop_tokens = torch.tensor(list_tokens_id, dtype=torch.long, device=model.device)

    for curr in range(prompt_len, total_len):
        o = model(input_ids[..., :curr], past_key_values=cache, use_cache=use_cache)
        logits = o.logits[..., -1, :]  # [1, vocab_size]
        probs = logits_processor(logits)  # [1, vocab_size]
        x = logits_processor.sample(probs)  # [1, 1]
        input_ids[0, curr] = x
        cache = o.past_key_values

        # check for end token
        if torch.isin(x, stop_tokens):
            if debug:
                printing.end_token_found(curr)
            break

    return input_ids[0, prompt_len : curr + 1].tolist()


@torch.no_grad()
def beam_search_generate(
    inputs: List[int],
    model: Module,
    max_gen_len: int = 40,
    num_beams: int = 4,
    top_k: int = 3,
    min_length: float = 5.0,
    alpha: float = 1.2,
    eos_tokens_id: int | List[int] = 1,
    pad_token_id: int = 0,
    debug: bool = False,
    tokenizer=None,
) -> List[int]:
    """
    Generate text sequence using beam search based on the input sequence.

    Args:
        inputs (List[int]): input sequence of batch size 1.
        model (Module): model to use for inference.
        max_gen_len (int): maximum length of the generated sequence.
        num_beams (int): number of beams.
        top_k (int): number of top k to consider at each beam.
        min_length (float): length penalty.
        alpha (float): alpha parameter of beam search decoding.
        eos_token_id (int): end token id.
        pad_token_id (int): pad token id.
        debug (bool): whether to print debug information.
        tokenizer: tokenizer for debug.

    Returns:
        List[int]: generated sequence.

    Note:
        This generation methods only works for decoder-only models.
        Cache is not available yet.
    """

    def _length_penalty_fn(length, alpha, min_length):
        return ((min_length + length) / (min_length + 1)) ** alpha

    prompt_len = len(inputs)
    max_seq_length = model.config.max_position_embeddings if hasattr(model.config, 'max_position_embeddings') else (model.config.max_context_length if hasattr(model.config, 'max_context_length') else 1024)

    assert prompt_len < max_seq_length, "Prompt length exceeds maximum sequence length."

    total_len = min(max_seq_length, prompt_len + max_gen_len)
    input_ids = torch.full((num_beams, total_len), pad_token_id, dtype=torch.long, device=model.device)
    input_ids[:, :prompt_len] = torch.tensor(inputs, dtype=torch.long, device=model.device)
    probs = torch.full((num_beams, total_len), torch.finfo(torch.float).min, dtype=torch.float, device=model.device)
    beams_probs = torch.full((num_beams,), torch.finfo(torch.float).min, dtype=torch.float, device=model.device)
    last_indexes = torch.full((num_beams,), -1, dtype=torch.long, device=model.device)

    stop_tokens = torch.tensor((eos_tokens_id if isinstance(eos_tokens_id, list) else [eos_tokens_id]), dtype=torch.long, device=model.device)

    # prefill
    probs[:, :prompt_len] = 1.0
    beams_probs[:] = 1.0
    o = model(input_ids[:, :prompt_len])
    curr_prob = torch.nn.functional.log_softmax(o.logits[0, -1, :], dim=-1)
    top_probs, top_tokens = torch.topk(curr_prob, num_beams, dim=-1)
    input_ids[:, prompt_len] = top_tokens
    probs[:, prompt_len] = probs[:, prompt_len - 1] + top_probs
    beams_probs[:] = probs[:, prompt_len] / _length_penalty_fn(1, alpha, min_length)

    for curr in range(prompt_len + 1, total_len):
        o = model(input_ids[:, :curr])
        logits = o.logits[:, -1, :]
        probs_curr = torch.nn.functional.log_softmax(logits, dim=-1)
        top_probs, top_tokens = torch.topk(probs_curr, top_k, dim=-1)
        possibilities = []
        for beam in range(num_beams):
            if last_indexes[beam] != -1:
                prob_vec = probs[beam].detach().clone()
                input_vec = input_ids[beam].detach().clone()
                possibilities.append(
                    (beams_probs[beam], input_vec, prob_vec, last_indexes[beam])
                )
                continue

            for possibility in range(top_k):
                new_prob = probs[beam, curr - 1] + top_probs[beam, possibility]
                lp = _length_penalty_fn(curr - prompt_len, alpha, min_length)
                prob_vec = probs[beam].detach().clone()
                prob_vec[curr] = new_prob
                input_vec = input_ids[beam].detach().clone()
                input_vec[curr] = top_tokens[beam, possibility]
                last_token_idx = -1
                if torch.isin(input_vec[curr], stop_tokens) or input_vec[curr] == pad_token_id:
                    last_token_idx = curr

                already_in = False
                for p in possibilities:
                    if torch.equal(p[1], input_vec):
                        already_in = True
                        break
                if not already_in:
                    possibilities.append((new_prob / (lp if lp != 0 else 1), input_vec, prob_vec, last_token_idx))

        possibilities.sort(key=lambda x: x[0], reverse=True)

        if debug:
            printing.beam_search_step(possibilities, curr, tokenizer)

        possibilities = possibilities[:num_beams]

        for beam in range(num_beams):
            beams_probs[beam] = possibilities[beam][0]
            input_ids[beam] = possibilities[beam][1]
            probs[beam] = possibilities[beam][2]
            last_indexes[beam] = possibilities[beam][3]

        if torch.all(last_indexes != -1):
            if debug:
                printing.end_token_found(curr)
            break

    last_indexes[last_indexes == -1] = total_len - 1

    return input_ids[0, prompt_len : last_indexes[0] + 1].tolist()

#Inference

In [None]:
!pip install -q -U accelerate bitsandbytes

In [None]:
seed = 123
import random
import numpy as np
import torch
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_8bit=True
)

target_model = "google/gemma-2-9b-it"

drafter_model = "google/gemma-2-2b-it"

target = AutoModelForCausalLM.from_pretrained(
    target_model,
    quantization_config=bnb_config,
    device_map="cuda"
)
target.eval()

tokenizer_name = target_model
if tokenizer_name != target_model:
    print(colored("Warning: Tokenizer is different from target model. Use with caution.", "red"))
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True)

drafter = AutoModelForCausalLM.from_pretrained(
    drafter_model,
    quantization_config=bnb_config,
    device_map="cuda",
)
drafter.eval()

model.safetensors.index.json:   0%|          | 0.00/39.1k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.90G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/3.67G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/173 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/47.0k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/838 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/241M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/187 [00:00<?, ?B/s]

Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 2304, padding_idx=0)
    (layers): ModuleList(
      (0-25): 26 x Gemma2DecoderLayer(
        (self_attn): Gemma2SdpaAttention(
          (q_proj): Linear8bitLt(in_features=2304, out_features=2048, bias=False)
          (k_proj): Linear8bitLt(in_features=2304, out_features=1024, bias=False)
          (v_proj): Linear8bitLt(in_features=2304, out_features=1024, bias=False)
          (o_proj): Linear8bitLt(in_features=2048, out_features=2304, bias=False)
          (rotary_emb): Gemma2RotaryEmbedding()
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear8bitLt(in_features=2304, out_features=9216, bias=False)
          (up_proj): Linear8bitLt(in_features=2304, out_features=9216, bias=False)
          (down_proj): Linear8bitLt(in_features=9216, out_features=2304, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (post_att

In [None]:
prefix = "Explain reinforcement learning in simple terms."

chat_templated = f"<bos><start_of_turn>user\n{prefix}<end_of_turn>\n<start_of_turn>model\n" # Gemma chat template
input_ids = tokenizer(chat_templated, return_tensors="pt").input_ids
input_ids = input_ids[0].tolist() # Generation methods require a list of ids

In [None]:
# Parameters
gen_len = 100       # Maximum number of tokens generated (could over pass when using speculative decoding)
gamma = 4           # Number of drafts generated by the drafter model at each step
logits_processor = NucleusProcessor(temperature=.6, top_p=.9) # Nucleus sampling with p=0.9 and T=0.6

import time
start_time = time.time()

# Generate text using the classic auto-regressive decoding (slow)
output_ids_ar = autoregressive_generate( # or autoregressive_generate_encoder_decoder for encoder-decoder models
                input_ids,
                target,
                logits_processor=logits_processor,
                max_gen_len=gen_len,
                eos_tokens_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.pad_token_id
            )
output_ar = tokenizer.decode(output_ids_ar, skip_special_tokens=True)

print("--- %s seconds ---" % (time.time() - start_time))

start_time = time.time()
# Generate text using the speculative decoding (faster)
output_ids_sd, alpha = speculative_generate(
                input_ids,
                drafter,
                target,
                tokenizer=tokenizer,
                logits_processor=logits_processor,
                gamma=gamma,
                max_gen_len=gen_len,
                eos_tokens_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.pad_token_id
            )
output_sd = tokenizer.decode(output_ids_sd, skip_special_tokens=True)

print("--- %s seconds ---" % (time.time() - start_time))

start_time = time.time()
# Generate text using the speculative decoding with ot(faster)
output_ids_sd, alpha_ot = speculative_generate(
                input_ids,
                drafter,
                target,
                tokenizer=tokenizer,
                logits_processor=logits_processor,
                gamma=gamma,
                max_gen_len=gen_len,
                eos_tokens_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.pad_token_id,
                use_optimal_transport=True
            )
output_sd_ot = tokenizer.decode(output_ids_sd, skip_special_tokens=True)

print("--- %s seconds ---" % (time.time() - start_time))

print("--- Auto-regressive decoding ---")
print(output_ar)

print("--- Speculative decoding ---")
print(output_sd)
print("Acceptance rate:", alpha)

print("--- Speculative decoding (OT) ---")
print(output_sd_ot)
print("Acceptance rate:", alpha_ot) # Number of drafts accepted by the target model divided by the number of drafts generated

--- 46.797274112701416 seconds ---
--- 46.78878664970398 seconds ---
--- 45.20297884941101 seconds ---
--- Auto-regressive decoding ---
Imagine you're teaching a dog a new trick. You don't tell the dog exactly what to do, step-by-step. Instead, you reward it with treats when it does something close to what you want, and you might scold it if it does something wrong.

Over time, the dog learns which actions lead to treats (good things) and which actions lead to scolding (bad things). It starts to repeat the actions that bring rewards and avoids the ones that bring punishment
--- Speculative decoding ---
Imagine you're teaching a dog a new trick.

* You **reward** the dog with a treat when it does something right (like sitting).
* You **don't reward** it when it does something wrong (like barking).

Over time, the dog learns to associate the desired behavior with a reward and starts doing it more often.

That's basically how reinforcement learning works!

**In simple terms:**

* **Agent:

In [None]:
import argparse
import random
import numpy as np
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    QuantoConfig,
)
import time
import os
from termcolor import colored


class InferenceCLI:

    def __init__(self, device: str = "cuda"):
        print(
            colored("Speculative Decoding", "red"),
            colored("CLI", on_color="on_red", color="white"),
            "\n",
        )
        self.device = device

        self.gamma = 4
        self.gen_len = 35
        self.debug = False
        self.spec = True
        self.dr = False
        self.cache = False
        self.target_gen = True
        self.chat = True # If using a chat instructed model, set to True

        self.processors = {
            "greedy": {
                "processor": GreedyProcessor,
                "building_args": {"temperature": float},
            },
            "multinomial": {
                "processor": MultinomialProcessor,
                "building_args": {"temperature": float},
            },
            "topk": {
                "processor": TopKProcessor,
                "building_args": {"temperature": float, "top_k": int},
            },
            "nucleus": {
                "processor": NucleusProcessor,
                "building_args": {"temperature": float, "top_p": float},
            },
            "topknucleus": {
                "processor": TopKNucleusProcessor,
                "building_args": {"temperature": float, "top_k": int, "top_p": float},
            },
        }
        self.selected_processor = {
            "name": "greedy",
            "processor": GreedyProcessor,
            "args": {"temperature": 1.0},
        }
        self.processor = GreedyProcessor(temperature=1.0)

        self._load_models()
        self._run()

    def _load_models(self):
        # Target model
        target_model = "google/gemma-2-2b-it"
        target_quantize = QuantoConfig(weights="int8")  # QuantoConfig(weights="int8")  None

        # Drafter model
        drafter_model = "google/gemma-2-9b-it"
        drafter_quantize = QuantoConfig(weights="int8")  # QuantoConfig(weights="int8") None

        print(colored("Target model:", on_color="on_yellow"), target_model)
        print(colored("Drafter model:", on_color="on_yellow"), drafter_model)
        print(colored("Loading models...", "light_grey"))

        self.target = AutoModelForCausalLM.from_pretrained(
            target_model,
            quantization_config=target_quantize,
            device_map=self.device,
            trust_remote_code=True,
        )
        self.target.eval()

        tokenizer_name = target_model
        if tokenizer_name != target_model:
            print(colored("Warning: Tokenizer is different from target model. Use with caution.", "red"))
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True)

        self.drafter = AutoModelForCausalLM.from_pretrained(
            drafter_model,
            quantization_config=drafter_quantize,
            device_map=self.device,
            trust_remote_code=True,
        )
        self.drafter.eval()

        self.end_tokens = self.tokenizer.eos_token_id

    def _perform_command(self, command: str):
        args = command.split(" ")
        if args[0] == "/quit":
            print(colored("Goodbye!", on_color="on_red"))
            exit(0)
        if args[0] == "/debug":
            self.debug = not self.debug
            print(colored(f"Debug mode: {self.debug}", on_color="on_blue"))
            return
        if args[0] == "/speculative":
            self.spec = not self.spec
            print(colored(f"Speculative Decoding generation: {self.spec}", on_color="on_blue"))
            return
        if args[0] == "/drafter":
            self.dr = not self.dr
            print(colored(f"Drafter generation: {self.dr}", on_color="on_blue"))
            return
        if args[0] == "/cache":
            self.cache = not self.cache
            print(colored(f"Cache: {self.cache}", on_color="on_blue"))
            if self.cache:
                print(colored("Warning, cache feature is very unstable accross different models. It might generate errors or just perturb the generation. Use with caution.", "red"))
            return
        if args[0] == "/target":
            self.target_gen = not self.target_gen
            print(colored(f"Target generation: {self.target_gen}", on_color="on_blue"))
            return
        if args[0] == "/chat":
            self.chat = not self.chat
            print(colored(f"Chat mode: {self.chat}", on_color="on_blue"))
            return
        if args[0] == "/length":
            if len(args) < 2:
                print(colored("Usage: /length <value>", "red"))
                return
            self.gen_len = int(args[1])
            print(colored(f"Generation length: {int(args[1])}", on_color="on_blue"))
            return
        if args[0] == "/gamma":
            if len(args) < 2:
                print(colored("Usage: /gamma <value>", "red"))
                return
            self.gamma = int(args[1])
            print(colored(f"Gamma: {int(args[1])}", on_color="on_blue"))
            return
        if args[0] == "/clear":
            os.system("cls" if os.name == "nt" else "clear")
            return
        if args[0] == "/processor":
            # /processor <processor_name> <args0> <args1> ...
            if len(args) < 2:
                print(colored("Usage: /processor <processor_name> <args0> <args1> ...", "red"))
                return
            processor_name = args[1]
            if processor_name not in self.processors:
                print(colored("Invalid processor name", "red"))
                print(colored("Available processors:", "red"))
                for processor in self.processors.keys():
                    print(colored(f"\t{processor}", "red"))
                return
            processor = self.processors[processor_name]
            print(colored(f"Selected processor: {processor_name}", "blue"))
            building_args = processor["building_args"]
            args = args[2:]
            processor_args = {}
            for arg_name, arg_type in building_args.items():
                if len(args) == 0:
                    print(colored(f"Missing argument {arg_name}", "red"))
                    return
                try:
                    processor_args[arg_name] = arg_type(args[0])
                    print(colored(f"\t{arg_name}: {arg_type(args[0])}", "blue"))
                except ValueError:
                    print(colored(f"Invalid argument {arg_name} of type {arg_type}", "red"))
                    return
                args = args[1:]
            self.selected_processor = {
                "name": processor_name,
                "processor": processor["processor"],
                "args": processor_args,
            }
            self.processor = processor["processor"](**processor_args)
            return

        self._help()

    def _help(self):
        print(colored("Commands:", on_color="on_blue"))
        print("/quit: quit the program")
        print("/debug: toggle speculative debug mode")
        print(colored(f"\t{self.debug}", "green" if self.debug else "red"))
        print("/clear: clear the screen")
        print("/speculative: toggle speculative decoding")
        print(colored(f"\t{self.spec}", "green" if self.spec else "red"))
        print("/target: toggle target generation")
        print(colored(f"\t{self.target_gen}", "green" if self.target_gen else "red"))
        print("/drafter: toggle drafter generation")
        print(colored(f"\t{self.dr}", "green" if self.dr else "red"))
        print("/cache: toggle cache")
        print(colored(f"\t{self.cache}", "green" if self.cache else "red"))
        print("/chat: toggle chat mode")
        print(colored(f"\t{self.chat}", "green" if self.chat else "red"))
        print("/length <value>: set generation length")
        print(colored(f"\t{self.gen_len}", "blue"))
        print("/gamma <value>: set gamma")
        print(colored(f"\t{self.gamma}", "blue"))
        print("/processor <processor_name> [args0] [args1] ...: set processor")
        print(colored(f"\t{self.selected_processor['name']}", "blue"))
        for arg_name, arg_value in self.selected_processor["args"].items():
            print(colored(f"\t\t{arg_name}: {arg_value}", "blue"))


    def _infer(self, prefix: str):
        if self.chat:
            prefix = self.tokenizer.apply_chat_template([{"role": "user", "content": prefix}], add_generation_prompt=True, tokenize=False)

        tokenized = self.tokenizer(prefix, return_tensors="pt").input_ids[0].tolist()

        spec_throughput = 0.0
        base_throughput = 0.0
        drafter_throughput = 0.0

        if self.spec:
            self._set_seed(42)
            spec_start_time = time.time()
            output_ids, accept_rate = speculative_generate(
                tokenized,
                self.drafter,
                self.target,
                tokenizer=self.tokenizer,
                logits_processor=self.processor,
                gamma=self.gamma,
                max_gen_len=self.gen_len,
                eos_tokens_id=self.end_tokens,
                debug=self.debug,
                use_cache=self.cache,
            )
            spec_end_time = time.time()
            spec_output = self.tokenizer.decode(output_ids, skip_special_tokens=True)
            print(colored("========== Speculative ==========", "green"))
            print(colored("Out:", "green"), spec_output)
            print(colored(f"Acceptance rate: {accept_rate:.3f}", "green"))
            spec_throughput = len(spec_output) / (spec_end_time - spec_start_time)
            print(colored(f"Throughput: {spec_throughput:.1f} tokens/s", "green"))
            print(colored("========== Speculative ==========", "green"))

        if self.target_gen:
            self._set_seed(42)
            start_time = time.time()
            output_ids = autoregressive_generate(
                tokenized,
                self.target,
                use_cache=self.cache,
                max_gen_len=self.gen_len,
                eos_tokens_id=self.end_tokens,
                logits_processor=self.processor,
                debug=self.debug,
            )
            end_time = time.time()
            output = self.tokenizer.decode(output_ids, skip_special_tokens=True)
            print(colored("=========== Target AR ===========", "blue"))
            print(colored("Out:", "blue"), output)
            base_throughput = len(output) / (end_time - start_time)
            print(colored(f"Throughput: {base_throughput:.1f} tokens/s", "blue"))
            print(colored("=========== Target AR ===========", "blue"))
            if self.spec and base_throughput > 0.0:
                print(colored(f"Throughput increase: {((spec_throughput / base_throughput)) * 100:.1f}%", "magenta"))

        if self.dr:
            self._set_seed(42)
            output_ids = autoregressive_generate(
                tokenized,
                self.drafter,
                use_cache=self.cache,
                max_gen_len=self.gen_len,
                eos_tokens_id=self.end_tokens,
                logits_processor=self.processor,
                debug=self.debug,
            )
            output = self.tokenizer.decode(output_ids, skip_special_tokens=True)

            print(colored("========== Drafter AR ==========", "cyan"))
            drafter_throughput = len(output) / (end_time - start_time)
            print(colored("Out:", "cyan"), output)
            print(colored(f"Throughput: {drafter_throughput:.1f} tokens/s", "cyan"))
            print(colored("========== Drafter AR ==========", "cyan"))

    def _run(self):
        while True:
            command = input("> ").replace('\\n', '\n').replace('\\t', '\t')
            if command.startswith("/"):
                self._perform_command(command)
                continue

            self._infer(command)

    def _set_seed(self, seed: int):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Speculative Decoding CLI")
    parser.add_argument("--device", type=str, default="cuda", help="Device to use for inference")
    args = parser.parse_args()

    InferenceCLI(device=args.device)