In [None]:
import math
import time
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import sys
from typing import Tuple
from pathlib import Path
from tqdm.auto import tqdm
import random
import triton.testing as tt

# Add parent directory to path
sys.path.insert(0, "../..")

from hira.index.indexer import CUDAIndexer
from hira.index.searcher import CUDASearcher
from hira.benchmark_area.utils.data_loader import get_real_data
from hira.kernels.triton_wrappers import (
    triton_two_level_filter,
    triton_three_level_filter_v1,
)

# Set plotting style
sns.set_style("whitegrid")
plt.rcParams["figure.figsize"] = (12, 6)

In [None]:
# Configs
CONFIG = {
    "num_keys_list": [10000, 20000, 30000, 40000, 50000, 60000, 70000, 80000, 90000],
    "branching_factors": [4, 8, 16, 32, 64],
    "distributions": ["real"],  # , "uniform", "mixture_of_gaussians", "zipf"],
    "dim": 128,
    "device": "cuda",
    "target_results": 10,
    "max_iterations": 1,
    "seed": 42,
    "real_data_path": "../kv_sampling/kv_data/kv_data_Meta-Llama-3-8B-Instruct_layer31_20251219_005742.npz",
    "output_csv": "benchmark_results.csv",
}

# Results storage
results = []

### Functions

In [None]:
def calc_all_scores(K, q):
    # q_norm = q / torch.norm(q, p=2)
    scores = K @ q  # (n,)
    return scores


def brute_force(
    keys: torch.Tensor, query: torch.Tensor, threshold: float
) -> torch.Tensor:
    # query_norm = query / torch.norm(query, p=2)
    scores = torch.matmul(keys, query)
    result = (scores >= threshold).nonzero(as_tuple=True)[0]
    return result


def indexed_search(
    indexer: CUDAIndexer,
    searcher: CUDASearcher,
    query: torch.Tensor,
    threshold: float,
) -> torch.Tensor:
    return searcher.search(query, threshold, indexer)


def run_func(func, warmup, runs):
    # Warmup
    for _ in range(warmup):
        func()
    torch.cuda.synchronize()

    # start = torch.cuda.Event(enable_timing=True)
    # end = torch.cuda.Event(enable_timing=True)

    # times_ms = []

    # torch.cuda.synchronize()

    # for _ in range(runs):
    #     start.record()
    #     for _ in range(inner):
    #         func()
    #     end.record()
    #     torch.cuda.synchronize()
    #     times_ms.append(start.elapsed_time(end) / inner)  # milliseconds

    # times_ms = np.array(times_ms)
    # return times_ms.mean()
    ms = tt.do_bench(func, warmup=warmup, rep=runs)
    return ms


@torch.no_grad()
def run_benchmark(
    keys: torch.Tensor,
    branching_factor: int,
    dim: int,
    max_iterations: int,
    target_results: int,
    warmup: int,
    num_runs: int,
    outer: int,
) -> Tuple[float, float, float, float]:
    keys = keys.to("cuda")

    # build indexes
    two_level_index = CUDAIndexer(
        depth=CUDAIndexer.DEPTH.TWO_LEVELS,
        branching_factor=branching_factor,
        # branching_factor=8,
        max_iterations=max_iterations,
    ).build(keys)
    three_level_index = CUDAIndexer(
        depth=CUDAIndexer.DEPTH.THREE_LEVELS,
        branching_factor=branching_factor,
        # branching_factor=4,
        max_iterations=max_iterations,
    ).build(keys)
    searcher_two_levels = CUDASearcher(block_c=branching_factor)
    searcher_three_levels = CUDASearcher(block_c=branching_factor)

    # Create query and find threshold (only once)
    query = torch.randn(dim).to("cuda")
    query = query / torch.norm(query, p=2)
    all_scores = torch.matmul(keys, query)
    sorted_scores, _ = torch.sort(all_scores, descending=True)
    threshold = sorted_scores[min(target_results, len(sorted_scores) - 1)].item()

    benchmarks = [
        ("all_scores", lambda: calc_all_scores(keys, query)),
        ("bf", lambda: brute_force(keys, query, threshold)),
        (
            "two",
            lambda: indexed_search(
                two_level_index,
                searcher_two_levels,
                query,
                threshold,
            ),
        ),
        (
            "three",
            lambda: indexed_search(
                three_level_index,
                searcher_three_levels,
                query,
                threshold,
            ),
        ),
    ]

    results = {
        "all_scores": 0.0,
        "bf": 0.0,
        "two": 0.0,
        "three": 0.0,
    }

    for _ in range(outer):
        random.shuffle(benchmarks)

        for name, fn in benchmarks:
            results[name] += run_func(fn, warmup, num_runs)

    return (
        results["all_scores"] / outer,
        results["bf"] / outer,
        results["two"] / outer,
        results["three"] / outer,
    )


