Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions benchmarks/benchmark_native_vs_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,10 @@ def bench_shim(points, queries, *, track_objects: bool, with_objs: bool):
return t_build, t_query


def median_times(fn, points, queries, repeats: int):
def median_times(fn, points, queries, repeats: int, desc: str = "Running"):
"""Run benchmark multiple times and return median times."""
builds, queries_t = [], []
for _ in tqdm(range(repeats)):
for _ in tqdm(range(repeats), desc=desc, unit="run"):
gc.disable()
b, q = fn(points, queries)
gc.enable()
Expand All @@ -86,29 +87,44 @@ def main():
ap.add_argument("--repeats", type=int, default=5)
args = ap.parse_args()

print("Native vs Shim Benchmark")
print("=" * 50)
print(f"Configuration:")
print(f" Points: {args.points:,}")
print(f" Queries: {args.queries}")
print(f" Repeats: {args.repeats}")
print()

rng = random.Random(SEED)
points = gen_points(args.points, rng)
queries = gen_queries(args.queries, rng)

# Warmup to load modules
print("Warming up...")
_ = bench_native(points[:1000], queries[:50])
_ = bench_shim(points[:1000], queries[:50], track_objects=False, with_objs=False)
print()

print("Running benchmarks...")
n_build, n_query = median_times(
lambda pts, qs: bench_native(pts, qs), points, queries, args.repeats
lambda pts, qs: bench_native(pts, qs), points, queries, args.repeats,
desc="Native"
)
s_build_no_map, s_query_no_map = median_times(
lambda pts, qs: bench_shim(pts, qs, track_objects=False, with_objs=False),
points,
queries,
args.repeats,
desc="Shim (no map)"
)
s_build_map, s_query_map = median_times(
lambda pts, qs: bench_shim(pts, qs, track_objects=True, with_objs=True),
points,
queries,
args.repeats,
desc="Shim (track+objs)"
)
print()

def fmt(x):
return f"{x:.3f}"
Expand Down
61 changes: 58 additions & 3 deletions benchmarks/quadtree_bench/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,36 @@ def median_or_nan(self, vals: List[float]) -> float:
cleaned = [x for x in vals if isinstance(x, (int, float)) and not math.isnan(x)]
return stats.median(cleaned) if cleaned else math.nan

def _print_experiment_summary(self, n: int, results: Dict[str, Any], exp_idx: int) -> None:
"""Print a summary of results for the current experiment."""
def fmt(x):
return f"{x:.3f}" if not math.isnan(x) else "nan"

# Get the results for this experiment (last index)
total = results["total"]
build = results["build"]
query = results["query"]

# Find the fastest engine for this experiment
valid_engines = [(name, total[name][exp_idx]) for name in total.keys()
if not math.isnan(total[name][exp_idx])]

if not valid_engines:
return

fastest = min(valid_engines, key=lambda x: x[1])

print(f"\n 📊 Results for {n:,} points:")
print(f" Fastest: {fastest[0]} ({fmt(fastest[1])}s total)")

# Show top 3 performers
sorted_engines = sorted(valid_engines, key=lambda x: x[1])[:3]
for rank, (name, time) in enumerate(sorted_engines, 1):
b = build[name][exp_idx]
q = query[name][exp_idx]
print(f" {rank}. {name:15} build={fmt(b)}s, query={fmt(q)}s, total={fmt(time)}s")
print()

def run_benchmark(self, engines: Dict[str, Engine]) -> Dict[str, Any]:
"""
Run complete benchmark suite.
Expand All @@ -109,6 +139,7 @@ def run_benchmark(self, engines: Dict[str, Engine]) -> Dict[str, Any]:
Dictionary containing benchmark results
"""
# Warmup on a small set to JIT caches, etc.
print("Warming up engines...")
warmup_points = self.generate_points(2_000)
warmup_queries = self.generate_queries(self.config.n_queries)
for engine in engines.values():
Expand All @@ -127,9 +158,12 @@ def run_benchmark(self, engines: Dict[str, Engine]) -> Dict[str, Any]:
}

# Run experiments
iterator = tqdm(self.config.experiments, desc="Experiments", unit="points")
for n in iterator:
iterator.set_postfix({"points": n})
print(f"\nRunning {len(self.config.experiments)} experiments with {len(engines)} engines...")
experiment_bar = tqdm(self.config.experiments, desc="Experiments", unit="exp", position=0)

for exp_idx, n in enumerate(experiment_bar):
experiment_bar.set_description(f"Experiment {exp_idx+1}/{len(self.config.experiments)}")
experiment_bar.set_postfix({"points": f"{n:,}"})

# Generate data for this experiment
exp_rng = random.Random(10_000 + n)
Expand All @@ -139,11 +173,23 @@ def run_benchmark(self, engines: Dict[str, Engine]) -> Dict[str, Any]:
# Collect results across repeats
engine_times = {name: {"build": [], "query": []} for name in engines}

# Progress bar for engines × repeats
total_iterations = len(engines) * self.config.repeats
engine_bar = tqdm(
total=total_iterations,
desc=" Testing engines",
unit="run",
position=1,
leave=False
)

for repeat in range(self.config.repeats):
gc.disable()

# Benchmark each engine
for name, engine in engines.items():
engine_bar.set_description(f" {name} (repeat {repeat+1}/{self.config.repeats})")

try:
build_time, query_time = self.benchmark_engine_once(
engine, points, queries
Expand All @@ -154,9 +200,13 @@ def run_benchmark(self, engines: Dict[str, Engine]) -> Dict[str, Any]:

engine_times[name]["build"].append(build_time)
engine_times[name]["query"].append(query_time)

engine_bar.update(1)

gc.enable()

engine_bar.close()

# Calculate medians and derived metrics
for name in engines:
build_median = self.median_or_nan(engine_times[name]["build"])
Expand Down Expand Up @@ -184,6 +234,11 @@ def run_benchmark(self, engines: Dict[str, Engine]) -> Dict[str, Any]:
results["insert_rate"][name].append(insert_rate)
results["query_rate"][name].append(query_rate)

# Print intermediate results for this experiment
self._print_experiment_summary(n, results, exp_idx)

experiment_bar.close()

# Add metadata to results
results["engines"] = engines
results["config"] = self.config
Expand Down