From 03d66a5b48dcd62c14fce572ef0971bd97467826 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20Mod=C3=A9e?= Date: Wed, 20 May 2026 11:54:27 +0200 Subject: [PATCH 01/21] Add historical-replay mode for benchmark fairness Activated by setting ForecastQuestion.as_of_date; None (default) preserves live behavior unchanged. Core: * Schema fields published_date_source, cutoff_applied, fetch_strategy, snapshot_timestamp on SearchResult and Document for post-hoc audit * SearchStagePipeline filters post-cutoff results; recovers undated ones from URL slug or Wayback first-seen; drops the unrecoverable * lookup_dashboards rewrites URLs to closest Wayback snapshot at-or- before cutoff; suppresses dashboards with no pre-cutoff snapshot * ExtractionPipeline fetches via Wayback id_ snapshots when cutoff is set; falls back to live with strategy recorded on Document * SearchCache key incorporates as_of_date so replays don't collide * Optional historical_roleplay decomposition prompt * eval_stage/contamination.py adds filter_caught_contamination_rate (explicit lower bound, never silent) and retrieval_free baseline (E4) Hardening surfaced by live testing: * Tavily end_date accepted on the Protocol but not forwarded (verified empirically not to filter) * Wayback CDX retries on 5xx / 429 / timeout with exponential backoff * Sub-queries get cutoff year appended in historical mode * Top-up round with bigger max_results when survivors < threshold * RFC 2822 date parsing for Tavily news topic responses Co-Authored-By: Claude Opus 4.7 --- README.md | 36 +++ bioscancast/extraction/fetcher.py | 68 +++- bioscancast/extraction/pipeline.py | 28 +- bioscancast/filtering/models.py | 16 + bioscancast/schemas/document.py | 10 + .../stages/eval_stage/contamination.py | 194 +++++++++++ .../stages/search_stage/backends/base.py | 14 +- .../backends/google_cse_backend.py | 15 +- .../search_stage/backends/tavily_backend.py | 15 +- bioscancast/stages/search_stage/cache.py | 31 +- .../stages/search_stage/dashboard_lookup.py | 55 +++- .../stages/search_stage/date_recovery.py | 124 +++++++ bioscancast/stages/search_stage/pipeline.py | 303 ++++++++++++++++-- .../search_stage/query_decomposition.py | 42 ++- bioscancast/stages/search_stage/wayback.py | 185 +++++++++++ .../tests/test_contamination_metrics.py | 132 ++++++++ bioscancast/tests/test_cutoff_filtering.py | 226 +++++++++++++ bioscancast/tests/test_date_recovery.py | 95 ++++++ bioscancast/tests/test_extraction_pipeline.py | 2 +- bioscancast/tests/test_historical_topup.py | 253 +++++++++++++++ .../test_search_filtering_integration.py | 4 +- bioscancast/tests/test_search_pipeline.py | 4 +- bioscancast/tests/test_wayback_fetch.py | 115 +++++++ bioscancast/tests/test_wayback_retry.py | 101 ++++++ scripts/probe_tavily_topic.py | 100 ++++++ scripts/test_historical_replay.py | 198 ++++++++++++ 26 files changed, 2294 insertions(+), 72 deletions(-) create mode 100644 bioscancast/stages/eval_stage/contamination.py create mode 100644 bioscancast/stages/search_stage/date_recovery.py create mode 100644 bioscancast/stages/search_stage/wayback.py create mode 100644 bioscancast/tests/test_contamination_metrics.py create mode 100644 bioscancast/tests/test_cutoff_filtering.py create mode 100644 bioscancast/tests/test_date_recovery.py create mode 100644 bioscancast/tests/test_historical_topup.py create mode 100644 bioscancast/tests/test_wayback_fetch.py create mode 100644 bioscancast/tests/test_wayback_retry.py create mode 100644 scripts/probe_tavily_topic.py create mode 100644 scripts/test_historical_replay.py diff --git a/README.md b/README.md index 6376e50..bcb7c26 100644 --- a/README.md +++ b/README.md @@ -366,6 +366,42 @@ human_comparison.py Used to compare model forecasts against human forecasts. +## Historical-replay mode (benchmarking against human forecasters) + +When benchmarking the pipeline against human forecasters on past questions, +the model must not be allowed to see sources that didn't exist (or contained +different content) at the time the human forecasted. Historical-replay mode +enforces this by reading a single per-question field, `ForecastQuestion.as_of_date`: + +- When `as_of_date` is `None` (default), the pipeline behaves exactly as in + live mode. No code paths change. +- When `as_of_date` is set, the search backend receives `end_date=as_of_date`, + the cache key incorporates the cutoff, post-retrieval filtering drops any + result dated after the cutoff (and any undated result whose date cannot be + cheaply recovered), dashboard URLs are rewritten to the closest Wayback + snapshot at or before the cutoff (or suppressed if none exists), and the + extraction stage fetches from Wayback. Wayback fallback to live is logged + at INFO and recorded in `Document.fetch_strategy`, never silent. + +The LLM "historical roleplay" prompt is *not* automatically enabled by +`as_of_date`; it lives behind a separate `historical_roleplay=True` flag on +`SearchStagePipeline` because its effect on query quality is harder to +predict. Turn it on for the benchmark and off for production. + +What this mode does NOT fix: the LLMs themselves were trained on data that +postdates many of our benchmark questions. Retrieval fairness ≠ model +fairness. The `retrieval_free_baseline_forecast` metric in +`bioscancast/stages/eval_stage/contamination.py` reports how well the LLM +forecasts with no evidence at all; a small gap between that and the full +pipeline is itself evidence of training-data leakage and must be reported +alongside the headline Brier/log scores. + +`filter_caught_contamination_rate` is also exposed by the same module. It +is a **lower bound** on contamination — it only counts post-cutoff results +whose `published_date` is known. Undated results and results whose content +changed post-cutoff are invisible to it. Reports MUST surface this caveat; +the metric's docstring repeats it for the same reason. + --- # Datasets diff --git a/bioscancast/extraction/fetcher.py b/bioscancast/extraction/fetcher.py index d0d71ae..6cc0c01 100644 --- a/bioscancast/extraction/fetcher.py +++ b/bioscancast/extraction/fetcher.py @@ -7,6 +7,8 @@ from curl_cffi import requests as curl_requests +from bioscancast.stages.search_stage.wayback import closest_snapshot_before + from .config import ExtractionConfig logger = logging.getLogger(__name__) @@ -25,6 +27,8 @@ class FetchResult: content_bytes: Optional[bytes] fetched_at: datetime error: Optional[str] + fetch_strategy: str = "live" + snapshot_timestamp: Optional[datetime] = None def _sniff_content_type(content: bytes) -> Optional[str]: @@ -51,6 +55,7 @@ def fetch( url: str, *, config: ExtractionConfig | None = None, + as_of_date: Optional[datetime] = None, ) -> FetchResult: """Fetch a URL and return the result. Never raises on network errors. @@ -58,7 +63,56 @@ def fetch( ExtractionConfig.impersonate) to avoid Cloudflare/JA3-based blocks that reject httpx and requests. The impersonation profile sets a matching User-Agent automatically. + + Historical-replay mode: when ``as_of_date`` is set the function first + asks Wayback for the closest capture at-or-before that date and fetches + the raw snapshot bytes via the ``id_`` modifier. The returned FetchResult + carries ``fetch_strategy="wayback"`` and ``snapshot_timestamp`` set to + the capture time. If no snapshot exists, or the Wayback fetch errors, + we fall back to a live fetch and tag the result + ``fetch_strategy="wayback_fallback_to_live"`` so audit reports can see + the leak. The fallback is logged at INFO — never silent. """ + if as_of_date is not None: + snapshot = closest_snapshot_before(url, as_of_date) + if snapshot is not None: + snapshot_dt, snapshot_url = snapshot + wb_result = _fetch_via_curl( + target_url=snapshot_url, + reported_url=url, + config=config, + ) + if wb_result.error is None and wb_result.content_bytes is not None: + wb_result.fetch_strategy = "wayback" + wb_result.snapshot_timestamp = snapshot_dt + return wb_result + logger.info( + "Wayback fetch failed for %s (snapshot %s, error=%s); " + "falling back to live", + url, snapshot_dt.isoformat(), wb_result.error, + ) + else: + logger.info( + "No Wayback snapshot for %s at-or-before %s; falling back to live", + url, as_of_date.isoformat(), + ) + live_result = _fetch_via_curl(target_url=url, reported_url=url, config=config) + live_result.fetch_strategy = "wayback_fallback_to_live" + return live_result + + return _fetch_via_curl(target_url=url, reported_url=url, config=config) + + +def _fetch_via_curl( + *, + target_url: str, + reported_url: str, + config: ExtractionConfig | None, +) -> FetchResult: + """Issue the actual HTTP GET. ``target_url`` is what we hit (may be a + Wayback ``id_`` URL); ``reported_url`` is what we record in + ``FetchResult.url`` so downstream consumers see the original publisher + URL, not archive.org.""" cfg = config or ExtractionConfig() fetched_at = datetime.now(timezone.utc) @@ -66,7 +120,7 @@ def fetch( # curl_cffi's streaming Response is not a context manager in the # installed version, so we close it explicitly in a finally block. response = curl_requests.get( - url, + target_url, stream=True, timeout=cfg.fetch_timeout_seconds, impersonate=cfg.impersonate, @@ -76,7 +130,7 @@ def fetch( content_length = response.headers.get("content-length") if content_length and int(content_length) > cfg.fetch_max_bytes: return FetchResult( - url=url, + url=reported_url, final_url=str(response.url), status_code=response.status_code, content_type=_normalize_content_type( @@ -95,7 +149,7 @@ def fetch( total += len(chunk) if total > cfg.fetch_max_bytes: return FetchResult( - url=url, + url=reported_url, final_url=str(response.url), status_code=response.status_code, content_type=_normalize_content_type( @@ -118,7 +172,7 @@ def fetch( raw_ct = _sniff_content_type(content_bytes) or raw_ct return FetchResult( - url=url, + url=reported_url, final_url=str(response.url), status_code=response.status_code, content_type=raw_ct, @@ -130,10 +184,10 @@ def fetch( response.close() except Exception as exc: - logger.warning("Fetch failed for %s: %s", url, exc) + logger.warning("Fetch failed for %s: %s", target_url, exc) return FetchResult( - url=url, - final_url=url, + url=reported_url, + final_url=reported_url, status_code=None, content_type=None, content_bytes=None, diff --git a/bioscancast/extraction/pipeline.py b/bioscancast/extraction/pipeline.py index 0ae2d99..47dd1f4 100644 --- a/bioscancast/extraction/pipeline.py +++ b/bioscancast/extraction/pipeline.py @@ -19,10 +19,22 @@ class ExtractionPipeline: - """Orchestrates document fetching, parsing, and chunk normalization.""" + """Orchestrates document fetching, parsing, and chunk normalization. - def __init__(self, *, config: ExtractionConfig | None = None) -> None: + ``as_of_date`` opts the fetcher into Wayback-rewrite mode. See + ``bioscancast.extraction.fetcher.fetch`` for the strategy semantics + (live / wayback / wayback_fallback_to_live). The resulting strategy + and snapshot timestamp are copied onto each Document for audit. + """ + + def __init__( + self, + *, + config: ExtractionConfig | None = None, + as_of_date: Optional[datetime] = None, + ) -> None: self._config = config or ExtractionConfig() + self._as_of_date = as_of_date self._parsers = get_parsers(pdf_max_pages=self._config.pdf_max_pages) # Lazily constructed on first PDF that reaches the refiner step. self._docling_refiner = None @@ -54,7 +66,11 @@ def extract_one(self, filtered_doc: FilteredDocument) -> Document: doc_id = f"doc-{filtered_doc.result_id}" # Step 1: Fetch - fetch_result = fetch(filtered_doc.url, config=self._config) + fetch_result = fetch( + filtered_doc.url, + config=self._config, + as_of_date=self._as_of_date, + ) if fetch_result.error or fetch_result.content_bytes is None: return self._make_failed_document( @@ -169,6 +185,9 @@ def extract_one(self, filtered_doc: FilteredDocument) -> Document: chunks=chunks, extracted_tables=extracted_tables, extracted_dates=extracted_dates, + fetch_strategy=fetch_result.fetch_strategy, + snapshot_timestamp=fetch_result.snapshot_timestamp, + cutoff_applied=self._as_of_date, ) def _get_docling_refiner(self): @@ -212,6 +231,9 @@ def _make_failed_document( error_message=error, http_status=fetch_result.status_code if fetch_result else None, content_type=fetch_result.content_type if fetch_result else None, + fetch_strategy=fetch_result.fetch_strategy if fetch_result else "live", + snapshot_timestamp=fetch_result.snapshot_timestamp if fetch_result else None, + cutoff_applied=self._as_of_date, ) def _build_chunks( diff --git a/bioscancast/filtering/models.py b/bioscancast/filtering/models.py index 058b659..facf320 100644 --- a/bioscancast/filtering/models.py +++ b/bioscancast/filtering/models.py @@ -15,6 +15,13 @@ class ForecastQuestion: pathogen: Optional[str] = None event_type: Optional[str] = None resolution_criteria: Optional[str] = None + # Historical-replay cutoff. When None (default), the pipeline runs in live + # mode and uses datetime.now() everywhere. When set, every cutoff-sensitive + # module (freshness scoring, search backend date filter, cache key, + # post-retrieval filter, dashboard Wayback rewrite, extraction Wayback + # rewrite, optional decomposition roleplay) treats this as "now" so the + # model sees only what a human forecaster could have seen at this moment. + as_of_date: Optional[datetime] = None @dataclass @@ -43,6 +50,15 @@ class SearchResult: retrieval_reason: Optional[str] = None contains_aggregator_forecast: bool = False search_stage_score: float = 0.0 + # Provenance for the date used to evaluate the historical-mode cutoff. + # One of: "backend" (Tavily/Google returned a date), "url_slug", + # "last_modified", "wayback_first_seen", "wayback_snapshot" (for dashboards + # rewritten to Wayback), or None (live mode, or date came from the backend + # in a way that didn't go through the recovery chain). + published_date_source: Optional[str] = None + # The as_of_date that was applied when this result was produced, copied + # off the ForecastQuestion. None in live mode. Useful for post-hoc audits. + cutoff_applied: Optional[datetime] = None @dataclass diff --git a/bioscancast/schemas/document.py b/bioscancast/schemas/document.py index d415995..120ee6a 100644 --- a/bioscancast/schemas/document.py +++ b/bioscancast/schemas/document.py @@ -115,3 +115,13 @@ class Document: extracted_dates: List[str] = field(default_factory=list) """Date strings found anywhere in the document, preserved as-is.""" + + # ---- historical-replay provenance ---- + fetch_strategy: str = "live" + """How the bytes were obtained: 'live', 'wayback', or 'wayback_fallback_to_live'.""" + + snapshot_timestamp: Optional[datetime] = None + """Wayback capture timestamp when fetch_strategy == 'wayback'. None otherwise.""" + + cutoff_applied: Optional[datetime] = None + """The as_of_date that was active when this document was fetched. None in live mode.""" diff --git a/bioscancast/stages/eval_stage/contamination.py b/bioscancast/stages/eval_stage/contamination.py new file mode 100644 index 0000000..a237151 --- /dev/null +++ b/bioscancast/stages/eval_stage/contamination.py @@ -0,0 +1,194 @@ +"""Contamination diagnostics for the human-comparison benchmark. + +These metrics are *not* the same as proving fairness. They are reporting +aids: they let a reviewer see how much of the model's evidence base +demonstrably violated the cutoff, and how much of the model's edge over a +human came from training-data leakage rather than retrieval. + +Two metrics live here: + +* ``filter_caught_contamination_rate`` — a LOWER BOUND on contamination. + Counts SearchResults whose ``published_date`` is later than the cutoff + *and that nonetheless reached the final result list*. After the search + stage's cutoff filter runs this should be ~0; in live-mode benchmark + runs it can be substantial. Undated post-cutoff content is invisible + to this metric — that is the largest source of contamination it cannot + see, and reports MUST say so. + +* ``retrieval_free_baseline_forecast`` — asks the LLM to forecast with no + retrieved evidence, then scores it like any other forecast. A small gap + between this baseline and the full pipeline is itself evidence of + training-data leakage in the model's weights, distinct from leakage in + the retrieval pipeline. +""" + +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass +from datetime import datetime +from typing import Iterable, List, Optional + +from bioscancast.filtering.models import ForecastQuestion, SearchResult +from bioscancast.llm.client import LLMClient + +logger = logging.getLogger(__name__) + + +# Phrased verbosely on purpose: a non-coder reading the eval report should +# not confuse "filter-caught" with "absolute". +FILTER_CAUGHT_CONTAMINATION_LOWER_BOUND_LABEL = ( + "filter_caught_contamination_rate " + "(LOWER BOUND — undated post-cutoff content is not counted)" +) + + +@dataclass +class ContaminationCounts: + total: int + post_cutoff_in_final: int + undated_in_final: int + pre_cutoff_in_final: int + + @property + def filter_caught_rate(self) -> float: + """Share of the final results whose dated publication is post-cutoff. + + IMPORTANT: this is a lower bound. It does not count undated results, + and it does not count pages whose content was edited after the + cutoff but whose first-publication date is pre-cutoff. Reports + derived from this metric must surface that caveat. + """ + if self.total == 0: + return 0.0 + return self.post_cutoff_in_final / self.total + + +def filter_caught_contamination_rate( + final_results: Iterable[SearchResult], + as_of: datetime, +) -> ContaminationCounts: + """Count post-cutoff results that nevertheless reached the final list. + + Pass the SearchResult list that the search stage *returned* (i.e. after + its own cutoff filter ran). In a well-behaved historical-replay run the + post_cutoff_in_final count should be 0; in a live-mode run it usually + won't be. + """ + post_cutoff = 0 + undated = 0 + pre_cutoff = 0 + total = 0 + for r in final_results: + total += 1 + if r.published_date is None: + undated += 1 + elif r.published_date > as_of: + post_cutoff += 1 + else: + pre_cutoff += 1 + return ContaminationCounts( + total=total, + post_cutoff_in_final=post_cutoff, + undated_in_final=undated, + pre_cutoff_in_final=pre_cutoff, + ) + + +@dataclass +class BaselineForecast: + question_id: str + options: List[str] + probabilities: List[float] + rationale: Optional[str] = None + + +def retrieval_free_baseline_forecast( + question: ForecastQuestion, + options: List[str], + llm_client: LLMClient, +) -> BaselineForecast: + """Ask the LLM to forecast the question with NO retrieved evidence. + + The gap between this baseline and the full-pipeline forecast quantifies + how much of the model's signal comes from retrieval vs. training-data + knowledge. A small gap on a 2024 question to a 2026-trained LLM is + strong evidence that the LLM already "knew the answer" — which is a + separate fairness problem from retrieval leakage that no amount of + pipeline filtering can fix. Report alongside Brier/log scores, never + in place of them. + """ + prompt = json.dumps( + { + "task": ( + "Forecast the probability of each option for this biosecurity " + "question using ONLY your prior knowledge. Do not assume any " + "additional research has been done. Return strict JSON: " + "{\"probabilities\": [], " + "\"rationale\": \"\"}. Probabilities must sum to 1." + ), + "question": question.text, + "pathogen": question.pathogen, + "region": question.region, + "target_date": ( + question.target_date.isoformat() if question.target_date else None + ), + "as_of_date": ( + question.as_of_date.date().isoformat() if question.as_of_date else None + ), + "options": options, + } + ) + try: + result = llm_client.generate_json(prompt) + except Exception: + logger.exception("Retrieval-free baseline LLM call failed for %s", question.id) + uniform = [1.0 / len(options)] * len(options) + return BaselineForecast( + question_id=question.id, + options=options, + probabilities=uniform, + rationale="LLM call failed; uniform fallback", + ) + + raw_probs = result.get("probabilities") or [] + if not isinstance(raw_probs, list) or len(raw_probs) != len(options): + logger.warning( + "Baseline LLM returned malformed probabilities for %s: %r", + question.id, raw_probs, + ) + uniform = [1.0 / len(options)] * len(options) + return BaselineForecast( + question_id=question.id, + options=options, + probabilities=uniform, + rationale="Malformed LLM output; uniform fallback", + ) + try: + probs = [float(p) for p in raw_probs] + except (TypeError, ValueError): + uniform = [1.0 / len(options)] * len(options) + return BaselineForecast( + question_id=question.id, + options=options, + probabilities=uniform, + rationale="Non-numeric LLM output; uniform fallback", + ) + + total = sum(probs) + if total <= 0: + uniform = [1.0 / len(options)] * len(options) + return BaselineForecast( + question_id=question.id, + options=options, + probabilities=uniform, + rationale="Zero-sum LLM output; uniform fallback", + ) + probs = [p / total for p in probs] + return BaselineForecast( + question_id=question.id, + options=options, + probabilities=probs, + rationale=result.get("rationale"), + ) diff --git a/bioscancast/stages/search_stage/backends/base.py b/bioscancast/stages/search_stage/backends/base.py index ad2ec1f..24c2e2d 100644 --- a/bioscancast/stages/search_stage/backends/base.py +++ b/bioscancast/stages/search_stage/backends/base.py @@ -17,6 +17,16 @@ class RawSearchResult: class SearchBackend(Protocol): - """Interface that all search backends must satisfy.""" + """Interface that all search backends must satisfy. - def search(self, query: str, max_results: int = 10) -> List[RawSearchResult]: ... + ``end_date`` is an optional YYYY-MM-DD upper bound used by historical- + replay mode. Backends that don't support it should accept and ignore it + (the post-retrieval cutoff filter in the pipeline will still apply). + """ + + def search( + self, + query: str, + max_results: int = 10, + end_date: Optional[str] = None, + ) -> List[RawSearchResult]: ... diff --git a/bioscancast/stages/search_stage/backends/google_cse_backend.py b/bioscancast/stages/search_stage/backends/google_cse_backend.py index dbeb19b..7cbfaa1 100644 --- a/bioscancast/stages/search_stage/backends/google_cse_backend.py +++ b/bioscancast/stages/search_stage/backends/google_cse_backend.py @@ -7,7 +7,7 @@ from __future__ import annotations -from typing import List +from typing import List, Optional from .base import RawSearchResult @@ -15,9 +15,16 @@ class GoogleCSEBackend: """Stub backend — raises NotImplementedError on use.""" - def search(self, query: str, max_results: int = 10) -> List[RawSearchResult]: + def search( + self, + query: str, + max_results: int = 10, + end_date: Optional[str] = None, + ) -> List[RawSearchResult]: raise NotImplementedError( "GoogleCSEBackend is a stub. Implement using the Google Custom Search " - "JSON API ($5/1k queries after 100/day free tier). See base.py for the " - "SearchBackend protocol." + "JSON API ($5/1k queries after 100/day free tier). When implementing, " + "the YYYY-MM-DD `end_date` argument should be honoured via the CSE " + "`sort=date:r:YYYYMMDD:YYYYMMDD` parameter for historical-replay mode. " + "See base.py for the SearchBackend protocol." ) diff --git a/bioscancast/stages/search_stage/backends/tavily_backend.py b/bioscancast/stages/search_stage/backends/tavily_backend.py index 5f5c8bd..490c8cc 100644 --- a/bioscancast/stages/search_stage/backends/tavily_backend.py +++ b/bioscancast/stages/search_stage/backends/tavily_backend.py @@ -31,7 +31,20 @@ def __init__(self, api_key: Optional[str] = None) -> None: "TAVILY_API_KEY is required. Set it in your environment or pass api_key." ) - def search(self, query: str, max_results: int = 10) -> List[RawSearchResult]: + def search( + self, + query: str, + max_results: int = 10, + end_date: Optional[str] = None, + ) -> List[RawSearchResult]: + # NOTE: end_date is accepted to satisfy the SearchBackend Protocol + # but is NOT forwarded to the Tavily SDK. Empirically (Feb 2025 test + # run on q1) the Tavily client returned articles dated well after + # the supplied end_date — the parameter is documented in the SDK + # signature but does not appear to actually filter on the result's + # published_date. The post-retrieval cutoff filter in the + # SearchStagePipeline is the real defense; do not be tempted to + # rely on this parameter without re-verifying Tavily's behavior. from tavily import TavilyClient # lazy import to avoid hard dep at import time client = TavilyClient(api_key=self._api_key) diff --git a/bioscancast/stages/search_stage/cache.py b/bioscancast/stages/search_stage/cache.py index 4bb67c0..f2ae522 100644 --- a/bioscancast/stages/search_stage/cache.py +++ b/bioscancast/stages/search_stage/cache.py @@ -38,15 +38,28 @@ def __init__(self, db_path: str = "data/cache/search_cache.sqlite") -> None: self._conn.commit() @staticmethod - def _make_key(backend_name: str, query: str) -> str: - date_bucket = datetime.now(timezone.utc).strftime("%Y-%m-%d") + def _make_key( + backend_name: str, + query: str, + as_of_date: Optional[datetime] = None, + ) -> str: + # In historical-replay mode the bucket is the cutoff date, so that two + # benchmark runs against different cutoffs never share cache entries. + if as_of_date is not None: + date_bucket = as_of_date.strftime("%Y-%m-%d") + else: + date_bucket = datetime.now(timezone.utc).strftime("%Y-%m-%d") raw = f"{backend_name}|{query.strip().lower()}|{date_bucket}" return hashlib.sha256(raw.encode()).hexdigest() def get( - self, backend_name: str, query: str, max_age_hours: int = 24 + self, + backend_name: str, + query: str, + max_age_hours: int = 24, + as_of_date: Optional[datetime] = None, ) -> Optional[List[RawSearchResult]]: - key = self._make_key(backend_name, query) + key = self._make_key(backend_name, query, as_of_date) row = self._conn.execute( "SELECT results_json, created_at FROM search_cache WHERE cache_key = ?", (key,), @@ -63,8 +76,14 @@ def get( items = json.loads(row[0]) return [RawSearchResult(**item) for item in items] - def put(self, backend_name: str, query: str, results: List[RawSearchResult]) -> None: - key = self._make_key(backend_name, query) + def put( + self, + backend_name: str, + query: str, + results: List[RawSearchResult], + as_of_date: Optional[datetime] = None, + ) -> None: + key = self._make_key(backend_name, query, as_of_date) payload = json.dumps( [ { diff --git a/bioscancast/stages/search_stage/dashboard_lookup.py b/bioscancast/stages/search_stage/dashboard_lookup.py index 99aea4f..e3784c3 100644 --- a/bioscancast/stages/search_stage/dashboard_lookup.py +++ b/bioscancast/stages/search_stage/dashboard_lookup.py @@ -1,10 +1,20 @@ """Dashboard lookup — inject known pathogen dashboard URLs as SearchResults. +In live mode this returns the live dashboard URL with a synthetic +``published_date=None`` and freshness=1.0 — a sensible signal that the +dashboard "is current". In historical-replay mode (``question.as_of_date`` +set), live dashboards are dangerous: they return today's case counts even +for a question created in early 2025. We therefore look up the closest +Wayback snapshot at-or-before the cutoff and rewrite the URL; if no +pre-cutoff snapshot exists, we suppress the dashboard entirely rather +than fall back to live. + v1 — flagged for iteration after first benchmark run. """ from __future__ import annotations +import logging import uuid from datetime import datetime, timezone from typing import List @@ -16,14 +26,23 @@ extract_domain, normalize_url, ) +from bioscancast.stages.search_stage.wayback import closest_snapshot_before + +logger = logging.getLogger(__name__) def lookup_dashboards(question: ForecastQuestion) -> List[SearchResult]: """Generate synthetic SearchResult entries for known pathogen dashboards. - If ``question.pathogen`` (lowercased) matches a key in DASHBOARD_LOOKUP, - returns a SearchResult for each URL with rank=0 and - retrieval_reason="dashboard_lookup". Returns empty list if no match. + Live mode: returns one SearchResult per URL with rank=0 and + retrieval_reason="dashboard_lookup". + + Historical-replay mode (``question.as_of_date`` is not None): for each + URL, looks up the closest Wayback snapshot at-or-before the cutoff and + emits a SearchResult pointing at the snapshot. Dashboards with no + pre-cutoff snapshot are suppressed entirely (NOT fallen-back to live) + because live dashboards return today's counts and would silently + contaminate the benchmark. """ if not question.pathogen: return [] @@ -33,11 +52,32 @@ def lookup_dashboards(question: ForecastQuestion) -> List[SearchResult]: if not urls: return [] + as_of = question.as_of_date results: list[SearchResult] = [] now = datetime.now(timezone.utc) for url in urls: - domain = extract_domain(url) + if as_of is not None: + snapshot = closest_snapshot_before(url, as_of) + if snapshot is None: + logger.info( + "Suppressing dashboard %s — no Wayback snapshot at-or-before %s", + url, as_of.isoformat(), + ) + continue + snapshot_dt, snapshot_url = snapshot + effective_url = snapshot_url + published_date: datetime | None = snapshot_dt + published_date_source = "wayback_snapshot" + # Keep ``domain`` as the original publisher for tier scoring; + # the URL itself points at archive.org for fetching. + domain = extract_domain(url) + else: + effective_url = url + published_date = None + published_date_source = None + domain = extract_domain(url) + tier_num, domain_score, source_tier = resolve_tier(domain) results.append( @@ -46,13 +86,14 @@ def lookup_dashboards(question: ForecastQuestion) -> List[SearchResult]: question_id=question.id, query_id=f"dashboard_{question.id}", engine="dashboard", - url=url, - canonical_url=normalize_url(url), + url=effective_url, + canonical_url=normalize_url(effective_url), domain=domain, title=f"Dashboard: {domain}", snippet=f"Known {pathogen_key} monitoring dashboard", rank=0, retrieved_at=now, + published_date=published_date, is_official_domain=(tier_num == 1 and source_tier == "official"), source_tier=source_tier, domain_score=domain_score, @@ -60,6 +101,8 @@ def lookup_dashboards(question: ForecastQuestion) -> List[SearchResult]: retrieval_reason="dashboard_lookup", contains_aggregator_forecast=is_aggregator_domain(domain), search_stage_score=0.0, # computed later by pipeline + published_date_source=published_date_source, + cutoff_applied=as_of, ) ) diff --git a/bioscancast/stages/search_stage/date_recovery.py b/bioscancast/stages/search_stage/date_recovery.py new file mode 100644 index 0000000..2c851c0 --- /dev/null +++ b/bioscancast/stages/search_stage/date_recovery.py @@ -0,0 +1,124 @@ +"""Recover a plausible publication date for a SearchResult whose backend +didn't supply one. + +Why this exists: in historical-replay mode the pipeline must drop any source +it cannot date, because undated pages are exactly where post-cutoff content +can hide (a page first published before the cutoff but rewritten afterwards +will still report no ``published_date`` from Tavily). Soft-allowing undated +results would silently defeat the benchmark; this module instead tries cheap +external signals before giving up. + +Recovery strategies, cheapest first: + +1. URL slug regex (``/2024/03/15/...`` and ``/2024-03-15/...``) — free, no + network call. Catches most news organisations. +2. ``Last-Modified`` header via HEAD request — off by default. Requires + passing a fetcher callable explicitly; the search stage does not normally + carry an HTTP client. Opt in only when you need the recall. +3. Wayback Machine first-seen — one CDX call per URL. Conservative: "first + archived" is an upper bound on first published, so a pre-cutoff first- + seen is sound evidence the page existed before the cutoff. + +Each function returns ``Optional[datetime]`` and never raises on network or +parse errors (it logs and returns ``None``). +""" + +from __future__ import annotations + +import logging +import re +from datetime import datetime, timezone +from typing import Callable, Optional + +from .wayback import first_seen as _wayback_first_seen + +logger = logging.getLogger(__name__) + +# Matches /YYYY/MM/DD/ and /YYYY/MM/ and /YYYY-MM-DD/ within a URL path. +_URL_DATE_PATTERNS = [ + re.compile(r"/(\d{4})/(\d{1,2})/(\d{1,2})(?:/|$|[?#])"), + re.compile(r"/(\d{4})-(\d{1,2})-(\d{1,2})(?:/|$|[?#])"), + re.compile(r"/(\d{4})/(\d{1,2})(?:/|$|[?#])"), +] + + +def date_from_url_slug(url: str) -> Optional[datetime]: + """Extract a date from common URL slug patterns. Returns midnight UTC of + the matched date, or None if no pattern matches or the date is invalid.""" + for pattern in _URL_DATE_PATTERNS: + m = pattern.search(url) + if not m: + continue + groups = m.groups() + try: + year = int(groups[0]) + month = int(groups[1]) + day = int(groups[2]) if len(groups) >= 3 else 1 + if year < 1990 or year > 2100: + continue # almost certainly not a date + return datetime(year, month, day, tzinfo=timezone.utc) + except (ValueError, IndexError): + continue + return None + + +def date_from_last_modified( + url: str, head_fetcher: Optional[Callable[[str], Optional[str]]] = None +) -> Optional[datetime]: + """Issue a HEAD request and parse the Last-Modified header. + + The caller must pass a ``head_fetcher`` callable that returns the + Last-Modified header string (or None). The search stage does not have a + built-in HTTP client, so this is dependency-injected to avoid an awkward + import of ``curl_cffi`` into the search-stage package. Off by default. + """ + if head_fetcher is None: + return None + try: + header = head_fetcher(url) + except Exception as exc: + logger.warning("HEAD request failed for %s: %s", url, exc) + return None + if not header: + return None + # RFC 7231 format: "Wed, 21 Oct 2015 07:28:00 GMT" + for fmt in ("%a, %d %b %Y %H:%M:%S %Z", "%a, %d %b %Y %H:%M:%S GMT"): + try: + dt = datetime.strptime(header, fmt) + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + return dt + except ValueError: + continue + return None + + +def date_from_wayback_first_seen(url: str) -> Optional[datetime]: + """Earliest Wayback capture timestamp for ``url`` as an upper bound on + first publication. Returns None on lookup failure.""" + return _wayback_first_seen(url) + + +def recover_published_date( + url: str, + head_fetcher: Optional[Callable[[str], Optional[str]]] = None, + use_wayback: bool = True, +) -> tuple[Optional[datetime], Optional[str]]: + """Try each strategy in order. Returns (date, source_label) where the + label is one of ``"url_slug" | "last_modified" | "wayback_first_seen"`` + on success, or (None, None) when no strategy yielded a date. + """ + dt = date_from_url_slug(url) + if dt is not None: + return dt, "url_slug" + + dt = date_from_last_modified(url, head_fetcher=head_fetcher) + if dt is not None: + return dt, "last_modified" + + if use_wayback: + dt = date_from_wayback_first_seen(url) + if dt is not None: + return dt, "wayback_first_seen" + + return None, None diff --git a/bioscancast/stages/search_stage/pipeline.py b/bioscancast/stages/search_stage/pipeline.py index 7882ab0..2be52fa 100644 --- a/bioscancast/stages/search_stage/pipeline.py +++ b/bioscancast/stages/search_stage/pipeline.py @@ -9,6 +9,7 @@ import logging import uuid from datetime import datetime, timezone +from email.utils import parsedate_to_datetime from typing import List, Optional from bioscancast.filtering.config import FILTER_CONFIG @@ -17,6 +18,7 @@ from bioscancast.stages.search_stage.backends.base import RawSearchResult, SearchBackend from bioscancast.stages.search_stage.cache import SearchCache from bioscancast.stages.search_stage.dashboard_lookup import lookup_dashboards +from bioscancast.stages.search_stage.date_recovery import recover_published_date from bioscancast.stages.search_stage.query_decomposition import SubQuery, decompose_question from bioscancast.stages.search_stage.tier_resolution import ( is_aggregator_domain, @@ -31,14 +33,23 @@ _NON_CONTENT_EXTENSIONS: set[str] = {".zip", ".exe", ".msi", ".dmg", ".tar", ".gz", ".png", ".jpg", ".jpeg", ".gif", ".svg", ".mp4", ".mp3"} -def _compute_freshness(published_date: Optional[datetime]) -> float: +def _compute_freshness( + published_date: Optional[datetime], + *, + reference_date: Optional[datetime] = None, +) -> float: """Compute freshness score from published_date. - Returns 0.5 (neutral) when no date is available, per spec. + Returns 0.5 (neutral) when no date is available, per spec. ``reference_date`` + is the "now" against which age is measured; in historical-replay mode the + pipeline passes ``question.as_of_date`` so freshness is judged from the + human forecaster's vantage point. Defaults to wall-clock ``now`` for + live mode. """ if published_date is None: return 0.5 - days_old = (datetime.now(timezone.utc) - published_date).days + ref = reference_date or datetime.now(timezone.utc) + days_old = (ref - published_date).days if days_old < 0: return 1.0 return max(0.0, min(1.0, 1.0 - (days_old / 365.0))) @@ -52,7 +63,14 @@ def _compute_search_stage_score(domain_score: float, freshness_score: float, ran def _parse_published_date(date_str: Optional[str]) -> Optional[datetime]: - """Best-effort parse of backend-provided published_date strings.""" + """Best-effort parse of backend-provided published_date strings. + + Tavily inconsistently returns either ISO-8601 (``2025-02-17`` or + ``2025-02-17T13:00:00+00:00``) or RFC 2822 (``Tue, 19 May 2026 13:00:00 + GMT``) depending on the search topic, so we try both. Returning None + here is expensive in historical mode (it triggers the date-recovery + chain), so it matters that we cover the formats Tavily actually emits. + """ if not date_str: return None for fmt in ("%Y-%m-%dT%H:%M:%S%z", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%d"): @@ -63,7 +81,16 @@ def _parse_published_date(date_str: Optional[str]) -> Optional[datetime]: return dt except ValueError: continue - return None + # RFC 2822 fallback — what Tavily's news topic actually returns. + try: + dt = parsedate_to_datetime(date_str) + if dt is None: + return None + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + return dt + except (TypeError, ValueError): + return None def _is_non_content_url(url: str) -> bool: @@ -73,7 +100,25 @@ def _is_non_content_url(url: str) -> bool: class SearchStagePipeline: - """Orchestrates the full search stage: decompose → search → score → deduplicate.""" + """Orchestrates the full search stage: decompose → search → score → deduplicate. + + Historical-replay mode is activated implicitly by ``question.as_of_date``. + When that field is non-None: + + * the Tavily/CSE backend receives ``end_date=as_of_date`` and is asked to + restrict results to pages dated on or before the cutoff, + * the cache key incorporates the cutoff so replay runs don't see each + other's results, + * freshness scoring uses the cutoff as "now" rather than wall-clock time, + * a post-retrieval filter drops anything dated after the cutoff and any + undated result whose date can't be recovered from a cheap fallback, + * the dashboard injection rewrites URLs to closest Wayback snapshots, + suppressing dashboards with no pre-cutoff snapshot entirely, + * (opt-in) the LLM decomposition prompt is asked to roleplay the cutoff. + + The ``historical_roleplay`` constructor flag controls only the last item; + everything else is implicit on ``as_of_date``. + """ def __init__( self, @@ -83,6 +128,10 @@ def __init__( results_per_query: int = 10, total_cap: int = 60, backend_name: str = "tavily", + historical_roleplay: bool = False, + min_post_filter_results: int = 10, + top_up_results_per_query: int = 50, + max_top_up_rounds: int = 1, ) -> None: self._backend = search_backend self._llm = llm_client @@ -90,70 +139,204 @@ def __init__( self._results_per_query = results_per_query self._total_cap = total_cap self._backend_name = backend_name + self._historical_roleplay = historical_roleplay + # Top-up parameters apply only in historical-replay mode. In live + # mode the initial pass is always considered sufficient. + self._min_post_filter_results = min_post_filter_results + self._top_up_results_per_query = top_up_results_per_query + self._max_top_up_rounds = max_top_up_rounds def run(self, question: ForecastQuestion) -> List[SearchResult]: """Execute the full search stage pipeline.""" + as_of = question.as_of_date + # 1. Decompose question into sub-queries - sub_queries = decompose_question(question, self._llm) + sub_queries = decompose_question( + question, + self._llm, + historical_roleplay=self._historical_roleplay, + ) logger.info("Decomposed into %d sub-queries", len(sub_queries)) # 2. Inject dashboard lookups dashboard_results = lookup_dashboards(question) logger.info("Dashboard lookup produced %d results", len(dashboard_results)) - # 3. Execute searches per sub-query + # 3. Execute initial search round all_results: list[SearchResult] = list(dashboard_results) - for sq in sub_queries: - raw_results = self._execute_search(sq.text) - for rank_offset, raw in enumerate(raw_results): - result = self._convert(raw, sq, question.id, rank_offset + 1) - all_results.append(result) - - if len(all_results) >= self._total_cap: - logger.info("Hit total cap of %d results before all sub-queries", self._total_cap) - break - - # 4. Deduplicate - deduped = self._deduplicate(all_results) - - # 5. Hard exclusions - filtered = self._apply_exclusions(deduped) + seen_canonical: set[str] = set() + for r in dashboard_results: + if r.canonical_url: + seen_canonical.add(r.canonical_url) + + all_results, seen_canonical = self._search_round( + sub_queries, + question, + as_of, + max_results=self._results_per_query, + collected=all_results, + seen_canonical=seen_canonical, + stop_cap=self._total_cap, + ) - # 6. Compute search_stage_score + # 4-6. Dedup → exclusions → cutoff filter + filtered = self._dedup_exclude_cutoff(all_results, as_of) + + # 6b. Top-up: in historical mode only, if we're below the survivor + # threshold, run additional rounds with a larger results_per_query + # to fish for older content. Tavily's end_date doesn't actually + # filter (see tavily_backend.py), so the candidate pool is heavily + # biased toward post-cutoff content for recent topics — top-up is + # what makes historical mode actually return usable results. + if as_of is not None: + rounds_done = 0 + while ( + rounds_done < self._max_top_up_rounds + and len(filtered) < self._min_post_filter_results + ): + rounds_done += 1 + logger.info( + "Historical top-up round %d: have %d survivors, want >= %d", + rounds_done, len(filtered), self._min_post_filter_results, + ) + all_results, seen_canonical = self._search_round( + sub_queries, + question, + as_of, + max_results=self._top_up_results_per_query, + collected=all_results, + seen_canonical=seen_canonical, + # Allow many more candidates than the final cap because + # most will be dropped by the cutoff filter. + stop_cap=self._total_cap * 10, + ) + filtered = self._dedup_exclude_cutoff(all_results, as_of) + + if len(filtered) < self._min_post_filter_results: + logger.warning( + "Historical top-up exhausted: %d survivors after %d round(s) " + "(target was %d). Returning what we have.", + len(filtered), rounds_done, self._min_post_filter_results, + ) + + # 7. Compute search_stage_score (freshness measured from cutoff in + # historical mode, wall-clock in live mode) for r in filtered: + r.freshness_score = _compute_freshness( + r.published_date, reference_date=as_of + ) r.search_stage_score = _compute_search_stage_score( r.domain_score, r.freshness_score, r.rank ) - # 7. Sort and cap + # 8. Sort and cap filtered.sort(key=lambda r: r.search_stage_score, reverse=True) result = filtered[: self._total_cap] logger.info("Search stage returning %d results", len(result)) return result - def _execute_search(self, query: str) -> List[RawSearchResult]: + def _search_round( + self, + sub_queries: List[SubQuery], + question: ForecastQuestion, + as_of: Optional[datetime], + *, + max_results: int, + collected: list[SearchResult], + seen_canonical: set[str], + stop_cap: int, + ) -> tuple[list[SearchResult], set[str]]: + """Issue each sub-query and append converted SearchResults to + ``collected``, skipping any URL already in ``seen_canonical``. + Returns the updated list and seen-set. Stops early when the + collected list reaches ``stop_cap``. + """ + for sq in sub_queries: + query_text = self._apply_year_hint(sq.text, as_of) + raw_results = self._execute_search( + query_text, as_of_date=as_of, max_results=max_results + ) + for rank_offset, raw in enumerate(raw_results): + canonical = normalize_url(raw.url) + if canonical and canonical in seen_canonical: + continue + result = self._convert(raw, sq, question.id, rank_offset + 1, as_of) + collected.append(result) + if canonical: + seen_canonical.add(canonical) + if len(collected) >= stop_cap: + logger.info( + "Stopping search round at %d collected results (cap=%d)", + len(collected), stop_cap, + ) + break + return collected, seen_canonical + + def _dedup_exclude_cutoff( + self, results: list[SearchResult], as_of: Optional[datetime] + ) -> list[SearchResult]: + """Run dedup → hard exclusions → cutoff filter (historical mode only).""" + deduped = self._deduplicate(results) + filtered = self._apply_exclusions(deduped) + if as_of is not None: + filtered = self._apply_cutoff_filter(filtered, as_of) + return filtered + + @staticmethod + def _apply_year_hint(query: str, as_of: Optional[datetime]) -> str: + """In historical mode, append the cutoff year to the query so the + search backend's lexical match biases toward dated content. Empirically + Tavily ignores the ``end_date`` parameter, so query-text steering is + what actually drags results toward the right time period. No-op in + live mode.""" + if as_of is None: + return query + year = as_of.year + # Avoid double-hinting if the LLM already put the year in. + if str(year) in query: + return query + return f"{query} {year}" + + def _execute_search( + self, + query: str, + as_of_date: Optional[datetime] = None, + max_results: Optional[int] = None, + ) -> List[RawSearchResult]: # TODO: multilingual support + # end_date is passed for Protocol compliance but TavilyBackend + # explicitly does not forward it (see backends/tavily_backend.py). + end_date_str = as_of_date.strftime("%Y-%m-%d") if as_of_date else None + effective_max = max_results if max_results is not None else self._results_per_query if self._cache: - cached = self._cache.get(self._backend_name, query) + cached = self._cache.get(self._backend_name, query, as_of_date=as_of_date) if cached is not None: logger.debug("Cache hit for query: %s", query) return cached - results = self._backend.search(query, max_results=self._results_per_query) + results = self._backend.search( + query, max_results=effective_max, end_date=end_date_str + ) if self._cache: - self._cache.put(self._backend_name, query, results) + self._cache.put(self._backend_name, query, results, as_of_date=as_of_date) return results def _convert( - self, raw: RawSearchResult, sub_query: SubQuery, question_id: str, rank: int + self, + raw: RawSearchResult, + sub_query: SubQuery, + question_id: str, + rank: int, + as_of_date: Optional[datetime] = None, ) -> SearchResult: domain = extract_domain(raw.url) canonical = normalize_url(raw.url) tier_num, domain_score, source_tier = resolve_tier(domain) published = _parse_published_date(raw.published_date) - freshness = _compute_freshness(published) + freshness = _compute_freshness(published, reference_date=as_of_date) + published_date_source = "backend" if published is not None else None return SearchResult( id=uuid.uuid4().hex, @@ -177,6 +360,8 @@ def _convert( # kept in results so downstream analysis can measure contamination effects. contains_aggregator_forecast=is_aggregator_domain(domain), search_stage_score=0.0, # computed after dedup + published_date_source=published_date_source, + cutoff_applied=as_of_date, ) def _deduplicate(self, results: List[SearchResult]) -> List[SearchResult]: @@ -219,6 +404,60 @@ def _apply_exclusions(self, results: List[SearchResult]) -> List[SearchResult]: kept.append(r) return kept + def _apply_cutoff_filter( + self, results: List[SearchResult], as_of: datetime + ) -> List[SearchResult]: + """Historical-replay mode: keep only results that demonstrably existed + before ``as_of``. Drop post-cutoff and undatable results. + + Wayback-snapshot dashboards already have ``published_date`` set to the + capture timestamp by ``dashboard_lookup``; this filter is therefore + idempotent on them. + """ + dropped_post_cutoff = 0 + dropped_undatable = 0 + recovered = 0 + kept: list[SearchResult] = [] + for r in results: + if r.published_date is not None: + if r.published_date > as_of: + dropped_post_cutoff += 1 + logger.debug( + "Cutoff filter: dropping post-cutoff %s (pub=%s, cutoff=%s)", + r.url, r.published_date.isoformat(), as_of.isoformat(), + ) + continue + kept.append(r) + continue + + # Undated — try the recovery chain + recovered_date, source = recover_published_date(r.url) + if recovered_date is None: + dropped_undatable += 1 + logger.debug( + "Cutoff filter: dropping %s (no_date_available)", r.url + ) + continue + if recovered_date > as_of: + dropped_post_cutoff += 1 + logger.debug( + "Cutoff filter: recovered date %s > cutoff for %s", + recovered_date.isoformat(), r.url, + ) + continue + r.published_date = recovered_date + r.published_date_source = source + recovered += 1 + kept.append(r) + + logger.info( + "Cutoff filter: kept=%d, recovered=%d, dropped_post_cutoff=%d, " + "dropped_undatable=%d (cutoff=%s)", + len(kept), recovered, dropped_post_cutoff, dropped_undatable, + as_of.isoformat(), + ) + return kept + def run_search_stage( question: ForecastQuestion, @@ -226,6 +465,7 @@ def run_search_stage( llm_client: LLMClient, cache: Optional[SearchCache] = None, backend_name: str = "tavily", + historical_roleplay: bool = False, ) -> List[SearchResult]: """Convenience function to run the search stage pipeline.""" pipeline = SearchStagePipeline( @@ -233,5 +473,6 @@ def run_search_stage( llm_client=llm_client, cache=cache, backend_name=backend_name, + historical_roleplay=historical_roleplay, ) return pipeline.run(question) diff --git a/bioscancast/stages/search_stage/query_decomposition.py b/bioscancast/stages/search_stage/query_decomposition.py index 7456cf0..a1b14d9 100644 --- a/bioscancast/stages/search_stage/query_decomposition.py +++ b/bioscancast/stages/search_stage/query_decomposition.py @@ -88,16 +88,29 @@ def classify_question_type(question: ForecastQuestion, llm_client: LLMClient) -> return "unknown" -def _build_decomposition_prompt(question: ForecastQuestion, question_type: str) -> str: +def _build_decomposition_prompt( + question: ForecastQuestion, + question_type: str, + historical_roleplay: bool = False, +) -> str: axes = AXES_BY_TYPE.get(question_type, list(VALID_AXES)) + task_lines = [ + "Decompose this biosecurity forecast question into 5-8 search-engine-optimised " + "sub-queries. Each sub-query should be 2-8 words and target a specific information " + "axis. Return strict JSON: {\"sub_queries\": [{\"text\": \"...\", \"axis\": \"...\"}]}. " + "No prose." + ] + if historical_roleplay and question.as_of_date is not None: + task_lines.append( + "IMPORTANT: Generate sub-queries as if today were " + f"{question.as_of_date.date().isoformat()}. Do not assume knowledge " + "of events, named entities, or facts that you only learned about " + "after that date. Phrase queries in terms a forecaster on that " + "date would have used." + ) return json.dumps( { - "task": ( - "Decompose this biosecurity forecast question into 5-8 search-engine-optimised " - "sub-queries. Each sub-query should be 2-8 words and target a specific information " - "axis. Return strict JSON: {\"sub_queries\": [{\"text\": \"...\", \"axis\": \"...\"}]}. " - "No prose." - ), + "task": " ".join(task_lines), "question": question.text, "pathogen": question.pathogen, "region": question.region, @@ -158,14 +171,25 @@ def _fallback_subqueries(question: ForecastQuestion) -> List[SubQuery]: def decompose_question( - question: ForecastQuestion, llm_client: LLMClient + question: ForecastQuestion, + llm_client: LLMClient, + *, + historical_roleplay: bool = False, ) -> List[SubQuery]: """Decompose a forecast question into sub-queries using an LLM. Falls back to simple keyword-based sub-queries if the LLM fails. + + ``historical_roleplay`` is an opt-in benchmark-only flag. When True AND + ``question.as_of_date`` is set, the prompt is extended with an instruction + asking the LLM to query as if today were the cutoff date. This is gated + behind its own flag because prompt-level roleplay can have hard-to-predict + effects on query quality. """ question_type = classify_question_type(question, llm_client) - prompt = _build_decomposition_prompt(question, question_type) + prompt = _build_decomposition_prompt( + question, question_type, historical_roleplay=historical_roleplay + ) try: result = llm_client.generate_json(prompt) diff --git a/bioscancast/stages/search_stage/wayback.py b/bioscancast/stages/search_stage/wayback.py new file mode 100644 index 0000000..8cb3f60 --- /dev/null +++ b/bioscancast/stages/search_stage/wayback.py @@ -0,0 +1,185 @@ +"""Tiny Wayback Machine CDX client used by historical-replay mode. + +Two callers need this: + +* the search-stage dashboard rewrite (closest snapshot at-or-before cutoff), +* the extraction-stage fetcher (same lookup, then fetch the snapshot bytes), +* and the date-recovery chain in the search stage (first-seen capture). + +The implementation deliberately uses stdlib ``urllib`` rather than ``curl_cffi``: +the Wayback CDX endpoint returns a small JSON document, is not protected by +Cloudflare/JA3 filters, and adding ``curl_cffi`` as a dependency of the +search stage would broaden the dependency surface unnecessarily. + +All functions return ``None`` on any network/parse error and log at WARNING. +Callers must tolerate ``None`` — Wayback is best-effort. +""" + +from __future__ import annotations + +import json +import logging +import socket +import time +import urllib.error +import urllib.parse +import urllib.request +from datetime import datetime, timezone +from typing import Optional, Tuple + +logger = logging.getLogger(__name__) + +CDX_ENDPOINT = "https://web.archive.org/cdx/search/cdx" +SNAPSHOT_TEMPLATE = "https://web.archive.org/web/{timestamp}/{url}" +# The ``id_`` modifier on the timestamp returns the raw original bytes +# (no Wayback toolbar / no rewriting). Use this when we want to feed the +# response straight into our HTML/PDF parsers. +RAW_SNAPSHOT_TEMPLATE = "https://web.archive.org/web/{timestamp}id_/{url}" + +_REQUEST_TIMEOUT_SECONDS = 30 + +# Retry schedule for the Wayback CDX endpoint. The endpoint frequently +# returns HTTP 503 or times out under load — empirically a single +# benchmark run can trigger dozens of consecutive failures. We trade +# wall-clock time for completeness because historical-replay benchmarks +# are not latency-sensitive. Pre-attempt delays in seconds; the i-th +# entry is the wait BEFORE attempt i (so the first attempt fires +# immediately). Override at module level in tests to keep them fast. +RETRY_BACKOFF_SECONDS: tuple[float, ...] = (0, 10, 30, 90, 240) + +# Recoverable HTTP status codes that warrant a retry. +_RECOVERABLE_STATUSES = {429, 500, 502, 503, 504} + + +def _sleep(seconds: float) -> None: + """Indirection so tests can monkeypatch a no-op sleep.""" + if seconds > 0: + time.sleep(seconds) + + +def _cdx_query(params: dict) -> Optional[list]: + """POST-free GET against the CDX endpoint. Returns the parsed JSON list, + or None on any failure. Retries on HTTP 503/429/5xx and read timeouts + according to ``RETRY_BACKOFF_SECONDS``.""" + query = urllib.parse.urlencode(params) + full_url = f"{CDX_ENDPOINT}?{query}" + + body: Optional[str] = None + for attempt, pre_delay in enumerate(RETRY_BACKOFF_SECONDS, start=1): + if pre_delay: + logger.info( + "Wayback CDX backoff %.0fs before attempt %d/%d", + pre_delay, attempt, len(RETRY_BACKOFF_SECONDS), + ) + _sleep(pre_delay) + try: + req = urllib.request.Request( + full_url, headers={"User-Agent": "BioScanCast/replay (+wayback-cdx)"} + ) + with urllib.request.urlopen(req, timeout=_REQUEST_TIMEOUT_SECONDS) as resp: + body = resp.read().decode("utf-8", errors="replace") + break # success + except urllib.error.HTTPError as exc: + if exc.code in _RECOVERABLE_STATUSES and attempt < len(RETRY_BACKOFF_SECONDS): + logger.info( + "Wayback CDX HTTP %d on attempt %d; retrying", + exc.code, attempt, + ) + continue + logger.warning( + "Wayback CDX gave up after %d attempt(s): HTTP %d for %s", + attempt, exc.code, full_url, + ) + return None + except (socket.timeout, TimeoutError, urllib.error.URLError) as exc: + # urllib.error.URLError wraps socket.timeout on read timeouts in + # some Python builds; check both. + is_timeout = isinstance(exc, (socket.timeout, TimeoutError)) or ( + isinstance(exc, urllib.error.URLError) + and isinstance(getattr(exc, "reason", None), (socket.timeout, TimeoutError)) + ) + if is_timeout and attempt < len(RETRY_BACKOFF_SECONDS): + logger.info( + "Wayback CDX timeout on attempt %d; retrying", attempt, + ) + continue + logger.warning( + "Wayback CDX gave up after %d attempt(s): %s for %s", + attempt, exc, full_url, + ) + return None + except Exception as exc: + logger.warning( + "Wayback CDX non-recoverable error: %s for %s", exc, full_url + ) + return None + + if body is None: + return None + + if not body.strip(): + return [] + try: + data = json.loads(body) + except json.JSONDecodeError: + logger.warning("Wayback CDX returned non-JSON body for %s", full_url) + return None + + # First row is the header. Drop it. + if isinstance(data, list) and data and isinstance(data[0], list): + return data[1:] + return [] + + +def _parse_cdx_timestamp(ts: str) -> Optional[datetime]: + """CDX timestamps are YYYYMMDDhhmmss in UTC.""" + try: + return datetime.strptime(ts, "%Y%m%d%H%M%S").replace(tzinfo=timezone.utc) + except ValueError: + return None + + +def closest_snapshot_before( + url: str, as_of: datetime +) -> Optional[Tuple[datetime, str]]: + """Return (snapshot_datetime, raw_snapshot_url) for the latest Wayback + capture of ``url`` whose timestamp is ``<= as_of``. Returns None when no + suitable snapshot exists or the lookup fails. + + The returned URL uses the ``id_`` modifier so callers get unwrapped + original content (no Wayback chrome). + """ + to_param = as_of.astimezone(timezone.utc).strftime("%Y%m%d%H%M%S") + rows = _cdx_query( + { + "url": url, + "to": to_param, + "limit": "-1", # most recent matching row + "output": "json", + "filter": "statuscode:200", + } + ) + if not rows: + return None + timestamp = rows[0][1] # column 1 is the capture timestamp + parsed = _parse_cdx_timestamp(timestamp) + if parsed is None: + return None + snapshot_url = RAW_SNAPSHOT_TEMPLATE.format(timestamp=timestamp, url=url) + return parsed, snapshot_url + + +def first_seen(url: str) -> Optional[datetime]: + """Return the earliest Wayback capture timestamp for ``url``, or None.""" + rows = _cdx_query( + { + "url": url, + "limit": "1", + "output": "json", + "filter": "statuscode:200", + "sort": "ascending", + } + ) + if not rows: + return None + return _parse_cdx_timestamp(rows[0][1]) diff --git a/bioscancast/tests/test_contamination_metrics.py b/bioscancast/tests/test_contamination_metrics.py new file mode 100644 index 0000000..c48829f --- /dev/null +++ b/bioscancast/tests/test_contamination_metrics.py @@ -0,0 +1,132 @@ +from datetime import datetime, timezone + +from bioscancast.filtering.models import ForecastQuestion, SearchResult +from bioscancast.stages.eval_stage.contamination import ( + BaselineForecast, + ContaminationCounts, + filter_caught_contamination_rate, + retrieval_free_baseline_forecast, +) + + +def _result(pub: datetime | None) -> SearchResult: + return SearchResult( + id="x", + question_id="Q", + query_id="q1", + engine="fake", + url="https://example.com/x", + canonical_url="https://example.com/x", + domain="example.com", + title="T", + snippet="S", + rank=1, + retrieved_at=datetime.now(timezone.utc), + published_date=pub, + ) + + +class TestFilterCaughtContaminationRate: + def test_clean_run_is_zero(self): + cutoff = datetime(2024, 6, 1, tzinfo=timezone.utc) + results = [ + _result(datetime(2024, 1, 1, tzinfo=timezone.utc)), + _result(datetime(2024, 5, 31, tzinfo=timezone.utc)), + ] + counts = filter_caught_contamination_rate(results, cutoff) + assert counts.post_cutoff_in_final == 0 + assert counts.filter_caught_rate == 0.0 + assert counts.pre_cutoff_in_final == 2 + + def test_some_leak_through(self): + cutoff = datetime(2024, 6, 1, tzinfo=timezone.utc) + results = [ + _result(datetime(2024, 5, 1, tzinfo=timezone.utc)), + _result(datetime(2024, 8, 1, tzinfo=timezone.utc)), # post + _result(datetime(2024, 9, 1, tzinfo=timezone.utc)), # post + _result(None), # undated + ] + counts = filter_caught_contamination_rate(results, cutoff) + assert counts.post_cutoff_in_final == 2 + assert counts.undated_in_final == 1 + assert counts.pre_cutoff_in_final == 1 + assert counts.filter_caught_rate == 0.5 + + def test_empty_list(self): + counts = filter_caught_contamination_rate( + [], datetime(2024, 1, 1, tzinfo=timezone.utc) + ) + assert counts.filter_caught_rate == 0.0 + assert counts.total == 0 + + +class TestRetrievalFreeBaselineForecast: + def test_well_formed_response(self): + question = ForecastQuestion( + id="Q1", + text="Will X happen?", + created_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + as_of_date=datetime(2024, 6, 1, tzinfo=timezone.utc), + ) + + class GoodLLM: + def generate_json(self, prompt): + return {"probabilities": [0.7, 0.3], "rationale": "guess"} + + out = retrieval_free_baseline_forecast( + question, options=["yes", "no"], llm_client=GoodLLM() + ) + assert isinstance(out, BaselineForecast) + assert abs(sum(out.probabilities) - 1.0) < 1e-9 + assert out.probabilities[0] > out.probabilities[1] + assert out.rationale == "guess" + + def test_renormalises_unnormalised_probabilities(self): + question = ForecastQuestion( + id="Q1", + text="?", + created_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + ) + + class UnnormLLM: + def generate_json(self, prompt): + return {"probabilities": [2.0, 6.0], "rationale": ""} + + out = retrieval_free_baseline_forecast( + question, options=["a", "b"], llm_client=UnnormLLM() + ) + assert abs(sum(out.probabilities) - 1.0) < 1e-9 + assert abs(out.probabilities[0] - 0.25) < 1e-9 + + def test_malformed_response_uniform_fallback(self): + question = ForecastQuestion( + id="Q1", + text="?", + created_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + ) + + class BadLLM: + def generate_json(self, prompt): + return {"probabilities": "not a list"} + + out = retrieval_free_baseline_forecast( + question, options=["a", "b", "c"], llm_client=BadLLM() + ) + assert out.probabilities == [1 / 3, 1 / 3, 1 / 3] + assert "fallback" in (out.rationale or "") + + def test_llm_exception_uniform_fallback(self): + question = ForecastQuestion( + id="Q1", + text="?", + created_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + ) + + class ExplodingLLM: + def generate_json(self, prompt): + raise RuntimeError("oops") + + out = retrieval_free_baseline_forecast( + question, options=["a", "b"], llm_client=ExplodingLLM() + ) + assert out.probabilities == [0.5, 0.5] diff --git a/bioscancast/tests/test_cutoff_filtering.py b/bioscancast/tests/test_cutoff_filtering.py new file mode 100644 index 0000000..5588fa9 --- /dev/null +++ b/bioscancast/tests/test_cutoff_filtering.py @@ -0,0 +1,226 @@ +"""End-to-end-ish tests for historical-replay mode in SearchStagePipeline. + +Uses the same FakeLLMClient/FakeSearchBackend pattern as test_search_pipeline.py +to keep the test layer hand-rolled and dependency-free. +""" + +from datetime import datetime, timezone +from typing import List +from unittest.mock import patch + +from bioscancast.filtering.models import ForecastQuestion +from bioscancast.stages.search_stage.backends.base import RawSearchResult +from bioscancast.stages.search_stage.pipeline import ( + SearchStagePipeline, + _parse_published_date, +) + + +class TestParsePublishedDate: + def test_iso_with_offset(self): + assert _parse_published_date("2025-02-17T13:00:00+00:00") == datetime( + 2025, 2, 17, 13, 0, 0, tzinfo=timezone.utc + ) + + def test_iso_date_only(self): + assert _parse_published_date("2025-02-17") == datetime( + 2025, 2, 17, tzinfo=timezone.utc + ) + + def test_rfc2822_with_zone(self): + # The format Tavily's news topic actually returns. + result = _parse_published_date("Tue, 19 May 2026 13:00:00 GMT") + assert result is not None + assert result.year == 2026 + assert result.month == 5 + assert result.day == 19 + assert result.tzinfo is not None + + def test_rfc2822_with_offset(self): + result = _parse_published_date("Tue, 19 May 2026 13:00:00 +0000") + assert result is not None + assert result.day == 19 + + def test_none_and_empty(self): + assert _parse_published_date(None) is None + assert _parse_published_date("") is None + + def test_garbage_returns_none(self): + assert _parse_published_date("not a date") is None + + +class _FakeLLM: + def __init__(self): + self._calls = 0 + + def generate_json(self, prompt: str) -> dict: + self._calls += 1 + if self._calls == 1: + return {"question_type": "outbreak_count"} + return { + "sub_queries": [ + {"text": "H5N1 cases 2024", "axis": "latest_data"}, + {"text": "avian flu trend", "axis": "trend"}, + {"text": "bird flu policy", "axis": "policy"}, + ] + } + + +class _FakeBackend: + def __init__(self, results: List[RawSearchResult]): + self._results = results + self.end_dates_seen: list = [] + + def search(self, query, max_results=10, end_date=None): + self.end_dates_seen.append(end_date) + return list(self._results) + + +def _make_question(as_of: datetime | None) -> ForecastQuestion: + return ForecastQuestion( + id="Q-CUT", + text="Will H5N1 exceed 100 cases by end of 2024?", + created_at=datetime(2024, 6, 1, tzinfo=timezone.utc), + pathogen="nopathogen", # avoid dashboard injection in this test + as_of_date=as_of, + ) + + +def test_post_cutoff_results_are_dropped(): + cutoff = datetime(2024, 6, 1, tzinfo=timezone.utc) + backend_results = [ + RawSearchResult( + url="https://news.example.com/a", + title="Pre-cutoff", + snippet="", + rank=1, + published_date="2024-05-15", + ), + RawSearchResult( + url="https://news.example.com/b", + title="Post-cutoff", + snippet="", + rank=2, + published_date="2024-08-15", + ), + ] + pipeline = SearchStagePipeline( + search_backend=_FakeBackend(backend_results), + llm_client=_FakeLLM(), + backend_name="fake", + ) + results = pipeline.run(_make_question(cutoff)) + urls = {r.url for r in results} + assert "https://news.example.com/a" in urls + assert "https://news.example.com/b" not in urls + + +def test_undated_dropped_when_recovery_fails(): + cutoff = datetime(2024, 6, 1, tzinfo=timezone.utc) + backend_results = [ + RawSearchResult( + url="https://news.example.com/no-date", + title="Undated", + snippet="", + rank=1, + published_date=None, + ), + ] + pipeline = SearchStagePipeline( + search_backend=_FakeBackend(backend_results), + llm_client=_FakeLLM(), + backend_name="fake", + ) + with patch( + "bioscancast.stages.search_stage.pipeline.recover_published_date" + ) as mock_rec: + mock_rec.return_value = (None, None) + results = pipeline.run(_make_question(cutoff)) + assert not any(r.url == "https://news.example.com/no-date" for r in results) + + +def test_undated_kept_when_recovery_succeeds_before_cutoff(): + cutoff = datetime(2024, 6, 1, tzinfo=timezone.utc) + backend_results = [ + RawSearchResult( + url="https://news.example.com/2024/03/15/article", + title="Slug-dated", + snippet="", + rank=1, + published_date=None, + ), + ] + pipeline = SearchStagePipeline( + search_backend=_FakeBackend(backend_results), + llm_client=_FakeLLM(), + backend_name="fake", + ) + results = pipeline.run(_make_question(cutoff)) + matching = [r for r in results if "2024/03/15" in r.url] + assert len(matching) == 1 + assert matching[0].published_date_source == "url_slug" + + +def test_end_date_forwarded_to_backend(): + cutoff = datetime(2024, 6, 1, tzinfo=timezone.utc) + backend = _FakeBackend( + [ + RawSearchResult( + url="https://news.example.com/x", + title="X", + snippet="", + rank=1, + published_date="2024-01-01", + ) + ] + ) + pipeline = SearchStagePipeline( + search_backend=backend, llm_client=_FakeLLM(), backend_name="fake" + ) + pipeline.run(_make_question(cutoff)) + assert all(d == "2024-06-01" for d in backend.end_dates_seen if d is not None) + assert any(d == "2024-06-01" for d in backend.end_dates_seen) + + +def test_live_mode_unchanged(): + backend = _FakeBackend( + [ + RawSearchResult( + url="https://news.example.com/x", + title="X", + snippet="", + rank=1, + published_date=None, + ) + ] + ) + pipeline = SearchStagePipeline( + search_backend=backend, llm_client=_FakeLLM(), backend_name="fake" + ) + results = pipeline.run(_make_question(as_of=None)) + # Undated result MUST be kept in live mode (the cutoff filter is off) + assert any(r.url == "https://news.example.com/x" for r in results) + # And backend received end_date=None + assert all(d is None for d in backend.end_dates_seen) + + +def test_cutoff_applied_persisted_on_results(): + cutoff = datetime(2024, 6, 1, tzinfo=timezone.utc) + backend = _FakeBackend( + [ + RawSearchResult( + url="https://news.example.com/x", + title="X", + snippet="", + rank=1, + published_date="2024-01-01", + ) + ] + ) + pipeline = SearchStagePipeline( + search_backend=backend, llm_client=_FakeLLM(), backend_name="fake" + ) + results = pipeline.run(_make_question(cutoff)) + assert results + for r in results: + assert r.cutoff_applied == cutoff diff --git a/bioscancast/tests/test_date_recovery.py b/bioscancast/tests/test_date_recovery.py new file mode 100644 index 0000000..54c142a --- /dev/null +++ b/bioscancast/tests/test_date_recovery.py @@ -0,0 +1,95 @@ +from datetime import datetime, timezone +from unittest.mock import patch + +from bioscancast.stages.search_stage.date_recovery import ( + date_from_last_modified, + date_from_url_slug, + recover_published_date, +) + + +class TestDateFromUrlSlug: + def test_year_month_day_path(self): + assert date_from_url_slug( + "https://example.com/2024/03/15/some-article" + ) == datetime(2024, 3, 15, tzinfo=timezone.utc) + + def test_year_month_only(self): + assert date_from_url_slug( + "https://example.com/news/2023/06/topic" + ) == datetime(2023, 6, 1, tzinfo=timezone.utc) + + def test_iso_dashed(self): + assert date_from_url_slug( + "https://example.com/p/2025-01-20/title" + ) == datetime(2025, 1, 20, tzinfo=timezone.utc) + + def test_no_match(self): + assert date_from_url_slug("https://example.com/about/contact") is None + + def test_implausible_year_rejected(self): + # 1872 looks like a year but is too old to be a sensible publication + assert date_from_url_slug("https://example.com/1872/03/15") is None + + +class TestDateFromLastModified: + def test_no_fetcher_returns_none(self): + # Off by default: requires explicit injection + assert date_from_last_modified("https://example.com/a") is None + + def test_rfc7231_format(self): + header = "Wed, 21 Oct 2015 07:28:00 GMT" + result = date_from_last_modified( + "https://example.com/a", head_fetcher=lambda _: header + ) + assert result == datetime(2015, 10, 21, 7, 28, 0, tzinfo=timezone.utc) + + def test_fetcher_returning_none(self): + assert date_from_last_modified( + "https://example.com/a", head_fetcher=lambda _: None + ) is None + + def test_fetcher_raises_returns_none(self): + def boom(_): + raise RuntimeError("network down") + + assert date_from_last_modified("https://example.com/a", head_fetcher=boom) is None + + def test_unparseable_header(self): + assert date_from_last_modified( + "https://example.com/a", head_fetcher=lambda _: "not a date" + ) is None + + +class TestRecoverPublishedDate: + def test_url_slug_wins(self): + dt, source = recover_published_date( + "https://example.com/2024/03/15/x", use_wayback=False + ) + assert source == "url_slug" + assert dt == datetime(2024, 3, 15, tzinfo=timezone.utc) + + def test_wayback_used_when_no_slug(self): + with patch( + "bioscancast.stages.search_stage.date_recovery._wayback_first_seen" + ) as mock_wb: + mock_wb.return_value = datetime(2020, 1, 1, tzinfo=timezone.utc) + dt, source = recover_published_date("https://example.com/about") + assert source == "wayback_first_seen" + assert dt == datetime(2020, 1, 1, tzinfo=timezone.utc) + + def test_all_strategies_fail(self): + with patch( + "bioscancast.stages.search_stage.date_recovery._wayback_first_seen" + ) as mock_wb: + mock_wb.return_value = None + dt, source = recover_published_date("https://example.com/about") + assert dt is None + assert source is None + + def test_wayback_disabled(self): + dt, source = recover_published_date( + "https://example.com/about", use_wayback=False + ) + assert dt is None + assert source is None diff --git a/bioscancast/tests/test_extraction_pipeline.py b/bioscancast/tests/test_extraction_pipeline.py index f7b99d0..4aba2d3 100644 --- a/bioscancast/tests/test_extraction_pipeline.py +++ b/bioscancast/tests/test_extraction_pipeline.py @@ -74,7 +74,7 @@ def _make_fetch_result( def _fake_fetch_factory(mapping: dict[str, FetchResult]): """Return a fetch function that looks up results by URL.""" - def fake_fetch(url, *, config=None): + def fake_fetch(url, *, config=None, as_of_date=None): if url in mapping: return mapping[url] return _make_fetch_result(url, b"", error="not_found") diff --git a/bioscancast/tests/test_historical_topup.py b/bioscancast/tests/test_historical_topup.py new file mode 100644 index 0000000..c8ee0d8 --- /dev/null +++ b/bioscancast/tests/test_historical_topup.py @@ -0,0 +1,253 @@ +"""Tests for the historical-mode year-hint and top-up behavior in +SearchStagePipeline. +""" + +from datetime import datetime, timezone +from typing import List + +from bioscancast.filtering.models import ForecastQuestion +from bioscancast.stages.search_stage.backends.base import RawSearchResult +from bioscancast.stages.search_stage.pipeline import SearchStagePipeline + + +class _FakeLLM: + def __init__(self): + self._calls = 0 + + def generate_json(self, prompt: str) -> dict: + self._calls += 1 + if self._calls == 1: + return {"question_type": "outbreak_count"} + return { + "sub_queries": [ + {"text": "H5N1 cases", "axis": "latest_data"}, + {"text": "avian flu trend", "axis": "trend"}, + {"text": "bird flu policy", "axis": "policy"}, + ] + } + + +class _RecordingBackend: + """Records every (query, max_results) it was called with and returns a + canned mapping. Same URL can appear across calls — pipeline dedup + handles that.""" + + def __init__(self, results_by_query: dict[tuple[str, int], List[RawSearchResult]] | None = None): + self.calls: list[tuple[str, int]] = [] + self._results = results_by_query or {} + # Fallback results for any query not explicitly mapped. + self._fallback: List[RawSearchResult] = [] + + def set_fallback(self, results: List[RawSearchResult]) -> None: + self._fallback = results + + def search(self, query, max_results=10, end_date=None): + self.calls.append((query, max_results)) + # Prefer exact match on (query, max_results); else any match on + # query; else fallback. + if (query, max_results) in self._results: + return list(self._results[(query, max_results)]) + for (q, _), res in self._results.items(): + if q == query: + return list(res) + return list(self._fallback) + + +def _question(as_of: datetime | None) -> ForecastQuestion: + return ForecastQuestion( + id="Q-TU", + text="H5N1 outbreak in 2024", + created_at=datetime(2024, 6, 1, tzinfo=timezone.utc), + pathogen="nopathogen", # skip dashboard lookup + as_of_date=as_of, + ) + + +def test_year_hint_appended_in_historical_mode(): + backend = _RecordingBackend() + backend.set_fallback( + [ + RawSearchResult( + url="https://news.example.com/a", + title="A", + snippet="", + rank=1, + published_date="2024-01-01", + ) + ] + ) + pipeline = SearchStagePipeline( + search_backend=backend, llm_client=_FakeLLM(), backend_name="fake" + ) + pipeline.run(_question(datetime(2024, 6, 1, tzinfo=timezone.utc))) + # Every query the backend saw should end in " 2024" + assert backend.calls, "backend should have been called" + queries = [q for q, _ in backend.calls] + assert all(q.endswith(" 2024") for q in queries), queries + + +def test_year_hint_skipped_in_live_mode(): + backend = _RecordingBackend() + backend.set_fallback( + [ + RawSearchResult( + url="https://news.example.com/a", + title="A", + snippet="", + rank=1, + published_date=None, + ) + ] + ) + pipeline = SearchStagePipeline( + search_backend=backend, llm_client=_FakeLLM(), backend_name="fake" + ) + pipeline.run(_question(as_of=None)) + queries = [q for q, _ in backend.calls] + assert not any(q.endswith(" 2024") for q in queries), queries + + +def test_year_hint_not_double_appended_if_already_present(): + # If the LLM's sub-query already mentions the year, don't append it again. + class _LLM: + def __init__(self): + self._n = 0 + + def generate_json(self, prompt: str) -> dict: + self._n += 1 + if self._n == 1: + return {"question_type": "outbreak_count"} + return { + "sub_queries": [ + {"text": "H5N1 cases 2024", "axis": "latest_data"}, + ] + } + + backend = _RecordingBackend() + backend.set_fallback( + [ + RawSearchResult( + url="https://news.example.com/a", + title="A", + snippet="", + rank=1, + published_date="2024-01-01", + ) + ] + ) + pipeline = SearchStagePipeline( + search_backend=backend, llm_client=_LLM(), backend_name="fake" + ) + pipeline.run(_question(datetime(2024, 6, 1, tzinfo=timezone.utc))) + queries = [q for q, _ in backend.calls] + # Should NOT be "H5N1 cases 2024 2024" + assert all(q.count("2024") == 1 for q in queries), queries + + +def test_top_up_fires_when_survivors_below_threshold(): + """First round returns mostly post-cutoff (so few survive); top-up + round with bigger max_results returns extras that include pre-cutoff + items. The backend should be called once per sub-query per round.""" + as_of = datetime(2024, 6, 1, tzinfo=timezone.utc) + + # Build the per-query result sets. + round1 = [ + RawSearchResult( + url=f"https://news.example.com/post-{i}", + title="post", + snippet="", + rank=i, + published_date="2024-09-01", # post-cutoff + ) + for i in range(3) + ] + round2 = round1 + [ + RawSearchResult( + url=f"https://news.example.com/pre-{i}", + title="pre", + snippet="", + rank=i + 10, + published_date="2024-01-01", # pre-cutoff + ) + for i in range(20) + ] + + backend = _RecordingBackend() + # Three sub-queries from _FakeLLM each get year-hinted: + for query in ("H5N1 cases 2024", "avian flu trend 2024", "bird flu policy 2024"): + backend._results[(query, 10)] = round1 + backend._results[(query, 50)] = round2 + + pipeline = SearchStagePipeline( + search_backend=backend, + llm_client=_FakeLLM(), + backend_name="fake", + min_post_filter_results=10, + top_up_results_per_query=50, + max_top_up_rounds=1, + ) + results = pipeline.run(_question(as_of)) + + # Each sub-query should have been called twice: once with max=10 and once + # with max=50. + max_results_seen = [m for _, m in backend.calls] + assert 10 in max_results_seen + assert 50 in max_results_seen + # After top-up we should have well over the threshold of pre-cutoff results. + assert len(results) >= 10 + + +def test_top_up_skipped_when_survivors_meet_threshold(): + """If the initial round already returns enough survivors, no top-up.""" + as_of = datetime(2024, 6, 1, tzinfo=timezone.utc) + plenty = [ + RawSearchResult( + url=f"https://news.example.com/x-{i}", + title=f"x{i}", + snippet="", + rank=i, + published_date="2024-01-01", + ) + for i in range(20) + ] + backend = _RecordingBackend() + backend.set_fallback(plenty) + pipeline = SearchStagePipeline( + search_backend=backend, + llm_client=_FakeLLM(), + backend_name="fake", + min_post_filter_results=10, + top_up_results_per_query=50, + max_top_up_rounds=1, + ) + pipeline.run(_question(as_of)) + # Only the initial round (max=10) should have fired. + max_results_seen = {m for _, m in backend.calls} + assert max_results_seen == {10} + + +def test_top_up_skipped_in_live_mode(): + """Live mode never tops up, even when result count is low.""" + backend = _RecordingBackend() + backend.set_fallback( + [ + RawSearchResult( + url="https://news.example.com/only", + title="only", + snippet="", + rank=1, + published_date=None, + ) + ] + ) + pipeline = SearchStagePipeline( + search_backend=backend, + llm_client=_FakeLLM(), + backend_name="fake", + min_post_filter_results=10, + top_up_results_per_query=50, + max_top_up_rounds=2, + ) + pipeline.run(_question(as_of=None)) + max_results_seen = {m for _, m in backend.calls} + assert max_results_seen == {10} diff --git a/bioscancast/tests/test_search_filtering_integration.py b/bioscancast/tests/test_search_filtering_integration.py index 5dd7405..dfa6bb9 100644 --- a/bioscancast/tests/test_search_filtering_integration.py +++ b/bioscancast/tests/test_search_filtering_integration.py @@ -20,7 +20,9 @@ class RealisticFakeSearchBackend: """Returns results with titles/snippets that overlap with the H5N1 question, simulating what a real search engine would return.""" - def search(self, query: str, max_results: int = 10) -> List[RawSearchResult]: + def search( + self, query: str, max_results: int = 10, end_date=None + ) -> List[RawSearchResult]: return [ RawSearchResult( url="https://www.cdc.gov/bird-flu/situation-summary/", diff --git a/bioscancast/tests/test_search_pipeline.py b/bioscancast/tests/test_search_pipeline.py index bc53723..bbec0a7 100644 --- a/bioscancast/tests/test_search_pipeline.py +++ b/bioscancast/tests/test_search_pipeline.py @@ -78,7 +78,9 @@ def _default_results() -> List[RawSearchResult]: ), ] - def search(self, query: str, max_results: int = 10) -> List[RawSearchResult]: + def search( + self, query: str, max_results: int = 10, end_date=None + ) -> List[RawSearchResult]: self.queries_received.append(query) return self._results diff --git a/bioscancast/tests/test_wayback_fetch.py b/bioscancast/tests/test_wayback_fetch.py new file mode 100644 index 0000000..32b220a --- /dev/null +++ b/bioscancast/tests/test_wayback_fetch.py @@ -0,0 +1,115 @@ +"""Offline tests for the Wayback rewrite in the extraction fetcher. + +The patching reaches into ``bioscancast.extraction.fetcher.closest_snapshot_before`` +(the symbol imported at module load) and ``curl_requests.get``, never touching +the network. There is also a ``@pytest.mark.live`` smoke test for hitting +Wayback for real. +""" + +from datetime import datetime, timezone +from unittest.mock import patch + +import pytest + +from bioscancast.extraction.fetcher import fetch + + +class _FakeResponse: + def __init__(self, *, body: bytes, url: str, status: int = 200): + self.status_code = status + self.headers = {"content-type": "text/html"} + self.url = url + self._body = body + + def iter_content(self): + yield self._body + + def close(self): + pass + + +def _patch_curl(body: bytes, url: str = "https://example.com/page"): + return patch( + "bioscancast.extraction.fetcher.curl_requests.get", + return_value=_FakeResponse(body=body, url=url), + ) + + +def _patch_snapshot(value): + return patch( + "bioscancast.extraction.fetcher.closest_snapshot_before", + return_value=value, + ) + + +class TestWaybackRewrite: + def test_live_mode_no_wayback_call(self): + with _patch_curl(b"live") as mock_get, _patch_snapshot(None) as mock_snap: + result = fetch("https://example.com/page", as_of_date=None) + assert result.fetch_strategy == "live" + assert result.snapshot_timestamp is None + mock_snap.assert_not_called() + mock_get.assert_called_once() + + def test_wayback_success(self): + snap_dt = datetime(2024, 3, 1, 12, 0, 0, tzinfo=timezone.utc) + snap_url = "https://web.archive.org/web/20240301120000id_/https://example.com/page" + with _patch_snapshot((snap_dt, snap_url)), _patch_curl(b"snapshot"): + result = fetch( + "https://example.com/page", + as_of_date=datetime(2024, 6, 1, tzinfo=timezone.utc), + ) + assert result.fetch_strategy == "wayback" + assert result.snapshot_timestamp == snap_dt + assert result.url == "https://example.com/page" # original, not archive.org + assert result.content_bytes == b"snapshot" + + def test_no_snapshot_falls_back_to_live(self): + with _patch_snapshot(None), _patch_curl(b"live"): + result = fetch( + "https://example.com/page", + as_of_date=datetime(2024, 6, 1, tzinfo=timezone.utc), + ) + assert result.fetch_strategy == "wayback_fallback_to_live" + assert result.snapshot_timestamp is None + assert result.url == "https://example.com/page" + + def test_wayback_fetch_error_falls_back_to_live(self): + snap_dt = datetime(2024, 3, 1, tzinfo=timezone.utc) + snap_url = "https://web.archive.org/web/20240301120000id_/https://example.com/page" + # First call (to Wayback) errors; second call (live) succeeds. + responses = [ + ConnectionError("wayback down"), + _FakeResponse(body=b"live", url="https://example.com/page"), + ] + + def side_effect(*args, **kwargs): + r = responses.pop(0) + if isinstance(r, Exception): + raise r + return r + + with _patch_snapshot((snap_dt, snap_url)), patch( + "bioscancast.extraction.fetcher.curl_requests.get", side_effect=side_effect + ): + result = fetch( + "https://example.com/page", + as_of_date=datetime(2024, 6, 1, tzinfo=timezone.utc), + ) + assert result.fetch_strategy == "wayback_fallback_to_live" + assert result.content_bytes == b"live" + + +@pytest.mark.live +def test_live_wayback_lookup(): + """Smoke-test the real Wayback CDX endpoint. Skipped by default.""" + from bioscancast.stages.search_stage.wayback import closest_snapshot_before + + result = closest_snapshot_before( + "https://www.cdc.gov/", + datetime(2023, 1, 1, tzinfo=timezone.utc), + ) + assert result is not None + snap_dt, snap_url = result + assert snap_dt < datetime(2023, 1, 2, tzinfo=timezone.utc) + assert "web.archive.org/web/" in snap_url diff --git a/bioscancast/tests/test_wayback_retry.py b/bioscancast/tests/test_wayback_retry.py new file mode 100644 index 0000000..93e42cd --- /dev/null +++ b/bioscancast/tests/test_wayback_retry.py @@ -0,0 +1,101 @@ +"""Retry/backoff behavior for the Wayback CDX client.""" + +from __future__ import annotations + +import socket +import urllib.error +from io import BytesIO +from unittest.mock import patch + +from bioscancast.stages.search_stage import wayback + + +def _http_error(code: int) -> urllib.error.HTTPError: + return urllib.error.HTTPError( + url="https://web.archive.org/cdx/search/cdx", + code=code, + msg=str(code), + hdrs=None, # type: ignore[arg-type] + fp=None, + ) + + +def _ok_response(payload: bytes): + """A minimal stand-in for the context manager returned by urlopen.""" + + class _CM: + def __enter__(self): + return BytesIO(payload) + + def __exit__(self, *a): + return False + + return _CM() + + +class TestCdxRetry: + def _no_sleep(self): + return patch.object(wayback, "_sleep", lambda _s: None) + + def _short_schedule(self): + # 3 attempts max so tests are predictable; all delays are no-ops. + return patch.object(wayback, "RETRY_BACKOFF_SECONDS", (0, 0, 0)) + + def test_retries_then_succeeds_on_503(self): + # First two calls 503, third returns valid JSON. + seq = [ + _http_error(503), + _http_error(503), + _ok_response(b'[["urlkey","timestamp","original"],["a","20240101120000","b"]]'), + ] + with self._short_schedule(), self._no_sleep(), patch( + "bioscancast.stages.search_stage.wayback.urllib.request.urlopen", + side_effect=seq, + ): + data = wayback._cdx_query({"url": "https://example.com/"}) + assert data is not None + assert data == [["a", "20240101120000", "b"]] + + def test_gives_up_after_max_attempts_503(self): + with self._short_schedule(), self._no_sleep(), patch( + "bioscancast.stages.search_stage.wayback.urllib.request.urlopen", + side_effect=[_http_error(503)] * 3, + ): + data = wayback._cdx_query({"url": "https://example.com/"}) + assert data is None + + def test_retries_on_timeout(self): + seq = [ + socket.timeout("read timeout"), + _ok_response(b'[["urlkey","timestamp","original"]]'), + ] + with self._short_schedule(), self._no_sleep(), patch( + "bioscancast.stages.search_stage.wayback.urllib.request.urlopen", + side_effect=seq, + ): + data = wayback._cdx_query({"url": "https://example.com/"}) + # Header-only payload → empty rows list + assert data == [] + + def test_non_recoverable_status_does_not_retry(self): + # 404 should fail immediately with no retries + with self._short_schedule(), self._no_sleep(), patch( + "bioscancast.stages.search_stage.wayback.urllib.request.urlopen", + side_effect=[_http_error(404)], + ) as mock_open: + data = wayback._cdx_query({"url": "https://example.com/"}) + assert data is None + assert mock_open.call_count == 1 + + def test_recoverable_statuses_cover_5xx_and_429(self): + # 429 is rate-limit; should be treated as recoverable. + seq = [ + _http_error(429), + _ok_response(b'[["header"]]'), + ] + with self._short_schedule(), self._no_sleep(), patch( + "bioscancast.stages.search_stage.wayback.urllib.request.urlopen", + side_effect=seq, + ): + data = wayback._cdx_query({"url": "https://example.com/"}) + assert data == [] diff --git a/scripts/probe_tavily_topic.py b/scripts/probe_tavily_topic.py new file mode 100644 index 0000000..57ebe49 --- /dev/null +++ b/scripts/probe_tavily_topic.py @@ -0,0 +1,100 @@ +"""Compare Tavily ``topic="news"`` vs default (``general``) for one historical +query. No Wayback, no LLM — just two Tavily calls and a date distribution +so we can see whether dropping ``topic="news"`` would actually surface more +pre-cutoff results. + +Run: + python scripts/probe_tavily_topic.py +""" + +from __future__ import annotations + +import os +import sys +from collections import Counter +from datetime import date + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +try: + from dotenv import load_dotenv + load_dotenv() +except ImportError: + pass + +from tavily import TavilyClient + + +QUERY = "H5N1 human cases United States 2025" +CUTOFF = date(2025, 2, 17) +MAX_RESULTS = 20 + + +def _bucket(dstr: str | None) -> str: + if not dstr: + return "no_date" + try: + d = date.fromisoformat(dstr[:10]) + except ValueError: + return "unparseable" + if d <= CUTOFF: + return "pre_cutoff" + if d.year == 2025: + return "post_cutoff_2025" + if d.year == 2026: + return "post_cutoff_2026" + return f"post_cutoff_{d.year}" + + +def _run(client: TavilyClient, *, with_news_topic: bool) -> None: + kwargs: dict = { + "query": QUERY, + "max_results": MAX_RESULTS, + "include_answer": False, + } + if with_news_topic: + kwargs["topic"] = "news" + + print(f"\n{'=' * 72}") + print(f"Tavily topic = {'news' if with_news_topic else 'general (default)'}") + print(f"Query = {QUERY!r}") + print(f"Cutoff = {CUTOFF.isoformat()}") + print("=" * 72) + + resp = client.search(**kwargs) + results = resp.get("results", []) + print(f"Returned {len(results)} results\n") + + buckets: Counter = Counter() + for i, r in enumerate(results, 1): + d = r.get("published_date") + bucket = _bucket(d) + buckets[bucket] += 1 + date_label = (d or "—")[:10] if d else "—" + url = r.get("url", "") + title = (r.get("title") or "")[:80] + marker = " " if bucket == "pre_cutoff" else " " + if bucket == "pre_cutoff": + marker = "✓ " + print(f"{marker}{i:2d}. {date_label:<10} [{bucket:<20}] {title}") + print(f" {url}") + + print(f"\nBucket totals:") + for bucket, n in sorted(buckets.items(), key=lambda kv: -kv[1]): + print(f" {bucket:<20} {n}") + pre = buckets.get("pre_cutoff", 0) + if results: + print(f"\nPre-cutoff hit rate: {pre}/{len(results)} = {pre / len(results):.0%}") + + +def main() -> None: + api_key = os.environ.get("TAVILY_API_KEY") + if not api_key: + sys.exit("TAVILY_API_KEY missing") + client = TavilyClient(api_key=api_key) + _run(client, with_news_topic=True) + _run(client, with_news_topic=False) + + +if __name__ == "__main__": + main() diff --git a/scripts/test_historical_replay.py b/scripts/test_historical_replay.py new file mode 100644 index 0000000..f095bc3 --- /dev/null +++ b/scripts/test_historical_replay.py @@ -0,0 +1,198 @@ +"""Manual smoke test for historical-replay mode. + +Runs the search stage against `q1` (resolved H5N1 US, Feb 28 2025 deadline) +with as_of_date = question.created_date. Prints a digest of what the cutoff +machinery did. Does NOT push through filtering/extraction (issue #13 would +make that uninformative without an LLM and a relaxed threshold). + +What this validates on the feat/as-of-date-replay branch: + - Tavily backend receives end_date matching the cutoff + - All returned SearchResult.published_date <= as_of_date + - Dashboard URLs are Wayback-rewritten (or suppressed if no snapshot) + - SearchResult.cutoff_applied is populated + - The date-recovery chain fires for undated results + +What this also gathers (for issue #5 — Tavily date reliability): + - share of Tavily results that came with a published_date + - share that needed the recovery chain (and which strategy won) + - share that were dropped for no-date-available + +Run: + python scripts/test_historical_replay.py + +Requires TAVILY_API_KEY and OPENAI_API_KEY in environment (or .env). +""" + +from __future__ import annotations + +import logging +import os +import sys +from collections import Counter +from datetime import datetime, timezone + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +try: + from dotenv import load_dotenv + load_dotenv() +except ImportError: + pass + +from bioscancast.filtering.models import ForecastQuestion +from bioscancast.llm.client import OpenAIClient +from bioscancast.stages.search_stage.backends.tavily_backend import TavilyBackend +from bioscancast.stages.search_stage.pipeline import SearchStagePipeline + + +# q1 from bioscancast_questions.csv. created_date is Excel serial 45705 = +# 2025-02-17. Resolution deadline is Feb 28, 2025. Cutoff for the human +# forecaster is the creation date. +Q1_TEXT = ( + "How many confirmed human cases of H5N1 will be reported in the US " + "by February 28, 2025, according to the US dashboard?" +) +Q1_CREATED = datetime(2025, 2, 17, tzinfo=timezone.utc) + + +def _hr(title: str) -> None: + bar = "=" * 72 + print(f"\n{bar}\n{title}\n{bar}") + + +def main() -> None: + # Surface the cutoff-filter and dashboard suppression INFO logs. + logging.basicConfig( + level=logging.INFO, + format="%(levelname)s %(name)s: %(message)s", + ) + # Hide curl_cffi/openai noise but keep our modules chatty. + for noisy in ("urllib3", "httpx", "openai", "tavily"): + logging.getLogger(noisy).setLevel(logging.WARNING) + + question = ForecastQuestion( + id="q1", + text=Q1_TEXT, + created_at=Q1_CREATED, + pathogen="h5n1", + region="United States", + target_date=datetime(2025, 2, 28, tzinfo=timezone.utc), + as_of_date=Q1_CREATED, # the cutoff + ) + + _hr("CONFIGURATION") + print(f"Question text : {question.text}") + print(f"Pathogen : {question.pathogen}") + print(f"Created at : {question.created_at.isoformat()}") + print(f"Target date : {question.target_date.isoformat()}") + print(f"AS-OF (cutoff) : {question.as_of_date.isoformat()}") + print(f"Historical mode : YES (as_of_date is set)") + + llm = OpenAIClient() + # Wrap the Tavily backend so we can observe whether end_date was forwarded. + base_backend = TavilyBackend() + end_dates_seen: list = [] + _orig_search = base_backend.search + + def wrapped_search(query: str, max_results: int = 10, end_date=None): + end_dates_seen.append(end_date) + return _orig_search(query, max_results=max_results, end_date=end_date) + + base_backend.search = wrapped_search # type: ignore[assignment] + + # NB: deliberately running without SearchCache so we hit Tavily fresh and + # the test isn't influenced by entries from a previous (different-cutoff) + # run. The cache key incorporates the cutoff so this is just paranoia. + pipeline = SearchStagePipeline( + search_backend=base_backend, + llm_client=llm, + cache=None, + backend_name="tavily", + # Leave historical_roleplay off — that's a separate opt-in. + ) + + _hr("RUNNING SEARCH STAGE") + results = pipeline.run(question) + + _hr("BACKEND OBSERVATIONS") + print(f"Sub-queries issued to Tavily: {len(end_dates_seen)}") + distinct_end_dates = set(end_dates_seen) + print(f"end_date values forwarded : {distinct_end_dates}") + if distinct_end_dates == {question.as_of_date.strftime('%Y-%m-%d')}: + print(">> end_date correctly forwarded on every call.") + else: + print(">> WARNING: end_date forwarding looks wrong.") + + _hr("RESULT SUMMARY") + print(f"Total results returned: {len(results)}") + + dashboards = [r for r in results if r.engine == "dashboard"] + organic = [r for r in results if r.engine != "dashboard"] + print(f" Dashboards : {len(dashboards)}") + print(f" Organic : {len(organic)}") + + # Cutoff sanity check + leaks = [r for r in results if r.published_date and r.published_date > question.as_of_date] + print(f" Post-cutoff leaks: {len(leaks)}") + if leaks: + for r in leaks: + print(f" LEAK: {r.url} pub={r.published_date.isoformat()}") + else: + print(" >> No post-cutoff results in the final list.") + + # cutoff_applied audit + bad_cutoff = [r for r in results if r.cutoff_applied != question.as_of_date] + print(f" cutoff_applied mismatches: {len(bad_cutoff)}") + + _hr("DASHBOARDS (Wayback rewrite or suppression)") + if not dashboards: + print("No dashboards present — Wayback either had no snapshot or " + "they were suppressed. Check the INFO log lines above.") + for r in dashboards: + in_wayback = "web.archive.org/web/" in r.url + snap_date = r.published_date.isoformat() if r.published_date else "n/a" + print(f" [{('WAYBACK' if in_wayback else 'LIVE!')}] {r.url}") + print(f" snapshot_date={snap_date} source={r.published_date_source}") + + _hr("DATA FOR ISSUE #5 (Tavily published_date reliability)") + source_counter: Counter = Counter(r.published_date_source for r in organic) + print("Per-result published_date_source distribution (organic only):") + for src, n in source_counter.most_common(): + label = src if src is not None else "" + print(f" {label:25s} {n}") + n_backend = source_counter.get("backend", 0) + n_recovered = sum( + n for src, n in source_counter.items() + if src in {"url_slug", "last_modified", "wayback_first_seen"} + ) + n_unsourced = source_counter.get(None, 0) + if organic: + print( + f"\nTavily-supplied date rate: {n_backend}/{len(organic)} = " + f"{n_backend / len(organic):.0%}" + ) + print( + f"Recovery-chain saves : {n_recovered}/{len(organic)} = " + f"{n_recovered / len(organic):.0%}" + ) + print( + f"Unsourced (kept anyway) : {n_unsourced}/{len(organic)} = " + f"{n_unsourced / len(organic):.0%} " + "[expected ~0 in historical mode after the filter]" + ) + + _hr("TOP 15 RESULTS (sorted by search_stage_score)") + for i, r in enumerate(results[:15], 1): + pub = r.published_date.date().isoformat() if r.published_date else "—" + src = r.published_date_source or "—" + print( + f"{i:2d}. score={r.search_stage_score:.3f} " + f"tier={r.source_tier:<13s} pub={pub:<10s} src={src:<20s} " + f"{r.domain}" + ) + print(f" {r.title[:90]}") + print(f" {r.url}") + + +if __name__ == "__main__": + main() From 211f6df0b17833ce46b2241306797cb50312c431 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20Mod=C3=A9e?= Date: Wed, 20 May 2026 13:28:41 +0200 Subject: [PATCH 02/21] Forward start_date+end_date pair to Tavily for historical replay Tavily's news endpoint silently ignores `end_date` when passed alone but honors the start_date+end_date pair, returning 20/20 native pre-cutoff results across the resolved corpus (q1, q3, q7, q9). The pipeline now synthesizes `start_date = as_of_date - 365d` (configurable via `historical_lookback_days`) and forwards both bounds. The 0/20 pre-cutoff disaster on live testing of q1 is fully addressed by this change; the post-retrieval cutoff filter remains as defense-in-depth. The TavilyBackend drops a lone `end_date` with a warning rather than sending a request Tavily will misinterpret. Stale comments referring to "Tavily ignores end_date" are updated across pipeline.py and tavily_backend.py. Co-Authored-By: Claude Opus 4.7 --- .../stages/search_stage/backends/base.py | 10 ++- .../search_stage/backends/tavily_backend.py | 41 +++++---- bioscancast/stages/search_stage/pipeline.py | 52 +++++++++--- bioscancast/tests/test_cutoff_filtering.py | 72 +++++++++++++++- bioscancast/tests/test_historical_topup.py | 2 +- .../test_search_filtering_integration.py | 2 +- bioscancast/tests/test_search_pipeline.py | 2 +- bioscancast/tests/test_tavily_backend.py | 83 +++++++++++++++++++ 8 files changed, 229 insertions(+), 35 deletions(-) create mode 100644 bioscancast/tests/test_tavily_backend.py diff --git a/bioscancast/stages/search_stage/backends/base.py b/bioscancast/stages/search_stage/backends/base.py index 24c2e2d..7fdf32e 100644 --- a/bioscancast/stages/search_stage/backends/base.py +++ b/bioscancast/stages/search_stage/backends/base.py @@ -19,9 +19,12 @@ class RawSearchResult: class SearchBackend(Protocol): """Interface that all search backends must satisfy. - ``end_date`` is an optional YYYY-MM-DD upper bound used by historical- - replay mode. Backends that don't support it should accept and ignore it - (the post-retrieval cutoff filter in the pipeline will still apply). + ``start_date`` and ``end_date`` are optional YYYY-MM-DD bounds used by + historical-replay mode. Tavily's news endpoint requires the **pair** to be + set together (see ``tavily_backend.py``); passing ``end_date`` alone is + silently ignored. Backends that don't support either should accept and + ignore them — the post-retrieval cutoff filter in the pipeline will still + apply. """ def search( @@ -29,4 +32,5 @@ def search( query: str, max_results: int = 10, end_date: Optional[str] = None, + start_date: Optional[str] = None, ) -> List[RawSearchResult]: ... diff --git a/bioscancast/stages/search_stage/backends/tavily_backend.py b/bioscancast/stages/search_stage/backends/tavily_backend.py index 490c8cc..7bce08d 100644 --- a/bioscancast/stages/search_stage/backends/tavily_backend.py +++ b/bioscancast/stages/search_stage/backends/tavily_backend.py @@ -36,25 +36,38 @@ def search( query: str, max_results: int = 10, end_date: Optional[str] = None, + start_date: Optional[str] = None, ) -> List[RawSearchResult]: - # NOTE: end_date is accepted to satisfy the SearchBackend Protocol - # but is NOT forwarded to the Tavily SDK. Empirically (Feb 2025 test - # run on q1) the Tavily client returned articles dated well after - # the supplied end_date — the parameter is documented in the SDK - # signature but does not appear to actually filter on the result's - # published_date. The post-retrieval cutoff filter in the - # SearchStagePipeline is the real defense; do not be tempted to - # rely on this parameter without re-verifying Tavily's behavior. + # Date-window behavior (verified 2026-05-20, see + # ``specs/tavily-investigation-findings.md``): Tavily's news endpoint + # honors ``start_date`` + ``end_date`` only when **both** are passed + # together. Passing ``end_date`` alone is silently ignored and the + # results come back unfiltered. The pipeline is responsible for + # supplying a sensible ``start_date`` alongside any ``end_date``; if + # only ``end_date`` is passed here we drop it rather than send a + # request we know Tavily will misinterpret. The post-retrieval cutoff + # filter in ``SearchStagePipeline`` remains the authoritative defense. from tavily import TavilyClient # lazy import to avoid hard dep at import time client = TavilyClient(api_key=self._api_key) - try: - response = client.search( - query=query, - max_results=max_results, - topic="news", - include_answer=False, + kwargs: dict = { + "query": query, + "max_results": max_results, + "topic": "news", + "include_answer": False, + } + if start_date and end_date: + kwargs["start_date"] = start_date + kwargs["end_date"] = end_date + elif end_date and not start_date: + logger.warning( + "TavilyBackend received end_date=%s without start_date; " + "dropping (Tavily ignores end_date alone). Cutoff filter " + "will still apply post-retrieval.", + end_date, ) + try: + response = client.search(**kwargs) except Exception: logger.exception("Tavily search failed for query: %s", query) return [] diff --git a/bioscancast/stages/search_stage/pipeline.py b/bioscancast/stages/search_stage/pipeline.py index 2be52fa..8739227 100644 --- a/bioscancast/stages/search_stage/pipeline.py +++ b/bioscancast/stages/search_stage/pipeline.py @@ -8,7 +8,7 @@ import logging import uuid -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from email.utils import parsedate_to_datetime from typing import List, Optional @@ -32,6 +32,14 @@ # File extensions that indicate non-content resources _NON_CONTENT_EXTENSIONS: set[str] = {".zip", ".exe", ".msi", ".dmg", ".tar", ".gz", ".png", ".jpg", ".jpeg", ".gif", ".svg", ".mp4", ".mp3"} +# Default lookback window for historical-replay mode. Tavily's news endpoint +# requires both start_date and end_date to be set together (passing end_date +# alone is silently ignored — see ``backends/tavily_backend.py``). We synthesize +# a start_date 12 months before the cutoff: empirically (2026-05-20) this gives +# 20/20 native pre-cutoff hit rate on the resolved corpus without leaking past +# the cutoff. Tune via ``historical_lookback_days`` on the pipeline. +_DEFAULT_HISTORICAL_LOOKBACK_DAYS = 365 + def _compute_freshness( published_date: Optional[datetime], @@ -132,6 +140,7 @@ def __init__( min_post_filter_results: int = 10, top_up_results_per_query: int = 50, max_top_up_rounds: int = 1, + historical_lookback_days: int = _DEFAULT_HISTORICAL_LOOKBACK_DAYS, ) -> None: self._backend = search_backend self._llm = llm_client @@ -145,6 +154,10 @@ def __init__( self._min_post_filter_results = min_post_filter_results self._top_up_results_per_query = top_up_results_per_query self._max_top_up_rounds = max_top_up_rounds + # In historical-replay mode the backend receives end_date=as_of_date + # and start_date=as_of_date-lookback. Tavily requires the pair; see + # the module-level note on ``_DEFAULT_HISTORICAL_LOOKBACK_DAYS``. + self._historical_lookback_days = historical_lookback_days def run(self, question: ForecastQuestion) -> List[SearchResult]: """Execute the full search stage pipeline.""" @@ -184,10 +197,11 @@ def run(self, question: ForecastQuestion) -> List[SearchResult]: # 6b. Top-up: in historical mode only, if we're below the survivor # threshold, run additional rounds with a larger results_per_query - # to fish for older content. Tavily's end_date doesn't actually - # filter (see tavily_backend.py), so the candidate pool is heavily - # biased toward post-cutoff content for recent topics — top-up is - # what makes historical mode actually return usable results. + # to fish for more in-window content. With the start_date+end_date + # pair now forwarded to Tavily (see backends/tavily_backend.py), + # the candidate pool is already date-filtered upstream; top-up + # mostly compensates for results dropped by deduplication and the + # blocked-domain filter. if as_of is not None: rounds_done = 0 while ( @@ -285,10 +299,12 @@ def _dedup_exclude_cutoff( @staticmethod def _apply_year_hint(query: str, as_of: Optional[datetime]) -> str: """In historical mode, append the cutoff year to the query so the - search backend's lexical match biases toward dated content. Empirically - Tavily ignores the ``end_date`` parameter, so query-text steering is - what actually drags results toward the right time period. No-op in - live mode.""" + search backend's lexical match biases toward dated content. The + start_date+end_date pair forwarded to Tavily already filters by + publication date, but the year hint reinforces topical relevance + within the window (Tavily's in-window ranking can still surface + irrelevant dated-correct results on cold or sparse queries). No-op + in live mode.""" if as_of is None: return query year = as_of.year @@ -304,9 +320,16 @@ def _execute_search( max_results: Optional[int] = None, ) -> List[RawSearchResult]: # TODO: multilingual support - # end_date is passed for Protocol compliance but TavilyBackend - # explicitly does not forward it (see backends/tavily_backend.py). - end_date_str = as_of_date.strftime("%Y-%m-%d") if as_of_date else None + # In historical-replay mode we pass BOTH start_date and end_date. + # Tavily silently ignores end_date when start_date is missing + # (verified 2026-05-20, specs/tavily-investigation-findings.md). + end_date_str: Optional[str] = None + start_date_str: Optional[str] = None + if as_of_date is not None: + end_date_str = as_of_date.strftime("%Y-%m-%d") + start_date_str = ( + as_of_date - timedelta(days=self._historical_lookback_days) + ).strftime("%Y-%m-%d") effective_max = max_results if max_results is not None else self._results_per_query if self._cache: cached = self._cache.get(self._backend_name, query, as_of_date=as_of_date) @@ -315,7 +338,10 @@ def _execute_search( return cached results = self._backend.search( - query, max_results=effective_max, end_date=end_date_str + query, + max_results=effective_max, + end_date=end_date_str, + start_date=start_date_str, ) if self._cache: diff --git a/bioscancast/tests/test_cutoff_filtering.py b/bioscancast/tests/test_cutoff_filtering.py index 5588fa9..8f7b20d 100644 --- a/bioscancast/tests/test_cutoff_filtering.py +++ b/bioscancast/tests/test_cutoff_filtering.py @@ -70,9 +70,11 @@ class _FakeBackend: def __init__(self, results: List[RawSearchResult]): self._results = results self.end_dates_seen: list = [] + self.start_dates_seen: list = [] - def search(self, query, max_results=10, end_date=None): + def search(self, query, max_results=10, end_date=None, start_date=None): self.end_dates_seen.append(end_date) + self.start_dates_seen.append(start_date) return list(self._results) @@ -200,8 +202,74 @@ def test_live_mode_unchanged(): results = pipeline.run(_make_question(as_of=None)) # Undated result MUST be kept in live mode (the cutoff filter is off) assert any(r.url == "https://news.example.com/x" for r in results) - # And backend received end_date=None + # And backend received end_date=None AND start_date=None — Tavily ignores + # end_date when start_date is missing, so the pipeline must keep them + # both unset in live mode. assert all(d is None for d in backend.end_dates_seen) + assert all(d is None for d in backend.start_dates_seen) + + +def test_historical_mode_forwards_start_and_end_date_pair(): + """Tavily honors end_date only when start_date is also set. The pipeline + must synthesize start_date = as_of - historical_lookback_days and pass + both to the backend on every search call.""" + cutoff = datetime(2024, 6, 1, tzinfo=timezone.utc) + backend = _FakeBackend( + [ + RawSearchResult( + url="https://news.example.com/x", + title="X", + snippet="", + rank=1, + published_date="2024-01-01", + ) + ] + ) + pipeline = SearchStagePipeline( + search_backend=backend, + llm_client=_FakeLLM(), + backend_name="fake", + historical_lookback_days=365, + ) + pipeline.run(_make_question(cutoff)) + # Every search call in historical mode must carry BOTH bounds. + paired = [ + (s, e) + for s, e in zip(backend.start_dates_seen, backend.end_dates_seen) + if s is not None or e is not None + ] + assert paired, "expected at least one date-bounded search in historical mode" + for start, end in paired: + assert start is not None and end is not None, ( + "Tavily ignores end_date alone — pipeline must pass the pair" + ) + assert end == "2024-06-01" + assert start == "2023-06-02" # 365 days before 2024-06-01 + + +def test_historical_lookback_days_is_configurable(): + """Override the default 365-day lookback via the pipeline constructor.""" + cutoff = datetime(2024, 6, 1, tzinfo=timezone.utc) + backend = _FakeBackend( + [ + RawSearchResult( + url="https://news.example.com/x", + title="X", + snippet="", + rank=1, + published_date="2024-01-01", + ) + ] + ) + pipeline = SearchStagePipeline( + search_backend=backend, + llm_client=_FakeLLM(), + backend_name="fake", + historical_lookback_days=30, + ) + pipeline.run(_make_question(cutoff)) + starts = [s for s in backend.start_dates_seen if s is not None] + assert starts and all(s == "2024-05-02" for s in starts) # 30 days before def test_cutoff_applied_persisted_on_results(): diff --git a/bioscancast/tests/test_historical_topup.py b/bioscancast/tests/test_historical_topup.py index c8ee0d8..15fb484 100644 --- a/bioscancast/tests/test_historical_topup.py +++ b/bioscancast/tests/test_historical_topup.py @@ -41,7 +41,7 @@ def __init__(self, results_by_query: dict[tuple[str, int], List[RawSearchResult] def set_fallback(self, results: List[RawSearchResult]) -> None: self._fallback = results - def search(self, query, max_results=10, end_date=None): + def search(self, query, max_results=10, end_date=None, start_date=None): self.calls.append((query, max_results)) # Prefer exact match on (query, max_results); else any match on # query; else fallback. diff --git a/bioscancast/tests/test_search_filtering_integration.py b/bioscancast/tests/test_search_filtering_integration.py index dfa6bb9..28b8d1c 100644 --- a/bioscancast/tests/test_search_filtering_integration.py +++ b/bioscancast/tests/test_search_filtering_integration.py @@ -21,7 +21,7 @@ class RealisticFakeSearchBackend: simulating what a real search engine would return.""" def search( - self, query: str, max_results: int = 10, end_date=None + self, query: str, max_results: int = 10, end_date=None, start_date=None ) -> List[RawSearchResult]: return [ RawSearchResult( diff --git a/bioscancast/tests/test_search_pipeline.py b/bioscancast/tests/test_search_pipeline.py index bbec0a7..907d6a4 100644 --- a/bioscancast/tests/test_search_pipeline.py +++ b/bioscancast/tests/test_search_pipeline.py @@ -79,7 +79,7 @@ def _default_results() -> List[RawSearchResult]: ] def search( - self, query: str, max_results: int = 10, end_date=None + self, query: str, max_results: int = 10, end_date=None, start_date=None ) -> List[RawSearchResult]: self.queries_received.append(query) return self._results diff --git a/bioscancast/tests/test_tavily_backend.py b/bioscancast/tests/test_tavily_backend.py new file mode 100644 index 0000000..b80842e --- /dev/null +++ b/bioscancast/tests/test_tavily_backend.py @@ -0,0 +1,83 @@ +"""Unit tests for TavilyBackend's date-window forwarding. + +The Tavily news endpoint silently ignores ``end_date`` unless ``start_date`` +is also passed (verified 2026-05-20, see +``specs/tavily-investigation-findings.md``). The backend's job is to forward +the pair when both are present, drop ``end_date`` alone with a warning, +and call the SDK with no date params otherwise. +""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from bioscancast.stages.search_stage.backends.tavily_backend import TavilyBackend + + +class _FakeTavilyClient: + """Captures the kwargs of every ``search`` call so tests can assert on them.""" + + def __init__(self, *_args, **_kwargs): + self.calls: list[dict[str, Any]] = [] + _FakeTavilyClient.last_instance = self + + def search(self, **kwargs): + self.calls.append(kwargs) + return {"results": []} + + +@pytest.fixture +def fake_tavily(monkeypatch): + """Patch tavily.TavilyClient so no network call is made.""" + import tavily + + monkeypatch.setattr(tavily, "TavilyClient", _FakeTavilyClient) + yield _FakeTavilyClient + + +def test_forwards_start_and_end_date_pair(fake_tavily): + backend = TavilyBackend(api_key="test-key") + backend.search( + "H5N1 cases", max_results=5, + start_date="2024-01-01", end_date="2025-02-17", + ) + call = fake_tavily.last_instance.calls[-1] + assert call["start_date"] == "2024-01-01" + assert call["end_date"] == "2025-02-17" + assert call["topic"] == "news" + assert call["max_results"] == 5 + + +def test_drops_end_date_when_start_date_missing(fake_tavily, caplog): + """Tavily ignores end_date alone — sending it would mislead anyone reading + the request log. The backend logs a warning and omits both.""" + backend = TavilyBackend(api_key="test-key") + with caplog.at_level("WARNING"): + backend.search("Mpox cases", end_date="2025-02-17") + call = fake_tavily.last_instance.calls[-1] + assert "end_date" not in call + assert "start_date" not in call + assert any("end_date" in rec.message and "start_date" in rec.message + for rec in caplog.records), ( + "expected a warning when end_date is passed without start_date" + ) + + +def test_no_date_params_in_live_mode(fake_tavily): + backend = TavilyBackend(api_key="test-key") + backend.search("H5N1 cases") + call = fake_tavily.last_instance.calls[-1] + assert "start_date" not in call + assert "end_date" not in call + assert call["topic"] == "news" + + +def test_start_date_without_end_date_is_also_dropped(fake_tavily): + """The pair must be complete; lone start_date is also ignored upstream.""" + backend = TavilyBackend(api_key="test-key") + backend.search("H5N1 cases", start_date="2024-01-01") + call = fake_tavily.last_instance.calls[-1] + assert "start_date" not in call + assert "end_date" not in call From 6b45372dfb3ef7202356c3bf363c642287871c44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20Mod=C3=A9e?= Date: Wed, 20 May 2026 14:24:53 +0200 Subject: [PATCH 03/21] Throttle Wayback CDX and gate first-seen for low-value domains The Wayback CDX endpoint rate-limits at ~60 req/min server-side. The existing reactive RETRY_BACKOFF_SECONDS = (0, 10, 30, 90, 240) ladder only fires after the server has already started returning 429s, burning ~6 min per failure. Historical-replay benchmarks routinely hit dozens of these failures, producing ~30 min wall-clock on q1. This commit adds two complementary measures: 1. Proactive throttle in wayback.py: a module-level _throttle() gate sleeps before every urlopen to maintain a 2.0 s minimum interval (~30 req/min, half the server cap). Overridable via env var BIOSCANCAST_WAYBACK_MIN_INTERVAL_SECONDS. The retry ladder is unchanged and still handles genuine 503s / read timeouts. 2. Selective-recovery gate in pipeline._apply_cutoff_filter: skip the Wayback first-seen leg of recover_published_date() for aggregator domains (metaculus, manifold, kalshi, ...) and source_tier=="unknown". The URL-slug regex and Last-Modified strategies still run for gated results. New wayback_skipped counter in the cutoff-filter log line. Live q1 smoke test: ~30 min -> 49 s (~37x). Test suite: 252 -> 261 passed (9 new tests covering the throttle gate, env override, retry interaction, gate decisions, and end-to-end recovery routing). Co-Authored-By: Claude Opus 4.7 --- bioscancast/stages/search_stage/pipeline.py | 37 +++++- bioscancast/stages/search_stage/wayback.py | 41 +++++++ bioscancast/tests/test_cutoff_filtering.py | 118 +++++++++++++++++++- bioscancast/tests/test_wayback_retry.py | 49 ++++++++ 4 files changed, 240 insertions(+), 5 deletions(-) diff --git a/bioscancast/stages/search_stage/pipeline.py b/bioscancast/stages/search_stage/pipeline.py index 8739227..3012858 100644 --- a/bioscancast/stages/search_stage/pipeline.py +++ b/bioscancast/stages/search_stage/pipeline.py @@ -41,6 +41,26 @@ _DEFAULT_HISTORICAL_LOOKBACK_DAYS = 365 +def _should_use_wayback_for_recovery(r: SearchResult) -> bool: + """Selective gate for the Wayback first-seen leg of the date-recovery chain. + + Wayback CDX is rate-limited (~60 req/min server-side) and even with + proactive throttling each call costs us a few seconds. For undated + results that would be dropped on quality grounds anyway — aggregators + and unknown-tier domains — there is no recall benefit to paying that + cost. The URL-slug regex and Last-Modified strategies still run; only + the Wayback leg is gated. + """ + domain = extract_domain(r.url) + if is_aggregator_domain(domain): + logger.debug("Date recovery: skipping Wayback for aggregator %s", domain) + return False + if (r.source_tier or "").lower() == "unknown": + logger.debug("Date recovery: skipping Wayback for unknown-tier %s", domain) + return False + return True + + def _compute_freshness( published_date: Optional[datetime], *, @@ -443,6 +463,7 @@ def _apply_cutoff_filter( dropped_post_cutoff = 0 dropped_undatable = 0 recovered = 0 + wayback_skipped = 0 kept: list[SearchResult] = [] for r in results: if r.published_date is not None: @@ -456,8 +477,16 @@ def _apply_cutoff_filter( kept.append(r) continue - # Undated — try the recovery chain - recovered_date, source = recover_published_date(r.url) + # Undated — try the recovery chain. Skip the Wayback first-seen + # leg for aggregator domains and unknown-tier sources: those + # results would be dropped on quality grounds anyway, and the + # CDX call (even with throttling) costs us several seconds each. + use_wayback = _should_use_wayback_for_recovery(r) + if not use_wayback: + wayback_skipped += 1 + recovered_date, source = recover_published_date( + r.url, use_wayback=use_wayback + ) if recovered_date is None: dropped_undatable += 1 logger.debug( @@ -478,9 +507,9 @@ def _apply_cutoff_filter( logger.info( "Cutoff filter: kept=%d, recovered=%d, dropped_post_cutoff=%d, " - "dropped_undatable=%d (cutoff=%s)", + "dropped_undatable=%d, wayback_skipped=%d (cutoff=%s)", len(kept), recovered, dropped_post_cutoff, dropped_undatable, - as_of.isoformat(), + wayback_skipped, as_of.isoformat(), ) return kept diff --git a/bioscancast/stages/search_stage/wayback.py b/bioscancast/stages/search_stage/wayback.py index 8cb3f60..b17bc88 100644 --- a/bioscancast/stages/search_stage/wayback.py +++ b/bioscancast/stages/search_stage/wayback.py @@ -19,7 +19,9 @@ import json import logging +import os import socket +import threading import time import urllib.error import urllib.parse @@ -50,6 +52,32 @@ # Recoverable HTTP status codes that warrant a retry. _RECOVERABLE_STATUSES = {429, 500, 502, 503, 504} +# Minimum interval between successive outbound CDX calls. Internet Archive +# rate-limits CDX at ~60 req/min server-side; the widely-used edgi-govdata +# Python client paces at ~0.8 req/s (1.25 s) by default. We sit at 2.0 s +# (30 req/min) — comfortably under the server cap with headroom for bursts, +# but ~2x throughput vs the initial conservative 4 s setting once we +# confirmed the throttle eliminates 429s in practice. Override via env var +# ``BIOSCANCAST_WAYBACK_MIN_INTERVAL_SECONDS`` for ad-hoc tuning. +_DEFAULT_MIN_INTERVAL_SECONDS = 2.0 +_MIN_INTERVAL_ENV_VAR = "BIOSCANCAST_WAYBACK_MIN_INTERVAL_SECONDS" +_throttle_lock = threading.Lock() +_last_call_monotonic: float = 0.0 + + +def _min_interval_seconds() -> float: + raw = os.environ.get(_MIN_INTERVAL_ENV_VAR) + if raw is None: + return _DEFAULT_MIN_INTERVAL_SECONDS + try: + return float(raw) + except ValueError: + logger.warning( + "Invalid %s=%r; using default %.1fs", + _MIN_INTERVAL_ENV_VAR, raw, _DEFAULT_MIN_INTERVAL_SECONDS, + ) + return _DEFAULT_MIN_INTERVAL_SECONDS + def _sleep(seconds: float) -> None: """Indirection so tests can monkeypatch a no-op sleep.""" @@ -57,6 +85,18 @@ def _sleep(seconds: float) -> None: time.sleep(seconds) +def _throttle() -> None: + """Block until the configured min interval since the last CDX call has elapsed.""" + global _last_call_monotonic + min_interval = _min_interval_seconds() + with _throttle_lock: + elapsed = time.monotonic() - _last_call_monotonic + wait = min_interval - elapsed + if wait > 0: + _sleep(wait) + _last_call_monotonic = time.monotonic() + + def _cdx_query(params: dict) -> Optional[list]: """POST-free GET against the CDX endpoint. Returns the parsed JSON list, or None on any failure. Retries on HTTP 503/429/5xx and read timeouts @@ -72,6 +112,7 @@ def _cdx_query(params: dict) -> Optional[list]: pre_delay, attempt, len(RETRY_BACKOFF_SECONDS), ) _sleep(pre_delay) + _throttle() try: req = urllib.request.Request( full_url, headers={"User-Agent": "BioScanCast/replay (+wayback-cdx)"} diff --git a/bioscancast/tests/test_cutoff_filtering.py b/bioscancast/tests/test_cutoff_filtering.py index 8f7b20d..fd30137 100644 --- a/bioscancast/tests/test_cutoff_filtering.py +++ b/bioscancast/tests/test_cutoff_filtering.py @@ -8,11 +8,12 @@ from typing import List from unittest.mock import patch -from bioscancast.filtering.models import ForecastQuestion +from bioscancast.filtering.models import ForecastQuestion, SearchResult from bioscancast.stages.search_stage.backends.base import RawSearchResult from bioscancast.stages.search_stage.pipeline import ( SearchStagePipeline, _parse_published_date, + _should_use_wayback_for_recovery, ) @@ -272,6 +273,121 @@ def test_historical_lookback_days_is_configurable(): assert starts and all(s == "2024-05-02" for s in starts) # 30 days before +def _make_search_result(url: str, source_tier: str = "trusted_media") -> SearchResult: + return SearchResult( + id="r1", + question_id="q1", + query_id="sq1", + engine="fake", + url=url, + canonical_url=None, + domain="", + title="t", + snippet="", + rank=1, + retrieved_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + source_tier=source_tier, + ) + + +class TestSelectiveRecoveryGate: + """The Wayback-leg gate on the date-recovery chain.""" + + def test_official_tier_uses_wayback(self): + r = _make_search_result( + "https://www.cdc.gov/bird-flu/situation-summary/", source_tier="official" + ) + assert _should_use_wayback_for_recovery(r) is True + + def test_academic_tier_uses_wayback(self): + r = _make_search_result( + "https://www.nature.com/articles/xyz", source_tier="academic" + ) + assert _should_use_wayback_for_recovery(r) is True + + def test_unknown_tier_skips_wayback(self): + r = _make_search_result( + "https://obscure-site.example/article", source_tier="unknown" + ) + assert _should_use_wayback_for_recovery(r) is False + + def test_aggregator_domain_skips_wayback(self): + # metaculus.com is in AGGREGATOR_DOMAINS regardless of tier label. + r = _make_search_result( + "https://www.metaculus.com/questions/12345/", + source_tier="trusted_media", + ) + assert _should_use_wayback_for_recovery(r) is False + + +def test_aggregator_undated_recovery_skips_wayback(): + """End-to-end: an undated aggregator result with no slug date routes to + recover_published_date with use_wayback=False, so the Wayback leg never + fires for it.""" + cutoff = datetime(2024, 6, 1, tzinfo=timezone.utc) + backend_results = [ + RawSearchResult( + url="https://www.metaculus.com/questions/abc", # known aggregator + title="Aggregator forecast", + snippet="", + rank=1, + published_date=None, + ), + ] + pipeline = SearchStagePipeline( + search_backend=_FakeBackend(backend_results), + llm_client=_FakeLLM(), + backend_name="fake", + ) + with patch( + "bioscancast.stages.search_stage.pipeline.recover_published_date", + return_value=(None, None), + ) as mock_rec: + pipeline.run(_make_question(cutoff)) + # The recovery function was called, but with use_wayback=False. + assert mock_rec.called + # At least one of the calls was for the aggregator URL with use_wayback=False. + aggregator_calls = [ + c for c in mock_rec.call_args_list + if c.args and "metaculus.com" in c.args[0] + ] + assert aggregator_calls + for call in aggregator_calls: + assert call.kwargs.get("use_wayback") is False + + +def test_official_undated_recovery_still_tries_wayback(): + """A tier-1 official domain with no slug date should still hit the + Wayback leg of recovery (i.e., use_wayback=True).""" + cutoff = datetime(2024, 6, 1, tzinfo=timezone.utc) + backend_results = [ + RawSearchResult( + url="https://www.cdc.gov/some/article", # tier 1 official + title="CDC article", + snippet="", + rank=1, + published_date=None, + ), + ] + pipeline = SearchStagePipeline( + search_backend=_FakeBackend(backend_results), + llm_client=_FakeLLM(), + backend_name="fake", + ) + with patch( + "bioscancast.stages.search_stage.pipeline.recover_published_date", + return_value=(None, None), + ) as mock_rec: + pipeline.run(_make_question(cutoff)) + cdc_calls = [ + c for c in mock_rec.call_args_list + if c.args and "cdc.gov" in c.args[0] + ] + assert cdc_calls + for call in cdc_calls: + assert call.kwargs.get("use_wayback") is True + + def test_cutoff_applied_persisted_on_results(): cutoff = datetime(2024, 6, 1, tzinfo=timezone.utc) backend = _FakeBackend( diff --git a/bioscancast/tests/test_wayback_retry.py b/bioscancast/tests/test_wayback_retry.py index 93e42cd..bb8b8b5 100644 --- a/bioscancast/tests/test_wayback_retry.py +++ b/bioscancast/tests/test_wayback_retry.py @@ -99,3 +99,52 @@ def test_recoverable_statuses_cover_5xx_and_429(self): ): data = wayback._cdx_query({"url": "https://example.com/"}) assert data == [] + + +class TestCdxThrottle: + """Proactive min-interval pacing in front of every urlopen.""" + + def test_throttle_paces_successive_calls(self): + sleep_calls: list[float] = [] + ok = b'[["urlkey","timestamp","original"],["a","20240101120000","b"]]' + with patch.object(wayback, "_last_call_monotonic", 0.0), patch.object( + wayback, "_min_interval_seconds", lambda: 5.0 + ), patch.object(wayback, "_sleep", lambda s: sleep_calls.append(s)), patch.object( + wayback, "RETRY_BACKOFF_SECONDS", (0,) + ), patch( + "bioscancast.stages.search_stage.wayback.urllib.request.urlopen", + side_effect=[_ok_response(ok), _ok_response(ok)], + ): + wayback._cdx_query({"url": "https://example.com/a"}) + wayback._cdx_query({"url": "https://example.com/b"}) + positive_waits = [s for s in sleep_calls if s > 0] + assert len(positive_waits) == 1 + assert 4.0 < positive_waits[0] <= 5.0 + + def test_throttle_fires_before_each_retry(self): + # Throttle paces before every urlopen — including retried ones — so a + # 503 → OK sequence yields two _throttle() calls, the second of which + # sleeps because the first urlopen just bumped _last_call_monotonic. + sleep_calls: list[float] = [] + ok = b'[["urlkey","timestamp","original"]]' + with patch.object(wayback, "_last_call_monotonic", 0.0), patch.object( + wayback, "_min_interval_seconds", lambda: 3.0 + ), patch.object(wayback, "_sleep", lambda s: sleep_calls.append(s)), patch.object( + wayback, "RETRY_BACKOFF_SECONDS", (0, 0, 0) + ), patch( + "bioscancast.stages.search_stage.wayback.urllib.request.urlopen", + side_effect=[_http_error(503), _ok_response(ok)], + ): + data = wayback._cdx_query({"url": "https://example.com/"}) + assert data == [] + positive_waits = [s for s in sleep_calls if s > 0] + assert len(positive_waits) == 1 + assert 2.0 < positive_waits[0] <= 3.0 + + def test_min_interval_env_override(self, monkeypatch): + monkeypatch.setenv("BIOSCANCAST_WAYBACK_MIN_INTERVAL_SECONDS", "1.5") + assert wayback._min_interval_seconds() == 1.5 + monkeypatch.setenv("BIOSCANCAST_WAYBACK_MIN_INTERVAL_SECONDS", "not-a-number") + assert wayback._min_interval_seconds() == wayback._DEFAULT_MIN_INTERVAL_SECONDS + monkeypatch.delenv("BIOSCANCAST_WAYBACK_MIN_INTERVAL_SECONDS", raising=False) + assert wayback._min_interval_seconds() == wayback._DEFAULT_MIN_INTERVAL_SECONDS From 508425a7570d3fec6b94c50ff4de39136cb3f8d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20Mod=C3=A9e?= Date: Wed, 20 May 2026 14:47:27 +0200 Subject: [PATCH 04/21] Extend Tavily probe and add offline analyzer probe_tavily_topic.py grows from a one-query news/general comparison into a corpus iterator with config caching, synthetic-backdated stress queries, and per-knob result dumping. analyze_tavily_probe.py is new: it reads the cached probe payloads and recomputes hit-rate tables without re-paying the Tavily quota. Both scripts default to writing/reading specs/probe-results/ (gitignored by convention; create on first run). They were the workhorses behind the start_date+end_date investigation that produced commit 211f6df. Co-Authored-By: Claude Opus 4.7 --- scripts/analyze_tavily_probe.py | 241 ++++++++++++++++++++++++++++++ scripts/probe_tavily_topic.py | 255 +++++++++++++++++++++++++------- 2 files changed, 441 insertions(+), 55 deletions(-) create mode 100644 scripts/analyze_tavily_probe.py diff --git a/scripts/analyze_tavily_probe.py b/scripts/analyze_tavily_probe.py new file mode 100644 index 0000000..7d75764 --- /dev/null +++ b/scripts/analyze_tavily_probe.py @@ -0,0 +1,241 @@ +"""Consolidate Tavily probe-results JSON dumps into a hit-rate table. + +Reads every ``specs/probe-results/*.json`` produced by ``probe_tavily_topic.py``, +applies the production cutoff filter + URL-slug date recovery, and prints a +markdown table suitable for pasting into the findings doc. + +Also computes a "hybrid" row per question_id: union of news + general results +under matching knobs, deduped by URL. + +No network calls. Safe to re-run any time. +""" + +from __future__ import annotations + +import json +import os +import sys +from collections import defaultdict +from datetime import date, datetime +from pathlib import Path +from typing import Any, Iterable + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from bioscancast.stages.search_stage.date_recovery import date_from_url_slug + +REPO_ROOT = Path(__file__).resolve().parent.parent +RESULTS_DIR = REPO_ROOT / "specs" / "probe-results" + + +def parse_published_date(dstr: str | None) -> date | None: + if not dstr: + return None + try: + return date.fromisoformat(dstr[:10]) + except ValueError: + pass + try: + from email.utils import parsedate_to_datetime + return parsedate_to_datetime(dstr).date() + except (ValueError, TypeError): + return None + + +def classify_one(result: dict[str, Any], cutoff: date) -> dict[str, Any]: + url = result.get("url", "") + raw = result.get("published_date") + pd = parse_published_date(raw) + slug = date_from_url_slug(url) + slug_d = slug.date() if slug else None + # "effective" date: prefer native published_date, fall back to slug. + effective = pd or slug_d + return { + "url": url, + "title": result.get("title", ""), + "raw_published_date": raw, + "parsed_published_date": pd.isoformat() if pd else None, + "slug_date": slug_d.isoformat() if slug_d else None, + "effective_date": effective.isoformat() if effective else None, + "native_pre_cutoff": pd is not None and pd <= cutoff, + "effective_pre_cutoff": effective is not None and effective <= cutoff, + "native_dated": pd is not None, + "effective_dated": effective is not None, + } + + +def analyze_payload(payload: dict[str, Any]) -> dict[str, Any]: + cutoff = date.fromisoformat(payload["cutoff"]) + results = payload["response"].get("results", []) or [] + classified = [classify_one(r, cutoff) for r in results] + n = len(classified) or 1 + return { + "tag": payload["tag"], + "query": payload["query"], + "cutoff": payload["cutoff"], + "knobs": payload["knobs"], + "n_results": len(classified), + "native_pre_cutoff": sum(1 for c in classified if c["native_pre_cutoff"]), + "native_dated": sum(1 for c in classified if c["native_dated"]), + "effective_pre_cutoff": sum(1 for c in classified if c["effective_pre_cutoff"]), + "effective_dated": sum(1 for c in classified if c["effective_dated"]), + "results": classified, + "fetched_at": payload.get("fetched_at"), + } + + +def knob_summary(knobs: dict[str, Any]) -> str: + """Compact human-readable summary of the non-default knobs.""" + parts = [] + topic = knobs.get("topic", "news") + parts.append(topic) + for k, v in sorted(knobs.items()): + if k in {"topic", "max_results", "include_answer"}: + continue + if k == "include_domains": + parts.append(f"domains={len(v)}") + else: + parts.append(f"{k}={v}") + return " ".join(parts) + + +def load_all() -> list[dict[str, Any]]: + out = [] + if not RESULTS_DIR.exists(): + return out + for path in sorted(RESULTS_DIR.glob("*.json")): + with path.open(encoding="utf-8") as f: + payload = json.load(f) + out.append(analyze_payload(payload)) + return out + + +def emit_table(rows: Iterable[dict[str, Any]]) -> str: + rows = list(rows) + lines = [] + header = "| tag | config | n | native pre/dated | + slug pre/dated |" + sep = "|---|---|---|---|---|" + lines.append(header) + lines.append(sep) + for r in rows: + cfg = knob_summary(r["knobs"]) + native = f"{r['native_pre_cutoff']}/{r['native_dated']}" + eff = f"{r['effective_pre_cutoff']}/{r['effective_dated']}" + lines.append(f"| {r['tag']} | {cfg} | {r['n_results']} | {native} | {eff} |") + return "\n".join(lines) + + +def compute_hybrid(rows: list[dict[str, Any]]) -> list[dict[str, Any]]: + """For each tag, find the news-topic and general-topic rows under otherwise + matching knobs and produce a unioned hybrid row.""" + by_tag: dict[str, dict[str, list[dict[str, Any]]]] = defaultdict(lambda: defaultdict(list)) + for r in rows: + topic = r["knobs"].get("topic", "news") + # Match by (tag, non-topic-knobs); store both topic variants. + non_topic = {k: v for k, v in r["knobs"].items() if k != "topic"} + key = json.dumps(non_topic, sort_keys=True) + by_tag[r["tag"]][key].append(r) + + hybrid_rows = [] + for tag, by_knobs in by_tag.items(): + for key, group in by_knobs.items(): + if len({r["knobs"].get("topic") for r in group}) < 2: + continue + news = [r for r in group if r["knobs"].get("topic") == "news"] + general = [r for r in group if r["knobs"].get("topic") == "general"] + if not news or not general: + continue + news_r = news[0] + general_r = general[0] + seen_urls: set[str] = set() + unioned = [] + for src in (news_r, general_r): + for c in src["results"]: + if c["url"] in seen_urls: + continue + seen_urls.add(c["url"]) + unioned.append(c) + native_pre = sum(1 for c in unioned if c["native_pre_cutoff"]) + native_dated = sum(1 for c in unioned if c["native_dated"]) + eff_pre = sum(1 for c in unioned if c["effective_pre_cutoff"]) + eff_dated = sum(1 for c in unioned if c["effective_dated"]) + cfg_knobs = {**json.loads(key), "topic": "hybrid(news+general)"} + hybrid_rows.append({ + "tag": tag, + "query": news_r["query"], + "cutoff": news_r["cutoff"], + "knobs": cfg_knobs, + "n_results": len(unioned), + "native_pre_cutoff": native_pre, + "native_dated": native_dated, + "effective_pre_cutoff": eff_pre, + "effective_dated": eff_dated, + "results": unioned, + }) + return hybrid_rows + + +def print_url_slug_coverage(rows: list[dict[str, Any]]) -> None: + """Audit: for general-mode rows with no native dates, what fraction of URLs + yield a date via the slug regex?""" + print("\n## URL-slug recovery coverage (general-mode, no native date)\n") + print("| tag | knobs | undated_urls | slug_recovered | recovery_rate |") + print("|---|---|---|---|---|") + for r in rows: + if r["knobs"].get("topic") != "general": + continue + undated = [c for c in r["results"] if not c["native_dated"]] + recovered = [c for c in undated if c["slug_date"] is not None] + if not undated: + continue + print( + f"| {r['tag']} | {knob_summary(r['knobs'])} | {len(undated)} | " + f"{len(recovered)} | {len(recovered) / len(undated):.0%} |" + ) + + +def print_undated_url_sample(rows: list[dict[str, Any]], n: int = 30) -> None: + """For Phase E: list a sample of undated, slug-non-matching URLs so we can + eyeball what patterns Tavily-general returns.""" + print("\n## Undated URLs that the slug regex does NOT catch (sample)\n") + seen: set[str] = set() + count = 0 + for r in rows: + if r["knobs"].get("topic") != "general": + continue + for c in r["results"]: + if c["native_dated"] or c["slug_date"] is not None: + continue + if c["url"] in seen: + continue + seen.add(c["url"]) + print(f"- [{r['tag']}] {c['url']}") + count += 1 + if count >= n: + return + + +def main() -> None: + rows = load_all() + if not rows: + print("No probe-results/*.json found. Run probe_tavily_topic.py first.") + return + + print(f"# Tavily probe analysis ({len(rows)} runs)\n") + print("## All runs\n") + print(emit_table(rows)) + + hybrid = compute_hybrid(rows) + if hybrid: + print("\n## Hybrid (news+general union)\n") + print(emit_table(hybrid)) + + print_url_slug_coverage(rows) + print_undated_url_sample(rows) + + # Total call count = number of payloads (one Tavily call per cache entry). + print(f"\n_Total cached Tavily calls: {len(rows)}_") + + +if __name__ == "__main__": + main() diff --git a/scripts/probe_tavily_topic.py b/scripts/probe_tavily_topic.py index 57ebe49..2cd4648 100644 --- a/scripts/probe_tavily_topic.py +++ b/scripts/probe_tavily_topic.py @@ -1,18 +1,45 @@ -"""Compare Tavily ``topic="news"`` vs default (``general``) for one historical -query. No Wayback, no LLM — just two Tavily calls and a date distribution -so we can see whether dropping ``topic="news"`` would actually surface more -pre-cutoff results. +"""Probe Tavily configurations across the BioScanCast resolved corpus. -Run: - python scripts/probe_tavily_topic.py +Originally a single-query script comparing topic="news" vs topic="general" on +q1 (H5N1 US, cutoff Feb 17 2025). Now generalized to iterate the corpus and +explore Tavily knobs (search_depth, include_domains, exact_match, etc.) under +the historical-replay cutoff machinery. + +Investigation context: see ``specs/tavily-historical-coverage.md`` and the +plan at ``~/.claude/plans/i-d-like-you-to-wondrous-whale.md``. + +Each (question x config) result is dumped to ``specs/probe-results/`` as JSON +so the analyzer (``analyze_tavily_probe.py``) can re-compute hit rates and +date-recovery coverage offline without re-paying the Tavily quota. + +Examples: + # All resolved questions, news topic, default settings + python scripts/probe_tavily_topic.py --question-id all --topic news + + # Single question, advanced search_depth + python scripts/probe_tavily_topic.py --question-id q1 --topic news \ + --knobs '{"search_depth": "advanced"}' + + # Synthetic backdated query (override question text + cutoff) + python scripts/probe_tavily_topic.py --synthetic-query "MERS-CoV cases Saudi Arabia 2015" \ + --synthetic-cutoff 2017-01-01 --synthetic-tag mers2015 --topic news + + # Original q1/news+general behavior (legacy) + python scripts/probe_tavily_topic.py --legacy """ from __future__ import annotations +import argparse +import csv +import hashlib +import json import os import sys from collections import Counter -from datetime import date +from datetime import date, datetime, timedelta, timezone +from pathlib import Path +from typing import Any sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) @@ -25,75 +52,193 @@ from tavily import TavilyClient -QUERY = "H5N1 human cases United States 2025" -CUTOFF = date(2025, 2, 17) -MAX_RESULTS = 20 +REPO_ROOT = Path(__file__).resolve().parent.parent +CORPUS_CSV = REPO_ROOT / "bioscancast" / "stages" / "eval_stage" / "bioscancast_questions.csv" +RESULTS_DIR = REPO_ROOT / "specs" / "probe-results" + + +def excel_serial_to_date(serial: int | str) -> date: + """Excel epoch is 1899-12-30 (Lotus 1-2-3 leap-year bug correction).""" + return (datetime(1899, 12, 30) + timedelta(days=int(serial))).date() + + +def load_resolved_questions() -> list[dict[str, Any]]: + out: list[dict[str, Any]] = [] + with CORPUS_CSV.open(encoding="utf-8") as f: + reader = csv.DictReader(f, delimiter=";") + for row in reader: + if row.get("question_status") == "resolved": + row["cutoff_date"] = excel_serial_to_date(row["created_date"]) + out.append(row) + return out -def _bucket(dstr: str | None) -> str: +def get_question(qid: str) -> dict[str, Any]: + for q in load_resolved_questions(): + if q["question_id"] == qid: + return q + raise SystemExit(f"Unknown or unresolved question_id: {qid}") + + +def _bucket(dstr: str | None, cutoff: date) -> str: if not dstr: return "no_date" try: d = date.fromisoformat(dstr[:10]) except ValueError: - return "unparseable" - if d <= CUTOFF: + try: + from email.utils import parsedate_to_datetime + d = parsedate_to_datetime(dstr).date() + except (ValueError, TypeError): + return "unparseable" + if d <= cutoff: return "pre_cutoff" - if d.year == 2025: - return "post_cutoff_2025" - if d.year == 2026: - return "post_cutoff_2026" return f"post_cutoff_{d.year}" -def _run(client: TavilyClient, *, with_news_topic: bool) -> None: - kwargs: dict = { - "query": QUERY, - "max_results": MAX_RESULTS, - "include_answer": False, +def config_hash(query: str, cutoff: date, knobs: dict[str, Any]) -> str: + payload = json.dumps({"query": query, "cutoff": cutoff.isoformat(), "knobs": knobs}, sort_keys=True) + return hashlib.sha1(payload.encode()).hexdigest()[:10] + + +def cache_path(tag: str, knobs: dict[str, Any]) -> Path: + """Filename: __.json. Tag is question_id or synthetic-tag.""" + knob_summary = "_".join(f"{k}={v}" for k, v in sorted(knobs.items()) if k != "include_domains") + if "include_domains" in knobs: + knob_summary += "_domains=" + str(len(knobs["include_domains"])) + knob_summary = knob_summary.replace("/", "_").replace(":", "_")[:60] or "default" + h = hashlib.sha1(json.dumps(knobs, sort_keys=True).encode()).hexdigest()[:8] + return RESULTS_DIR / f"{tag}__{knob_summary}__{h}.json" + + +def run_probe( + client: TavilyClient, + *, + tag: str, + query: str, + cutoff: date, + knobs: dict[str, Any], + force: bool = False, +) -> dict[str, Any]: + """Run one Tavily call (cached). Returns the cached payload.""" + RESULTS_DIR.mkdir(parents=True, exist_ok=True) + path = cache_path(tag, knobs) + if path.exists() and not force: + with path.open(encoding="utf-8") as f: + return json.load(f) + + kwargs: dict[str, Any] = {"query": query, "include_answer": False, **knobs} + if "max_results" not in kwargs: + kwargs["max_results"] = 20 + resp = client.search(**kwargs) + + payload = { + "tag": tag, + "query": query, + "cutoff": cutoff.isoformat(), + "knobs": knobs, + "fetched_at": datetime.now(timezone.utc).isoformat(), + "response": resp, } - if with_news_topic: - kwargs["topic"] = "news" + with path.open("w", encoding="utf-8") as f: + json.dump(payload, f, indent=2) + return payload - print(f"\n{'=' * 72}") - print(f"Tavily topic = {'news' if with_news_topic else 'general (default)'}") - print(f"Query = {QUERY!r}") - print(f"Cutoff = {CUTOFF.isoformat()}") - print("=" * 72) - - resp = client.search(**kwargs) - results = resp.get("results", []) - print(f"Returned {len(results)} results\n") +def summarize(payload: dict[str, Any]) -> None: + cutoff = date.fromisoformat(payload["cutoff"]) + results = payload["response"].get("results", []) or [] buckets: Counter = Counter() - for i, r in enumerate(results, 1): + dated = 0 + for r in results: d = r.get("published_date") - bucket = _bucket(d) - buckets[bucket] += 1 - date_label = (d or "—")[:10] if d else "—" - url = r.get("url", "") - title = (r.get("title") or "")[:80] - marker = " " if bucket == "pre_cutoff" else " " - if bucket == "pre_cutoff": - marker = "✓ " - print(f"{marker}{i:2d}. {date_label:<10} [{bucket:<20}] {title}") - print(f" {url}") - - print(f"\nBucket totals:") - for bucket, n in sorted(buckets.items(), key=lambda kv: -kv[1]): - print(f" {bucket:<20} {n}") + if d: + dated += 1 + buckets[_bucket(d, cutoff)] += 1 pre = buckets.get("pre_cutoff", 0) - if results: - print(f"\nPre-cutoff hit rate: {pre}/{len(results)} = {pre / len(results):.0%}") - - -def main() -> None: + knob_str = ", ".join(f"{k}={v}" for k, v in sorted(payload["knobs"].items()))[:80] or "(default)" + n = len(results) or 1 + print( + f" {payload['tag']:>10} cutoff={cutoff} {knob_str:<82} " + f"-> pre={pre}/{len(results)} ({pre / n:.0%}) dated={dated}/{len(results)}" + ) + + +def add_year_hint(query: str, cutoff: date) -> str: + """Mirror the pipeline's year-hint suffix so probes match pipeline behavior.""" + y = str(cutoff.year) + if y in query: + return query + return f"{query} {y}" + + +def build_query_from_question(q: dict[str, Any], hint_year: bool = True) -> str: + """Construct a search query from a corpus question. Strip framing words + ("How many ... will be reported ... according to ...") to expose the + topical noun phrase. Keep the topic prefix as a hint.""" + text = q["question_text"] + base = f"{q['topic']} {text}" + return add_year_hint(base, q["cutoff_date"]) if hint_year else base + + +def parse_args(argv: list[str]) -> argparse.Namespace: + p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + p.add_argument("--question-id", help="Resolved question id (q1, q3, q7, q9) or 'all'.") + p.add_argument("--topic", choices=["news", "general", "finance"], default="news") + p.add_argument("--knobs", default="{}", help="JSON object of extra Tavily kwargs.") + p.add_argument("--synthetic-query", help="Use a free-form query instead of corpus.") + p.add_argument("--synthetic-cutoff", help="YYYY-MM-DD cutoff for synthetic query.") + p.add_argument("--synthetic-tag", help="Short tag for cache filename (synthetic only).") + p.add_argument("--no-year-hint", action="store_true", help="Skip the year-suffix hint.") + p.add_argument("--force", action="store_true", help="Bypass cache and re-call Tavily.") + p.add_argument("--legacy", action="store_true", help="Replicate original q1 news+general behavior.") + return p.parse_args(argv) + + +def main(argv: list[str] | None = None) -> None: + args = parse_args(argv or sys.argv[1:]) api_key = os.environ.get("TAVILY_API_KEY") if not api_key: sys.exit("TAVILY_API_KEY missing") client = TavilyClient(api_key=api_key) - _run(client, with_news_topic=True) - _run(client, with_news_topic=False) + + if args.legacy: + q = get_question("q1") + query = build_query_from_question(q, hint_year=False) + cutoff = q["cutoff_date"] + for topic in ("news", "general"): + payload = run_probe( + client, tag="q1_legacy", query=query, cutoff=cutoff, + knobs={"topic": topic, "max_results": 20}, force=args.force, + ) + summarize(payload) + return + + knobs = json.loads(args.knobs) + knobs.setdefault("topic", args.topic) + knobs.setdefault("max_results", 20) + + if args.synthetic_query: + if not args.synthetic_cutoff: + sys.exit("--synthetic-cutoff required with --synthetic-query") + tag = args.synthetic_tag or "synth" + cutoff = date.fromisoformat(args.synthetic_cutoff) + query = args.synthetic_query if args.no_year_hint else add_year_hint(args.synthetic_query, cutoff) + payload = run_probe(client, tag=tag, query=query, cutoff=cutoff, knobs=knobs, force=args.force) + summarize(payload) + return + + if not args.question_id: + sys.exit("provide --question-id, --synthetic-query, or --legacy") + + qids = ["q1", "q3", "q7", "q9"] if args.question_id == "all" else [args.question_id] + for qid in qids: + q = get_question(qid) + query = build_query_from_question(q, hint_year=not args.no_year_hint) + payload = run_probe( + client, tag=qid, query=query, cutoff=q["cutoff_date"], knobs=knobs, force=args.force, + ) + summarize(payload) if __name__ == "__main__": From 810327f5350fe4513a8585cc3bbdf39981144455 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20Mod=C3=A9e?= Date: Tue, 26 May 2026 13:07:19 +0200 Subject: [PATCH 05/21] Loosen hallucination guard to accept normalisation drift MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous strict whitespace-normalised substring check rejected real factual quotes whenever the LLM made any small punctuation or unicode adjustment — and live tests showed it does so constantly on real WHO/CDC/ECDC documents. The headline failure was CDC MMWR producing a single record from a doc that says "99 cases" five times, because every variant the model emitted had either a trailing period where the source had a comma, parens dropped around "(NMDOH)", or a smart-quote vs straight-quote mismatch. The new layered guard `_quote_matches` accepts a quote that matches the chunk under any of three increasingly permissive normalisations: 1. NFKC + explicit typography fold (smart quotes, em/en dashes, ellipsis) + whitespace collapse. 2. Strip terminal ".;,:!?" from the quote. 3. Strip wrapping punctuation `()[]{}"'` from both sides. The function returns the canonical chunk substring it matched against, so ChunkReference.quote always stores a verbatim chunk excerpt rather than the model's altered output. Tests show this lifts real-fact capture from 23 to 34 records on the 6 real biosecurity test documents (a 48% increase) while still rejecting fabricated quotes, content-insertion hallucinations (extra words in a list), synthesised prefix-onto-fragment paraphrases, and wrong-number alterations. The 30-character-window match described in the plan turned out to be unnecessary once the typography fold and paren-strip layers were in place — and would have weakened the guard on real hallucinations without buying additional real-quote coverage. New parametrised tests cover each layer plus the rejection cases. Co-Authored-By: Claude Opus 4.7 --- .../insight/extraction/chunk_extractor.py | 151 +++++++++++++- .../tests/test_insight_chunk_extractor.py | 192 ++++++++++++++++++ 2 files changed, 333 insertions(+), 10 deletions(-) diff --git a/bioscancast/insight/extraction/chunk_extractor.py b/bioscancast/insight/extraction/chunk_extractor.py index 2a4666b..b7a1633 100644 --- a/bioscancast/insight/extraction/chunk_extractor.py +++ b/bioscancast/insight/extraction/chunk_extractor.py @@ -10,6 +10,7 @@ import logging import re +import unicodedata import uuid from datetime import datetime, timezone from typing import TYPE_CHECKING, Optional @@ -23,6 +24,47 @@ logger = logging.getLogger(__name__) +# Punctuation the hallucination guard is willing to ignore at the very end +# of a model-supplied quote. Live tests on real WHO/CDC/ECDC documents show +# the model habitually closes paraphrased quotes with '.' even when the +# source has a comma, semicolon, or no terminator at that position. +_TERMINAL_PUNCT = ".;,:!?" + +# Typography → ASCII folding applied at layer 1 before substring matching. +# NFKC alone does NOT fold smart quotes (U+2018/9, U+201C/D) or em/en +# dashes — those are independent Unicode codepoints, not compatibility +# forms. But real biosecurity sources mix them freely with their ASCII +# equivalents (WHO and ECDC PDFs in particular use curly quotes and +# em-dashes), and the model normalises them inconsistently in its +# output. Folding here keeps the guard robust to those variants. +_TYPOGRAPHY_FOLD: dict[str, str] = { + "‘": "'", # LEFT SINGLE QUOTATION MARK + "’": "'", # RIGHT SINGLE QUOTATION MARK + "‚": "'", # SINGLE LOW-9 QUOTATION MARK + "‛": "'", # SINGLE HIGH-REVERSED-9 + "“": '"', # LEFT DOUBLE QUOTATION MARK + "”": '"', # RIGHT DOUBLE QUOTATION MARK + "„": '"', # DOUBLE LOW-9 QUOTATION MARK + "–": "-", # EN DASH + "—": "-", # EM DASH + "−": "-", # MINUS SIGN + "…": "...", # HORIZONTAL ELLIPSIS +} + +_TYPOGRAPHY_FOLD_RE = re.compile( + "|".join(re.escape(k) for k in _TYPOGRAPHY_FOLD) +) + +# Wrapping punctuation the guard will strip from both sides at layer 3. +# These are characters whose presence-vs-absence around inline elements +# (acronyms like "(NMDOH)", figures like "[12]", quoted speech) flips +# between model output and source text without changing meaning. We do +# NOT strip hyphens or other connecting punctuation because those carry +# semantic load (e.g. "outbreak-related"). Note: smart quotes have +# already been folded to ASCII at layer 1, so this regex only needs to +# list the ASCII variants. +_WRAPPING_PUNCT_RE = re.compile(r"[\(\)\[\]\{\}\"\']") + # Hardcoded country name -> ISO 3166-1 alpha-2 map for the ~30 most # likely countries in biosecurity reporting. Don't pull in pycountry. @@ -70,10 +112,98 @@ def _normalize_whitespace(text: str) -> str: - """Collapse all whitespace to single spaces for substring matching.""" + """Collapse all whitespace to single spaces for substring matching. + + Retained as a thin wrapper for callers (and tests) that pre-date the + NFKC-aware match logic. New code should use ``_normalize_for_match``. + """ return re.sub(r"\s+", " ", text).strip() +def _normalize_for_match(text: str) -> str: + """NFKC + typography-to-ASCII fold + whitespace collapse. + + Used by the hallucination guard to compare quotes against chunk text + on a stable footing. NFKC handles compatibility chars (non-breaking + spaces, full-width ASCII); the explicit typography fold handles + smart quotes and em/en dashes (which are NOT compatibility chars in + Unicode). Without these, the guard rejects real quotes whose only + difference from the source is a typographic variant. + """ + if not text: + return "" + text = unicodedata.normalize("NFKC", text) + text = _TYPOGRAPHY_FOLD_RE.sub(lambda m: _TYPOGRAPHY_FOLD[m.group(0)], text) + return re.sub(r"\s+", " ", text).strip() + + +def _quote_matches(quote: str, chunk_text: str) -> Optional[str]: + """Hallucination guard: return the canonical chunk substring the quote + matches, or ``None`` if no match. + + Layers applied in order: + + 1. **NFKC + whitespace collapse → exact substring.** Catches curly vs + straight apostrophes, non-breaking spaces, em-dashes, full-width + ASCII, and the model's whitespace habits. + 2. **Strip terminal punctuation** (``.;,:!?``) from the normalised + quote, then substring check again. Catches the model's strong + tendency to close paraphrased quotes with ``.`` even when the + source has a comma or no punctuation at that position (e.g. + source: ``"...reported by Italy (63), Spain..."``; model quote: + ``"...reported by Italy (63)."``). + 3. **Strip wrapping punctuation** (``()[]{}""``) from both quote and + chunk and retry (also dropping terminal punctuation from the + quote). Catches the model's habit of dropping the parens around + acronyms (source: ``"f Health (NMDOH) eventually reported..."``; + model quote: ``"NMDOH eventually reported..."``). + + The function returns the *canonical* form of the matched substring + (with the same transformations applied that made the match succeed) + rather than the model's original output, so the stored + ``ChunkReference.quote`` always corresponds to actual chunk content + after the same normalisation. Returns ``None`` when no layer + matches — caller drops the fact. + + Note: this loosening was driven by live tests showing the strict + substring-only guard rejected ~85% of real factual quotes due to + minor punctuation/unicode drift on real WHO/CDC/ECDC documents, + while the looser three-layer guard still rejects substantive + paraphrases (e.g. the model bolting a prefix from one sentence onto + a fragment of another) and content-insertion hallucinations (extra + words in a list). + """ + if not quote: + return None + norm_quote = _normalize_for_match(quote) + if not norm_quote: + return None + norm_chunk = _normalize_for_match(chunk_text) + + # Layer 1: exact substring after NFKC + whitespace + if norm_quote in norm_chunk: + return norm_quote + + # Layer 2: strip terminal punctuation from the quote and retry + stripped = norm_quote.rstrip(_TERMINAL_PUNCT).strip() + if stripped and stripped != norm_quote and stripped in norm_chunk: + return stripped + + # Layer 3: strip wrapping punctuation everywhere on both sides, then + # strip terminal punctuation from the quote, and retry. + unwrap_quote = _WRAPPING_PUNCT_RE.sub("", stripped or norm_quote) + unwrap_quote = re.sub(r"\s+", " ", unwrap_quote).strip() + unwrap_quote = unwrap_quote.rstrip(_TERMINAL_PUNCT).strip() + if not unwrap_quote: + return None + unwrap_chunk = _WRAPPING_PUNCT_RE.sub("", norm_chunk) + unwrap_chunk = re.sub(r"\s+", " ", unwrap_chunk).strip() + if unwrap_quote in unwrap_chunk: + return unwrap_quote + + return None + + def _resolve_country_code(location: Optional[str]) -> Optional[str]: """Try to resolve a location string to an ISO country code.""" if not location: @@ -138,22 +268,23 @@ def extract_facts_from_chunk( facts_raw = response.content.get("facts", []) records: list[InsightRecord] = [] - chunk_text_normalized = _normalize_whitespace(chunk.text) for fact in facts_raw: - quote = fact.get("quote", "") - quote_normalized = _normalize_whitespace(quote) + raw_quote = fact.get("quote", "") # --- Hallucination guard --- - # The quote must appear as a substring in the chunk text. - # Exact substring check (whitespace-normalized) is the point — - # don't soften to fuzzy match without careful consideration. - if not quote_normalized or quote_normalized not in chunk_text_normalized: + # The quote must appear as a substring in the chunk text under + # NFKC + whitespace normalisation, optionally with terminal + # punctuation stripped. The guard rejects substantive paraphrases + # and content-insertion hallucinations. See ``_quote_matches`` for + # the rationale and the layers. + canonical_quote = _quote_matches(raw_quote, chunk.text) + if canonical_quote is None: logger.warning( "Hallucination guard: dropping fact with non-matching quote. " "chunk_id=%s, quote=%r", chunk.chunk_id, - quote[:100], + raw_quote[:100], ) continue @@ -185,7 +316,7 @@ def extract_facts_from_chunk( document_id=document.id, chunk_id=chunk.chunk_id, source_url=document.source_url, - quote=quote[:200], + quote=canonical_quote[:200], ), ], ) diff --git a/bioscancast/tests/test_insight_chunk_extractor.py b/bioscancast/tests/test_insight_chunk_extractor.py index 97fe451..d96a14f 100644 --- a/bioscancast/tests/test_insight_chunk_extractor.py +++ b/bioscancast/tests/test_insight_chunk_extractor.py @@ -3,11 +3,14 @@ All tests use FakeLLMClient — no network calls, no real OpenAI imports. """ +import pytest + from bioscancast.llm.fake_client import FakeLLMClient from bioscancast.insight.extraction.chunk_extractor import ( extract_facts_from_chunk, _resolve_country_code, _normalize_whitespace, + _quote_matches, ) from bioscancast.tests.fixtures.insight.synthetic_documents import ( @@ -177,3 +180,192 @@ def test_response_returned_for_budget_tracking(): assert response.input_tokens == 150 assert response.output_tokens == 15 assert response.model == "gpt-4o-mini" + + +# --------------------------------------------------------------------------- +# Hallucination guard — layered match behaviour +# --------------------------------------------------------------------------- + +# Each tuple: (label, chunk_text, quote, should_match) +# +# These cases come from live-LLM observations on real WHO/CDC/ECDC documents. +# The strict substring guard rejected ~85% of real factual quotes due to +# minor punctuation/unicode drift; the looser layered guard must keep +# accepting wholesale fabrications while accepting the real-but-drifted +# variants below. + +_LAYER1_NFKC_CASES = [ + ( + "curly-apos source vs straight-apos quote", + "Côte d’Ivoire reported four confirmed cases each in January", + "Côte d'Ivoire reported four confirmed cases each in January", + True, + ), + ( + "non-breaking space in source number", + "30 EU/EEA Member States reported a total of 4 623 cases of measles.", + "30 EU/EEA Member States reported a total of 4 623 cases of measles", + True, + ), + ( + "em-dash in source vs hyphen in quote", + "Measles—Multi-country—Monitoring European outbreaks", + "Measles-Multi-country-Monitoring European outbreaks", + # The typography fold step maps em-dash to hyphen. + True, + ), + ( + "newline in source vs space in quote", + "Democratic Republic of the Congo\n6 543", + "Democratic Republic of the Congo 6 543", + True, + ), +] + + +_LAYER2_TERMINAL_PUNCTUATION_CASES = [ + ( + "comma in source becomes period in quote", + "...reported by Italy (63), Spain (36), France (16) and Poland (five).", + "The highest case counts were reported by Italy (63).", + # Quote prefix not in source — should reject + False, + ), + ( + "comma in source becomes period in quote — full prefix", + "The highest case counts were reported by Italy (63), Spain (36), France (16) and Poland (five).", + "The highest case counts were reported by Italy (63).", + True, + ), + ( + "no terminator in source vs period in quote", + "Spain reported 97 cases of measles from 1 January to 12 April 2026 according to ECDC", + "Spain reported 97 cases of measles from 1 January to 12 April 2026.", + True, + ), +] + + +_LAYER3_WRAPPING_PUNCTUATION_CASES = [ + ( + "parens around acronym in source dropped in quote", + "f Health (NMDOH) eventually reported 99 outbreak-related measles cases, approximately one half.", + "NMDOH eventually reported 99 outbreak-related measles cases.", + True, + ), + ( + "double quotes in source dropped in quote", + "The CDC said “we are responding” to the outbreak.", + "The CDC said we are responding to the outbreak.", + True, + ), + ( + "square brackets in source dropped in quote", + "the result [12] shows a clear trend in cases.", + "the result 12 shows a clear trend in cases.", + True, + ), +] + + +_HALLUCINATION_CASES = [ + ( + "fabricated word inserted into list", + "Ghana and Liberia have reported human mpox due to clade IIa MPXV.", + "Ghana, Atlantis, and Liberia have reported human mpox due to clade IIa MPXV.", + False, + ), + ( + "wholesale fabrication", + "Some real chunk content about measles cases in Utah.", + "THIS QUOTE WAS INVENTED BY THE MODEL AND APPEARS NOWHERE.", + False, + ), + ( + "synthesised prefix bolted onto a fragment", + # Model takes a prefix from sentence A and bolts it onto a fragment of sentence B + "Italy reported 63 cases. Spain reported 36 cases.", + "Italy reported 36 cases.", + False, + ), + ( + "wrong number", + "Spain reported 36 cases yesterday.", + "Spain reported 363 cases yesterday.", + False, + ), + ( + "empty quote", + "Real chunk content.", + "", + False, + ), + ( + "whitespace-only quote", + "Real chunk content.", + " \n\t ", + False, + ), +] + + +@pytest.mark.parametrize( + "label,chunk_text,quote,should_match", + _LAYER1_NFKC_CASES + _LAYER2_TERMINAL_PUNCTUATION_CASES + _LAYER3_WRAPPING_PUNCTUATION_CASES, +) +def test_quote_matches_accepts_real_quotes_with_normalisation_drift( + label, chunk_text, quote, should_match +): + """Real factual quotes with NFKC / terminal-punctuation / wrapping- + punctuation drift should be accepted by the layered guard.""" + result = _quote_matches(quote, chunk_text) + if should_match: + assert result is not None, ( + f"{label}: expected match, got None. quote={quote!r}" + ) + # The returned canonical form must itself be a substring of the + # *normalised* chunk text — that's the invariant the guard + # guarantees to downstream consumers. + from bioscancast.insight.extraction.chunk_extractor import ( + _normalize_for_match, + _WRAPPING_PUNCT_RE, + ) + import re + norm_chunk = _normalize_for_match(chunk_text) + unwrap_chunk = re.sub( + r"\s+", " ", _WRAPPING_PUNCT_RE.sub("", norm_chunk) + ).strip() + assert result in norm_chunk or result in unwrap_chunk, ( + f"{label}: canonical form {result!r} not in chunk after " + "normalisation" + ) + else: + assert result is None, f"{label}: expected reject, got {result!r}" + + +@pytest.mark.parametrize( + "label,chunk_text,quote,should_match", _HALLUCINATION_CASES +) +def test_quote_matches_rejects_hallucinations( + label, chunk_text, quote, should_match +): + """Content insertions, fabrications, wrong numbers, and synthesised + quotes must all be rejected even by the looser layered guard.""" + result = _quote_matches(quote, chunk_text) + assert result is None, ( + f"{label}: expected reject, got {result!r} (quote={quote!r})" + ) + + +def test_quote_matches_returns_canonical_form_not_raw_quote(): + """When a quote matches via terminal-punctuation strip, the canonical + returned form should be the stripped version — not the model's raw + output — so downstream consumers always see a verbatim chunk + substring.""" + chunk = "Italy reported 63 cases, Spain reported 36 cases." + quote = "Italy reported 63 cases." # Model added period + result = _quote_matches(quote, chunk) + assert result == "Italy reported 63 cases" + # Round-trip: the canonical form must be a substring of the + # normalised chunk + assert result in _normalize_whitespace(chunk) From 5c50f99e91edf9f3e45c2cef37543e5e79db044e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20Mod=C3=A9e?= Date: Tue, 26 May 2026 13:09:56 +0200 Subject: [PATCH 06/21] Add real-document integration test for insight stage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wires the real ExtractionPipeline (fetcher monkey-patched to read on-disk bytes from data/docling_eval/sources/) into the insight pipeline with two deterministic fake LLMs: QuoteEchoingFakeLLM — picks a number-bearing sentence from each retrieved chunk and emits a synthetic fact citing it verbatim. Exercises the happy path of the hallucination guard. HallucinatingFakeLLM — always emits a fabricated quote. The guard must drop every fact. Covers 23 assertions across: - extraction per source (Africa CDC fails with requires_ocr, the other 6 succeed with >= 5 chunks each) - WHO mpox PDF metadata yields a publication date - insight pipeline produces records for every text-extractable doc - the CIDRAP article's "602 cases" headline appears in at least one record's quote - failed-extraction docs are skipped without LLM calls - the hallucination guard rejects every fabricated quote on every doc - cross-document dedup merges twin facts into one record with sources from both docs Uses >= thresholds so subsequent items in the insight-stage hardening plan (empty-chunk filter, partial-date dedup, pycountry resolution) can lift record counts without breaking this test. Runs in ~7 seconds with no live LLM calls. Co-Authored-By: Claude Opus 4.7 --- .../fixtures/insight/real_doc_extracts.py | 343 ++++++++++++++++++ .../test_insight_real_docs_integration.py | 257 +++++++++++++ 2 files changed, 600 insertions(+) create mode 100644 bioscancast/tests/fixtures/insight/real_doc_extracts.py create mode 100644 bioscancast/tests/test_insight_real_docs_integration.py diff --git a/bioscancast/tests/fixtures/insight/real_doc_extracts.py b/bioscancast/tests/fixtures/insight/real_doc_extracts.py new file mode 100644 index 0000000..0f5e65f --- /dev/null +++ b/bioscancast/tests/fixtures/insight/real_doc_extracts.py @@ -0,0 +1,343 @@ +"""Fixtures for end-to-end tests against the 7 real biosecurity documents +already committed under ``data/docling_eval/sources/``. + +The integration test runs the real ExtractionPipeline (with the network +fetcher monkey-patched to read on-disk bytes) over each source file and +hands the resulting Documents to the insight pipeline with deterministic +fake LLMs. Re-extracting all 7 sources takes ~5 seconds total with the +in-tree PDF parser, so we don't bother caching. + +The module deliberately has no dependency on the local-only smoke script +in ``scripts/eval_insight_on_real_docs.py`` — the test must be +self-contained. +""" + +from __future__ import annotations + +import json +import re +from datetime import datetime, timezone +from pathlib import Path +from typing import Optional +from unittest.mock import patch + +from bioscancast.extraction.config import ExtractionConfig +from bioscancast.extraction.fetcher import FetchResult +from bioscancast.extraction.pipeline import ExtractionPipeline +from bioscancast.filtering.models import FilteredDocument, ForecastQuestion +from bioscancast.llm.base import LLMResponse +from bioscancast.llm.fake_client import FakeLLMClient +from bioscancast.schemas import Document + + +# Resolve the sources directory relative to the repo root. +_REPO_ROOT = Path(__file__).resolve().parents[4] +SOURCES_DIR = _REPO_ROOT / "data" / "docling_eval" / "sources" + + +SOURCES = [ + { + "name": "who_mpox_sitrep64", + "file": "who_mpox_sitrep64.pdf", + "content_type": "application/pdf", + "domain": "who.int", + "url": "https://who.int/publications/m/item/multi-country-outbreak-of-mpox-external-situation-report-64", + "question": ForecastQuestion( + id="q-mpox-cases-2026", + text="How many confirmed mpox cases have been reported globally in 2026?", + created_at=datetime(2026, 4, 1), + target_date=datetime(2026, 12, 31), + region=None, + pathogen="mpox", + event_type="case_count", + ), + }, + { + "name": "who_cholera_epi34", + "file": "who_cholera_epi34.pdf", + "content_type": "application/pdf", + "domain": "who.int", + "url": "https://who.int/publications/m/item/multi-country-outbreak-of-cholera-external-situation-report-34", + "question": ForecastQuestion( + id="q-cholera-drc-2026", + text="How many cholera cases were reported in the Democratic Republic of the Congo?", + created_at=datetime(2026, 3, 1), + target_date=datetime(2026, 12, 31), + region="Democratic Republic of the Congo", + pathogen="cholera", + event_type="case_count", + ), + }, + { + "name": "cdc_mmwr_nm_measles", + "file": "cdc_mmwr_nm_measles.pdf", + "content_type": "application/pdf", + "domain": "cdc.gov", + "url": "https://cdc.gov/mmwr/volumes/74/wr/mm7509a1.htm", + "question": ForecastQuestion( + id="q-measles-nm-2025", + text="How many measles cases were reported in the New Mexico outbreak?", + created_at=datetime(2026, 3, 15), + target_date=datetime(2026, 12, 31), + region="New Mexico", + pathogen="measles", + event_type="case_count", + ), + }, + { + "name": "ecdc_cdtr_week16", + "file": "ecdc_cdtr_week16.pdf", + "content_type": "application/pdf", + "domain": "ecdc.europa.eu", + "url": "https://ecdc.europa.eu/en/publications-data/communicable-disease-threats-report-week-16-2026", + "question": ForecastQuestion( + id="q-measles-europe-2026", + text="How many measles cases have been reported across European countries this year?", + created_at=datetime(2026, 4, 20), + target_date=datetime(2026, 12, 31), + region="Europe", + pathogen="measles", + event_type="case_count", + ), + }, + { + "name": "africa_cdc_weekly_apr2026", + "file": "africa_cdc_weekly_apr2026.pdf", + "content_type": "application/pdf", + "domain": "africacdc.org", + "url": "https://africacdc.org/download/africa-cdc-weekly-event-based-surveillance-april-2026/", + "question": ForecastQuestion( + id="q-africa-outbreaks-2026", + text="What disease outbreaks are currently active across Africa in April 2026?", + created_at=datetime(2026, 4, 15), + target_date=datetime(2026, 12, 31), + region="Africa", + pathogen=None, + event_type="outbreak_declared", + ), + }, + { + "name": "cidrap_utah_measles", + "file": "cidrap_utah_measles.html", + "content_type": "text/html", + "domain": "cidrap.umn.edu", + "url": "https://cidrap.umn.edu/measles/utah-measles-cases-2026", + "question": ForecastQuestion( + id="q-measles-utah-2026", + text="How many measles cases have been confirmed in Utah?", + created_at=datetime(2026, 4, 25), + target_date=datetime(2026, 12, 31), + region="Utah", + pathogen="measles", + event_type="case_count", + ), + }, + { + "name": "promed_latest", + "file": "promed_latest.html", + "content_type": "text/html", + "domain": "promedmail.org", + "url": "https://promedmail.org/promed-posts/", + "question": ForecastQuestion( + id="q-promed-h5n1-2026", + text="What avian influenza H5N1 outbreaks have been reported recently?", + created_at=datetime(2026, 5, 1), + target_date=datetime(2026, 12, 31), + region=None, + pathogen="H5N1", + event_type="outbreak_declared", + ), + }, +] + + +def make_filtered_doc(source: dict) -> FilteredDocument: + """Build a minimal FilteredDocument that drives ExtractionPipeline.""" + return FilteredDocument( + result_id=source["name"], + question_id=source["question"].id, + url=source["url"], + canonical_url=source["url"], + domain=source["domain"], + title=source["name"], + snippet="", + published_date=None, + file_type=None, + relevance_score=0.9, + credibility_score=0.9, + final_score=0.9, + source_tier="official", + is_official_domain=True, + selection_reasons=["test"], + extraction_priority=1, + extraction_mode="auto", + expected_value="high", + ) + + +def _make_fake_fetch(file_path: Path, content_type: str): + """Return a fetch() replacement that reads on-disk bytes — no network.""" + payload = file_path.read_bytes() + + def fake_fetch(url, *, config=None, as_of_date=None): + return FetchResult( + url=url, + final_url=url, + status_code=200, + content_type=content_type, + content_bytes=payload, + fetched_at=datetime.now(timezone.utc), + error=None, + ) + + return fake_fetch + + +def extract_real_documents() -> dict[str, Document]: + """Run the real ExtractionPipeline over every source file in SOURCES. + + Returns a mapping from source name to Document. Africa CDC is expected + to come back with ``status="failed"`` because its PDF is image-only — + the in-tree parser correctly flags this as ``requires_ocr``. + """ + config = ExtractionConfig(enable_docling_refiner=False) + pipeline = ExtractionPipeline(config=config) + out: dict[str, Document] = {} + for src in SOURCES: + path = SOURCES_DIR / src["file"] + if not path.exists(): + raise FileNotFoundError(f"Missing source file: {path}") + with patch( + "bioscancast.extraction.pipeline.fetch", + _make_fake_fetch(path, src["content_type"]), + ): + out[src["name"]] = pipeline.extract_one(make_filtered_doc(src)) + return out + + +def get_source(name: str) -> dict: + """Look up a source dict by name.""" + for src in SOURCES: + if src["name"] == name: + return src + raise KeyError(name) + + +# --------------------------------------------------------------------------- +# Fake LLMs for the integration test +# --------------------------------------------------------------------------- + + +_NUMBER_SENTENCE = re.compile( + r"[^.\n]*?\b\d[\d,\.]*\b[^.\n]*\.", re.MULTILINE +) + + +def _extract_chunk_text_from_prompt(user_prompt: str) -> str: + marker = "CHUNK TEXT:\n" + idx = user_prompt.find(marker) + return user_prompt[idx + len(marker):] if idx != -1 else user_prompt + + +def _pick_quote(chunk_text: str) -> str: + """Pick a verbatim quote from the chunk — prefer the first sentence + that contains a number (more interesting for biosecurity facts). + Fall back to the first non-trivial line.""" + m = _NUMBER_SENTENCE.search(chunk_text) + if m: + return m.group(0).strip()[:180] + for line in chunk_text.splitlines(): + line = line.strip() + if len(line) > 20: + return line[:150] + return chunk_text[:150] + + +class QuoteEchoingFakeLLM: + """Reads the chunk text out of the LLM prompt and emits one synthetic + fact citing a verbatim quote drawn from the chunk. Useful for testing + the happy path of the hallucination guard on real chunk content + without any LLM cost.""" + + def __init__(self, embedding_client: Optional[FakeLLMClient] = None) -> None: + self._embed = embedding_client or FakeLLMClient() + self.calls: list[dict] = [] + + def generate_json( + self, + *, + system: str, + user: str, + schema: dict, + model: str, + max_tokens: int = 1024, + ) -> LLMResponse: + chunk_text = _extract_chunk_text_from_prompt(user) + quote = _pick_quote(chunk_text) + fact = { + "event_type": "case_count", + "confidence": 0.65, + "location": None, + "pathogen": None, + "metric_name": "events", + "metric_value": 1.0, + "metric_unit": "events", + "event_date": None, + "summary": "Quote-echoing fake fact for testing.", + "quote": quote, + } + self.calls.append({"quote": quote, "model": model}) + return LLMResponse( + content={"facts": [fact]}, + input_tokens=100, + output_tokens=20, + model=model, + raw_text=json.dumps({"facts": [fact]}), + ) + + def embed(self, texts, *, model): + return self._embed.embed(texts, model=model) + + +class HallucinatingFakeLLM: + """Always emits a fabricated quote that does not appear in any chunk. + Every fact should be rejected by the hallucination guard.""" + + BOGUS_QUOTE = "THIS QUOTE WAS INVENTED BY THE MODEL AND APPEARS NOWHERE." + + def __init__(self, embedding_client: Optional[FakeLLMClient] = None) -> None: + self._embed = embedding_client or FakeLLMClient() + self.calls = 0 + + def generate_json( + self, + *, + system: str, + user: str, + schema: dict, + model: str, + max_tokens: int = 1024, + ) -> LLMResponse: + self.calls += 1 + fact = { + "event_type": "case_count", + "confidence": 0.9, + "location": "Atlantis", + "pathogen": "Imaginarius bogus", + "metric_name": "confirmed_cases", + "metric_value": 999.0, + "metric_unit": "cases", + "event_date": "2099-01-01", + "summary": "Fabricated.", + "quote": self.BOGUS_QUOTE, + } + return LLMResponse( + content={"facts": [fact]}, + input_tokens=100, + output_tokens=20, + model=model, + raw_text=json.dumps({"facts": [fact]}), + ) + + def embed(self, texts, *, model): + return self._embed.embed(texts, model=model) diff --git a/bioscancast/tests/test_insight_real_docs_integration.py b/bioscancast/tests/test_insight_real_docs_integration.py new file mode 100644 index 0000000..e16549f --- /dev/null +++ b/bioscancast/tests/test_insight_real_docs_integration.py @@ -0,0 +1,257 @@ +"""End-to-end integration tests against 7 real biosecurity documents. + +Wires the real ExtractionPipeline (fetcher monkey-patched to read on-disk +bytes) into the insight pipeline with deterministic fake LLMs. Uses the +documents already committed under ``data/docling_eval/sources/``. + +Assertions use ``>=`` thresholds so subsequent insight-stage refactors +(items 5–7 of the hardening plan) that legitimately add records don't +break this test. The numbers are floors, calibrated from the observed +behaviour with the layered hallucination guard. + +No live LLM calls. Runs in <10 seconds. +""" + +from __future__ import annotations + +import pytest + +from bioscancast.filtering.models import ForecastQuestion +from bioscancast.insight.config import InsightConfig +from bioscancast.insight.pipeline import InsightPipeline +from bioscancast.schemas import Document +from bioscancast.tests.fixtures.insight.real_doc_extracts import ( + HallucinatingFakeLLM, + QuoteEchoingFakeLLM, + SOURCES, + extract_real_documents, + get_source, +) + + +# --------------------------------------------------------------------------- +# Module-scope fixture — extract once, reuse across tests +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def real_docs() -> dict[str, Document]: + """Extract all 7 real source documents once for the whole module.""" + return extract_real_documents() + + +@pytest.fixture +def insight_config() -> InsightConfig: + """Config matching the parameters used in the live evaluation.""" + return InsightConfig( + retrieval_top_k=5, + max_chunks_per_document=5, + max_input_tokens_per_run=10_000_000, + ) + + +# --------------------------------------------------------------------------- +# Extraction sanity +# --------------------------------------------------------------------------- + + +def test_extraction_produces_a_document_for_every_source(real_docs): + """Every source file in SOURCES must yield exactly one Document.""" + assert set(real_docs.keys()) == {src["name"] for src in SOURCES} + + +def test_extraction_africa_cdc_fails_with_requires_ocr(real_docs): + """Africa CDC's PDF is image-only — the in-tree parser must flag + this as ``requires_ocr`` rather than silently producing no chunks.""" + doc = real_docs["africa_cdc_weekly_apr2026"] + assert doc.status == "failed" + assert doc.error_message == "requires_ocr" + assert doc.chunks == [] + + +@pytest.mark.parametrize( + "name", + [ + "who_mpox_sitrep64", + "who_cholera_epi34", + "cdc_mmwr_nm_measles", + "ecdc_cdtr_week16", + "cidrap_utah_measles", + "promed_latest", + ], +) +def test_extraction_produces_chunks_for_text_extractable_sources(real_docs, name): + """Every source except Africa CDC must extract at least a few chunks.""" + doc = real_docs[name] + assert doc.status == "success", ( + f"{name}: expected status=success, got {doc.status}" + ) + assert len(doc.chunks) >= 5, ( + f"{name}: expected >= 5 chunks, got {len(doc.chunks)}" + ) + + +def test_who_mpox_publication_date_extracted(real_docs): + """WHO PDFs carry usable /CreationDate metadata.""" + doc = real_docs["who_mpox_sitrep64"] + assert doc.published_date is not None + assert doc.published_date.year == 2026 + + +# --------------------------------------------------------------------------- +# Insight pipeline happy path +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + # (name, min_records). Floors are calibrated with the layered + # hallucination guard; future items that lift extraction quality may + # legitimately push records higher. + "name,min_records", + [ + ("who_mpox_sitrep64", 1), + ("who_cholera_epi34", 1), + ("cdc_mmwr_nm_measles", 1), + ("ecdc_cdtr_week16", 1), + ("cidrap_utah_measles", 1), + ], +) +def test_pipeline_produces_at_least_one_record_per_text_doc( + real_docs, insight_config, name, min_records +): + """With a quote-echoing fake LLM, every text-extractable doc should + yield at least one record (the fake picks a number-bearing sentence + from each retrieved chunk; the layered guard accepts the picked + quote as long as it's a real verbatim substring of the chunk).""" + doc = real_docs[name] + src = get_source(name) + fake = QuoteEchoingFakeLLM() + pipe = InsightPipeline(llm_client=fake, config=insight_config) + result = pipe.run(src["question"], [doc]) + assert len(result.records) >= min_records, ( + f"{name}: expected >= {min_records} records, got {len(result.records)}" + ) + # Every record must carry valid provenance + for rec in result.records: + assert rec.sources, f"{name}: record has no sources" + for s in rec.sources: + assert s.chunk_id.startswith(f"{doc.id}-") + assert s.source_url == doc.source_url + assert s.quote, f"{name}: record has empty quote" + + +def test_cidrap_pipeline_captures_602_utah_cases(real_docs, insight_config): + """The CIDRAP article's headline fact ("602 measles cases in Utah") + should appear in the quote of at least one record.""" + doc = real_docs["cidrap_utah_measles"] + src = get_source("cidrap_utah_measles") + fake = QuoteEchoingFakeLLM() + pipe = InsightPipeline(llm_client=fake, config=insight_config) + result = pipe.run(src["question"], [doc]) + assert any( + "602" in s.quote + for rec in result.records + for s in rec.sources + ), "CIDRAP: expected at least one record citing '602' (measles cases)" + + +def test_failed_doc_is_skipped(real_docs, insight_config): + """Africa CDC (status=failed) must take the skip path and never + cause a per-chunk LLM call.""" + doc = real_docs["africa_cdc_weekly_apr2026"] + src = get_source("africa_cdc_weekly_apr2026") + fake = QuoteEchoingFakeLLM() + pipe = InsightPipeline(llm_client=fake, config=insight_config) + result = pipe.run(src["question"], [doc]) + assert result.documents_skipped == 1 + assert result.documents_processed == 0 + assert len(fake.calls) == 0 + assert any("Skipped" in note for note in result.notes) + + +# --------------------------------------------------------------------------- +# Hallucination guard end-to-end +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "name", + [ + "who_mpox_sitrep64", + "who_cholera_epi34", + "cdc_mmwr_nm_measles", + "ecdc_cdtr_week16", + "cidrap_utah_measles", + "promed_latest", + ], +) +def test_hallucination_guard_drops_every_fact_with_fabricated_quote( + real_docs, insight_config, name +): + """A HallucinatingFakeLLM that always emits a quote that doesn't + appear in any chunk must produce zero records on every doc.""" + doc = real_docs[name] + src = get_source(name) + fake = HallucinatingFakeLLM() + pipe = InsightPipeline(llm_client=fake, config=insight_config) + result = pipe.run(src["question"], [doc]) + assert fake.calls > 0, f"{name}: fake was never called" + assert result.records == [], ( + f"{name}: hallucination guard let through {len(result.records)} fact(s)" + ) + + +# --------------------------------------------------------------------------- +# Cross-document deduplication +# --------------------------------------------------------------------------- + + +def test_cross_doc_dedup_merges_identical_facts(real_docs, insight_config): + """Two docs producing facts with the same dedup key should merge + into a single record whose sources span both docs.""" + cidrap = real_docs["cidrap_utah_measles"] + mmwr = real_docs["cdc_mmwr_nm_measles"] + assert cidrap.status == "success" + assert mmwr.status == "success" + + # Custom fake that returns an identically-structured fact for every + # chunk it sees, but uses a real quote pulled from the chunk text + # (so the hallucination guard accepts it). + class _TwinFactFake(QuoteEchoingFakeLLM): + def generate_json(self, *, system, user, schema, model, max_tokens=1024): + response = super().generate_json( + system=system, user=user, schema=schema, model=model, + max_tokens=max_tokens, + ) + # Force every fact to dedup-collide on the same key + response.content["facts"][0].update({ + "event_type": "case_count", + "location": "United States", + "pathogen": "measles", + "metric_name": "confirmed_cases", + "metric_value": 42.0, + "metric_unit": "cases", + "event_date": "2026-03-01", + }) + return response + + question = ForecastQuestion( + id="q-cross-doc-measles", + text="How many measles cases reported in the US?", + created_at=cidrap.fetched_at.replace(tzinfo=None), + region="United States", + pathogen="measles", + event_type="case_count", + ) + fake = _TwinFactFake() + pipe = InsightPipeline(llm_client=fake, config=insight_config) + result = pipe.run(question, [cidrap, mmwr]) + + assert len(result.records) == 1, ( + f"Expected 1 merged record, got {len(result.records)}" + ) + record = result.records[0] + source_doc_ids = {s.document_id for s in record.sources} + assert source_doc_ids == {cidrap.id, mmwr.id}, ( + f"Expected sources from both docs, got {source_doc_ids}" + ) From 99f93dd9557e0d75c9e525b2d31fde0004057e42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20Mod=C3=A9e?= Date: Tue, 26 May 2026 13:13:01 +0200 Subject: [PATCH 07/21] Drop or repair empty chunks at extraction time MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Empty-text chunks cost the insight stage twice: they take a top-k slot in retrieval (BM25 indexes the heading even when the body is empty), and then trigger an LLM call against blank chunk text. Tests show the CDC MMWR borderless table case where this happened on a chunk that ranked #1 by retrieval score — a wasted call every time. A new helper, `_drop_or_repair_empty_chunks`, runs in `ExtractionPipeline.extract_one` between `normalize_chunks` and chunk renumbering. Two paths: - Table chunks with empty text but populated `table_data`: render the rows to a tab-separated text block so BM25 and the LLM can see the cell contents. `table_data` itself is preserved unchanged for any consumer that wants the structured form. - Other empty chunks: drop with a DEBUG log carrying chunk_id, type and heading. An empty prose chunk almost always indicates a half- broken upstream parser section (heading without body, footer artefact); nothing the insight stage can act on. Tests confirm the existing MMWR table chunk (previously empty-text) now carries 277 characters of rendered table content while the underlying rows stay accessible via `table_data`. Empty prose chunks disappear before reaching downstream stages. Co-Authored-By: Claude Opus 4.7 --- bioscancast/extraction/pipeline.py | 72 ++++++++++++ bioscancast/tests/test_extraction_pipeline.py | 111 ++++++++++++++++++ 2 files changed, 183 insertions(+) diff --git a/bioscancast/extraction/pipeline.py b/bioscancast/extraction/pipeline.py index 47dd1f4..888e347 100644 --- a/bioscancast/extraction/pipeline.py +++ b/bioscancast/extraction/pipeline.py @@ -146,6 +146,11 @@ def extract_one(self, filtered_doc: FilteredDocument) -> Document: max_tokens=self._config.chunk_max_tokens, ) + # Step 5b: Drop or repair empty chunks before any downstream + # consumer sees them. Empty chunks make retrieval waste budget + # (BM25 still indexes the heading) and confuse the insight stage. + chunks = _drop_or_repair_empty_chunks(chunks) + # Renumber chunk indices after normalization for i, chunk in enumerate(chunks): chunk.chunk_index = i @@ -290,3 +295,70 @@ def _extract_dates(self, text: str) -> List[str]: seen.add(d) unique.append(d) return unique + + +def _render_table_data_as_text(rows: List[List[str]]) -> str: + """Render row-major table data as a flat text block. + + Cells are joined with tabs within a row and rows with newlines. + Empty cells are preserved (so column alignment stays implicit) but + fully-empty rows are skipped. This is good enough for BM25 keyword + matching when the underlying PDF parser produced a table whose cells + are present but whose flowed text was empty. + """ + out_rows: List[str] = [] + for row in rows: + cells = [(cell or "").strip() for cell in row] + if not any(cells): + continue + out_rows.append("\t".join(cells)) + return "\n".join(out_rows) + + +def _drop_or_repair_empty_chunks( + chunks: List[DocumentChunk], +) -> List[DocumentChunk]: + """Filter chunks whose ``text`` is blank after stripping whitespace. + + Two paths, in order of preference: + + * If the chunk is a table with non-empty ``table_data`` rows, render + those rows into the ``text`` field so downstream retrieval and + LLM extraction can see the cell contents. The structured + ``table_data`` is preserved unchanged for any consumer that wants + cell-level access. + * Otherwise drop the chunk and log at DEBUG. An empty prose chunk + usually indicates a half-broken section in the upstream parser + (heading without body, footer artefact, decorative caption), not + something the insight stage can act on. + + Running this *after* ``normalize_chunks`` means it sees the final + post-split chunk text, not pre-split fragments that the splitter + might have rebalanced. + """ + out: List[DocumentChunk] = [] + for chunk in chunks: + if chunk.text and chunk.text.strip(): + out.append(chunk) + continue + if chunk.chunk_type == "table" and chunk.table_data: + rendered = _render_table_data_as_text(chunk.table_data) + if rendered: + chunk.text = rendered + chunk.token_count = approx_token_count(rendered) + logger.debug( + "Empty-text table chunk repaired from table_data " + "(chunk_id=%s, rows=%d, rendered_chars=%d)", + chunk.chunk_id, + len(chunk.table_data), + len(rendered), + ) + out.append(chunk) + continue + logger.debug( + "Dropping empty chunk (chunk_id=%s, type=%s, heading=%r)", + chunk.chunk_id, + chunk.chunk_type, + (chunk.heading or "")[:60], + ) + return out diff --git a/bioscancast/tests/test_extraction_pipeline.py b/bioscancast/tests/test_extraction_pipeline.py index 4aba2d3..c0ee6fb 100644 --- a/bioscancast/tests/test_extraction_pipeline.py +++ b/bioscancast/tests/test_extraction_pipeline.py @@ -317,3 +317,114 @@ def test_extract_one(self): assert isinstance(doc, Document) assert doc.status == "success" + + +# --------------------------------------------------------------------------- +# Empty chunk handling — _drop_or_repair_empty_chunks +# --------------------------------------------------------------------------- + +class TestEmptyChunkHandling: + """The extraction pipeline must drop empty prose chunks and render + table chunks whose text is empty (but whose table_data is populated) + into a flat text representation. See pipeline._drop_or_repair_empty_chunks.""" + + def test_empty_prose_chunk_is_dropped(self): + from bioscancast.extraction.pipeline import _drop_or_repair_empty_chunks + from bioscancast.schemas.document import DocumentChunk + + chunks = [ + DocumentChunk( + chunk_id="c0", chunk_index=0, text="Real content.", + chunk_type="prose", + ), + DocumentChunk( + chunk_id="c1", chunk_index=1, text="", + chunk_type="prose", + ), + DocumentChunk( + chunk_id="c2", chunk_index=2, text=" \n\t ", + chunk_type="prose", + ), + DocumentChunk( + chunk_id="c3", chunk_index=3, text="More content.", + chunk_type="prose", + ), + ] + out = _drop_or_repair_empty_chunks(chunks) + assert [c.chunk_id for c in out] == ["c0", "c3"] + + def test_empty_text_table_chunk_with_table_data_is_repaired(self): + from bioscancast.extraction.pipeline import _drop_or_repair_empty_chunks + from bioscancast.schemas.document import DocumentChunk + + table_data = [ + ["Country", "Cases", "Deaths"], + ["Uganda", "9", "3"], + ["DRC", "6543", "148"], + ] + chunks = [ + DocumentChunk( + chunk_id="t0", chunk_index=0, text="", + chunk_type="table", table_data=table_data, + ), + ] + out = _drop_or_repair_empty_chunks(chunks) + assert len(out) == 1 + assert "Uganda" in out[0].text + assert "6543" in out[0].text + assert "Deaths" in out[0].text + # Original structured data is preserved unchanged + assert out[0].table_data == table_data + # Token count is recomputed for the rendered text + assert out[0].token_count is not None + assert out[0].token_count > 0 + + def test_empty_table_chunk_without_table_data_is_dropped(self): + from bioscancast.extraction.pipeline import _drop_or_repair_empty_chunks + from bioscancast.schemas.document import DocumentChunk + + chunks = [ + DocumentChunk( + chunk_id="t0", chunk_index=0, text="", + chunk_type="table", table_data=None, + ), + DocumentChunk( + chunk_id="t1", chunk_index=1, text="", + chunk_type="table", table_data=[], + ), + ] + out = _drop_or_repair_empty_chunks(chunks) + assert out == [] + + def test_empty_table_chunk_with_only_empty_rows_is_dropped(self): + from bioscancast.extraction.pipeline import _drop_or_repair_empty_chunks + from bioscancast.schemas.document import DocumentChunk + + chunks = [ + DocumentChunk( + chunk_id="t0", chunk_index=0, text="", + chunk_type="table", + table_data=[["", "", ""], ["", ""]], + ), + ] + out = _drop_or_repair_empty_chunks(chunks) + assert out == [] + + def test_table_chunk_with_text_passes_through_unchanged(self): + """A table chunk that already has text is left alone.""" + from bioscancast.extraction.pipeline import _drop_or_repair_empty_chunks + from bioscancast.schemas.document import DocumentChunk + + chunks = [ + DocumentChunk( + chunk_id="t0", chunk_index=0, + text="Existing flowed text", + chunk_type="table", + table_data=[["a", "b"]], + token_count=4, + ), + ] + out = _drop_or_repair_empty_chunks(chunks) + assert len(out) == 1 + assert out[0].text == "Existing flowed text" # unchanged + assert out[0].token_count == 4 # unchanged From a47be32255d641c03aab7300669d74b43306e707 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20Mod=C3=A9e?= Date: Tue, 26 May 2026 13:21:22 +0200 Subject: [PATCH 08/21] Dedup with partial event-date precision MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous single-stage dedup key formatted every event_date as YYYY-MM-DD, which prevented merging two reports of the same event when one source gave a day and another gave only the month — tests on the 6 real biosecurity documents showed WHO cholera repeatedly producing two records for the DRC January-2026 6543-cases fact, one from a prose sentence ("In January 2026 ... reported 6543 new cholera cases") and one from a table cell ("Democratic Republic of the Congo 6 543") with subtly different parsed dates. Changes: - New `event_date_precision` field on `InsightRecord` carries the granularity ("year"|"month"|"day") alongside the canonical `event_date` datetime (which is now the start of the period when only a partial date is known). - `_parse_event_date` accepts YYYY, YYYY-MM, YYYY-MM-DD, and the existing free-form day-precision variants, returning a (datetime, precision) tuple. - Extraction prompt instructs the model to use the most specific ISO date the chunk supports and NOT to invent a day-of-month when only a month is given. - Two-stage `_deduplicate_records`: first group by (event_type, metric_name, normalized_location), then within each group walk records in order and merge each into the first surviving entry whose date bucket overlaps at the coarser precision. The surviving record adopts the finer precision so downstream consumers don't lose information. - Confidence is taken as the max across merged records; provenance references from every merged source are preserved. Tests cover the matrix of precision combinations (day/month overlap, year/day overlap, equal months, different days, undated vs dated, three-way mixed-precision merge, separate locations). Live runs show records merging correctly when the model picks the same metric_name across sources; cases where the model uses subtly different metric names ("new cases" vs "Cases" vs "cholera cases") still stay separate — that's a metric-name-normalisation problem separate from this change. Co-Authored-By: Claude Opus 4.7 --- .../insight/extraction/chunk_extractor.py | 50 +++- bioscancast/insight/extraction/prompts.py | 7 +- bioscancast/insight/pipeline.py | 151 +++++++++--- bioscancast/schemas/insight_record.py | 14 +- bioscancast/tests/test_insight_pipeline.py | 215 ++++++++++++++++++ 5 files changed, 399 insertions(+), 38 deletions(-) diff --git a/bioscancast/insight/extraction/chunk_extractor.py b/bioscancast/insight/extraction/chunk_extractor.py index b7a1633..aa5db48 100644 --- a/bioscancast/insight/extraction/chunk_extractor.py +++ b/bioscancast/insight/extraction/chunk_extractor.py @@ -220,16 +220,49 @@ def _resolve_country_code(location: Optional[str]) -> Optional[str]: return None -def _parse_event_date(date_str: Optional[str]) -> Optional[datetime]: - """Try to parse a date string from the LLM output.""" +def _parse_event_date( + date_str: Optional[str], +) -> tuple[Optional[datetime], Optional[str]]: + """Parse a date string from the LLM output into a (datetime, precision) + pair. + + Accepts (in order of preference): + + * ``YYYY-MM-DD`` and similar full ISO forms → precision="day" + * ``YYYY-MM`` → precision="month"; datetime is the start of that month + * ``YYYY`` → precision="year"; datetime is January 1 of that year + * Free-form ``"15 January 2026"``, ``"January 15, 2026"`` → "day" + + Returns ``(None, None)`` when no format matches. Storing precision + alongside the canonicalised datetime lets the dedup logic merge a + month-precision record like 2026-01 with a day-precision record + inside that month without throwing away the underlying granularity. + """ if not date_str: - return None - for fmt in ("%Y-%m-%d", "%d %B %Y", "%B %d, %Y", "%Y-%m-%dT%H:%M:%S"): + return None, None + + cleaned = date_str.strip() + + # Day-precision attempts (in order) + for fmt in ("%Y-%m-%d", "%Y-%m-%dT%H:%M:%S", "%d %B %Y", "%B %d, %Y"): try: - return datetime.strptime(date_str, fmt) + return datetime.strptime(cleaned, fmt), "day" except ValueError: continue - return None + + # Month-precision attempt + try: + return datetime.strptime(cleaned, "%Y-%m"), "month" + except ValueError: + pass + + # Year-precision attempt + try: + return datetime.strptime(cleaned, "%Y"), "year" + except ValueError: + pass + + return None, None def extract_facts_from_chunk( @@ -290,7 +323,9 @@ def extract_facts_from_chunk( location = fact.get("location") iso_code = _resolve_country_code(location) - event_date = _parse_event_date(fact.get("event_date")) + event_date, event_date_precision = _parse_event_date( + fact.get("event_date") + ) record = InsightRecord( id=f"ins-{uuid.uuid4().hex[:12]}", @@ -308,6 +343,7 @@ def extract_facts_from_chunk( ), metric_unit=fact.get("metric_unit"), event_date=event_date, + event_date_precision=event_date_precision, summary=fact.get("summary"), model=model, extracted_at=datetime.now(timezone.utc), diff --git a/bioscancast/insight/extraction/prompts.py b/bioscancast/insight/extraction/prompts.py index b3bdfd8..537d392 100644 --- a/bioscancast/insight/extraction/prompts.py +++ b/bioscancast/insight/extraction/prompts.py @@ -32,7 +32,12 @@ This is expected and common — most chunks are irrelevant. 4. Do NOT answer the forecast question. Your job is fact extraction, \ not forecasting. -5. Be aware of cognitive biases that affect information processing: +5. For event_date, use the most specific ISO date you can extract from \ +the chunk and nothing more: ``YYYY-MM-DD`` when a day is given, \ +``YYYY-MM`` when only a month is given (e.g. "January 2026"), or \ +``YYYY`` when only a year is given. Do NOT invent a day-of-month when \ +the chunk only mentions a month. +6. Be aware of cognitive biases that affect information processing: - Anchoring: do not over-weight the first number you encounter. - Availability: rare dramatic events are not necessarily more likely. - Overconfidence: if the chunk is ambiguous, lower your confidence. diff --git a/bioscancast/insight/pipeline.py b/bioscancast/insight/pipeline.py index 4557672..b8ebb80 100644 --- a/bioscancast/insight/pipeline.py +++ b/bioscancast/insight/pipeline.py @@ -148,46 +148,139 @@ def _normalize_location(location: Optional[str]) -> str: return location.lower().strip() +# Ordered from coarsest to finest. Used by the dedup logic to compare +# two records whose event_date_precision values differ. +_PRECISION_ORDER = {"year": 0, "month": 1, "day": 2} + + +def _date_bucket( + dt: Optional[datetime], precision: Optional[str] +) -> Optional[tuple]: + """Truncate a (datetime, precision) pair to a comparable tuple bucket. + + Returns ``None`` when no date is known. + """ + if dt is None or precision is None: + return None + if precision == "year": + return (dt.year,) + if precision == "month": + return (dt.year, dt.month) + # day (or anything more specific) is collapsed to day + return (dt.year, dt.month, dt.day) + + +def _dates_overlap( + d1: Optional[datetime], + p1: Optional[str], + d2: Optional[datetime], + p2: Optional[str], +) -> bool: + """Check whether two (datetime, precision) pairs refer to + overlapping time buckets. + + Both-None counts as overlap (the "no date known" bucket). One-None + does NOT overlap with a known-date bucket — we don't merge dated + and undated facts. + + Two known dates overlap when, truncated to whichever precision is + coarser, their buckets are equal. So month-precision ``2026-01`` + overlaps with day-precision ``2026-01-25`` (both → ``(2026, 1)`` + at month precision) but not with ``2026-02-25``. + """ + if d1 is None and d2 is None: + return True + if d1 is None or d2 is None: + return False + if p1 is None or p2 is None: + return False + coarser = p1 if _PRECISION_ORDER[p1] <= _PRECISION_ORDER[p2] else p2 + return _date_bucket(d1, coarser) == _date_bucket(d2, coarser) + + def _record_dedup_key(record: InsightRecord) -> tuple: - """Build a deduplication key for an InsightRecord. + """First-stage dedup key: groups records that *might* be duplicates. - Two records are duplicates if they have the same event_type, - metric_name, date, and normalized location. + Date is intentionally omitted from the first-stage key because two + records with different date precisions (e.g. ``2026-01`` vs + ``2026-01-25``) need to be considered together. The second stage + walks each group and uses ``_dates_overlap`` to decide whether to + merge. """ - date_str = "" - if record.event_date: - date_str = record.event_date.strftime("%Y-%m-%d") return ( record.event_type, record.metric_name or "", - date_str, _normalize_location(record.location), ) -def _deduplicate_records(records: list[InsightRecord]) -> list[InsightRecord]: - """Deduplicate InsightRecords, merging provenance lists. +def _merge_record_into( + target: InsightRecord, source: InsightRecord +) -> None: + """Merge ``source`` into ``target`` in place. - Keeps the record with the higher confidence score and merges - source references from duplicates. + Adds source's unique chunk references, raises confidence to the max + of the two, and adopts the finer of the two date precisions (with + its corresponding date). The coarser-precision source loses its + date but its provenance is preserved. + """ + existing_chunk_ids = { + (s.document_id, s.chunk_id) for s in target.sources + } + for src in source.sources: + if (src.document_id, src.chunk_id) not in existing_chunk_ids: + target.sources.append(src) + if source.confidence > target.confidence: + target.confidence = source.confidence + # Adopt the finer precision date if source has one + source_rank = ( + _PRECISION_ORDER.get(source.event_date_precision, -1) + if source.event_date_precision else -1 + ) + target_rank = ( + _PRECISION_ORDER.get(target.event_date_precision, -1) + if target.event_date_precision else -1 + ) + if source.event_date and source_rank > target_rank: + target.event_date = source.event_date + target.event_date_precision = source.event_date_precision + + +def _deduplicate_records(records: list[InsightRecord]) -> list[InsightRecord]: + """Two-stage deduplication merging records with overlapping date buckets. + + Stage 1: group by ``(event_type, metric_name, normalized_location)``. + Stage 2: within each group, walk records in order and merge each + into the first surviving entry whose date bucket overlaps. This + handles the common case where multiple sources report the same + event at different date precisions (e.g. WHO sitrep says + "January 2026" while a country report says "as of 25 January + 2026") — both refer to the same underlying fact. + + Records with completely distinct dates within the same group stay + separate (no false merging of "Jan 5" with "Jan 6"). """ - seen: dict[tuple, InsightRecord] = {} + from collections import defaultdict + groups: dict[tuple, list[InsightRecord]] = defaultdict(list) for record in records: - key = _record_dedup_key(record) - if key in seen: - existing = seen[key] - # Merge provenance - existing_chunk_ids = { - (s.document_id, s.chunk_id) for s in existing.sources - } - for src in record.sources: - if (src.document_id, src.chunk_id) not in existing_chunk_ids: - existing.sources.append(src) - # Keep higher confidence - if record.confidence > existing.confidence: - existing.confidence = record.confidence - else: - seen[key] = record - - return list(seen.values()) + groups[_record_dedup_key(record)].append(record) + + out: list[InsightRecord] = [] + for group in groups.values(): + merged: list[InsightRecord] = [] + for record in group: + target = None + for existing in merged: + if _dates_overlap( + record.event_date, record.event_date_precision, + existing.event_date, existing.event_date_precision, + ): + target = existing + break + if target is not None: + _merge_record_into(target, record) + else: + merged.append(record) + out.extend(merged) + return out diff --git a/bioscancast/schemas/insight_record.py b/bioscancast/schemas/insight_record.py index 05d68b3..be47884 100644 --- a/bioscancast/schemas/insight_record.py +++ b/bioscancast/schemas/insight_record.py @@ -71,7 +71,19 @@ class InsightRecord: """Unit of the metric (e.g. 'cases', 'herds', 'deaths').""" event_date: Optional[datetime] = None - """Date the fact pertains to (not the date it was reported).""" + """Date the fact pertains to (not the date it was reported). + + Canonicalised to the start of the period when only a partial date is + known (e.g. ``"2026-01"`` → ``datetime(2026, 1, 1)``). Read together + with ``event_date_precision`` to recover the original granularity. + """ + + event_date_precision: Optional[str] = None + """Granularity of ``event_date``: ``"year"`` | ``"month"`` | ``"day"``, + or ``None`` when no date was extracted. The dedup logic in the insight + pipeline merges two records whose date buckets overlap at the coarser + precision (e.g. a record with month precision 2026-01 merges with a + day-precision record dated 2026-01-25).""" # ---- free-text fallback ---- summary: Optional[str] = None diff --git a/bioscancast/tests/test_insight_pipeline.py b/bioscancast/tests/test_insight_pipeline.py index d34bb7b..9cdae02 100644 --- a/bioscancast/tests/test_insight_pipeline.py +++ b/bioscancast/tests/test_insight_pipeline.py @@ -189,6 +189,221 @@ def test_pipeline_deduplication(): assert len(case_records[0].sources) >= 2 +# --------------------------------------------------------------------------- +# Date-precision-aware dedup +# --------------------------------------------------------------------------- + + +def _record( + record_id: str, + *, + event_date: "datetime | None" = None, + event_date_precision: "str | None" = None, + confidence: float = 0.8, + location: "str | None" = "United States", + metric_name: "str | None" = "confirmed_cases", + event_type: str = "case_count", + document_id: str = "doc-test", +): + """Build an InsightRecord with a single ChunkReference for dedup tests.""" + from bioscancast.schemas import InsightRecord, ChunkReference + return InsightRecord( + id=record_id, + question_id="q-test", + event_type=event_type, + confidence=confidence, + location=location, + metric_name=metric_name, + metric_value=1.0, + event_date=event_date, + event_date_precision=event_date_precision, + sources=[ChunkReference( + document_id=document_id, + chunk_id=f"chunk-{record_id}", + source_url=f"https://example.com/{document_id}", + quote="some quote", + )], + ) + + +def test_dedup_merges_day_precision_with_month_precision_in_same_month(): + """A day-precision fact (2026-01-25) should merge with a month-precision + fact (2026-01) and the merged record should keep the finer precision.""" + from datetime import datetime + from bioscancast.insight.pipeline import _deduplicate_records + + day_record = _record( + "day", event_date=datetime(2026, 1, 25), event_date_precision="day", + ) + month_record = _record( + "month", event_date=datetime(2026, 1, 1), event_date_precision="month", + document_id="doc-other", + ) + result = _deduplicate_records([day_record, month_record]) + assert len(result) == 1 + merged = result[0] + assert merged.event_date_precision == "day" + assert merged.event_date == datetime(2026, 1, 25) + assert len(merged.sources) == 2 + + +def test_dedup_keeps_month_precision_when_order_reversed(): + """Order of presentation must not change the outcome: month-then-day + must also merge to the day-precision form.""" + from datetime import datetime + from bioscancast.insight.pipeline import _deduplicate_records + + month_record = _record( + "month", event_date=datetime(2026, 1, 1), event_date_precision="month", + ) + day_record = _record( + "day", event_date=datetime(2026, 1, 25), event_date_precision="day", + document_id="doc-other", + ) + result = _deduplicate_records([month_record, day_record]) + assert len(result) == 1 + merged = result[0] + assert merged.event_date_precision == "day" + assert merged.event_date == datetime(2026, 1, 25) + + +def test_dedup_does_not_merge_different_months_at_month_precision(): + """Two month-precision records in different months must stay separate.""" + from datetime import datetime + from bioscancast.insight.pipeline import _deduplicate_records + + jan = _record( + "jan", event_date=datetime(2026, 1, 1), event_date_precision="month", + ) + feb = _record( + "feb", event_date=datetime(2026, 2, 1), event_date_precision="month", + document_id="doc-other", + ) + result = _deduplicate_records([jan, feb]) + assert len(result) == 2 + + +def test_dedup_does_not_merge_different_days(): + """Two day-precision records on different days must stay separate even + when they share the same month.""" + from datetime import datetime + from bioscancast.insight.pipeline import _deduplicate_records + + day5 = _record( + "day5", event_date=datetime(2026, 1, 5), event_date_precision="day", + ) + day6 = _record( + "day6", event_date=datetime(2026, 1, 6), event_date_precision="day", + document_id="doc-other", + ) + result = _deduplicate_records([day5, day6]) + assert len(result) == 2 + + +def test_dedup_merges_year_precision_with_day_precision_in_same_year(): + """A year-only fact (2026) should merge with a day-precision fact + inside that year.""" + from datetime import datetime + from bioscancast.insight.pipeline import _deduplicate_records + + year = _record( + "year", event_date=datetime(2026, 1, 1), event_date_precision="year", + ) + day = _record( + "day", event_date=datetime(2026, 3, 15), event_date_precision="day", + document_id="doc-other", + ) + result = _deduplicate_records([year, day]) + assert len(result) == 1 + assert result[0].event_date_precision == "day" + assert result[0].event_date == datetime(2026, 3, 15) + + +def test_dedup_does_not_merge_different_years(): + """Year-only facts in different years must stay separate.""" + from datetime import datetime + from bioscancast.insight.pipeline import _deduplicate_records + + y2025 = _record( + "2025", event_date=datetime(2025, 1, 1), event_date_precision="year", + ) + y2026 = _record( + "2026", event_date=datetime(2026, 1, 1), event_date_precision="year", + document_id="doc-other", + ) + result = _deduplicate_records([y2025, y2026]) + assert len(result) == 2 + + +def test_dedup_three_way_merge_with_mixed_precisions(): + """Three records, one each at year/month/day precision, all in the + same time range, should collapse to one with day precision.""" + from datetime import datetime + from bioscancast.insight.pipeline import _deduplicate_records + + year = _record( + "year", event_date=datetime(2026, 1, 1), event_date_precision="year", + document_id="doc-a", + ) + month = _record( + "month", event_date=datetime(2026, 1, 1), event_date_precision="month", + document_id="doc-b", + ) + day = _record( + "day", event_date=datetime(2026, 1, 25), event_date_precision="day", + document_id="doc-c", + ) + result = _deduplicate_records([year, month, day]) + assert len(result) == 1 + merged = result[0] + assert merged.event_date_precision == "day" + assert merged.event_date == datetime(2026, 1, 25) + assert len(merged.sources) == 3 + + +def test_dedup_does_not_merge_dated_with_undated(): + """A record with no date and a record with a date should stay + separate — we don't claim they're about the same event.""" + from datetime import datetime + from bioscancast.insight.pipeline import _deduplicate_records + + undated = _record("undated") + dated = _record( + "dated", event_date=datetime(2026, 1, 25), event_date_precision="day", + document_id="doc-other", + ) + result = _deduplicate_records([undated, dated]) + assert len(result) == 2 + + +def test_dedup_merges_undated_records(): + """Two records with no date in the same group still merge.""" + from bioscancast.insight.pipeline import _deduplicate_records + + a = _record("a") + b = _record("b", document_id="doc-other") + result = _deduplicate_records([a, b]) + assert len(result) == 1 + assert len(result[0].sources) == 2 + + +def test_dedup_does_not_merge_across_different_locations(): + """Different normalized locations stay in separate groups.""" + from datetime import datetime + from bioscancast.insight.pipeline import _deduplicate_records + + us = _record( + "us", location="United States", + event_date=datetime(2026, 1, 25), event_date_precision="day", + ) + uk = _record( + "uk", location="United Kingdom", document_id="doc-other", + event_date=datetime(2026, 1, 25), event_date_precision="day", + ) + result = _deduplicate_records([us, uk]) + assert len(result) == 2 + + # --------------------------------------------------------------------------- # Multi-document end-to-end # --------------------------------------------------------------------------- From ecdb7269a473dbb27d5c8eefc15dda75993843bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20Mod=C3=A9e?= Date: Tue, 26 May 2026 13:26:30 +0200 Subject: [PATCH 09/21] Replace 30-country hardcoded map with pycountry-backed resolver MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous COUNTRY_TO_ISO dict hand-listed ~30 countries — primarily because nobody wanted to take on a dependency at the time. Live tests show the cost: most extracted records on the 6 real biosecurity docs had iso_country_code=None despite clear country names like "Austria", "Bulgaria", "Comoros", and "Madagascar" that the map simply didn't cover. With ~250 ISO 3166-1 entries, hand-maintenance was untenable. pycountry covers all 249 ISO 3166-1 entries by canonical, common, and official name plus alpha-2/alpha-3 codes. The new resolver in `chunk_extractor.py` layers four steps: 1. Typography fold (smart quotes → ASCII) so "Côte d'Ivoire" with either apostrophe variant resolves. 2. Explicit not-a-country set ("Africa", "European Region", "EU/EEA", etc.) — these are multi-country roll-ups, not single ISO entries, and pycountry's `search_fuzzy` produces surprising false positives here. Returns None deliberately. 3. Alias dict for forms pycountry won't match directly ("UK", "DRC", "Russia", "Burma", "Ivory Coast", "North Korea", and UK constituents like "England" / "Scotland" → "GB"). 4. US subnational set — all 50 states plus DC and US territories → "US". Common in biosecurity reporting (CDC MMWR routinely phrases location as "Lea County, New Mexico"). `pycountry.countries.search_fuzzy` is deliberately not used — it would resolve "Eastern Mediterranean Region" to a single country (false positive). Strict matching only. Compound locations like "Mubende district, Uganda" still work via the existing right-to-left comma-segment fallback. Tests cover: bare country lookups across all 6 real docs' locations, the alias set, US states, multi-country region rejection, compound locations, alpha-2/alpha-3 codes, smart-quote variants. pycountry pinned to >=24.0 in requirements.txt. Co-Authored-By: Claude Opus 4.7 --- .../insight/extraction/chunk_extractor.py | 180 +++++++++++++----- .../tests/test_insight_chunk_extractor.py | 99 ++++++++++ requirements.txt | 1 + 3 files changed, 232 insertions(+), 48 deletions(-) diff --git a/bioscancast/insight/extraction/chunk_extractor.py b/bioscancast/insight/extraction/chunk_extractor.py index aa5db48..dd24d71 100644 --- a/bioscancast/insight/extraction/chunk_extractor.py +++ b/bioscancast/insight/extraction/chunk_extractor.py @@ -15,6 +15,8 @@ from datetime import datetime, timezone from typing import TYPE_CHECKING, Optional +import pycountry + from bioscancast.schemas import DocumentChunk, Document, ChunkReference, InsightRecord from bioscancast.filtering.models import ForecastQuestion from .prompts import build_extraction_prompt @@ -66,48 +68,66 @@ _WRAPPING_PUNCT_RE = re.compile(r"[\(\)\[\]\{\}\"\']") -# Hardcoded country name -> ISO 3166-1 alpha-2 map for the ~30 most -# likely countries in biosecurity reporting. Don't pull in pycountry. -COUNTRY_TO_ISO: dict[str, str] = { - "united states": "US", - "usa": "US", - "us": "US", - "united kingdom": "GB", +# pycountry covers all 249 ISO 3166-1 entries by common, official, and +# canonical names; the alias dicts below add only the forms pycountry +# doesn't resolve on its own. + +# Country aliases for forms pycountry won't match directly. Keys are +# lowercased and stripped of internal periods, so write them that way +# ("uk", "us", "drc" — NOT "u.k.", "u.s."). The lookup helper does the +# same normalisation on the incoming string. +_COUNTRY_ALIASES: dict[str, str] = { "uk": "GB", - "china": "CN", - "india": "IN", - "brazil": "BR", - "uganda": "UG", - "democratic republic of the congo": "CD", + "us": "US", + "usa": "US", "drc": "CD", - "congo": "CG", - "nigeria": "NG", - "south africa": "ZA", - "kenya": "KE", - "ethiopia": "ET", - "tanzania": "TZ", - "egypt": "EG", - "australia": "AU", - "canada": "CA", - "mexico": "MX", - "germany": "DE", - "france": "FR", - "italy": "IT", - "spain": "ES", - "japan": "JP", - "south korea": "KR", - "indonesia": "ID", - "thailand": "TH", - "vietnam": "VN", - "pakistan": "PK", - "bangladesh": "BD", - "saudi arabia": "SA", - "iran": "IR", - "turkey": "TR", + "dr congo": "CD", + "democratic republic of the congo": "CD", + "republic of the congo": "CG", + "uae": "AE", "russia": "RU", - "texas": "US", - "california": "US", - "iowa": "US", + "burma": "MM", + "ivory coast": "CI", + "north korea": "KP", + "south korea": "KR", + "england": "GB", + "scotland": "GB", + "wales": "GB", + "northern ireland": "GB", + "great britain": "GB", +} + +# US states and territories — biosecurity reporting frequently uses +# state-level location text. Anything in this set resolves to "US". +# Two-letter postal codes are also included (case-insensitive matching). +_US_SUBNATIONAL: set[str] = { + "alabama", "alaska", "arizona", "arkansas", "california", "colorado", + "connecticut", "delaware", "florida", "georgia", "hawaii", "idaho", + "illinois", "indiana", "iowa", "kansas", "kentucky", "louisiana", + "maine", "maryland", "massachusetts", "michigan", "minnesota", + "mississippi", "missouri", "montana", "nebraska", "nevada", + "new hampshire", "new jersey", "new mexico", "new york", + "north carolina", "north dakota", "ohio", "oklahoma", "oregon", + "pennsylvania", "rhode island", "south carolina", "south dakota", + "tennessee", "texas", "utah", "vermont", "virginia", "washington", + "west virginia", "wisconsin", "wyoming", + "district of columbia", "d.c.", "dc", + "puerto rico", "guam", "american samoa", + "u.s. virgin islands", "northern mariana islands", +} + +# Multi-country region labels that explicitly resolve to None. The model +# emits these often when reading WHO regional roll-ups (e.g. +# "European Region", "African Region") that aren't single countries. +_NOT_A_COUNTRY: set[str] = { + "africa", "asia", "europe", "north america", "south america", + "americas", "oceania", + "european region", "african region", + "region of the americas", "western pacific region", + "south-east asia region", "eastern mediterranean region", + "european union", "eu", "eu/eea", "eea", + "world", "global", "globally", "multi-country", "multi country", + "unknown", "various", "international", } @@ -204,19 +224,83 @@ def _quote_matches(quote: str, chunk_text: str) -> Optional[str]: return None +def _lookup_one(token: str) -> Optional[str]: + """Resolve a single location token to an ISO alpha-2 code. + + Order: explicit not-a-country set, alias dict, US subnational, then + pycountry's built-in lookup (which handles all 249 ISO 3166-1 + entries by common/canonical/official name and alpha-2/alpha-3 codes). + Returns ``None`` when nothing matches. ``search_fuzzy`` is + deliberately not used because it produces surprising false positives. + + The token is normalised against ``_NOT_A_COUNTRY``, ``_COUNTRY_ALIASES``, + and ``_US_SUBNATIONAL`` with internal punctuation removed (so "U.K." + matches the "uk" alias). pycountry is then called on the typography- + folded version (so "Côte d'Ivoire" with curly apostrophe also + resolves) — its own internal matching is case-insensitive but does + not handle smart quotes. + """ + if not token: + return None + # Typography-fold for pycountry's benefit (smart quotes → straight) + folded = _TYPOGRAPHY_FOLD_RE.sub( + lambda m: _TYPOGRAPHY_FOLD[m.group(0)], token.strip() + ) + # Alias/region lookup key — lowercased, no internal periods, single spaces + key = folded.lower().replace(".", "") + key = re.sub(r"\s+", " ", key).strip() + if not key: + return None + if key in _NOT_A_COUNTRY: + return None + if key in _COUNTRY_ALIASES: + return _COUNTRY_ALIASES[key] + if key in _US_SUBNATIONAL: + return "US" + try: + return pycountry.countries.lookup(folded).alpha_2 + except LookupError: + return None + + def _resolve_country_code(location: Optional[str]) -> Optional[str]: - """Try to resolve a location string to an ISO country code.""" + """Resolve a free-text location to an ISO 3166-1 alpha-2 country code. + + Handles common biosecurity reporting patterns: + + * Bare country names ("Uganda", "Côte d'Ivoire") — via pycountry. + * Common abbreviations ("DRC", "UK", "USA") — via the alias dict. + * US states ("Texas", "New Mexico") → "US". + * Multi-country regional roll-ups ("European Region", "EU/EEA", + "Africa") → ``None`` (these aren't single countries). + * Compound locations like ``"Mubende district, Uganda"`` — tries + the whole string first, then falls back to each comma-separated + segment from right to left. + """ if not location: return None - key = location.lower().strip() - if key in COUNTRY_TO_ISO: - return COUNTRY_TO_ISO[key] - # Try matching the last part (e.g., "Mubende district, Uganda" -> "uganda") - parts = key.split(",") - for part in reversed(parts): + cleaned = location.strip() + if not cleaned: + return None + + # Try the whole string first + result = _lookup_one(cleaned) + if result is not None: + return result + # Defer to the not-a-country whitelist on the full string before + # falling through to per-segment matching. Normalise the same way + # as _lookup_one so "European Region" / "europe" etc. all match. + if re.sub(r"\s+", " ", cleaned.lower().replace(".", "")).strip() in _NOT_A_COUNTRY: + return None + + # Try comma-separated segments right-to-left ("Mubende district, Uganda") + for part in reversed(cleaned.split(",")): part = part.strip() - if part in COUNTRY_TO_ISO: - return COUNTRY_TO_ISO[part] + if not part: + continue + result = _lookup_one(part) + if result is not None: + return result return None diff --git a/bioscancast/tests/test_insight_chunk_extractor.py b/bioscancast/tests/test_insight_chunk_extractor.py index d96a14f..2ba0421 100644 --- a/bioscancast/tests/test_insight_chunk_extractor.py +++ b/bioscancast/tests/test_insight_chunk_extractor.py @@ -161,6 +161,105 @@ def test_resolve_country_code_unknown(): assert _resolve_country_code("") is None +@pytest.mark.parametrize( + "location,expected", + [ + # All 249 ISO entries are covered by pycountry — spot-check the + # ones that show up in the 6 real biosecurity test documents. + ("Comoros", "KM"), + ("Bulgaria", "BG"), + ("Austria", "AT"), + ("Ireland", "IE"), + ("Portugal", "PT"), + ("Sweden", "SE"), + ("Romania", "RO"), + ("Netherlands", "NL"), + ("Madagascar", "MG"), + ("Mozambique", "MZ"), + ("Côte d'Ivoire", "CI"), + ("Côte d’Ivoire", "CI"), # curly apostrophe + # Alpha-2 / alpha-3 codes + ("DE", "DE"), + ("DEU", "DE"), + ], +) +def test_resolve_country_code_pycountry_lookups(location, expected): + assert _resolve_country_code(location) == expected + + +@pytest.mark.parametrize( + "location,expected", + [ + # Aliases that pycountry doesn't catch directly + ("UK", "GB"), + ("U.K.", "GB"), + ("England", "GB"), + ("Scotland", "GB"), + ("DRC", "CD"), + ("DR Congo", "CD"), + ("Democratic Republic of the Congo", "CD"), + ("Republic of the Congo", "CG"), + # Bare "Congo" intentionally resolves to CG (Republic of) via + # pycountry's canonical name — biosecurity reporting that means + # DRC should use one of the DRC aliases. + ("Congo", "CG"), + ("UAE", "AE"), + ("Russia", "RU"), + ("Burma", "MM"), + ("Myanmar", "MM"), + ("Ivory Coast", "CI"), + ("North Korea", "KP"), + ("South Korea", "KR"), + ], +) +def test_resolve_country_code_aliases(location, expected): + assert _resolve_country_code(location) == expected + + +@pytest.mark.parametrize( + "state", + [ + "California", "Iowa", "New Mexico", "Utah", "Texas", + "New York", "Florida", "Washington", + ], +) +def test_resolve_country_code_us_states_resolve_to_us(state): + assert _resolve_country_code(state) == "US" + + +@pytest.mark.parametrize( + "region", + [ + "Africa", "Asia", "Europe", "Americas", "Oceania", + "European Region", "African Region", + "Region of the Americas", "Western Pacific Region", + "South-East Asia Region", "Eastern Mediterranean Region", + "EU/EEA", "EU", "EEA", + "World", "global", + ], +) +def test_resolve_country_code_multi_country_regions_return_none(region): + assert _resolve_country_code(region) is None + + +@pytest.mark.parametrize( + "location,expected", + [ + # Compound location strings — the resolver should try the last + # segment after a comma. + ("Mubende district, Uganda", "UG"), + ("Lea County, New Mexico", "US"), + ("Ngazidja region, Comoros", "KM"), + ("Lyon, France", "FR"), + # When the last segment is a region label, the resolver should + # return None rather than guessing. + ("Some city, European Region", None), + ], +) +def test_resolve_country_code_compound_locations(location, expected): + assert _resolve_country_code(location) == expected + + # --------------------------------------------------------------------------- # Token tracking # --------------------------------------------------------------------------- diff --git a/requirements.txt b/requirements.txt index 778005b..86323fa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,5 @@ pdfplumber>=0.11,<1.0 # Fallback PDF table extraction for cases PyMuPDF mishandl docling[chunking]>=2.90,<3.0 # TableFormer-based refinement for borderless / merged-cell PDF tables (first-run downloads ~40 MB to the HuggingFace cache) tiktoken>=0.7,<1.0 # Approximate token counting (cl100k_base encoding) openai>=1.0,<2.0 # OpenAI API client (used by filtering stage LLM calls) +pycountry>=24.0 # ISO 3166-1 country name → alpha-2 lookup (insight stage location resolution) pytest>=8.0,<9.0 # Testing From d3205a4b2de6d825502a73dad58b800b87a4ed4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20Mod=C3=A9e?= Date: Tue, 26 May 2026 13:28:35 +0200 Subject: [PATCH 10/21] Reject filename-shaped PDF /Title metadata MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PyMuPDF surfaces whatever the PDF's /Title metadata field says, with no filtering. Real biosecurity PDFs often have stale conversion titles leaked from the source Word document — tests show ECDC's CDTR returning "2026-WCP-0020 Draft.docx" as its title, which then displayed verbatim through Document.title and into any downstream consumer. A new `_sanitize_title` helper on PdfParser drops: - Titles ending in a document-format extension (.pdf, .docx, .doc, .odt, .rtf, .txt, .pages, .xlsx, .pptx, .html, case-insensitive). - Empty / whitespace-only titles. - Implausibly short titles (< 5 chars). When the sanitiser returns None, the existing `parsed.title or filtered_doc.title` fallback chain in `extraction/pipeline.py` picks up the search-side title instead — which is the desired behaviour. Tests confirm the ECDC stale title is dropped, every other filename extension variant is rejected, and real document titles (WHO sitreps, MMWR articles, ECDC CDTR) pass through unchanged. Co-Authored-By: Claude Opus 4.7 --- bioscancast/extraction/parsers/pdf_parser.py | 46 +++++++++++++++- bioscancast/tests/test_extraction_pdf.py | 57 ++++++++++++++++++++ 2 files changed, 102 insertions(+), 1 deletion(-) diff --git a/bioscancast/extraction/parsers/pdf_parser.py b/bioscancast/extraction/parsers/pdf_parser.py index 180b66c..2137a04 100644 --- a/bioscancast/extraction/parsers/pdf_parser.py +++ b/bioscancast/extraction/parsers/pdf_parser.py @@ -2,6 +2,7 @@ import io import logging +import re import statistics from datetime import datetime from typing import List, Optional @@ -15,6 +16,18 @@ _DEFAULT_MAX_PAGES = 100 +# A PDF /Title metadata value is "filename-shaped" when it ends in a +# document-format extension or is implausibly short. Many PDFs from +# the WHO / ECDC / CDC pipelines have stale conversion titles like +# "2026-WCP-0020 Draft.docx" — those should be rejected so downstream +# falls back to the search-side title rather than displaying a +# meaningless filename. +_FILENAME_EXT_RE = re.compile( + r"\.(pdf|docx?|odt|rtf|txt|pages|xlsx?|pptx?|html?)$", + re.IGNORECASE, +) +_MIN_TITLE_LENGTH = 5 + class PdfParser: """Extracts structured content from PDF documents using PyMuPDF.""" @@ -37,7 +50,7 @@ def parse(self, content: bytes, *, source_url: str) -> ParsedContent: # Extract metadata meta = doc.metadata or {} - title = meta.get("title") or None + title = self._sanitize_title(meta.get("title")) pub_date = self._parse_pdf_date(meta.get("creationDate")) all_text_parts: List[str] = [] @@ -234,6 +247,37 @@ def _extract_tables_pdfplumber( logger.debug("pdfplumber fallback failed on page %d: %s", page_index, exc) return [] + def _sanitize_title(self, title: Optional[str]) -> Optional[str]: + """Filter out PDF /Title metadata values that aren't real titles. + + Returns the trimmed title when it looks like real content, or + ``None`` to let the upstream caller fall back to the search-side + title. Rejects: + + * Empty / whitespace-only values. + * Strings ending in a document-format extension (``.pdf``, + ``.docx``, ``.html``, etc.) — typically a stale Word→PDF + conversion title or a verbatim filename. + * Implausibly short titles (fewer than ``_MIN_TITLE_LENGTH`` + characters) — usually internal doc IDs or leftover placeholders. + """ + if not title: + return None + stripped = title.strip() + if not stripped: + return None + if _FILENAME_EXT_RE.search(stripped): + logger.debug( + "Dropping filename-shaped PDF title: %r", stripped[:80] + ) + return None + if len(stripped) < _MIN_TITLE_LENGTH: + logger.debug( + "Dropping implausibly short PDF title: %r", stripped + ) + return None + return stripped + def _parse_pdf_date(self, date_str: Optional[str]) -> Optional[datetime]: """Parse PDF date strings like 'D:20240115120000+00'00''.""" if not date_str: diff --git a/bioscancast/tests/test_extraction_pdf.py b/bioscancast/tests/test_extraction_pdf.py index 357f2ed..e58b77f 100644 --- a/bioscancast/tests/test_extraction_pdf.py +++ b/bioscancast/tests/test_extraction_pdf.py @@ -95,3 +95,60 @@ def test_page_cap_triggers_partial(self, who_pdf): assert result.is_partial assert result.partial_reason is not None assert "Truncated" in result.partial_reason + + +# --------------------------------------------------------------------------- +# Title sanitisation +# --------------------------------------------------------------------------- + +class TestSanitizeTitle: + """The PDF /Title metadata field is often a stale conversion artefact + (e.g. "2026-WCP-0020 Draft.docx") rather than a real document title. + _sanitize_title filters those out so the upstream pipeline can fall + back to the search-side title. + """ + + def test_filename_ending_in_docx_is_rejected(self, pdf_parser): + # ECDC CDTR case observed in live tests + assert pdf_parser._sanitize_title("2026-WCP-0020 Draft.docx") is None + + @pytest.mark.parametrize( + "title", + [ + "report.pdf", + "draft.doc", + "file.docx", + "presentation.pptx", + "data.xlsx", + "page.html", + "notes.txt", + "Document.PDF", # case-insensitive + ], + ) + def test_other_filename_extensions_are_rejected(self, pdf_parser, title): + assert pdf_parser._sanitize_title(title) is None + + def test_empty_or_whitespace_title_is_rejected(self, pdf_parser): + assert pdf_parser._sanitize_title(None) is None + assert pdf_parser._sanitize_title("") is None + assert pdf_parser._sanitize_title(" ") is None + assert pdf_parser._sanitize_title("\n\t") is None + + def test_implausibly_short_title_is_rejected(self, pdf_parser): + assert pdf_parser._sanitize_title("ab") is None + assert pdf_parser._sanitize_title("xyz") is None + + @pytest.mark.parametrize( + "title", + [ + "Measles Outbreak — New Mexico, 2025", + "Multi-country outbreak of mpox, External situation report 64", + "Communicable Disease Threats Report, Week 16, 2026", + "Weekly Cholera Update No. 34", + ], + ) + def test_real_titles_pass_through(self, pdf_parser, title): + assert pdf_parser._sanitize_title(title) == title + + def test_title_is_trimmed(self, pdf_parser): + assert pdf_parser._sanitize_title(" Real Title ") == "Real Title" From 49e374dfb2fb9a80300bfcd1552a3b943970799e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20Mod=C3=A9e?= Date: Tue, 26 May 2026 13:32:20 +0200 Subject: [PATCH 11/21] Migrate filtering stage to the shared LLMClient protocol MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The filtering stage was using the older `bioscancast/llm/client.py:LLMClient` (single positional prompt string, returns plain dict) — a parallel protocol that's been carrying a long- standing TODO in the insight README. The shared protocol at `bioscancast/llm/base.py:LLMClient` uses keyword-only system/user/schema/ model/max_tokens and returns LLMResponse with structured token accounting, matching what insight (and soon forecasting) already uses. Changes: - `bioscancast/filtering/llm_filter.py`: * `build_filter_prompt` now returns a (system, user, schema) triple instead of a single concatenated JSON string. System carries the task instructions; user carries the question + candidates payload; schema is a real strict JSON Schema (not the previous example-dict). * `llm_filter_candidates` calls `llm_client.generate_json` with kwargs and reads `response.content["decisions"]` instead of the raw dict. * Adds default `model` and `max_tokens` parameters. - `bioscancast/filtering/pipeline.py`: switches `LLMClient` import to the shared base protocol. - `bioscancast/llm/client.py`: keeps the legacy single-positional protocol intact (search stage still uses it) but adds a top-of-file docstring warning new callers off. - `bioscancast/insight/README.md`: the filtering-migration TODO is closed; replaced with the remaining search-stage migration as a follow-up. - New `bioscancast/tests/test_filtering_llm.py` covers the new protocol path: prompt triple, strict JSON schema shape, the expected `generate_json(**kwargs)` call signature, missing-decision handling, empty-input shortcut, and a regression check that the filter module never re-imports the legacy client. Co-Authored-By: Claude Opus 4.7 --- bioscancast/filtering/llm_filter.py | 111 +++++++++++---- bioscancast/filtering/pipeline.py | 2 +- bioscancast/insight/README.md | 3 +- bioscancast/llm/client.py | 14 ++ bioscancast/tests/test_filtering_llm.py | 182 ++++++++++++++++++++++++ 5 files changed, 281 insertions(+), 31 deletions(-) create mode 100644 bioscancast/tests/test_filtering_llm.py diff --git a/bioscancast/filtering/llm_filter.py b/bioscancast/filtering/llm_filter.py index 2bcaa70..b3129d8 100644 --- a/bioscancast/filtering/llm_filter.py +++ b/bioscancast/filtering/llm_filter.py @@ -3,23 +3,78 @@ import json from typing import Dict, List -from bioscancast.llm.client import LLMClient +from bioscancast.llm.base import LLMClient from .models import FilterDecision, ForecastQuestion, SearchResult +# Default model and max output tokens for the filter LLM call. These can be +# overridden via `llm_filter_candidates(... , model=..., max_tokens=...)`. +DEFAULT_FILTER_MODEL = "gpt-4o-mini" +DEFAULT_FILTER_MAX_TOKENS = 4096 + + +FILTER_SYSTEM_PROMPT = ( + "You are filtering search results for a biosecurity forecasting " + "pipeline. Your job is to decide which candidates are likely to " + "contain relevant factual evidence for forecasting. Prefer official, " + "primary, recent, and event-specific sources. Reject low-information, " + "generic, duplicated, or weakly relevant pages. Return a JSON object " + "matching the supplied schema with one decision per candidate." +) + + +FILTER_OUTPUT_SCHEMA: dict = { + "type": "object", + "properties": { + "decisions": { + "type": "array", + "items": { + "type": "object", + "properties": { + "result_id": {"type": "string"}, + "keep": {"type": "boolean"}, + "relevance_score": { + "type": "number", "minimum": 0.0, "maximum": 1.0, + }, + "credibility_score": { + "type": "number", "minimum": 0.0, "maximum": 1.0, + }, + "final_score": { + "type": "number", "minimum": 0.0, "maximum": 1.0, + }, + "reason_codes": { + "type": "array", + "items": {"type": "string"}, + }, + "notes": {"type": ["string", "null"]}, + }, + "required": [ + "result_id", "keep", "relevance_score", + "credibility_score", "final_score", "reason_codes", + "notes", + ], + "additionalProperties": False, + }, + }, + }, + "required": ["decisions"], + "additionalProperties": False, +} + + def build_filter_prompt( question: ForecastQuestion, candidates: list[dict], -) -> str: - payload = { - "task": ( - "You are filtering search results for a biosecurity forecasting pipeline. " - "Keep only candidates likely to contain relevant factual evidence for forecasting. " - "Prefer official, primary, recent, and event-specific sources. " - "Reject low-information, generic, duplicated, or weakly relevant pages. " - "Return your response as JSON matching the output_schema below." - ), +) -> tuple[str, str, dict]: + """Build the (system, user, schema) tuple for the filter LLM call. + + The user prompt is a JSON payload containing the question and the + candidate list; the system prompt covers the task instructions; the + schema is enforced by structured-output capable LLMs (and ignored + by ones that aren't). + """ + user_payload = { "question": { "id": question.id, "text": question.text, @@ -29,21 +84,9 @@ def build_filter_prompt( "resolution_criteria": question.resolution_criteria, }, "candidates": candidates, - "output_schema": { - "decisions": [ - { - "result_id": "string", - "keep": "boolean", - "relevance_score": "0_to_1_float", - "credibility_score": "0_to_1_float", - "final_score": "0_to_1_float", - "reason_codes": ["list_of_short_strings"], - "notes": "short explanation", - } - ] - }, } - return json.dumps(payload, default=str, indent=2) + user = json.dumps(user_payload, default=str, indent=2) + return FILTER_SYSTEM_PROMPT, user, FILTER_OUTPUT_SCHEMA def llm_filter_candidates( @@ -51,6 +94,9 @@ def llm_filter_candidates( candidate_decisions: List[FilterDecision], result_map: Dict[str, SearchResult], llm_client: LLMClient, + *, + model: str = DEFAULT_FILTER_MODEL, + max_tokens: int = DEFAULT_FILTER_MAX_TOKENS, ) -> List[FilterDecision]: if not candidate_decisions: return [] @@ -73,10 +119,19 @@ def llm_filter_candidates( } ) - prompt = build_filter_prompt(question, candidates) - response = llm_client.generate_json(prompt) + system, user, schema = build_filter_prompt(question, candidates) + response = llm_client.generate_json( + system=system, + user=user, + schema=schema, + model=model, + max_tokens=max_tokens, + ) - output_by_id = {item["result_id"]: item for item in response.get("decisions", [])} + output_by_id = { + item["result_id"]: item + for item in response.content.get("decisions", []) + } updated: list[FilterDecision] = [] for decision in candidate_decisions: @@ -97,4 +152,4 @@ def llm_filter_candidates( decision.notes = data.get("notes") updated.append(decision) - return updated \ No newline at end of file + return updated diff --git a/bioscancast/filtering/pipeline.py b/bioscancast/filtering/pipeline.py index 23443d9..e8c2e5b 100644 --- a/bioscancast/filtering/pipeline.py +++ b/bioscancast/filtering/pipeline.py @@ -2,7 +2,7 @@ from typing import List, Optional -from bioscancast.llm.client import LLMClient +from bioscancast.llm.base import LLMClient from .config import FILTER_CONFIG from .deduplication import deduplicate_filtered_documents diff --git a/bioscancast/insight/README.md b/bioscancast/insight/README.md index 69ab134..e5dd040 100644 --- a/bioscancast/insight/README.md +++ b/bioscancast/insight/README.md @@ -43,7 +43,6 @@ See `config.py` for all configurable values (`InsightConfig`). ## TODO -- [ ] Migrate `bioscancast/filtering/llm_filter.py`'s local `LLMClient` protocol to the shared `bioscancast/llm/base.py` protocol. - [ ] Cognitive bias mitigations belong primarily in the **forecasting** stage's prompts, not here. Insight extraction is neutral fact-finding. A brief reminder is included in the extraction prompt but full bias mitigation should be implemented in forecasting. -- [ ] When extraction lands and synthetic fixtures are swapped for real Document outputs, expect surprises. Plan a follow-up PR to harden the chunk extractor against messy real-world chunks (long text, mid-sentence breaks, OCR garbage in tables). +- [ ] Migrate `bioscancast/stages/search_stage/`'s use of the legacy `bioscancast/llm/client.py` to the shared `bioscancast/llm/base.py` protocol (filtering has already been migrated; the legacy client lingers only because the search stage still calls it). - [ ] Strong model refinement pass (behind `use_strong_model_refinement` config flag, currently a no-op). diff --git a/bioscancast/llm/client.py b/bioscancast/llm/client.py index 425ac78..5bb983d 100644 --- a/bioscancast/llm/client.py +++ b/bioscancast/llm/client.py @@ -1,3 +1,17 @@ +"""Legacy LLM client kept for the search stage's transitional period. + +This module pre-dates ``bioscancast/llm/base.py`` and uses a simpler +``generate_json(prompt: str) -> dict`` signature. New code should use +``bioscancast.llm.base.LLMClient`` (with structured ``system/user/schema/ +model/max_tokens`` arguments returning an ``LLMResponse``) and the +production client at ``bioscancast/llm/openai_client.py``. + +Filtering and the insight stage already use the modern protocol. The +search stage (``bioscancast/stages/search_stage/``) is the only +remaining consumer of this module and will be migrated in a follow-up. +Do not add new callers here. +""" + from __future__ import annotations import json diff --git a/bioscancast/tests/test_filtering_llm.py b/bioscancast/tests/test_filtering_llm.py new file mode 100644 index 0000000..e7891c2 --- /dev/null +++ b/bioscancast/tests/test_filtering_llm.py @@ -0,0 +1,182 @@ +"""Tests for the LLM-driven filtering stage. + +The filter uses the shared ``bioscancast.llm.base.LLMClient`` protocol — +the older single-positional-argument ``bioscancast.llm.client.LLMClient`` +is no longer accepted here. +""" + +from __future__ import annotations + +from datetime import datetime, timezone + +from bioscancast.filtering.llm_filter import ( + DEFAULT_FILTER_MAX_TOKENS, + DEFAULT_FILTER_MODEL, + FILTER_OUTPUT_SCHEMA, + build_filter_prompt, + llm_filter_candidates, +) +from bioscancast.filtering.models import ( + FilterDecision, + ForecastQuestion, + SearchResult, +) +from bioscancast.llm.base import LLMResponse +from bioscancast.llm.fake_client import FakeLLMClient + + +def _make_question() -> ForecastQuestion: + return ForecastQuestion( + id="q1", + text="Will US H5N1 herds exceed 1500 by June 2026?", + created_at=datetime(2026, 4, 1, tzinfo=timezone.utc), + pathogen="H5N1", + region="United States", + event_type="case_count", + ) + + +def _make_result(rid: str, url: str) -> SearchResult: + return SearchResult( + id=rid, question_id="q1", query_id="qx", engine="tavily", + url=url, canonical_url=url, domain="cdc.gov", + title=f"Result {rid}", snippet="data on H5N1 herds", rank=1, + retrieved_at=datetime(2026, 4, 1, tzinfo=timezone.utc), + ) + + +def _make_decision(rid: str) -> FilterDecision: + return FilterDecision( + result_id=rid, keep=None, stage="heuristic", + relevance_score=0.5, credibility_score=0.5, priority_score=0.5, + ) + + +def test_build_filter_prompt_returns_triple(): + """build_filter_prompt must return (system, user, schema) matching the + shared LLMClient.generate_json signature.""" + q = _make_question() + system, user, schema = build_filter_prompt(q, [{"result_id": "r1"}]) + assert isinstance(system, str) and system.startswith("You are filtering") + assert isinstance(user, str) + # User prompt is a JSON object containing the question and candidates + import json as _json + payload = _json.loads(user) + assert payload["question"]["id"] == "q1" + assert payload["candidates"] == [{"result_id": "r1"}] + assert schema is FILTER_OUTPUT_SCHEMA + + +def test_filter_output_schema_is_strict_json_schema(): + """The output schema must be a real JSON Schema (not an example) + with strict OpenAI-compatible properties.""" + schema = FILTER_OUTPUT_SCHEMA + assert schema["type"] == "object" + assert schema["additionalProperties"] is False + item = schema["properties"]["decisions"]["items"] + assert item["type"] == "object" + assert item["additionalProperties"] is False + # All listed properties must be required + assert set(item["required"]) == set(item["properties"].keys()) + + +def test_llm_filter_candidates_calls_new_protocol(): + """llm_filter_candidates must call generate_json with system/user/schema/ + model/max_tokens kwargs and read from response.content (LLMResponse).""" + q = _make_question() + sr = _make_result("r1", "https://cdc.gov/h5n1") + fd = _make_decision("r1") + + captured = {} + + class CapturingFake: + def generate_json(self, *, system, user, schema, model, max_tokens): + captured.update( + system=system, user=user, schema=schema, + model=model, max_tokens=max_tokens, + ) + return LLMResponse( + content={"decisions": [{ + "result_id": "r1", "keep": True, + "relevance_score": 0.9, "credibility_score": 0.85, + "final_score": 0.88, + "reason_codes": ["official_source"], + "notes": "cdc dashboard", + }]}, + input_tokens=100, output_tokens=20, + model="gpt-4o-mini", raw_text="{}", + ) + + def embed(self, texts, *, model): + raise NotImplementedError("filter doesn't embed") + + result = llm_filter_candidates(q, [fd], {"r1": sr}, CapturingFake()) + + # Protocol call + assert "system" in captured + assert "user" in captured + assert captured["schema"] is FILTER_OUTPUT_SCHEMA + assert captured["model"] == DEFAULT_FILTER_MODEL + assert captured["max_tokens"] == DEFAULT_FILTER_MAX_TOKENS + + # Decision parsing + assert len(result) == 1 + decision = result[0] + assert decision.keep is True + assert decision.stage == "llm" + assert decision.relevance_score == 0.9 + assert decision.credibility_score == 0.85 + assert decision.priority_score == 0.88 + assert decision.reason_codes == ["official_source"] + assert decision.notes == "cdc dashboard" + + +def test_llm_filter_candidates_handles_missing_decision(): + """If the LLM omits a decision for some candidate, that candidate + should be marked keep=False with a 'missing_llm_decision' reason.""" + q = _make_question() + sr_a = _make_result("ra", "https://a.com") + sr_b = _make_result("rb", "https://b.com") + fake = FakeLLMClient([LLMResponse( + content={"decisions": [{ + "result_id": "ra", "keep": True, + "relevance_score": 0.8, "credibility_score": 0.8, + "final_score": 0.8, "reason_codes": [], "notes": None, + }]}, # rb is omitted + input_tokens=100, output_tokens=20, + model="gpt-4o-mini", raw_text="{}", + )]) + + result = llm_filter_candidates( + q, + [_make_decision("ra"), _make_decision("rb")], + {"ra": sr_a, "rb": sr_b}, + fake, + ) + + by_id = {d.result_id: d for d in result} + assert by_id["ra"].keep is True + assert by_id["rb"].keep is False + assert "missing_llm_decision" in by_id["rb"].reason_codes + + +def test_llm_filter_candidates_empty_input_returns_empty(): + """Empty candidate list should not call the LLM at all.""" + + class ExplodingFake: + def generate_json(self, **_): + raise AssertionError("generate_json must not be called on empty input") + def embed(self, *_, **__): + raise NotImplementedError + + q = _make_question() + assert llm_filter_candidates(q, [], {}, ExplodingFake()) == [] + + +def test_llm_filter_does_not_use_legacy_client_module(): + """The migration: the filter module must NOT import from + bioscancast.llm.client (the legacy single-positional protocol).""" + import bioscancast.filtering.llm_filter as mod + src = open(mod.__file__, encoding="utf-8").read() + assert "from bioscancast.llm.client" not in src + assert "bioscancast.llm.base" in src From e29103dc4d550167b1f9e0321322b115f1a1d665 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20Mod=C3=A9e?= Date: Tue, 26 May 2026 13:38:07 +0200 Subject: [PATCH 12/21] Parallelise per-chunk extraction with ThreadPoolExecutor MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The insight pipeline's per-chunk LLM call loop was strictly sequential. Tests on the 6 real biosecurity documents showed each doc spending almost all its wall-clock in serial OpenAI request latency: WHO mpox ~30s, ECDC CDTR ~38s, MMWR ~11s for top-k=5 chunks. Since each extract_facts_from_chunk call is independent and the OpenAI sync client is thread-safe, this was leaving easy speedups on the floor. Changes: - `InsightPipeline.run` now dispatches per-chunk extractions to a `ThreadPoolExecutor` whose pool size is `min(chunk_workers, len(scored_chunks))`. With chunk_workers=1 or only one chunk, the code takes a sequential fallback path. Errors in one chunk are caught, logged, and don't abort the document. - Budget accounting still happens serially on the main thread after futures complete, so BudgetTracker stays simple (no locks needed). - New `chunk_workers: int = 6` field on `InsightConfig`. Six is a pragmatic default — matches the typical retrieval_top_k while staying well below OpenAI's per-minute rate limits for gpt-4o-mini. Setting to 1 reproduces the previous sequential behaviour. - `FakeLLMClient.generate_json` and `enqueue` now hold a `threading.Lock` around the response deque and counters so the test fakes stay deterministic under concurrent calls. New tests: - test_pipeline_parallel_chunk_extraction_produces_all_records: content-keyed fake confirms every chunk in the top-k is processed by parallel workers and records survive provenance checks. - test_pipeline_sequential_and_parallel_produce_same_record_count: chunk_workers=1 and chunk_workers=4 yield identical record counts and identical input-token totals on the same input. - test_pipeline_parallel_isolates_chunk_failures: a fake that throws on the 2nd chunk doesn't kill the doc — the other three still complete and the doc is marked processed. Live verification on the 6 real biosecurity documents: wall-clock per doc drops ~2× across the board (WHO mpox 24s→13s, WHO cholera 11s→3s, MMWR 11s→3s, ECDC CDTR 37s→15s). Record counts stable within LLM stochasticity. Co-Authored-By: Claude Opus 4.7 --- bioscancast/insight/config.py | 9 ++ bioscancast/insight/pipeline.py | 63 ++++++-- bioscancast/llm/fake_client.py | 33 ++-- bioscancast/tests/test_insight_pipeline.py | 178 +++++++++++++++++++++ 4 files changed, 262 insertions(+), 21 deletions(-) diff --git a/bioscancast/insight/config.py b/bioscancast/insight/config.py index 05cc4af..cec8b4f 100644 --- a/bioscancast/insight/config.py +++ b/bioscancast/insight/config.py @@ -17,6 +17,7 @@ "max_input_tokens_per_run": 500_000, "max_chunks_per_document": 12, "extraction_max_output_tokens": 4096, + "chunk_workers": 6, } @@ -38,6 +39,14 @@ class InsightConfig: 1024 ceiling in LLMClient.generate_json truncates dense pages (e.g. the ECDC CDTR) mid-JSON; 4096 leaves comfortable headroom.""" + chunk_workers: int = 6 + """Max parallel per-chunk extraction calls per document. Each call is + a single OpenAI chat-completions request; the SDK is thread-safe so a + ThreadPoolExecutor is used. Six matches typical retrieval_top_k while + staying well below OpenAI's per-minute rate limits for gpt-4o-mini. + Set to 1 for sequential execution (useful for debugging or rate- + limit-sensitive setups).""" + @classmethod def from_dict(cls, d: dict) -> InsightConfig: """Create an InsightConfig from a dict, ignoring unknown keys.""" diff --git a/bioscancast/insight/pipeline.py b/bioscancast/insight/pipeline.py index b8ebb80..58237de 100644 --- a/bioscancast/insight/pipeline.py +++ b/bioscancast/insight/pipeline.py @@ -20,6 +20,7 @@ from __future__ import annotations import logging +from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from datetime import datetime from typing import Optional @@ -59,6 +60,29 @@ def __init__( self._llm = llm_client self._config = config or InsightConfig() + def _safe_extract(self, sc, doc, question, config): + """Wrap chunk extraction so a failure in one chunk doesn't kill + the document. Returns the (records, response) tuple or None. + + Called both serially and from ThreadPoolExecutor workers — must + not mutate ``self`` so concurrent calls stay safe. + """ + try: + return extract_facts_from_chunk( + sc.chunk, + doc, + question, + self._llm, + model=config.cheap_model, + max_tokens=config.extraction_max_output_tokens, + ) + except Exception as exc: # noqa: BLE001 + logger.warning( + "Chunk extraction failed (chunk_id=%s): %s", + sc.chunk.chunk_id, exc, + ) + return None + def run( self, question: ForecastQuestion, @@ -112,16 +136,35 @@ def run( # Cap chunks per document scored_chunks = scored_chunks[: config.max_chunks_per_document] - # --- Per-chunk extraction --- - for sc in scored_chunks: - records, response = extract_facts_from_chunk( - sc.chunk, - doc, - question, - self._llm, - model=config.cheap_model, - max_tokens=config.extraction_max_output_tokens, - ) + # --- Per-chunk extraction (parallel within a doc) --- + # Live tests on real biosecurity documents show the per-doc + # wall-clock is almost entirely sequential OpenAI request + # latency. A ThreadPoolExecutor cuts WHO mpox (~30s) and ECDC + # CDTR (~38s) down by roughly chunk_workers× because each + # request is independent and the OpenAI sync client is + # thread-safe. Errors in one chunk don't kill the doc; + # budget accounting happens serially after futures complete + # so BudgetTracker doesn't need its own lock. + workers = max(1, min(config.chunk_workers, len(scored_chunks))) + if workers == 1 or len(scored_chunks) <= 1: + chunk_results = [ + self._safe_extract(sc, doc, question, config) + for sc in scored_chunks + ] + else: + with ThreadPoolExecutor(max_workers=workers) as ex: + futures = [ + ex.submit( + self._safe_extract, sc, doc, question, config, + ) + for sc in scored_chunks + ] + chunk_results = [f.result() for f in futures] + + for outcome in chunk_results: + if outcome is None: + continue + records, response = outcome budget.record(response) all_records.extend(records) diff --git a/bioscancast/llm/fake_client.py b/bioscancast/llm/fake_client.py index 32f7a9c..f56302e 100644 --- a/bioscancast/llm/fake_client.py +++ b/bioscancast/llm/fake_client.py @@ -9,6 +9,7 @@ import hashlib import math +import threading from collections import deque from typing import Sequence @@ -20,6 +21,13 @@ class FakeLLMClient: Responses are consumed in FIFO order. If the queue is exhausted, a RuntimeError is raised — failing loudly beats returning empty dicts. + + ``generate_json`` is thread-safe under multiple concurrent callers + (e.g. when the insight pipeline's per-chunk ThreadPoolExecutor calls + it from several workers): a single ``threading.Lock`` protects the + response deque and the call-count / token totals. Without this, the + deque popleft + counter increment race and tests run under concurrency + can drop responses or under-count tokens. """ def __init__( @@ -33,10 +41,12 @@ def __init__( self.call_count = 0 self.total_input_tokens = 0 self.total_output_tokens = 0 + self._lock = threading.Lock() def enqueue(self, *responses: LLMResponse) -> None: """Add responses to the end of the queue.""" - self._responses.extend(responses) + with self._lock: + self._responses.extend(responses) def generate_json( self, @@ -47,16 +57,17 @@ def generate_json( model: str, max_tokens: int = 1024, ) -> LLMResponse: - if not self._responses: - raise RuntimeError( - f"FakeLLMClient: no scripted responses left " - f"(call #{self.call_count + 1}). " - f"Enqueue more responses before running the test." - ) - response = self._responses.popleft() - self.call_count += 1 - self.total_input_tokens += response.input_tokens - self.total_output_tokens += response.output_tokens + with self._lock: + if not self._responses: + raise RuntimeError( + f"FakeLLMClient: no scripted responses left " + f"(call #{self.call_count + 1}). " + f"Enqueue more responses before running the test." + ) + response = self._responses.popleft() + self.call_count += 1 + self.total_input_tokens += response.input_tokens + self.total_output_tokens += response.output_tokens return response def embed(self, texts: list[str], *, model: str) -> list[list[float]]: diff --git a/bioscancast/tests/test_insight_pipeline.py b/bioscancast/tests/test_insight_pipeline.py index 9cdae02..59d252c 100644 --- a/bioscancast/tests/test_insight_pipeline.py +++ b/bioscancast/tests/test_insight_pipeline.py @@ -404,6 +404,184 @@ def test_dedup_does_not_merge_across_different_locations(): assert len(result) == 2 +# --------------------------------------------------------------------------- +# Per-chunk parallelism (ThreadPoolExecutor) +# --------------------------------------------------------------------------- + + +def _content_keyed_response(chunk_text_marker: str, quote: str) -> LLMResponse: + """Build an LLMResponse whose `quote` matches a chunk that contains + `chunk_text_marker`. Used to test that parallel chunk extraction + pairs responses with the right chunks regardless of which worker + processes which chunk first. + """ + return LLMResponse( + content={"facts": [{ + "event_type": "case_count", + "confidence": 0.8, + "location": None, + "pathogen": None, + "metric_name": "marker", + "metric_value": 1.0, + "metric_unit": "events", + "event_date": None, + "summary": chunk_text_marker, + "quote": quote, + }]}, + input_tokens=100, + output_tokens=20, + model="gpt-4o-mini", + raw_text='{"facts": [...]}', + ) + + +class _ContentKeyedFakeLLM: + """Smart fake that returns a different response per chunk based on + which chunk text appears in the user prompt. Order-independent — so + concurrent worker threads can hit any chunk in any order and still + get the right quote-matching response. + """ + + def __init__(self) -> None: + import threading + self._lock = threading.Lock() + self.calls = 0 + + def generate_json(self, *, system, user, schema, model, max_tokens=1024): + import re as _re + with self._lock: + self.calls += 1 + # Find the chunk text inside the prompt + marker_match = _re.search(r"CHUNK TEXT:\n(.+?)$", user, _re.DOTALL) + if not marker_match: + return LLMResponse( + content={"facts": []}, input_tokens=50, output_tokens=5, + model=model, raw_text="{}", + ) + chunk_text = marker_match.group(1) + # Pull the first sentence as a verbatim quote + sentence = _re.search(r"[A-Z][^.\n]{20,150}\.", chunk_text) + if not sentence: + return LLMResponse( + content={"facts": []}, input_tokens=50, output_tokens=5, + model=model, raw_text="{}", + ) + quote = sentence.group(0) + return LLMResponse( + content={"facts": [{ + "event_type": "case_count", + "confidence": 0.7, + "location": None, + "pathogen": None, + "metric_name": "concurrent_marker", + "metric_value": 1.0, + "metric_unit": "events", + "event_date": None, + "summary": None, + "quote": quote, + }]}, + input_tokens=120, output_tokens=25, + model=model, raw_text='{"facts": [...]}', + ) + + def embed(self, texts, *, model): + # Reuse FakeLLMClient's hash-based embeddings + return FakeLLMClient(embedding_dim=32).embed(texts, model=model) + + +def test_pipeline_parallel_chunk_extraction_produces_all_records(): + """With chunk_workers > 1, every retrieved chunk should still be + processed and every accepted quote should produce a record. Uses a + content-keyed fake so response/chunk pairing is order-independent.""" + fake = _ContentKeyedFakeLLM() + config = InsightConfig( + retrieval_top_k=4, + max_chunks_per_document=4, + chunk_workers=4, + ) + pipeline = InsightPipeline(llm_client=fake, config=config) + result = pipeline.run(QUESTION_SUDAN, [DOC_WHO_SUDAN]) + + # Every chunk in the top-k should have been called + assert fake.calls == 4 + # The pipeline-level dedup may merge records, but at least the + # ones whose chunks contain a quotable sentence should produce one + # record each (mod dedup). + assert len(result.records) >= 1 + # All records must have valid provenance + for rec in result.records: + assert rec.sources + for s in rec.sources: + assert s.quote + + +def test_pipeline_sequential_and_parallel_produce_same_record_count(): + """chunk_workers=1 and chunk_workers=4 must produce the same number + of records when the fake LLM is content-keyed (so result depends on + chunk content, not worker order).""" + config_seq = InsightConfig( + retrieval_top_k=4, max_chunks_per_document=4, chunk_workers=1, + ) + config_par = InsightConfig( + retrieval_top_k=4, max_chunks_per_document=4, chunk_workers=4, + ) + + seq_pipeline = InsightPipeline( + llm_client=_ContentKeyedFakeLLM(), config=config_seq, + ) + par_pipeline = InsightPipeline( + llm_client=_ContentKeyedFakeLLM(), config=config_par, + ) + + seq_result = seq_pipeline.run(QUESTION_SUDAN, [DOC_WHO_SUDAN]) + par_result = par_pipeline.run(QUESTION_SUDAN, [DOC_WHO_SUDAN]) + + assert len(seq_result.records) == len(par_result.records) + assert ( + seq_result.budget_summary["total_input_tokens"] + == par_result.budget_summary["total_input_tokens"] + ) + + +def test_pipeline_parallel_isolates_chunk_failures(): + """If extract_facts_from_chunk raises on one chunk, the others + should still complete and the doc shouldn't blow up.""" + import threading + + class _IntermittentFake: + def __init__(self) -> None: + self._lock = threading.Lock() + self.calls = 0 + + def generate_json(self, *, system, user, schema, model, max_tokens=1024): + with self._lock: + self.calls += 1 + n = self.calls + if n == 2: + raise RuntimeError("simulated failure on second chunk") + # Return a benign empty response for others + return LLMResponse( + content={"facts": []}, input_tokens=80, output_tokens=10, + model=model, raw_text="{}", + ) + + def embed(self, texts, *, model): + return FakeLLMClient(embedding_dim=32).embed(texts, model=model) + + fake = _IntermittentFake() + config = InsightConfig( + retrieval_top_k=4, max_chunks_per_document=4, chunk_workers=4, + ) + pipeline = InsightPipeline(llm_client=fake, config=config) + # Must not raise — failed chunk is logged and skipped + result = pipeline.run(QUESTION_SUDAN, [DOC_WHO_SUDAN]) + + # The doc still counts as processed + assert result.documents_processed == 1 + # All four chunks were attempted (the failure didn't abort the rest) + assert fake.calls == 4 + + # --------------------------------------------------------------------------- # Multi-document end-to-end # --------------------------------------------------------------------------- From efe62c3ff27be1010c057b99496355f4c577642b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20Mod=C3=A9e?= Date: Tue, 26 May 2026 14:06:43 +0200 Subject: [PATCH 13/21] Add controlled vocabulary for metric_name in extraction prompt MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Live observation from items 6 (partial-date dedup) was that the model emitted lots of different metric_name strings for what was essentially the same metric — "confirmed cases" / "cases" / "reported cases" / "total cases" / "total_cases" / "cholera cases" / "new cases" / "Cases" all appeared in a single live run, and they prevented the dedup logic from merging facts about the same event. ECDC alone produced 9 distinct metric_name variants of "case count". Tests show that listing a canonical snake_case vocabulary directly in the extraction prompt — with explicit guidance that qualifiers (sex, sub-region, time-period) belong in `summary` or `location` rather than in `metric_name` — collapses the diversity dramatically: 17 unique metric_names → 4–6 across the same 6 real biosecurity test documents, all drawn from the canonical list. The model can still invent a short snake_case label when none of the canonical values fit. The vocabulary covers the common biosecurity metrics observed in live tests (confirmed_cases, suspected_cases, probable_cases, confirmed_or_probable_cases, deaths, hospitalizations, recoveries, vaccinations_administered, vaccine_doses_distributed, affected_herds, affected_animals, new_outbreaks_declared, reproductive_number, case_fatality_ratio). This change works together with the value-aware dedup added in the follow-up commit — together they let real duplicates merge cleanly while preventing the merge from being too aggressive when the model misattributes locations. Co-Authored-By: Claude Opus 4.7 --- bioscancast/insight/extraction/prompts.py | 24 ++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/bioscancast/insight/extraction/prompts.py b/bioscancast/insight/extraction/prompts.py index 537d392..0f74b4c 100644 --- a/bioscancast/insight/extraction/prompts.py +++ b/bioscancast/insight/extraction/prompts.py @@ -37,7 +37,29 @@ ``YYYY-MM`` when only a month is given (e.g. "January 2026"), or \ ``YYYY`` when only a year is given. Do NOT invent a day-of-month when \ the chunk only mentions a month. -6. Be aware of cognitive biases that affect information processing: +6. For metric_name, use one of these canonical snake_case values when \ +applicable (this lets downstream dedup merge facts about the same \ +metric across sources): + - confirmed_cases (suspected, probable, possible all get \ +their own variants below) + - suspected_cases + - probable_cases + - confirmed_or_probable_cases + - deaths + - hospitalizations + - recoveries + - vaccinations_administered + - vaccine_doses_distributed + - affected_herds (animal disease — herds/farms affected) + - affected_animals + - new_outbreaks_declared + - reproductive_number (R0, Rt) + - case_fatality_ratio + If none of these fit, invent a short snake_case label. Do NOT put \ +qualifiers (sex, age, sub-region, time-period like "weekly") in \ +metric_name — capture those in `summary` or `location` instead. \ +"cases", "reported cases", "total cases" all map to confirmed_cases. +7. Be aware of cognitive biases that affect information processing: - Anchoring: do not over-weight the first number you encounter. - Availability: rare dramatic events are not necessarily more likely. - Overconfidence: if the chunk is ambiguous, lower your confidence. From 95e47383de7614601cdf12389876cf9115979402 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20Mod=C3=A9e?= Date: Tue, 26 May 2026 14:07:01 +0200 Subject: [PATCH 14/21] Value-aware dedup: refuse to merge when metric_values disagree MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The two-stage dedup added in item 6 groups records by (event_type, metric_name, normalized_location) and merges any whose date buckets overlap. But it didn't compare metric_value — so two records that share a dedup key with overlapping dates AND disagreeing values would still merge, silently dropping one value. This matters in practice because LLMs occasionally misattribute locations. Live test on the WHO cholera doc exposed exactly this: the model emitted "In January 2026, the African Region reported the highest number of cases (9782 cases; 13 countries)" with location=DRC and value=9782 (the African Region figure incorrectly tagged with DRC). Under the previous dedup logic this merged into the legitimate DRC 6543 records, hiding the attribution error and silently dropping the 9782 value from the dataframe. New `_values_compatible(v1, v2)` helper allows merging when: - Both values are None (no count claimed) - Either value is None (one source omitted the count) - Values are equal - Values are within 1% relative tolerance (accommodates rounding e.g. "about 6500" vs "6543" — same fact, different precision) Values further apart are treated as distinct facts and kept as separate records, surfacing the conflict to downstream consumers rather than burying it. Live verification on the 6 real biosecurity documents: WHO cholera went from 1 (over-merged, value lost) → 2 records (legitimate DRC merge + Africa Region preserved separately). Total record count 35 → 40, all 5 additional records are legitimate distinct facts the vocabulary-only configuration silently dropped. Three new tests cover: distinct-value rejection, within-tolerance merge (rounding), and one-value-None merge. Co-Authored-By: Claude Opus 4.7 --- bioscancast/insight/pipeline.py | 42 ++++++++++- bioscancast/tests/test_insight_pipeline.py | 83 ++++++++++++++++++++++ 2 files changed, 122 insertions(+), 3 deletions(-) diff --git a/bioscancast/insight/pipeline.py b/bioscancast/insight/pipeline.py index 58237de..d73d347 100644 --- a/bioscancast/insight/pipeline.py +++ b/bioscancast/insight/pipeline.py @@ -241,6 +241,31 @@ def _dates_overlap( return _date_bucket(d1, coarser) == _date_bucket(d2, coarser) +def _values_compatible( + v1: Optional[float], v2: Optional[float], *, rel_tol: float = 0.01 +) -> bool: + """Check whether two metric_values are close enough to be the same fact. + + Both-None counts as compatible (the "no value" bucket). One-None is + compatible with a known value (one source happened to omit the + count). Two known values are compatible when their relative + difference is within ``rel_tol`` (default 1%) — this accommodates + rounding (e.g. "6500" vs "6543") without merging genuinely + different counts (e.g. African Region 9782 vs DRC 6543 — which + happens when the model misattributes a regional quote to a + specific country). + """ + if v1 is None or v2 is None: + return True + if v1 == v2: + return True + # Relative difference; guard against division by zero + denom = max(abs(v1), abs(v2)) + if denom == 0: + return v1 == v2 + return abs(v1 - v2) / denom <= rel_tol + + def _record_dedup_key(record: InsightRecord) -> tuple: """First-stage dedup key: groups records that *might* be duplicates. @@ -315,12 +340,23 @@ def _deduplicate_records(records: list[InsightRecord]) -> list[InsightRecord]: for record in group: target = None for existing in merged: - if _dates_overlap( + if not _dates_overlap( record.event_date, record.event_date_precision, existing.event_date, existing.event_date_precision, ): - target = existing - break + continue + if not _values_compatible( + record.metric_value, existing.metric_value, + ): + # Same dedup key + overlapping dates but different + # numeric values — almost always a model attribution + # error (e.g. regional total mistakenly tagged with + # a country location). Keep both records so the + # conflict is visible downstream rather than + # silently dropped. + continue + target = existing + break if target is not None: _merge_record_into(target, record) else: diff --git a/bioscancast/tests/test_insight_pipeline.py b/bioscancast/tests/test_insight_pipeline.py index 59d252c..1d6b9ba 100644 --- a/bioscancast/tests/test_insight_pipeline.py +++ b/bioscancast/tests/test_insight_pipeline.py @@ -404,6 +404,89 @@ def test_dedup_does_not_merge_across_different_locations(): assert len(result) == 2 +# --------------------------------------------------------------------------- +# Value-aware dedup (catches model location-attribution errors) +# --------------------------------------------------------------------------- + + +def _record_with_value( + record_id: str, + metric_value: float, + *, + event_date=None, + event_date_precision=None, + document_id: str = "doc-test", +): + """Build a record with a specific metric_value for value-conflict tests.""" + from bioscancast.schemas import InsightRecord, ChunkReference + from datetime import datetime + if event_date is None: + event_date = datetime(2026, 1, 1) + return InsightRecord( + id=record_id, + question_id="q-test", + event_type="case_count", + confidence=0.8, + location="Democratic Republic of the Congo", + metric_name="confirmed_cases", + metric_value=metric_value, + event_date=event_date, + event_date_precision=event_date_precision or "month", + sources=[ChunkReference( + document_id=document_id, + chunk_id=f"chunk-{record_id}", + source_url=f"https://example.com/{document_id}", + quote="some quote", + )], + ) + + +def test_dedup_does_not_merge_when_metric_values_differ_significantly(): + """Two records sharing event_type, metric_name, location and date + bucket should still NOT merge when their metric_values disagree — + this is almost always a model location-attribution error (e.g. + regional 9782 cases mistakenly labelled with the DRC's location). + Live tests on the WHO cholera doc exposed this exact case.""" + from bioscancast.insight.pipeline import _deduplicate_records + + drc_real = _record_with_value("drc", 6543.0) + africa_misattributed = _record_with_value( + "africa_misattributed", 9782.0, document_id="doc-other", + ) + result = _deduplicate_records([drc_real, africa_misattributed]) + assert len(result) == 2, ( + f"Expected separate records for different values, got {len(result)}" + ) + + +def test_dedup_merges_when_metric_values_are_close(): + """Rounding differences within 1% should still merge — one source + rounding to 6500 while another reports 6543 doesn't mean they're + different facts.""" + from bioscancast.insight.pipeline import _deduplicate_records + + exact = _record_with_value("exact", 6543.0) + rounded = _record_with_value( + "rounded", 6540.0, document_id="doc-other", # within 1% + ) + result = _deduplicate_records([exact, rounded]) + assert len(result) == 1 + + +def test_dedup_merges_when_one_value_is_none(): + """When one record has no metric_value (qualitative claim) and + another has a value, they should still merge if dedup keys and + dates align — the value-aware check treats None as compatible.""" + from bioscancast.insight.pipeline import _deduplicate_records + + valued = _record_with_value("valued", 6543.0) + no_value = _record_with_value("no_value", None, document_id="doc-other") + result = _deduplicate_records([valued, no_value]) + assert len(result) == 1 + # The merged record retains the actual value + assert result[0].metric_value == 6543.0 + + # --------------------------------------------------------------------------- # Per-chunk parallelism (ThreadPoolExecutor) # --------------------------------------------------------------------------- From ce3ba1141fc9cbc282bf69f520af3b4f1af40040 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20Mod=C3=A9e?= Date: Tue, 26 May 2026 14:15:41 +0200 Subject: [PATCH 15/21] Migrate search stage to shared LLMClient and remove legacy client MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Item 10 stopped short of full retirement: it migrated filtering to the shared `bioscancast.llm.base.LLMClient` protocol but left `bioscancast.llm.client` in place because the search stage still called it. This change finishes the migration and deletes the legacy module — one LLM protocol for the whole codebase. Search-stage changes: - `bioscancast/stages/search_stage/query_decomposition.py` now builds (system, user, schema) triples for both the question-type classifier and the sub-query decomposer. Two new JSON schemas (`CLASSIFY_SCHEMA`, `DECOMPOSE_SCHEMA`) constrain the model output to the existing QUESTION_TYPES and VALID_AXES enums. Calls switch from `generate_json(prompt) -> dict` to the kwargs form returning `LLMResponse`. - `bioscancast/stages/search_stage/pipeline.py` imports `LLMClient` from `bioscancast.llm.base` instead of the legacy module. - `scripts/run_search_stage.py` instantiates `OpenAILLMClient` (the production class for the shared protocol) instead of the legacy `OpenAIClient`. Test updates: - `bioscancast/tests/test_query_decomposition.py` rewritten: `FakeLLMClient` now implements the shared protocol (kwargs + `LLMResponse`), responses are wrapped via a small `_resp` helper, and two new tests assert the right schema is passed to each call. - `bioscancast/tests/test_filtering_llm.py` docstring updated to note the legacy module is gone; the regression check is kept around as a guard against anyone reintroducing it. Cleanup: - `bioscancast/llm/client.py` deleted. - `bioscancast/llm/__init__.py` no longer exports the now-defunct `FilteringLLMClient` / `OpenAIClient` aliases; nothing in the repo referenced them. - The follow-up TODO line in `bioscancast/insight/README.md` is removed — there is no remaining migration debt. All 348 tests still pass. Co-Authored-By: Claude Opus 4.7 --- bioscancast/insight/README.md | 1 - bioscancast/llm/__init__.py | 9 +- bioscancast/llm/client.py | 52 ----- bioscancast/stages/search_stage/pipeline.py | 2 +- .../search_stage/query_decomposition.py | 205 +++++++++++++----- bioscancast/tests/test_filtering_llm.py | 13 +- bioscancast/tests/test_query_decomposition.py | 116 +++++++--- scripts/run_search_stage.py | 4 +- 8 files changed, 251 insertions(+), 151 deletions(-) delete mode 100644 bioscancast/llm/client.py diff --git a/bioscancast/insight/README.md b/bioscancast/insight/README.md index e5dd040..d2dbd22 100644 --- a/bioscancast/insight/README.md +++ b/bioscancast/insight/README.md @@ -44,5 +44,4 @@ See `config.py` for all configurable values (`InsightConfig`). ## TODO - [ ] Cognitive bias mitigations belong primarily in the **forecasting** stage's prompts, not here. Insight extraction is neutral fact-finding. A brief reminder is included in the extraction prompt but full bias mitigation should be implemented in forecasting. -- [ ] Migrate `bioscancast/stages/search_stage/`'s use of the legacy `bioscancast/llm/client.py` to the shared `bioscancast/llm/base.py` protocol (filtering has already been migrated; the legacy client lingers only because the search stage still calls it). - [ ] Strong model refinement pass (behind `use_strong_model_refinement` config flag, currently a no-op). diff --git a/bioscancast/llm/__init__.py b/bioscancast/llm/__init__.py index a11bbff..d02c228 100644 --- a/bioscancast/llm/__init__.py +++ b/bioscancast/llm/__init__.py @@ -1,14 +1,9 @@ -from .client import LLMClient as FilteringLLMClient, OpenAIClient -from .base import LLMClient as InsightLLMClient, LLMResponse +from .base import LLMClient, LLMResponse from .fake_client import FakeLLMClient from .openai_client import OpenAILLMClient __all__ = [ - # Legacy (filtering stage) - "FilteringLLMClient", - "OpenAIClient", - # New (insight stage and beyond) - "InsightLLMClient", + "LLMClient", "LLMResponse", "FakeLLMClient", "OpenAILLMClient", diff --git a/bioscancast/llm/client.py b/bioscancast/llm/client.py deleted file mode 100644 index 5bb983d..0000000 --- a/bioscancast/llm/client.py +++ /dev/null @@ -1,52 +0,0 @@ -"""Legacy LLM client kept for the search stage's transitional period. - -This module pre-dates ``bioscancast/llm/base.py`` and uses a simpler -``generate_json(prompt: str) -> dict`` signature. New code should use -``bioscancast.llm.base.LLMClient`` (with structured ``system/user/schema/ -model/max_tokens`` arguments returning an ``LLMResponse``) and the -production client at ``bioscancast/llm/openai_client.py``. - -Filtering and the insight stage already use the modern protocol. The -search stage (``bioscancast/stages/search_stage/``) is the only -remaining consumer of this module and will be migrated in a follow-up. -Do not add new callers here. -""" - -from __future__ import annotations - -import json -import os -from typing import Any, Optional, Protocol - - -class LLMClient(Protocol): - def generate_json(self, prompt: str) -> dict: ... - - -class OpenAIClient: - """Concrete LLM client using OpenAI's chat completions API.""" - - def __init__( - self, - api_key: Optional[str] = None, - model: str = "gpt-4o-mini", - temperature: float = 0.3, - seed: int = 42, - ) -> None: - import openai - - self._client = openai.OpenAI(api_key=api_key or os.environ["OPENAI_API_KEY"]) - self._model = model - self._temperature = temperature - self._seed = seed - - def generate_json(self, prompt: str) -> dict: - response = self._client.chat.completions.create( - model=self._model, - messages=[{"role": "user", "content": prompt}], - response_format={"type": "json_object"}, - temperature=self._temperature, - seed=self._seed, - ) - text = response.choices[0].message.content or "{}" - return json.loads(text) diff --git a/bioscancast/stages/search_stage/pipeline.py b/bioscancast/stages/search_stage/pipeline.py index 3012858..edc573f 100644 --- a/bioscancast/stages/search_stage/pipeline.py +++ b/bioscancast/stages/search_stage/pipeline.py @@ -14,7 +14,7 @@ from bioscancast.filtering.config import FILTER_CONFIG from bioscancast.filtering.models import ForecastQuestion, SearchResult -from bioscancast.llm.client import LLMClient +from bioscancast.llm.base import LLMClient from bioscancast.stages.search_stage.backends.base import RawSearchResult, SearchBackend from bioscancast.stages.search_stage.cache import SearchCache from bioscancast.stages.search_stage.dashboard_lookup import lookup_dashboards diff --git a/bioscancast/stages/search_stage/query_decomposition.py b/bioscancast/stages/search_stage/query_decomposition.py index a1b14d9..4d9bbdd 100644 --- a/bioscancast/stages/search_stage/query_decomposition.py +++ b/bioscancast/stages/search_stage/query_decomposition.py @@ -4,6 +4,9 @@ 1. Classifies the question type (outbreak_count, binary_event, etc.) 2. Decomposes it into 5-8 search-engine-optimised sub-queries 3. Validates sub-query word counts (2-8 words) post-hoc + +Uses the shared ``bioscancast.llm.base.LLMClient`` protocol (structured +system/user/schema/model/max_tokens, returns ``LLMResponse``). """ from __future__ import annotations @@ -15,10 +18,20 @@ from typing import List from bioscancast.filtering.models import ForecastQuestion -from bioscancast.llm.client import LLMClient +from bioscancast.llm.base import LLMClient logger = logging.getLogger(__name__) + +# Default model and output-token cap for search-stage LLM calls. Both +# classification and decomposition produce small JSON responses; these +# can be tuned via the optional ``model`` / ``max_tokens`` kwargs on the +# public functions. +DEFAULT_QUERY_MODEL = "gpt-4o-mini" +DEFAULT_CLASSIFY_MAX_TOKENS = 256 +DEFAULT_DECOMPOSE_MAX_TOKENS = 1024 + + VALID_AXES: set[str] = { "latest_data", "trend", @@ -44,6 +57,59 @@ } +# --------------------------------------------------------------------------- +# JSON schemas for the LLM responses +# --------------------------------------------------------------------------- + +CLASSIFY_SCHEMA: dict = { + "type": "object", + "properties": { + "question_type": { + "type": "string", + "enum": sorted(QUESTION_TYPES), + }, + }, + "required": ["question_type"], + "additionalProperties": False, +} + +DECOMPOSE_SCHEMA: dict = { + "type": "object", + "properties": { + "sub_queries": { + "type": "array", + "items": { + "type": "object", + "properties": { + "text": {"type": "string"}, + "axis": {"type": "string", "enum": sorted(VALID_AXES)}, + }, + "required": ["text", "axis"], + "additionalProperties": False, + }, + }, + }, + "required": ["sub_queries"], + "additionalProperties": False, +} + + +CLASSIFY_SYSTEM_PROMPT = ( + "You classify biosecurity forecast questions into exactly one of the " + "following types: outbreak_count (how many cases by date X), " + "binary_event (will event X occur by date Y), " + "mechanism_or_attribution (what caused X), " + "unknown (if none fit clearly). Return JSON matching the schema." +) + +DECOMPOSE_SYSTEM_PROMPT = ( + "You decompose biosecurity forecast questions into 5-8 search-engine-" + "optimised sub-queries. Each sub-query is 2-8 words and targets one " + "of the allowed information axes. Return JSON matching the schema. " + "No prose." +) + + @dataclass class SubQuery: id: str @@ -52,72 +118,84 @@ class SubQuery: axis: str -def classify_question_type(question: ForecastQuestion, llm_client: LLMClient) -> str: - """Classify a forecast question into one of the known question types. - - Design decision: uses an LLM call rather than keyword heuristics. - The LLM approach is more flexible for novel question phrasings, and the - cost is negligible (one small JSON call). Falls back to "unknown" on - any failure, which routes to all axes — safe but slightly wasteful. - Revisit if classification latency or cost becomes an issue. - """ - prompt = json.dumps( +def _build_classify_user_payload(question: ForecastQuestion) -> str: + return json.dumps( { - "task": ( - "Classify this biosecurity forecast question into exactly one type. " - "Return JSON: {\"question_type\": \"\"}. " - "Types: outbreak_count (how many cases by date X), " - "binary_event (will event X occur by date Y), " - "mechanism_or_attribution (what caused X), " - "unknown (if none fit clearly)." - ), "question": question.text, "pathogen": question.pathogen, "event_type": question.event_type, - } + }, + default=str, ) - try: - result = llm_client.generate_json(prompt) - qtype = result.get("question_type", "unknown") - if qtype not in QUESTION_TYPES: - logger.warning("LLM returned unknown question type '%s', falling back to 'unknown'", qtype) - return "unknown" - return qtype - except Exception: - logger.exception("Question classification failed, defaulting to 'unknown'") - return "unknown" -def _build_decomposition_prompt( +def _build_decompose_user_payload( question: ForecastQuestion, question_type: str, + *, historical_roleplay: bool = False, ) -> str: axes = AXES_BY_TYPE.get(question_type, list(VALID_AXES)) - task_lines = [ - "Decompose this biosecurity forecast question into 5-8 search-engine-optimised " - "sub-queries. Each sub-query should be 2-8 words and target a specific information " - "axis. Return strict JSON: {\"sub_queries\": [{\"text\": \"...\", \"axis\": \"...\"}]}. " - "No prose." - ] + payload: dict = { + "question": question.text, + "pathogen": question.pathogen, + "region": question.region, + "target_date": ( + question.target_date.isoformat() if question.target_date else None + ), + "allowed_axes": axes, + } if historical_roleplay and question.as_of_date is not None: - task_lines.append( - "IMPORTANT: Generate sub-queries as if today were " + # Benchmark-only: ask the LLM to query as if today were the + # ForecastQuestion's cutoff date. Carried as a JSON field in the + # user payload rather than concatenated into the task instruction + # because the system prompt is the structured-LLM-aware mechanism + # for instructions in the new protocol. Functionally equivalent. + payload["historical_roleplay_instruction"] = ( + "Generate sub-queries as if today were " f"{question.as_of_date.date().isoformat()}. Do not assume knowledge " "of events, named entities, or facts that you only learned about " "after that date. Phrase queries in terms a forecaster on that " "date would have used." ) - return json.dumps( - { - "task": " ".join(task_lines), - "question": question.text, - "pathogen": question.pathogen, - "region": question.region, - "target_date": question.target_date.isoformat() if question.target_date else None, - "allowed_axes": axes, - } - ) + return json.dumps(payload, default=str) + + +def classify_question_type( + question: ForecastQuestion, + llm_client: LLMClient, + *, + model: str = DEFAULT_QUERY_MODEL, + max_tokens: int = DEFAULT_CLASSIFY_MAX_TOKENS, +) -> str: + """Classify a forecast question into one of the known question types. + + Design decision: uses an LLM call rather than keyword heuristics. + The LLM approach is more flexible for novel question phrasings, and the + cost is negligible (one small JSON call). Falls back to "unknown" on + any failure, which routes to all axes — safe but slightly wasteful. + Revisit if classification latency or cost becomes an issue. + """ + try: + response = llm_client.generate_json( + system=CLASSIFY_SYSTEM_PROMPT, + user=_build_classify_user_payload(question), + schema=CLASSIFY_SCHEMA, + model=model, + max_tokens=max_tokens, + ) + except Exception: + logger.exception("Question classification failed, defaulting to 'unknown'") + return "unknown" + + qtype = response.content.get("question_type", "unknown") + if qtype not in QUESTION_TYPES: + logger.warning( + "LLM returned unknown question type '%s', falling back to 'unknown'", + qtype, + ) + return "unknown" + return qtype def _validate_word_count(text: str) -> str | None: @@ -175,29 +253,40 @@ def decompose_question( llm_client: LLMClient, *, historical_roleplay: bool = False, + model: str = DEFAULT_QUERY_MODEL, + classify_max_tokens: int = DEFAULT_CLASSIFY_MAX_TOKENS, + decompose_max_tokens: int = DEFAULT_DECOMPOSE_MAX_TOKENS, ) -> List[SubQuery]: """Decompose a forecast question into sub-queries using an LLM. Falls back to simple keyword-based sub-queries if the LLM fails. ``historical_roleplay`` is an opt-in benchmark-only flag. When True AND - ``question.as_of_date`` is set, the prompt is extended with an instruction - asking the LLM to query as if today were the cutoff date. This is gated - behind its own flag because prompt-level roleplay can have hard-to-predict - effects on query quality. + ``question.as_of_date`` is set, the user payload is extended with an + instruction asking the LLM to query as if today were the cutoff date. + This is gated behind its own flag because prompt-level roleplay can have + hard-to-predict effects on query quality. """ - question_type = classify_question_type(question, llm_client) - prompt = _build_decomposition_prompt( - question, question_type, historical_roleplay=historical_roleplay + question_type = classify_question_type( + question, llm_client, model=model, max_tokens=classify_max_tokens, ) try: - result = llm_client.generate_json(prompt) + response = llm_client.generate_json( + system=DECOMPOSE_SYSTEM_PROMPT, + user=_build_decompose_user_payload( + question, question_type, + historical_roleplay=historical_roleplay, + ), + schema=DECOMPOSE_SCHEMA, + model=model, + max_tokens=decompose_max_tokens, + ) except Exception: logger.exception("LLM decomposition failed, using fallback sub-queries") return _fallback_subqueries(question) - raw_queries = result.get("sub_queries", []) + raw_queries = response.content.get("sub_queries", []) if not isinstance(raw_queries, list): logger.warning("LLM returned non-list sub_queries, using fallback") return _fallback_subqueries(question) diff --git a/bioscancast/tests/test_filtering_llm.py b/bioscancast/tests/test_filtering_llm.py index e7891c2..805afd7 100644 --- a/bioscancast/tests/test_filtering_llm.py +++ b/bioscancast/tests/test_filtering_llm.py @@ -1,8 +1,9 @@ """Tests for the LLM-driven filtering stage. -The filter uses the shared ``bioscancast.llm.base.LLMClient`` protocol — -the older single-positional-argument ``bioscancast.llm.client.LLMClient`` -is no longer accepted here. +The filter uses the shared ``bioscancast.llm.base.LLMClient`` protocol. +The older single-positional-argument legacy LLMClient module has been +fully removed from the codebase; the regression test at the bottom of +the file blocks anyone from reintroducing it under the original path. """ from __future__ import annotations @@ -174,8 +175,10 @@ def embed(self, *_, **__): def test_llm_filter_does_not_use_legacy_client_module(): - """The migration: the filter module must NOT import from - bioscancast.llm.client (the legacy single-positional protocol).""" + """Regression check: the filter module must NOT import from + bioscancast.llm.client (the legacy single-positional protocol that + used to exist before the migration). The module is gone now, but + keep this check so nothing reintroduces it under that path.""" import bioscancast.filtering.llm_filter as mod src = open(mod.__file__, encoding="utf-8").read() assert "from bioscancast.llm.client" not in src diff --git a/bioscancast/tests/test_query_decomposition.py b/bioscancast/tests/test_query_decomposition.py index 97442d0..c5dbc33 100644 --- a/bioscancast/tests/test_query_decomposition.py +++ b/bioscancast/tests/test_query_decomposition.py @@ -1,7 +1,17 @@ +"""Tests for the LLM-driven search-stage query decomposition. + +Uses the shared ``bioscancast.llm.base.LLMClient`` protocol — the +``generate_json`` calls pass system/user/schema/model/max_tokens +kwargs and return ``LLMResponse`` objects. +""" + from datetime import datetime, timezone from bioscancast.filtering.models import ForecastQuestion +from bioscancast.llm.base import LLMResponse from bioscancast.stages.search_stage.query_decomposition import ( + CLASSIFY_SCHEMA, + DECOMPOSE_SCHEMA, VALID_AXES, classify_question_type, decompose_question, @@ -20,47 +30,87 @@ def _make_question(**overrides): return ForecastQuestion(**defaults) +def _resp(content: dict, input_tokens: int = 80, output_tokens: int = 20) -> LLMResponse: + """Build a minimal LLMResponse for tests.""" + return LLMResponse( + content=content, + input_tokens=input_tokens, + output_tokens=output_tokens, + model="gpt-4o-mini", + raw_text="{}", + ) + + class FakeLLMClient: - """Mock LLM client that returns canned responses.""" + """Mock LLM client implementing the shared + ``bioscancast.llm.base.LLMClient`` protocol — FIFO scripted responses + keyed by call order, not by content.""" def __init__(self, responses=None): - self._responses = responses or [] + self._responses = list(responses or []) self._call_count = 0 - - def generate_json(self, prompt: str) -> dict: + self.recorded_calls: list[dict] = [] + + def generate_json( + self, + *, + system: str, + user: str, + schema: dict, + model: str, + max_tokens: int = 1024, + ) -> LLMResponse: + self.recorded_calls.append({ + "system": system, "user": user, "schema": schema, + "model": model, "max_tokens": max_tokens, + }) if self._call_count < len(self._responses): resp = self._responses[self._call_count] self._call_count += 1 return resp self._call_count += 1 - return {} + return _resp({}) + + def embed(self, texts, *, model): + raise NotImplementedError("query decomposition doesn't embed") class TestClassifyQuestionType: def test_returns_outbreak_count(self): - llm = FakeLLMClient([{"question_type": "outbreak_count"}]) + llm = FakeLLMClient([_resp({"question_type": "outbreak_count"})]) assert classify_question_type(_make_question(), llm) == "outbreak_count" def test_returns_binary_event(self): - llm = FakeLLMClient([{"question_type": "binary_event"}]) + llm = FakeLLMClient([_resp({"question_type": "binary_event"})]) assert classify_question_type(_make_question(), llm) == "binary_event" def test_invalid_type_falls_back_to_unknown(self): - llm = FakeLLMClient([{"question_type": "nonsense"}]) + llm = FakeLLMClient([_resp({"question_type": "nonsense"})]) assert classify_question_type(_make_question(), llm) == "unknown" def test_llm_failure_falls_back_to_unknown(self): class FailingLLM: - def generate_json(self, prompt: str) -> dict: + def generate_json(self, **_): raise RuntimeError("LLM down") + def embed(self, *_, **__): + raise NotImplementedError assert classify_question_type(_make_question(), FailingLLM()) == "unknown" + def test_calls_new_protocol_with_classify_schema(self): + llm = FakeLLMClient([_resp({"question_type": "outbreak_count"})]) + classify_question_type(_make_question(), llm) + assert len(llm.recorded_calls) == 1 + call = llm.recorded_calls[0] + assert call["schema"] is CLASSIFY_SCHEMA + assert isinstance(call["system"], str) and "biosecurity" in call["system"] + assert isinstance(call["user"], str) + class TestDecomposeQuestion: def test_produces_valid_subqueries(self): - classify_resp = {"question_type": "outbreak_count"} - decompose_resp = { + classify_resp = _resp({"question_type": "outbreak_count"}) + decompose_resp = _resp({ "sub_queries": [ {"text": "H5N1 human cases US 2025", "axis": "latest_data"}, {"text": "H5N1 outbreak trend growth", "axis": "trend"}, @@ -68,7 +118,7 @@ def test_produces_valid_subqueries(self): {"text": "H5N1 historical human cases", "axis": "historical_analogy"}, {"text": "bird flu cases latest report", "axis": "latest_data"}, ] - } + }) llm = FakeLLMClient([classify_resp, decompose_resp]) result = decompose_question(_make_question(), llm) @@ -82,15 +132,15 @@ def test_produces_valid_subqueries(self): assert 2 <= word_count <= 8 def test_invalid_axis_dropped(self): - classify_resp = {"question_type": "unknown"} - decompose_resp = { + classify_resp = _resp({"question_type": "unknown"}) + decompose_resp = _resp({ "sub_queries": [ {"text": "H5N1 cases latest", "axis": "latest_data"}, {"text": "some bad axis query", "axis": "invalid_axis"}, {"text": "bird flu trend analysis", "axis": "trend"}, {"text": "avian influenza policy response", "axis": "policy"}, ] - } + }) llm = FakeLLMClient([classify_resp, decompose_resp]) result = decompose_question(_make_question(), llm) @@ -98,8 +148,8 @@ def test_invalid_axis_dropped(self): assert "invalid_axis" not in axes def test_too_long_query_truncated(self): - classify_resp = {"question_type": "unknown"} - decompose_resp = { + classify_resp = _resp({"question_type": "unknown"}) + decompose_resp = _resp({ "sub_queries": [ {"text": "one two three four five six seven eight nine ten", "axis": "latest_data"}, {"text": "H5N1 trend data", "axis": "trend"}, @@ -107,7 +157,7 @@ def test_too_long_query_truncated(self): {"text": "bird flu expert analysis view", "axis": "expert_opinion"}, {"text": "historical outbreak comparison data", "axis": "historical_analogy"}, ] - } + }) llm = FakeLLMClient([classify_resp, decompose_resp]) result = decompose_question(_make_question(), llm) @@ -119,32 +169,48 @@ class FailingLLM: def __init__(self): self.calls = 0 - def generate_json(self, prompt: str) -> dict: + def generate_json(self, **_): self.calls += 1 if self.calls == 1: - return {"question_type": "unknown"} + return _resp({"question_type": "unknown"}) raise RuntimeError("LLM decomposition failed") + def embed(self, *_, **__): + raise NotImplementedError + result = decompose_question(_make_question(), FailingLLM()) assert len(result) >= 1 for sq in result: assert sq.axis in VALID_AXES def test_malformed_response_uses_fallback(self): - classify_resp = {"question_type": "unknown"} - decompose_resp = {"sub_queries": "not a list"} + classify_resp = _resp({"question_type": "unknown"}) + decompose_resp = _resp({"sub_queries": "not a list"}) llm = FakeLLMClient([classify_resp, decompose_resp]) result = decompose_question(_make_question(), llm) assert len(result) >= 1 def test_caps_at_8(self): - classify_resp = {"question_type": "unknown"} - decompose_resp = { + classify_resp = _resp({"question_type": "unknown"}) + decompose_resp = _resp({ "sub_queries": [ {"text": f"query number {i} text", "axis": "latest_data"} for i in range(12) ] - } + }) llm = FakeLLMClient([classify_resp, decompose_resp]) result = decompose_question(_make_question(), llm) assert len(result) <= 8 + + def test_uses_decompose_schema_on_second_call(self): + classify_resp = _resp({"question_type": "outbreak_count"}) + decompose_resp = _resp({"sub_queries": [ + {"text": "H5N1 cases US 2025", "axis": "latest_data"}, + {"text": "avian flu trend", "axis": "trend"}, + {"text": "USDA policy avian", "axis": "policy"}, + ]}) + llm = FakeLLMClient([classify_resp, decompose_resp]) + decompose_question(_make_question(), llm) + assert len(llm.recorded_calls) == 2 + assert llm.recorded_calls[0]["schema"] is CLASSIFY_SCHEMA + assert llm.recorded_calls[1]["schema"] is DECOMPOSE_SCHEMA diff --git a/scripts/run_search_stage.py b/scripts/run_search_stage.py index ae45f62..d76707f 100644 --- a/scripts/run_search_stage.py +++ b/scripts/run_search_stage.py @@ -35,7 +35,7 @@ pass # python-dotenv not installed; keys must be in environment directly from bioscancast.filtering.models import ForecastQuestion -from bioscancast.llm.client import OpenAIClient +from bioscancast.llm.openai_client import OpenAILLMClient from bioscancast.stages.search_stage.backends.tavily_backend import TavilyBackend from bioscancast.stages.search_stage.cache import SearchCache from bioscancast.stages.search_stage.pipeline import SearchStagePipeline @@ -65,7 +65,7 @@ def main(): region=args.region, ) - llm_client = OpenAIClient() + llm_client = OpenAILLMClient() search_backend = TavilyBackend() cache = None if args.no_cache else SearchCache() From 366aed83cf8bd34def5af056589f9e65d534fe0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20Mod=C3=A9e?= Date: Tue, 26 May 2026 14:17:54 +0200 Subject: [PATCH 16/21] Remove now-redundant legacy-import regression test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous commit deleted bioscancast/llm/client.py. The regression test that asserted filtering didn't import from that module is now dead weight: any attempt to import the deleted module would fail at import time with ModuleNotFoundError — much louder than a source-text grep assertion. Source-text grep tests are generally fragile too: they couple to implementation details (the literal `from bioscancast.llm.client` string) rather than behaviour. The new-protocol tests already in this file verify the actual filter calls work correctly, which is what we actually care about. Co-Authored-By: Claude Opus 4.7 --- bioscancast/tests/test_filtering_llm.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/bioscancast/tests/test_filtering_llm.py b/bioscancast/tests/test_filtering_llm.py index 805afd7..f1d61ca 100644 --- a/bioscancast/tests/test_filtering_llm.py +++ b/bioscancast/tests/test_filtering_llm.py @@ -1,9 +1,6 @@ """Tests for the LLM-driven filtering stage. The filter uses the shared ``bioscancast.llm.base.LLMClient`` protocol. -The older single-positional-argument legacy LLMClient module has been -fully removed from the codebase; the regression test at the bottom of -the file blocks anyone from reintroducing it under the original path. """ from __future__ import annotations @@ -172,14 +169,3 @@ def embed(self, *_, **__): q = _make_question() assert llm_filter_candidates(q, [], {}, ExplodingFake()) == [] - - -def test_llm_filter_does_not_use_legacy_client_module(): - """Regression check: the filter module must NOT import from - bioscancast.llm.client (the legacy single-positional protocol that - used to exist before the migration). The module is gone now, but - keep this check so nothing reintroduces it under that path.""" - import bioscancast.filtering.llm_filter as mod - src = open(mod.__file__, encoding="utf-8").read() - assert "from bioscancast.llm.client" not in src - assert "bioscancast.llm.base" in src From e324d1961066d6a565eb3a22f98f9f1167099b44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20Mod=C3=A9e?= Date: Tue, 26 May 2026 14:51:38 +0200 Subject: [PATCH 17/21] Use trafilatura structural extraction for HTML sections MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The HTML parser was already calling trafilatura, but only using the plain-text output as a `raw_text` fallback. Sections were rebuilt by walking the entire raw DOM body — which on pages like CIDRAP (whose body contains the target article plus three unrelated articles plus a "top reads" sidebar plus a footer) produced 18 chunks of mostly noise. The investigation showed this wasn't a deliberate design decision; the original code just used trafilatura.extract() at its default text-only setting and didn't discover the structured output modes. Switching to `trafilatura.extract(output_format='xml', include_tables=True)` gives us a cleaned `
...` tree with the article body properly isolated. The new `_extract_sections_from_trafilatura_xml` walks that tree the same way the DOM walker walks BeautifulSoup, emitting heading-stack-aware sections plus tables. The previous DOM walker is kept as a fallback: when trafilatura's output has less than 200 chars of body text (listing pages, error pages, or pages whose layout trafilatura's heuristics misjudge), the parser falls through to the original code path so we never silently extract nothing. Title, published_date, and language continue to come from raw DOM head/meta queries — those are reliable regardless of which body path runs. Verified end-to-end: - CIDRAP fixture: 18 chunks (4 unrelated articles + nav + footer) → 2 chunks (just the Utah measles article). The "602 cases" headline survives; the insight pipeline still produces a record citing it. - ProMED fixture: chunks dropped from 7 to 6 (header tagline gone); the 157-row outbreak table is fully preserved in `table_rows`. - PDF docs (WHO mpox/cholera, CDC MMWR, ECDC, Africa CDC): unchanged (PDFs don't go through this parser). - Live spot-checks on 5 fresh HTML sources (WHO DON article, CIDRAP homepage, CDC HAN alert, Reuters healthcare landing, a 404 page) all behave correctly. Articles get clean structural extraction; listings get a small but useful section; 404s correctly take the fallback path. Test changes: - CIDRAP's min-chunks floor in test_insight_real_docs_integration.py drops to 1 (was 5) — that's the new, lower, cleaner extraction output. Other fixtures unchanged. - New `TestTrafilaturaXmlExtraction` class in test_extraction_html.py with synthetic-HTML cases for sibling-article stripping, table preservation, DOM-walker fallback on thin pages, and metadata extraction independent of which body path runs. All 351 tests pass. Co-Authored-By: Claude Opus 4.7 --- bioscancast/extraction/parsers/html_parser.py | 205 +++++++++++++++++- bioscancast/tests/test_extraction_html.py | 129 +++++++++++ .../test_insight_real_docs_integration.py | 27 ++- 3 files changed, 346 insertions(+), 15 deletions(-) diff --git a/bioscancast/extraction/parsers/html_parser.py b/bioscancast/extraction/parsers/html_parser.py index 58ab6a1..d37eab2 100644 --- a/bioscancast/extraction/parsers/html_parser.py +++ b/bioscancast/extraction/parsers/html_parser.py @@ -1,8 +1,10 @@ from __future__ import annotations +import logging import re from datetime import datetime from typing import List, Optional +from xml.etree import ElementTree as ET from bs4 import BeautifulSoup, Tag @@ -13,9 +15,27 @@ except ImportError: trafilatura = None # type: ignore[assignment] +logger = logging.getLogger(__name__) + +# Minimum extractable content for the trafilatura-XML path to be accepted. +# When the XML output's body text is shorter than this, we treat the +# extraction as having failed (e.g. listing pages where trafilatura can +# only isolate a snippet) and fall back to the DOM walker. The threshold +# is intentionally small so single-paragraph articles still go down the +# preferred path. +_MIN_TRAFILATURA_BODY_CHARS = 200 + class HtmlParser: - """Extracts structured content from HTML documents.""" + """Extracts structured content from HTML documents. + + Strategy: prefer trafilatura's structured (XML) main-content extraction + which strips navigation, sidebars, "related articles" sections, and + other site boilerplate that a raw DOM walk would otherwise pull in. + Fall back to a DOM walker when trafilatura returns too little + content — usually because the page is a listing/landing page + trafilatura can't isolate, or because trafilatura is not installed. + """ def can_parse(self, content_type: str, content: bytes) -> bool: if "html" in (content_type or ""): @@ -26,18 +46,46 @@ def can_parse(self, content_type: str, content: bytes) -> bool: def parse(self, content: bytes, *, source_url: str) -> ParsedContent: html_text = content.decode("utf-8", errors="replace") - # Use trafilatura for cleaned main-content text + # Use trafilatura for cleaned main-content text (plain) AND + # the structural XML output used for section extraction below. main_text = "" + main_xml = "" if trafilatura is not None: main_text = trafilatura.extract(html_text) or "" + main_xml = ( + trafilatura.extract( + html_text, + output_format="xml", + include_tables=True, + include_links=False, + ) + or "" + ) - # Parse with BeautifulSoup for structure + # Parse with BeautifulSoup for document-level metadata (title, + # date, language) — these are reliably in the raw DOM head/meta + # tags whether or not the body extraction succeeds. soup = BeautifulSoup(html_text, "html.parser") - title = self._extract_title(soup) published_date = self._extract_published_date(soup) language = self._extract_language(soup) - sections = self._extract_sections(soup) + + # Primary path: walk trafilatura's main-content XML. + sections = ( + self._extract_sections_from_trafilatura_xml(main_xml) + if main_xml + else [] + ) + # Fallback path: walk the full DOM when trafilatura's output is + # too thin to be trustworthy (listing pages, error pages, + # pages trafilatura's heuristics misjudge). + if not sections: + logger.debug( + "trafilatura XML extraction yielded no usable sections for %s " + "(xml_chars=%d); falling back to DOM walker", + source_url, len(main_xml), + ) + sections = self._extract_sections(soup) raw_text = main_text or soup.get_text(separator="\n", strip=True) @@ -49,6 +97,112 @@ def parse(self, content: bytes, *, source_url: str) -> ParsedContent: published_date=published_date, ) + # ---------------------------------------------------------------- + # Trafilatura XML → sections + # ---------------------------------------------------------------- + + def _extract_sections_from_trafilatura_xml( + self, xml_text: str + ) -> List[SectionContent]: + """Walk trafilatura's XML output (``
...
``) + to produce ordered sections. + + Trafilatura emits headings as ````, paragraphs as + ``

``, and tables as ```` with ```` containing + ```` (sometimes wrapping a ``

``). We rebuild a + heading-stack-aware section list with the same shape as the DOM + walker so downstream code doesn't care which path produced them. + + Returns an empty list when the XML body has less than + ``_MIN_TRAFILATURA_BODY_CHARS`` of body text — the caller treats + that as a signal to fall back to the DOM walker. + """ + try: + root = ET.fromstring(xml_text) + except ET.ParseError as exc: + logger.warning("Could not parse trafilatura XML: %s", exc) + return [] + + main_el = root.find("main") + if main_el is None: + return [] + + # Cheap quality gate: count printable body text. Headings + para + # text + cell text together must clear the threshold or we drop + # to the DOM walker. + body_chars = sum( + len(_collect_element_text(el)) + for el in main_el.iter() + if el.tag in ("head", "p", "cell") + ) + if body_chars < _MIN_TRAFILATURA_BODY_CHARS: + return [] + + sections: List[SectionContent] = [] + heading_stack: List[str] = [] + current_level = 0 + current_text_parts: List[str] = [] + + def flush_prose() -> None: + if not current_text_parts: + return + text = "\n".join(current_text_parts).strip() + if text: + sections.append( + SectionContent( + section_path=" > ".join(heading_stack) if heading_stack else None, + page_number=None, + text=text, + chunk_type="prose", + extractor="trafilatura", + ) + ) + current_text_parts.clear() + + for child in main_el: + tag = child.tag + if tag == "head": + flush_prose() + level = _heading_level(child.get("rend")) + heading_text = _collect_element_text(child) + if level <= current_level: + heading_stack = heading_stack[: level - 1] + heading_stack.append(heading_text) + current_level = level + + elif tag == "p" or tag == "quote": + p_text = _collect_element_text(child).strip() + if p_text: + current_text_parts.append(p_text) + + elif tag == "list": + # Render a list as one prose chunk with bullet-marked lines. + items = [ + f"• {_collect_element_text(item).strip()}" + for item in child.findall("item") + if _collect_element_text(item).strip() + ] + if items: + current_text_parts.append("\n".join(items)) + + elif tag == "table": + flush_prose() + table_rows = _parse_trafilatura_table(child) + if table_rows: + sections.append( + SectionContent( + section_path=" > ".join(heading_stack) if heading_stack else None, + page_number=None, + text="", + chunk_type="table", + table_rows=table_rows, + extractor="trafilatura", + ) + ) + + flush_prose() + return sections + def _extract_title(self, soup: BeautifulSoup) -> Optional[str]: og_title = soup.find("meta", property="og:title") if og_title and og_title.get("content"): # type: ignore[union-attr] @@ -215,3 +369,44 @@ def _fallback_sections(self, text: str) -> List[SectionContent]: ) ) return sections + + +# --------------------------------------------------------------------------- +# Module-level helpers for the trafilatura XML walker +# --------------------------------------------------------------------------- + + +def _collect_element_text(el: ET.Element) -> str: + """Concatenate text from an XML element and all its descendants.""" + parts: List[str] = [] + if el.text: + parts.append(el.text) + for child in el: + parts.append(_collect_element_text(child)) + if child.tail: + parts.append(child.tail) + return "".join(parts) + + +def _heading_level(rend: Optional[str]) -> int: + """Parse trafilatura's ``rend="hN"`` heading marker. Defaults to h2 + when the marker is missing or unrecognised (matches the most common + article structure).""" + if rend and rend.startswith("h") and rend[1:].isdigit(): + return max(1, min(int(rend[1:]), 4)) + return 2 + + +def _parse_trafilatura_table(table_el: ET.Element) -> List[List[str]]: + """Convert a trafilatura ``

`` element into row-major cells. + + Trafilatura emits ```` containing ````, sometimes with a + nested ``

``. We collapse to plain text per cell and discard any + rows that came out empty. + """ + rows: List[List[str]] = [] + for row in table_el.findall("row"): + cells = [_collect_element_text(c).strip() for c in row.findall("cell")] + if any(cells): + rows.append(cells) + return rows diff --git a/bioscancast/tests/test_extraction_html.py b/bioscancast/tests/test_extraction_html.py index 3f1e846..bb2ac0a 100644 --- a/bioscancast/tests/test_extraction_html.py +++ b/bioscancast/tests/test_extraction_html.py @@ -118,3 +118,132 @@ def test_section_paths_contain_article_headings(self, html_parser, reuters_html) result = html_parser.parse(reuters_html, source_url="https://reuters.com/mpox") all_paths = " ".join(s.section_path or "" for s in result.sections) assert "Outbreak Spread" in all_paths or "International Response" in all_paths + + +# --------------------------------------------------------------------------- +# Trafilatura-XML extraction path +# --------------------------------------------------------------------------- + +class TestTrafilaturaXmlExtraction: + """The HTML parser prefers trafilatura's structural XML output (which + strips navigation, sidebars, "related articles" lists, and other + site boilerplate) and falls back to the DOM walker only when + trafilatura returns too little content to be trustworthy. + """ + + def test_xml_path_strips_unrelated_sibling_articles(self, html_parser): + """A page whose contains the target article AND multiple + unrelated articles should yield only the target's content via + the trafilatura path.""" + html = b""" + Target Article + +

+

Target Article Title

+

The target article body says fact one with enough text to + survive trafilatura's content-density heuristic and clear the + minimum-body-chars threshold easily.

+

A second paragraph of the target article continues + describing fact two in detail with concrete numbers like 602 + and 405 and other figures.

+
+ + + """ + result = html_parser.parse(html, source_url="https://example.com/a") + all_text = " ".join(s.text or "" for s in result.sections) + assert "fact one" in all_text + assert "fact two" in all_text + # The aside content must not appear in any extracted section + assert "Sidebar article A" not in all_text + assert "Sidebar article B" not in all_text + assert "Sidebar article C" not in all_text + # Every section should come from the trafilatura extractor + for s in result.sections: + assert s.extractor == "trafilatura" + + def test_xml_path_preserves_tables(self, html_parser): + """Tables inside the main content should make it through as + chunk_type=table with table_rows populated.""" + html = b""" + Table Page + +
+

Disease Surveillance Weekly

+

The following table summarises this week's outbreak alerts + with one row per event, including pathogen and country, in + enough text to exceed the minimum-body-chars threshold for + the trafilatura extraction path to be selected.

+
+ + + + + + +
PathogenCountry
H5N1United States
MeaslesUtah
MpoxComoros
+

+ + """ + result = html_parser.parse(html, source_url="https://example.com/b") + table_sections = [s for s in result.sections if s.chunk_type == "table"] + assert len(table_sections) == 1 + table = table_sections[0] + assert table.table_rows is not None + # Header + 3 data rows + assert len(table.table_rows) >= 3 + # The headers and at least one cell should appear in the rows + flat = [c for row in table.table_rows for c in row] + assert "H5N1" in flat + assert "Utah" in flat + + def test_falls_back_to_dom_walker_on_thin_xml(self, html_parser): + """When trafilatura can extract very little, the parser falls + back to the DOM walker so navigation-heavy pages still produce + SOMETHING rather than silently extracting nothing. + + The fallback path is also what keeps the parser working when + trafilatura is not installed at all. + """ + # Tiny page where trafilatura will find almost no content. + html = b""" + Tiny + + +

Hello.

+ + """ + result = html_parser.parse(html, source_url="https://example.com/c") + # The result is allowed to be empty here (no good content + # anywhere), but the call must not raise. + assert isinstance(result.sections, list) + + def test_metadata_extracted_independently_of_body_path(self, html_parser): + """Title, published_date, and language must come from the raw + DOM head/meta regardless of which body extraction path runs.""" + html = b""" + + + Real Title From Meta + + + + +
+

Body content long enough to satisfy the trafilatura path + threshold by clearing the minimum number of characters required + for structured extraction to be preferred over the fallback.

+

And a second paragraph for good measure.

+
+ + """ + result = html_parser.parse(html, source_url="https://example.com/d") + assert result.title == "Real Title From Meta" + assert result.published_date is not None + assert result.published_date.year == 2026 + assert result.language == "en" diff --git a/bioscancast/tests/test_insight_real_docs_integration.py b/bioscancast/tests/test_insight_real_docs_integration.py index e16549f..732a074 100644 --- a/bioscancast/tests/test_insight_real_docs_integration.py +++ b/bioscancast/tests/test_insight_real_docs_integration.py @@ -70,24 +70,31 @@ def test_extraction_africa_cdc_fails_with_requires_ocr(real_docs): @pytest.mark.parametrize( - "name", + # (name, min_chunks). CIDRAP has a lower floor because trafilatura's + # main-content extraction correctly isolates only the actual article + # body (~2 chunks) rather than the surrounding navigation, sidebars, + # and three additional unrelated articles that the raw DOM contains. + # Other sources are calibrated to today's behaviour. + "name,min_chunks", [ - "who_mpox_sitrep64", - "who_cholera_epi34", - "cdc_mmwr_nm_measles", - "ecdc_cdtr_week16", - "cidrap_utah_measles", - "promed_latest", + ("who_mpox_sitrep64", 5), + ("who_cholera_epi34", 5), + ("cdc_mmwr_nm_measles", 5), + ("ecdc_cdtr_week16", 5), + ("cidrap_utah_measles", 1), + ("promed_latest", 5), ], ) -def test_extraction_produces_chunks_for_text_extractable_sources(real_docs, name): +def test_extraction_produces_chunks_for_text_extractable_sources( + real_docs, name, min_chunks +): """Every source except Africa CDC must extract at least a few chunks.""" doc = real_docs[name] assert doc.status == "success", ( f"{name}: expected status=success, got {doc.status}" ) - assert len(doc.chunks) >= 5, ( - f"{name}: expected >= 5 chunks, got {len(doc.chunks)}" + assert len(doc.chunks) >= min_chunks, ( + f"{name}: expected >= {min_chunks} chunks, got {len(doc.chunks)}" ) From a0b6d9c4a4f34af86f3161b34ea8fc1c33a74646 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20Mod=C3=A9e?= Date: Tue, 26 May 2026 22:05:38 +0200 Subject: [PATCH 18/21] Broaden HTML pub_date extraction to JSON-LD, Dublin Core, and news platforms MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous date extractor checked only four patterns: ``article:published_time``, ``og:published_time``, ````, and ``