diff --git a/src/gradata/_manifest_helpers.py b/src/gradata/_manifest_helpers.py index 8b6b4e9f..66cdd49e 100644 --- a/src/gradata/_manifest_helpers.py +++ b/src/gradata/_manifest_helpers.py @@ -122,7 +122,6 @@ def _sdk_capabilities() -> dict: ("auto_correct_hook", "gradata.hooks.auto_correct", "gradata"), ("reporting", "gradata.enhancements.reporting", "fest.build-inspired+gradata"), ("quality_monitoring", "gradata.enhancements.quality_monitoring", "jarvis-inspired+gradata"), - ("rule_evolution", "gradata.enhancements.rule_evolution", "jarvis-inspired+gradata"), ] all_modules = _paul_modules + _ruflo_modules + _deerflow_modules + _ecc_modules + _everos_modules + _core_modules diff --git a/src/gradata/benchmarks/swe_bench.py b/src/gradata/benchmarks/swe_bench.py deleted file mode 100644 index 7443e3dd..00000000 --- a/src/gradata/benchmarks/swe_bench.py +++ /dev/null @@ -1,516 +0,0 @@ -""" -SWE-bench Harness — Prove Gradata improves AI coding agents. -============================================================== - -Runs SWE-bench instances through a Gradata-enhanced agent, captures -failed patches as corrections, accumulates brain, measures improvement. - -Two modes: - 1. Offline (no Docker): compare agent patches to gold patches via - diff similarity. Fast, cheap, sufficient to prove accumulation. - 2. Online (Docker/Modal): run actual tests for ground-truth pass/fail. - -The experiment: - Run A: baseline agent (no brain) → X% resolved - Run B: same agent + Gradata brain → Y% resolved - If Y > X, that's the paper. - -Usage:: - - from gradata.benchmarks.swe_bench import ( - SWEBenchHarness, load_swe_bench_lite, RunConfig, - ) - - harness = SWEBenchHarness(brain_dir="./swe-brain") - instances = load_swe_bench_lite() - results = harness.run(instances, agent_fn=my_agent) - print(results.summary()) - -Requires: pip install datasets (for loading SWE-bench data) -""" - -from __future__ import annotations - -import json -import logging -import time -from collections.abc import Callable -from dataclasses import dataclass, field -from pathlib import Path -from typing import Any - -_log = logging.getLogger(__name__) - -__all__ = [ - "PatchResult", - "RunConfig", - "RunResults", - "SWEBenchHarness", - "SWEInstance", - "compare_patches", - "load_swe_bench_lite", - "load_swe_bench_verified", -] - - -# --------------------------------------------------------------------------- -# Data structures -# --------------------------------------------------------------------------- - -@dataclass -class SWEInstance: - """A single SWE-bench instance. - - Attributes: - instance_id: Unique ID (e.g. "django__django-11099"). - repo: GitHub repo (e.g. "django/django"). - problem_statement: The issue text. - gold_patch: The correct fix (diff format). - fail_to_pass: Test names that must flip to passing. - pass_to_pass: Test names that must stay passing. - base_commit: Commit SHA to start from. - hints: Optional hints from issue comments. - version: Package version string. - """ - instance_id: str - repo: str - problem_statement: str - gold_patch: str = "" - fail_to_pass: list[str] = field(default_factory=list) - pass_to_pass: list[str] = field(default_factory=list) - base_commit: str = "" - hints: str = "" - version: str = "" - - -@dataclass -class PatchResult: - """Result of an agent's attempt at fixing an instance. - - Attributes: - instance_id: Which instance was attempted. - agent_patch: The patch the agent produced. - gold_patch: The correct patch (for comparison). - resolved: Whether the patch resolves the issue. - similarity: Diff similarity to gold patch (0.0-1.0). - correction_captured: Whether brain.correct() was called. - lesson_created: Whether a new lesson was created. - attempt_number: Which attempt this was (1-indexed). - duration_ms: How long the agent took. - error: Error message if agent crashed. - """ - instance_id: str - agent_patch: str = "" - gold_patch: str = "" - resolved: bool = False - similarity: float = 0.0 - correction_captured: bool = False - lesson_created: bool = False - attempt_number: int = 1 - duration_ms: int = 0 - error: str = "" - - -@dataclass -class RunConfig: - """Configuration for a SWE-bench run. - - Attributes: - run_id: Identifier for this run. - use_brain: Whether to inject brain rules. - batch_size: How many instances before measuring. - max_instances: Cap on total instances to process. - similarity_threshold: Min similarity to count as "resolved" in offline mode. - """ - run_id: str = "default" - use_brain: bool = True - batch_size: int = 50 - max_instances: int = 300 - similarity_threshold: float = 0.85 - - -@dataclass -class BatchStats: - """Statistics for a batch of instances.""" - batch_number: int - instances_in_batch: int - resolved: int - resolve_rate: float - avg_similarity: float - lessons_total: int - corrections_total: int - - -@dataclass -class RunResults: - """Aggregate results from a full SWE-bench run. - - Attributes: - config: The run configuration. - results: Per-instance results. - batch_stats: Per-batch statistics (shows learning curve). - total_resolved: Total instances resolved. - total_attempted: Total instances attempted. - resolve_rate: Overall resolve rate. - brain_lessons_created: Total lessons created during run. - duration_seconds: Total run time. - """ - config: RunConfig - results: list[PatchResult] = field(default_factory=list) - batch_stats: list[BatchStats] = field(default_factory=list) - total_resolved: int = 0 - total_attempted: int = 0 - resolve_rate: float = 0.0 - brain_lessons_created: int = 0 - duration_seconds: float = 0.0 - - def summary(self) -> str: - """Human-readable summary of the run.""" - lines = [ - f"SWE-bench Run: {self.config.run_id}", - f"Brain: {'enabled' if self.config.use_brain else 'DISABLED (baseline)'}", - f"Resolved: {self.total_resolved}/{self.total_attempted} ({self.resolve_rate:.1%})", - f"Lessons created: {self.brain_lessons_created}", - f"Duration: {self.duration_seconds:.0f}s", - "", - "Learning curve (resolve rate per batch):", - ] - for bs in self.batch_stats: - bar = "#" * int(bs.resolve_rate * 20) - lines.append( - f" Batch {bs.batch_number}: {bs.resolve_rate:.1%} " - f"[{bar:<20}] (lessons: {bs.lessons_total})" - ) - return "\n".join(lines) - - def to_dict(self) -> dict[str, Any]: - """Serialize for JSON export.""" - return { - "run_id": self.config.run_id, - "use_brain": self.config.use_brain, - "total_resolved": self.total_resolved, - "total_attempted": self.total_attempted, - "resolve_rate": round(self.resolve_rate, 4), - "brain_lessons_created": self.brain_lessons_created, - "duration_seconds": round(self.duration_seconds, 1), - "batch_stats": [ - { - "batch": bs.batch_number, - "resolve_rate": round(bs.resolve_rate, 4), - "lessons_total": bs.lessons_total, - } - for bs in self.batch_stats - ], - } - - -# --------------------------------------------------------------------------- -# Patch comparison (offline mode) -# --------------------------------------------------------------------------- - -def compare_patches(agent_patch: str, gold_patch: str) -> float: - """Compare two patches and return similarity score. - - Uses line-level set overlap (Jaccard on meaningful diff lines). - Strips whitespace and comment-only changes. - - Args: - agent_patch: The agent's proposed fix. - gold_patch: The correct fix. - - Returns: - Similarity score in [0.0, 1.0]. 1.0 = identical patches. - """ - if not agent_patch and not gold_patch: - return 1.0 - if not agent_patch or not gold_patch: - return 0.0 - - def _meaningful_lines(patch: str) -> set[str]: - """Extract meaningful diff lines (additions/removals only).""" - lines = set() - for line in patch.splitlines(): - stripped = line.strip() - # Only count actual code changes, not headers - if stripped.startswith(("+", "-")) and not stripped.startswith(("+++", "---", "@@")): - # Normalize whitespace - normalized = " ".join(stripped[1:].split()) - if normalized and not normalized.startswith("#"): - lines.add(normalized) - return lines - - agent_lines = _meaningful_lines(agent_patch) - gold_lines = _meaningful_lines(gold_patch) - - if not agent_lines and not gold_lines: - return 0.5 # Both patches have no meaningful changes - - union = agent_lines | gold_lines - if not union: - return 0.0 - - intersection = agent_lines & gold_lines - return len(intersection) / len(union) - - -# --------------------------------------------------------------------------- -# Data loading -# --------------------------------------------------------------------------- - -def _load_dataset(dataset_name: str, split: str = "test") -> list[SWEInstance]: - """Load SWE-bench instances from HuggingFace datasets. - - Requires: pip install datasets - """ - try: - from datasets import load_dataset - except ImportError as e: - raise ImportError( - "SWE-bench data loading requires the 'datasets' package.\n" - "Install with: pip install datasets" - ) from e - - ds = load_dataset(dataset_name, split=split) - instances = [] - for raw_item in ds: - item: dict = dict(raw_item) # type: ignore[arg-type] - instances.append(SWEInstance( - instance_id=item["instance_id"], - repo=item["repo"], - problem_statement=item["problem_statement"], - gold_patch=item.get("patch", ""), - fail_to_pass=json.loads(item.get("FAIL_TO_PASS", "[]")), - pass_to_pass=json.loads(item.get("PASS_TO_PASS", "[]")), - base_commit=item.get("base_commit", ""), - hints=item.get("hints_text", ""), - version=item.get("version", ""), - )) - return instances - - -def load_swe_bench_lite() -> list[SWEInstance]: - """Load SWE-bench Lite (300 test instances).""" - return _load_dataset("princeton-nlp/SWE-bench_Lite", split="test") - - -def load_swe_bench_verified() -> list[SWEInstance]: - """Load SWE-bench Verified (500 human-verified instances).""" - return _load_dataset("princeton-nlp/SWE-bench_Verified", split="test") - - -# --------------------------------------------------------------------------- -# Agent function type -# --------------------------------------------------------------------------- - -# An agent function takes (instance, brain_rules) and returns a patch string. -# brain_rules is "" when use_brain=False (baseline), or the injected rules -# when use_brain=True. -AgentFn = Callable[[SWEInstance, str], str] - - -# --------------------------------------------------------------------------- -# Harness -# --------------------------------------------------------------------------- - -class SWEBenchHarness: - """Runs SWE-bench with Gradata brain accumulation. - - The harness: - 1. Feeds each instance to the agent function - 2. Compares the agent's patch to the gold patch - 3. If wrong: calls brain.correct(agent_patch, gold_patch) - 4. Tracks resolve rate per batch as the brain accumulates - 5. Injects brain rules into subsequent agent calls (if use_brain=True) - - Args: - brain_dir: Directory for the Gradata brain. - brain: Optional pre-existing Brain instance. - """ - - def __init__( - self, - brain_dir: str | Path | None = None, - brain: Any = None, - ) -> None: - self.brain = brain - self.brain_dir = Path(brain_dir) if brain_dir else None - - if self.brain is None and self.brain_dir: - try: - from gradata.brain import Brain - if self.brain_dir.exists(): - self.brain = Brain(self.brain_dir) - else: - self.brain = Brain.init(self.brain_dir, domain="SWE-bench") - except ImportError: - _log.warning("Brain not available, running without learning") - - def run( - self, - instances: list[SWEInstance], - agent_fn: AgentFn, - config: RunConfig | None = None, - ) -> RunResults: - """Run the benchmark. - - Args: - instances: SWE-bench instances to attempt. - agent_fn: Callable(instance, brain_rules) -> patch_string. - config: Run configuration. - - Returns: - RunResults with per-instance and per-batch statistics. - """ - config = config or RunConfig() - instances = instances[:config.max_instances] - - run_results = RunResults(config=config) - start_time = time.time() - - batch_results: list[PatchResult] = [] - batch_number = 0 - lessons_total = 0 - - for i, instance in enumerate(instances): - # Get brain rules for injection - brain_rules = "" - if config.use_brain and self.brain: - try: - brain_rules = self.brain.apply_brain_rules( - f"Fix bug in {instance.repo}: {instance.problem_statement[:200]}" - ) - except Exception: - brain_rules = "" - - # Run the agent - t0 = time.time() - try: - agent_patch = agent_fn(instance, brain_rules) - except Exception as e: - agent_patch = "" - _log.warning("Agent failed on %s: %s", instance.instance_id, e) - - duration_ms = int((time.time() - t0) * 1000) - - # Compare to gold patch - similarity = compare_patches(agent_patch, instance.gold_patch) - resolved = similarity >= config.similarity_threshold - - # Capture correction if wrong - correction_captured = False - lesson_created = False - if not resolved and self.brain and config.use_brain and agent_patch and instance.gold_patch: - try: - event = self.brain.correct( - draft=agent_patch[:5000], - final=instance.gold_patch[:5000], - category="CODE", - context={ - "task_type": "swe_bench_fix", - "repo": instance.repo, - "instance_id": instance.instance_id, - }, - ) - correction_captured = True - if event.get("lessons_created", 0) > 0: - lesson_created = True - lessons_total += 1 - except Exception as e: - _log.warning("Correction capture failed: %s", e) - - result = PatchResult( - instance_id=instance.instance_id, - agent_patch=agent_patch[:1000], - gold_patch=instance.gold_patch[:1000], - resolved=resolved, - similarity=round(similarity, 4), - correction_captured=correction_captured, - lesson_created=lesson_created, - attempt_number=1, - duration_ms=duration_ms, - ) - run_results.results.append(result) - batch_results.append(result) - - # Batch checkpoint - if len(batch_results) >= config.batch_size or i == len(instances) - 1: - batch_number += 1 - batch_resolved = sum(1 for r in batch_results if r.resolved) - batch_rate = batch_resolved / len(batch_results) if batch_results else 0 - avg_sim = ( - sum(r.similarity for r in batch_results) / len(batch_results) - if batch_results else 0 - ) - - run_results.batch_stats.append(BatchStats( - batch_number=batch_number, - instances_in_batch=len(batch_results), - resolved=batch_resolved, - resolve_rate=round(batch_rate, 4), - avg_similarity=round(avg_sim, 4), - lessons_total=lessons_total, - corrections_total=sum( - 1 for r in run_results.results if r.correction_captured - ), - )) - - _log.info( - "Batch %d: %d/%d resolved (%.1f%%), %d lessons total", - batch_number, batch_resolved, len(batch_results), - batch_rate * 100, lessons_total, - ) - batch_results = [] - - # Final stats - run_results.total_attempted = len(run_results.results) - run_results.total_resolved = sum(1 for r in run_results.results if r.resolved) - run_results.resolve_rate = ( - run_results.total_resolved / run_results.total_attempted - if run_results.total_attempted else 0 - ) - run_results.brain_lessons_created = lessons_total - run_results.duration_seconds = time.time() - start_time - - return run_results - - def compare_runs( - self, - baseline: RunResults, - enhanced: RunResults, - ) -> dict[str, Any]: - """Compare a baseline run (no brain) vs enhanced run (with brain). - - Returns a summary dict suitable for a paper or blog post. - """ - improvement = enhanced.resolve_rate - baseline.resolve_rate - improvement_pct = ( - improvement / baseline.resolve_rate * 100 - if baseline.resolve_rate > 0 else 0 - ) - - # Per-batch learning curve comparison - curve = [] - for i, (b, e) in enumerate(zip(baseline.batch_stats, enhanced.batch_stats, strict=False)): - curve.append({ - "batch": i + 1, - "baseline_rate": b.resolve_rate, - "enhanced_rate": e.resolve_rate, - "delta": round(e.resolve_rate - b.resolve_rate, 4), - "lessons_at_batch": e.lessons_total, - }) - - return { - "baseline_resolve_rate": round(baseline.resolve_rate, 4), - "enhanced_resolve_rate": round(enhanced.resolve_rate, 4), - "absolute_improvement": round(improvement, 4), - "relative_improvement_pct": round(improvement_pct, 1), - "lessons_created": enhanced.brain_lessons_created, - "baseline_instances": baseline.total_attempted, - "enhanced_instances": enhanced.total_attempted, - "learning_curve": curve, - "verdict": ( - f"Gradata improved SWE-bench resolve rate from " - f"{baseline.resolve_rate:.1%} to {enhanced.resolve_rate:.1%} " - f"(+{improvement:.1%} absolute, +{improvement_pct:.1f}% relative)" - ), - } diff --git a/src/gradata/cloud/__init__.py b/src/gradata/cloud/__init__.py index 99b192ef..968ac875 100644 --- a/src/gradata/cloud/__init__.py +++ b/src/gradata/cloud/__init__.py @@ -30,11 +30,4 @@ from gradata.cloud.client import CloudClient -__all__ = ["CloudClient", "WikiStore"] - - -def __getattr__(name: str): - if name == "WikiStore": - from gradata.cloud.wiki_store import WikiStore - return WikiStore - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") +__all__ = ["CloudClient"] diff --git a/src/gradata/cloud/wiki_store.py b/src/gradata/cloud/wiki_store.py deleted file mode 100644 index 39a77922..00000000 --- a/src/gradata/cloud/wiki_store.py +++ /dev/null @@ -1,451 +0,0 @@ -"""Supabase wiki store — pgvector-backed wiki for cloud users. - -Replaces local qmd (BM25 full-text) with pgvector (semantic similarity) -for rule injection and wiki search. Same API surface so the injection -hook works with either backend. - -Requires: supabase-py + pgvector extension enabled on the Supabase project. - -Schema (auto-created via ensure_schema()): - - wiki_pages: page content + pgvector embedding - - wiki_sources: raw source tracking (what was ingested) - -Usage: - store = WikiStore(supabase_url="...", supabase_key="...") - store.ensure_schema() - store.upsert_page({"title": "Rule: CODE", "category": "CODE", ...}) - results = store.search("code implementation", limit=5) -""" -from __future__ import annotations - -import hashlib -import json -import logging -from dataclasses import dataclass -from datetime import UTC, datetime -from pathlib import Path -from typing import Any - -_log = logging.getLogger(__name__) - -EMBEDDING_DIM = 384 # all-MiniLM-L6-v2 output dimension (fixed by model choice) - -# SQL for schema creation (run via Supabase SQL editor or RPC) -SCHEMA_SQL = f""" --- Enable pgvector extension -create extension if not exists vector; - --- Wiki pages: content + vector embedding for semantic search -create table if not exists wiki_pages ( - id text primary key, - brain_id text not null, - title text not null, - category text, - page_type text not null default 'concept', - content text not null, - content_hash text not null, - embedding vector({EMBEDDING_DIM}), - tags jsonb default '[]'::jsonb, - source_file text, - created_at timestamptz not null default now(), - updated_at timestamptz not null default now() -); - --- Sources: track what raw documents were ingested -create table if not exists wiki_sources ( - id text primary key, - brain_id text not null, - title text not null, - source_type text not null default 'document', - source_url text, - file_path text, - content_hash text not null, - ingested_at timestamptz not null default now(), - metadata jsonb default '{{}}'::jsonb -); - --- Indexes for common queries -create index if not exists idx_wiki_pages_brain on wiki_pages(brain_id); -create index if not exists idx_wiki_pages_category on wiki_pages(category); -create index if not exists idx_wiki_sources_brain on wiki_sources(brain_id); - --- Vector similarity index (IVFFlat for fast approximate search) --- Only create if embedding column is populated --- create index if not exists idx_wiki_pages_embedding --- on wiki_pages using ivfflat (embedding vector_cosine_ops) with (lists = 100); -""" - -# RLS policies (users can only access their own brain's data) -RLS_SQL = """ -alter table wiki_pages enable row level security; -alter table wiki_sources enable row level security; - -create policy if not exists "wiki_pages_brain_isolation" - on wiki_pages for all - using (brain_id = current_setting('app.brain_id', true)); - -create policy if not exists "wiki_sources_brain_isolation" - on wiki_sources for all - using (brain_id = current_setting('app.brain_id', true)); -""" - - -@dataclass -class WikiPage: - """A wiki page retrieved from the store.""" - - id: str - title: str - category: str | None - content: str - page_type: str - tags: list[str] - similarity: float = 0.0 - - -class WikiStore: - """Supabase-backed wiki store with pgvector semantic search. - - Drop-in replacement for local qmd search in the rule injection hook. - """ - - def __init__( - self, - supabase_url: str, - supabase_key: str, - brain_id: str, - ) -> None: - try: - from supabase import create_client - except ImportError: - raise ImportError( - "supabase-py required for cloud wiki. " - "Install with: pip install gradata[cloud-wiki]" - ) from None - self.client = create_client(supabase_url, supabase_key) - self.brain_id = brain_id - self._embedder: Any = None - - def _get_embedder(self) -> Any: - """Lazy-load sentence-transformers model.""" - if self._embedder is None: - try: - from sentence_transformers import SentenceTransformer - self._embedder = SentenceTransformer("all-MiniLM-L6-v2") - except ImportError: - raise ImportError( - "sentence-transformers required for embeddings. " - "Install with: pip install gradata[embeddings]" - ) from None - return self._embedder - - def _embed(self, text: str) -> list[float]: - """Generate embedding vector for text.""" - model = self._get_embedder() - return model.encode(text).tolist() - - @staticmethod - def _content_hash(content: str) -> str: - return hashlib.sha256(content.encode()).hexdigest()[:16] - - @staticmethod - def _page_id(brain_id: str, title: str) -> str: - digest = hashlib.sha256(f"{brain_id}:{title}".encode()).hexdigest()[:12] - return f"wp_{digest}" - - @staticmethod - def _source_id(brain_id: str, path_or_url: str) -> str: - digest = hashlib.sha256(f"{brain_id}:{path_or_url}".encode()).hexdigest()[:12] - return f"ws_{digest}" - - # ── Schema management ───────────────────────────────────────────── - - def ensure_schema(self) -> None: - """Create wiki tables if they don't exist. - - Attempts to use the ``exec_sql`` Supabase RPC function. This is NOT - a standard Supabase function — you must create it manually first, or - run the SQL in ``SCHEMA_SQL`` directly via the Supabase SQL editor. - """ - try: - self.client.rpc("exec_sql", {"sql": SCHEMA_SQL}).execute() - _log.info("Wiki schema ensured for brain %s", self.brain_id) - except Exception as e: - _log.warning( - "Schema creation via RPC failed. Run SCHEMA_SQL and " - "SEARCH_RPC_SQL manually in the Supabase SQL editor: %s", e, - ) - - # ── Page CRUD ───────────────────────────────────────────────────── - - def upsert_page( - self, - title: str, - content: str, - category: str | None = None, - page_type: str = "concept", - tags: list[str] | None = None, - source_file: str | None = None, - embed: bool = True, - ) -> str: - """Insert or update a wiki page. Returns page ID.""" - page_id = self._page_id(self.brain_id, title) - now = datetime.now(UTC).isoformat() - - row: dict[str, Any] = { - "id": page_id, - "brain_id": self.brain_id, - "title": title, - "category": category, - "page_type": page_type, - "content": content, - "content_hash": self._content_hash(content), - "tags": json.dumps(tags or []), - "source_file": source_file, - "updated_at": now, - } - - if embed: - try: - row["embedding"] = self._embed(f"{title}\n{content[:500]}") - except ImportError: - _log.debug("Embeddings unavailable, storing page without vector") - - self.client.table("wiki_pages").upsert(row).execute() - return page_id - - def get_page(self, title: str) -> WikiPage | None: - """Get a page by title.""" - page_id = self._page_id(self.brain_id, title) - resp = ( - self.client.table("wiki_pages") - .select("*") - .eq("id", page_id) - .maybe_single() - .execute() - ) - if not resp.data: - return None - return self._row_to_page(resp.data) - - def delete_page(self, title: str) -> bool: - """Delete a page by title.""" - page_id = self._page_id(self.brain_id, title) - self.client.table("wiki_pages").delete().eq("id", page_id).execute() - return True - - # ── Search ──────────────────────────────────────────────────────── - - def search(self, query: str, limit: int = 5) -> list[WikiPage]: - """Semantic search via pgvector cosine similarity. - - This is the cloud replacement for ``qmd search``. - Returns pages ranked by embedding similarity. - """ - try: - query_vec = self._embed(query) - except ImportError: - return self._text_search(query, limit) - - # Use Supabase RPC for vector similarity search - resp = self.client.rpc("wiki_search", { - "query_embedding": query_vec, - "match_brain_id": self.brain_id, - "match_count": limit, - }).execute() - - if not resp.data: - return [] - - pages = [] - for row in resp.data: - page = self._row_to_page(row) - page.similarity = row.get("similarity", 0.0) - pages.append(page) - return pages - - def search_by_category(self, category: str) -> list[WikiPage]: - """Get all pages in a category (for rule injection).""" - resp = ( - self.client.table("wiki_pages") - .select("*") - .eq("brain_id", self.brain_id) - .eq("category", category.upper()) - .execute() - ) - return [self._row_to_page(row) for row in (resp.data or [])] - - def search_categories(self, query: str, limit: int = 10) -> set[str]: - """Semantic search returning matched categories (for rule injection hook). - - Drop-in replacement for _wiki_categories() in inject_brain_rules.py. - """ - pages = self.search(query, limit=limit) - return {p.category.upper() for p in pages if p.category} - - def _text_search(self, query: str, limit: int) -> list[WikiPage]: - """Fallback text search when embeddings unavailable.""" - resp = ( - self.client.table("wiki_pages") - .select("*") - .eq("brain_id", self.brain_id) - .ilike("content", "%{}%".format(query.replace("%", r"\%").replace("_", r"\_"))) - .limit(limit) - .execute() - ) - return [self._row_to_page(row) for row in (resp.data or [])] - - # ── Source tracking ─────────────────────────────────────────────── - - def add_source( - self, - title: str, - source_type: str = "document", - source_url: str | None = None, - file_path: str | None = None, - content_hash: str = "", - metadata: dict | None = None, - ) -> str: - """Track an ingested source document.""" - source_id = self._source_id(self.brain_id, source_url or file_path or title) - self.client.table("wiki_sources").upsert({ - "id": source_id, - "brain_id": self.brain_id, - "title": title, - "source_type": source_type, - "source_url": source_url, - "file_path": file_path, - "content_hash": content_hash, - "metadata": json.dumps(metadata or {}), - }).execute() - return source_id - - def list_sources(self) -> list[dict]: - """List all tracked sources for this brain.""" - resp = ( - self.client.table("wiki_sources") - .select("*") - .eq("brain_id", self.brain_id) - .order("ingested_at", desc=True) - .execute() - ) - return resp.data or [] - - # ── Sync from local wiki ────────────────────────────────────────── - - def sync_from_local(self, wiki_dir: str | Path) -> dict[str, int]: - """Upload local wiki pages to Supabase. - - Reads all .md files from wiki_dir, compares content hashes, - and upserts changed pages. Returns counts. - """ - wiki_path = Path(wiki_dir) - if not wiki_path.is_dir(): - return {"uploaded": 0, "skipped": 0, "errors": 0} - - uploaded = skipped = errors = 0 - - for md_file in wiki_path.rglob("*.md"): - try: - content = md_file.read_text(encoding="utf-8") - title = md_file.stem.replace("-", " ").title() - - # Parse frontmatter for category/type - category = None - page_type = "concept" - tags: list[str] = [] - if content.startswith("---"): - parts = content.split("---", 2) - if len(parts) >= 3: - for line in parts[1].splitlines(): - line = line.strip() - if line.startswith("category:"): - category = line.split(":", 1)[1].strip().upper() - elif line.startswith("type:"): - page_type = line.split(":", 1)[1].strip() - elif line.startswith("title:"): - title = line.split(":", 1)[1].strip().strip('"') - - # Check if content changed - content_hash = self._content_hash(content) - page_id = self._page_id(self.brain_id, title) - existing = ( - self.client.table("wiki_pages") - .select("content_hash") - .eq("id", page_id) - .maybe_single() - .execute() - ) - if existing.data and existing.data.get("content_hash") == content_hash: - skipped += 1 - continue - - self.upsert_page( - title=title, - content=content, - category=category, - page_type=page_type, - tags=tags, - source_file=str(md_file.relative_to(wiki_path)), - ) - uploaded += 1 - except Exception as e: - _log.debug("Failed to sync %s: %s", md_file, e) - errors += 1 - - return {"uploaded": uploaded, "skipped": skipped, "errors": errors} - - # ── Helpers ──────────────────────────────────────────────────────── - - @staticmethod - def _row_to_page(row: dict) -> WikiPage: - tags = row.get("tags", []) - if isinstance(tags, str): - tags = json.loads(tags) - return WikiPage( - id=row["id"], - title=row["title"], - category=row.get("category"), - content=row.get("content", ""), - page_type=row.get("page_type", "concept"), - tags=tags, - ) - - -# ── Supabase RPC function for vector search ────────────────────────── - -SEARCH_RPC_SQL = f""" -create or replace function wiki_search( - query_embedding vector({EMBEDDING_DIM}), - match_brain_id text, - match_count int default 5 -) -returns table ( - id text, - title text, - category text, - content text, - page_type text, - tags jsonb, - similarity float -) -language plpgsql -as $$ -begin - return query - select - wp.id, - wp.title, - wp.category, - wp.content, - wp.page_type, - wp.tags, - 1 - (wp.embedding <=> query_embedding) as similarity - from wiki_pages wp - where wp.brain_id = match_brain_id - and wp.embedding is not null - order by wp.embedding <=> query_embedding - limit match_count; -end; -$$; -""" diff --git a/src/gradata/contrib/enhancements/install_manifest.py b/src/gradata/contrib/enhancements/install_manifest.py index 949cfa33..e212ede1 100644 --- a/src/gradata/contrib/enhancements/install_manifest.py +++ b/src/gradata/contrib/enhancements/install_manifest.py @@ -271,8 +271,7 @@ def is_installed(self, module_id: str) -> bool: kind="enhancement", components=[ "enhancements.rule_integrity", "enhancements.contradiction_detector", - "enhancements.rule_verifier", "enhancements.rule_conflicts", - "enhancements.rule_canary", + "enhancements.rule_conflicts", "enhancements.rule_canary", ], dependencies=["learning-pipeline"], cost=ModuleCost.MEDIUM, diff --git a/src/gradata/contrib/enhancements/outcome_feedback.py b/src/gradata/contrib/enhancements/outcome_feedback.py deleted file mode 100644 index 3784b158..00000000 --- a/src/gradata/contrib/enhancements/outcome_feedback.py +++ /dev/null @@ -1 +0,0 @@ -"""Outcome Feedback -- External signal to confidence feedback loop.""" diff --git a/src/gradata/enhancements/__init__.py b/src/gradata/enhancements/__init__.py index fe44b363..c7a47120 100644 --- a/src/gradata/enhancements/__init__.py +++ b/src/gradata/enhancements/__init__.py @@ -20,7 +20,6 @@ success_conditions -- 6-condition validation meta_rules -- Emergent meta-rule discovery (compound procedural memory) rule_integrity -- HMAC signing for tamper detection - rule_verifier -- Output verification against rules rule_canary -- Rule regression detection rule_conflicts -- Contradiction detection contradiction_detector -- Semantic contradiction detection before graduation diff --git a/src/gradata/enhancements/meta_rules_storage.py b/src/gradata/enhancements/meta_rules_storage.py index ded90fb6..4d420820 100644 --- a/src/gradata/enhancements/meta_rules_storage.py +++ b/src/gradata/enhancements/meta_rules_storage.py @@ -2,7 +2,8 @@ Meta-Rule SQLite Persistence — load/save for meta_rules and super_meta_rules tables. ===================================================================================== All database I/O for meta-rules lives here. Core logic and discovery live in -``meta_rules.py``; tier-2/3 super-meta-rule logic lives in ``super_meta_rules.py``. +``meta_rules.py`` (including the :class:`SuperMetaRule` dataclass used for +tier-2/3 rows). Also exposes a *differential-privacy scaffold* (:class:`DPConfig`, :func:`apply_dp_to_export_row`) used by the cloud export path when meta-rules diff --git a/src/gradata/enhancements/pubsub_pipeline.py b/src/gradata/enhancements/pubsub_pipeline.py deleted file mode 100644 index 72c51dbe..00000000 --- a/src/gradata/enhancements/pubsub_pipeline.py +++ /dev/null @@ -1,49 +0,0 @@ -"""Pub/sub event pipeline for decoupled correction processing. - -Unlike learning_pipeline.py (sequential chain), this is a pub/sub system -where stages subscribe to event types independently. Stage failures don't -block other stages. Used for async/background processing patterns. -""" - -from __future__ import annotations - -import logging -from collections.abc import Callable -from typing import Any - -_log = logging.getLogger(__name__) - - -class PubSubPipeline: - """Lightweight pub/sub pipeline for correction processing. - - Each stage subscribes to an event type. When an event fires, - all subscribers for that type run in registration order. - Stage failures are logged but don't block other stages. - """ - - def __init__(self): - self._subscribers: dict[str, list[Callable]] = {} - self._event_log: list[dict] = [] - - def subscribe(self, event_type: str, handler: Callable): - if event_type not in self._subscribers: - self._subscribers[event_type] = [] - self._subscribers[event_type].append(handler) - - def emit(self, event_type: str, data: dict[str, Any] | None = None) -> list[dict]: - """Emit an event. Returns list of stage results.""" - results = [] - self._event_log.append({"type": event_type, "data": data}) - for handler in self._subscribers.get(event_type, []): - try: - result = handler(data or {}) - results.append({"handler": handler.__name__, "status": "ok", "result": result}) - except Exception as e: - _log.warning("Pipeline stage %s failed: %s", handler.__name__, e) - results.append({"handler": handler.__name__, "status": "error", "error": str(e)}) - return results - - @property - def event_log(self) -> list[dict]: - return list(self._event_log) diff --git a/src/gradata/enhancements/rule_evolution.py b/src/gradata/enhancements/rule_evolution.py deleted file mode 100644 index 0431b666..00000000 --- a/src/gradata/enhancements/rule_evolution.py +++ /dev/null @@ -1,434 +0,0 @@ -""" -Rule Evolution — A/B testing + conflict detection for rule lifecycle. -===================================================================== -Merged from: rule_ab_testing.py + rule_conflicts.py (S79 consolidation) - -Two concerns: - 1. A/B Testing: Statistical comparison of rule variants with Wilson scores. - 2. Conflict Detection: Updates/Extends/Derives relationship classification. - -Usage:: - - from gradata.enhancements.rule_evolution import ( - RuleExperiment, ExperimentResult, wilson_score_interval, - detect_rule_conflict, RuleRelation, classify_all_relations, - ) -""" - -from __future__ import annotations - -import math -import random -import re -from dataclasses import dataclass -from enum import Enum -from typing import Any - -from gradata._types import ELIGIBLE_STATES, Lesson -from gradata.enhancements.diff_engine import compute_diff - -__all__ = [ - "ExperimentManager", - "ExperimentResult", - "RuleExperiment", - # Conflict Detection - "RuleRelation", - "classify_all_relations", - "detect_rule_conflict", - # A/B Testing - "wilson_score_interval", -] - - -# ═══════════════════════════════════════════════════════════════════════ -# A/B Testing (Wilson Score Confidence Intervals) -# ═══════════════════════════════════════════════════════════════════════ - - -def wilson_score_interval(successes: int, trials: int, z: float = 1.96) -> tuple[float, float]: - """Compute Wilson score confidence interval for small samples.""" - if trials == 0: - return (0.0, 0.0) - p_hat = successes / trials - z2 = z * z - denominator = 1 + z2 / trials - center = (p_hat + z2 / (2 * trials)) / denominator - spread = (z / denominator) * math.sqrt((p_hat * (1 - p_hat) + z2 / (4 * trials)) / trials) - return (round(max(0.0, center - spread), 4), round(min(1.0, center + spread), 4)) - - -@dataclass -class ExperimentResult: - """Result of evaluating an A/B experiment.""" - - winner: str | None = None - confidence: float = 0.0 - a_success_rate: float = 0.0 - b_success_rate: float = 0.0 - a_interval: tuple[float, float] = (0.0, 0.0) - b_interval: tuple[float, float] = (0.0, 0.0) - a_trials: int = 0 - b_trials: int = 0 - sufficient_data: bool = False - margin: float = 0.0 - - @property - def is_conclusive(self) -> bool: - return self.winner is not None and self.sufficient_data - - -@dataclass -class RuleExperiment: - """A/B test between two rule variants.""" - - experiment_id: str - variant_a: str - variant_b: str - category: str = "" - min_trials: int = 20 - min_margin: float = 0.10 - _a_successes: int = 0 - _a_trials: int = 0 - _b_successes: int = 0 - _b_trials: int = 0 - - def assign(self) -> str: - return random.choice(["a", "b"]) - - def record(self, variant: str, success: bool) -> None: - if variant == "a": - self._a_trials += 1 - if success: - self._a_successes += 1 - elif variant == "b": - self._b_trials += 1 - if success: - self._b_successes += 1 - else: - raise ValueError(f"variant must be 'a' or 'b', got {variant!r}") - - def evaluate(self) -> ExperimentResult: - a_rate = self._a_successes / self._a_trials if self._a_trials else 0.0 - b_rate = self._b_successes / self._b_trials if self._b_trials else 0.0 - a_interval = wilson_score_interval(self._a_successes, self._a_trials) - b_interval = wilson_score_interval(self._b_successes, self._b_trials) - sufficient = self._a_trials >= self.min_trials and self._b_trials >= self.min_trials - margin = abs(a_rate - b_rate) - winner = None - confidence = 0.0 - if sufficient and margin >= self.min_margin: - if a_interval[0] > b_interval[1]: - winner = "a" - confidence = min(1.0, margin / 0.5) - elif b_interval[0] > a_interval[1]: - winner = "b" - confidence = min(1.0, margin / 0.5) - elif margin >= self.min_margin * 2: - winner = "a" if a_rate > b_rate else "b" - confidence = min(1.0, margin / 0.5) * 0.7 - return ExperimentResult( - winner=winner, - confidence=round(confidence, 4), - a_success_rate=round(a_rate, 4), - b_success_rate=round(b_rate, 4), - a_interval=a_interval, - b_interval=b_interval, - a_trials=self._a_trials, - b_trials=self._b_trials, - sufficient_data=sufficient, - margin=round(margin, 4), - ) - - @property - def total_trials(self) -> int: - return self._a_trials + self._b_trials - - def to_dict(self) -> dict[str, Any]: - return { - "experiment_id": self.experiment_id, - "variant_a": self.variant_a, - "variant_b": self.variant_b, - "category": self.category, - "a_successes": self._a_successes, - "a_trials": self._a_trials, - "b_successes": self._b_successes, - "b_trials": self._b_trials, - } - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> RuleExperiment: - exp = cls( - experiment_id=data["experiment_id"], - variant_a=data["variant_a"], - variant_b=data["variant_b"], - category=data.get("category", ""), - ) - exp._a_successes = data.get("a_successes", 0) - exp._a_trials = data.get("a_trials", 0) - exp._b_successes = data.get("b_successes", 0) - exp._b_trials = data.get("b_trials", 0) - return exp - - -class ExperimentManager: - """Manages multiple concurrent A/B experiments.""" - - def __init__(self) -> None: - self._experiments: dict[str, RuleExperiment] = {} - self._completed: list[dict[str, Any]] = [] - - def create( - self, experiment_id: str, variant_a: str, variant_b: str, category: str = "" - ) -> RuleExperiment: - exp = RuleExperiment( - experiment_id=experiment_id, variant_a=variant_a, variant_b=variant_b, category=category - ) - self._experiments[experiment_id] = exp - return exp - - def get(self, experiment_id: str) -> RuleExperiment | None: - return self._experiments.get(experiment_id) - - def evaluate_all(self) -> list[ExperimentResult]: - results = [] - completed_ids = [] - for exp_id, exp in self._experiments.items(): - result = exp.evaluate() - results.append(result) - if result.is_conclusive: - self._completed.append( - { - "experiment": exp.to_dict(), - "result": { - "winner": result.winner, - "confidence": result.confidence, - "margin": result.margin, - }, - } - ) - completed_ids.append(exp_id) - for exp_id in completed_ids: - del self._experiments[exp_id] - return results - - @property - def active_count(self) -> int: - return len(self._experiments) - - @property - def completed_count(self) -> int: - return len(self._completed) - - def stats(self) -> dict[str, Any]: - return { - "active_experiments": self.active_count, - "completed_experiments": self.completed_count, - "total_trials": sum(e.total_trials for e in self._experiments.values()), - } - - -# ═══════════════════════════════════════════════════════════════════════ -# Conflict Detection (Updates/Extends/Derives relationships) -# ═══════════════════════════════════════════════════════════════════════ - - -class RuleRelation(Enum): - UPDATES = "updates" - EXTENDS = "extends" - DERIVES = "derives" - INDEPENDENT = "independent" - - -def _text_similarity(a: str, b: str) -> float: - if not a or not b: - return 0.0 - if a.strip().lower() == b.strip().lower(): - return 1.0 - diff = compute_diff(a, b) - return round(1.0 - diff.edit_distance, 4) - - -def _extract_keywords(text: str) -> set[str]: - stopwords = { - "a", - "an", - "the", - "is", - "are", - "was", - "were", - "be", - "been", - "have", - "has", - "had", - "do", - "does", - "did", - "will", - "would", - "could", - "should", - "may", - "might", - "can", - "to", - "of", - "in", - "for", - "on", - "with", - "at", - "by", - "from", - "as", - "it", - "its", - "this", - "that", - "and", - "but", - "or", - "not", - "no", - "if", - "then", - "when", - "so", - "than", - "too", - "very", - "just", - "also", - "all", - "each", - "every", - "any", - "some", - "only", - "i", - "we", - "you", - "they", - "he", - "she", - "my", - "your", - "our", - "their", - } - words = set(re.sub(r"[^\w\s]", " ", text.lower()).split()) - return words - stopwords - - -def _detect_opposite_direction(a_desc: str, b_desc: str) -> bool: - a_lower = a_desc.lower() - b_lower = b_desc.lower() - opposites = [ - ("use", "avoid"), - ("use", "don't use"), - ("use", "do not use"), - ("include", "exclude"), - ("include", "remove"), - ("include", "omit"), - ("add", "remove"), - ("add", "don't add"), - ("add", "do not add"), - ("always", "never"), - ("must", "must not"), - ("keep", "remove"), - ("prefer", "avoid"), - ("enable", "disable"), - ("before", "after"), - ("first", "last"), - ] - for pos, neg in opposites: - if (pos in a_lower and neg in b_lower) or (neg in a_lower and pos in b_lower): - return True - return False - - -def replace_contradicted_rule( - old_rule: Lesson, - new_lesson: Lesson, -) -> Lesson: - """Replace a contradicted rule with the new lesson's direction. - - Transfers the old rule's metadata (fire_count, sessions_since_fire) - to preserve history, but resets confidence and description to the - new lesson. This enables instant preference reversal. - - Returns the old_rule (mutated in place). - """ - old_rule.description = new_lesson.description - old_rule.confidence = new_lesson.confidence - old_rule.root_cause = f"replaced: was '{old_rule.description[:60]}'" - # Reset contradiction tracking - if hasattr(old_rule, "_contradiction_streak"): - old_rule._contradiction_streak = 0 - return old_rule - - -def detect_rule_conflict( - new_lesson: Lesson, - existing_rules: list[Lesson], - *, - update_threshold: float = 0.80, - extend_threshold: float = 0.60, - derive_min_cluster: int = 3, -) -> tuple[RuleRelation, Lesson | None]: - if not existing_rules: - return (RuleRelation.INDEPENDENT, None) - new_desc = new_lesson.description - new_keywords = _extract_keywords(new_desc) - best_similarity = 0.0 - best_rule: Lesson | None = None - category_cluster: list[Lesson] = [] - for rule in existing_rules: - if rule.state not in ELIGIBLE_STATES: - continue - similarity = _text_similarity(new_desc, rule.description) - if similarity > best_similarity: - best_similarity = similarity - best_rule = rule - if rule.category == new_lesson.category: - rule_keywords = _extract_keywords(rule.description) - if new_keywords & rule_keywords: - category_cluster.append(rule) - if ( - best_rule is not None - and best_similarity > update_threshold - and _detect_opposite_direction(new_desc, best_rule.description) - ): - return (RuleRelation.UPDATES, best_rule) - if ( - best_rule is not None - and best_similarity > extend_threshold - and not _detect_opposite_direction(new_desc, best_rule.description) - ): - return (RuleRelation.EXTENDS, best_rule) - if len(category_cluster) >= derive_min_cluster: - return (RuleRelation.DERIVES, None) - return (RuleRelation.INDEPENDENT, None) - - -def classify_all_relations( - new_lesson: Lesson, - existing_rules: list[Lesson], -) -> list[tuple[RuleRelation, Lesson, float]]: - results: list[tuple[RuleRelation, Lesson, float]] = [] - new_desc = new_lesson.description - for rule in existing_rules: - similarity = _text_similarity(new_desc, rule.description) - if similarity < 0.3: - continue - is_opposite = _detect_opposite_direction(new_desc, rule.description) - if similarity > 0.80 and is_opposite: - relation = RuleRelation.UPDATES - elif similarity > 0.60 and not is_opposite: - relation = RuleRelation.EXTENDS - else: - relation = RuleRelation.INDEPENDENT - results.append((relation, rule, similarity)) - results.sort(key=lambda x: x[2], reverse=True) - return results diff --git a/src/gradata/enhancements/rule_verifier.py b/src/gradata/enhancements/rule_verifier.py deleted file mode 100644 index 92d6bef1..00000000 --- a/src/gradata/enhancements/rule_verifier.py +++ /dev/null @@ -1,243 +0,0 @@ -"""Rule verification: pre-execution filtering and post-hoc output checking. - -Pre-execution: TOOL_RULE_MATRIX maps tool/task types to relevant rule -categories so irrelevant rules are skipped before verification runs. - -Post-hoc: scans output text for checkable patterns (em dashes, pricing, -links) and reports violations. Feeds results back into confidence scoring. -""" - -from __future__ import annotations - -import re -import sqlite3 -from dataclasses import dataclass -from datetime import UTC, datetime -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from pathlib import Path - -# --------------------------------------------------------------------------- -# Pre-execution decision tree -# --------------------------------------------------------------------------- - -# Maps tool/task types to the rule categories that are relevant for them. -# If a tool_type is not in the matrix, all categories are checked (safe default). -# Extend this dict at runtime via update_tool_rule_matrix() — no code changes needed. -TOOL_RULE_MATRIX: dict[str, list[str]] = { - "Write": ["DRAFTING", "ARCHITECTURE", "IP_PROTECTION", "ACCURACY"], - "Edit": ["DRAFTING", "ARCHITECTURE", "ACCURACY"], - "Bash": ["PROCESS", "VERIFICATION", "CONSTRAINT"], - "email_draft": ["DRAFTING", "COMMUNICATION", "POSITIONING", "PRICING"], - "demo_prep": ["DEMO_PREP", "ACCURACY", "PRESENTATION"], - "prospecting": ["LEADS", "CONSTRAINT", "DATA_INTEGRITY"], - "code": ["ARCHITECTURE", "THOROUGHNESS", "VERIFICATION"], -} - - - -def should_verify(tool_type: str, rule_category: str) -> bool: - """Pre-execution gate: is this rule relevant for this tool/task? - - If *tool_type* is not in TOOL_RULE_MATRIX, returns ``True`` (verify - everything by default — safe fallback for unknown tools). - - Args: - tool_type: The current tool or task type (e.g. "email_draft", "Bash"). - rule_category: The category of the rule being considered. - - Returns: - ``True`` if the rule should be checked, ``False`` to skip it. - """ - relevant = TOOL_RULE_MATRIX.get(tool_type) - if relevant is None: - return True # unknown tool -> check everything - return rule_category.upper() in (c.upper() for c in relevant) - - -def get_relevant_rules(tool_type: str, all_rules: list[dict]) -> list[dict]: - """Filter rules to only those relevant for the current tool/task. - - Each rule dict must have a ``"category"`` key. Rules whose category is - not relevant for *tool_type* (per TOOL_RULE_MATRIX) are dropped. - - Args: - tool_type: The current tool or task type. - all_rules: Full list of rule dicts (each with at least ``"category"``). - - Returns: - Filtered list of rule dicts relevant to the tool type. - """ - return [ - rule for rule in all_rules - if should_verify(tool_type, rule.get("category", "UNKNOWN")) - ] - - - -# --------------------------------------------------------------------------- -# Verification pattern registry -# --------------------------------------------------------------------------- - -# (keyword_in_rule, regex_pattern, should_be_absent, description) -_PATTERNS: list[tuple[str, str, bool, str]] = [ - ("em dash", r"\u2014|--", True, "contains em dash or double dash"), - ("em dashes", r"\u2014|--", True, "contains em dash or double dash"), - ("pricing", r"\$\d+", True, "contains dollar amount"), - ("dollar", r"\$\d+", True, "contains dollar amount"), - ("booking link", r"https?://\S+/\S+", False, "missing booking link"), - ("hyperlink", r" list[tuple[re.Pattern, bool, str]]: - """Scan rule description for checkable patterns. - - Returns list of (compiled_regex, should_be_absent, violation_description). - """ - desc_lower = rule_description.lower() - checks = [] - seen = set() - for keyword, pattern, absent, desc in _PATTERNS: - if keyword in desc_lower and pattern not in seen: - checks.append((re.compile(pattern, re.IGNORECASE), absent, desc)) - seen.add(pattern) - return checks - - -def verify_rules( - output: str, - applied_rules: list[dict], - context: dict | None = None, -) -> list[RuleVerification]: - """Check output against applied rules for verifiable violations. - - When *context* contains a ``"tool_type"`` key, pre-execution filtering - via :func:`should_verify` is applied first — rules whose category is - irrelevant for the tool are skipped entirely, making verification both - faster and less prone to false positives from mismatched rules. - - Args: - output: The AI-generated text to check. - applied_rules: List of dicts with at least 'category' and 'description'. - context: Optional context dict. Recognized keys: - - ``tool_type``: enables pre-execution category filtering. - - Returns: - List of RuleVerification results (one per checkable rule). - """ - # Pre-execution filter: skip rules irrelevant to the current tool - tool_type = (context or {}).get("tool_type", "") - if tool_type: - applied_rules = get_relevant_rules(tool_type, applied_rules) - - results = [] - for rule in applied_rules: - desc = rule.get("description", "") - cat = rule.get("category", "UNKNOWN") - checks = auto_detect_verification(desc) - if not checks: - continue - - for regex, should_be_absent, violation_desc in checks: - match = regex.search(output) - if should_be_absent and match: - results.append(RuleVerification( - rule_category=cat, - rule_description=desc[:200], - passed=False, - violation_detail=violation_desc, - output_snippet=output[max(0, match.start() - 30):match.end() + 30][:200], - )) - elif not should_be_absent and not match: - results.append(RuleVerification( - rule_category=cat, - rule_description=desc[:200], - passed=False, - violation_detail=violation_desc, - output_snippet=output[:200], - )) - else: - results.append(RuleVerification( - rule_category=cat, - rule_description=desc[:200], - passed=True, - )) - return results - - -# --------------------------------------------------------------------------- -# Persistence -# --------------------------------------------------------------------------- - -_CREATE_TABLE = """ -CREATE TABLE IF NOT EXISTS rule_verifications ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - session INTEGER, - rule_category TEXT, - rule_description TEXT, - passed BOOLEAN, - violation_detail TEXT, - output_snippet TEXT, - timestamp TEXT DEFAULT CURRENT_TIMESTAMP -) -""" - - -def ensure_table(db_path: Path) -> None: - from gradata._db import ensure_table as _ensure - from gradata._db import get_connection - conn = get_connection(db_path) - _ensure(conn, _CREATE_TABLE) - conn.close() - - -def log_verification( - session: int, - results: list[RuleVerification], - db_path: Path, -) -> None: - """Write verification results to SQLite.""" - ensure_table(db_path) - now = datetime.now(UTC).isoformat() - with sqlite3.connect(str(db_path)) as conn: - for r in results: - conn.execute( - "INSERT INTO rule_verifications " - "(session, rule_category, rule_description, passed, violation_detail, output_snippet, timestamp) " - "VALUES (?, ?, ?, ?, ?, ?, ?)", - (session, r.rule_category, r.rule_description, r.passed, - r.violation_detail, r.output_snippet, now), - ) - - -def get_verification_stats(db_path: Path) -> dict: - """Return summary stats from rule_verifications table.""" - ensure_table(db_path) - with sqlite3.connect(str(db_path)) as conn: - total = conn.execute("SELECT COUNT(*) FROM rule_verifications").fetchone()[0] - passed = conn.execute("SELECT COUNT(*) FROM rule_verifications WHERE passed = 1").fetchone()[0] - violations = conn.execute( - "SELECT rule_category, COUNT(*) FROM rule_verifications " - "WHERE passed = 0 GROUP BY rule_category ORDER BY COUNT(*) DESC" - ).fetchall() - - return { - "total_checks": total, - "passed": passed, - "pass_rate": passed / total if total > 0 else 1.0, - "violations_by_category": {cat: count for cat, count in violations}, - } diff --git a/src/gradata/enhancements/super_meta_rules.py b/src/gradata/enhancements/super_meta_rules.py deleted file mode 100644 index e1278b3d..00000000 --- a/src/gradata/enhancements/super_meta_rules.py +++ /dev/null @@ -1,197 +0,0 @@ -""" -Super-Meta-Rule Logic — tier-2 and tier-3 principle emergence. -============================================================== -Super-meta-rule discovery requires Gradata Cloud. The open-source SDK -preserves the data model and formatting API; discovery and refresh are -no-ops that return empty results. - -All SQLite persistence lives in ``meta_rules_storage.py``; core -meta-rule logic lives in ``meta_rules.py``. -""" - -from __future__ import annotations - -import logging - -from gradata.enhancements.meta_rules import ( - TIER_SUPER_META, - TIER_UNIVERSAL, - MetaRule, - SuperMetaRule, - evaluate_conditions, -) - -_log = logging.getLogger(__name__) - - -# --------------------------------------------------------------------------- -# Discovery (requires Gradata Cloud) -# --------------------------------------------------------------------------- - - -def detect_super_meta_rules( - meta_rules: list[MetaRule], - min_group_size: int = 3, - current_session: int = 0, -) -> list[SuperMetaRule]: - """Discover tier-2 super-meta-rules from groups of related meta-rules. - - Requires Gradata Cloud. Returns empty list in open-source build. - - Args: - meta_rules: All currently active meta-rules. - min_group_size: Minimum group size to form a super-meta-rule. - current_session: Current session number. - - Returns: - Empty list (discovery requires Gradata Cloud). - """ - _log.info("Super-meta-rule discovery requires Gradata Cloud") - return [] - - -def detect_universal_rules( - super_metas: list[SuperMetaRule], - min_group_size: int = 3, - current_session: int = 0, -) -> list[SuperMetaRule]: - """Discover tier-3 universal principles from super-meta-rules. - - Requires Gradata Cloud. Returns empty list in open-source build. - - Args: - super_metas: All current tier-2 super-meta-rules. - min_group_size: Minimum group size. - current_session: Current session number. - - Returns: - Empty list (discovery requires Gradata Cloud). - """ - _log.info("Universal rule discovery requires Gradata Cloud") - return [] - - -def validate_super_meta_rule( - smeta: SuperMetaRule, - current_meta_rules: list[MetaRule], -) -> bool: - """Check if a super-meta-rule is still valid. - - A super-meta-rule is invalid when fewer than 2 of its source - meta-rules still exist (AGM contraction). - - Args: - smeta: The super-meta-rule to validate. - current_meta_rules: Currently active meta-rules. - - Returns: - ``True`` if still supported by enough source meta-rules. - """ - current_ids = {m.id for m in current_meta_rules} - surviving = sum(1 for sid in smeta.source_meta_rule_ids if sid in current_ids) - return surviving >= 2 - - -def refresh_super_meta_rules( - meta_rules: list[MetaRule], - existing_supers: list[SuperMetaRule], - current_session: int = 0, - min_group_size: int = 3, -) -> list[SuperMetaRule]: - """Re-discover super-meta-rules, keeping valid existing ones. - - In the open-source build, this validates existing super-meta-rules - but does not discover new ones. - - Args: - meta_rules: All currently active meta-rules. - existing_supers: Previously discovered super-meta-rules. - current_session: Current session number. - min_group_size: Minimum group size (unused in open-source build). - - Returns: - Validated subset of *existing_supers*. - """ - _log.info("Super-meta-rule refresh requires Gradata Cloud") - valid: list[SuperMetaRule] = [] - for smeta in existing_supers: - if validate_super_meta_rule(smeta, meta_rules): - smeta.last_validated_session = current_session - valid.append(smeta) - valid.sort(key=lambda s: s.confidence, reverse=True) - return valid - - -# --------------------------------------------------------------------------- -# Formatting -# --------------------------------------------------------------------------- - - -def format_super_meta_rules( - supers: list[SuperMetaRule], - context: str | None = None, - condition_context: dict[str, object] | None = None, -) -> str: - """Format super-meta-rules for injection into LLM context. - - Super-meta-rules go FIRST in the prompt (primacy positioning) as - they represent the highest-priority generalised principles. - - Args: - supers: Super-meta-rules to format (tier 2 and 3). - context: Optional task-context label for re-ranking. - condition_context: Optional dict for precondition/anti-condition - filtering. - - Returns: - Formatted string block, or ``""`` if *supers* is empty. - """ - if not supers: - return "" - - if condition_context is not None: - supers = [s for s in supers if evaluate_conditions(s, condition_context)] - - if not supers: - return "" - - if context: - ctx = context - weighted: list[tuple[SuperMetaRule, float]] = [] - for s in supers: - w = s.context_weights.get(ctx, s.context_weights.get("default", 1.0)) - weighted.append((s, s.confidence * w)) - weighted.sort(key=lambda t: t[1], reverse=True) - supers = [s for s, _ in weighted] - - universals = [s for s in supers if s.tier >= TIER_UNIVERSAL] - tier2 = [s for s in supers if s.tier == TIER_SUPER_META] - - lines: list[str] = [] - - if universals: - lines.append("## Universal Principles (highest priority)") - for i, u in enumerate(universals, start=1): - n = len(u.source_meta_rule_ids) - cats = ", ".join(u.source_categories[:5]) - lines.append( - f"{i}. [UNIV:{u.confidence:.2f}|{n} super-rules|{cats}] " - f"{u.abstraction}" - ) - for ex in u.examples: - lines.append(f" - {ex}") - - if tier2: - lines.append("") - lines.append("## Super-Meta-Rules (compound meta-principles)") - for i, s in enumerate(tier2, start=1): - n = len(s.source_meta_rule_ids) - cats = ", ".join(s.source_categories[:5]) - lines.append( - f"{i}. [SMETA:{s.confidence:.2f}|{n} meta-rules|{cats}] " - f"{s.abstraction}" - ) - for ex in s.examples: - lines.append(f" - {ex}") - - return "\n".join(lines) diff --git a/src/gradata/rules/budget.py b/src/gradata/rules/budget.py deleted file mode 100644 index 1186b7fa..00000000 --- a/src/gradata/rules/budget.py +++ /dev/null @@ -1,43 +0,0 @@ -"""Context-budget-aware rule injection compression.""" - -from __future__ import annotations - -from enum import IntEnum -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from gradata._types import Lesson - - -class ContextBudget(IntEnum): - EMERGENCY = 1 # Single highest-confidence rule, bare text - MINIMAL = 2 # Top 2 rules, description only - COMPACT = 3 # RULE-state only, max 3, compressed format - STANDARD = 4 # max_rules rules, full XML, no examples - FULL = 5 # All rules, full formatting, examples included - - -def filter_by_budget(lessons: list[Lesson], budget: int = 5, max_rules: int = 5) -> list[Lesson]: - """Filter and limit lessons based on context budget level.""" - from gradata._types import ELIGIBLE_STATES, LessonState - - if budget <= 3: # EMERGENCY / MINIMAL / COMPACT - eligible = [l for l in lessons if l.state == LessonState.RULE] - eligible.sort(key=lambda l: -l.confidence) - return eligible[:budget] - else: # STANDARD / FULL - eligible = [l for l in lessons if l.state in ELIGIBLE_STATES] - eligible.sort(key=lambda l: -l.confidence) - return eligible[:max_rules] - - -def format_by_budget(lesson: Lesson, budget: int = 5) -> str: - """Format a single rule's injection text based on budget.""" - if budget <= 1: - return lesson.description - elif budget <= 2: - return f"{lesson.category}: {lesson.description}" - elif budget <= 3: - return f"{lesson.category}: {lesson.description}" - else: - return f'{lesson.category}: {lesson.description}' diff --git a/src/gradata/rules/rw_lock.py b/src/gradata/rules/rw_lock.py deleted file mode 100644 index 0444796f..00000000 --- a/src/gradata/rules/rw_lock.py +++ /dev/null @@ -1,54 +0,0 @@ -"""Reader-writer lock for concurrent brain access.""" - -from __future__ import annotations - -import threading -from contextlib import contextmanager - - -class RWLock: - """Multiple readers OR one writer. No starvation.""" - - def __init__(self): - self._lock = threading.Lock() - self._readers = 0 - self._writer_event = threading.Event() - self._writer_event.set() # no writer initially - - def acquire_read(self): - self._writer_event.wait() - with self._lock: - self._readers += 1 - - def release_read(self): - with self._lock: - self._readers -= 1 - if self._readers == 0: - self._writer_event.set() - - def acquire_write(self): - self._writer_event.clear() - # Wait for readers to finish - while True: - with self._lock: - if self._readers == 0: - return - - def release_write(self): - self._writer_event.set() - - @contextmanager - def read_lock(self): - self.acquire_read() - try: - yield - finally: - self.release_read() - - @contextmanager - def write_lock(self): - self.acquire_write() - try: - yield - finally: - self.release_write() diff --git a/src/gradata/security/privacy_model.py b/src/gradata/security/privacy_model.py deleted file mode 100644 index 423877d3..00000000 --- a/src/gradata/security/privacy_model.py +++ /dev/null @@ -1,113 +0,0 @@ -"""Privacy model — differential privacy + sanitization for cloud sharing. - -Implements three privacy primitives for the Gradata sharing pipeline: - -1. **Laplace noise** — calibrated DP noise on usage statistics so that - individual fire/misfire counts cannot reveal exact user behavior. -2. **Sanitization** — strips PII-risk fields (drafts, corrections, event IDs) - before any lesson leaves the local brain. -3. **k-anonymity gate** — a rule must exist in k+ brains before it can - appear in the marketplace, preventing unique-rule re-identification. - -Text-level re-identification (inferring a user from rule description text) -is explicitly out of scope for v1. See THREAT_MODEL.md for details. -""" - -from __future__ import annotations - -import math -import random -from typing import Any - -MIN_K_ANONYMITY = 5 # Rule must exist in 5+ brains before marketplace listing - - -def add_laplace_noise( - value: float, - sensitivity: float = 1.0, - epsilon: float = 1.0, -) -> float: - """Add calibrated Laplace noise for differential privacy. - - Higher epsilon = less noise (less privacy, more utility). - Lower epsilon = more noise (more privacy, less utility). - - The Laplace mechanism satisfies epsilon-differential privacy when - ``scale = sensitivity / epsilon``. - - Args: - value: The true numeric value. - sensitivity: Maximum change in value from one individual's data. - epsilon: Privacy budget parameter (> 0). - - Returns: - The value with added Laplace noise. - """ - scale = sensitivity / epsilon - # Inverse CDF sampling of Laplace(0, scale) - u = random.random() - 0.5 - laplace = -scale * math.copysign(1, u) * math.log(1 - 2 * abs(u)) - return value + laplace - - -def sanitize_for_sharing( - lesson_dict: dict[str, Any], - epsilon: float = 1.0, -) -> dict[str, Any]: - """Prepare a lesson for cloud sharing with privacy protections. - - Pipeline: - 1. Add Laplace noise to: fire_count, misfire_count, sessions_since_fire - 2. Strip PII-risk fields: example_draft, example_corrected, - correction_event_ids, memory_ids, agent_type - 3. Keep: description, category, confidence, state, path - (needed for tree structure and matching) - - NOTE: Text-level re-identification (inferring user from rule descriptions) - is explicitly OUT OF SCOPE for v1. The description field is shared as-is. - Future work: LLM-based text redaction before export. - - Args: - lesson_dict: Raw lesson dictionary from the local brain. - epsilon: Privacy budget for Laplace noise on statistics. - - Returns: - A new dictionary safe for cloud transmission. - """ - sanitized = dict(lesson_dict) - - # Add noise to statistics - for field in ("fire_count", "misfire_count", "sessions_since_fire"): - if field in sanitized and isinstance(sanitized[field], (int, float)): - sanitized[field] = max( - 0, - round(add_laplace_noise(float(sanitized[field]), epsilon=epsilon)), - ) - - # Strip PII-risk fields - for field in ( - "example_draft", - "example_corrected", - "correction_event_ids", - "memory_ids", - "agent_type", - ): - sanitized.pop(field, None) - - return sanitized - - -def check_k_anonymity(rule_count_across_brains: int) -> bool: - """Check if a rule meets the k-anonymity threshold for marketplace listing. - - A rule that exists in fewer than MIN_K_ANONYMITY brains could be used - to re-identify the user who created it. This gate prevents listing - until sufficient adoption dilutes that signal. - - Args: - rule_count_across_brains: Number of distinct brains containing this rule. - - Returns: - True if the rule is safe to list in the marketplace. - """ - return rule_count_across_brains >= MIN_K_ANONYMITY diff --git a/tests/test_adaptations.py b/tests/test_adaptations.py index ca8cf444..a466c333 100644 --- a/tests/test_adaptations.py +++ b/tests/test_adaptations.py @@ -2259,7 +2259,6 @@ def test_all_new_modules_in_capabilities(self): expected = [ "loop_detection", "middleware_chain", "git_backfill", "auto_correct_hook", "reporting", "quality_monitoring", - "rule_evolution", ] for name in expected: assert name in modules, f"Missing: {name}" @@ -2432,119 +2431,6 @@ def test_detection_has_replacement_hint(self): assert r.replacement_hint != "" -# --------------------------------------------------------------------------- -# 25. Rule A/B Testing -# --------------------------------------------------------------------------- - -from gradata.enhancements.rule_evolution import ( - ExperimentManager, - ExperimentResult, - RuleExperiment, - wilson_score_interval, -) - - -class TestRuleABTesting: - def test_wilson_score_zero_trials(self): - lower, upper = wilson_score_interval(0, 0) - assert lower == 0.0 - assert upper == 0.0 - - def test_wilson_score_all_success(self): - lower, upper = wilson_score_interval(100, 100) - assert lower > 0.9 - assert upper == 1.0 - - def test_wilson_score_half(self): - lower, upper = wilson_score_interval(50, 100) - assert 0.3 < lower < 0.5 - assert 0.5 < upper < 0.7 - - def test_experiment_basic(self): - exp = RuleExperiment( - experiment_id="test", - variant_a="Rule A", - variant_b="Rule B", - ) - assert exp.total_trials == 0 - - def test_experiment_record(self): - exp = RuleExperiment( - experiment_id="test", - variant_a="A", variant_b="B", - ) - exp.record("a", success=True) - exp.record("b", success=False) - assert exp.total_trials == 2 - - def test_experiment_invalid_variant(self): - exp = RuleExperiment(experiment_id="test", variant_a="A", variant_b="B") - with pytest.raises(ValueError, match="'a' or 'b'"): - exp.record("c", success=True) - - def test_experiment_inconclusive_few_trials(self): - exp = RuleExperiment( - experiment_id="test", variant_a="A", variant_b="B", - min_trials=20, - ) - for _ in range(5): - exp.record("a", success=True) - exp.record("b", success=False) - result = exp.evaluate() - assert not result.sufficient_data - assert result.winner is None - - def test_experiment_conclusive(self): - exp = RuleExperiment( - experiment_id="test", variant_a="A", variant_b="B", - min_trials=20, min_margin=0.10, - ) - # A wins 90%, B wins 50% - for _ in range(25): - exp.record("a", success=True) - for _ in range(5): - exp.record("a", success=False) - for _ in range(15): - exp.record("b", success=True) - for _ in range(15): - exp.record("b", success=False) - result = exp.evaluate() - assert result.sufficient_data - assert result.winner == "a" - assert result.confidence > 0 - - def test_experiment_serialization(self): - exp = RuleExperiment(experiment_id="test", variant_a="A", variant_b="B") - exp.record("a", success=True) - exp.record("b", success=False) - d = exp.to_dict() - restored = RuleExperiment.from_dict(d) - assert restored.total_trials == 2 - - def test_experiment_assign(self): - exp = RuleExperiment(experiment_id="test", variant_a="A", variant_b="B") - variants = {exp.assign() for _ in range(20)} - assert "a" in variants - assert "b" in variants - - def test_manager_create(self): - mgr = ExperimentManager() - exp = mgr.create("test", "A", "B", category="TONE") - assert mgr.active_count == 1 - assert mgr.get("test") is exp - - def test_manager_evaluate_all(self): - mgr = ExperimentManager() - mgr.create("test", "A", "B") - results = mgr.evaluate_all() - assert len(results) == 1 - - def test_manager_stats(self): - mgr = ExperimentManager() - mgr.create("test", "A", "B") - stats = mgr.stats() - assert stats["active_experiments"] == 1 - # =========================================================================== # Tree of Thoughts diff --git a/tests/test_budget_injection.py b/tests/test_budget_injection.py deleted file mode 100644 index 9933c674..00000000 --- a/tests/test_budget_injection.py +++ /dev/null @@ -1,92 +0,0 @@ -"""Tests for budget-aware rule injection.""" - -from gradata._types import Lesson, LessonState -from gradata.rules.budget import ContextBudget, filter_by_budget, format_by_budget - - -def _make_lessons(): - return [ - Lesson( - date="2026-01-01", - state=LessonState.RULE, - confidence=0.95, - category="TONE", - description="Be direct", - ), - Lesson( - date="2026-01-02", - state=LessonState.RULE, - confidence=0.91, - category="ACCURACY", - description="Cite sources", - ), - Lesson( - date="2026-01-03", - state=LessonState.PATTERN, - confidence=0.70, - category="TONE", - description="Match energy", - ), - Lesson( - date="2026-01-04", - state=LessonState.RULE, - confidence=0.92, - category="STRUCTURE", - description="Lead with answer", - ), - Lesson( - date="2026-01-05", - state=LessonState.INSTINCT, - confidence=0.45, - category="FORMAT", - description="No em dashes", - ), - ] - - -class TestFilterByBudget: - def test_emergency_returns_one(self): - result = filter_by_budget(_make_lessons(), budget=1) - assert len(result) == 1 - assert result[0].confidence == 0.95 - - def test_minimal_returns_two_rules_only(self): - result = filter_by_budget(_make_lessons(), budget=2) - assert len(result) == 2 - assert all(l.state == LessonState.RULE for l in result) - - def test_compact_returns_three_rules_only(self): - result = filter_by_budget(_make_lessons(), budget=3) - assert len(result) == 3 - assert all(l.state == LessonState.RULE for l in result) - - def test_standard_includes_patterns(self): - result = filter_by_budget(_make_lessons(), budget=4) - states = {l.state for l in result} - assert LessonState.RULE in states - - def test_full_same_as_standard(self): - r4 = filter_by_budget(_make_lessons(), budget=4) - r5 = filter_by_budget(_make_lessons(), budget=5) - assert len(r4) == len(r5) - - def test_instinct_excluded_at_all_levels(self): - for budget in range(1, 6): - result = filter_by_budget(_make_lessons(), budget=budget) - assert all(l.state != LessonState.INSTINCT for l in result) - - -class TestFormatByBudget: - def test_emergency_bare_text(self): - l = _make_lessons()[0] - assert format_by_budget(l, budget=1) == "Be direct" - - def test_minimal_category_prefix(self): - l = _make_lessons()[0] - assert format_by_budget(l, budget=2) == "TONE: Be direct" - - def test_standard_has_xml_and_confidence(self): - l = _make_lessons()[0] - result = format_by_budget(l, budget=4) - assert " smaller variance (less privacy, more utility).""" - original = 100.0 - n = 2000 - low_eps = [add_laplace_noise(original, epsilon=0.1) for _ in range(n)] - high_eps = [add_laplace_noise(original, epsilon=10.0) for _ in range(n)] - var_low = statistics.variance(low_eps) - var_high = statistics.variance(high_eps) - assert var_low > var_high, ( - f"Low-epsilon variance ({var_low:.1f}) should exceed " - f"high-epsilon variance ({var_high:.1f})" - ) - - def test_sensitivity_scales_noise(self): - """Higher sensitivity -> more noise.""" - original = 100.0 - n = 2000 - low_sens = [add_laplace_noise(original, sensitivity=0.1) for _ in range(n)] - high_sens = [add_laplace_noise(original, sensitivity=10.0) for _ in range(n)] - var_low = statistics.variance(low_sens) - var_high = statistics.variance(high_sens) - assert var_high > var_low - - def test_zero_value(self): - """Noise works on zero input.""" - result = add_laplace_noise(0.0) - assert isinstance(result, float) - - -# --------------------------------------------------------------------------- -# Sanitization -# --------------------------------------------------------------------------- -class TestSanitize: - @pytest.fixture() - def sample_lesson(self) -> dict: - return { - "description": "Always use Oxford comma", - "category": "style", - "confidence": 0.85, - "state": "RULE", - "path": "style.punctuation", - "fire_count": 42, - "misfire_count": 3, - "sessions_since_fire": 7, - "example_draft": "I like cats dogs and birds", - "example_corrected": "I like cats, dogs, and birds", - "correction_event_ids": ["evt_001", "evt_002"], - "memory_ids": ["mem_abc"], - "agent_type": "sales", - } - - def test_strips_pii_fields(self, sample_lesson: dict): - result = sanitize_for_sharing(sample_lesson) - for field in ( - "example_draft", - "example_corrected", - "correction_event_ids", - "memory_ids", - "agent_type", - ): - assert field not in result - - def test_keeps_functional_fields(self, sample_lesson: dict): - result = sanitize_for_sharing(sample_lesson) - for field in ("description", "category", "confidence", "state", "path"): - assert field in result - assert result[field] == sample_lesson[field] - - def test_noises_statistics(self, sample_lesson: dict): - """At least one stat field should differ after sanitization.""" - results = [sanitize_for_sharing(sample_lesson) for _ in range(20)] - stat_fields = ("fire_count", "misfire_count", "sessions_since_fire") - changed = False - for r in results: - for f in stat_fields: - if r[f] != sample_lesson[f]: - changed = True - break - assert changed, "Statistics should be noised" - - def test_statistics_non_negative(self, sample_lesson: dict): - """Noised values must be clamped to >= 0.""" - # Use a small original so noise could push negative - lesson = {**sample_lesson, "fire_count": 1, "misfire_count": 0, "sessions_since_fire": 0} - for _ in range(100): - result = sanitize_for_sharing(lesson, epsilon=0.1) - assert result["fire_count"] >= 0 - assert result["misfire_count"] >= 0 - assert result["sessions_since_fire"] >= 0 - - def test_does_not_mutate_original(self, sample_lesson: dict): - original_copy = dict(sample_lesson) - sanitize_for_sharing(sample_lesson) - assert sample_lesson == original_copy - - def test_missing_optional_fields_ok(self): - """Sanitize handles lessons without optional fields gracefully.""" - minimal = {"description": "Test rule", "confidence": 0.5} - result = sanitize_for_sharing(minimal) - assert result["description"] == "Test rule" - assert result["confidence"] == 0.5 - - def test_epsilon_propagates_to_noise(self, sample_lesson: dict): - """High epsilon should produce values closer to originals on average.""" - n = 500 - high_eps = [sanitize_for_sharing(sample_lesson, epsilon=100.0) for _ in range(n)] - deviations = [abs(r["fire_count"] - sample_lesson["fire_count"]) for r in high_eps] - mean_dev = statistics.mean(deviations) - # With epsilon=100, noise scale is tiny; mean deviation should be < 2 - assert mean_dev < 2.0, f"Mean deviation {mean_dev} too high for epsilon=100" - - -# --------------------------------------------------------------------------- -# k-anonymity -# --------------------------------------------------------------------------- -class TestKAnonymity: - def test_below_threshold_fails(self): - assert check_k_anonymity(0) is False - assert check_k_anonymity(3) is False - assert check_k_anonymity(4) is False - - def test_at_threshold_passes(self): - assert check_k_anonymity(5) is True - - def test_above_threshold_passes(self): - assert check_k_anonymity(100) is True - - def test_threshold_value(self): - assert MIN_K_ANONYMITY == 5 diff --git a/tests/test_pubsub_pipeline.py b/tests/test_pubsub_pipeline.py deleted file mode 100644 index 30a717f9..00000000 --- a/tests/test_pubsub_pipeline.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Tests for pub/sub event pipeline.""" - -from gradata.enhancements.pubsub_pipeline import PubSubPipeline - - -class TestPubSubPipeline: - def test_subscribe_and_emit(self): - pipe = PubSubPipeline() - received = [] - pipe.subscribe("CORRECTION", lambda d: received.append(d)) - pipe.emit("CORRECTION", {"text": "hello"}) - assert len(received) == 1 - assert received[0]["text"] == "hello" - - def test_multiple_subscribers(self): - pipe = PubSubPipeline() - results = [] - pipe.subscribe("X", lambda d: results.append("a")) - pipe.subscribe("X", lambda d: results.append("b")) - pipe.emit("X") - assert results == ["a", "b"] - - def test_stage_failure_doesnt_block(self): - pipe = PubSubPipeline() - results = [] - pipe.subscribe("X", lambda d: 1 / 0) # will raise - pipe.subscribe("X", lambda d: results.append("ok")) - pipe.emit("X") - assert results == ["ok"] - - def test_unsubscribed_event_noop(self): - pipe = PubSubPipeline() - results = pipe.emit("UNKNOWN") - assert results == [] - - def test_event_log_tracked(self): - pipe = PubSubPipeline() - pipe.emit("A", {"x": 1}) - pipe.emit("B", {"y": 2}) - assert len(pipe.event_log) == 2 - assert pipe.event_log[0]["type"] == "A" - - def test_ordering_preserved(self): - pipe = PubSubPipeline() - order = [] - pipe.subscribe("X", lambda d: order.append(1)) - pipe.subscribe("X", lambda d: order.append(2)) - pipe.subscribe("X", lambda d: order.append(3)) - pipe.emit("X") - assert order == [1, 2, 3] diff --git a/tests/test_rule_verifier.py b/tests/test_rule_verifier.py deleted file mode 100644 index 6a6834ea..00000000 --- a/tests/test_rule_verifier.py +++ /dev/null @@ -1,260 +0,0 @@ -""" -Tests for enhancements/rule_verifier.py (S71 module, zero prior coverage). - -Covers: -- auto_detect_verification() pattern matching -- verify_rules() violation detection -- log_verification() + get_verification_stats() SQLite persistence -- RuleVerification dataclass - -Run: cd sdk && python -m pytest tests/test_rule_verifier.py -v -""" - -from __future__ import annotations - -import tempfile -from pathlib import Path - -import pytest - -from gradata.enhancements.rule_verifier import ( - RuleVerification, - auto_detect_verification, - ensure_table, - get_verification_stats, - log_verification, - verify_rules, -) - - -# =========================================================================== -# auto_detect_verification -# =========================================================================== - -class TestAutoDetectVerification: - """auto_detect_verification() scans rule descriptions for checkable patterns.""" - - def test_em_dash_rule(self): - checks = auto_detect_verification("Never use em dashes in emails") - assert len(checks) >= 1 - # Should detect em dash pattern as should_be_absent=True - regex, absent, desc = checks[0] - assert absent is True - assert "em dash" in desc - - def test_pricing_rule(self): - checks = auto_detect_verification("Do not include pricing or dollar amounts") - assert len(checks) >= 1 - regexes_absent = [(r, a) for r, a, _ in checks] - assert any(a is True for _, a in regexes_absent) - - def test_booking_link_rule(self): - checks = auto_detect_verification("Always include a booking link in outreach") - assert len(checks) >= 1 - # Booking link should be should_be_absent=False (must be present) - regex, absent, desc = checks[0] - assert absent is False - assert "booking" in desc - - def test_bold_rule(self): - checks = auto_detect_verification("No bold mid-paragraph text") - assert len(checks) >= 1 - regex, absent, desc = checks[0] - assert absent is True - - def test_no_match_returns_empty(self): - checks = auto_detect_verification("Use a professional tone in all emails") - assert len(checks) == 0 - - def test_multiple_patterns_in_one_rule(self): - """A rule mentioning both em dashes and pricing should return multiple checks.""" - checks = auto_detect_verification("Never use em dashes or include dollar amounts") - assert len(checks) >= 2 - - def test_annual_pricing_detection(self): - checks = auto_detect_verification("Do not reference annual pricing in emails") - keywords = [desc for _, _, desc in checks] - assert any("annual" in d for d in keywords) - - -# =========================================================================== -# verify_rules -# =========================================================================== - -class TestVerifyRules: - """verify_rules() checks AI output against applied rules.""" - - def test_clean_output_passes(self): - rules = [{"category": "DRAFTING", "description": "Never use em dashes"}] - output = "This is a clean sentence with no dashes." - results = verify_rules(output, rules) - assert len(results) >= 1 - assert all(r.passed for r in results) - - def test_em_dash_violation_detected(self): - rules = [{"category": "DRAFTING", "description": "Never use em dashes"}] - output = "This has an em dash \u2014 right here." - results = verify_rules(output, rules) - violations = [r for r in results if not r.passed] - assert len(violations) >= 1 - assert violations[0].rule_category == "DRAFTING" - assert "em dash" in violations[0].violation_detail - - def test_double_dash_also_caught(self): - rules = [{"category": "DRAFTING", "description": "Never use em dashes"}] - output = "This has a double dash -- right here." - results = verify_rules(output, rules) - violations = [r for r in results if not r.passed] - assert len(violations) >= 1 - - def test_pricing_violation(self): - rules = [{"category": "PRICING", "description": "Do not include dollar pricing"}] - output = "Our starter plan is $60/month." - results = verify_rules(output, rules) - violations = [r for r in results if not r.passed] - assert len(violations) >= 1 - - def test_missing_booking_link_violation(self): - rules = [{"category": "DRAFTING", "description": "Always include a booking link hyperlink"}] - output = "Let me know if you want to chat." - results = verify_rules(output, rules) - violations = [r for r in results if not r.passed] - assert len(violations) >= 1 - assert "booking" in violations[0].violation_detail - - def test_booking_link_present_passes(self): - rules = [{"category": "DRAFTING", "description": "Always include a booking link"}] - output = 'Book a time here: Link' - results = verify_rules(output, rules) - # The hyperlink check should pass since = 1 - assert len(violations[0].output_snippet) <= 200 - - def test_description_truncated(self): - long_desc = "x" * 300 - rules = [{"category": "DRAFTING", "description": f"Never use em dashes. {long_desc}"}] - results = verify_rules("text \u2014 here", rules) - for r in results: - assert len(r.rule_description) <= 200 - - def test_context_parameter_accepted(self): - """Context param should be accepted even if not currently used.""" - rules = [{"category": "DRAFTING", "description": "Never use em dashes"}] - results = verify_rules("clean text", rules, context={"task_type": "email"}) - assert isinstance(results, list) - - def test_multiple_rules_checked(self): - rules = [ - {"category": "DRAFTING", "description": "Never use em dashes"}, - {"category": "PRICING", "description": "Do not include dollar amounts"}, - ] - output = "Text with \u2014 dash and $100 price." - results = verify_rules(output, rules) - violations = [r for r in results if not r.passed] - assert len(violations) >= 2 - - def test_bold_violation(self): - rules = [{"category": "DRAFTING", "description": "No bold formatting in emails"}] - output = "This has **bold text** in it." - results = verify_rules(output, rules) - violations = [r for r in results if not r.passed] - assert len(violations) >= 1 - - -# =========================================================================== -# RuleVerification dataclass -# =========================================================================== - -class TestRuleVerification: - def test_defaults(self): - rv = RuleVerification( - rule_category="TEST", rule_description="desc", passed=True - ) - assert rv.violation_detail == "" - assert rv.output_snippet == "" - - def test_all_fields(self): - rv = RuleVerification( - rule_category="DRAFTING", - rule_description="No em dashes", - passed=False, - violation_detail="contains em dash", - output_snippet="text \u2014 here", - ) - assert not rv.passed - assert rv.violation_detail == "contains em dash" - - -# =========================================================================== -# SQLite persistence -# =========================================================================== - -class TestVerificationPersistence: - """log_verification() and get_verification_stats() SQLite roundtrip.""" - - def test_log_and_retrieve_stats(self): - with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: - db_path = Path(f.name) - - try: - results = [ - RuleVerification("DRAFTING", "No em dashes", True), - RuleVerification("DRAFTING", "No em dashes", False, "violation", "snippet"), - RuleVerification("PRICING", "No dollar amounts", False, "violation", "snippet"), - ] - log_verification(session=71, results=results, db_path=db_path) - stats = get_verification_stats(db_path) - - assert stats["total_checks"] == 3 - assert stats["passed"] == 1 - assert stats["pass_rate"] == pytest.approx(1 / 3) - assert "DRAFTING" in stats["violations_by_category"] - assert "PRICING" in stats["violations_by_category"] - assert stats["violations_by_category"]["DRAFTING"] == 1 - assert stats["violations_by_category"]["PRICING"] == 1 - finally: - import gc; gc.collect() # release SQLite connections on Windows - db_path.unlink(missing_ok=True) - - def test_empty_db_stats(self): - with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: - db_path = Path(f.name) - - try: - ensure_table(db_path) - stats = get_verification_stats(db_path) - assert stats["total_checks"] == 0 - assert stats["pass_rate"] == 1.0 - assert stats["violations_by_category"] == {} - finally: - import gc; gc.collect() - db_path.unlink(missing_ok=True) - - def test_multiple_sessions(self): - with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: - db_path = Path(f.name) - - try: - r1 = [RuleVerification("DRAFTING", "Rule A", True)] - r2 = [RuleVerification("DRAFTING", "Rule A", False, "v", "s")] - log_verification(session=70, results=r1, db_path=db_path) - log_verification(session=71, results=r2, db_path=db_path) - stats = get_verification_stats(db_path) - assert stats["total_checks"] == 2 - assert stats["passed"] == 1 - finally: - import gc; gc.collect() - db_path.unlink(missing_ok=True) diff --git a/tests/test_rw_lock.py b/tests/test_rw_lock.py deleted file mode 100644 index 4b1b6946..00000000 --- a/tests/test_rw_lock.py +++ /dev/null @@ -1,46 +0,0 @@ -"""Tests for reader-writer lock.""" - -import threading -import time - -from gradata.rules.rw_lock import RWLock - - -class TestRWLock: - def test_read_lock_context_manager(self): - lock = RWLock() - with lock.read_lock(): - assert lock._readers == 1 - assert lock._readers == 0 - - def test_multiple_readers(self): - lock = RWLock() - lock.acquire_read() - lock.acquire_read() - assert lock._readers == 2 - lock.release_read() - lock.release_read() - - def test_write_lock_context_manager(self): - lock = RWLock() - with lock.write_lock(): - assert not lock._writer_event.is_set() - assert lock._writer_event.is_set() - - def test_write_excludes_new_reads(self): - lock = RWLock() - results = [] - lock.acquire_write() - - def try_read(): - lock.acquire_read() - results.append("read") - lock.release_read() - - t = threading.Thread(target=try_read) - t.start() - time.sleep(0.1) - assert results == [] # read blocked - lock.release_write() - t.join(timeout=1) - assert results == ["read"] diff --git a/tests/test_spec_compliance.py b/tests/test_spec_compliance.py index cc73cc66..e3f0fe54 100644 --- a/tests/test_spec_compliance.py +++ b/tests/test_spec_compliance.py @@ -90,7 +90,6 @@ def test_all_enhancements_importable(self, module): "truth_protocol", "eval_benchmark", "install_manifest", - "outcome_feedback", ], ) def test_contrib_enhancements_importable(self, module): diff --git a/tests/test_steals.py b/tests/test_steals.py index d5c51479..fb805f71 100644 --- a/tests/test_steals.py +++ b/tests/test_steals.py @@ -356,110 +356,6 @@ def test_draft_comparison(self): assert ctx.is_correction is True -# ========================================================================= -# STEAL 5: Rule Conflict Detection -# ========================================================================= - - -class TestRuleConflicts: - """Test rule relationship classification.""" - - def _make_lesson(self, desc: str, *, category: str = "DRAFTING", - state: LessonState = LessonState.RULE, - confidence: float = 0.90) -> Lesson: - return Lesson( - date="2026-03-01", - state=state, - confidence=confidence, - category=category, - description=desc, - ) - - def test_updates_detection(self): - from gradata.enhancements.rule_evolution import detect_rule_conflict, RuleRelation - - new = self._make_lesson("Always avoid using formal tone in emails") - existing = [ - self._make_lesson("Always use formal tone in emails"), - ] - - relation, target = detect_rule_conflict(new, existing) - assert relation == RuleRelation.UPDATES - assert target is not None - - def test_extends_detection(self): - from gradata.enhancements.rule_evolution import detect_rule_conflict, RuleRelation - - new = self._make_lesson("Use colons instead of em dashes in email subject lines") - existing = [ - self._make_lesson("Use colons instead of em dashes in email body text"), - ] - - relation, target = detect_rule_conflict(new, existing) - assert relation == RuleRelation.EXTENDS - assert target is not None - - def test_derives_detection(self): - from gradata.enhancements.rule_evolution import detect_rule_conflict, RuleRelation - - new = self._make_lesson("Use colons instead of em dashes in email paragraphs") - existing = [ - self._make_lesson("Use colons instead of em dashes in email subject lines"), - self._make_lesson("Use colons instead of em dashes in email bullet points"), - self._make_lesson("Use colons instead of em dashes in email headers"), - ] - - relation, target = detect_rule_conflict(new, existing) - # All same category with high keyword overlap -> DERIVES or EXTENDS - assert relation in (RuleRelation.DERIVES, RuleRelation.EXTENDS) - - def test_independent_detection(self): - from gradata.enhancements.rule_evolution import detect_rule_conflict, RuleRelation - - new = self._make_lesson("Always verify prospect company size before outreach") - existing = [ - self._make_lesson("Use colons not em dashes in emails"), - ] - - relation, target = detect_rule_conflict(new, existing) - assert relation == RuleRelation.INDEPENDENT - assert target is None - - def test_empty_existing_rules(self): - from gradata.enhancements.rule_evolution import detect_rule_conflict, RuleRelation - - new = self._make_lesson("Some new rule") - relation, target = detect_rule_conflict(new, []) - assert relation == RuleRelation.INDEPENDENT - - def test_classify_all_relations(self): - from gradata.enhancements.rule_evolution import classify_all_relations, RuleRelation - - new = self._make_lesson("Use short email subject lines for cold outreach") - existing = [ - self._make_lesson("Keep email subject lines under fifty characters"), - self._make_lesson("Always verify prospect identity before drafting"), - self._make_lesson("Use formal tone in cold outreach emails"), - ] - - results = classify_all_relations(new, existing) - assert isinstance(results, list) - # Should have at least one result (the subject line rule is similar) - # Results are sorted by similarity descending - - -class TestRuleRelationEnum: - """Test RuleRelation enum values.""" - - def test_all_values(self): - from gradata.enhancements.rule_evolution import RuleRelation - - assert RuleRelation.UPDATES.value == "updates" - assert RuleRelation.EXTENDS.value == "extends" - assert RuleRelation.DERIVES.value == "derives" - assert RuleRelation.INDEPENDENT.value == "independent" - - # ========================================================================= # STEAL 6: Learning Graph # ========================================================================= @@ -615,7 +511,6 @@ class TestIntegration: def test_all_imports(self): from gradata.mcp_tools import correct, recall, manifest from gradata.correction_detector import detect_correction, extract_correction_context - from gradata.enhancements.rule_evolution import detect_rule_conflict, RuleRelation from gradata.graph import build_learning_graph, GraphNode, GraphEdge, to_json def test_correction_to_graph_flow(self): @@ -648,27 +543,3 @@ def test_correction_to_graph_flow(self): assert len(nodes) == 1 assert nodes[0].category == "FORMATTING" - def test_conflict_detection_to_graph_flow(self): - """Simulate: new correction -> check conflicts -> build graph with edges.""" - from gradata.enhancements.rule_evolution import classify_all_relations, RuleRelation - from gradata.graph import build_learning_graph, GraphEdge - - existing = [ - Lesson(date="2026-03-01", state=LessonState.RULE, confidence=0.95, - category="DRAFTING", description="Always keep emails concise and short"), - Lesson(date="2026-03-01", state=LessonState.RULE, confidence=0.90, - category="DRAFTING", description="Use direct subject lines in cold emails"), - ] - - new_lesson = Lesson( - date="2026-03-27", state=LessonState.INSTINCT, confidence=0.30, - category="DRAFTING", description="Keep cold emails under three sentences", - ) - - # Check relations - relations = classify_all_relations(new_lesson, existing) - - # Build graph - all_lessons = existing + [new_lesson] - nodes, edges = build_learning_graph(all_lessons) - assert len(nodes) == 3 diff --git a/tests/test_swe_bench.py b/tests/test_swe_bench.py deleted file mode 100644 index fa203e82..00000000 --- a/tests/test_swe_bench.py +++ /dev/null @@ -1,224 +0,0 @@ -""" -Tests for SWE-bench harness (no Docker, no HuggingFace download required). -=========================================================================== -Uses mock instances to verify the full flow: - agent attempts fix → compare to gold → brain.correct() on failure → lessons accumulate -""" - -from __future__ import annotations - -import tempfile -from pathlib import Path - -import pytest - -from gradata.benchmarks.swe_bench import ( - PatchResult, - RunConfig, - RunResults, - SWEBenchHarness, - SWEInstance, - compare_patches, -) - - -# --------------------------------------------------------------------------- -# Mock data -# --------------------------------------------------------------------------- - -def _mock_instances(n: int = 10) -> list[SWEInstance]: - """Generate mock SWE-bench instances for testing.""" - instances = [] - for i in range(n): - instances.append(SWEInstance( - instance_id=f"test__repo-{i}", - repo="test/repo", - problem_statement=f"Bug #{i}: function returns wrong value", - gold_patch=f"--- a/src/foo.py\n+++ b/src/foo.py\n@@ -1 +1 @@\n-return {i}\n+return {i + 1}", - fail_to_pass=[f"test_bug_{i}"], - )) - return instances - - -def _perfect_agent(instance: SWEInstance, brain_rules: str) -> str: - """Agent that always returns the gold patch (100% resolve rate).""" - return instance.gold_patch - - -def _bad_agent(instance: SWEInstance, brain_rules: str) -> str: - """Agent that always returns wrong patch (0% resolve rate).""" - return "--- a/wrong.py\n+++ b/wrong.py\n@@ -1 +1 @@\n-wrong\n+still wrong" - - -def _improving_agent_factory(improve_after: int = 5): - """Agent that starts bad but gets better (simulates learning).""" - call_count = [0] - - def agent(instance: SWEInstance, brain_rules: str) -> str: - call_count[0] += 1 - if call_count[0] > improve_after and brain_rules: - return instance.gold_patch # "Learned" from brain rules - return "--- a/bad.py\n+++ b/bad.py\n@@ -1 +1 @@\n-bad\n+still bad" - - return agent - - -# --------------------------------------------------------------------------- -# Patch comparison tests -# --------------------------------------------------------------------------- - - -class TestPatchComparison: - def test_identical_patches(self): - patch = "--- a/f.py\n+++ b/f.py\n@@ -1 +1 @@\n-old\n+new" - assert compare_patches(patch, patch) == 1.0 - - def test_completely_different(self): - a = "--- a/f.py\n+++ b/f.py\n@@ -1 +1 @@\n-alpha\n+beta" - b = "--- a/g.py\n+++ b/g.py\n@@ -1 +1 @@\n-gamma\n+delta" - assert compare_patches(a, b) == 0.0 - - def test_partial_overlap(self): - a = "--- a/f.py\n+++ b/f.py\n@@ -1 +1 @@\n-old\n+new\n+extra" - b = "--- a/f.py\n+++ b/f.py\n@@ -1 +1 @@\n-old\n+new" - sim = compare_patches(a, b) - assert 0.0 < sim < 1.0 - - def test_empty_patches(self): - assert compare_patches("", "") == 1.0 - assert compare_patches("some patch", "") == 0.0 - assert compare_patches("", "some patch") == 0.0 - - def test_whitespace_normalized(self): - a = "--- a/f.py\n+++ b/f.py\n@@ -1 +1 @@\n+return x + 1" - b = "--- a/f.py\n+++ b/f.py\n@@ -1 +1 @@\n+return x + 1" - assert compare_patches(a, b) == 1.0 - - -# --------------------------------------------------------------------------- -# Harness tests (no external deps) -# --------------------------------------------------------------------------- - - -class TestSWEBenchHarness: - def test_perfect_agent_100_percent(self): - with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir: - harness = SWEBenchHarness(brain_dir=tmpdir) - instances = _mock_instances(5) - config = RunConfig(run_id="perfect", batch_size=5, use_brain=False) - results = harness.run(instances, _perfect_agent, config) - assert results.resolve_rate == 1.0 - assert results.total_resolved == 5 - - def test_bad_agent_0_percent(self): - with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir: - harness = SWEBenchHarness(brain_dir=tmpdir) - instances = _mock_instances(5) - config = RunConfig(run_id="bad", batch_size=5, use_brain=True) - results = harness.run(instances, _bad_agent, config) - assert results.resolve_rate == 0.0 - assert results.total_resolved == 0 - - def test_corrections_captured_on_failure(self): - with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir: - harness = SWEBenchHarness(brain_dir=tmpdir) - instances = _mock_instances(3) - config = RunConfig(run_id="capture", batch_size=3, use_brain=True) - results = harness.run(instances, _bad_agent, config) - captured = sum(1 for r in results.results if r.correction_captured) - assert captured == 3 # All failures captured - - def test_lessons_created_from_corrections(self): - with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir: - harness = SWEBenchHarness(brain_dir=tmpdir) - instances = _mock_instances(5) - config = RunConfig(run_id="lessons", batch_size=5, use_brain=True) - results = harness.run(instances, _bad_agent, config) - # At least some lessons should be created - assert results.brain_lessons_created > 0 - - def test_batch_stats_computed(self): - with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir: - harness = SWEBenchHarness(brain_dir=tmpdir) - instances = _mock_instances(10) - config = RunConfig(run_id="batches", batch_size=5, use_brain=False) - results = harness.run(instances, _bad_agent, config) - assert len(results.batch_stats) == 2 # 10 instances / 5 per batch - - def test_summary_output(self): - with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir: - harness = SWEBenchHarness(brain_dir=tmpdir) - instances = _mock_instances(3) - config = RunConfig(run_id="summary", batch_size=3, use_brain=False) - results = harness.run(instances, _perfect_agent, config) - summary = results.summary() - assert "summary" in summary.lower() or "SWE-bench" in summary - - def test_to_dict_serializable(self): - with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir: - harness = SWEBenchHarness(brain_dir=tmpdir) - instances = _mock_instances(3) - results = harness.run(instances, _perfect_agent, RunConfig(batch_size=3)) - d = results.to_dict() - assert "resolve_rate" in d - assert "batch_stats" in d - # Should be JSON-serializable - import json - json.dumps(d) - - def test_compare_runs(self): - with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir: - harness = SWEBenchHarness(brain_dir=tmpdir) - instances = _mock_instances(5) - - baseline = harness.run( - instances, _bad_agent, - RunConfig(run_id="baseline", batch_size=5, use_brain=False), - ) - enhanced = harness.run( - instances, _perfect_agent, - RunConfig(run_id="enhanced", batch_size=5, use_brain=True), - ) - - comparison = harness.compare_runs(baseline, enhanced) - assert comparison["enhanced_resolve_rate"] > comparison["baseline_resolve_rate"] - assert "verdict" in comparison - assert "improved" in comparison["verdict"].lower() - - def test_max_instances_cap(self): - with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir: - harness = SWEBenchHarness(brain_dir=tmpdir) - instances = _mock_instances(20) - config = RunConfig(max_instances=5, batch_size=5, use_brain=False) - results = harness.run(instances, _perfect_agent, config) - assert results.total_attempted == 5 - - def test_agent_crash_handled(self): - def crashing_agent(instance, rules): - raise RuntimeError("Agent exploded") - - with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir: - harness = SWEBenchHarness(brain_dir=tmpdir) - instances = _mock_instances(2) - config = RunConfig(batch_size=2, use_brain=False) - results = harness.run(instances, crashing_agent, config) - assert results.total_attempted == 2 - assert results.total_resolved == 0 - - def test_no_brain_dir_works(self): - harness = SWEBenchHarness() # No brain - instances = _mock_instances(3) - config = RunConfig(batch_size=3, use_brain=False) - results = harness.run(instances, _perfect_agent, config) - assert results.total_resolved == 3 - - -class TestSWEInstance: - def test_instance_fields(self): - inst = SWEInstance( - instance_id="django__django-123", - repo="django/django", - problem_statement="Bug in ORM", - ) - assert inst.repo == "django/django" - assert inst.fail_to_pass == [] diff --git a/tests/test_wiki_store.py b/tests/test_wiki_store.py deleted file mode 100644 index af2c2478..00000000 --- a/tests/test_wiki_store.py +++ /dev/null @@ -1,207 +0,0 @@ -"""Tests for cloud wiki store (wiki_store.py). - -Tests the WikiStore logic with a mocked Supabase client since -we can't connect to a real Supabase instance in CI. -""" -from __future__ import annotations - -import json -from dataclasses import dataclass -from pathlib import Path -from unittest.mock import MagicMock, patch - -import pytest - - -@dataclass -class MockResponse: - data: list | dict | None = None - - -@pytest.fixture -def mock_supabase(): - """Create a mock Supabase client.""" - with patch("gradata.cloud.wiki_store.WikiStore.__init__", return_value=None) as _: - from gradata.cloud.wiki_store import WikiStore - store = object.__new__(WikiStore) - store.client = MagicMock() - store.brain_id = "test-brain" - store._embedder = None - yield store - - -def test_page_id_deterministic(): - from gradata.cloud.wiki_store import WikiStore - id1 = WikiStore._page_id("brain1", "My Title") - id2 = WikiStore._page_id("brain1", "My Title") - id3 = WikiStore._page_id("brain1", "Other Title") - assert id1 == id2 - assert id1 != id3 - assert id1.startswith("wp_") - - -def test_source_id_deterministic(): - from gradata.cloud.wiki_store import WikiStore - id1 = WikiStore._source_id("brain1", "https://example.com") - id2 = WikiStore._source_id("brain1", "https://example.com") - assert id1 == id2 - assert id1.startswith("ws_") - - -def test_content_hash(): - from gradata.cloud.wiki_store import WikiStore - h1 = WikiStore._content_hash("hello world") - h2 = WikiStore._content_hash("hello world") - h3 = WikiStore._content_hash("different") - assert h1 == h2 - assert h1 != h3 - assert len(h1) == 16 - - -def test_upsert_page(mock_supabase): - mock_table = MagicMock() - mock_table.upsert.return_value.execute.return_value = MockResponse() - mock_supabase.client.table.return_value = mock_table - - page_id = mock_supabase.upsert_page( - title="Test Page", - content="Some content", - category="CODE", - embed=False, - ) - - assert page_id.startswith("wp_") - mock_supabase.client.table.assert_called_with("wiki_pages") - mock_table.upsert.assert_called_once() - row = mock_table.upsert.call_args[0][0] - assert row["brain_id"] == "test-brain" - assert row["title"] == "Test Page" - assert row["category"] == "CODE" - - -def test_get_page_found(mock_supabase): - mock_chain = MagicMock() - mock_chain.select.return_value = mock_chain - mock_chain.eq.return_value = mock_chain - mock_chain.maybe_single.return_value = mock_chain - mock_chain.execute.return_value = MockResponse(data={ - "id": "wp_abc", "title": "Test", "category": "CODE", - "content": "hello", "page_type": "concept", "tags": "[]", - }) - mock_supabase.client.table.return_value = mock_chain - - page = mock_supabase.get_page("Test") - assert page is not None - assert page.title == "Test" - assert page.category == "CODE" - - -def test_get_page_not_found(mock_supabase): - mock_chain = MagicMock() - mock_chain.select.return_value = mock_chain - mock_chain.eq.return_value = mock_chain - mock_chain.maybe_single.return_value = mock_chain - mock_chain.execute.return_value = MockResponse(data=None) - mock_supabase.client.table.return_value = mock_chain - - page = mock_supabase.get_page("Nonexistent") - assert page is None - - -def test_search_categories(mock_supabase): - mock_supabase.client.rpc.return_value.execute.return_value = MockResponse(data=[ - {"id": "wp_1", "title": "Rule: CODE", "category": "CODE", - "content": "...", "page_type": "concept", "tags": "[]", "similarity": 0.9}, - {"id": "wp_2", "title": "Rule: TONE", "category": "TONE", - "content": "...", "page_type": "concept", "tags": "[]", "similarity": 0.7}, - ]) - # Mock _embed to return a dummy vector - mock_supabase._embed = lambda text: [0.0] * 384 - - cats = mock_supabase.search_categories("code implementation") - assert "CODE" in cats - assert "TONE" in cats - - -def test_text_search_fallback(mock_supabase): - mock_chain = MagicMock() - mock_chain.select.return_value = mock_chain - mock_chain.eq.return_value = mock_chain - mock_chain.ilike.return_value = mock_chain - mock_chain.limit.return_value = mock_chain - mock_chain.execute.return_value = MockResponse(data=[ - {"id": "wp_1", "title": "Test", "category": "CODE", - "content": "code stuff", "page_type": "concept", "tags": []}, - ]) - mock_supabase.client.table.return_value = mock_chain - - results = mock_supabase._text_search("code", limit=5) - assert len(results) == 1 - assert results[0].category == "CODE" - - -def test_add_source(mock_supabase): - mock_table = MagicMock() - mock_table.upsert.return_value.execute.return_value = MockResponse() - mock_supabase.client.table.return_value = mock_table - - source_id = mock_supabase.add_source( - title="Karpathy blog", - source_type="article", - source_url="https://example.com/blog", - content_hash="abc123", - ) - - assert source_id.startswith("ws_") - mock_supabase.client.table.assert_called_with("wiki_sources") - - -def test_sync_from_local(mock_supabase, tmp_path): - # Create test wiki pages - concepts = tmp_path / "concepts" - concepts.mkdir() - (concepts / "rule-code.md").write_text( - "---\ntitle: Graduated Rules: CODE\ncategory: CODE\ntype: concept\n---\nContent here", - encoding="utf-8", - ) - (concepts / "rule-tone.md").write_text( - "---\ntitle: Graduated Rules: TONE\ncategory: TONE\n---\nTone rules", - encoding="utf-8", - ) - - # Mock: no existing pages (all new) - mock_chain = MagicMock() - mock_chain.select.return_value = mock_chain - mock_chain.eq.return_value = mock_chain - mock_chain.maybe_single.return_value = mock_chain - mock_chain.execute.return_value = MockResponse(data=None) - - mock_upsert = MagicMock() - mock_upsert.upsert.return_value.execute.return_value = MockResponse() - - def table_router(name): - if name == "wiki_pages": - return MagicMock( - select=lambda *a: mock_chain, - upsert=lambda row: mock_upsert.upsert(row), - ) - return MagicMock() - - mock_supabase.client.table = table_router - # Disable embedding for sync test - mock_supabase.upsert_page = MagicMock(return_value="wp_test") - - result = mock_supabase.sync_from_local(tmp_path) - assert result["uploaded"] == 2 - assert result["errors"] == 0 - - -def test_schema_sql_valid(): - """Schema SQL should be well-formed (basic check).""" - from gradata.cloud.wiki_store import SCHEMA_SQL, SEARCH_RPC_SQL - assert "create table" in SCHEMA_SQL.lower() - assert "wiki_pages" in SCHEMA_SQL - assert "wiki_sources" in SCHEMA_SQL - assert "vector" in SCHEMA_SQL - assert "wiki_search" in SEARCH_RPC_SQL - assert "vector" in SEARCH_RPC_SQL