diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 4afb86f..df38eb4 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -78,6 +78,7 @@ jobs: tests/inference_engine/scheduler/ \ tests/inference_engine/pipeline/ \ tests/inference_engine/session/ \ + tests/inference_engine/bench/ \ tests/sdk/python/ \ tests/training/repr_align/ \ tests/backends/mlx/test_env.py \ @@ -86,6 +87,7 @@ jobs: --cov=inference_engine.scheduler \ --cov=inference_engine.pipeline \ --cov=inference_engine.session \ + --cov=inference_engine.bench \ --cov=kakeya \ --cov=training.repr_align \ --cov-report=term \ diff --git a/inference_engine/bench/__init__.py b/inference_engine/bench/__init__.py new file mode 100644 index 0000000..c2129a4 --- /dev/null +++ b/inference_engine/bench/__init__.py @@ -0,0 +1,7 @@ +"""Pure-Python aggregation helpers used by ``scripts/bench_agentic/``. + +These helpers are split out of the CLI scripts so they can be unit- +tested under the Linux 100% coverage gate. The CLI scripts that +import them are themselves exempt from the coverage gate (CLI +plumbing convention; see ``scripts/serve.py`` for precedent). +""" diff --git a/inference_engine/bench/session_long_run.py b/inference_engine/bench/session_long_run.py new file mode 100644 index 0000000..8270de9 --- /dev/null +++ b/inference_engine/bench/session_long_run.py @@ -0,0 +1,210 @@ +"""Pure aggregation helpers for the gRPC long-session bench. + +The bench script under ``scripts/bench_agentic/bench_session_long_run.py`` +walks one gRPC session through many turns, recording per-turn +metrics: latency, KV bytes, history length, error / success. After +the run it calls :func:`aggregate_run` here to compute the headline +KPIs: + + * ``kv_bounded`` — does ``kv_live_bytes`` stay under a tight band + across all turns? (ADR 0006 §2.3.a, ADR 0008 §7 G2.) + * ``prefill_bounded`` — does per-turn latency stay flat as the + history grows? (ADR 0008 §7 G2 prefill claim, the v0.3 GA gate + that was a non-claim on the deprecated HTTP shim.) + * Latency p50/p95, KV min/mean/max, n_turns, n_errors. + +Splitting this out of the CLI script means the aggregation logic is +fully unit-testable and the script itself stays focused on IO. The +script also computes a 10-minute bucket breakdown for visual sanity- +check on long runs (4h+); that bucketing logic lives here too. +""" + +from __future__ import annotations + +import statistics +from typing import Any, Dict, List, Optional + + +# --------------------------------------------------------------------------- +# Aggregation +# --------------------------------------------------------------------------- + + +def _percentile(values: List[float], pct: float) -> Optional[float]: + """Linear-interpolated percentile, ``None`` if input is empty. + + Implemented locally instead of pulling in ``numpy`` so the bench + has no scientific-stack dependency. + """ + if not values: + return None + if not 0.0 <= pct <= 1.0: + raise ValueError(f"pct must be in [0, 1], got {pct}") + sorted_values = sorted(values) + if len(sorted_values) == 1: + return float(sorted_values[0]) + rank = pct * (len(sorted_values) - 1) + lo = int(rank) + hi = min(lo + 1, len(sorted_values) - 1) + frac = rank - lo + return float(sorted_values[lo] + (sorted_values[hi] - sorted_values[lo]) * frac) + + +def _kv_bounded(kv_values: List[int], *, tolerance: float = 0.10) -> Optional[bool]: + """Returns ``True`` iff the KV-bytes series stays within + ``tolerance`` (default 10%) of its minimum across every turn. + + Returns ``None`` when there are not enough successful turns to + answer (≤1 sample). The tolerance is a relative band — if the + minimum is 0 we treat that as a pathologically small denominator + and use ``max(min, 1)`` to avoid div-by-zero, the same convention + ``bench_long_session.py`` uses. + """ + if len(kv_values) <= 1: + return None + lo = min(kv_values) + hi = max(kv_values) + return (hi - lo) / max(lo, 1) < tolerance + + +def _prefill_bounded( + latencies: List[float], + *, + head_window: int = 5, + tail_window: int = 5, + drift_threshold_s: float = 5.0, +) -> Optional[bool]: + """Returns ``True`` iff median per-turn latency on the LAST + ``tail_window`` turns is within ``drift_threshold_s`` seconds of + the median on the FIRST ``head_window`` turns. + + This is the prefill-bounded contract: a healthy session-bound + runtime processes only the new user message per turn, so latency + should not grow with conversation length. On the deprecated HTTP + shim, by contrast, every turn re-prefills the full history and + latency grows linearly — that's the failure mode this metric + catches. + + ``None`` when the run is too short to bracket head and tail + windows without overlap. + """ + if len(latencies) < head_window + tail_window: + return None + head = latencies[:head_window] + tail = latencies[-tail_window:] + head_p50 = statistics.median(head) + tail_p50 = statistics.median(tail) + return (tail_p50 - head_p50) <= drift_threshold_s + + +def _latency_drift_p50_s( + latencies: List[float], + *, + head_window: int = 5, + tail_window: int = 5, +) -> Optional[float]: + """Drift in seconds between head-window p50 and tail-window p50. + + Positive = latency grew over the run. Returns ``None`` for + runs too short to bracket head and tail without overlap. + """ + if len(latencies) < head_window + tail_window: + return None + head = latencies[:head_window] + tail = latencies[-tail_window:] + return float(statistics.median(tail) - statistics.median(head)) + + +def _bucketize_10min(turns: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Partition successful turns by their wall-clock bucket + (10-minute granularity, indexed from 0). Each bucket reports + ``n_turns``, p50/p95 latency, and mean kv_live_bytes — gives a + visual sanity check of latency / memory drift across a long run. + + Empty input or all-error input returns an empty list. + """ + buckets: Dict[int, List[Dict[str, Any]]] = {} + for t in turns: + if not t.get("ok"): + continue + bucket_idx = int(t["t_relative_s"] // 600) + buckets.setdefault(bucket_idx, []).append(t) + + out: List[Dict[str, Any]] = [] + for idx in sorted(buckets): + items = buckets[idx] + latencies = [float(t["latency_s"]) for t in items] + kv_values = [ + int(t["kv_live_bytes"]) for t in items + if t.get("kv_live_bytes") is not None + ] + out.append( + { + "bucket_index": idx, + "n_turns": len(items), + "p50_latency_s": _percentile(latencies, 0.50), + "p95_latency_s": _percentile(latencies, 0.95), + "mean_kv_live_bytes": ( + statistics.mean(kv_values) if kv_values else None + ), + } + ) + return out + + +def aggregate_run( + turns: List[Dict[str, Any]], + *, + duration_s: float, + kv_tolerance: float = 0.10, + drift_head_window: int = 5, + drift_tail_window: int = 5, + drift_threshold_s: float = 5.0, +) -> Dict[str, Any]: + """Build the aggregate report from a list of per-turn records. + + Each turn dict must carry at least: + * ``ok`` — bool + * ``t_relative_s`` — float, seconds since run start + * ``latency_s`` — float (only if ``ok``) + * ``kv_live_bytes`` — int or ``None`` (only if ``ok``) + + Returns a dict with the headline KPIs ADR 0006 §2.3.a / ADR 0008 + §7 G2 speak to: ``kv_bounded``, ``prefill_bounded``, latency + p50/p95, kv min/mean/max, error count, 10-minute bucket break- + down. + """ + successes = [t for t in turns if t.get("ok")] + errors = [t for t in turns if not t.get("ok")] + + latencies = [float(t["latency_s"]) for t in successes] + kv_values = [ + int(t["kv_live_bytes"]) for t in successes + if t.get("kv_live_bytes") is not None + ] + + return { + "n_turns": len(successes), + "n_errors": len(errors), + "duration_s": float(duration_s), + "p50_latency_s": _percentile(latencies, 0.50), + "p95_latency_s": _percentile(latencies, 0.95), + "min_kv_live_bytes": min(kv_values) if kv_values else None, + "mean_kv_live_bytes": ( + statistics.mean(kv_values) if kv_values else None + ), + "max_kv_live_bytes": max(kv_values) if kv_values else None, + "kv_bounded": _kv_bounded(kv_values, tolerance=kv_tolerance), + "prefill_bounded": _prefill_bounded( + latencies, + head_window=drift_head_window, + tail_window=drift_tail_window, + drift_threshold_s=drift_threshold_s, + ), + "latency_drift_p50_s": _latency_drift_p50_s( + latencies, + head_window=drift_head_window, + tail_window=drift_tail_window, + ), + "buckets_10min": _bucketize_10min(turns), + } diff --git a/scripts/bench_agentic/bench_session_long_run.py b/scripts/bench_agentic/bench_session_long_run.py new file mode 100755 index 0000000..ad997da --- /dev/null +++ b/scripts/bench_agentic/bench_session_long_run.py @@ -0,0 +1,322 @@ +"""gRPC long-session bench (PR-E1b of ADR 0008 Phase E). + +Walks ONE gRPC session through many short turns, recording per-turn +latency and ``session.info().kv_live_bytes``. Validates the two +ADR 0008 §7 GA gates the deprecated HTTP shim's ``bench_long_session.py`` +cannot answer: + + * **memory bounded**: ``agg.kv_bounded`` is True (KV stays within + a tight band across the whole run). + * **prefill bounded**: ``agg.prefill_bounded`` is True (per-turn + latency is flat across the run — no drift + with history length). + +The HTTP shim's bench fails on prefill-bounded by architecture: every +``/v1/chat/completions`` request re-prefills the full conversation +history. The session-bound gRPC contract makes prefill cost depend +only on the size of each new user message, regardless of how long +the conversation is. This bench measures that empirically. + +Usage:: + + # Terminal 1 — start the runtime + PYTHONPATH=.:sdks/python python3 scripts/start_grpc_runtime_server.py \ + --backend cpu --verifier-id Qwen/Qwen3-0.6B \ + --bind 127.0.0.1:50051 \ + --capacity 1 --sink 4 --window 64 + + # Terminal 2 — run the bench + PYTHONPATH=.:sdks/python python3 \ + scripts/bench_agentic/bench_session_long_run.py \ + --grpc-address 127.0.0.1:50051 \ + --tokenizer-id Qwen/Qwen3-0.6B \ + --duration-s 14400 --turn-spacing-s 30 \ + --output results/platform-tests/bench_session_4h_$(date +%s).json + +CLI plumbing only — pure aggregation lives in +:mod:`inference_engine.bench.session_long_run` and is unit-tested +under the Linux 100% coverage gate. This script itself is exempt by +the same convention as ``serve.py`` / ``run_demo.py`` / ``chat.py``. +""" + +from __future__ import annotations + +import argparse +import json +import os +import signal +import sys +import time +from pathlib import Path +from typing import Any, Dict, List, Optional + +from inference_engine.bench.session_long_run import aggregate_run + + +# Workload — a fixed rotating set of short user messages so per-turn +# token counts stay small and the bench's prefill-bounded claim is a +# clean signal about session-bound prefill, not about variability in +# message sizes. Six messages chosen so the deepest history-length +# cycle is ~6 turns; long enough to exercise multiple sink+window +# trims. +_USER_MESSAGES: List[str] = [ + "What is a sliding window KV cache?", + "Explain the role of the sink tokens.", + "How does this differ from prefix caching?", + "What are the typical sink and window sizes?", + "Walk me through one inference step.", + "Summarize what we discussed in two sentences.", +] + + +def _build_argument_parser() -> argparse.ArgumentParser: + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument( + "--grpc-address", default="127.0.0.1:50051", + help="host:port of a running kakeya gRPC RuntimeService. " + "Default points at the local server scripts/" + "start_grpc_runtime_server.py defaults to.", + ) + ap.add_argument( + "--tokenizer-id", default="Qwen/Qwen3-0.6B", + help="HF model id for the tokenizer. MUST match the verifier " + "the gRPC server is running, otherwise token ids will " + "be misinterpreted server-side.", + ) + ap.add_argument( + "--duration-s", type=float, default=1800.0, + help="Total wall-clock duration of the run, in seconds. " + "Default 1800 (30 min smoke). Use 14400 for the full 4h.", + ) + ap.add_argument( + "--turn-spacing-s", type=float, default=30.0, + help="Wall-clock spacing between turn STARTS. If a turn " + "takes longer than this, the next turn starts " + "immediately — turn 0's start time is t=0, not " + "t=spacing.", + ) + ap.add_argument( + "--max-tokens", type=int, default=64, + help="max_tokens for each Generate call.", + ) + ap.add_argument( + "--output", required=True, + help="Path to write the JSON report. Atomic-replace via " + "tmp + os.replace so a SIGTERM mid-write doesn't leave " + "a half-written file.", + ) + ap.add_argument( + "--partial-checkpoint-every-s", type=float, default=600.0, + help="Every N seconds, write a snapshot to " + ".partial.json so a long-running bench has " + "evidence on disk even if the host reboots before " + "completion.", + ) + return ap + + +def _now() -> float: + return time.monotonic() + + +def _wallclock() -> float: + return time.time() + + +def _atomic_write_json(path: Path, payload: Dict[str, Any]) -> None: + tmp = path.with_suffix(path.suffix + ".tmp") + with open(tmp, "w", encoding="utf-8") as fh: + json.dump(payload, fh, indent=2) + os.replace(tmp, path) + + +def _build_payload( + *, + turns: List[Dict[str, Any]], + args: argparse.Namespace, + started_at: float, + finished_at: Optional[float], + duration_s: float, + partial: bool, + abort_reason: Optional[str] = None, +) -> Dict[str, Any]: + return { + "schema_version": 1, + "kind": "bench_session_long_run", + "partial": partial, + "abort_reason": abort_reason, + "started_at": started_at, + "finished_at": finished_at, + "duration_s": duration_s, + "config": { + "grpc_address": args.grpc_address, + "tokenizer_id": args.tokenizer_id, + "duration_s_target": args.duration_s, + "turn_spacing_s": args.turn_spacing_s, + "max_tokens": args.max_tokens, + }, + "turns": turns, + "agg": aggregate_run(turns, duration_s=duration_s), + } + + +def _run_one_turn( + *, + session, + tokenizer, + user_message: str, + max_tokens: int, + t_relative_s: float, +) -> Dict[str, Any]: + """One bench iteration. On error, returns ``ok=False`` with the + error class + str instead of raising, so the run continues and + the error surfaces in the aggregate report.""" + try: + # Tokenize the NEW user message only — this is the whole + # point of session-bound runtime. Compare to bench_long_session.py + # where every turn sends the full conversation history. + new_tokens = tokenizer.encode(user_message, add_special_tokens=False) + t0 = _now() + session.append(new_tokens) + emitted: List[int] = [] + for token_id in session.generate(max_tokens=max_tokens): + emitted.append(token_id) + latency_s = _now() - t0 + info = session.info() + return { + "ok": True, + "t_relative_s": t_relative_s, + "latency_s": latency_s, + "kv_live_bytes": info.kv_live_bytes, + "history_length": info.history_length, + "n_emitted": len(emitted), + "user_message_tokens": len(new_tokens), + } + except Exception as exc: # noqa: BLE001 - we want to log every error class + return { + "ok": False, + "t_relative_s": t_relative_s, + "error_class": type(exc).__name__, + "error_str": str(exc), + } + + +def main() -> int: + ap = _build_argument_parser() + args = ap.parse_args() + + # Lazy imports — these pull the HF stack, only do it when actually + # running, so --help stays fast and the unit tests on aggregate_run + # don't need to install HF on Linux. + from transformers import AutoTokenizer + from kakeya import Client + + out_path = Path(args.output) + out_path.parent.mkdir(parents=True, exist_ok=True) + + print( + f"[bench] loading tokenizer {args.tokenizer_id!r}", + file=sys.stderr, flush=True, + ) + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_id) + eos = tokenizer.eos_token_id + eos_ids: List[int] = [int(eos)] if eos is not None else [] + + print( + f"[bench] connecting to {args.grpc_address}", + file=sys.stderr, flush=True, + ) + started_at = _wallclock() + t_origin = _now() + last_checkpoint_at = t_origin + turns: List[Dict[str, Any]] = [] + abort_reason: Optional[str] = None + + stop_requested = False + + def _on_signal(signum, _frame): # pragma: no cover - signal-driven + nonlocal stop_requested, abort_reason + stop_requested = True + abort_reason = f"signal {signum}" + + for sig in (signal.SIGTERM, signal.SIGINT): + signal.signal(sig, _on_signal) + + try: + with Client(args.grpc_address) as client: + with client.create_session(eos_token_ids=eos_ids) as session: + turn_idx = 0 + while not stop_requested: + t_relative = _now() - t_origin + if t_relative >= args.duration_s: + break + msg = _USER_MESSAGES[turn_idx % len(_USER_MESSAGES)] + record = _run_one_turn( + session=session, + tokenizer=tokenizer, + user_message=msg, + max_tokens=args.max_tokens, + t_relative_s=t_relative, + ) + turns.append(record) + turn_idx += 1 + + # Partial checkpoint — write snapshot every N seconds + # so a host reboot doesn't lose hours of evidence. + if ( + args.partial_checkpoint_every_s > 0 + and (_now() - last_checkpoint_at) + >= args.partial_checkpoint_every_s + ): + _atomic_write_json( + out_path.with_suffix(out_path.suffix + ".partial"), + _build_payload( + turns=turns, args=args, + started_at=started_at, finished_at=None, + duration_s=_now() - t_origin, + partial=True, + ), + ) + last_checkpoint_at = _now() + + # Pace turn STARTS at args.turn_spacing_s. + next_start = ( + t_origin + (turn_idx * args.turn_spacing_s) + ) + sleep_s = next_start - _now() + if sleep_s > 0: + # Sleep in small chunks so SIGTERM is responsive. + deadline = _now() + sleep_s + while not stop_requested and _now() < deadline: + time.sleep(min(0.5, deadline - _now())) + except Exception as exc: # noqa: BLE001 - the bench's job is to summarize, not crash + abort_reason = f"{type(exc).__name__}: {exc}" + print( + f"[bench] aborting due to: {abort_reason}", + file=sys.stderr, flush=True, + ) + + duration_s = _now() - t_origin + finished_at = _wallclock() + payload = _build_payload( + turns=turns, args=args, + started_at=started_at, finished_at=finished_at, + duration_s=duration_s, + partial=False, + abort_reason=abort_reason, + ) + _atomic_write_json(out_path, payload) + print( + f"[bench] wrote {out_path}: " + f"n_turns={payload['agg']['n_turns']} " + f"n_errors={payload['agg']['n_errors']} " + f"duration_s={duration_s:.1f} " + f"kv_bounded={payload['agg']['kv_bounded']} " + f"prefill_bounded={payload['agg']['prefill_bounded']}", + file=sys.stderr, flush=True, + ) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/review_pr_e1b_on_mac.sh b/scripts/review_pr_e1b_on_mac.sh new file mode 100755 index 0000000..f480270 --- /dev/null +++ b/scripts/review_pr_e1b_on_mac.sh @@ -0,0 +1,166 @@ +#!/usr/bin/env bash +# Mac M4 review aid for PR-E1b (ADR 0008 §6.5 gRPC long-session +# bench: bench_session_long_run.py). +# +# This script runs a 30-minute SMOKE invocation of the bench against +# a locally-started gRPC server, then prints instructions for the +# 4-hour run. The 30-min smoke is what reviewers commit; the 4h run +# is the GA evidence and is run separately when you have wall-clock +# budget for it. +# +# Both runs validate the two ADR 0008 §7 GA gates the deprecated +# HTTP shim's bench_long_session.py cannot answer: +# +# * memory bounded: agg.kv_bounded is True. +# * prefill bounded: agg.prefill_bounded is True. +# +# Usage (from repo root, on Mac M4 / arm64): +# +# bash scripts/review_pr_e1b_on_mac.sh +# +# Then commit the smoke artifact: +# +# git add results/platform-tests/pr-e1b-mac-bench-session-30min-* +# git commit -m "Mac M4 review evidence for PR-E1b (30-min smoke)" +# git push +# +# Then optionally launch the 4h: +# +# # in one terminal: +# bash -c "$(scripts/review_pr_e1b_on_mac.sh --print-server-cmd)" +# # in another: +# bash -c "$(scripts/review_pr_e1b_on_mac.sh --print-4h-cmd)" +# +# The 4h JSON gets committed under results/platform-tests/ to the +# PR branch as the binding evidence for v0.3 GA gate G2. +# +# Pre-requisites: +# - Qwen3-0.6B in HF cache (smoke + 4h both load it). +# - Free port 50051 (the script binds to 127.0.0.1:50051). +# - PYTHONPATH unset / not pointing at a stale checkout. + +set -euo pipefail + +ROOT="$(cd "$(dirname "$0")/.." && pwd)" +cd "$ROOT" + +# --------------------------------------------------------------------------- +# Branch B: bare command printers (used by the `bash -c "$( ... )"` idiom). +# --------------------------------------------------------------------------- + +if [[ "${1:-}" == "--print-server-cmd" ]]; then + cat <<'CMD' +PYTHONPATH=.:sdks/python python3 scripts/start_grpc_runtime_server.py \ + --backend cpu --verifier-id Qwen/Qwen3-0.6B \ + --bind 127.0.0.1:50051 \ + --capacity 1 --sink 4 --window 64 +CMD + exit 0 +fi + +if [[ "${1:-}" == "--print-4h-cmd" ]]; then + stamp="$(date +%s)" + cat </dev/null; then + echo "==> stopping gRPC server (pid=$server_pid)" + kill "$server_pid" 2>/dev/null || true + wait "$server_pid" 2>/dev/null || true + fi +} +trap cleanup EXIT + +echo "==> starting gRPC server (logs: $server_log)" +PYTHONPATH=.:sdks/python python3 scripts/start_grpc_runtime_server.py \ + --backend cpu \ + --verifier-id Qwen/Qwen3-0.6B \ + --bind 127.0.0.1:50051 \ + --capacity 1 \ + --sink 4 --window 64 \ + --log-level INFO \ + >"$server_log" 2>&1 & +server_pid=$! +echo " pid=$server_pid" + +echo "==> waiting up to 60s for gRPC server to become ready" +ready=0 +for _ in $(seq 1 60); do + if grep -q "kakeya gRPC RuntimeService listening on" "$server_log" 2>/dev/null; then + ready=1 + break + fi + sleep 1 +done + +if [[ "$ready" != "1" ]]; then + echo "!!! gRPC server did not become ready in 60s. Last 20 lines of log:" + tail -20 "$server_log" || true + exit 1 +fi + +echo "==> running bench_session_long_run.py (1800s = 30min smoke)" +PYTHONPATH=.:sdks/python python3 \ + scripts/bench_agentic/bench_session_long_run.py \ + --grpc-address 127.0.0.1:50051 \ + --tokenizer-id Qwen/Qwen3-0.6B \ + --duration-s 1800 --turn-spacing-s 30 \ + --max-tokens 64 \ + --output "$out_json" \ + 2>&1 | tee "$bench_log" + +echo +echo "==> Smoke complete. Headline KPIs from $out_json:" +PYTHONPATH=.:sdks/python python3 - "$out_json" <<'PY' +import json +import sys +with open(sys.argv[1], encoding="utf-8") as fh: + payload = json.load(fh) +agg = payload["agg"] +print(f" n_turns = {agg['n_turns']}") +print(f" n_errors = {agg['n_errors']}") +print(f" p50_latency_s = {agg['p50_latency_s']}") +print(f" p95_latency_s = {agg['p95_latency_s']}") +print(f" kv min/mean/max = " + f"{agg['min_kv_live_bytes']} / " + f"{agg['mean_kv_live_bytes']} / " + f"{agg['max_kv_live_bytes']}") +print(f" kv_bounded = {agg['kv_bounded']}") +print(f" prefill_bounded = {agg['prefill_bounded']}") +print(f" latency_drift_p50_s = {agg['latency_drift_p50_s']}") +PY + +echo +echo "==> Done. Commit the artifact:" +echo " git add $out_dir/pr-e1b-mac-bench-session-30min-${stamp}.*" +echo " git commit -m 'Mac M4 review evidence for PR-E1b (30-min smoke)'" +echo " git push" +echo +echo "==> When you're ready for the 4-hour GA evidence run:" +echo " # in one terminal:" +echo ' bash -c "$(scripts/review_pr_e1b_on_mac.sh --print-server-cmd)"' +echo " # in another:" +echo ' bash -c "$(scripts/review_pr_e1b_on_mac.sh --print-4h-cmd)"' +echo " # then commit the resulting bench_session_4h_.json" diff --git a/scripts/start_grpc_runtime_server.py b/scripts/start_grpc_runtime_server.py new file mode 100755 index 0000000..561fd9a --- /dev/null +++ b/scripts/start_grpc_runtime_server.py @@ -0,0 +1,208 @@ +"""gRPC runtime server launcher (PR-E1b of ADR 0008 Phase E). + +Boots a real Qwen3 verifier (CPU or MLX), wires it through a +:class:`SessionStore` + :class:`AppendTokensCoordinator` + +:class:`GenerationCoordinator`, and serves the v0.3 gRPC +``RuntimeService`` defined in ``proto/kakeya/v1/runtime.proto``. + +Usage:: + + PYTHONPATH=.:sdks/python python3 scripts/start_grpc_runtime_server.py \ + --backend cpu \ + --verifier-id Qwen/Qwen3-0.6B \ + --bind 127.0.0.1:50051 \ + --capacity 4 --sink 4 --window 64 + +This script is the symmetric counterpart of ``scripts/serve.py`` +(which boots the deprecated HTTP+SSE shim) and is exempt from unit- +test coverage by the same convention used for ``serve.py`` / +``run_demo.py`` / ``chat.py``: CLI plumbing around already-tested +library code, validated by integration runs and the Mac M4 review +aid (``scripts/review_pr_e1b_on_mac.sh``). +""" + +from __future__ import annotations + +import argparse +import asyncio +import logging +import signal +import sys +from typing import Tuple + +import torch + +from inference_engine.memory.pool import SlabPool +from inference_engine.memory.slab import SlabConfig +from inference_engine.server.grpc_app import ( + DEFAULT_BIND_ADDRESS, + GrpcServerConfig, + create_grpc_server, +) +from inference_engine.session.coordinator import AppendTokensCoordinator +from inference_engine.session.generator import GenerationCoordinator +from inference_engine.session.store import SessionStore +from kv_cache_proposer.verifier import VerifierConfig + +_LOG = logging.getLogger("kakeya.grpc-runtime") + + +def _resolve_kv_dims(verifier) -> Tuple[int, int, int]: + """Derive (num_layers, num_kv_heads, head_dim) from a loaded + HF / MLX verifier. + + Used purely for slab byte accounting; the verifier maintains its + own KV cache internally — the slab is a fixed-capacity allocation + handle that backs ``GetSessionInfo.kv_live_bytes`` and the + runtime's pool-pressure invariants. Reading the dims from the + verifier's HF config means the per-session byte numbers reported + over gRPC match what the verifier is actually holding. + """ + cfg = verifier.model.config + num_layers = int(getattr(cfg, "num_hidden_layers")) + # Qwen3 / Gemma / DeepSeek all support GQA — kv-heads is the + # dimension that matters for KV cache size, not attention-heads. + num_kv_heads = int( + getattr(cfg, "num_key_value_heads", None) + or getattr(cfg, "num_attention_heads") + ) + head_dim = int( + getattr(cfg, "head_dim", None) + or (cfg.hidden_size // cfg.num_attention_heads) + ) + return num_layers, num_kv_heads, head_dim + + +def _build_verifier( + *, + backend: str, + verifier_id: str, + sink: int, + window: int, +): + cfg = VerifierConfig( + model_id=verifier_id, + dtype=torch.bfloat16, device="cpu", + sink_size=sink, window_size=window, + ) + if backend == "cpu": + from kv_cache_proposer.verifier import SinkWindowVerifier + return SinkWindowVerifier(cfg) + if backend == "mlx": + from inference_engine.backends.mlx.env import probe_environment + env = probe_environment() + if not env.is_available: + print( + f"[grpc-server] MLX unavailable: {env.failure_reason}", + file=sys.stderr, + ) + sys.exit(2) + from inference_engine.backends.mlx.verifier import MLXSinkWindowVerifier + return MLXSinkWindowVerifier(cfg) + raise SystemExit(f"unknown backend: {backend}") + + +async def _serve(args: argparse.Namespace) -> int: + logging.basicConfig( + level=getattr(logging, args.log_level.upper(), logging.INFO), + format="%(asctime)s %(name)s %(levelname)s %(message)s", + ) + + _LOG.info( + "loading verifier backend=%s id=%s sink=%d window=%d", + args.backend, args.verifier_id, args.sink, args.window, + ) + verifier = _build_verifier( + backend=args.backend, verifier_id=args.verifier_id, + sink=args.sink, window=args.window, + ) + + num_layers, num_kv_heads, head_dim = _resolve_kv_dims(verifier) + _LOG.info( + "verifier dims: layers=%d kv_heads=%d head_dim=%d capacity=%d", + num_layers, num_kv_heads, head_dim, args.sink + args.window, + ) + + slab_cfg = SlabConfig( + num_layers=num_layers, + num_heads=num_kv_heads, + sink_size=args.sink, + window_size=args.window, + head_dim=head_dim, + dtype=torch.bfloat16, + device="cpu", + ) + pool = SlabPool(num_slabs=args.capacity, slab_config=slab_cfg) + store = SessionStore( + capacity=args.capacity, + cache_inspector=verifier, + slab_pool=pool, + ) + append_coord = AppendTokensCoordinator(store, verifier) + gen_coord = GenerationCoordinator(store, verifier) + + config = GrpcServerConfig( + bind_address=args.bind, + max_concurrent_rpcs=args.max_concurrent_rpcs, + ) + server = create_grpc_server( + session_store=store, + append_coordinator=append_coord, + generation_coordinator=gen_coord, + config=config, + ) + + await server.start() + _LOG.info("kakeya gRPC RuntimeService listening on %s", args.bind) + + stop_event = asyncio.Event() + + def _on_signal(sig: int) -> None: + _LOG.info("received signal %d; initiating graceful shutdown", sig) + stop_event.set() + + loop = asyncio.get_running_loop() + for sig in (signal.SIGTERM, signal.SIGINT): + try: + loop.add_signal_handler(sig, _on_signal, int(sig)) + except NotImplementedError: # pragma: no cover - Windows + pass + + await stop_event.wait() + await server.stop(grace=args.shutdown_grace_s) + _LOG.info("kakeya gRPC RuntimeService stopped cleanly") + return 0 + + +def main() -> int: + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument("--backend", choices=["cpu", "mlx"], default="cpu") + ap.add_argument("--verifier-id", default="Qwen/Qwen3-0.6B") + ap.add_argument("--bind", default=DEFAULT_BIND_ADDRESS, + help=f"host:port to bind. Default: {DEFAULT_BIND_ADDRESS}") + ap.add_argument("--capacity", type=int, default=4, + help="SessionStore + SlabPool capacity. Each unit is " + "one concurrent session worth of (sink+window) KV " + "cache. v0.3 single-tenant defaults to 4.") + ap.add_argument("--sink", type=int, default=4, + help="Sink-token KV cache size (per-session).") + ap.add_argument("--window", type=int, default=64, + help="Sliding-window KV cache size (per-session). " + "Together with --sink, bounds total KV per session " + "to (sink+window) tokens.") + ap.add_argument("--max-concurrent-rpcs", type=int, default=None, + help="Cap on simultaneous in-flight gRPC RPCs. " + "Defaults to grpc.aio's default; set explicitly on " + "CPU-bound hosts.") + ap.add_argument("--shutdown-grace-s", type=float, default=5.0, + help="Seconds to give in-flight RPCs to finish on " + "SIGTERM/SIGINT before hard-aborting.") + ap.add_argument("--log-level", default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR"]) + args = ap.parse_args() + + return asyncio.run(_serve(args)) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/inference_engine/bench/__init__.py b/tests/inference_engine/bench/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/inference_engine/bench/test_session_long_run.py b/tests/inference_engine/bench/test_session_long_run.py new file mode 100644 index 0000000..c3955cd --- /dev/null +++ b/tests/inference_engine/bench/test_session_long_run.py @@ -0,0 +1,351 @@ +"""Unit tests for the gRPC long-session bench's aggregation logic. + +Covers :mod:`inference_engine.bench.session_long_run` to 100%. The +bench's CLI driver under ``scripts/bench_agentic/bench_session_long_run.py`` +is exempt from coverage by the same convention as ``serve.py``. +""" + +from __future__ import annotations + +import math + +import pytest + +from inference_engine.bench.session_long_run import ( + _bucketize_10min, + _kv_bounded, + _latency_drift_p50_s, + _percentile, + _prefill_bounded, + aggregate_run, +) + + +# --------------------------------------------------------------------------- +# _percentile +# --------------------------------------------------------------------------- + + +class TestPercentile: + def test_empty_returns_none(self): + assert _percentile([], 0.5) is None + + def test_single_value_returns_that_value(self): + assert _percentile([4.2], 0.5) == 4.2 + assert _percentile([4.2], 0.0) == 4.2 + assert _percentile([4.2], 1.0) == 4.2 + + def test_multiple_values_p50_is_median(self): + # 5 values -> p50 is the middle one (index 2). + assert _percentile([1.0, 2.0, 3.0, 4.0, 5.0], 0.5) == 3.0 + + def test_p95_is_close_to_max(self): + # For 10 evenly-spaced points 1..10, p95 = 9.55. + values = [float(i) for i in range(1, 11)] + result = _percentile(values, 0.95) + assert result is not None + assert math.isclose(result, 9.55, abs_tol=1e-9) + + def test_unsorted_input_is_handled(self): + # Internal sort means caller doesn't have to sort. + assert _percentile([5, 1, 3, 2, 4], 0.5) == 3 + + def test_invalid_pct_raises(self): + with pytest.raises(ValueError, match="pct must be in"): + _percentile([1.0, 2.0], 1.5) + with pytest.raises(ValueError, match="pct must be in"): + _percentile([1.0, 2.0], -0.1) + + +# --------------------------------------------------------------------------- +# _kv_bounded +# --------------------------------------------------------------------------- + + +class TestKvBounded: + def test_empty_returns_none(self): + assert _kv_bounded([]) is None + + def test_single_sample_returns_none(self): + assert _kv_bounded([100]) is None + + def test_within_tolerance_returns_true(self): + # min=100, max=105 -> 5% drift, under default 10% tolerance. + assert _kv_bounded([100, 102, 105, 100]) is True + + def test_outside_tolerance_returns_false(self): + # min=100, max=130 -> 30% drift, over default 10% tolerance. + assert _kv_bounded([100, 110, 120, 130]) is False + + def test_zero_minimum_uses_div_protect(self): + # If min is 0, denominator falls back to 1 to avoid div/0; + # the test thus reduces to "max < tolerance * 1 = 0.10 bytes". + # For [0, 0, 0] -> max=0 < 0.10 = True (trivially bounded). + assert _kv_bounded([0, 0, 0]) is True + # For [0, 5, 10] -> max=10, not < 0.10 -> False. + assert _kv_bounded([0, 5, 10]) is False + + def test_custom_tolerance(self): + # 30% drift; with tolerance=0.50 should be True. + assert _kv_bounded([100, 130], tolerance=0.50) is True + # Same series with tolerance=0.20 should be False. + assert _kv_bounded([100, 130], tolerance=0.20) is False + + +# --------------------------------------------------------------------------- +# _prefill_bounded +# --------------------------------------------------------------------------- + + +class TestPrefillBounded: + def test_too_short_returns_none(self): + # With default head=5, tail=5, anything under 10 returns None. + assert _prefill_bounded([1.0, 2.0, 3.0, 4.0, 5.0]) is None + + def test_flat_latency_is_bounded(self): + # 20 samples, all ~1.0s. tail_p50 - head_p50 ~= 0 < 5. + latencies = [1.0 + 0.01 * i for i in range(20)] + assert _prefill_bounded(latencies) is True + + def test_growing_latency_above_threshold_unbounded(self): + # Linear growth from 1.0 to 20.0. head_p50 ~= 1.2, tail_p50 ~= 19.0. + latencies = [1.0 + i for i in range(20)] + assert _prefill_bounded(latencies) is False + + def test_growing_latency_within_threshold_bounded(self): + # Drift of 3 seconds, threshold of 5 seconds. + latencies = [1.0] * 5 + [2.0] * 5 + [3.0] * 5 + [4.0] * 5 + # head_p50 = 1, tail_p50 = 4, drift = 3. Default threshold = 5. + assert _prefill_bounded(latencies) is True + + def test_custom_threshold(self): + latencies = [1.0] * 10 + [4.0] * 10 # drift = 3 + assert _prefill_bounded(latencies, drift_threshold_s=2.0) is False + assert _prefill_bounded(latencies, drift_threshold_s=10.0) is True + + def test_custom_windows(self): + latencies = [1.0] * 3 + [10.0] * 3 + # With head=2, tail=2, drift = 9, exceeds default threshold. + assert _prefill_bounded( + latencies, head_window=2, tail_window=2, + ) is False + + +# --------------------------------------------------------------------------- +# _latency_drift_p50_s +# --------------------------------------------------------------------------- + + +class TestLatencyDriftP50: + def test_too_short_returns_none(self): + assert _latency_drift_p50_s([1.0, 2.0]) is None + + def test_flat_latency_is_zero(self): + latencies = [1.0] * 20 + result = _latency_drift_p50_s(latencies) + assert result is not None + assert math.isclose(result, 0.0, abs_tol=1e-9) + + def test_growth_is_positive(self): + latencies = [1.0] * 5 + [2.0] * 5 + [3.0] * 5 + [4.0] * 5 + result = _latency_drift_p50_s(latencies) + assert result is not None + # head_p50 = 1.0, tail_p50 = 4.0 + assert math.isclose(result, 3.0, abs_tol=1e-9) + + +# --------------------------------------------------------------------------- +# _bucketize_10min +# --------------------------------------------------------------------------- + + +class TestBucketize10min: + def test_empty_returns_empty(self): + assert _bucketize_10min([]) == [] + + def test_all_errors_returns_empty(self): + turns = [ + {"ok": False, "t_relative_s": 0, "error_class": "X"}, + {"ok": False, "t_relative_s": 60, "error_class": "Y"}, + ] + assert _bucketize_10min(turns) == [] + + def test_single_bucket(self): + # All turns under 10 min -> bucket 0. + turns = [ + {"ok": True, "t_relative_s": 0, "latency_s": 1.0, + "kv_live_bytes": 100}, + {"ok": True, "t_relative_s": 300, "latency_s": 2.0, + "kv_live_bytes": 200}, + ] + out = _bucketize_10min(turns) + assert len(out) == 1 + assert out[0]["bucket_index"] == 0 + assert out[0]["n_turns"] == 2 + assert out[0]["p50_latency_s"] == 1.5 + assert out[0]["mean_kv_live_bytes"] == 150 + + def test_multiple_buckets(self): + turns = [ + {"ok": True, "t_relative_s": 0, "latency_s": 1.0, + "kv_live_bytes": 100}, + # Bucket 0 (0-10 min) + {"ok": True, "t_relative_s": 599, "latency_s": 2.0, + "kv_live_bytes": 110}, + # Bucket 1 (10-20 min) + {"ok": True, "t_relative_s": 700, "latency_s": 3.0, + "kv_live_bytes": 120}, + # Bucket 3 (30-40 min) — gap is intentional + {"ok": True, "t_relative_s": 1900, "latency_s": 4.0, + "kv_live_bytes": 130}, + ] + out = _bucketize_10min(turns) + assert [b["bucket_index"] for b in out] == [0, 1, 3] + assert [b["n_turns"] for b in out] == [2, 1, 1] + + def test_skips_kv_none_in_mean(self): + turns = [ + {"ok": True, "t_relative_s": 0, "latency_s": 1.0, + "kv_live_bytes": 100}, + {"ok": True, "t_relative_s": 100, "latency_s": 1.0, + "kv_live_bytes": None}, + ] + out = _bucketize_10min(turns) + assert len(out) == 1 + # Only the first turn has a KV value; mean is just that value. + assert out[0]["mean_kv_live_bytes"] == 100 + # Both turns counted for n_turns. + assert out[0]["n_turns"] == 2 + + def test_all_kv_none_in_bucket_returns_none_mean(self): + turns = [ + {"ok": True, "t_relative_s": 0, "latency_s": 1.0, + "kv_live_bytes": None}, + ] + out = _bucketize_10min(turns) + assert len(out) == 1 + assert out[0]["mean_kv_live_bytes"] is None + + def test_errors_are_skipped(self): + turns = [ + {"ok": True, "t_relative_s": 0, "latency_s": 1.0, + "kv_live_bytes": 100}, + {"ok": False, "t_relative_s": 60, "error_class": "X"}, + {"ok": True, "t_relative_s": 120, "latency_s": 2.0, + "kv_live_bytes": 110}, + ] + out = _bucketize_10min(turns) + # Only the 2 successes counted. + assert out[0]["n_turns"] == 2 + + +# --------------------------------------------------------------------------- +# aggregate_run +# --------------------------------------------------------------------------- + + +class TestAggregateRun: + def test_empty_input(self): + out = aggregate_run([], duration_s=0.0) + assert out["n_turns"] == 0 + assert out["n_errors"] == 0 + assert out["duration_s"] == 0.0 + assert out["p50_latency_s"] is None + assert out["p95_latency_s"] is None + assert out["min_kv_live_bytes"] is None + assert out["mean_kv_live_bytes"] is None + assert out["max_kv_live_bytes"] is None + assert out["kv_bounded"] is None + assert out["prefill_bounded"] is None + assert out["latency_drift_p50_s"] is None + assert out["buckets_10min"] == [] + + def test_all_errors(self): + turns = [ + {"ok": False, "t_relative_s": 0, "error_class": "TimeoutError"}, + {"ok": False, "t_relative_s": 60, "error_class": "TimeoutError"}, + ] + out = aggregate_run(turns, duration_s=120.0) + assert out["n_turns"] == 0 + assert out["n_errors"] == 2 + assert out["p50_latency_s"] is None + assert out["kv_bounded"] is None + + def test_happy_path_with_kv_bounded_and_prefill_bounded(self): + # 12 successful turns, flat latency ~1.0s, kv ~ 100 bytes. + turns = [ + {"ok": True, "t_relative_s": float(i * 30), + "latency_s": 1.0 + 0.05 * (i % 3), + "kv_live_bytes": 100 + (i % 3), + "history_length": 10 + i, "n_emitted": 16, + "user_message_tokens": 10} + for i in range(12) + ] + out = aggregate_run(turns, duration_s=12 * 30.0) + assert out["n_turns"] == 12 + assert out["n_errors"] == 0 + assert out["p50_latency_s"] is not None + assert out["p95_latency_s"] is not None + assert out["kv_bounded"] is True + assert out["prefill_bounded"] is True + assert out["min_kv_live_bytes"] == 100 + assert out["max_kv_live_bytes"] == 102 + + def test_unbounded_run_reports_false(self): + # Latency grows linearly -> prefill_bounded False. + # KV grows linearly too -> kv_bounded False. + turns = [] + for i in range(20): + turns.append({ + "ok": True, + "t_relative_s": float(i * 30), + "latency_s": 1.0 + i * 1.0, + "kv_live_bytes": 100 + i * 100, + }) + out = aggregate_run(turns, duration_s=600.0) + assert out["kv_bounded"] is False + assert out["prefill_bounded"] is False + assert out["latency_drift_p50_s"] is not None + assert out["latency_drift_p50_s"] > 0 + + def test_mixed_success_and_error(self): + turns = [ + {"ok": True, "t_relative_s": 0, "latency_s": 1.0, + "kv_live_bytes": 100}, + {"ok": False, "t_relative_s": 30, "error_class": "X"}, + {"ok": True, "t_relative_s": 60, "latency_s": 1.1, + "kv_live_bytes": 102}, + ] + out = aggregate_run(turns, duration_s=90.0) + assert out["n_turns"] == 2 + assert out["n_errors"] == 1 + + def test_custom_thresholds_pass_through(self): + # Build a run that passes default 10% kv tolerance but fails 1%. + turns = [ + {"ok": True, "t_relative_s": float(i), + "latency_s": 1.0, + "kv_live_bytes": 100 + i} + for i in range(5) + ] + out_default = aggregate_run(turns, duration_s=5.0) + # 100 -> 104 = 4% drift. Default 10%, so True. + assert out_default["kv_bounded"] is True + out_strict = aggregate_run(turns, duration_s=5.0, kv_tolerance=0.01) + assert out_strict["kv_bounded"] is False + + def test_drift_window_pass_through(self): + turns = [ + {"ok": True, "t_relative_s": float(i), + "latency_s": 1.0 + i * 0.1, + "kv_live_bytes": 100} + for i in range(20) + ] + # With small windows, drift becomes meaningful. + out = aggregate_run( + turns, duration_s=20.0, + drift_head_window=2, drift_tail_window=2, + drift_threshold_s=0.1, + ) + # head_p50 ~= 1.05, tail_p50 ~= 2.85, drift ~= 1.8 > 0.1. + assert out["prefill_bounded"] is False