print("Search functions defined!")

### Bench

In [None]:
# results = []
# output_csv = Path(CONFIG["output_csv"])

# for num_keys in CONFIG["num_keys_list"]:
#     print(f"Running benchmarks for {num_keys} keys...")

#     keys = get_real_data(
#         num_keys,
#         CONFIG["dim"],
#         seed=CONFIG["seed"],
#         real_data_path=CONFIG["real_data_path"],
#     )

#     for branching_factor in CONFIG["branching_factors"]:
#         print(f"  Branching Factor: {branching_factor}")

#         all_scores_mean, bf_mean, two_level_mean, three_level_mean = run_benchmark(
#             keys=keys,
#             dim=CONFIG["dim"],
#             max_iterations=CONFIG["max_iterations"],
#             target_results=CONFIG["target_results"],
#             branching_factor=branching_factor,
#             warmup=10,
#             num_runs=100,
#             outer=20,
#         )

#         result = {
#             "num_keys": num_keys,
#             "all_scores_mean_ms": all_scores_mean,
#             "brute_force_mean_ms": bf_mean,
#             "two_level_mean_ms": two_level_mean,
#             "three_level_mean_ms": three_level_mean,
#             "branching_factor": branching_factor,
#         }

#         results.append(result)

#         # Save intermediate results to CSV after each experiment
#         pd.DataFrame(results).to_csv(output_csv, index=False)

# # Convert to DataFrame
# df_results = pd.DataFrame(results)
# print(f"\nResults saved to: {output_csv.absolute()}")

#### Triton Bench

In [None]:
two_ones = [f"two_{bf}" for bf in CONFIG["branching_factors"]]
three_ones = [f"three_{bf}" for bf in CONFIG["branching_factors"]]

# line_vals = ["all_scores", "bf"] + two_ones + three_ones
line_vals = two_ones + three_ones

WARMUP = 10
RUNS = 100


@tt.perf_report(
    tt.Benchmark(
        x_names=["n"],
        x_vals=[10000, 20000, 40000, 60000, 80000, 90000],
        line_arg="provider",
        line_vals=line_vals,
        line_names=line_vals,
        ylabel="Âµs",
        plot_name=f"triton_benchmarking",
        args={},
    )
)
def triton_benchmark(n, provider):
    # start
    if provider.startswith("two_"):
        branching_factor = int(provider.split("_")[1])
        provider = "two"
    elif provider.startswith("three_"):
        branching_factor = int(provider.split("_")[1])
        provider = "three"
    # end

    keys = get_real_data(
        n,
        CONFIG["dim"],
        seed=CONFIG["seed"],
        real_data_path=CONFIG["real_data_path"],
    )

    keys = keys.to("cuda")

    # build indexes
    two_level_index = CUDAIndexer(
        depth=CUDAIndexer.DEPTH.TWO_LEVELS,
        branching_factor=branching_factor,
        # branching_factor=8,
        max_iterations=1,
    ).build(keys)
    three_level_index = CUDAIndexer(
        depth=CUDAIndexer.DEPTH.THREE_LEVELS,
        branching_factor=branching_factor,
        # branching_factor=4,
        max_iterations=1,
    ).build(keys)
    searcher_two_levels = CUDASearcher(block_c=branching_factor)
    searcher_three_levels = CUDASearcher(block_c=branching_factor)

    # Create query and find threshold (only once)
    query = torch.randn(CONFIG["dim"]).to("cuda")
    query = query / torch.norm(query, p=2)
    all_scores = torch.matmul(keys, query)
    sorted_scores, _ = torch.sort(all_scores, descending=True)
    threshold = sorted_scores[
        min(CONFIG["target_results"], len(sorted_scores) - 1)
    ].item()

    benchmarks = {
        "all_scores": lambda: calc_all_scores(keys, query),
        "bf": lambda: brute_force(keys, query, threshold),
        "two": lambda: indexed_search(
            two_level_index,
            searcher_two_levels,
            query,
            threshold,
        ),
        "three": lambda: indexed_search(
            three_level_index,
            searcher_three_levels,
            query,
            threshold,
        ),
    }

    func = benchmarks[provider]

    for _ in range(WARMUP):
        func()
    torch.cuda.synchronize()

    ms = tt.do_bench(func, warmup=WARMUP, rep=RUNS)

    return ms * 1e3


triton_benchmark.run(print_data=True, show_plots=True, save_path="./reports")