In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt

In [3]:
def _apply_repetition_penalty(
    logits: torch.Tensor, generated: torch.Tensor, penalty: float
) -> torch.Tensor:

    if penalty <= 1.0 or generated.numel() == 0:
        return logits
    
    unique_tokens = torch.unique(generated)
    print(f"unique_tokens: {unique_tokens}")
    token_indices = unique_tokens.long()
    print(f"token_indices: {token_indices}")
    
    # Get the logits for the repeated tokens
    vals = logits[0, token_indices]
    print(f"vals: {vals}")
    
    pos_mask = vals > 0
    print(f"pos_mask: {pos_mask}")
    
    vals = torch.where(pos_mask, vals / penalty, vals * penalty)
    print(f"vals: {vals}")

    logits[0, token_indices] = vals

    return logits

In [9]:
vocab_size = 10
penalty = 1.5

original_logits = torch.tensor([[2.0, 3.0, 3.0, 0.5, -1.0, 2.5, 1.0, 0.0, -0.5, 1.8]])
print(original_logits)

generated_tokens = torch.tensor([[1, 3, 2, 1, 8]])
print(generated_tokens)

penalized_logits = _apply_repetition_penalty(original_logits, generated_tokens, penalty)
print(penalized_logits)

tensor([[ 2.0000,  3.0000,  3.0000,  0.5000, -1.0000,  2.5000,  1.0000,  0.0000,
         -0.5000,  1.8000]])
tensor([[1, 3, 2, 1, 8]])
unique_tokens: tensor([1, 2, 3, 8])
token_indices: tensor([1, 2, 3, 8])
vals: tensor([ 3.0000,  3.0000,  0.5000, -0.5000])
pos_mask: tensor([ True,  True,  True, False])
vals: tensor([ 2.0000,  2.0000,  0.3333, -0.7500])
tensor([[ 2.0000,  2.0000,  2.0000,  0.3333, -1.0000,  2.5000,  1.0000,  0.0000,
         -0.7500,  1.8000]])


In [12]:
from typing import Dict, Tuple, List

def _enforce_no_repeat_ngram(
    logits: torch.Tensor, generated: torch.Tensor, no_repeat_ngram_size: int
) -> torch.Tensor:

    if no_repeat_ngram_size <= 0:
        return logits

    seq = generated.squeeze(0).tolist()
    print(f"seq: {seq}")
    
    if len(seq) < no_repeat_ngram_size - 1:
        return logits

    prefix_len = no_repeat_ngram_size - 1
    
    # Build map: prefix -> next_token set
    next_for_prefix: Dict[Tuple[int, ...], List[int]] = {}
    
    for i in range(len(seq) - no_repeat_ngram_size + 1):
        prefix = tuple(seq[i : i + prefix_len])
        nxt = seq[i + prefix_len]
        print(f"prefix: {prefix}, nxt: {nxt}")
        next_for_prefix.setdefault(prefix, []).append(nxt)

    print(f"final prefix map: {next_for_prefix}")
    
    cur_prefix = tuple(seq[-prefix_len:])
    print(f"Current prefix (last {prefix_len} tokens): {cur_prefix}")
    
    banned = next_for_prefix.get(cur_prefix, [])
    print(f"Banned tokens for current prefix: {banned}")
    if banned:
        print(f"Blocking tokens: {banned}")
        banned_idx = torch.tensor(
            banned, device=logits.device, dtype=torch.long
        ).unsqueeze(0)
        print(f"Banned indices tensor: {banned_idx}")
        
        
        # Show logits before modification
        print(f"Logits before blocking: {logits}")
        logits = logits.scatter(dim=-1, index=banned_idx, value=float("-inf"))
        print(f"Logits after blocking: {logits}")
        
    return logits

vocab_size = 10
generated = torch.tensor([[1, 2, 3, 1, 2, 4]])
logits = torch.randn(1, vocab_size)

result = _enforce_no_repeat_ngram(logits, generated, 2)


seq: [1, 2, 3, 1, 2, 4]
prefix: (1,), nxt: 2
prefix: (2,), nxt: 3
prefix: (3,), nxt: 1
prefix: (1,), nxt: 2
prefix: (2,), nxt: 4
final prefix map: {(1,): [2, 2], (2,): [3, 4], (3,): [1]}
Current prefix (last 1 tokens): (4,)
Banned tokens for current prefix: []
