## imports and setup

In [None]:
# install necessary packages
!pip uninstall transformers triton -y  # uninstall any existing transformers package
!pip install transformers -q
!pip install triton==2.0.0 --no-deps # TAKEN FROM README, not sure why .dev version doesnt exist

Found existing installation: transformers 4.47.1
Uninstalling transformers-4.47.1:
  Successfully uninstalled transformers-4.47.1
Found existing installation: triton 3.1.0
Uninstalling triton-3.1.0:
  Successfully uninstalled triton-3.1.0
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.4/44.4 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.7/9.7 MB[0m [31m128.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting triton==2.0.0
  Downloading triton-2.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.0 kB)
Downloading triton-2.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (63.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.3/63.3 MB[0m [31m32.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: triton
Successfully installed triton-2.0.0


In [None]:
!pip install --upgrade transformers einops -q  # install the latest versions of transformers, einops, and triton

In [None]:
# clone the hyper-attention repository to access additional required files
!git clone https://github.com/insuhan/hyper-attn.git

# change directory to the cloned repository
%cd hyper-attn
!ls

Cloning into 'hyper-attn'...
remote: Enumerating objects: 33, done.[K
remote: Counting objects: 100% (21/21), done.[K
remote: Compressing objects: 100% (14/14), done.[K
remote: Total 33 (delta 10), reused 7 (delta 7), pack-reused 12 (from 1)[K
Receiving objects: 100% (33/33), 108.29 KiB | 27.07 MiB/s, done.
Resolving deltas: 100% (11/11), done.
/content/hyper-attn
assets	benchmark_patch_llm.py	benchmark_single_attention.py  LICENSE	models	README.md


In [None]:
# imports
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm  # progress bar for loops
from torch.nn import CrossEntropyLoss  # loss function for classification tasks

import transformers
print("Transformers version:", transformers.__version__)

# import model and tokenizer classes from transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRotaryEmbedding

os.environ['HUGGING_FACE_HUB_TOKEN'] = 'Token'

# add the hyper-attn directory to the python path
sys.path.append('/content/hyper-attn')
sys.path.append('/content/hyper-attn/models')

Transformers version: 4.48.1


## POC testing hyper attention through direct class declaration

In [None]:
from models.attention.hyper_attn import HyperAttention

attn = HyperAttention(
    input_dim=64,
    lsh_num_projs=7,
    block_size=256,
    sample_size=256,
    min_seq_len=4096)

# dummy dimensions
batch_size = 2
seq_len = 4096
input_dim = 64
n_heads = 8 # num attention heads

# dummy tensors for q k v
query = torch.rand(batch_size, n_heads, seq_len, input_dim)
key = torch.rand(batch_size, n_heads, seq_len, input_dim)
value = torch.rand(batch_size, n_heads, seq_len, input_dim)

# Forward pass
attn_output = attn(query, key, value, True)

print("Query shape:", query.shape)
print("Key shape:", key.shape)
print("Value shape:", value.shape)
print("Attention output shape:", attn_output.shape)

Query shape: torch.Size([2, 8, 4096, 64])
Key shape: torch.Size([2, 8, 4096, 64])
Value shape: torch.Size([2, 8, 4096, 64])
Attention output shape: torch.Size([2, 8, 4096, 64])


## Replace attention modules in llama 3

here, we rewrite the entire attention class from llama 3. when K,Q,V matricies are initialized, we normalize by number of heads to match the dimension expected of HyperAttention. We also initialize HyperAttention for each attention head. Then, in the forward pass, we utilize HyperAttention to compute attention for the current token.

![](https://github.com/insuhan/hyper-attn/raw/main/assets/sortlsh.png)

In [None]:
import torch
import torch.nn as nn
from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM
from models.attention.hyper_attn import HyperAttention
import math

def nearest_perfect_square(num):
    if num <= 0:
        return 1
    root = int(math.sqrt(num))
    lower_square = root ** 2
    upper_square = (root + 1) ** 2
    return lower_square if (num - lower_square) < (upper_square - num) else upper_square

class LlamaHyperAttention(nn.Module):
    def __init__(self, config, compression_ratio=0.5):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = getattr(config, "num_key_value_heads", config.num_attention_heads)
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.compression_ratio = compression_ratio  # Compression ratio

        # Initialize projections with correct dimensions
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

        # Initialize HyperAttention modules
        self.hyper_attentions = nn.ModuleList([
            HyperAttention(
                input_dim=self.head_dim,
                lsh_num_projs=7,
                block_size=256,  # Placeholder, updated dynamically
                sample_size=256,
                min_seq_len=4096
            ) for _ in range(self.num_heads)
        ])

    def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=True):
        batch_size, seq_length, _ = hidden_states.shape

        # Dynamically calculate block size based on compression ratio and sequence length
        dynamic_block_size = nearest_perfect_square((1 - self.compression_ratio) ** 2 * seq_length ** 2)

        # Update block size for each HyperAttention module
        for hyper_attn in self.hyper_attentions:
            hyper_attn.block_size = dynamic_block_size

        # Project queries, keys, and values
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        # Reshape for attention computation
        query_states = query_states.view(batch_size, seq_length, self.num_heads, self.head_dim)
        key_states = key_states.view(batch_size, seq_length, self.num_key_value_heads, self.head_dim)
        value_states = value_states.view(batch_size, seq_length, self.num_key_value_heads, self.head_dim)

        # Repeat k/v heads if necessary
        key_states = torch.repeat_interleave(key_states, self.num_key_value_groups, dim=2)
        value_states = torch.repeat_interleave(value_states, self.num_key_value_groups, dim=2)

        # Transpose for attention computation
        query_states = query_states.transpose(1, 2)
        key_states = key_states.transpose(1, 2)
        value_states = value_states.transpose(1, 2)

        # Apply HyperAttention for each head
        attn_outputs = []
        for head_idx, hyper_attn in enumerate(self.hyper_attentions):
            head_output = hyper_attn(
                query_states[:, head_idx],
                key_states[:, head_idx],
                value_states[:, head_idx],
                causal=True
            )
            attn_outputs.append(head_output)

        # Combine attention outputs
        attn_output = torch.stack(attn_outputs, dim=1)
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size)

        # Final projection
        attn_output = self.o_proj(attn_output)

        if use_cache:
            return (attn_output, None)
        return attn_output

class LlamaHyperModel(LlamaModel):
    def __init__(self, config, compression_ratio=0.5):
        super().__init__(config)
        # Replace attention modules with HyperAttention
        for layer in self.layers:
            layer.self_attn = LlamaHyperAttention(config, compression_ratio=compression_ratio)

class LlamaHyperForCausalLM(LlamaForCausalLM):
    def __init__(self, config, compression_ratio=0.5):
        super().__init__(config)
        self.model = LlamaHyperModel(config, compression_ratio=compression_ratio)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, compression_ratio=0.5, *args, **kwargs):
        # Load the model
        model = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs, ignore_mismatched_sizes=True)

        # Convert to HyperAttention version
        config = model.config
        hyper_model = cls(config, compression_ratio=compression_ratio)

        # Copy all weights
        pretrained_dict = model.state_dict()
        hyper_dict = hyper_model.state_dict()

        # Copy non-attention weights directly
        for name, param in pretrained_dict.items():
            if name in hyper_dict:
                if param.shape == hyper_dict[name].shape:
                    hyper_dict[name].copy_(param)

        return hyper_model


## Load Model and Test Generation

In [None]:
# Initialize model and tokenizer
model_name = "meta-llama/Llama-3.1-8B"
tokenizer = AutoTokenizer.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    use_auth_token=os.environ['HUGGING_FACE_HUB_TOKEN'],
    trust_remote_code=True
)
# load the model
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    torch_dtype=torch.float16,
    device_map='auto',
    use_auth_token=os.environ['HUGGING_FACE_HUB_TOKEN'],
    trust_remote_code=True
)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Set to evaluation mode
model.eval()

# Add a new special token for padding
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
# Resize the token embeddings to include the new token
model.resize_token_embeddings(len(tokenizer))

# Prepare input text
input_text = "Explain the concept of attention in deep learning in"
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)
input_ids = inputs.input_ids.to(device)
attention_mask = inputs.attention_mask.to(device)

# Generate text
with torch.no_grad():
    generated_ids = model.generate(
        input_ids,
        max_length=100,
        attention_mask=attention_mask,
        num_beams=4,
        no_repeat_ngram_size=2,
        temperature=0.7,
        top_p=0.92,
        top_k=50,
        pad_token_id=tokenizer.pad_token_id
    )

if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id + 1  # Assign a distinct value
    print(f"Setting `pad_token_id` to {tokenizer.pad_token_id} for proper padding.")

# Decode the generated text
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print(f"Generated text:\n{generated_text}")




tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/184 [00:00<?, ?B/s]

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Generated text:
Explain the concept of attention in deep learning in the context of computer vision
Attention is a fundamental concept in Deep Learning, particularly in Computer Vision, that enables models to selectively focus on specific parts of the input data, such as images or videos, to make more accurate predictions or decisions. In this explanation, we'll delve into the world of visual attention and explore its significance in modern deep neural networks.

**What is Attention?**

Attention, also known as attention mechanism or attention-based model,


## Benchmark with LongBench

In [None]:
!pip install datasets -q

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/480.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m29.5 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/116.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m11.7 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/179.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m179.3/179.3 kB[0m [31m18.0 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/143.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m14.7 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
from datasets import load_dataset, load_metric
# dataset = load_dataset('THUDM/LongBench-v2', split='train')

# print(dataset[0])

ImportError: cannot import name 'load_metric' from 'datasets' (/usr/local/lib/python3.11/dist-packages/datasets/__init__.py)

## Experiments - accuracy and time metrics

**Metric Summaries:**
- `RulerStringMatch (Partial)`: Checks if any reference answer substring exists in prediction (case-insensitive)
- `RulerStringMatch (All)`: Measures average coverage of all reference components in prediction
- `LevenshteinSimilarity`: Computes character-level similarity between prediction and best-matching reference (0-100 scale)

**Benchmark Execution Flow:**
1. Streams dataset examples with `input`, `context`, and `answers`
2. For each example: constructs prompt combining context/input, generates model response with timing
3. Postprocesses outputs (cleans non-printables/whitespace)
4. Compares predictions against references using all three metrics
5. Returns average inference time + aggregated metric scores

In [None]:
!pip install fuzzywuzzy python-Levenshtein bleurt bert-score -q

In [None]:
import re
import numpy as np
from fuzzywuzzy import fuzz

def postprocess_pred(predict_str: str):
    """Clean generated text using RULER's postprocessing rules"""
    predict_str = predict_str.strip()
    # Remove all non-printable characters
    np_pattern = re.compile(r"[\x00-\x1f]")
    predict_str = np_pattern.sub("\n", predict_str).strip()
    return predict_str

def ruler_string_match_part(predictions, references):
    """Calculate if any reference appears in prediction (case-insensitive)"""
    scores = []
    for pred, refs in zip(predictions, references):
        found = any(ref.strip().lower() in pred.lower() for ref in refs)
        scores.append(1.0 if found else 0.0)
    return (sum(scores) / len(scores)) * 100 if scores else 0

def ruler_string_match_all(predictions, references):
    """Calculate average coverage of all references in prediction"""
    scores = []
    for pred, refs in zip(predictions, references):
        matched = sum(ref.strip().lower() in pred.lower() for ref in refs)
        scores.append(matched / len(refs) if len(refs) > 0 else 0)
    return (sum(scores) / len(scores)) * 100 if scores else 0

def levenshtein_similarity(predictions, references):
    """Calculate average Levenshtein similarity between predictions and best-matching reference"""
    scores = []
    for pred, refs in zip(predictions, references):
        # Get best match among multiple possible references
        best_score = max(fuzz.ratio(pred, ref) for ref in refs)
        scores.append(best_score)
    return np.mean(scores) if scores else 0

def benchmark_model(model, tokenizer, dataset, dataset_name, device, num_examples=10):
    """
    Benchmarks the model with multiple text similarity metrics
    Returns: (avg_time_ms, ruler_part_score, ruler_all_score, levenshtein_score)
    """
    total_time = 0
    predictions = []
    references = []

    bleurt_metric = load_metric("bleurt", keep_in_memory=True)
    bertscore_metric = load_metric("bertscore", keep_in_memory=True)

    # Load dataset in streaming mode
    streamed_dataset = load_dataset(dataset, dataset_name, split='train', streaming=True)

    example_count = 0
    for example in streamed_dataset:
        if example_count >= num_examples:
            break

        # Extract dataset elements
        input_text = example['input']
        context = example['context']
        answer_refs = example['answers']

        # Create prompt
        prompt = f"Context: {context}\n\nInput: {input_text}\n\nAnswer:"

        # Tokenize and generate
        inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
        input_ids = inputs.input_ids.to(device)
        attention_mask = inputs.attention_mask.to(device)

        start_time = time.time()
        with torch.no_grad():
            generated_ids = model.generate(
                input_ids,
                max_new_tokens=50,
                attention_mask=attention_mask,
                num_beams=4,
                no_repeat_ngram_size=2,
                temperature=0.7,
                top_p=0.92,
                top_k=50,
                pad_token_id=tokenizer.pad_token_id
            )
        total_time += time.time() - start_time

        # Extract and clean generated answer
        input_length = input_ids.shape[1]
        generated_answer = tokenizer.decode(
            generated_ids[0][input_length:],
            skip_special_tokens=True
        )
        processed_pred = postprocess_pred(generated_answer)

        # Store results
        predictions.append(processed_pred)
        references.append([ans.strip() for ans in answer_refs])
        example_count += 1

    # Calculate metrics
    ruler_part = ruler_string_match_part(predictions, references)
    ruler_all = ruler_string_match_all(predictions, references)
    levenshtein = levenshtein_similarity(predictions, references)
    avg_time = (total_time / example_count) * 1000 if example_count > 0 else 0

    # Calculate BLEURT scores (max over references)
    bleurt_scores = []
    for pred, refs in zip(predictions, references):
        scores = bleurt_metric.compute(predictions=[pred]*len(refs), references=refs)["scores"]
        bleurt_scores.append(max(scores))
    bleurt_score = np.mean(bleurt_scores) * 100  # Convert to percentage

    # Calculate BERTScore (F1 maximized over references)
    bert_f1_scores = []
    for pred, refs in zip(predictions, references):
        results = bertscore_metric.compute(
            predictions=[pred]*len(refs),
            references=refs,
            lang="en",
            rescale_with_baseline=True
        )
        bert_f1_scores.append(max(results["f1"]))
    bertscore_f1 = np.mean(bert_f1_scores) * 100  # Percentage

    return avg_time, ruler_part, ruler_all, levenshtein, bleurt_score, bertscore_f1

In [None]:
avg_time, part_score, all_score, lev_score, bleurt_score, bertscore_f1 = benchmark_model(
    model=model,
    tokenizer=tokenizer,
    dataset='THUDM/LongBench-v2',
    dataset_name='2wikimqa',
    device=device,
    num_examples=1
)
print(f"Metrics: {part_score:.1f}% Partial Match | {all_score:.1f}% Full Match | {lev_score:.1f} Levenshtein")

OutOfMemoryError: CUDA out of memory. Tried to allocate 8.00 GiB. GPU 0 has a total capacity of 39.56 GiB of which 2.74 GiB is free. Process 96073 has 36.81 GiB memory in use. Of the allocated memory 33.50 GiB is allocated by PyTorch, and 2.81 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
avg_time, part_score, all_score, lev_score, bleurt_score, bertscore_f1 = benchmark_model(
    model=model,
    tokenizer=tokenizer,
    dataset='THUDM/LongBench-v2',
    dataset_name='2wikimqa',
    device=device,
    num_examples=1
)
print(f"Metrics: {part_score:.1f}% Partial Match | {all_score:.1f}% Full Match | {lev_score:.1f} Levenshtein")

## POC testing with hardcoded data snippet - Ignore

In [None]:
import time
import re
import numpy as np
from fuzzywuzzy import fuzz

def postprocess_pred(predict_str: str):
    """Clean generated text using RULER's postprocessing rules"""
    predict_str = predict_str.strip()
    np_pattern = re.compile(r"[\x00-\x1f]")
    predict_str = np_pattern.sub("\n", predict_str).strip()
    return predict_str

def benchmark_model_poc(model, tokenizer, device, num_examples=1):
    """
    Benchmarks the model with metrics on a hardcoded example
    Returns: (avg_time_ms, ruler_part_score, ruler_all_score, levenshtein_score)
    """
    total_time = 0
    predictions = []
    references = []

    bleurt_metric = load_metric("bleurt", keep_in_memory=True)
    bertscore_metric = load_metric("bertscore", keep_in_memory=True)

    # Hardcoded example data
    question = 'Where was the wife of Francis I Rákóczi born?'
    context = 'Passage 1: Waldrada of Lotharingia Waldrada was the mistress, and later the wife, of Lothair II of Lotharingia.'
    ground_truth = ['Ozalj']  # List format for compatibility

    for _ in range(num_examples):
        # Construct prompt
        input_text = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:"

        # Tokenize and generate
        inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)
        input_ids = inputs.input_ids.to(device)
        attention_mask = inputs.attention_mask.to(device)

        # Generate with timing
        start_time = time.time()
        with torch.no_grad():
            generated_ids = model.generate(
                input_ids,
                max_new_tokens=50,
                attention_mask=attention_mask,
                num_beams=4,
                no_repeat_ngram_size=2,
                temperature=0.7,
                top_p=0.92,
                top_k=50,
                pad_token_id=tokenizer.pad_token_id
            )
        total_time += time.time() - start_time

        # Extract and clean generated answer
        input_length = input_ids.shape[1]
        generated_answer = tokenizer.decode(
            generated_ids[0][input_length:],
            skip_special_tokens=True
        )
        processed_pred = postprocess_pred(generated_answer)

        # Store results
        predictions.append(processed_pred)
        references.append(ground_truth)

        # Print for inspection
        print(f"Generated answer: {processed_pred}")
        print(f"Expected answer: {ground_truth[0]}\n")

    # Calculate metrics
    def ruler_part(preds, refs):
        return np.mean([any(r.lower() in p.lower() for r in ref) for p, ref in zip(preds, refs)]) * 100

    def ruler_all(preds, refs):
        return np.mean([sum(r.lower() in p.lower() for r in ref)/len(ref) for p, ref in zip(preds, refs)]) * 100

    def levenshtein(preds, refs):
        return np.mean([max(fuzz.ratio(p, r) for r in ref) for p, ref in zip(preds, refs)])

    # Calculate BLEURT scores (max over references)
    bleurt_scores = []
    for pred, refs in zip(predictions, references):
        scores = bleurt_metric.compute(predictions=[pred]*len(refs), references=refs)["scores"]
        bleurt_scores.append(max(scores))
    bleurt_score = np.mean(bleurt_scores) * 100  # Convert to percentage

    # Calculate BERTScore (F1 maximized over references)
    bert_f1_scores = []
    for pred, refs in zip(predictions, references):
        results = bertscore_metric.compute(
            predictions=[pred]*len(refs),
            references=refs,
            lang="en",
            rescale_with_baseline=True
        )
        bert_f1_scores.append(max(results["f1"]))
    bertscore_f1 = np.mean(bert_f1_scores) * 100  # Percentage

    avg_time = (total_time / num_examples) * 1000 if num_examples > 0 else 0
    return (
        avg_time,
        ruler_part(predictions, references),
        ruler_all(predictions, references),
        levenshtein(predictions, references),
        bleurt_score,
        bertscore_f1
    )

In [None]:
avg_time, part_score, all_score, lev_score, bleurt_score, bertscore_f1 = benchmark_model_poc(
    model, tokenizer, device, num_examples=1
)

print(f"\nBenchmark Results (avg of 1 runs):")
print(f"Inference time: {avg_time:.2f}ms")
print(f"Partial Match: {part_score:.1f}%")
print(f"Complete Match: {all_score:.1f}%")
print(f"Levenshtein Similarity: {lev_score:.1f}/100")
print(f"BLEURT Score: {bleurt_score:.1f}%")
print(f"BERTScore F1: {bertscore_f1:.1f}%")

### todo: stream dataset, store kv cache in sparse