From 40b84adfca29a7f444010cee2846792b98c1dfeb Mon Sep 17 00:00:00 2001 From: Oliver Le Date: Mon, 13 Apr 2026 13:42:34 -0700 Subject: [PATCH 1/9] =?UTF-8?q?feat(sdk):=20middleware=20core=20=E2=80=94?= =?UTF-8?q?=20RuleSource,=20injection,=20enforcement?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds the shared primitives for runtime middleware adapters: - RuleSource reads lessons.md from a brain directory (same path Claude Code hooks use) and selects top-N RULE/PATTERN lessons. - build_brain_rules_block() renders the same XML the SessionStart hook emits, for consistency across environments. - check_output() scans text against RULE-tier regex patterns derived from rule_to_hook.classify_rule (reuses existing classifier). - RuleViolation exception + GRADATA_BYPASS=1 kill switch env var. Zero changes to existing rule_engine / rule_to_hook — middleware is additive. --- src/gradata/middleware/__init__.py | 82 +++++++++ src/gradata/middleware/_core.py | 264 +++++++++++++++++++++++++++++ tests/test_middleware_core.py | 135 +++++++++++++++ 3 files changed, 481 insertions(+) create mode 100644 src/gradata/middleware/__init__.py create mode 100644 src/gradata/middleware/_core.py create mode 100644 tests/test_middleware_core.py diff --git a/src/gradata/middleware/__init__.py b/src/gradata/middleware/__init__.py new file mode 100644 index 00000000..0596084f --- /dev/null +++ b/src/gradata/middleware/__init__.py @@ -0,0 +1,82 @@ +"""Runtime middleware adapters for non-Claude-Code environments. + +Gradata's hooks only fire inside Claude Code. For direct-SDK agents +(raw OpenAI SDK, raw Anthropic SDK, LangChain, CrewAI) this subpackage +provides runtime wrappers that inject learned rules into system prompts +and enforce RULE-tier patterns on outputs. + +Quick start: + + from anthropic import Anthropic + from gradata.middleware import wrap_anthropic + + client = wrap_anthropic(Anthropic(), brain_path="./brain") + # All client.messages.create(...) calls now get rules injected. + +The adapters share a common :class:`RuleSource` that reads from the same +``lessons.md`` + brain database that Claude Code hooks use, so behaviour +is consistent across environments. + +Environment overrides: + GRADATA_BYPASS=1 — disables all injection and enforcement (emergency kill switch). + +Optional deps: + - AnthropicMiddleware / wrap_anthropic -> ``anthropic`` + - OpenAIMiddleware / wrap_openai -> ``openai`` + - LangChainCallback -> ``langchain-core`` + - CrewAIGuard -> works with plain CrewAI guardrails + +Importing an adapter without its optional dep raises a clear ImportError +with the install hint. +""" + +from __future__ import annotations + +from gradata.middleware._core import ( + RuleSource, + RuleViolation, + build_brain_rules_block, + check_output, + is_bypassed, +) + +# Adapters are exposed via lazy __getattr__ so importing the package +# doesn't require anthropic / openai / langchain / crewai to be installed. + +__all__ = [ # noqa: RUF022 — logical grouping (core -> adapters) over alphabetical + "RuleSource", + "RuleViolation", + "build_brain_rules_block", + "check_output", + "is_bypassed", + # Lazy exports — see __getattr__ + "AnthropicMiddleware", + "OpenAIMiddleware", + "LangChainCallback", + "CrewAIGuard", + "wrap_anthropic", + "wrap_openai", +] + + +def __getattr__(name: str): # pragma: no cover - trivial dispatch + if name in ("AnthropicMiddleware", "wrap_anthropic"): + from gradata.middleware.anthropic_adapter import ( + AnthropicMiddleware, + wrap_anthropic, + ) + + return {"AnthropicMiddleware": AnthropicMiddleware, "wrap_anthropic": wrap_anthropic}[name] + if name in ("OpenAIMiddleware", "wrap_openai"): + from gradata.middleware.openai_adapter import OpenAIMiddleware, wrap_openai + + return {"OpenAIMiddleware": OpenAIMiddleware, "wrap_openai": wrap_openai}[name] + if name == "LangChainCallback": + from gradata.middleware.langchain_adapter import LangChainCallback + + return LangChainCallback + if name == "CrewAIGuard": + from gradata.middleware.crewai_adapter import CrewAIGuard + + return CrewAIGuard + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/gradata/middleware/_core.py b/src/gradata/middleware/_core.py new file mode 100644 index 00000000..f340ef67 --- /dev/null +++ b/src/gradata/middleware/_core.py @@ -0,0 +1,264 @@ +"""Shared core for runtime middleware adapters. + +This module is SDK-layer: it provides the :class:`RuleSource` (which reads +``lessons.md`` from a brain directory) and the injection / enforcement +primitives the per-framework adapters compose. It must not import any +third-party LLM SDK. +""" + +from __future__ import annotations + +import logging +import os +import re +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING + +from gradata.enhancements.rule_to_hook import DeterminismCheck, classify_rule + +if TYPE_CHECKING: # pragma: no cover + from gradata._types import Lesson + +_log = logging.getLogger(__name__) + +# Default cap matches the existing SessionStart hook (inject_brain_rules.py) +DEFAULT_MAX_RULES = 10 +DEFAULT_MIN_CONFIDENCE = 0.60 + +# Regex block patterns derived from RULE-tier deterministic classifications. +# Each entry: (hook_template, compiled_regex, friendly_name) +# These apply post-call to detect rule violations in model output. +_BLOCK_PATTERNS: dict[str, tuple[re.Pattern[str], str]] = { + "regex_replace": (re.compile(r"[\u2014\u2013]"), "em-dash"), +} + + +class RuleViolation(Exception): # noqa: N818 — public API name specified in spec + """Raised when an LLM output violates a RULE-tier deterministic pattern. + + Attributes: + rule_description: The source rule's description text. + pattern_name: Short label for which check fired (e.g. ``"em-dash"``). + output: The offending model output text. + """ + + def __init__(self, rule_description: str, pattern_name: str, output: str) -> None: + self.rule_description = rule_description + self.pattern_name = pattern_name + self.output = output + super().__init__( + f"RuleViolation: output matched '{pattern_name}' " + f"(rule: {rule_description!r})" + ) + + +def is_bypassed() -> bool: + """Return True if GRADATA_BYPASS=1 is set (kill switch for middleware).""" + return os.environ.get("GRADATA_BYPASS", "").strip() == "1" + + +# --------------------------------------------------------------------------- +# RuleSource +# --------------------------------------------------------------------------- + + +@dataclass +class _ScoredLesson: + category: str + description: str + state: str # "RULE" or "PATTERN" + confidence: float + + +class RuleSource: + """Reads lessons from the same brain directory Claude Code hooks use. + + The source loads ``/lessons.md`` on demand and returns the + top-N highest-priority lessons. Parsing delegates to the same + :func:`gradata.enhancements.self_improvement.parse_lessons` the + SessionStart hook uses, so behaviour is identical across environments. + + A :class:`RuleSource` can also be constructed directly from a list of + lesson dicts — useful for tests and for callers that source rules from + somewhere other than the default lessons file. + """ + + def __init__( + self, + brain_path: str | Path | None = None, + *, + lessons: list[dict] | None = None, + max_rules: int = DEFAULT_MAX_RULES, + min_confidence: float = DEFAULT_MIN_CONFIDENCE, + ) -> None: + self._brain_path = Path(brain_path) if brain_path else None + self._static_lessons = lessons + self.max_rules = max_rules + self.min_confidence = min_confidence + + # -- loading ---------------------------------------------------------- + + def _load_from_brain(self) -> list[_ScoredLesson]: + if self._brain_path is None: + return [] + path = self._brain_path / "lessons.md" + if not path.is_file(): + return [] + try: + from gradata.enhancements.self_improvement import parse_lessons + except ImportError: # pragma: no cover + _log.debug("parse_lessons unavailable; returning no rules") + return [] + try: + text = path.read_text(encoding="utf-8") + except OSError as exc: # pragma: no cover - filesystem edge + _log.warning("Could not read %s: %s", path, exc) + return [] + parsed = parse_lessons(text) + return [_lesson_to_scored(lesson) for lesson in parsed] + + def _load_from_dicts(self) -> list[_ScoredLesson]: + out: list[_ScoredLesson] = [] + for lesson in self._static_lessons or []: + state = str(lesson.get("state") or lesson.get("status") or "").upper() + conf = float(lesson.get("confidence", 0.0) or 0.0) + category = str(lesson.get("category", "") or "") + description = str(lesson.get("description", "") or "") + if not description: + continue + out.append( + _ScoredLesson( + category=category, + description=description, + state=state, + confidence=conf, + ) + ) + return out + + def load(self) -> list[_ScoredLesson]: + """Return eligible lessons (RULE/PATTERN only, above min_confidence).""" + lessons = ( + self._load_from_dicts() if self._static_lessons is not None + else self._load_from_brain() + ) + return [ + l for l in lessons + if l.state in ("RULE", "PATTERN") and l.confidence >= self.min_confidence + ] + + # -- selection -------------------------------------------------------- + + def select(self) -> list[_ScoredLesson]: + """Return up to ``max_rules`` lessons ranked for injection. + + RULE beats PATTERN, ties broken by confidence descending. This matches + the priority scheme used by ``inject_brain_rules.py``. + """ + lessons = self.load() + lessons.sort( + key=lambda l: (1 if l.state == "RULE" else 0, l.confidence), + reverse=True, + ) + return lessons[: self.max_rules] + + # -- enforcement ------------------------------------------------------ + + def rule_tier_blockers(self) -> list[tuple[_ScoredLesson, re.Pattern[str], str]]: + """Return (lesson, compiled_pattern, name) tuples for RULE-tier blockers. + + Uses :func:`gradata.enhancements.rule_to_hook.classify_rule` to find + rules whose descriptions map to a deterministic regex template, then + resolves that template to a compiled pattern. PATTERN-tier lessons are + skipped — only RULE-tier rules (confidence >= 0.90) are enforced. + """ + out: list[tuple[_ScoredLesson, re.Pattern[str], str]] = [] + for lesson in self.load(): + if lesson.state != "RULE": + continue + try: + candidate = classify_rule(lesson.description, lesson.confidence) + except ValueError: + continue + if candidate.determinism == DeterminismCheck.NOT_DETERMINISTIC: + continue + spec = _BLOCK_PATTERNS.get(candidate.hook_template) + if spec is None: + continue + pattern, name = spec + out.append((lesson, pattern, name)) + return out + + +def _lesson_to_scored(lesson: Lesson) -> _ScoredLesson: + state_name = lesson.state.name if hasattr(lesson.state, "name") else str(lesson.state) + return _ScoredLesson( + category=lesson.category, + description=lesson.description, + state=state_name, + confidence=float(lesson.confidence), + ) + + +# --------------------------------------------------------------------------- +# Injection +# --------------------------------------------------------------------------- + + +def build_brain_rules_block(source: RuleSource) -> str: + """Render the ```` XML block for a given :class:`RuleSource`. + + Matches the format produced by :mod:`gradata.hooks.inject_brain_rules` + so injection is identical across Claude Code and direct-SDK agents. + Returns ``""`` when no rules are eligible. + """ + if is_bypassed(): + return "" + selected = source.select() + if not selected: + return "" + lines = [ + f"[{l.state}:{l.confidence:.2f}] {l.category}: {l.description}" + for l in selected + ] + return "\n" + "\n".join(lines) + "\n" + + +def inject_into_system(system: str | None, block: str) -> str: + """Append the ```` block to an existing system prompt.""" + if not block: + return system or "" + if not system: + return block + return f"{system}\n\n{block}" + + +# --------------------------------------------------------------------------- +# Enforcement +# --------------------------------------------------------------------------- + + +def check_output(source: RuleSource, text: str, *, strict: bool = False) -> list[RuleViolation]: + """Scan ``text`` for RULE-tier pattern violations. + + When ``strict`` is True, the first violation is raised. Otherwise the + full list of violations is returned (empty if clean). + """ + if is_bypassed() or not text: + return [] + violations: list[RuleViolation] = [] + for lesson, pattern, name in source.rule_tier_blockers(): + if pattern.search(text): + v = RuleViolation( + rule_description=lesson.description, + pattern_name=name, + output=text, + ) + if strict: + raise v + _log.warning( + "Gradata rule violation (%s): %s", name, lesson.description, + ) + violations.append(v) + return violations diff --git a/tests/test_middleware_core.py b/tests/test_middleware_core.py new file mode 100644 index 00000000..d7c9a300 --- /dev/null +++ b/tests/test_middleware_core.py @@ -0,0 +1,135 @@ +"""Tests for the middleware core (rule source, injection, enforcement).""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from gradata.middleware import ( + RuleSource, + RuleViolation, + build_brain_rules_block, + check_output, + is_bypassed, +) + + +def test_rule_source_from_static_lessons_selects_rule_and_pattern(): + src = RuleSource( + lessons=[ + {"state": "RULE", "confidence": 0.95, "category": "TONE", + "description": "Never use em dashes"}, + {"state": "PATTERN", "confidence": 0.70, "category": "STRUCTURE", + "description": "Lead with the answer"}, + {"state": "INSTINCT", "confidence": 0.55, "category": "DRAFTING", + "description": "Avoid padding"}, + ], + ) + selected = src.select() + assert len(selected) == 2 + # RULE comes first (priority bucket), then PATTERN. + assert selected[0].state == "RULE" + assert selected[1].state == "PATTERN" + + +def test_build_brain_rules_block_wraps_in_xml(): + src = RuleSource( + lessons=[ + {"state": "RULE", "confidence": 0.95, "category": "TONE", + "description": "Never use em dashes"}, + ], + ) + block = build_brain_rules_block(src) + assert block.startswith("") + assert block.endswith("") + assert "[RULE:0.95]" in block + assert "TONE" in block + + +def test_build_brain_rules_block_respects_max_rules(): + lessons = [ + {"state": "RULE", "confidence": 0.90 + i / 100, "category": f"C{i}", + "description": f"desc {i}"} + for i in range(20) + ] + src = RuleSource(lessons=lessons, max_rules=5) + block = build_brain_rules_block(src) + assert block.count("[RULE:") == 5 + + +def test_check_output_finds_em_dash_violation(): + src = RuleSource( + lessons=[ + {"state": "RULE", "confidence": 0.95, "category": "TONE", + "description": "Never use em dashes"}, + ], + ) + violations = check_output(src, "no good \u2014 here", strict=False) + assert len(violations) == 1 + assert violations[0].pattern_name == "em-dash" + + +def test_check_output_strict_raises(): + src = RuleSource( + lessons=[ + {"state": "RULE", "confidence": 0.95, "category": "TONE", + "description": "Never use em dashes"}, + ], + ) + with pytest.raises(RuleViolation): + check_output(src, "bad \u2014 text", strict=True) + + +def test_check_output_ignores_non_rule_tier(): + src = RuleSource( + lessons=[ + {"state": "PATTERN", "confidence": 0.80, "category": "TONE", + "description": "Never use em dashes"}, + ], + ) + # PATTERN-tier is injected but not enforced + assert check_output(src, "bad \u2014 text", strict=False) == [] + + +def test_is_bypassed_env(monkeypatch): + monkeypatch.setenv("GRADATA_BYPASS", "1") + assert is_bypassed() is True + monkeypatch.setenv("GRADATA_BYPASS", "0") + assert is_bypassed() is False + monkeypatch.delenv("GRADATA_BYPASS", raising=False) + assert is_bypassed() is False + + +def test_bypass_disables_block_and_check(monkeypatch): + monkeypatch.setenv("GRADATA_BYPASS", "1") + src = RuleSource( + lessons=[ + {"state": "RULE", "confidence": 0.95, "category": "TONE", + "description": "Never use em dashes"}, + ], + ) + assert build_brain_rules_block(src) == "" + assert check_output(src, "bad \u2014 text", strict=True) == [] + + +def test_rule_source_from_brain_path(tmp_path: Path): + brain = tmp_path / "brain" + brain.mkdir() + (brain / "lessons.md").write_text( + "[2026-04-13] [RULE:0.95] TONE: Never use em dashes in prose\n" + "[2026-04-13] [PATTERN:0.70] STRUCTURE: Lead with the answer\n", + encoding="utf-8", + ) + src = RuleSource(brain_path=brain) + selected = src.select() + assert len(selected) == 2 + cats = {l.category for l in selected} + assert "TONE" in cats + assert "STRUCTURE" in cats + + +def test_rule_source_missing_brain_returns_empty(tmp_path: Path): + src = RuleSource(brain_path=tmp_path / "does-not-exist") + assert src.select() == [] + assert build_brain_rules_block(src) == "" From 0a9a961a83ec4d0a427f733ea8ba1893bb87d808 Mon Sep 17 00:00:00 2001 From: Oliver Le Date: Mon, 13 Apr 2026 13:43:14 -0700 Subject: [PATCH 2/9] =?UTF-8?q?feat(sdk):=20AnthropicMiddleware=20?= =?UTF-8?q?=E2=80=94=20wrap=5Fanthropic()=20adapter?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wraps anthropic.Anthropic() so every client.messages.create() call gets Gradata's appended to the system prompt and its response text post-checked against RULE-tier patterns. - strict=False (default) logs violations; strict=True raises RuleViolation - Handles both string and content-block-list system prompts - Lazy ImportError with 'pip install anthropic' hint when dep missing - All other client attributes delegate to the underlying client --- src/gradata/middleware/anthropic_adapter.py | 132 ++++++++++++ tests/test_middleware_anthropic.py | 215 ++++++++++++++++++++ 2 files changed, 347 insertions(+) create mode 100644 src/gradata/middleware/anthropic_adapter.py create mode 100644 tests/test_middleware_anthropic.py diff --git a/src/gradata/middleware/anthropic_adapter.py b/src/gradata/middleware/anthropic_adapter.py new file mode 100644 index 00000000..3ff4f606 --- /dev/null +++ b/src/gradata/middleware/anthropic_adapter.py @@ -0,0 +1,132 @@ +"""Anthropic SDK middleware adapter. + +Wraps an ``anthropic.Anthropic()`` client so every +``client.messages.create(...)`` call gets Gradata rules injected into the +system prompt and its response optionally checked against RULE-tier regex +patterns. + +Usage:: + + from anthropic import Anthropic + from gradata.middleware import wrap_anthropic + + client = wrap_anthropic(Anthropic(), brain_path="./brain") + resp = client.messages.create( + model="claude-sonnet-4-5", + messages=[{"role": "user", "content": "Hi"}], + max_tokens=128, + ) + +The wrapper preserves the original ``messages`` object shape; only the +``system`` kwarg is mutated on the way in, and the response is only +inspected on the way out. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from gradata.middleware._core import ( + RuleSource, + build_brain_rules_block, + check_output, + inject_into_system, +) + + +def _require_anthropic() -> None: + try: + import anthropic # noqa: F401 + except ImportError as exc: # pragma: no cover - import guard + raise ImportError( + "AnthropicMiddleware requires the 'anthropic' package. " + "Install with: pip install anthropic" + ) from exc + + +def _extract_text(response: Any) -> str: + """Best-effort extraction of the assistant text from an Anthropic response.""" + content = getattr(response, "content", None) + if content is None and isinstance(response, dict): + content = response.get("content") + if not content: + return "" + parts: list[str] = [] + for block in content: + # SDK object: block.type == 'text', block.text == '...' + block_type = getattr(block, "type", None) + if block_type is None and isinstance(block, dict): + block_type = block.get("type") + if block_type == "text": + text = getattr(block, "text", None) + if text is None and isinstance(block, dict): + text = block.get("text", "") + if text: + parts.append(str(text)) + return "\n".join(parts) + + +class AnthropicMiddleware: + """Wraps an Anthropic client with Gradata rule injection + enforcement.""" + + def __init__( + self, + client: Any, + *, + brain_path: str | Path | None = None, + source: RuleSource | None = None, + strict: bool = False, + ) -> None: + _require_anthropic() + self._client = client + self._source = source or RuleSource(brain_path=brain_path) + self._strict = strict + # Replace the messages namespace with a wrapper that intercepts create + self._orig_messages = client.messages + self.messages = _MessagesProxy(self) + + # Delegate everything else to the underlying client + def __getattr__(self, name: str) -> Any: + return getattr(self._client, name) + + +class _MessagesProxy: + """Thin proxy over ``client.messages`` that intercepts ``create``.""" + + def __init__(self, mw: AnthropicMiddleware) -> None: + self._mw = mw + + def __getattr__(self, name: str) -> Any: + return getattr(self._mw._orig_messages, name) + + def create(self, *args: Any, **kwargs: Any) -> Any: + block = build_brain_rules_block(self._mw._source) + if block: + system = kwargs.get("system") + # Anthropic accepts either a string or a list of content blocks. + # For lists we append a new text block; for strings we concatenate. + if isinstance(system, list): + kwargs["system"] = [*system, {"type": "text", "text": block}] + else: + kwargs["system"] = inject_into_system(system, block) + + response = self._mw._orig_messages.create(*args, **kwargs) + + text = _extract_text(response) + if text: + check_output(self._mw._source, text, strict=self._mw._strict) + return response + + +def wrap_anthropic( + client: Any, + *, + brain_path: str | Path | None = None, + source: RuleSource | None = None, + strict: bool = False, +) -> AnthropicMiddleware: + """Convenience constructor — see :class:`AnthropicMiddleware`.""" + return AnthropicMiddleware( + client, brain_path=brain_path, source=source, strict=strict, + ) diff --git a/tests/test_middleware_anthropic.py b/tests/test_middleware_anthropic.py new file mode 100644 index 00000000..cd8c37c2 --- /dev/null +++ b/tests/test_middleware_anthropic.py @@ -0,0 +1,215 @@ +"""Tests for gradata.middleware.anthropic_adapter. + +These tests mock the Anthropic SDK client — no real API calls are made. +""" + +from __future__ import annotations + +import sys +import types +from pathlib import Path + +import pytest + +# --------------------------------------------------------------------------- +# Minimal stub for the `anthropic` package so the adapter imports cleanly +# in CI without the real SDK installed. +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _stub_anthropic(monkeypatch): + if "anthropic" not in sys.modules: + stub = types.ModuleType("anthropic") + stub.Anthropic = object # type: ignore[attr-defined] + monkeypatch.setitem(sys.modules, "anthropic", stub) + yield + + +# --------------------------------------------------------------------------- +# Fakes mimicking the parts of the Anthropic SDK the adapter touches. +# --------------------------------------------------------------------------- + + +class _FakeTextBlock: + def __init__(self, text: str) -> None: + self.type = "text" + self.text = text + + +class _FakeResponse: + def __init__(self, text: str) -> None: + self.content = [_FakeTextBlock(text)] + + +class _FakeMessages: + def __init__(self, reply: str = "hello world") -> None: + self.reply = reply + self.last_kwargs: dict = {} + + def create(self, **kwargs): + self.last_kwargs = kwargs + return _FakeResponse(self.reply) + + +class _FakeClient: + def __init__(self, reply: str = "hello world") -> None: + self.messages = _FakeMessages(reply) + self.meta = "keep-me" + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def brain_with_em_dash_rule(tmp_path: Path) -> Path: + brain = tmp_path / "brain" + brain.mkdir() + (brain / "lessons.md").write_text( + "[2026-04-13] [RULE:0.95] TONE: Never use em dashes in prose\n", + encoding="utf-8", + ) + return brain + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_wrap_anthropic_injects_rules_into_system(brain_with_em_dash_rule: Path): + from gradata.middleware import wrap_anthropic + + client = _FakeClient() + wrapped = wrap_anthropic(client, brain_path=brain_with_em_dash_rule) + + wrapped.messages.create( + model="claude-sonnet-4-5", + messages=[{"role": "user", "content": "hi"}], + max_tokens=16, + ) + + system = client.messages.last_kwargs.get("system", "") + assert "" in system + assert "TONE" in system + assert "em dashes" in system + + +def test_wrap_anthropic_preserves_existing_system_prompt(brain_with_em_dash_rule: Path): + from gradata.middleware import wrap_anthropic + + client = _FakeClient() + wrapped = wrap_anthropic(client, brain_path=brain_with_em_dash_rule) + + wrapped.messages.create( + model="claude-sonnet-4-5", + system="You are a helpful assistant.", + messages=[{"role": "user", "content": "hi"}], + max_tokens=16, + ) + + system = client.messages.last_kwargs["system"] + assert system.startswith("You are a helpful assistant.") + assert "" in system + + +def test_wrap_anthropic_strict_raises_on_violation(brain_with_em_dash_rule: Path): + from gradata.middleware import RuleViolation, wrap_anthropic + + client = _FakeClient(reply="this response has an em dash \u2014 right here") + wrapped = wrap_anthropic(client, brain_path=brain_with_em_dash_rule, strict=True) + + with pytest.raises(RuleViolation): + wrapped.messages.create( + model="claude-sonnet-4-5", + messages=[{"role": "user", "content": "hi"}], + max_tokens=16, + ) + + +def test_wrap_anthropic_non_strict_logs_but_does_not_raise( + brain_with_em_dash_rule: Path, caplog +): + from gradata.middleware import wrap_anthropic + + client = _FakeClient(reply="em dash here \u2014 nope") + wrapped = wrap_anthropic(client, brain_path=brain_with_em_dash_rule, strict=False) + + with caplog.at_level("WARNING", logger="gradata.middleware._core"): + resp = wrapped.messages.create( + model="claude-sonnet-4-5", + messages=[{"role": "user", "content": "hi"}], + max_tokens=16, + ) + assert resp is not None # did not raise + assert any("rule violation" in rec.message.lower() for rec in caplog.records) + + +def test_wrap_anthropic_bypass_env_disables_injection( + brain_with_em_dash_rule: Path, monkeypatch +): + from gradata.middleware import wrap_anthropic + + monkeypatch.setenv("GRADATA_BYPASS", "1") + client = _FakeClient() + wrapped = wrap_anthropic(client, brain_path=brain_with_em_dash_rule) + + wrapped.messages.create( + model="claude-sonnet-4-5", + messages=[{"role": "user", "content": "hi"}], + max_tokens=16, + ) + assert "system" not in client.messages.last_kwargs + + +def test_wrap_anthropic_no_brain_is_noop(tmp_path: Path): + from gradata.middleware import wrap_anthropic + + client = _FakeClient() + wrapped = wrap_anthropic(client, brain_path=tmp_path / "missing") + + wrapped.messages.create( + model="claude-sonnet-4-5", + messages=[{"role": "user", "content": "hi"}], + max_tokens=16, + ) + # No brain -> no system injected + assert "system" not in client.messages.last_kwargs + + +def test_wrap_anthropic_delegates_other_attrs(brain_with_em_dash_rule: Path): + from gradata.middleware import wrap_anthropic + + client = _FakeClient() + wrapped = wrap_anthropic(client, brain_path=brain_with_em_dash_rule) + assert wrapped.meta == "keep-me" + + +def test_anthropic_import_error_has_install_hint(monkeypatch): + import importlib + + # Remove both anthropic stub and cached adapter so the import check runs. + monkeypatch.delitem(sys.modules, "anthropic", raising=False) + monkeypatch.delitem( + sys.modules, "gradata.middleware.anthropic_adapter", raising=False, + ) + + # Force ImportError on `import anthropic` by installing a finder that + # rejects it. + import builtins + + real_import = builtins.__import__ + + def _no_anthropic(name, *args, **kwargs): + if name == "anthropic": + raise ImportError("no anthropic") + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", _no_anthropic) + + mod = importlib.import_module("gradata.middleware.anthropic_adapter") + with pytest.raises(ImportError) as exc: + mod.AnthropicMiddleware(object()) + assert "pip install anthropic" in str(exc.value) From 13a609da7b982c0dde4d9fe2cd2e8381a3dbacf2 Mon Sep 17 00:00:00 2001 From: Oliver Le Date: Mon, 13 Apr 2026 13:43:15 -0700 Subject: [PATCH 3/9] =?UTF-8?q?feat(sdk):=20OpenAIMiddleware=20=E2=80=94?= =?UTF-8?q?=20wrap=5Fopenai()=20adapter?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wraps openai.OpenAI() so every client.chat.completions.create() call gets a system message prepended (or merged into an existing system message) and its response text post-checked. - Same strict / bypass semantics as AnthropicMiddleware - Lazy ImportError with 'pip install openai' hint when dep missing --- src/gradata/middleware/openai_adapter.py | 146 ++++++++++++++++++++ tests/test_middleware_openai.py | 168 +++++++++++++++++++++++ 2 files changed, 314 insertions(+) create mode 100644 src/gradata/middleware/openai_adapter.py create mode 100644 tests/test_middleware_openai.py diff --git a/src/gradata/middleware/openai_adapter.py b/src/gradata/middleware/openai_adapter.py new file mode 100644 index 00000000..7dbbc64a --- /dev/null +++ b/src/gradata/middleware/openai_adapter.py @@ -0,0 +1,146 @@ +"""OpenAI SDK middleware adapter. + +Wraps an ``openai.OpenAI()`` client so every +``client.chat.completions.create(...)`` call gets Gradata rules injected +into / prepended to the ``messages`` list as a system message, and its +response optionally checked against RULE-tier regex patterns. + +Usage:: + + from openai import OpenAI + from gradata.middleware import wrap_openai + + client = wrap_openai(OpenAI(), brain_path="./brain") + resp = client.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "Hi"}], + ) +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from gradata.middleware._core import ( + RuleSource, + build_brain_rules_block, + check_output, + inject_into_system, +) + + +def _require_openai() -> None: + try: + import openai # noqa: F401 + except ImportError as exc: # pragma: no cover - import guard + raise ImportError( + "OpenAIMiddleware requires the 'openai' package. " + "Install with: pip install openai" + ) from exc + + +def _extract_text(response: Any) -> str: + """Best-effort text extraction from an OpenAI chat.completions response.""" + choices = getattr(response, "choices", None) + if choices is None and isinstance(response, dict): + choices = response.get("choices") + if not choices: + return "" + parts: list[str] = [] + for choice in choices: + message = getattr(choice, "message", None) + if message is None and isinstance(choice, dict): + message = choice.get("message") + if message is None: + continue + content = getattr(message, "content", None) + if content is None and isinstance(message, dict): + content = message.get("content") + if content: + parts.append(str(content)) + return "\n".join(parts) + + +def _inject_into_messages(messages: list[Any], block: str) -> list[Any]: + """Return a new messages list with rules folded into the system message. + + If a leading system message exists, its ``content`` is extended with the + block; otherwise a new system message is prepended. + """ + if not block: + return list(messages) + out = [dict(m) if isinstance(m, dict) else m for m in messages] + if out and isinstance(out[0], dict) and out[0].get("role") == "system": + existing = out[0].get("content") or "" + out[0]["content"] = inject_into_system( + existing if isinstance(existing, str) else str(existing), + block, + ) + else: + out.insert(0, {"role": "system", "content": block}) + return out + + +class OpenAIMiddleware: + """Wraps an OpenAI client with Gradata rule injection + enforcement.""" + + def __init__( + self, + client: Any, + *, + brain_path: str | Path | None = None, + source: RuleSource | None = None, + strict: bool = False, + ) -> None: + _require_openai() + self._client = client + self._source = source or RuleSource(brain_path=brain_path) + self._strict = strict + self._orig_chat = client.chat + self.chat = _ChatProxy(self) + + def __getattr__(self, name: str) -> Any: + return getattr(self._client, name) + + +class _ChatProxy: + def __init__(self, mw: OpenAIMiddleware) -> None: + self._mw = mw + self.completions = _CompletionsProxy(mw) + + def __getattr__(self, name: str) -> Any: + return getattr(self._mw._orig_chat, name) + + +class _CompletionsProxy: + def __init__(self, mw: OpenAIMiddleware) -> None: + self._mw = mw + + def __getattr__(self, name: str) -> Any: + return getattr(self._mw._orig_chat.completions, name) + + def create(self, *args: Any, **kwargs: Any) -> Any: + block = build_brain_rules_block(self._mw._source) + if block: + messages = kwargs.get("messages") or [] + kwargs["messages"] = _inject_into_messages(list(messages), block) + + response = self._mw._orig_chat.completions.create(*args, **kwargs) + text = _extract_text(response) + if text: + check_output(self._mw._source, text, strict=self._mw._strict) + return response + + +def wrap_openai( + client: Any, + *, + brain_path: str | Path | None = None, + source: RuleSource | None = None, + strict: bool = False, +) -> OpenAIMiddleware: + """Convenience constructor — see :class:`OpenAIMiddleware`.""" + return OpenAIMiddleware( + client, brain_path=brain_path, source=source, strict=strict, + ) diff --git a/tests/test_middleware_openai.py b/tests/test_middleware_openai.py new file mode 100644 index 00000000..6eb75d5c --- /dev/null +++ b/tests/test_middleware_openai.py @@ -0,0 +1,168 @@ +"""Tests for gradata.middleware.openai_adapter (mocked; no real API calls).""" + +from __future__ import annotations + +import sys +import types +from pathlib import Path + +import pytest + + +@pytest.fixture(autouse=True) +def _stub_openai(monkeypatch): + if "openai" not in sys.modules: + stub = types.ModuleType("openai") + stub.OpenAI = object # type: ignore[attr-defined] + monkeypatch.setitem(sys.modules, "openai", stub) + yield + + +class _FakeMessage: + def __init__(self, content: str) -> None: + self.content = content + + +class _FakeChoice: + def __init__(self, content: str) -> None: + self.message = _FakeMessage(content) + + +class _FakeResponse: + def __init__(self, content: str) -> None: + self.choices = [_FakeChoice(content)] + + +class _FakeCompletions: + def __init__(self, reply: str = "ok") -> None: + self.reply = reply + self.last_kwargs: dict = {} + + def create(self, **kwargs): + self.last_kwargs = kwargs + return _FakeResponse(self.reply) + + +class _FakeChat: + def __init__(self, reply: str = "ok") -> None: + self.completions = _FakeCompletions(reply) + + +class _FakeClient: + def __init__(self, reply: str = "ok") -> None: + self.chat = _FakeChat(reply) + self.meta = "delegate" + + +@pytest.fixture +def brain_with_em_dash_rule(tmp_path: Path) -> Path: + brain = tmp_path / "brain" + brain.mkdir() + (brain / "lessons.md").write_text( + "[2026-04-13] [RULE:0.95] TONE: Never use em dashes in prose\n", + encoding="utf-8", + ) + return brain + + +def test_wrap_openai_prepends_system_message(brain_with_em_dash_rule: Path): + from gradata.middleware import wrap_openai + + client = _FakeClient() + wrapped = wrap_openai(client, brain_path=brain_with_em_dash_rule) + + wrapped.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "hi"}], + ) + sent = client.chat.completions.last_kwargs["messages"] + assert sent[0]["role"] == "system" + assert "" in sent[0]["content"] + assert sent[1]["role"] == "user" + + +def test_wrap_openai_extends_existing_system(brain_with_em_dash_rule: Path): + from gradata.middleware import wrap_openai + + client = _FakeClient() + wrapped = wrap_openai(client, brain_path=brain_with_em_dash_rule) + wrapped.chat.completions.create( + model="gpt-4o-mini", + messages=[ + {"role": "system", "content": "Be terse."}, + {"role": "user", "content": "hi"}, + ], + ) + sent = client.chat.completions.last_kwargs["messages"] + assert sent[0]["role"] == "system" + assert sent[0]["content"].startswith("Be terse.") + assert "" in sent[0]["content"] + + +def test_wrap_openai_strict_raises_on_violation(brain_with_em_dash_rule: Path): + from gradata.middleware import RuleViolation, wrap_openai + + client = _FakeClient(reply="bad \u2014 output") + wrapped = wrap_openai(client, brain_path=brain_with_em_dash_rule, strict=True) + with pytest.raises(RuleViolation): + wrapped.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "hi"}], + ) + + +def test_wrap_openai_non_strict_does_not_raise(brain_with_em_dash_rule: Path): + from gradata.middleware import wrap_openai + + client = _FakeClient(reply="bad \u2014 output") + wrapped = wrap_openai(client, brain_path=brain_with_em_dash_rule, strict=False) + # Must not raise + resp = wrapped.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "hi"}], + ) + assert resp is not None + + +def test_wrap_openai_bypass_env(brain_with_em_dash_rule: Path, monkeypatch): + from gradata.middleware import wrap_openai + + monkeypatch.setenv("GRADATA_BYPASS", "1") + client = _FakeClient() + wrapped = wrap_openai(client, brain_path=brain_with_em_dash_rule) + wrapped.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "hi"}], + ) + sent = client.chat.completions.last_kwargs["messages"] + # Unchanged — no system message prepended + assert sent[0]["role"] == "user" + + +def test_wrap_openai_delegates_other_attrs(brain_with_em_dash_rule: Path): + from gradata.middleware import wrap_openai + + client = _FakeClient() + wrapped = wrap_openai(client, brain_path=brain_with_em_dash_rule) + assert wrapped.meta == "delegate" + + +def test_openai_import_error_has_install_hint(monkeypatch): + import builtins + import importlib + + monkeypatch.delitem(sys.modules, "openai", raising=False) + monkeypatch.delitem(sys.modules, "gradata.middleware.openai_adapter", raising=False) + + real_import = builtins.__import__ + + def _no_openai(name, *args, **kwargs): + if name == "openai": + raise ImportError("no openai") + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", _no_openai) + mod = importlib.import_module("gradata.middleware.openai_adapter") + with pytest.raises(ImportError) as exc: + mod.OpenAIMiddleware(object()) + assert "pip install openai" in str(exc.value) From 7353601461ee5dfdedc256bcebe53dd3d4f07bf7 Mon Sep 17 00:00:00 2001 From: Oliver Le Date: Mon, 13 Apr 2026 13:43:37 -0700 Subject: [PATCH 4/9] =?UTF-8?q?feat(sdk):=20LangChainCallback=20=E2=80=94?= =?UTF-8?q?=20BaseCallbackHandler=20adapter?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements LangChain's BaseCallbackHandler: - on_llm_start: prepends to the first prompt - on_chat_model_start: extends/inserts a SystemMessage with the block - on_llm_end: post-checks the LLMResult text against RULE-tier patterns Gracefully handles missing langchain-core — raises ImportError with 'pip install langchain-core' hint at instantiation time via __new__. --- src/gradata/middleware/langchain_adapter.py | 126 +++++++++++++++ tests/test_middleware_langchain.py | 163 ++++++++++++++++++++ 2 files changed, 289 insertions(+) create mode 100644 src/gradata/middleware/langchain_adapter.py create mode 100644 tests/test_middleware_langchain.py diff --git a/src/gradata/middleware/langchain_adapter.py b/src/gradata/middleware/langchain_adapter.py new file mode 100644 index 00000000..d446a722 --- /dev/null +++ b/src/gradata/middleware/langchain_adapter.py @@ -0,0 +1,126 @@ +"""LangChain middleware adapter. + +Provides :class:`LangChainCallback`, a ``BaseCallbackHandler`` that: + +- Injects the Gradata ```` block into prompts at + ``on_llm_start`` / ``on_chat_model_start``. +- Checks the LLM output against RULE-tier regex patterns at ``on_llm_end``. + +Usage:: + + from langchain_openai import ChatOpenAI + from gradata.middleware import LangChainCallback + + llm = ChatOpenAI(callbacks=[LangChainCallback(brain_path="./brain")]) + llm.invoke("Write a short greeting") + +Because LangChain callbacks mutate internal prompt buffers in-place, the +injection is done best-effort on the first prompt only. For stricter +control, prefer the :class:`gradata.middleware.OpenAIMiddleware` wrapper +over the underlying client. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from gradata.middleware._core import ( + RuleSource, + build_brain_rules_block, + check_output, +) + +try: + from langchain_core.callbacks import BaseCallbackHandler as _BaseCallbackHandler + _LANGCHAIN_AVAILABLE = True +except ImportError: + _BaseCallbackHandler = object # type: ignore[assignment,misc] + _LANGCHAIN_AVAILABLE = False + + +class LangChainCallback(_BaseCallbackHandler): # type: ignore[misc,valid-type] + """LangChain callback that injects Gradata rules and enforces them.""" + + def __new__(cls, *args: Any, **kwargs: Any) -> LangChainCallback: + if not _LANGCHAIN_AVAILABLE: + raise ImportError( + "LangChainCallback requires 'langchain-core'. " + "Install with: pip install langchain-core" + ) + return super().__new__(cls) + + def __init__( + self, + *, + brain_path: str | Path | None = None, + source: RuleSource | None = None, + strict: bool = False, + ) -> None: + super().__init__() + self._source = source or RuleSource(brain_path=brain_path) + self._strict = strict + + # -- injection -------------------------------------------------------- + + def on_llm_start( + self, + serialized: dict[str, Any], + prompts: list[str], + **kwargs: Any, + ) -> None: + block = build_brain_rules_block(self._source) + if not block or not prompts: + return + # Prepend block to the first prompt. LangChain uses the list in-place. + prompts[0] = f"{block}\n\n{prompts[0]}" + + def on_chat_model_start( + self, + serialized: dict[str, Any], + messages: list[list[Any]], + **kwargs: Any, + ) -> None: + block = build_brain_rules_block(self._source) + if not block or not messages or not messages[0]: + return + first_batch = messages[0] + # If the first message is a system-style message, extend its content. + first = first_batch[0] + content = getattr(first, "content", None) + msg_type = getattr(first, "type", "") + if content is not None and msg_type == "system": + first.content = f"{content}\n\n{block}" + return + # Otherwise, prepend a SystemMessage if langchain_core is available. + try: + from langchain_core.messages import SystemMessage + except ImportError: # pragma: no cover + return + first_batch.insert(0, SystemMessage(content=block)) + + # -- enforcement ------------------------------------------------------ + + def on_llm_end(self, response: Any, **kwargs: Any) -> None: + text = _extract_llm_text(response) + if not text: + return + check_output(self._source, text, strict=self._strict) + + +def _extract_llm_text(response: Any) -> str: + """Best-effort text extraction from a LangChain ``LLMResult``.""" + generations = getattr(response, "generations", None) + if generations is None and isinstance(response, dict): + generations = response.get("generations") + if not generations: + return "" + parts: list[str] = [] + for batch in generations: + for gen in batch: + text = getattr(gen, "text", None) + if text is None and isinstance(gen, dict): + text = gen.get("text", "") + if text: + parts.append(str(text)) + return "\n".join(parts) diff --git a/tests/test_middleware_langchain.py b/tests/test_middleware_langchain.py new file mode 100644 index 00000000..00c5f09a --- /dev/null +++ b/tests/test_middleware_langchain.py @@ -0,0 +1,163 @@ +"""Tests for gradata.middleware.langchain_adapter (no real LLM calls).""" + +from __future__ import annotations + +import importlib +import sys +import types +from pathlib import Path + +import pytest + +# --------------------------------------------------------------------------- +# Stub out langchain_core so tests run in CI without the real package. +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _stub_langchain(monkeypatch): + if "langchain_core" not in sys.modules: + lc = types.ModuleType("langchain_core") + callbacks_mod = types.ModuleType("langchain_core.callbacks") + messages_mod = types.ModuleType("langchain_core.messages") + + class _BaseCallbackHandler: + def __init__(self) -> None: + pass + + class _SystemMessage: + def __init__(self, content: str) -> None: + self.content = content + self.type = "system" + + callbacks_mod.BaseCallbackHandler = _BaseCallbackHandler + messages_mod.SystemMessage = _SystemMessage + lc.callbacks = callbacks_mod + lc.messages = messages_mod + + monkeypatch.setitem(sys.modules, "langchain_core", lc) + monkeypatch.setitem(sys.modules, "langchain_core.callbacks", callbacks_mod) + monkeypatch.setitem(sys.modules, "langchain_core.messages", messages_mod) + + # Force a fresh import of the adapter so it picks up the stub. + monkeypatch.delitem( + sys.modules, "gradata.middleware.langchain_adapter", raising=False, + ) + yield + + +@pytest.fixture +def brain_with_em_dash_rule(tmp_path: Path) -> Path: + brain = tmp_path / "brain" + brain.mkdir() + (brain / "lessons.md").write_text( + "[2026-04-13] [RULE:0.95] TONE: Never use em dashes in prose\n", + encoding="utf-8", + ) + return brain + + +# Simple fakes for LangChain message + generation types +class _FakeMessage: + def __init__(self, content: str, type_: str = "human") -> None: + self.content = content + self.type = type_ + + +class _FakeGeneration: + def __init__(self, text: str) -> None: + self.text = text + + +class _FakeLLMResult: + def __init__(self, text: str) -> None: + self.generations = [[_FakeGeneration(text)]] + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_on_llm_start_prepends_block_to_first_prompt(brain_with_em_dash_rule: Path): + from gradata.middleware.langchain_adapter import LangChainCallback + + cb = LangChainCallback(brain_path=brain_with_em_dash_rule) + prompts = ["User: hi"] + cb.on_llm_start({}, prompts) + assert prompts[0].startswith("") + assert "User: hi" in prompts[0] + + +def test_on_chat_model_start_inserts_system_message(brain_with_em_dash_rule: Path): + from gradata.middleware.langchain_adapter import LangChainCallback + + cb = LangChainCallback(brain_path=brain_with_em_dash_rule) + batches = [[_FakeMessage("hi", "human")]] + cb.on_chat_model_start({}, batches) + first = batches[0][0] + assert first.type == "system" + assert "" in first.content + + +def test_on_chat_model_start_extends_existing_system(brain_with_em_dash_rule: Path): + from gradata.middleware.langchain_adapter import LangChainCallback + + cb = LangChainCallback(brain_path=brain_with_em_dash_rule) + batches = [[_FakeMessage("You are kind.", "system"), _FakeMessage("hi", "human")]] + cb.on_chat_model_start({}, batches) + assert batches[0][0].type == "system" + assert batches[0][0].content.startswith("You are kind.") + assert "" in batches[0][0].content + + +def test_on_llm_end_strict_raises_on_violation(brain_with_em_dash_rule: Path): + from gradata.middleware import RuleViolation + from gradata.middleware.langchain_adapter import LangChainCallback + + cb = LangChainCallback(brain_path=brain_with_em_dash_rule, strict=True) + result = _FakeLLMResult("bad \u2014 output") + with pytest.raises(RuleViolation): + cb.on_llm_end(result) + + +def test_on_llm_end_non_strict_does_not_raise(brain_with_em_dash_rule: Path): + from gradata.middleware.langchain_adapter import LangChainCallback + + cb = LangChainCallback(brain_path=brain_with_em_dash_rule, strict=False) + cb.on_llm_end(_FakeLLMResult("bad \u2014 output")) # must not raise + + +def test_bypass_env_skips_injection(brain_with_em_dash_rule: Path, monkeypatch): + from gradata.middleware.langchain_adapter import LangChainCallback + + monkeypatch.setenv("GRADATA_BYPASS", "1") + cb = LangChainCallback(brain_path=brain_with_em_dash_rule) + prompts = ["User: hi"] + cb.on_llm_start({}, prompts) + assert prompts[0] == "User: hi" # unchanged + + +def test_langchain_import_error_has_install_hint(monkeypatch): + # Drop the stub + cached adapter; force-fail the import. + monkeypatch.delitem(sys.modules, "langchain_core", raising=False) + monkeypatch.delitem(sys.modules, "langchain_core.callbacks", raising=False) + monkeypatch.delitem(sys.modules, "langchain_core.messages", raising=False) + monkeypatch.delitem( + sys.modules, "gradata.middleware.langchain_adapter", raising=False, + ) + + import builtins + + real_import = builtins.__import__ + + def _no_langchain(name, *args, **kwargs): + if name.startswith("langchain_core"): + raise ImportError("no langchain_core") + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", _no_langchain) + mod = importlib.import_module("gradata.middleware.langchain_adapter") + with pytest.raises(ImportError) as exc: + mod.LangChainCallback() + assert "langchain-core" in str(exc.value) From a2053760b4d9ae9e982b9b27d48c19323241420f Mon Sep 17 00:00:00 2001 From: Oliver Le Date: Mon, 13 Apr 2026 13:43:38 -0700 Subject: [PATCH 5/9] =?UTF-8?q?feat(sdk):=20CrewAIGuard=20=E2=80=94=20guar?= =?UTF-8?q?drails=20callable=20adapter?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Callable that CrewAI agents register in their guardrails=[...] list. Returns (True, output) for clean text; when strict=True (default), returns (False, 'Gradata rule violation(s): ...') so CrewAI can retry. Text-coercion handles CrewAI output objects (raw/output/text/content attrs) as well as plain strings and dicts. No hard crewai dependency. --- src/gradata/middleware/crewai_adapter.py | 90 ++++++++++++++++++++++++ tests/test_middleware_crewai.py | 76 ++++++++++++++++++++ 2 files changed, 166 insertions(+) create mode 100644 src/gradata/middleware/crewai_adapter.py create mode 100644 tests/test_middleware_crewai.py diff --git a/src/gradata/middleware/crewai_adapter.py b/src/gradata/middleware/crewai_adapter.py new file mode 100644 index 00000000..a84327e0 --- /dev/null +++ b/src/gradata/middleware/crewai_adapter.py @@ -0,0 +1,90 @@ +"""CrewAI middleware adapter. + +Provides :class:`CrewAIGuard`, a callable that CrewAI agents can register +in their ``guardrails=[...]`` list. The guard runs on the agent's output +and returns the CrewAI-expected ``(valid, result_or_error)`` tuple. + +Usage:: + + from crewai import Agent + from gradata.middleware import CrewAIGuard + + guard = CrewAIGuard(brain_path="./brain") + agent = Agent( + role="Writer", + goal="Draft clean prose", + backstory="...", + guardrails=[guard], + ) + +This adapter has no hard dependency on ``crewai`` — it only implements +the guard callable shape CrewAI expects, so tests can exercise it with a +plain Python call. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from gradata.middleware._core import ( + RuleSource, + RuleViolation, + check_output, +) + + +class CrewAIGuard: + """A CrewAI-compatible guardrail that enforces Gradata RULE-tier rules. + + CrewAI guardrails are callables that take the agent output and return + ``(is_valid, result_or_error_message)``. When ``strict`` is False + (default) the guard always returns ``(True, output)`` but logs the + violations for observability. When ``strict`` is True, a violation + returns ``(False, "")`` so CrewAI can retry. + """ + + def __init__( + self, + *, + brain_path: str | Path | None = None, + source: RuleSource | None = None, + strict: bool = True, + ) -> None: + self._source = source or RuleSource(brain_path=brain_path) + self._strict = strict + + def __call__(self, output: Any) -> tuple[bool, Any]: + text = _coerce_text(output) + if not text: + return True, output + try: + violations = check_output(self._source, text, strict=False) + except RuleViolation as v: # pragma: no cover - strict=False above + return False, str(v) + if not violations: + return True, output + if self._strict: + message = "; ".join( + f"{v.pattern_name}: {v.rule_description}" for v in violations + ) + return False, f"Gradata rule violation(s): {message}" + return True, output + + +def _coerce_text(output: Any) -> str: + """Best-effort text extraction for CrewAI agent outputs.""" + if output is None: + return "" + if isinstance(output, str): + return output + for attr in ("raw", "output", "text", "content"): + val = getattr(output, attr, None) + if isinstance(val, str) and val: + return val + if isinstance(output, dict): + for key in ("raw", "output", "text", "content"): + val = output.get(key) + if isinstance(val, str) and val: + return val + return str(output) diff --git a/tests/test_middleware_crewai.py b/tests/test_middleware_crewai.py new file mode 100644 index 00000000..7192dc9a --- /dev/null +++ b/tests/test_middleware_crewai.py @@ -0,0 +1,76 @@ +"""Tests for gradata.middleware.crewai_adapter (no real CrewAI calls).""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + + +@pytest.fixture +def brain_with_em_dash_rule(tmp_path: Path) -> Path: + brain = tmp_path / "brain" + brain.mkdir() + (brain / "lessons.md").write_text( + "[2026-04-13] [RULE:0.95] TONE: Never use em dashes in prose\n", + encoding="utf-8", + ) + return brain + + +def test_crewai_guard_passes_clean_output(brain_with_em_dash_rule: Path): + from gradata.middleware import CrewAIGuard + + guard = CrewAIGuard(brain_path=brain_with_em_dash_rule) + ok, result = guard("A perfectly clean string.") + assert ok is True + assert result == "A perfectly clean string." + + +def test_crewai_guard_blocks_violation_when_strict(brain_with_em_dash_rule: Path): + from gradata.middleware import CrewAIGuard + + guard = CrewAIGuard(brain_path=brain_with_em_dash_rule, strict=True) + ok, result = guard("has em dash \u2014 here") + assert ok is False + assert "em-dash" in result or "em dash" in result.lower() + + +def test_crewai_guard_non_strict_allows_violation(brain_with_em_dash_rule: Path): + from gradata.middleware import CrewAIGuard + + guard = CrewAIGuard(brain_path=brain_with_em_dash_rule, strict=False) + text = "has em dash \u2014 here" + ok, result = guard(text) + assert ok is True + assert result == text + + +def test_crewai_guard_extracts_text_from_object(brain_with_em_dash_rule: Path): + from gradata.middleware import CrewAIGuard + + class FakeOutput: + def __init__(self, raw: str) -> None: + self.raw = raw + + guard = CrewAIGuard(brain_path=brain_with_em_dash_rule, strict=True) + ok, _ = guard(FakeOutput("bad \u2014 output")) + assert ok is False + + +def test_crewai_guard_bypass_env(brain_with_em_dash_rule: Path, monkeypatch): + from gradata.middleware import CrewAIGuard + + monkeypatch.setenv("GRADATA_BYPASS", "1") + guard = CrewAIGuard(brain_path=brain_with_em_dash_rule, strict=True) + ok, _ = guard("bad \u2014 output") + assert ok is True # bypass disables enforcement + + +def test_crewai_guard_empty_output_passes(brain_with_em_dash_rule: Path): + from gradata.middleware import CrewAIGuard + + guard = CrewAIGuard(brain_path=brain_with_em_dash_rule, strict=True) + ok, result = guard("") + assert ok is True + assert result == "" From d362e749950213b5749dcb1f1657d4dc4a5f7dce Mon Sep 17 00:00:00 2001 From: Oliver Le Date: Mon, 13 Apr 2026 13:43:38 -0700 Subject: [PATCH 6/9] docs(sdk): add docs/middleware.md with one example per adapter Covers common behavior (rule source, strict mode, GRADATA_BYPASS kill switch, optional-deps ImportError contract) and per-adapter usage for Anthropic, OpenAI, LangChain, CrewAI, plus the advanced custom RuleSource pattern. --- docs/middleware.md | 108 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 docs/middleware.md diff --git a/docs/middleware.md b/docs/middleware.md new file mode 100644 index 00000000..949f6a89 --- /dev/null +++ b/docs/middleware.md @@ -0,0 +1,108 @@ +# Runtime Middleware Adapters + +Gradata's hooks only fire inside Claude Code. For direct-SDK agents +(raw OpenAI SDK, raw Anthropic SDK, LangChain, CrewAI), the +`gradata.middleware` subpackage provides runtime wrappers that inject +learned rules into system prompts and optionally enforce RULE-tier regex +patterns on outputs. + +## Common behavior + +All adapters share one rule source: the same `lessons.md` + brain +database Claude Code hooks use. Selection, confidence floor, and the +`` XML format match `gradata.hooks.inject_brain_rules`. + +- **Cap**: 10 rules per call (configurable via `RuleSource(max_rules=N)`). +- **Priority**: RULE > PATTERN, ties broken by confidence descending. +- **Strict mode**: `strict=False` (default) logs violations; `strict=True` + raises `gradata.middleware.RuleViolation` so callers can retry. +- **Kill switch**: set `GRADATA_BYPASS=1` to disable all injection and + enforcement. +- **Optional deps**: importing `AnthropicMiddleware`, `OpenAIMiddleware`, + `LangChainCallback`, or `CrewAIGuard` without their respective third-party + package raises a clear `ImportError` with an install hint. + +## Anthropic + +```python +from anthropic import Anthropic +from gradata.middleware import wrap_anthropic + +client = wrap_anthropic(Anthropic(), brain_path="./brain") +# ... all client.messages.create(...) calls now get rules injected +resp = client.messages.create( + model="claude-sonnet-4-5", + messages=[{"role": "user", "content": "Write a short greeting"}], + max_tokens=128, +) +``` + +The wrapper mutates only the `system` kwarg (string or content-block +list) and post-checks the response's text blocks. + +## OpenAI + +```python +from openai import OpenAI +from gradata.middleware import wrap_openai + +client = wrap_openai(OpenAI(), brain_path="./brain") +resp = client.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "Write a short greeting"}], +) +``` + +Rules land in the leading system message — extending it if present, +prepending a new one otherwise. + +## LangChain + +```python +from langchain_openai import ChatOpenAI +from gradata.middleware import LangChainCallback + +llm = ChatOpenAI(callbacks=[LangChainCallback(brain_path="./brain")]) +llm.invoke("Write a short greeting") +``` + +Implements `BaseCallbackHandler` with hooks on +`on_llm_start` / `on_chat_model_start` for injection and `on_llm_end` for +enforcement. + +## CrewAI + +```python +from crewai import Agent +from gradata.middleware import CrewAIGuard + +guard = CrewAIGuard(brain_path="./brain", strict=True) +agent = Agent( + role="Writer", + goal="Draft clean prose", + backstory="...", + guardrails=[guard], +) +``` + +The guard returns `(True, output)` when clean and +`(False, "Gradata rule violation(s): ...")` when strict and a RULE-tier +pattern matches — CrewAI then retries. + +## Advanced: custom rule source + +If your lessons live somewhere other than `/lessons.md`, +construct a `RuleSource` directly: + +```python +from gradata.middleware import RuleSource, wrap_anthropic +from anthropic import Anthropic + +source = RuleSource( + lessons=[ + {"state": "RULE", "confidence": 0.95, "category": "TONE", + "description": "Never use em dashes in prose"}, + ], +) +client = wrap_anthropic(Anthropic(), source=source, strict=True) +``` From e57267dd368028d21ecbf963a820f5083b660895 Mon Sep 17 00:00:00 2001 From: Oliver Le Date: Mon, 13 Apr 2026 14:56:58 -0700 Subject: [PATCH 7/9] chore(sdk): address CodeRabbit feedback on PR #32 - _core.py: clamp lesson confidence to [0.0, 1.0] in both intake paths - anthropic_adapter: handle string-shaped response.content (no char-iteration) - crewai_adapter: default strict=False to match docstring/pass-through contract - langchain_adapter: inject rules into every prompt/batch, not just index 0 - langchain_adapter: preserve list/multimodal system content (no stringify) - openai_adapter: preserve structured system content (prepend fresh sys msg) - tests: parametrize openai/langchain strict boundary; cap synthetic confidences - tests: assert concrete response content, add multimodal + batch regressions --- src/gradata/middleware/_core.py | 15 +++++- src/gradata/middleware/anthropic_adapter.py | 4 ++ src/gradata/middleware/crewai_adapter.py | 2 +- src/gradata/middleware/langchain_adapter.py | 46 +++++++++++------- src/gradata/middleware/openai_adapter.py | 12 +++-- tests/test_middleware_core.py | 2 +- tests/test_middleware_langchain.py | 52 ++++++++++++++++----- tests/test_middleware_openai.py | 44 ++++++++++++----- 8 files changed, 129 insertions(+), 48 deletions(-) diff --git a/src/gradata/middleware/_core.py b/src/gradata/middleware/_core.py index f340ef67..4d3159fc 100644 --- a/src/gradata/middleware/_core.py +++ b/src/gradata/middleware/_core.py @@ -71,6 +71,17 @@ class _ScoredLesson: confidence: float +def _clamp_confidence(value: float) -> float: + """Clamp a confidence value into the [0.0, 1.0] range. + + Out-of-range inputs are logged at debug level and clamped rather than + raised — middleware must not fail on malformed lesson inputs. + """ + if value < 0.0 or value > 1.0: + _log.debug("Confidence %s out of [0.0, 1.0]; clamping", value) + return max(0.0, min(value, 1.0)) + + class RuleSource: """Reads lessons from the same brain directory Claude Code hooks use. @@ -122,7 +133,7 @@ def _load_from_dicts(self) -> list[_ScoredLesson]: out: list[_ScoredLesson] = [] for lesson in self._static_lessons or []: state = str(lesson.get("state") or lesson.get("status") or "").upper() - conf = float(lesson.get("confidence", 0.0) or 0.0) + conf = _clamp_confidence(float(lesson.get("confidence", 0.0) or 0.0)) category = str(lesson.get("category", "") or "") description = str(lesson.get("description", "") or "") if not description: @@ -197,7 +208,7 @@ def _lesson_to_scored(lesson: Lesson) -> _ScoredLesson: category=lesson.category, description=lesson.description, state=state_name, - confidence=float(lesson.confidence), + confidence=_clamp_confidence(float(lesson.confidence)), ) diff --git a/src/gradata/middleware/anthropic_adapter.py b/src/gradata/middleware/anthropic_adapter.py index 3ff4f606..487e578c 100644 --- a/src/gradata/middleware/anthropic_adapter.py +++ b/src/gradata/middleware/anthropic_adapter.py @@ -52,6 +52,10 @@ def _extract_text(response: Any) -> str: content = response.get("content") if not content: return "" + # Anthropic responses may expose content as a plain string (older SDKs, + # dict-shaped responses) or as a list of typed content blocks. + if isinstance(content, str): + return content parts: list[str] = [] for block in content: # SDK object: block.type == 'text', block.text == '...' diff --git a/src/gradata/middleware/crewai_adapter.py b/src/gradata/middleware/crewai_adapter.py index a84327e0..d2ee35f2 100644 --- a/src/gradata/middleware/crewai_adapter.py +++ b/src/gradata/middleware/crewai_adapter.py @@ -49,7 +49,7 @@ def __init__( *, brain_path: str | Path | None = None, source: RuleSource | None = None, - strict: bool = True, + strict: bool = False, ) -> None: self._source = source or RuleSource(brain_path=brain_path) self._strict = strict diff --git a/src/gradata/middleware/langchain_adapter.py b/src/gradata/middleware/langchain_adapter.py index d446a722..58ef49bd 100644 --- a/src/gradata/middleware/langchain_adapter.py +++ b/src/gradata/middleware/langchain_adapter.py @@ -72,8 +72,10 @@ def on_llm_start( block = build_brain_rules_block(self._source) if not block or not prompts: return - # Prepend block to the first prompt. LangChain uses the list in-place. - prompts[0] = f"{block}\n\n{prompts[0]}" + # Prepend block to every prompt in the batch (LangChain uses the list + # in-place and a batch call can contain multiple prompts). + for i, prompt in enumerate(prompts): + prompts[i] = f"{block}\n\n{prompt}" def on_chat_model_start( self, @@ -82,22 +84,32 @@ def on_chat_model_start( **kwargs: Any, ) -> None: block = build_brain_rules_block(self._source) - if not block or not messages or not messages[0]: + if not block or not messages: return - first_batch = messages[0] - # If the first message is a system-style message, extend its content. - first = first_batch[0] - content = getattr(first, "content", None) - msg_type = getattr(first, "type", "") - if content is not None and msg_type == "system": - first.content = f"{content}\n\n{block}" - return - # Otherwise, prepend a SystemMessage if langchain_core is available. - try: - from langchain_core.messages import SystemMessage - except ImportError: # pragma: no cover - return - first_batch.insert(0, SystemMessage(content=block)) + system_cls = None + for batch in messages: + if not batch: + continue + first = batch[0] + content = getattr(first, "content", None) + msg_type = getattr(first, "type", "") + if content is not None and msg_type == "system": + # BaseMessage.content may be a str or a list of content blocks + # (multimodal). Preserve structure in both cases. + if isinstance(content, str): + first.content = f"{content}\n\n{block}" + elif isinstance(content, list): + first.content = [*content, {"type": "text", "text": block}] + else: + first.content = [content, {"type": "text", "text": block}] + continue + if system_cls is None: + try: + from langchain_core.messages import SystemMessage + except ImportError: # pragma: no cover + return + system_cls = SystemMessage + batch.insert(0, system_cls(content=block)) # -- enforcement ------------------------------------------------------ diff --git a/src/gradata/middleware/openai_adapter.py b/src/gradata/middleware/openai_adapter.py index 7dbbc64a..7f82877b 100644 --- a/src/gradata/middleware/openai_adapter.py +++ b/src/gradata/middleware/openai_adapter.py @@ -72,11 +72,13 @@ def _inject_into_messages(messages: list[Any], block: str) -> list[Any]: return list(messages) out = [dict(m) if isinstance(m, dict) else m for m in messages] if out and isinstance(out[0], dict) and out[0].get("role") == "system": - existing = out[0].get("content") or "" - out[0]["content"] = inject_into_system( - existing if isinstance(existing, str) else str(existing), - block, - ) + existing = out[0].get("content") + if isinstance(existing, str) or existing is None: + out[0]["content"] = inject_into_system(existing, block) + else: + # Structured (e.g. multimodal list) content — don't stringify it; + # prepend a fresh system message so the original payload is preserved. + out.insert(0, {"role": "system", "content": block}) else: out.insert(0, {"role": "system", "content": block}) return out diff --git a/tests/test_middleware_core.py b/tests/test_middleware_core.py index d7c9a300..03fb43d3 100644 --- a/tests/test_middleware_core.py +++ b/tests/test_middleware_core.py @@ -49,7 +49,7 @@ def test_build_brain_rules_block_wraps_in_xml(): def test_build_brain_rules_block_respects_max_rules(): lessons = [ - {"state": "RULE", "confidence": 0.90 + i / 100, "category": f"C{i}", + {"state": "RULE", "confidence": min(1.0, 0.90 + i / 200), "category": f"C{i}", "description": f"desc {i}"} for i in range(20) ] diff --git a/tests/test_middleware_langchain.py b/tests/test_middleware_langchain.py index 00c5f09a..a3ffbcc8 100644 --- a/tests/test_middleware_langchain.py +++ b/tests/test_middleware_langchain.py @@ -100,6 +100,39 @@ def test_on_chat_model_start_inserts_system_message(brain_with_em_dash_rule: Pat assert "" in first.content +def test_on_llm_start_prepends_block_to_every_prompt_in_batch( + brain_with_em_dash_rule: Path, +): + from gradata.middleware.langchain_adapter import LangChainCallback + + cb = LangChainCallback(brain_path=brain_with_em_dash_rule) + prompts = ["User: first", "User: second", "User: third"] + cb.on_llm_start({}, prompts) + for p in prompts: + assert p.startswith("") + + +def test_on_chat_model_start_preserves_multimodal_list_system( + brain_with_em_dash_rule: Path, +): + from gradata.middleware.langchain_adapter import LangChainCallback + + cb = LangChainCallback(brain_path=brain_with_em_dash_rule) + original_blocks = [{"type": "text", "text": "You are kind."}] + sys_msg = _FakeMessage.__new__(_FakeMessage) + sys_msg.content = original_blocks + sys_msg.type = "system" + batches = [[sys_msg, _FakeMessage("hi", "human")]] + cb.on_chat_model_start({}, batches) + # List structure preserved; new block appended, not stringified. + assert isinstance(sys_msg.content, list) + assert sys_msg.content[0] == {"type": "text", "text": "You are kind."} + assert any( + isinstance(b, dict) and "" in str(b.get("text", "")) + for b in sys_msg.content + ) + + def test_on_chat_model_start_extends_existing_system(brain_with_em_dash_rule: Path): from gradata.middleware.langchain_adapter import LangChainCallback @@ -111,21 +144,18 @@ def test_on_chat_model_start_extends_existing_system(brain_with_em_dash_rule: Pa assert "" in batches[0][0].content -def test_on_llm_end_strict_raises_on_violation(brain_with_em_dash_rule: Path): +@pytest.mark.parametrize("strict", [True, False]) +def test_on_llm_end_strictness(brain_with_em_dash_rule: Path, strict: bool): from gradata.middleware import RuleViolation from gradata.middleware.langchain_adapter import LangChainCallback - cb = LangChainCallback(brain_path=brain_with_em_dash_rule, strict=True) + cb = LangChainCallback(brain_path=brain_with_em_dash_rule, strict=strict) result = _FakeLLMResult("bad \u2014 output") - with pytest.raises(RuleViolation): - cb.on_llm_end(result) - - -def test_on_llm_end_non_strict_does_not_raise(brain_with_em_dash_rule: Path): - from gradata.middleware.langchain_adapter import LangChainCallback - - cb = LangChainCallback(brain_path=brain_with_em_dash_rule, strict=False) - cb.on_llm_end(_FakeLLMResult("bad \u2014 output")) # must not raise + if strict: + with pytest.raises(RuleViolation): + cb.on_llm_end(result) + else: + cb.on_llm_end(result) # must not raise def test_bypass_env_skips_injection(brain_with_em_dash_rule: Path, monkeypatch): diff --git a/tests/test_middleware_openai.py b/tests/test_middleware_openai.py index 6eb75d5c..946f4bce 100644 --- a/tests/test_middleware_openai.py +++ b/tests/test_middleware_openai.py @@ -99,29 +99,51 @@ def test_wrap_openai_extends_existing_system(brain_with_em_dash_rule: Path): assert "" in sent[0]["content"] -def test_wrap_openai_strict_raises_on_violation(brain_with_em_dash_rule: Path): +@pytest.mark.parametrize("strict", [True, False]) +def test_wrap_openai_strictness(brain_with_em_dash_rule: Path, strict: bool): from gradata.middleware import RuleViolation, wrap_openai client = _FakeClient(reply="bad \u2014 output") - wrapped = wrap_openai(client, brain_path=brain_with_em_dash_rule, strict=True) - with pytest.raises(RuleViolation): - wrapped.chat.completions.create( + wrapped = wrap_openai(client, brain_path=brain_with_em_dash_rule, strict=strict) + if strict: + with pytest.raises(RuleViolation): + wrapped.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "hi"}], + ) + else: + resp = wrapped.chat.completions.create( model="gpt-4o-mini", messages=[{"role": "user", "content": "hi"}], ) + # Non-strict: response passes through unchanged. + assert resp.choices[0].message.content == "bad \u2014 output" -def test_wrap_openai_non_strict_does_not_raise(brain_with_em_dash_rule: Path): +def test_wrap_openai_preserves_multimodal_system_content(brain_with_em_dash_rule: Path): from gradata.middleware import wrap_openai - client = _FakeClient(reply="bad \u2014 output") - wrapped = wrap_openai(client, brain_path=brain_with_em_dash_rule, strict=False) - # Must not raise - resp = wrapped.chat.completions.create( + client = _FakeClient() + wrapped = wrap_openai(client, brain_path=brain_with_em_dash_rule) + multimodal_content = [ + {"type": "text", "text": "Base system instructions."}, + {"type": "image_url", "image_url": {"url": "https://example.com/x.png"}}, + ] + wrapped.chat.completions.create( model="gpt-4o-mini", - messages=[{"role": "user", "content": "hi"}], + messages=[ + {"role": "system", "content": multimodal_content}, + {"role": "user", "content": "hi"}, + ], ) - assert resp is not None + sent = client.chat.completions.last_kwargs["messages"] + # Original structured system message must be preserved unchanged. + assert sent[1]["role"] == "system" + assert sent[1]["content"] == multimodal_content + # A new string-content system message carrying the rules was prepended. + assert sent[0]["role"] == "system" + assert isinstance(sent[0]["content"], str) + assert "" in sent[0]["content"] def test_wrap_openai_bypass_env(brain_with_em_dash_rule: Path, monkeypatch): From 8e9c10a73054f04c57e76d52909a5ed8923fa8f3 Mon Sep 17 00:00:00 2001 From: Oliver Le Date: Mon, 13 Apr 2026 20:43:59 -0700 Subject: [PATCH 8/9] refactor(sdk): simplify pass on middleware adapters MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Consolidate the "getattr-or-dict-key" response-probing pattern into a single `_get` helper in `_core.py`. All four adapters (openai/anthropic/langchain/crewai) were re-implementing the same attr-then-fallback-to-dict lookup around response fields and content blocks, so the helper lets each extractor collapse by ~30%. Also clean up adjacent smells: - openai: drop dead `if not block: return list(messages)` branch (guarded upstream) and the double `list(messages)` copy at the call site. - openai `_inject_into_messages`: flatten the nested if/else so the string/None case is the one explicit branch and every non-string content (multimodal list, unexpected shape) falls through to "prepend a fresh system message". - crewai: pull the output text-key tuple into a module constant so it's defined once rather than repeated for attr-vs-dict passes. - __init__: replace the chained if-branch lazy dispatch with a single _LAZY_EXPORTS map + importlib. Behaviour is unchanged — all CR-motivated fixes (multimodal list/string handling, confidence clamping, batch-prompt iteration, strict default flip) are preserved. Full suite: 2111 passed / 23 skipped, ruff clean, pyright 0 errors (same 8 pre-existing warnings). Co-Authored-By: Gradata --- src/gradata/middleware/__init__.py | 41 ++++++++++---------- src/gradata/middleware/_core.py | 17 +++++++- src/gradata/middleware/anthropic_adapter.py | 28 ++++++-------- src/gradata/middleware/crewai_adapter.py | 12 +++--- src/gradata/middleware/langchain_adapter.py | 11 ++---- src/gradata/middleware/openai_adapter.py | 43 +++++++++------------ 6 files changed, 75 insertions(+), 77 deletions(-) diff --git a/src/gradata/middleware/__init__.py b/src/gradata/middleware/__init__.py index 0596084f..1c46b9fa 100644 --- a/src/gradata/middleware/__init__.py +++ b/src/gradata/middleware/__init__.py @@ -59,24 +59,25 @@ ] +# name -> (submodule, attribute) for lazy adapter loading. +_LAZY_EXPORTS = { + "AnthropicMiddleware": ("anthropic_adapter", "AnthropicMiddleware"), + "wrap_anthropic": ("anthropic_adapter", "wrap_anthropic"), + "OpenAIMiddleware": ("openai_adapter", "OpenAIMiddleware"), + "wrap_openai": ("openai_adapter", "wrap_openai"), + "LangChainCallback": ("langchain_adapter", "LangChainCallback"), + "CrewAIGuard": ("crewai_adapter", "CrewAIGuard"), +} + + def __getattr__(name: str): # pragma: no cover - trivial dispatch - if name in ("AnthropicMiddleware", "wrap_anthropic"): - from gradata.middleware.anthropic_adapter import ( - AnthropicMiddleware, - wrap_anthropic, - ) - - return {"AnthropicMiddleware": AnthropicMiddleware, "wrap_anthropic": wrap_anthropic}[name] - if name in ("OpenAIMiddleware", "wrap_openai"): - from gradata.middleware.openai_adapter import OpenAIMiddleware, wrap_openai - - return {"OpenAIMiddleware": OpenAIMiddleware, "wrap_openai": wrap_openai}[name] - if name == "LangChainCallback": - from gradata.middleware.langchain_adapter import LangChainCallback - - return LangChainCallback - if name == "CrewAIGuard": - from gradata.middleware.crewai_adapter import CrewAIGuard - - return CrewAIGuard - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + try: + module_name, attr_name = _LAZY_EXPORTS[name] + except KeyError: + raise AttributeError( + f"module {__name__!r} has no attribute {name!r}", + ) from None + import importlib + + module = importlib.import_module(f"{__name__}.{module_name}") + return getattr(module, attr_name) diff --git a/src/gradata/middleware/_core.py b/src/gradata/middleware/_core.py index 4d3159fc..324d373f 100644 --- a/src/gradata/middleware/_core.py +++ b/src/gradata/middleware/_core.py @@ -13,7 +13,7 @@ import re from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from gradata.enhancements.rule_to_hook import DeterminismCheck, classify_rule @@ -58,6 +58,21 @@ def is_bypassed() -> bool: return os.environ.get("GRADATA_BYPASS", "").strip() == "1" +def _get(obj: Any, key: str, default: Any = None) -> Any: + """Fetch ``key`` from a response-like object using attr-then-dict lookup. + + LLM SDK responses are inconsistently typed across versions — modern + clients expose pydantic objects while older ones (and cassette fixtures) + return plain dicts. Adapters would otherwise repeat + ``getattr(x, k) or (x.get(k) if isinstance(x, dict) else None)`` + for every field they touch. + """ + val = getattr(obj, key, None) + if val is None and isinstance(obj, dict): + return obj.get(key, default) + return val if val is not None else default + + # --------------------------------------------------------------------------- # RuleSource # --------------------------------------------------------------------------- diff --git a/src/gradata/middleware/anthropic_adapter.py b/src/gradata/middleware/anthropic_adapter.py index 487e578c..2ef033a5 100644 --- a/src/gradata/middleware/anthropic_adapter.py +++ b/src/gradata/middleware/anthropic_adapter.py @@ -29,6 +29,7 @@ from gradata.middleware._core import ( RuleSource, + _get, build_brain_rules_block, check_output, inject_into_system, @@ -46,28 +47,23 @@ def _require_anthropic() -> None: def _extract_text(response: Any) -> str: - """Best-effort extraction of the assistant text from an Anthropic response.""" - content = getattr(response, "content", None) - if content is None and isinstance(response, dict): - content = response.get("content") + """Best-effort extraction of the assistant text from an Anthropic response. + + Anthropic responses expose ``content`` either as a plain string (older + SDKs, dict-shaped responses) or as a list of typed content blocks. + """ + content = _get(response, "content") if not content: return "" - # Anthropic responses may expose content as a plain string (older SDKs, - # dict-shaped responses) or as a list of typed content blocks. if isinstance(content, str): return content parts: list[str] = [] for block in content: - # SDK object: block.type == 'text', block.text == '...' - block_type = getattr(block, "type", None) - if block_type is None and isinstance(block, dict): - block_type = block.get("type") - if block_type == "text": - text = getattr(block, "text", None) - if text is None and isinstance(block, dict): - text = block.get("text", "") - if text: - parts.append(str(text)) + if _get(block, "type") != "text": + continue + text = _get(block, "text") + if text: + parts.append(str(text)) return "\n".join(parts) diff --git a/src/gradata/middleware/crewai_adapter.py b/src/gradata/middleware/crewai_adapter.py index d2ee35f2..55053315 100644 --- a/src/gradata/middleware/crewai_adapter.py +++ b/src/gradata/middleware/crewai_adapter.py @@ -30,9 +30,12 @@ from gradata.middleware._core import ( RuleSource, RuleViolation, + _get, check_output, ) +_OUTPUT_TEXT_KEYS = ("raw", "output", "text", "content") + class CrewAIGuard: """A CrewAI-compatible guardrail that enforces Gradata RULE-tier rules. @@ -78,13 +81,8 @@ def _coerce_text(output: Any) -> str: return "" if isinstance(output, str): return output - for attr in ("raw", "output", "text", "content"): - val = getattr(output, attr, None) + for key in _OUTPUT_TEXT_KEYS: + val = _get(output, key) if isinstance(val, str) and val: return val - if isinstance(output, dict): - for key in ("raw", "output", "text", "content"): - val = output.get(key) - if isinstance(val, str) and val: - return val return str(output) diff --git a/src/gradata/middleware/langchain_adapter.py b/src/gradata/middleware/langchain_adapter.py index 58ef49bd..83c2ed2e 100644 --- a/src/gradata/middleware/langchain_adapter.py +++ b/src/gradata/middleware/langchain_adapter.py @@ -27,6 +27,7 @@ from gradata.middleware._core import ( RuleSource, + _get, build_brain_rules_block, check_output, ) @@ -122,17 +123,11 @@ def on_llm_end(self, response: Any, **kwargs: Any) -> None: def _extract_llm_text(response: Any) -> str: """Best-effort text extraction from a LangChain ``LLMResult``.""" - generations = getattr(response, "generations", None) - if generations is None and isinstance(response, dict): - generations = response.get("generations") - if not generations: - return "" + generations = _get(response, "generations") or [] parts: list[str] = [] for batch in generations: for gen in batch: - text = getattr(gen, "text", None) - if text is None and isinstance(gen, dict): - text = gen.get("text", "") + text = _get(gen, "text") if text: parts.append(str(text)) return "\n".join(parts) diff --git a/src/gradata/middleware/openai_adapter.py b/src/gradata/middleware/openai_adapter.py index 7f82877b..81472b1b 100644 --- a/src/gradata/middleware/openai_adapter.py +++ b/src/gradata/middleware/openai_adapter.py @@ -24,6 +24,7 @@ from gradata.middleware._core import ( RuleSource, + _get, build_brain_rules_block, check_output, inject_into_system, @@ -42,21 +43,13 @@ def _require_openai() -> None: def _extract_text(response: Any) -> str: """Best-effort text extraction from an OpenAI chat.completions response.""" - choices = getattr(response, "choices", None) - if choices is None and isinstance(response, dict): - choices = response.get("choices") - if not choices: - return "" + choices = _get(response, "choices") or [] parts: list[str] = [] for choice in choices: - message = getattr(choice, "message", None) - if message is None and isinstance(choice, dict): - message = choice.get("message") + message = _get(choice, "message") if message is None: continue - content = getattr(message, "content", None) - if content is None and isinstance(message, dict): - content = message.get("content") + content = _get(message, "content") if content: parts.append(str(content)) return "\n".join(parts) @@ -65,20 +58,19 @@ def _extract_text(response: Any) -> str: def _inject_into_messages(messages: list[Any], block: str) -> list[Any]: """Return a new messages list with rules folded into the system message. - If a leading system message exists, its ``content`` is extended with the - block; otherwise a new system message is prepended. + If a leading system message with string content exists, its content is + extended with the block. In every other case (no system message, or a + system message whose content is a structured multimodal list) a fresh + system message is prepended so the original payload is preserved. """ - if not block: - return list(messages) out = [dict(m) if isinstance(m, dict) else m for m in messages] - if out and isinstance(out[0], dict) and out[0].get("role") == "system": - existing = out[0].get("content") - if isinstance(existing, str) or existing is None: - out[0]["content"] = inject_into_system(existing, block) - else: - # Structured (e.g. multimodal list) content — don't stringify it; - # prepend a fresh system message so the original payload is preserved. - out.insert(0, {"role": "system", "content": block}) + head = out[0] if out else None + if ( + isinstance(head, dict) + and head.get("role") == "system" + and isinstance(head.get("content"), (str, type(None))) + ): + head["content"] = inject_into_system(head.get("content"), block) else: out.insert(0, {"role": "system", "content": block}) return out @@ -125,8 +117,9 @@ def __getattr__(self, name: str) -> Any: def create(self, *args: Any, **kwargs: Any) -> Any: block = build_brain_rules_block(self._mw._source) if block: - messages = kwargs.get("messages") or [] - kwargs["messages"] = _inject_into_messages(list(messages), block) + kwargs["messages"] = _inject_into_messages( + kwargs.get("messages") or [], block, + ) response = self._mw._orig_chat.completions.create(*args, **kwargs) text = _extract_text(response) From 701d27ad03edffca595d7591b1846ac10daa109b Mon Sep 17 00:00:00 2001 From: Oliver Le Date: Tue, 14 Apr 2026 10:21:45 -0700 Subject: [PATCH 9/9] fix(review): address CodeRabbit feedback round 3 on middleware adapters --- src/gradata/middleware/_core.py | 11 +++++- src/gradata/middleware/crewai_adapter.py | 18 +++++++--- src/gradata/middleware/langchain_adapter.py | 9 +++-- tests/test_middleware_core.py | 15 ++++++++ tests/test_middleware_crewai.py | 38 ++++++++++++++------- 5 files changed, 70 insertions(+), 21 deletions(-) diff --git a/src/gradata/middleware/_core.py b/src/gradata/middleware/_core.py index 324d373f..48d370b4 100644 --- a/src/gradata/middleware/_core.py +++ b/src/gradata/middleware/_core.py @@ -148,7 +148,16 @@ def _load_from_dicts(self) -> list[_ScoredLesson]: out: list[_ScoredLesson] = [] for lesson in self._static_lessons or []: state = str(lesson.get("state") or lesson.get("status") or "").upper() - conf = _clamp_confidence(float(lesson.get("confidence", 0.0) or 0.0)) + raw_conf = lesson.get("confidence", 0.0) + try: + conf = _clamp_confidence(float(raw_conf) if raw_conf is not None else 0.0) + except (TypeError, ValueError): + # Malformed caller-supplied lessons (e.g. confidence="high") + # must not abort the whole injection/enforcement path. + _log.debug( + "Skipping lesson with non-numeric confidence %r", raw_conf, + ) + continue category = str(lesson.get("category", "") or "") description = str(lesson.get("description", "") or "") if not description: diff --git a/src/gradata/middleware/crewai_adapter.py b/src/gradata/middleware/crewai_adapter.py index 55053315..a3e88d12 100644 --- a/src/gradata/middleware/crewai_adapter.py +++ b/src/gradata/middleware/crewai_adapter.py @@ -30,7 +30,6 @@ from gradata.middleware._core import ( RuleSource, RuleViolation, - _get, check_output, ) @@ -76,13 +75,24 @@ def __call__(self, output: Any) -> tuple[bool, Any]: def _coerce_text(output: Any) -> str: - """Best-effort text extraction for CrewAI agent outputs.""" + """Best-effort text extraction for CrewAI agent outputs. + + Preserves explicitly empty string fields (``raw=""``, ``output=""``, ...) + so the guard's empty-output fast path still applies instead of falling + through to ``str(output)`` and producing an object repr. + """ if output is None: return "" if isinstance(output, str): return output + # Prefer attribute lookup first (typed CrewAI outputs), then dict lookup. for key in _OUTPUT_TEXT_KEYS: - val = _get(output, key) - if isinstance(val, str) and val: + val = getattr(output, key, None) + if isinstance(val, str): return val + if isinstance(output, dict): + for key in _OUTPUT_TEXT_KEYS: + val = output.get(key) + if isinstance(val, str): + return val return str(output) diff --git a/src/gradata/middleware/langchain_adapter.py b/src/gradata/middleware/langchain_adapter.py index 83c2ed2e..69919274 100644 --- a/src/gradata/middleware/langchain_adapter.py +++ b/src/gradata/middleware/langchain_adapter.py @@ -15,9 +15,12 @@ llm.invoke("Write a short greeting") Because LangChain callbacks mutate internal prompt buffers in-place, the -injection is done best-effort on the first prompt only. For stricter -control, prefer the :class:`gradata.middleware.OpenAIMiddleware` wrapper -over the underlying client. +injection covers every prompt / message batch entry in a single callback +invocation: ``on_llm_start`` prepends the block to every prompt in the +list, and ``on_chat_model_start`` injects a system message into every +batch. For stricter control (e.g. structured responses), prefer the +:class:`gradata.middleware.OpenAIMiddleware` wrapper over the underlying +client. """ from __future__ import annotations diff --git a/tests/test_middleware_core.py b/tests/test_middleware_core.py index 03fb43d3..a1339110 100644 --- a/tests/test_middleware_core.py +++ b/tests/test_middleware_core.py @@ -133,3 +133,18 @@ def test_rule_source_missing_brain_returns_empty(tmp_path: Path): src = RuleSource(brain_path=tmp_path / "does-not-exist") assert src.select() == [] assert build_brain_rules_block(src) == "" + + +def test_rule_source_skips_non_numeric_confidence(): + # Malformed caller-supplied lessons must not abort the injection path. + src = RuleSource( + lessons=[ + {"state": "RULE", "confidence": "high", "category": "TONE", + "description": "malformed"}, + {"state": "RULE", "confidence": 0.95, "category": "TONE", + "description": "Never use em dashes"}, + ], + ) + selected = src.select() + assert len(selected) == 1 + assert selected[0].description == "Never use em dashes" diff --git a/tests/test_middleware_crewai.py b/tests/test_middleware_crewai.py index 7192dc9a..f0e96beb 100644 --- a/tests/test_middleware_crewai.py +++ b/tests/test_middleware_crewai.py @@ -27,23 +27,19 @@ def test_crewai_guard_passes_clean_output(brain_with_em_dash_rule: Path): assert result == "A perfectly clean string." -def test_crewai_guard_blocks_violation_when_strict(brain_with_em_dash_rule: Path): +@pytest.mark.parametrize("strict", [True, False]) +def test_crewai_guard_strictness(brain_with_em_dash_rule: Path, strict: bool): from gradata.middleware import CrewAIGuard - guard = CrewAIGuard(brain_path=brain_with_em_dash_rule, strict=True) - ok, result = guard("has em dash \u2014 here") - assert ok is False - assert "em-dash" in result or "em dash" in result.lower() - - -def test_crewai_guard_non_strict_allows_violation(brain_with_em_dash_rule: Path): - from gradata.middleware import CrewAIGuard - - guard = CrewAIGuard(brain_path=brain_with_em_dash_rule, strict=False) + guard = CrewAIGuard(brain_path=brain_with_em_dash_rule, strict=strict) text = "has em dash \u2014 here" ok, result = guard(text) - assert ok is True - assert result == text + if strict: + assert ok is False + assert "em-dash" in result or "em dash" in result.lower() + else: + assert ok is True + assert result == text def test_crewai_guard_extracts_text_from_object(brain_with_em_dash_rule: Path): @@ -74,3 +70,19 @@ def test_crewai_guard_empty_output_passes(brain_with_em_dash_rule: Path): ok, result = guard("") assert ok is True assert result == "" + + +def test_crewai_guard_preserves_empty_raw_field(brain_with_em_dash_rule: Path): + """An explicitly empty ``raw`` must hit the empty-output fast path, + not fall through to ``str(output)`` (object repr).""" + from gradata.middleware import CrewAIGuard + + class FakeOutput: + def __init__(self) -> None: + self.raw = "" + + guard = CrewAIGuard(brain_path=brain_with_em_dash_rule, strict=True) + ok, result = guard(FakeOutput()) + assert ok is True + # Output passes through unchanged (not stringified). + assert isinstance(result, FakeOutput)