From 8cbf57cdbfc0152f48cd48759b8e3ed077fccb1e Mon Sep 17 00:00:00 2001 From: ved015 Date: Fri, 29 May 2026 13:25:26 +0530 Subject: [PATCH 1/3] Add LongMemEval benchmark harness --- .gitignore | 9 + benchmarks/README.md | 9 + benchmarks/longmemeval/README.md | 166 ++++++++ benchmarks/longmemeval/__init__.py | 1 + benchmarks/longmemeval/client.py | 123 ++++++ benchmarks/longmemeval/config.py | 64 +++ benchmarks/longmemeval/dataset.py | 282 +++++++++++++ benchmarks/longmemeval/metrics.py | 113 +++++ benchmarks/longmemeval/run.py | 130 ++++++ benchmarks/longmemeval/run_all_categories.py | 408 +++++++++++++++++++ benchmarks/longmemeval/runner.py | 203 +++++++++ 11 files changed, 1508 insertions(+) create mode 100644 benchmarks/README.md create mode 100644 benchmarks/longmemeval/README.md create mode 100644 benchmarks/longmemeval/__init__.py create mode 100644 benchmarks/longmemeval/client.py create mode 100644 benchmarks/longmemeval/config.py create mode 100644 benchmarks/longmemeval/dataset.py create mode 100644 benchmarks/longmemeval/metrics.py create mode 100644 benchmarks/longmemeval/run.py create mode 100644 benchmarks/longmemeval/run_all_categories.py create mode 100644 benchmarks/longmemeval/runner.py diff --git a/.gitignore b/.gitignore index 37c88b7..30ef4e6 100644 --- a/.gitignore +++ b/.gitignore @@ -56,7 +56,16 @@ tests/ !tests/ !tests/**/*.py benchmarks/ +!benchmarks/ +!benchmarks/README.md LongMemEval/ +!benchmarks/longmemeval/ +!benchmarks/longmemeval/** +benchmarks/longmemeval/**/__pycache__/ +benchmarks/longmemeval/**/*.pyc +benchmarks/longmemeval/data/ +benchmarks/longmemeval/results/ +benchmarks/longmemeval/outputs/ backboard/ rust/ diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 0000000..4408368 --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,9 @@ +# XMem Benchmarks + +This directory contains benchmark harnesses for XMem. + +- `longmemeval/`: Python-only LongMemEval benchmark runner targeting the XMem HTTP API. + +Benchmark runs can create large dataset and result artifacts. Keep those files under +`benchmarks/longmemeval/data`, `benchmarks/longmemeval/results`, or +`benchmarks/longmemeval/outputs`; those paths are intentionally ignored by git. diff --git a/benchmarks/longmemeval/README.md b/benchmarks/longmemeval/README.md new file mode 100644 index 0000000..e6cbf2c --- /dev/null +++ b/benchmarks/longmemeval/README.md @@ -0,0 +1,166 @@ +# LongMemEval Benchmark for XMem Python + +This harness benchmarks the Python XMem service only. It targets the deployed +Python API at `https://api.xmem.in` by default and does not run or compare the +Go implementation. + +LongMemEval evaluates long-term conversational memory across multi-session +recall, temporal reasoning, single-session recall, knowledge updates, and +preference tracking. The harness follows the same broad structure used by +open-source memory-layer benchmarks: load dataset records, ingest the haystack +conversation history into an isolated user namespace, retrieve an answer for +the benchmark question, write predictions, and compute lightweight local +metrics for quick iteration. + +## Files + +- `dataset.py`: Loads JSON/JSONL LongMemEval records and converts sessions to + XMem conversation-turn ingest payloads. +- `client.py`: Async HTTP client for the Python XMem API. +- `runner.py`: Benchmark orchestration, batching, polling, resume support, and + output writing. +- `metrics.py`: Local exact-match, contains, and token-F1 metrics plus summary + aggregation. +- `run.py`: CLI entrypoint. + +## Secrets + +Do not commit API keys or provider credentials. + +To generate XMem predictions, set an XMem API key: + +```bash +export XMEM_API_KEY="..." +``` + +Use `--api-key-env` if your local environment uses a different variable name. + +To score predictions with the official LongMemEval LLM-as-judge evaluator, set +an OpenAI API key before running the evaluator: + +```bash +export OPENAI_API_KEY="..." +``` + +## Run a Smoke Check + +Validate dataset parsing and payload construction without calling the service: + +```bash +python -m benchmarks.longmemeval.run \ + --download \ + --dry-run \ + --limit 2 +``` + +Validate all six official categories without requiring an API key: + +```bash +python -m benchmarks.longmemeval.run_all_categories \ + --download \ + --dry-run +``` + +If the dataset is already available locally: + +```bash +python -m benchmarks.longmemeval.run \ + --dataset-path benchmarks/longmemeval/data/longmemeval_s_cleaned.json \ + --dry-run \ + --limit 2 +``` + +## Run Against the Python API + +```bash +export XMEM_API_KEY="..." + +python -m benchmarks.longmemeval.run \ + --download \ + --api-base-url https://api.xmem.in \ + --limit 10 \ + --batch-size 25 \ + --output-dir benchmarks/longmemeval/results/run-001 +``` + +The runner writes: + +- `results.jsonl`: Full per-example benchmark records. +- `predictions.jsonl`: Official prediction file with only `question_id` and + `hypothesis`. +- `summary.json`: Aggregate local metrics and latency. + +The local metrics are intended for fast development feedback. For publication +quality reporting, run the generated `predictions.jsonl` through the official +LongMemEval evaluation flow or an agreed LLM-as-judge rubric using the same +model/settings across systems. + +The benchmark runner itself only needs `XMEM_API_KEY` because it generates XMem +answers. The official/equivalent evaluator is a separate scoring step and needs +`OPENAI_API_KEY` when using an OpenAI judge model. + +## Run All Official Categories + +The dataset has six `question_type` categories. Each example has a unique +`question_id` and its own haystack sessions, and this runner isolates each +question into a separate XMem user namespace. That makes category-level +parallelism safe from memory leakage; the only practical constraint is API +throughput and rate limiting. + +```bash +export XMEM_API_KEY="..." + +python -m benchmarks.longmemeval.run_all_categories \ + --dataset-path benchmarks/longmemeval/data/longmemeval_s_cleaned.json \ + --api-base-url https://api.xmem.in \ + --output-root benchmarks/longmemeval/results/full-six-categories \ + --max-parallel-categories 6 +``` + +The all-category runner prints live processed/left/ETA status and writes one +official merged prediction file at: + +```text +benchmarks/longmemeval/results/full-six-categories/predictions.jsonl +``` + +Each category also gets a `runner.log` file under its output directory. If a +category process fails, the launcher prints the failing category, exit code, log +path, and the most recent child-process output. + +## Useful Options + +- `--limit N`: Run a small subset first. +- `--offset N`: Skip the first N selected examples. +- `--question-type TYPE`: Filter to one LongMemEval category. +- `--skip-ingest`: Reuse already-ingested user namespaces and only retrieve. +- `--no-resume`: Re-run examples even if they already exist in `results.jsonl`. +- `--ingest-api-version v1`: Use synchronous batch ingestion instead of the + default durable `/v2/memory/batch-ingest` path. +- `--effort-level high`: Use high-effort XMem ingestion for long records. +- `--dry-run`: Validate dataset/category setup without API calls. +- `--verbose`: Print child runner output while the all-category launcher runs. + +## Expected Failures + +These errors are intentional and should be actionable: + +- `Dataset file not found`: run with `--download`, or pass `--dataset-path`. +- `Missing API key`: set `XMEM_API_KEY`, or pass `--api-key-env` for a custom + variable name. +- Official evaluator authentication errors: set `OPENAI_API_KEY` before running + the LongMemEval scoring step. +- `Failed to download the LongMemEval dataset`: check network access, then retry + or download the dataset manually. +- ` failed with exit code ...`: inspect that category's `runner.log`. + +## Isolation Model + +Each example is ingested into a user id derived from: + +```text +- +``` + +This prevents facts from one benchmark question from leaking into another. Use a +new `--user-prefix` for fully fresh runs. diff --git a/benchmarks/longmemeval/__init__.py b/benchmarks/longmemeval/__init__.py new file mode 100644 index 0000000..16dfbfb --- /dev/null +++ b/benchmarks/longmemeval/__init__.py @@ -0,0 +1 @@ +"""LongMemEval benchmark harness for the Python XMem API.""" diff --git a/benchmarks/longmemeval/client.py b/benchmarks/longmemeval/client.py new file mode 100644 index 0000000..aafab65 --- /dev/null +++ b/benchmarks/longmemeval/client.py @@ -0,0 +1,123 @@ +"""HTTP client for the Python XMem API used by the benchmark.""" + +from __future__ import annotations + +import asyncio +import time +from dataclasses import dataclass +from typing import Any + +import httpx + + +TERMINAL_JOB_STATUSES = {"succeeded", "dead_letter"} + + +@dataclass(frozen=True) +class ApiCallResult: + data: dict[str, Any] + elapsed_ms: float + + +class XMemApiClient: + """Small async client around the deployed Python XMem API.""" + + def __init__( + self, + *, + base_url: str, + api_key: str, + timeout_seconds: float = 120.0, + max_retries: int = 3, + retry_backoff_seconds: float = 2.0, + ) -> None: + self.base_url = base_url.rstrip("/") + self.max_retries = max_retries + self.retry_backoff_seconds = retry_backoff_seconds + self._client = httpx.AsyncClient( + base_url=self.base_url, + timeout=httpx.Timeout(timeout_seconds), + headers={ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + "User-Agent": "xmem-longmemeval-benchmark/1.0", + }, + ) + + async def __aenter__(self) -> "XMemApiClient": + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + await self.close() + + async def close(self) -> None: + await self._client.aclose() + + async def ingest(self, payload: dict[str, Any]) -> ApiCallResult: + return await self._post("/v1/memory/ingest", payload) + + async def batch_ingest_v1(self, items: list[dict[str, Any]]) -> ApiCallResult: + return await self._post("/v1/memory/batch-ingest", {"items": items}) + + async def batch_ingest_v2(self, items: list[dict[str, Any]]) -> ApiCallResult: + return await self._post("/v2/memory/batch-ingest", {"items": items}) + + async def retrieve(self, payload: dict[str, Any]) -> ApiCallResult: + return await self._post("/v1/memory/retrieve", payload) + + async def job_status(self, status_url: str) -> ApiCallResult: + return await self._get(status_url) + + async def poll_job( + self, + status_url: str, + *, + interval_seconds: float, + timeout_seconds: float, + ) -> ApiCallResult: + deadline = time.monotonic() + timeout_seconds + last_result: ApiCallResult | None = None + while time.monotonic() < deadline: + last_result = await self.job_status(status_url) + status = str(last_result.data.get("status") or "").lower() + if status in TERMINAL_JOB_STATUSES: + return last_result + await asyncio.sleep(interval_seconds) + status = last_result.data.get("status") if last_result else "unknown" + raise TimeoutError(f"Timed out polling job {status_url}; last status={status}") + + async def _get(self, path: str) -> ApiCallResult: + return await self._request("GET", path) + + async def _post(self, path: str, payload: dict[str, Any]) -> ApiCallResult: + return await self._request("POST", path, json=payload) + + async def _request(self, method: str, path: str, **kwargs: Any) -> ApiCallResult: + request_path = path if path.startswith("/") else f"/{path}" + start = time.perf_counter() + response: httpx.Response | None = None + for attempt in range(self.max_retries + 1): + try: + response = await self._client.request(method, request_path, **kwargs) + if response.status_code < 500 and response.status_code != 429: + break + except httpx.HTTPError: + if attempt >= self.max_retries: + raise + if attempt < self.max_retries: + await asyncio.sleep(self.retry_backoff_seconds * (attempt + 1)) + + if response is None: + raise RuntimeError(f"No response from {method} {request_path}") + elapsed_ms = round((time.perf_counter() - start) * 1000, 2) + response.raise_for_status() + body = response.json() + if body.get("status") == "error": + error = body.get("error") or f"XMem API error from {request_path}" + raise RuntimeError(error) + data = body.get("data") + if data is None: + data = {} + if not isinstance(data, dict): + data = {"value": data} + return ApiCallResult(data=data, elapsed_ms=elapsed_ms) diff --git a/benchmarks/longmemeval/config.py b/benchmarks/longmemeval/config.py new file mode 100644 index 0000000..7a441c9 --- /dev/null +++ b/benchmarks/longmemeval/config.py @@ -0,0 +1,64 @@ +"""Configuration helpers for the LongMemEval benchmark.""" + +from __future__ import annotations + +import os +from dataclasses import dataclass +from pathlib import Path + + +DEFAULT_API_BASE_URL = "https://api.xmem.in" +DEFAULT_API_KEY_ENV = "XMEM_API_KEY" +DEFAULT_DATASET_VARIANT = "longmemeval_s_cleaned" +DEFAULT_DATASET_URLS = { + "longmemeval_s_cleaned": ( + "https://huggingface.co/datasets/xiaowu0162/longmemeval-cleaned/" + "resolve/main/longmemeval_s_cleaned.json" + ), + "longmemeval_m_cleaned": ( + "https://huggingface.co/datasets/xiaowu0162/longmemeval-cleaned/" + "resolve/main/longmemeval_m_cleaned.json" + ), + "longmemeval_oracle": ( + "https://huggingface.co/datasets/xiaowu0162/longmemeval-cleaned/" + "resolve/main/longmemeval_oracle.json" + ), +} + + +@dataclass(frozen=True) +class BenchmarkConfig: + """Runtime settings for a LongMemEval benchmark run.""" + + dataset_path: Path + output_dir: Path + api_base_url: str = DEFAULT_API_BASE_URL + api_key_env: str = DEFAULT_API_KEY_ENV + api_timeout_seconds: float = 120.0 + max_retries: int = 3 + retry_backoff_seconds: float = 2.0 + batch_size: int = 25 + ingest_api_version: str = "v2" + poll_interval_seconds: float = 2.0 + poll_timeout_seconds: float = 1800.0 + top_k: int = 10 + effort_level: str = "low" + user_prefix: str = "longmemeval" + limit: int | None = None + offset: int = 0 + question_type: str | None = None + skip_ingest: bool = False + resume: bool = True + dry_run: bool = False + + @property + def api_key(self) -> str: + return os.getenv(self.api_key_env, "").strip() + + def require_api_key(self) -> str: + api_key = self.api_key + if not api_key: + raise RuntimeError( + f"Missing API key. Set {self.api_key_env} before running the benchmark." + ) + return api_key diff --git a/benchmarks/longmemeval/dataset.py b/benchmarks/longmemeval/dataset.py new file mode 100644 index 0000000..4d64d33 --- /dev/null +++ b/benchmarks/longmemeval/dataset.py @@ -0,0 +1,282 @@ +"""Dataset loading and normalization for LongMemEval records.""" + +from __future__ import annotations + +import json +import urllib.request +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Iterable + +from .config import DEFAULT_DATASET_URLS + + +@dataclass(frozen=True) +class ConversationTurn: + role: str + content: str + + +@dataclass(frozen=True) +class ConversationSession: + session_id: str + date: str = "" + turns: list[ConversationTurn] = field(default_factory=list) + + +@dataclass(frozen=True) +class LongMemEvalExample: + question_id: str + question: str + answer: str + question_type: str = "" + sessions: list[ConversationSession] = field(default_factory=list) + metadata: dict[str, Any] = field(default_factory=dict) + + @property + def user_id_suffix(self) -> str: + safe = "".join( + ch if ch.isalnum() or ch in {"-", "_"} else "_" for ch in self.question_id + ) + return safe.strip("_") or "example" + + +@dataclass(frozen=True) +class IngestItem: + user_query: str + agent_response: str + user_id: str + session_datetime: str = "" + effort_level: str = "low" + + +def download_dataset(variant: str, destination: Path) -> Path: + """Download a known LongMemEval dataset variant to destination.""" + + if variant not in DEFAULT_DATASET_URLS: + known = ", ".join(sorted(DEFAULT_DATASET_URLS)) + raise ValueError( + f"Unknown dataset variant '{variant}'. Known variants: {known}" + ) + destination.parent.mkdir(parents=True, exist_ok=True) + urllib.request.urlretrieve(DEFAULT_DATASET_URLS[variant], destination) + return destination + + +def load_examples(path: Path) -> list[LongMemEvalExample]: + """Load LongMemEval examples from JSON or JSONL.""" + + raw_records = _read_records(path) + examples = [ + _parse_example(record, index) + for index, record in enumerate(raw_records) + ] + return [example for example in examples if example.question] + + +def select_examples( + examples: Iterable[LongMemEvalExample], + *, + offset: int = 0, + limit: int | None = None, + question_type: str | None = None, +) -> list[LongMemEvalExample]: + selected = list(examples) + if question_type: + selected = [ + example + for example in selected + if example.question_type.lower() == question_type.lower() + ] + if offset: + selected = selected[offset:] + if limit is not None: + selected = selected[:limit] + return selected + + +def build_ingest_items( + example: LongMemEvalExample, + *, + user_id: str, + effort_level: str = "low", +) -> list[IngestItem]: + """Convert LongMemEval sessions into XMem conversation-turn ingest items.""" + + items: list[IngestItem] = [] + for session in example.sessions: + for user_query, agent_response in _iter_message_pairs(session.turns): + if not user_query.strip() and not agent_response.strip(): + continue + items.append( + IngestItem( + user_query=user_query.strip() or "[empty user message]", + agent_response=agent_response.strip(), + user_id=user_id, + session_datetime=session.date, + effort_level=effort_level, + ) + ) + return items + + +def _read_records(path: Path) -> list[dict[str, Any]]: + text = path.read_text(encoding="utf-8").strip() + if not text: + return [] + if path.suffix.lower() == ".jsonl": + return [json.loads(line) for line in text.splitlines() if line.strip()] + payload = json.loads(text) + if isinstance(payload, list): + return payload + if isinstance(payload, dict): + for key in ("data", "examples", "records", "questions"): + value = payload.get(key) + if isinstance(value, list): + return value + if all(isinstance(v, dict) for v in payload.values()): + return list(payload.values()) + return [payload] + raise ValueError(f"Unsupported dataset payload in {path}") + + +def _parse_example(record: dict[str, Any], index: int) -> LongMemEvalExample: + question_id = str( + record.get("question_id") + or record.get("id") + or record.get("sample_id") + or f"example-{index}" + ) + sessions = _parse_sessions(record) + return LongMemEvalExample( + question_id=question_id, + question=str(record.get("question") or record.get("query") or "").strip(), + answer=str(record.get("answer") or record.get("gold_answer") or "").strip(), + question_type=str( + record.get("question_type") + or record.get("category") + or record.get("type") + or "" + ).strip(), + sessions=sessions, + metadata={ + key: value + for key, value in record.items() + if key not in {"haystack_sessions", "sessions", "conversation", "messages"} + }, + ) + + +def _parse_sessions(record: dict[str, Any]) -> list[ConversationSession]: + raw_sessions = ( + record.get("haystack_sessions") + or record.get("sessions") + or record.get("conversation") + or record.get("messages") + or [] + ) + dates = record.get("haystack_dates") or record.get("session_dates") or [] + if isinstance(raw_sessions, dict): + raw_sessions = list(raw_sessions.values()) + if _looks_like_turn(raw_sessions): + raw_sessions = [raw_sessions] + + sessions: list[ConversationSession] = [] + for idx, raw_session in enumerate(raw_sessions): + date = "" + if isinstance(dates, list) and idx < len(dates): + date = str(dates[idx] or "") + session_id = f"session-{idx + 1}" + turns_source = raw_session + if isinstance(raw_session, dict): + session_id = str( + raw_session.get("session_id") + or raw_session.get("id") + or session_id + ) + date = str(raw_session.get("date") or raw_session.get("created_at") or date) + turns_source = ( + raw_session.get("messages") + or raw_session.get("turns") + or raw_session.get("conversation") + or [] + ) + turns = _parse_turns(turns_source) + if turns: + sessions.append( + ConversationSession(session_id=session_id, date=date, turns=turns) + ) + return sessions + + +def _parse_turns(raw_turns: Any) -> list[ConversationTurn]: + if not isinstance(raw_turns, list): + return [] + turns: list[ConversationTurn] = [] + for raw_turn in raw_turns: + if isinstance(raw_turn, str): + turns.append(ConversationTurn(role="user", content=raw_turn)) + continue + if not isinstance(raw_turn, dict): + continue + role = str( + raw_turn.get("role") + or raw_turn.get("speaker") + or raw_turn.get("sender") + or raw_turn.get("from") + or "" + ).lower() + content = ( + raw_turn.get("content") + or raw_turn.get("text") + or raw_turn.get("message") + or raw_turn.get("utterance") + or "" + ) + turns.append(ConversationTurn(role=_normalize_role(role), content=str(content))) + return [turn for turn in turns if turn.content.strip()] + + +def _looks_like_turn(value: Any) -> bool: + if not isinstance(value, list) or not value: + return False + first = value[0] + return isinstance(first, dict) and any( + key in first + for key in ("role", "speaker", "content", "text") + ) + + +def _normalize_role(role: str) -> str: + if role in {"assistant", "ai", "agent", "bot", "gpt"}: + return "assistant" + if role in {"system"}: + return "system" + return "user" + + +def _iter_message_pairs( + turns: Iterable[ConversationTurn], +) -> Iterable[tuple[str, str]]: + pending_user: str | None = None + pending_assistant: list[str] = [] + + for turn in turns: + role = _normalize_role(turn.role) + content = turn.content.strip() + if not content: + continue + if role == "system": + continue + if role == "user": + if pending_user is not None: + yield pending_user, "\n\n".join(pending_assistant) + pending_assistant = [] + pending_user = content + continue + if pending_user is None: + pending_user = "[assistant context]" + pending_assistant.append(content) + + if pending_user is not None: + yield pending_user, "\n\n".join(pending_assistant) diff --git a/benchmarks/longmemeval/metrics.py b/benchmarks/longmemeval/metrics.py new file mode 100644 index 0000000..2294752 --- /dev/null +++ b/benchmarks/longmemeval/metrics.py @@ -0,0 +1,113 @@ +"""Lightweight answer metrics and aggregation for LongMemEval outputs.""" + +from __future__ import annotations + +import json +import re +from collections import defaultdict +from pathlib import Path +from typing import Any + + +def normalize_answer(text: str) -> str: + text = text.lower() + text = re.sub(r"[^a-z0-9\s]", " ", text) + text = re.sub(r"\b(a|an|the)\b", " ", text) + return " ".join(text.split()) + + +def token_f1(prediction: str, reference: str) -> float: + pred_tokens = normalize_answer(prediction).split() + ref_tokens = normalize_answer(reference).split() + if not pred_tokens or not ref_tokens: + return float(pred_tokens == ref_tokens) + common = set(pred_tokens) & set(ref_tokens) + overlap = sum( + min(pred_tokens.count(token), ref_tokens.count(token)) + for token in common + ) + if overlap == 0: + return 0.0 + precision = overlap / len(pred_tokens) + recall = overlap / len(ref_tokens) + return 2 * precision * recall / (precision + recall) + + +def score_answer(prediction: str, reference: str) -> dict[str, float | bool]: + normalized_prediction = normalize_answer(prediction) + normalized_reference = normalize_answer(reference) + exact_match = normalized_prediction == normalized_reference + contains = bool( + normalized_reference + and normalized_reference in normalized_prediction + ) + return { + "exact_match": exact_match, + "contains": contains, + "token_f1": round(token_f1(prediction, reference), 4), + } + + +def summarize_results(results: list[dict[str, Any]]) -> dict[str, Any]: + if not results: + return {"count": 0, "overall": {}, "by_question_type": {}} + + overall = _summarize_bucket(results) + buckets: dict[str, list[dict[str, Any]]] = defaultdict(list) + for result in results: + buckets[str(result.get("question_type") or "unknown")].append(result) + return { + "count": len(results), + "overall": overall, + "by_question_type": { + question_type: _summarize_bucket(bucket) + for question_type, bucket in sorted(buckets.items()) + }, + } + + +def write_json(path: Path, payload: Any) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text( + json.dumps(payload, indent=2, sort_keys=True) + "\n", + encoding="utf-8", + ) + + +def append_jsonl(path: Path, payload: dict[str, Any]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("a", encoding="utf-8") as handle: + handle.write(json.dumps(payload, sort_keys=True) + "\n") + + +def read_jsonl(path: Path) -> list[dict[str, Any]]: + if not path.exists(): + return [] + rows: list[dict[str, Any]] = [] + with path.open("r", encoding="utf-8") as handle: + for line in handle: + if line.strip(): + rows.append(json.loads(line)) + return rows + + +def _summarize_bucket(results: list[dict[str, Any]]) -> dict[str, float | int]: + count = len(results) + exact = sum(1 for result in results if result.get("metrics", {}).get("exact_match")) + contains = sum(1 for result in results if result.get("metrics", {}).get("contains")) + f1_scores = [ + float(result.get("metrics", {}).get("token_f1") or 0.0) + for result in results + ] + avg_f1 = sum(f1_scores) / count + avg_retrieve_ms = ( + sum(float(result.get("retrieve_elapsed_ms") or 0.0) for result in results) + / count + ) + return { + "count": count, + "exact_match": round(exact / count, 4), + "contains": round(contains / count, 4), + "token_f1": round(avg_f1, 4), + "avg_retrieve_ms": round(avg_retrieve_ms, 2), + } diff --git a/benchmarks/longmemeval/run.py b/benchmarks/longmemeval/run.py new file mode 100644 index 0000000..82e3ee1 --- /dev/null +++ b/benchmarks/longmemeval/run.py @@ -0,0 +1,130 @@ +"""Command line entrypoint for the LongMemEval benchmark.""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import sys +from pathlib import Path + +from .config import ( + DEFAULT_API_BASE_URL, + DEFAULT_API_KEY_ENV, + DEFAULT_DATASET_VARIANT, + BenchmarkConfig, +) +from .dataset import download_dataset +from .runner import LongMemEvalRunner + + +def main() -> None: + try: + args = parse_args() + dataset_path = prepare_dataset(args) + config = build_config(args, dataset_path) + summary = asyncio.run(LongMemEvalRunner(config).run()) + print(json.dumps(summary, indent=2, sort_keys=True)) + except Exception as exc: + print(f"ERROR: {exc}", file=sys.stderr) + raise SystemExit(1) from exc + + +def prepare_dataset(args: argparse.Namespace) -> Path: + if args.download: + try: + return download_dataset(args.variant, args.dataset_path) + except Exception as exc: + raise RuntimeError( + "Failed to download the LongMemEval dataset. " + "Check your network connection, or download the dataset manually " + f"and pass --dataset-path. Details: {exc}" + ) from exc + + if not args.dataset_path.exists(): + raise FileNotFoundError( + f"Dataset file not found: {args.dataset_path}. " + "Run with --download, or pass --dataset-path to a local " + "LongMemEval JSON/JSONL file." + ) + return args.dataset_path + + +def build_config(args: argparse.Namespace, dataset_path: Path) -> BenchmarkConfig: + return BenchmarkConfig( + dataset_path=dataset_path, + output_dir=args.output_dir, + api_base_url=args.api_base_url, + api_key_env=args.api_key_env, + api_timeout_seconds=args.api_timeout_seconds, + max_retries=args.max_retries, + retry_backoff_seconds=args.retry_backoff_seconds, + batch_size=args.batch_size, + ingest_api_version=args.ingest_api_version, + poll_interval_seconds=args.poll_interval_seconds, + poll_timeout_seconds=args.poll_timeout_seconds, + top_k=args.top_k, + effort_level=args.effort_level, + user_prefix=args.user_prefix, + limit=args.limit, + offset=args.offset, + question_type=args.question_type, + skip_ingest=args.skip_ingest, + resume=not args.no_resume, + dry_run=args.dry_run, + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Run LongMemEval against the Python XMem API.", + ) + parser.add_argument( + "--dataset-path", + type=Path, + default=Path("benchmarks/longmemeval/data/longmemeval_s_cleaned.json"), + help="Path to a LongMemEval JSON or JSONL file.", + ) + parser.add_argument( + "--download", + action="store_true", + help="Download the selected LongMemEval dataset variant before running.", + ) + parser.add_argument( + "--variant", + default=DEFAULT_DATASET_VARIANT, + help="Dataset variant to download when --download is used.", + ) + parser.add_argument( + "--output-dir", + type=Path, + default=Path("benchmarks/longmemeval/results/latest"), + help="Directory for results.jsonl, predictions.jsonl, and summary.json.", + ) + parser.add_argument("--api-base-url", default=DEFAULT_API_BASE_URL) + parser.add_argument("--api-key-env", default=DEFAULT_API_KEY_ENV) + parser.add_argument("--api-timeout-seconds", type=float, default=120.0) + parser.add_argument("--max-retries", type=int, default=3) + parser.add_argument("--retry-backoff-seconds", type=float, default=2.0) + parser.add_argument("--batch-size", type=int, default=25) + parser.add_argument("--ingest-api-version", choices=("v1", "v2"), default="v2") + parser.add_argument("--poll-interval-seconds", type=float, default=2.0) + parser.add_argument("--poll-timeout-seconds", type=float, default=1800.0) + parser.add_argument("--top-k", type=int, default=10) + parser.add_argument("--effort-level", choices=("low", "high"), default="low") + parser.add_argument("--user-prefix", default="longmemeval") + parser.add_argument("--limit", type=int) + parser.add_argument("--offset", type=int, default=0) + parser.add_argument("--question-type") + parser.add_argument("--skip-ingest", action="store_true") + parser.add_argument("--no-resume", action="store_true") + parser.add_argument( + "--dry-run", + action="store_true", + help="Load and transform the dataset without calling XMem.", + ) + return parser.parse_args() + + +if __name__ == "__main__": + main() diff --git a/benchmarks/longmemeval/run_all_categories.py b/benchmarks/longmemeval/run_all_categories.py new file mode 100644 index 0000000..575a965 --- /dev/null +++ b/benchmarks/longmemeval/run_all_categories.py @@ -0,0 +1,408 @@ +"""Run the official LongMemEval question_type categories in parallel.""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import os +import sys +import time +from collections import Counter +from collections import deque +from dataclasses import dataclass +from pathlib import Path + +from .config import DEFAULT_API_BASE_URL, DEFAULT_API_KEY_ENV, DEFAULT_DATASET_VARIANT +from .dataset import build_ingest_items, download_dataset, load_examples + + +OFFICIAL_QUESTION_TYPES = ( + "single-session-user", + "single-session-assistant", + "single-session-preference", + "temporal-reasoning", + "knowledge-update", + "multi-session", +) + + +@dataclass +class CategoryState: + category: str + total: int + ingest_total: int = 0 + processed: int = 0 + ingest_processed: int = 0 + current_ingest_seen: int = 0 + done: bool = False + returncode: int | None = None + + @property + def left(self) -> int: + return max(self.total - self.processed, 0) + + +async def main() -> None: + try: + args = parse_args() + await run(args) + except Exception as exc: + print(f"ERROR: {exc}", file=sys.stderr) + raise SystemExit(1) from exc + + +async def run(args: argparse.Namespace) -> None: + validate_args(args) + dataset_path = prepare_dataset(args) + examples = load_examples(dataset_path) + validate_independence(examples) + counts = Counter(example.question_type for example in examples) + ingest_counts = Counter( + { + category: sum( + len(build_ingest_items(example, user_id="inspect")) + for example in examples + if example.question_type == category + ) + for category in OFFICIAL_QUESTION_TYPES + } + ) + states = { + category: CategoryState( + category=category, + total=counts[category], + ingest_total=ingest_counts[category], + ) + for category in OFFICIAL_QUESTION_TYPES + } + + args.output_root.mkdir(parents=True, exist_ok=True) + print_category_plan(states) + + if args.dry_run: + print( + "Dry run complete: dataset loaded, categories validated, " + "and no API calls were made.", + flush=True, + ) + return + + if not os.getenv(args.api_key_env): + raise RuntimeError( + f"Missing API key. Set {args.api_key_env} before running the " + "benchmark, for example: export XMEM_API_KEY='...'." + ) + + start_time = time.monotonic() + semaphore = asyncio.Semaphore(args.max_parallel_categories) + tasks = [ + asyncio.create_task( + run_category( + category, + state, + args=args, + dataset_path=dataset_path, + semaphore=semaphore, + ) + ) + for category, state in states.items() + ] + reporter = asyncio.create_task( + report_progress(states, start_time, args.status_seconds) + ) + results = await asyncio.gather(*tasks, return_exceptions=True) + reporter.cancel() + await asyncio.gather(reporter, return_exceptions=True) + + failures = [] + for result in results: + if isinstance(result, Exception): + failures.append(str(result)) + if failures: + raise RuntimeError("One or more category runs failed: " + " | ".join(failures)) + + merge_predictions(args.output_root) + print_status(states, start_time, final=True) + + +def prepare_dataset(args: argparse.Namespace) -> Path: + if args.download: + try: + return download_dataset(args.variant, args.dataset_path) + except Exception as exc: + raise RuntimeError( + "Failed to download the LongMemEval dataset. " + "Check network access, or download the dataset manually and " + f"pass --dataset-path. Details: {exc}" + ) from exc + + if not args.dataset_path.exists(): + raise FileNotFoundError( + f"Dataset file not found: {args.dataset_path}. " + "Run with --download, or pass --dataset-path to a local " + "LongMemEval JSON/JSONL file." + ) + return args.dataset_path + + +def validate_args(args: argparse.Namespace) -> None: + if args.max_parallel_categories < 1: + raise ValueError("--max-parallel-categories must be at least 1.") + if args.max_parallel_categories > len(OFFICIAL_QUESTION_TYPES): + raise ValueError( + "--max-parallel-categories cannot exceed the six official " + "LongMemEval question_type categories." + ) + if args.batch_size < 1: + raise ValueError("--batch-size must be at least 1.") + if args.status_seconds < 1: + raise ValueError("--status-seconds must be at least 1.") + + +async def run_category( + category: str, + state: CategoryState, + *, + args: argparse.Namespace, + dataset_path: Path, + semaphore: asyncio.Semaphore, +) -> None: + async with semaphore: + output_dir = args.output_root / category + output_dir.mkdir(parents=True, exist_ok=True) + log_path = output_dir / "runner.log" + recent_lines: deque[str] = deque(maxlen=20) + cmd = [ + sys.executable, + "-m", + "benchmarks.longmemeval.run", + "--dataset-path", + str(dataset_path), + "--api-base-url", + args.api_base_url, + "--api-key-env", + args.api_key_env, + "--output-dir", + str(output_dir), + "--question-type", + category, + "--batch-size", + str(args.batch_size), + "--ingest-api-version", + args.ingest_api_version, + "--top-k", + str(args.top_k), + "--effort-level", + args.effort_level, + "--user-prefix", + f"{args.user_prefix}-{category}", + ] + if args.no_resume: + cmd.append("--no-resume") + process = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT, + ) + assert process.stdout is not None + with log_path.open("w", encoding="utf-8") as log_file: + async for raw_line in process.stdout: + line = raw_line.decode("utf-8", errors="replace").rstrip() + recent_lines.append(line) + log_file.write(line + "\n") + log_file.flush() + update_state_from_line(state, line) + update_ingest_state_from_line(state, line) + if args.verbose: + print(f"[{category}] {line}", flush=True) + state.returncode = await process.wait() + state.done = state.returncode == 0 + if state.returncode != 0: + tail = "\n".join(recent_lines) or "(no child output captured)" + raise RuntimeError( + f"{category} failed with exit code {state.returncode}. " + f"See {log_path}. Recent output:\n{tail}" + ) + + +async def report_progress( + states: dict[str, CategoryState], + start_time: float, + status_seconds: float, +) -> None: + try: + while True: + print_status(states, start_time) + await asyncio.sleep(status_seconds) + except asyncio.CancelledError: + return + + +def update_state_from_line(state: CategoryState, line: str) -> None: + if not line.startswith("[") or "/" not in line: + return + close = line.find("]") + if close == -1: + return + progress = line[1:close] + processed_text, total_text = progress.split("/", 1) + if processed_text.isdigit() and total_text.isdigit(): + state.processed = max(state.processed, int(processed_text)) + state.total = int(total_text) + + +def update_ingest_state_from_line(state: CategoryState, line: str) -> None: + if not line.startswith("[INGEST] processed="): + return + progress = line.split("processed=", 1)[1].strip() + processed_text, total_text = progress.split("/", 1) + if processed_text.isdigit() and total_text.isdigit(): + processed = int(processed_text) + total = int(total_text) + # Child output is per-question, so accumulate forward movement. + if processed < state.current_ingest_seen: + state.current_ingest_seen = 0 + delta = max(processed - state.current_ingest_seen, 0) + state.ingest_processed = min( + state.ingest_processed + delta, + state.ingest_total, + ) + state.current_ingest_seen = 0 if processed == total else processed + + +def print_status( + states: dict[str, CategoryState], + start_time: float, + *, + final: bool = False, +) -> None: + processed = sum(state.processed for state in states.values()) + total = sum(state.total for state in states.values()) + left = max(total - processed, 0) + ingest_processed = sum(state.ingest_processed for state in states.values()) + ingest_total = sum(state.ingest_total for state in states.values()) + ingest_left = max(ingest_total - ingest_processed, 0) + elapsed = max(time.monotonic() - start_time, 0.001) + rate = processed / elapsed if processed else 0.0 + eta = left / rate if rate else 0.0 + label = "FINAL" if final else "STATUS" + print( + f"[{label}] processed={processed}/{total} left={left} " + f"ingested_pairs={ingest_processed}/{ingest_total} " + f"pairs_left={ingest_left} elapsed={format_duration(elapsed)} " + f"eta={format_duration(eta)}", + flush=True, + ) + for category in OFFICIAL_QUESTION_TYPES: + state = states[category] + print( + f" - {category}: {state.processed}/{state.total} questions left=" + f"{state.left}; ingest_pairs={state.ingest_processed}/" + f"{state.ingest_total}", + flush=True, + ) + + +def print_category_plan(states: dict[str, CategoryState]) -> None: + print("LongMemEval category plan:", flush=True) + for category in OFFICIAL_QUESTION_TYPES: + state = states[category] + print( + f" - {category}: {state.total} questions, " + f"{state.ingest_total} ingest pairs", + flush=True, + ) + + +def merge_predictions(output_root: Path) -> None: + merged_path = output_root / "predictions.jsonl" + missing_predictions = [] + with merged_path.open("w", encoding="utf-8") as merged: + for category in OFFICIAL_QUESTION_TYPES: + path = output_root / category / "predictions.jsonl" + if not path.exists(): + missing_predictions.append(str(path)) + continue + with path.open("r", encoding="utf-8") as handle: + for line in handle: + if line.strip(): + payload = json.loads(line) + merged.write( + json.dumps( + { + "question_id": payload["question_id"], + "hypothesis": payload["hypothesis"], + }, + sort_keys=True, + ) + + "\n" + ) + if missing_predictions: + raise RuntimeError( + "Missing category prediction files: " + + ", ".join(missing_predictions) + ) + print(f"Merged official predictions: {merged_path}", flush=True) + + +def validate_independence(examples: list[object]) -> None: + question_ids = [getattr(example, "question_id") for example in examples] + if len(question_ids) != len(set(question_ids)): + raise RuntimeError("Dataset has duplicate question_id values.") + categories = {getattr(example, "question_type") for example in examples} + missing = set(OFFICIAL_QUESTION_TYPES) - categories + if missing: + raise RuntimeError(f"Dataset missing official categories: {sorted(missing)}") + + +def format_duration(seconds: float) -> str: + if seconds <= 0: + return "unknown" + seconds = int(seconds) + hours, rem = divmod(seconds, 3600) + minutes, secs = divmod(rem, 60) + if hours: + return f"{hours}h{minutes:02d}m" + if minutes: + return f"{minutes}m{secs:02d}s" + return f"{secs}s" + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Run all official LongMemEval categories against XMem.", + ) + parser.add_argument( + "--dataset-path", + type=Path, + default=Path("benchmarks/longmemeval/data/longmemeval_s_cleaned.json"), + ) + parser.add_argument("--download", action="store_true") + parser.add_argument("--variant", default=DEFAULT_DATASET_VARIANT) + parser.add_argument( + "--output-root", + type=Path, + default=Path("benchmarks/longmemeval/results/full-six-categories"), + ) + parser.add_argument("--api-base-url", default=DEFAULT_API_BASE_URL) + parser.add_argument("--api-key-env", default=DEFAULT_API_KEY_ENV) + parser.add_argument("--batch-size", type=int, default=25) + parser.add_argument("--ingest-api-version", choices=("v1", "v2"), default="v2") + parser.add_argument("--top-k", type=int, default=10) + parser.add_argument("--effort-level", choices=("low", "high"), default="low") + parser.add_argument("--user-prefix", default="longmemeval") + parser.add_argument("--max-parallel-categories", type=int, default=6) + parser.add_argument("--status-seconds", type=float, default=30.0) + parser.add_argument("--no-resume", action="store_true") + parser.add_argument("--verbose", action="store_true") + parser.add_argument( + "--dry-run", + action="store_true", + help="Validate dataset/category setup without requiring an API key.", + ) + return parser.parse_args() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/benchmarks/longmemeval/runner.py b/benchmarks/longmemeval/runner.py new file mode 100644 index 0000000..0db1bb9 --- /dev/null +++ b/benchmarks/longmemeval/runner.py @@ -0,0 +1,203 @@ +"""Benchmark orchestration for LongMemEval against the Python XMem API.""" + +from __future__ import annotations + +import time +from pathlib import Path +from typing import Any + +from .client import XMemApiClient +from .config import BenchmarkConfig +from .dataset import ( + LongMemEvalExample, + build_ingest_items, + load_examples, + select_examples, +) +from .metrics import ( + append_jsonl, + read_jsonl, + score_answer, + summarize_results, + write_json, +) + + +class LongMemEvalRunner: + """Runs LongMemEval examples against the XMem Python HTTP service.""" + + def __init__(self, config: BenchmarkConfig) -> None: + self.config = config + self.results_path = config.output_dir / "results.jsonl" + self.predictions_path = config.output_dir / "predictions.jsonl" + self.summary_path = config.output_dir / "summary.json" + + async def run(self) -> dict[str, Any]: + examples = select_examples( + load_examples(self.config.dataset_path), + offset=self.config.offset, + limit=self.config.limit, + question_type=self.config.question_type, + ) + if self.config.dry_run: + return self._dry_run_summary(examples) + + completed_ids = self._completed_question_ids() if self.config.resume else set() + api_key = self.config.require_api_key() + run_started = time.time() + + async with XMemApiClient( + base_url=self.config.api_base_url, + api_key=api_key, + timeout_seconds=self.config.api_timeout_seconds, + max_retries=self.config.max_retries, + retry_backoff_seconds=self.config.retry_backoff_seconds, + ) as client: + for index, example in enumerate(examples, start=1): + if example.question_id in completed_ids: + continue + result = await self._run_example( + client, + example, + index=index, + total=len(examples), + ) + append_jsonl(self.results_path, result) + append_jsonl( + self.predictions_path, + { + "question_id": result["question_id"], + "hypothesis": result["prediction"], + }, + ) + + all_results = read_jsonl(self.results_path) + summary = summarize_results(all_results) + summary["dataset_path"] = str(self.config.dataset_path) + summary["api_base_url"] = self.config.api_base_url + summary["duration_seconds"] = round(time.time() - run_started, 2) + write_json(self.summary_path, summary) + return summary + + async def _run_example( + self, + client: XMemApiClient, + example: LongMemEvalExample, + *, + index: int, + total: int, + ) -> dict[str, Any]: + user_id = f"{self.config.user_prefix}-{example.user_id_suffix}" + ingest_count = 0 + ingest_elapsed_ms = 0.0 + + if not self.config.skip_ingest: + items = build_ingest_items( + example, + user_id=user_id, + effort_level=self.config.effort_level, + ) + ingest_count = len(items) + ingest_elapsed_ms = await self._ingest_items(client, items) + + retrieve = await client.retrieve( + { + "query": example.question, + "user_id": user_id, + "top_k": self.config.top_k, + } + ) + prediction = str(retrieve.data.get("answer") or "") + result = { + "question_id": example.question_id, + "question_type": example.question_type or "unknown", + "question": example.question, + "reference_answer": example.answer, + "prediction": prediction, + "metrics": score_answer(prediction, example.answer), + "source_count": len(retrieve.data.get("sources") or []), + "confidence": retrieve.data.get("confidence"), + "user_id": user_id, + "ingest_count": ingest_count, + "ingest_elapsed_ms": round(ingest_elapsed_ms, 2), + "retrieve_elapsed_ms": retrieve.elapsed_ms, + "index": index, + "total": total, + } + print( + f"[{index}/{total}] {example.question_id}: " + f"f1={result['metrics']['token_f1']} retrieve_ms={retrieve.elapsed_ms}" + ) + return result + + async def _ingest_items(self, client: XMemApiClient, items: list[Any]) -> float: + if not items: + return 0.0 + + elapsed_ms = 0.0 + processed = 0 + for start in range(0, len(items), self.config.batch_size): + chunk = items[start : start + self.config.batch_size] + payload = [item.__dict__ for item in chunk] + if self.config.ingest_api_version == "v1": + result = await client.batch_ingest_v1(payload) + elapsed_ms += result.elapsed_ms + processed += len(chunk) + print( + f"[INGEST] processed={processed}/{len(items)}", + flush=True, + ) + continue + + accepted = await client.batch_ingest_v2(payload) + elapsed_ms += accepted.elapsed_ms + status_url = str(accepted.data.get("status_url") or "") + if not status_url: + raise RuntimeError( + "XMem v2 batch ingest response did not include status_url" + ) + status = await client.poll_job( + status_url, + interval_seconds=self.config.poll_interval_seconds, + timeout_seconds=self.config.poll_timeout_seconds, + ) + elapsed_ms += status.elapsed_ms + if str(status.data.get("status") or "").lower() != "succeeded": + error = status.data.get("error") or status.data + raise RuntimeError( + f"XMem batch ingest job failed: {error}" + ) + processed += len(chunk) + print( + f"[INGEST] processed={processed}/{len(items)}", + flush=True, + ) + return elapsed_ms + + def _completed_question_ids(self) -> set[str]: + return {str(row.get("question_id")) for row in read_jsonl(self.results_path)} + + def _dry_run_summary(self, examples: list[LongMemEvalExample]) -> dict[str, Any]: + ingest_counts = [ + len( + build_ingest_items( + example, + user_id="dry-run", + effort_level=self.config.effort_level, + ) + ) + for example in examples + ] + summary = { + "dry_run": True, + "dataset_path": str(self.config.dataset_path), + "selected_examples": len(examples), + "total_ingest_items": sum(ingest_counts), + "min_ingest_items": min(ingest_counts) if ingest_counts else 0, + "max_ingest_items": max(ingest_counts) if ingest_counts else 0, + "question_types": sorted( + {example.question_type or "unknown" for example in examples} + ), + } + write_json(self.summary_path, summary) + return summary From 2f8d90cf93252e12e5faf8e8747a5b57cf8696f3 Mon Sep 17 00:00:00 2001 From: ved015 Date: Fri, 29 May 2026 13:39:20 +0530 Subject: [PATCH 2/3] Address LongMemEval benchmark review comments --- benchmarks/longmemeval/client.py | 8 +++++++- benchmarks/longmemeval/dataset.py | 14 ++++++++++++-- benchmarks/longmemeval/run_all_categories.py | 4 ++++ benchmarks/longmemeval/runner.py | 1 - 4 files changed, 23 insertions(+), 4 deletions(-) diff --git a/benchmarks/longmemeval/client.py b/benchmarks/longmemeval/client.py index aafab65..240e8d9 100644 --- a/benchmarks/longmemeval/client.py +++ b/benchmarks/longmemeval/client.py @@ -93,7 +93,7 @@ async def _post(self, path: str, payload: dict[str, Any]) -> ApiCallResult: return await self._request("POST", path, json=payload) async def _request(self, method: str, path: str, **kwargs: Any) -> ApiCallResult: - request_path = path if path.startswith("/") else f"/{path}" + request_path = self._request_path(path) start = time.perf_counter() response: httpx.Response | None = None for attempt in range(self.max_retries + 1): @@ -121,3 +121,9 @@ async def _request(self, method: str, path: str, **kwargs: Any) -> ApiCallResult if not isinstance(data, dict): data = {"value": data} return ApiCallResult(data=data, elapsed_ms=elapsed_ms) + + @staticmethod + def _request_path(path: str) -> str: + if path.startswith(("http://", "https://", "/")): + return path + return f"/{path}" diff --git a/benchmarks/longmemeval/dataset.py b/benchmarks/longmemeval/dataset.py index 4d64d33..d612851 100644 --- a/benchmarks/longmemeval/dataset.py +++ b/benchmarks/longmemeval/dataset.py @@ -3,11 +3,12 @@ from __future__ import annotations import json -import urllib.request from dataclasses import dataclass, field from pathlib import Path from typing import Any, Iterable +import httpx + from .config import DEFAULT_DATASET_URLS @@ -59,7 +60,16 @@ def download_dataset(variant: str, destination: Path) -> Path: f"Unknown dataset variant '{variant}'. Known variants: {known}" ) destination.parent.mkdir(parents=True, exist_ok=True) - urllib.request.urlretrieve(DEFAULT_DATASET_URLS[variant], destination) + with httpx.stream( + "GET", + DEFAULT_DATASET_URLS[variant], + follow_redirects=True, + timeout=120.0, + ) as response: + response.raise_for_status() + with destination.open("wb") as handle: + for chunk in response.iter_bytes(): + handle.write(chunk) return destination diff --git a/benchmarks/longmemeval/run_all_categories.py b/benchmarks/longmemeval/run_all_categories.py index 575a965..7965280 100644 --- a/benchmarks/longmemeval/run_all_categories.py +++ b/benchmarks/longmemeval/run_all_categories.py @@ -246,6 +246,8 @@ def update_state_from_line(state: CategoryState, line: str) -> None: if close == -1: return progress = line[1:close] + if "/" not in progress: + return processed_text, total_text = progress.split("/", 1) if processed_text.isdigit() and total_text.isdigit(): state.processed = max(state.processed, int(processed_text)) @@ -256,6 +258,8 @@ def update_ingest_state_from_line(state: CategoryState, line: str) -> None: if not line.startswith("[INGEST] processed="): return progress = line.split("processed=", 1)[1].strip() + if "/" not in progress: + return processed_text, total_text = progress.split("/", 1) if processed_text.isdigit() and total_text.isdigit(): processed = int(processed_text) diff --git a/benchmarks/longmemeval/runner.py b/benchmarks/longmemeval/runner.py index 0db1bb9..c0af853 100644 --- a/benchmarks/longmemeval/runner.py +++ b/benchmarks/longmemeval/runner.py @@ -3,7 +3,6 @@ from __future__ import annotations import time -from pathlib import Path from typing import Any from .client import XMemApiClient From ce8be33d1c1587a42d95451edda6942010b1c14d Mon Sep 17 00:00:00 2001 From: ved015 Date: Fri, 29 May 2026 13:46:22 +0530 Subject: [PATCH 3/3] Fix staging deploy for force-pushed PR branches --- .github/workflows/deploy-staging.yml | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/.github/workflows/deploy-staging.yml b/.github/workflows/deploy-staging.yml index 30fcabb..995f05f 100644 --- a/.github/workflows/deploy-staging.yml +++ b/.github/workflows/deploy-staging.yml @@ -146,12 +146,15 @@ jobs: set -euo pipefail cd "${{ secrets.STAGING_EC2_DEPLOY_PATH }}" - PR_BRANCH="${{ github.head_ref || github.event.inputs.ref || 'develop' }}" - echo "── Deploying branch: $PR_BRANCH ──" + DEPLOY_REF="${{ github.head_ref || github.event.inputs.ref || 'develop' }}" + echo "── Deploying ref: $DEPLOY_REF ──" - git fetch origin "$PR_BRANCH" - git checkout "$PR_BRANCH" - git pull origin "$PR_BRANCH" + git fetch origin "$DEPLOY_REF" + if git show-ref --verify --quiet "refs/remotes/origin/$DEPLOY_REF"; then + git checkout -B "$DEPLOY_REF" "origin/$DEPLOY_REF" + else + git checkout --detach FETCH_HEAD + fi echo "── Restarting XMem staging service ──" sudo systemctl restart xmem-staging