From 4605223dc72909e0bf0ddedb4dc67700b971de55 Mon Sep 17 00:00:00 2001 From: Oliver Le Date: Fri, 10 Apr 2026 12:03:05 -0700 Subject: [PATCH 01/13] feat(P0.1): add semantic_similarity to EmbeddingClient Co-Authored-By: Gradata --- src/gradata/integrations/embeddings.py | 27 +++++++++++++++++++++++--- tests/test_embeddings.py | 26 +++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 3 deletions(-) create mode 100644 tests/test_embeddings.py diff --git a/src/gradata/integrations/embeddings.py b/src/gradata/integrations/embeddings.py index 3d466d06..f0250ac8 100644 --- a/src/gradata/integrations/embeddings.py +++ b/src/gradata/integrations/embeddings.py @@ -43,10 +43,13 @@ def embed(self, text): def _is_trusted_url(url): """SSRF protection: only allow HTTPS to gradata.ai or localhost.""" from urllib.parse import urlparse + parsed = urlparse(url) if parsed.hostname in ("localhost", "127.0.0.1"): return True - return bool(parsed.scheme == "https" and parsed.hostname and parsed.hostname.endswith("gradata.ai")) + return bool( + parsed.scheme == "https" and parsed.hostname and parsed.hostname.endswith("gradata.ai") + ) def _embed_api(self, text): if not self._is_trusted_url(self.api_url): @@ -74,6 +77,16 @@ def _embed_local(self, text): vec = [v / norm for v in vec] return vec + def semantic_similarity(self, text_a: str, text_b: str) -> float: + """Cosine similarity between embeddings of two texts. + Returns 0.0 if either text is empty, otherwise [0.0, 1.0]. + """ + if not text_a or not text_b: + return 0.0 + vec_a = self.embed(text_a) + vec_b = self.embed(text_b) + return cosine_similarity(vec_a, vec_b) + _default_client = None @@ -155,6 +168,14 @@ def _embed_and_cache(payload, *keys): except Exception: logger.warning("Failed to embed payload", exc_info=True) - bus.on("correction.created", lambda p: _embed_and_cache(p, "description", "text"), async_handler=True) + bus.on( + "correction.created", + lambda p: _embed_and_cache(p, "description", "text"), + async_handler=True, + ) bus.on("lesson.graduated", lambda p: _embed_and_cache(p, "description"), async_handler=True) - bus.on("meta_rule.created", lambda p: _embed_and_cache(p, "description", "rule"), async_handler=True) + bus.on( + "meta_rule.created", + lambda p: _embed_and_cache(p, "description", "rule"), + async_handler=True, + ) diff --git a/tests/test_embeddings.py b/tests/test_embeddings.py new file mode 100644 index 00000000..f44a8514 --- /dev/null +++ b/tests/test_embeddings.py @@ -0,0 +1,26 @@ +"""Tests for EmbeddingClient.semantic_similarity().""" + +import pytest + +from gradata.integrations.embeddings import EmbeddingClient + + +class TestSemanticSimilarity: + def test_identical_texts_similarity_is_one(self): + client = EmbeddingClient() + sim = client.semantic_similarity("hello world", "hello world") + assert sim == pytest.approx(1.0, abs=0.01) + + def test_different_texts_similarity_below_one(self): + client = EmbeddingClient() + sim = client.semantic_similarity( + "Please send the invoice by Friday", + "The weather is sunny today", + ) + assert sim < 0.8 + + def test_empty_text_returns_zero(self): + client = EmbeddingClient() + assert client.semantic_similarity("", "hello") == 0.0 + assert client.semantic_similarity("hello", "") == 0.0 + assert client.semantic_similarity("", "") == 0.0 From 6492c20281647164d89c320a740cd09e3fd681f4 Mon Sep 17 00:00:00 2001 From: Oliver Le Date: Fri, 10 Apr 2026 12:03:43 -0700 Subject: [PATCH 02/13] feat(P1.1): add alpha/beta_param fields to Lesson for Bayesian confidence Co-Authored-By: Gradata --- src/gradata/_types.py | 103 +++++++++++++++++++++++------------------- tests/test_types.py | 29 ++++++++++++ 2 files changed, 85 insertions(+), 47 deletions(-) create mode 100644 tests/test_types.py diff --git a/src/gradata/_types.py b/src/gradata/_types.py index 1b0df042..3574e873 100644 --- a/src/gradata/_types.py +++ b/src/gradata/_types.py @@ -19,42 +19,44 @@ class CorrectionType(Enum): This enables smarter graduation: behavioral rules graduate faster than factual corrections because they generalize better. """ - BEHAVIORAL = "behavioral" # Style, tone, format, process (generalizes well) - FACTUAL = "factual" # Wrong data, incorrect API, deprecated method (specific) - PROCEDURAL = "procedural" # Skipped steps, wrong order, missing verification - PREFERENCE = "preference" # User taste (em dashes, bold, date format) - DOMAIN = "domain" # Industry-specific rules (pricing, compliance, ICP) + + BEHAVIORAL = "behavioral" # Style, tone, format, process (generalizes well) + FACTUAL = "factual" # Wrong data, incorrect API, deprecated method (specific) + PROCEDURAL = "procedural" # Skipped steps, wrong order, missing verification + PREFERENCE = "preference" # User taste (em dashes, bold, date format) + DOMAIN = "domain" # Industry-specific rules (pricing, compliance, ICP) class RuleTransferScope(Enum): """How transferable a rule is across users/teams.""" - PERSONAL = "personal" # One user's style (email tone, formatting preference) - TEAM = "team" # Org-specific workflow (CRM naming, tool-specific format) + + PERSONAL = "personal" # One user's style (email tone, formatting preference) + TEAM = "team" # Org-specific workflow (CRM naming, tool-specific format) UNIVERSAL = "universal" # Everyone benefits (no AI tells, verify data, don't fabricate) class CorrectionScope(Enum): """Hierarchical scope for corrections — controls graduation ceiling.""" - UNIVERSAL = "universal" # Applies everywhere - DOMAIN = "domain" # Default: applies within this domain - PROJECT = "project" # Applies only to this project - ONE_OFF = "one_off" # Never graduates past INSTINCT + + UNIVERSAL = "universal" # Applies everywhere + DOMAIN = "domain" # Default: applies within this domain + PROJECT = "project" # Applies only to this project + ONE_OFF = "one_off" # Never graduates past INSTINCT class LessonState(Enum): """Maturity tiers for a learned lesson.""" - INSTINCT = "INSTINCT" # 0.00 - 0.59 - PATTERN = "PATTERN" # 0.60 - 0.89 - RULE = "RULE" # 0.90+ - UNTESTABLE = "UNTESTABLE" # 20+ sessions with 0 fires - ARCHIVED = "ARCHIVED" # graduated out (moved to lessons-archive.md) - KILLED = "KILLED" # failed ablation or untestable 20+ sessions + + INSTINCT = "INSTINCT" # 0.00 - 0.59 + PATTERN = "PATTERN" # 0.60 - 0.89 + RULE = "RULE" # 0.90+ + UNTESTABLE = "UNTESTABLE" # 20+ sessions with 0 fires + ARCHIVED = "ARCHIVED" # graduated out (moved to lessons-archive.md) + KILLED = "KILLED" # failed ablation or untestable 20+ sessions # States eligible for rule injection (shared across rule_engine and meta_rules). -ELIGIBLE_STATES: frozenset[LessonState] = frozenset( - {LessonState.RULE, LessonState.PATTERN} -) +ELIGIBLE_STATES: frozenset[LessonState] = frozenset({LessonState.RULE, LessonState.PATTERN}) # --------------------------------------------------------------------------- @@ -63,24 +65,24 @@ class LessonState(Enum): TRANSITIONS: dict[LessonState, dict[str, LessonState]] = { LessonState.INSTINCT: { - "promote": LessonState.PATTERN, # confidence >= 0.60 - "kill": LessonState.KILLED, # confidence <= 0.0 or untestable 20+ sessions + "promote": LessonState.PATTERN, # confidence >= 0.60 + "kill": LessonState.KILLED, # confidence <= 0.0 or untestable 20+ sessions }, LessonState.PATTERN: { - "promote": LessonState.RULE, # confidence >= 0.90 - "demote": LessonState.INSTINCT, # confidence < 0.60 after penalty - "kill": LessonState.KILLED, # confidence <= 0.0 + "promote": LessonState.RULE, # confidence >= 0.90 + "demote": LessonState.INSTINCT, # confidence < 0.60 after penalty + "kill": LessonState.KILLED, # confidence <= 0.0 }, LessonState.RULE: { - "archive": LessonState.ARCHIVED, # moved to lessons-archive.md - "demote": LessonState.PATTERN, # failed ablation + "archive": LessonState.ARCHIVED, # moved to lessons-archive.md + "demote": LessonState.PATTERN, # failed ablation }, LessonState.UNTESTABLE: { - "kill": LessonState.KILLED, # untestable 20+ sessions - "promote": LessonState.INSTINCT, # got testable evidence + "kill": LessonState.KILLED, # untestable 20+ sessions + "promote": LessonState.INSTINCT, # got testable evidence }, - LessonState.ARCHIVED: {}, # terminal state - LessonState.KILLED: {}, # terminal state + LessonState.ARCHIVED: {}, # terminal state + LessonState.KILLED: {}, # terminal state } @@ -114,6 +116,7 @@ class RuleMetadata: Tracks provenance (who, what, why, when, where, how) and dual scores that let the injection pipeline balance usefulness vs safety. """ + what: str = "" why: str = "" who: str = "" @@ -135,29 +138,35 @@ def to_dict(self) -> dict: @dataclass class Lesson: """A single learned lesson with confidence tracking.""" - date: str # ISO date when the lesson was created - state: LessonState # Current maturity tier - confidence: float # 0.00 - 1.00 - category: str # e.g. DRAFTING, ACCURACY, PROCESS - description: str # Full text after "CATEGORY: " - root_cause: str = "" # Root cause analysis (after "Root cause:") - fire_count: int = 0 # Times the lesson was triggered/applied - sessions_since_fire: int = 0 # Sessions since last application - misfire_count: int = 0 # Times applied but made output worse - scope_json: str = "" # JSON-serialized RuleScope; empty = universal + + date: str # ISO date when the lesson was created + state: LessonState # Current maturity tier + confidence: float # 0.00 - 1.00 + category: str # e.g. DRAFTING, ACCURACY, PROCESS + description: str # Full text after "CATEGORY: " + root_cause: str = "" # Root cause analysis (after "Root cause:") + fire_count: int = 0 # Times the lesson was triggered/applied + sessions_since_fire: int = 0 # Sessions since last application + misfire_count: int = 0 # Times applied but made output worse + scope_json: str = "" # JSON-serialized RuleScope; empty = universal transfer_scope: RuleTransferScope = RuleTransferScope.PERSONAL # How transferable correction_type: CorrectionType = CorrectionType.BEHAVIORAL # What kind of correction - example_draft: str | None = None # Before: what the AI produced + example_draft: str | None = None # Before: what the AI produced example_corrected: str | None = None # After: what the user changed it to - agent_type: str = "" # Agent that produced this lesson (e.g. "researcher", "email-drafter") - kill_reason: str = "" # Why this lesson was killed (e.g. "zero_confidence", "untestable_idle", "manual_rollback") - correction_event_ids: list[str] = field(default_factory=list) # Event IDs that created/updated this lesson (proof chain) + agent_type: str = "" # Agent that produced this lesson (e.g. "researcher", "email-drafter") + kill_reason: str = "" # Why this lesson was killed (e.g. "zero_confidence", "untestable_idle", "manual_rollback") + correction_event_ids: list[str] = field( + default_factory=list + ) # Event IDs that created/updated this lesson (proof chain) pending_approval: bool = False # True = awaiting human review before graduation parent_meta_rule_id: str | None = None # Meta-rule this lesson contributed to memory_ids: list[str] = field(default_factory=list) # Linked memory IDs - domain_scores: dict[str, dict[str, int]] = field(default_factory=dict) # Per-domain fire/misfire tracking + domain_scores: dict[str, dict[str, int]] = field( + default_factory=dict + ) # Per-domain fire/misfire tracking + alpha: float = 1.0 # Beta distribution α (prior + successes) + beta_param: float = 1.0 # Beta distribution β (prior + failures) metadata: RuleMetadata = field(default_factory=RuleMetadata) # 5W1H + dual scores def __post_init__(self) -> None: self.confidence = round(max(0.0, min(1.0, self.confidence)), 2) - diff --git a/tests/test_types.py b/tests/test_types.py new file mode 100644 index 00000000..eb45f5d3 --- /dev/null +++ b/tests/test_types.py @@ -0,0 +1,29 @@ +"""Tests for _types.py dataclass fields.""" + +from gradata._types import Lesson, LessonState + + +class TestLessonBetaFields: + def test_lesson_has_alpha_beta(self): + lesson = Lesson( + date="2026-04-10", + state=LessonState.INSTINCT, + confidence=0.40, + category="TONE", + description="Be direct", + ) + assert lesson.alpha == 1.0 + assert lesson.beta_param == 1.0 + + def test_lesson_custom_alpha_beta(self): + lesson = Lesson( + date="2026-04-10", + state=LessonState.PATTERN, + confidence=0.72, + category="TONE", + description="Be direct", + alpha=8.0, + beta_param=3.0, + ) + assert lesson.alpha == 8.0 + assert lesson.beta_param == 3.0 From 6d2016f5def1b3e66a002764f495dab276c61d9e Mon Sep 17 00:00:00 2001 From: Oliver Le Date: Fri, 10 Apr 2026 12:05:10 -0700 Subject: [PATCH 03/13] feat(P2.1): cross-correction recurring pattern detection Co-Authored-By: Gradata --- .../enhancements/behavioral_extractor.py | 395 +++++++++++++----- tests/test_behavioral_extractor.py | 61 +++ 2 files changed, 357 insertions(+), 99 deletions(-) create mode 100644 tests/test_behavioral_extractor.py diff --git a/src/gradata/enhancements/behavioral_extractor.py b/src/gradata/enhancements/behavioral_extractor.py index a142ad84..f4003f23 100644 --- a/src/gradata/enhancements/behavioral_extractor.py +++ b/src/gradata/enhancements/behavioral_extractor.py @@ -17,6 +17,7 @@ Informed by MiroFish sim 26 research + prior art from Grammarly, Duolingo, Google Smart Compose (sentence-level diffs > word-level for behavioral extraction). """ + from __future__ import annotations import logging @@ -26,6 +27,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: + from gradata._types import Lesson from gradata.enhancements.edit_classifier import EditClassification _log = logging.getLogger(__name__) @@ -41,18 +43,19 @@ # Archetype Taxonomy (12 correction types) # --------------------------------------------------------------------------- + class Archetype(Enum): - PREFIX_INSTRUCTION = auto() # Already an instruction ("User corrected: ...") - ADDITION_STEP = auto() # New workflow step inserted - REMOVAL_HEDGING = auto() # Hedging/weak language removed - REMOVAL_CONTENT = auto() # Whole content block removed - REPLACEMENT_WORD = auto() # Specific word/phrase swap - REPLACEMENT_TONE = auto() # Formality/tone shift + PREFIX_INSTRUCTION = auto() # Already an instruction ("User corrected: ...") + ADDITION_STEP = auto() # New workflow step inserted + REMOVAL_HEDGING = auto() # Hedging/weak language removed + REMOVAL_CONTENT = auto() # Whole content block removed + REPLACEMENT_WORD = auto() # Specific word/phrase swap + REPLACEMENT_TONE = auto() # Formality/tone shift REPLACEMENT_FACTUAL = auto() # Numbers/dates/URLs corrected - REORDER = auto() # Same content, different order - FORMAT_CHANGE = auto() # Structure changed (prose→list, etc.) - TRUNCATION = auto() # Significantly shortened (>35% reduction) - EXPANSION = auto() # Significantly expanded (>50% growth) + REORDER = auto() # Same content, different order + FORMAT_CHANGE = auto() # Structure changed (prose→list, etc.) + TRUNCATION = auto() # Significantly shortened (>35% reduction) + EXPANSION = auto() # Significantly expanded (>50% growth) CONSTRAINT_ADDITION = auto() # New always/never/must rule @@ -60,7 +63,7 @@ class Archetype(Enum): class ArchetypeMatch: archetype: Archetype confidence: float # 0.0–1.0 - context: dict # archetype-specific extracted data + context: dict # archetype-specific extracted data # --------------------------------------------------------------------------- @@ -70,28 +73,85 @@ class ArchetypeMatch: # Multi-word hedge phrases for substring matching in raw text. # Single-word hedges overlap with edit_classifier._TONE_WORDS (used for # word-set classification). Both lists must be updated together. -_HEDGE_PHRASES = frozenset({ - "no rush", "no pressure", "if you want", "if you'd like", - "when you get a chance", "at your convenience", "just", - "maybe", "perhaps", "possibly", "i think", "i believe", - "might", "could potentially", "sort of", "kind of", - "not sure but", "feel free to", "no worries if not", - "totally understand if", "let me know if", -}) - -_CONSTRAINT_WORDS = frozenset({ - "always", "never", "must", "don't", "do not", - "required", "mandatory", "prohibited", -}) - -_ACTION_VERBS = frozenset({ - "check", "verify", "validate", "review", "audit", "test", - "pull", "push", "run", "execute", "load", "import", "export", - "create", "build", "deploy", "send", "submit", "approve", - "research", "analyze", "compare", "confirm", "refresh", - "open", "close", "start", "stop", "enable", "disable", - "use", "avoid", "include", "exclude", "add", "remove", -}) +_HEDGE_PHRASES = frozenset( + { + "no rush", + "no pressure", + "if you want", + "if you'd like", + "when you get a chance", + "at your convenience", + "just", + "maybe", + "perhaps", + "possibly", + "i think", + "i believe", + "might", + "could potentially", + "sort of", + "kind of", + "not sure but", + "feel free to", + "no worries if not", + "totally understand if", + "let me know if", + } +) + +_CONSTRAINT_WORDS = frozenset( + { + "always", + "never", + "must", + "don't", + "do not", + "required", + "mandatory", + "prohibited", + } +) + +_ACTION_VERBS = frozenset( + { + "check", + "verify", + "validate", + "review", + "audit", + "test", + "pull", + "push", + "run", + "execute", + "load", + "import", + "export", + "create", + "build", + "deploy", + "send", + "submit", + "approve", + "research", + "analyze", + "compare", + "confirm", + "refresh", + "open", + "close", + "start", + "stop", + "enable", + "disable", + "use", + "avoid", + "include", + "exclude", + "add", + "remove", + } +) _PREFIX_PATTERNS = [ re.compile(r"^User corrected:\s*(.+)", re.IGNORECASE), @@ -106,9 +166,10 @@ class ArchetypeMatch: # Sentence-level helpers # --------------------------------------------------------------------------- + def _split_sentences(text: str) -> list[str]: """Split text into sentences at period/question/exclamation boundaries.""" - parts = re.split(r'(?<=[.!?])\s+', text.strip()) + parts = re.split(r"(?<=[.!?])\s+", text.strip()) return [p.strip() for p in parts if p.strip()] @@ -134,8 +195,9 @@ def _extract_topic(sentences: list[str]) -> str: if not sentences: return "content" text = sentences[0].strip() - text = re.sub(r'^(the|a|an|this|that|we|i|you|it|please|also|then)\s+', - '', text, flags=re.IGNORECASE) + text = re.sub( + r"^(the|a|an|this|that|we|i|you|it|please|also|then)\s+", "", text, flags=re.IGNORECASE + ) words = text.split()[:6] return " ".join(words).rstrip(".,;:") if words else "content" @@ -173,6 +235,7 @@ def _to_imperative(sentence: str) -> str: # Archetype Detection # --------------------------------------------------------------------------- + def detect_archetype( draft: str, final: str, @@ -187,8 +250,7 @@ def detect_archetype( m = pattern.match(final.strip()) if m: return ArchetypeMatch( - Archetype.PREFIX_INSTRUCTION, 1.0, - {"instruction": m.group(1).strip()} + Archetype.PREFIX_INSTRUCTION, 1.0, {"instruction": m.group(1).strip()} ) draft_sents = _split_sentences(draft) @@ -201,65 +263,70 @@ def detect_archetype( # Precompute word sets for sentence overlap (avoids re-splitting per pair) draft_sent_sets = [set(s.lower().split()) for s in draft_sents] final_sent_sets = [set(s.lower().split()) for s in final_sents] - added_sents = [s for s, ws in zip(final_sents, final_sent_sets, strict=True) - if not any(_sentence_overlap(ws, ds) > 0.5 for ds in draft_sent_sets)] + added_sents = [ + s + for s, ws in zip(final_sents, final_sent_sets, strict=True) + if not any(_sentence_overlap(ws, ds) > 0.5 for ds in draft_sent_sets) + ] # 2. REMOVAL_HEDGING (check BEFORE length — hedging removal shortens text) draft_lower = draft.lower() final_lower = final.lower() - removed_hedges = [h for h in _HEDGE_PHRASES - if re.search(r'\b' + re.escape(h) + r'\b', draft_lower) - and not re.search(r'\b' + re.escape(h) + r'\b', final_lower)] + removed_hedges = [ + h + for h in _HEDGE_PHRASES + if re.search(r"\b" + re.escape(h) + r"\b", draft_lower) + and not re.search(r"\b" + re.escape(h) + r"\b", final_lower) + ] if removed_hedges: - return ArchetypeMatch( - Archetype.REMOVAL_HEDGING, 0.90, - {"removed_phrases": removed_hedges} - ) + return ArchetypeMatch(Archetype.REMOVAL_HEDGING, 0.90, {"removed_phrases": removed_hedges}) # 3. CONSTRAINT_ADDITION (check BEFORE length — constraints lengthen text) - new_constraints = [w for w in _CONSTRAINT_WORDS - if re.search(r'\b' + re.escape(w) + r'\b', final_lower) - and not re.search(r'\b' + re.escape(w) + r'\b', draft_lower)] + new_constraints = [ + w + for w in _CONSTRAINT_WORDS + if re.search(r"\b" + re.escape(w) + r"\b", final_lower) + and not re.search(r"\b" + re.escape(w) + r"\b", draft_lower) + ] if new_constraints: constraint_sent = _find_sentence_containing(final, new_constraints[0]) return ArchetypeMatch( - Archetype.CONSTRAINT_ADDITION, 0.85, - {"constraint_word": new_constraints[0], - "constraint_sentence": constraint_sent} + Archetype.CONSTRAINT_ADDITION, + 0.85, + {"constraint_word": new_constraints[0], "constraint_sentence": constraint_sent}, ) # 4. ADDITION_STEP: new sentences with action verbs (before length check) if added_sents and _contains_action_verb(added_sents[0]): - return ArchetypeMatch( - Archetype.ADDITION_STEP, 0.80, - {"added_step": added_sents[0]} - ) + return ArchetypeMatch(Archetype.ADDITION_STEP, 0.80, {"added_step": added_sents[0]}) # 5. REPLACEMENT_TONE (uses classifier output, before length check) if classification: desc_lower = classification.description.lower() if "casualized" in desc_lower or "formalized" in desc_lower: direction = "casual" if "casualized" in desc_lower else "formal" - return ArchetypeMatch( - Archetype.REPLACEMENT_TONE, 0.85, - {"direction": direction} - ) + return ArchetypeMatch(Archetype.REPLACEMENT_TONE, 0.85, {"direction": direction}) # 6. TRUNCATION / EXPANSION (generic length change — after specific checks) len_ratio = len(final) / max(len(draft), 1) if len_ratio < 0.65: - removed_sents = [s for s, ws in zip(draft_sents, draft_sent_sets, strict=True) - if not any(_sentence_overlap(ws, fs) > 0.5 for fs in final_sent_sets)] + removed_sents = [ + s + for s, ws in zip(draft_sents, draft_sent_sets, strict=True) + if not any(_sentence_overlap(ws, fs) > 0.5 for fs in final_sent_sets) + ] topic = _extract_topic(removed_sents) if removed_sents else "content" return ArchetypeMatch( - Archetype.TRUNCATION, 0.85, - {"reduction_pct": round((1 - len_ratio) * 100), "removed_topic": topic} + Archetype.TRUNCATION, + 0.85, + {"reduction_pct": round((1 - len_ratio) * 100), "removed_topic": topic}, ) if len_ratio > 1.5: topic = _extract_topic(added_sents) if added_sents else "detail" return ArchetypeMatch( - Archetype.EXPANSION, 0.80, - {"growth_pct": round((len_ratio - 1) * 100), "added_topic": topic} + Archetype.EXPANSION, + 0.80, + {"growth_pct": round((len_ratio - 1) * 100), "added_topic": topic}, ) # 7. REORDER: same words, different arrangement @@ -271,50 +338,51 @@ def detect_archetype( new_facts = set(_FACTUAL_RE.findall(final)) if old_facts != new_facts and (old_facts or new_facts): return ArchetypeMatch( - Archetype.REPLACEMENT_FACTUAL, 0.85, - {"old_facts": list(old_facts), "new_facts": list(new_facts)} + Archetype.REPLACEMENT_FACTUAL, + 0.85, + {"old_facts": list(old_facts), "new_facts": list(new_facts)}, ) # 9. FORMAT_CHANGE - old_has_list = bool(re.search(r'^\s*[-*+]\s', draft, re.MULTILINE)) - new_has_list = bool(re.search(r'^\s*[-*+]\s', final, re.MULTILINE)) + old_has_list = bool(re.search(r"^\s*[-*+]\s", draft, re.MULTILINE)) + new_has_list = bool(re.search(r"^\s*[-*+]\s", final, re.MULTILINE)) if old_has_list != new_has_list: return ArchetypeMatch( - Archetype.FORMAT_CHANGE, 0.80, - {"old_format": "list" if old_has_list else "prose", - "new_format": "list" if new_has_list else "prose"} + Archetype.FORMAT_CHANGE, + 0.80, + { + "old_format": "list" if old_has_list else "prose", + "new_format": "list" if new_has_list else "prose", + }, ) # 10. REMOVAL_CONTENT - removed_sents = [s for s, ws in zip(draft_sents, draft_sent_sets, strict=True) - if not any(_sentence_overlap(ws, fs) > 0.5 for fs in final_sent_sets)] + removed_sents = [ + s + for s, ws in zip(draft_sents, draft_sent_sets, strict=True) + if not any(_sentence_overlap(ws, fs) > 0.5 for fs in final_sent_sets) + ] if removed_sents and not added_sents: topic = _extract_topic(removed_sents) - return ArchetypeMatch( - Archetype.REMOVAL_CONTENT, 0.75, - {"removed_topic": topic} - ) + return ArchetypeMatch(Archetype.REMOVAL_CONTENT, 0.75, {"removed_topic": topic}) # 11. REPLACEMENT_WORD if len(added_words) <= 3 and len(removed_words) <= 3 and added_words and removed_words: return ArchetypeMatch( - Archetype.REPLACEMENT_WORD, 0.80, - {"old_words": sorted(removed_words)[:3], - "new_words": sorted(added_words)[:3]} + Archetype.REPLACEMENT_WORD, + 0.80, + {"old_words": sorted(removed_words)[:3], "new_words": sorted(added_words)[:3]}, ) # 12. Fallback: any new sentences if added_sents: - return ArchetypeMatch( - Archetype.ADDITION_STEP, 0.60, - {"added_step": added_sents[0]} - ) + return ArchetypeMatch(Archetype.ADDITION_STEP, 0.60, {"added_step": added_sents[0]}) # Ultimate fallback return ArchetypeMatch( - Archetype.REPLACEMENT_WORD, 0.40, - {"old_words": sorted(removed_words)[:3], - "new_words": sorted(added_words)[:3]} + Archetype.REPLACEMENT_WORD, + 0.40, + {"old_words": sorted(removed_words)[:3], "new_words": sorted(added_words)[:3]}, ) @@ -322,6 +390,7 @@ def detect_archetype( # Template Generation # --------------------------------------------------------------------------- + def generate_instruction(match: ArchetypeMatch) -> str: """Generate an imperative behavioral instruction from an archetype match.""" ctx = match.context @@ -396,15 +465,49 @@ def generate_instruction(match: ArchetypeMatch) -> str: # Quality Gate # --------------------------------------------------------------------------- -_IMPERATIVE_STARTERS = frozenset({ - "use", "don", "always", "never", "include", "exclude", - "check", "verify", "write", "keep", "cut", "add", - "remove", "present", "be", "avoid", "ensure", "start", - "lead", "break", "replace", "run", "test", "audit", - "research", "validate", "pull", "load", "revise", - "prioritize", "emphasize", "highlight", "reduce", - "increase", "prefer", "limit", "focus", "simplify", -}) +_IMPERATIVE_STARTERS = frozenset( + { + "use", + "don", + "always", + "never", + "include", + "exclude", + "check", + "verify", + "write", + "keep", + "cut", + "add", + "remove", + "present", + "be", + "avoid", + "ensure", + "start", + "lead", + "break", + "replace", + "run", + "test", + "audit", + "research", + "validate", + "pull", + "load", + "revise", + "prioritize", + "emphasize", + "highlight", + "reduce", + "increase", + "prefer", + "limit", + "focus", + "simplify", + } +) + def _is_actionable(instruction: str) -> bool: if not instruction or len(instruction) < 5: @@ -450,6 +553,7 @@ def _try_llm_extract(llm_provider, draft: str, final: str, classification) -> st # Public API # --------------------------------------------------------------------------- + def extract_instruction( draft: str, final: str, @@ -496,3 +600,96 @@ def extract_instruction( # Generic fallback — return None so callers can try their own fallback # paths (e.g. keyword templates) before resorting to generic strings. return None + + +# --------------------------------------------------------------------------- +# Cross-Correction Recurring Patterns +# --------------------------------------------------------------------------- + + +@dataclass +class RecurringPattern: + """A pattern detected across multiple related corrections.""" + + category: str + correction_count: int + descriptions: list[str] + summary: str + confidence: float + + +def _extract_common_verbs(descriptions: list[str]) -> list[str]: + verb_prefixes = [ + "use", + "remove", + "add", + "avoid", + "be", + "check", + "verify", + "include", + "exclude", + "always", + "never", + "ensure", + "keep", + "start", + "stop", + "replace", + "shorten", + "expand", + ] + counts: dict[str, int] = {} + for desc in descriptions: + words = desc.lower().split() + for word in words: + for prefix in verb_prefixes: + if word.startswith(prefix): + counts[prefix] = counts.get(prefix, 0) + 1 + break + return sorted(counts, key=counts.get, reverse=True)[:3] + + +def _synthesize_summary(category: str, descriptions: list[str]) -> str: + common_verbs = _extract_common_verbs(descriptions) + verb_str = ", ".join(common_verbs[:2]) if common_verbs else "adjust" + return ( + f"Recurring {category} pattern across {len(descriptions)} corrections: " + f"consistently {verb_str} in this category" + ) + + +def detect_recurring_patterns( + lessons: list["Lesson"], + min_corrections: int = 3, +) -> list[RecurringPattern]: + """Detect patterns spanning multiple corrections in the same category.""" + if not lessons: + return [] + + by_cat: dict[str, list["Lesson"]] = {} + for lesson in lessons: + cat = lesson.category.upper() + by_cat.setdefault(cat, []).append(lesson) + + patterns: list[RecurringPattern] = [] + for cat, cat_lessons in by_cat.items(): + if len(cat_lessons) < min_corrections: + continue + + descriptions = [l.description for l in cat_lessons] + total_fires = sum(l.fire_count for l in cat_lessons) + confidence = min(0.95, 0.5 + 0.05 * len(cat_lessons) + 0.01 * total_fires) + + patterns.append( + RecurringPattern( + category=cat, + correction_count=len(cat_lessons), + descriptions=descriptions, + summary=_synthesize_summary(cat, descriptions), + confidence=round(confidence, 2), + ) + ) + + patterns.sort(key=lambda p: p.correction_count, reverse=True) + return patterns diff --git a/tests/test_behavioral_extractor.py b/tests/test_behavioral_extractor.py new file mode 100644 index 00000000..ffb3fd82 --- /dev/null +++ b/tests/test_behavioral_extractor.py @@ -0,0 +1,61 @@ +"""Tests for behavioral_extractor — recurring pattern detection.""" + +from gradata.enhancements.behavioral_extractor import detect_recurring_patterns, RecurringPattern +from gradata._types import Lesson, LessonState + + +class TestRecurringPatterns: + def _make_lesson(self, category, description, fire_count=3): + return Lesson( + date="2026-04-10", + state=LessonState.PATTERN, + confidence=0.65, + category=category, + description=description, + fire_count=fire_count, + ) + + def test_groups_related_corrections_by_category(self): + lessons = [ + self._make_lesson("TONE", "Use active voice in emails"), + self._make_lesson("TONE", "Remove hedging from subject lines"), + self._make_lesson("TONE", "Be direct in opening sentences"), + ] + patterns = detect_recurring_patterns(lessons) + assert len(patterns) >= 1 + assert patterns[0].category == "TONE" + assert patterns[0].correction_count >= 3 + + def test_no_pattern_with_fewer_than_3(self): + lessons = [ + self._make_lesson("TONE", "Use active voice"), + self._make_lesson("TONE", "Be direct"), + ] + patterns = detect_recurring_patterns(lessons, min_corrections=3) + assert len(patterns) == 0 + + def test_multiple_category_patterns(self): + lessons = [ + self._make_lesson("TONE", "Use active voice"), + self._make_lesson("TONE", "Remove hedging"), + self._make_lesson("TONE", "Be concise"), + self._make_lesson("PROCESS", "Verify data before sending"), + self._make_lesson("PROCESS", "Check CRM before drafting"), + self._make_lesson("PROCESS", "Always confirm meeting time"), + ] + patterns = detect_recurring_patterns(lessons) + categories = {p.category for p in patterns} + assert "TONE" in categories + assert "PROCESS" in categories + + def test_pattern_includes_summary(self): + lessons = [ + self._make_lesson("TONE", "Use active voice"), + self._make_lesson("TONE", "Remove hedging words"), + self._make_lesson("TONE", "Be direct and concise"), + ] + patterns = detect_recurring_patterns(lessons) + assert patterns[0].summary + + def test_empty_returns_empty(self): + assert detect_recurring_patterns([]) == [] From f4f893e5f8f77730838c636141e590ec2f29bbf7 Mon Sep 17 00:00:00 2001 From: Oliver Le Date: Fri, 10 Apr 2026 12:17:46 -0700 Subject: [PATCH 04/13] feat(P0.2): add log_suppression() Adds RULE_SUPPRESSION event emission to rule_tracker.py for tracking why rules are filtered out during apply_rules pipeline. Co-Authored-By: Gradata --- src/gradata/rules/rule_tracker.py | 49 ++++++++++++++++++- tests/test_rule_tracker_suppression.py | 65 ++++++++++++++++++++++++++ 2 files changed, 112 insertions(+), 2 deletions(-) create mode 100644 tests/test_rule_tracker_suppression.py diff --git a/src/gradata/rules/rule_tracker.py b/src/gradata/rules/rule_tracker.py index 8f05384d..43687b7f 100644 --- a/src/gradata/rules/rule_tracker.py +++ b/src/gradata/rules/rule_tracker.py @@ -58,17 +58,58 @@ def log_application( try: from gradata._events import emit + return emit("RULE_APPLICATION", source, data, tags, session) except Exception as e: - logging.getLogger("gradata.rule_tracker").warning("Failed to log rule application for %s: %s", rule_id, e) + logging.getLogger("gradata.rule_tracker").warning( + "Failed to log rule application for %s: %s", rule_id, e + ) return None +VALID_SUPPRESSION_REASONS = { + "relevance_threshold", + "domain_disabled", + "conflict", + "assumption_invalid", +} + + +def log_suppression( + rule_id: str, + reason: str, + relevance: float, + competing_rule_ids: list[str] | None = None, + ctx=None, + session: int | None = None, +) -> None: + """Emit a RULE_SUPPRESSION event when a rule is filtered out. + + Args: + rule_id: The rule that was suppressed. + reason: One of VALID_SUPPRESSION_REASONS. + relevance: The relevance score at time of suppression. + competing_rule_ids: Other rules that caused suppression (for conflicts). + ctx: Optional BrainContext. + session: Optional session number. + """ + from gradata._events import emit + + data = { + "rule_id": rule_id, + "reason": reason, + "relevance": round(relevance, 4), + "competing_rules": competing_rule_ids or [], + } + tags = [f"rule:{rule_id}", f"suppression:{reason}"] + emit("RULE_SUPPRESSION", source="rule_engine", data=data, tags=tags, session=session, ctx=ctx) + def get_session_applications(db_path: Path, session: int) -> list[dict]: """All RULE_APPLICATION events for a given session.""" try: from gradata._events import query + events = query(event_type="RULE_APPLICATION", session=session, limit=500) return [ { @@ -90,6 +131,7 @@ def get_session_applications(db_path: Path, session: int) -> list[dict]: @dataclass class RuleApplication: """Typed record of a single rule application event.""" + rule_id: str session: int | None = None accepted: bool = False @@ -103,6 +145,7 @@ def get_rule_history(db_path: Path, rule_id: str, limit: int = 20) -> list[dict] """Recent RULE_APPLICATION events for a specific rule.""" try: from gradata._events import query + events = query(event_type="RULE_APPLICATION", limit=500) matching = [e for e in events if e.get("data", {}).get("rule_id") == rule_id] return [ @@ -117,5 +160,7 @@ def get_rule_history(db_path: Path, rule_id: str, limit: int = 20) -> list[dict] for e in matching[:limit] ] except Exception as e: - logging.getLogger("gradata.rule_tracker").warning("get_rule_history failed for %s: %s", rule_id, e) + logging.getLogger("gradata.rule_tracker").warning( + "get_rule_history failed for %s: %s", rule_id, e + ) return [] diff --git a/tests/test_rule_tracker_suppression.py b/tests/test_rule_tracker_suppression.py new file mode 100644 index 00000000..a1ac7106 --- /dev/null +++ b/tests/test_rule_tracker_suppression.py @@ -0,0 +1,65 @@ +"""Tests for rule suppression tracking.""" + +from unittest.mock import patch + +import pytest + +from gradata.rules.rule_tracker import VALID_SUPPRESSION_REASONS, log_suppression + +_EMIT_PATH = "gradata._events.emit" + + +class TestLogSuppression: + def test_emits_rule_suppression_event(self): + with patch(_EMIT_PATH) as mock_emit: + log_suppression(rule_id="TONE:abc123", reason="relevance_threshold", relevance=0.25) + mock_emit.assert_called_once() + args = mock_emit.call_args[0] + assert args[0] == "RULE_SUPPRESSION" + + def test_data_contains_rule_id_and_reason(self): + with patch(_EMIT_PATH) as mock_emit: + log_suppression(rule_id="TONE:abc123", reason="domain_disabled", relevance=0.0) + kwargs = mock_emit.call_args[1] + # emit is called with keyword args from log_suppression + call_args = mock_emit.call_args + # Get data from positional or keyword args + data = call_args[1].get("data", call_args[0][2] if len(call_args[0]) > 2 else None) + assert data["rule_id"] == "TONE:abc123" + assert data["reason"] == "domain_disabled" + assert data["relevance"] == 0.0 + assert data["competing_rules"] == [] + + def test_competing_rules_passed_through(self): + with patch(_EMIT_PATH) as mock_emit: + log_suppression( + rule_id="TONE:abc123", + reason="conflict", + relevance=0.5, + competing_rule_ids=["STYLE:def456"], + ) + call_args = mock_emit.call_args + data = call_args[1].get("data", call_args[0][2] if len(call_args[0]) > 2 else None) + assert data["competing_rules"] == ["STYLE:def456"] + + def test_tags_include_rule_and_reason(self): + with patch(_EMIT_PATH) as mock_emit: + log_suppression(rule_id="TONE:abc123", reason="assumption_invalid", relevance=0.3) + call_args = mock_emit.call_args + tags = call_args[1].get("tags", call_args[0][3] if len(call_args[0]) > 3 else None) + assert "rule:TONE:abc123" in tags + assert "suppression:assumption_invalid" in tags + + def test_relevance_rounded(self): + with patch(_EMIT_PATH) as mock_emit: + log_suppression(rule_id="X:1", reason="relevance_threshold", relevance=0.123456789) + call_args = mock_emit.call_args + data = call_args[1].get("data", call_args[0][2] if len(call_args[0]) > 2 else None) + assert data["relevance"] == 0.1235 + + def test_valid_suppression_reasons_complete(self): + assert "relevance_threshold" in VALID_SUPPRESSION_REASONS + assert "domain_disabled" in VALID_SUPPRESSION_REASONS + assert "conflict" in VALID_SUPPRESSION_REASONS + assert "assumption_invalid" in VALID_SUPPRESSION_REASONS + assert len(VALID_SUPPRESSION_REASONS) == 4 From a2b1677fa1876dd019fea667b0ba09d62d17259e Mon Sep 17 00:00:00 2001 From: Oliver Le Date: Fri, 10 Apr 2026 12:22:38 -0700 Subject: [PATCH 05/13] =?UTF-8?q?feat(P2.2):=20temporal=20scope=20?= =?UTF-8?q?=E2=80=94=20decay=20curves=20+=20idle=20limits?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Gradata --- src/gradata/_scope.py | 59 +++++++++++++++++++++++++++++++++++-------- tests/test_scope.py | 49 +++++++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+), 10 deletions(-) create mode 100644 tests/test_scope.py diff --git a/src/gradata/_scope.py b/src/gradata/_scope.py index d66db6f4..fb253beb 100644 --- a/src/gradata/_scope.py +++ b/src/gradata/_scope.py @@ -46,7 +46,35 @@ class RuleScope: channel: str = "" stakes: str = "normal" agent_type: str = "" # Agent type for scoped rule injection (e.g. "researcher", "reviewer") - namespace: str = "" # Scope tag for per-context rules (e.g. "api-endpoint", "onboarding") + namespace: str = "" # Scope tag for per-context rules (e.g. "api-endpoint", "onboarding") + temporal_relevance: str = "" # "evergreen", "seasonal", "recent", or "" (wildcard) + max_idle_sessions: int = 0 # Auto-suppress after N idle sessions (0 = never) + created_session: int = 0 # Session number when this scope was first assigned + + +def temporal_decay( + sessions_since_fire: int, + max_idle: int, + floor: float = 0.05, + steepness: float = 3.0, +) -> float: + """Compute temporal decay multiplier for rule confidence. + + Uses exponential decay: exp(-steepness * ratio^2) where + ratio = sessions_since_fire / max_idle. + + Returns decay multiplier in [floor, 1.0]. If max_idle=0, returns 1.0 (evergreen). + """ + if max_idle <= 0: + return 1.0 + if sessions_since_fire <= 0: + return 1.0 + + import math + + ratio = sessions_since_fire / max_idle + decay = math.exp(-steepness * ratio * ratio) + return max(floor, round(decay, 4)) # --------------------------------------------------------------------------- @@ -56,7 +84,7 @@ class RuleScope: # task_type → keywords (any keyword triggers the type; first match wins) _TASK_TYPE_KEYWORDS: list[tuple[str, list[str]]] = [ ("email_draft", ["email", "draft", "write", "compose", "reply", "follow-up", "followup"]), - ("demo_prep", ["demo", "call", "meeting", "prep", "presentation"]), + ("demo_prep", ["demo", "call", "meeting", "prep", "presentation"]), ("prospecting", ["prospect", "lead", "find", "enrich", "list", "sweep"]), ("code_review", ["review", "code", "pr", "pull request"]), ("documentation", ["doc", "readme", "guide", "spec"]), @@ -65,24 +93,24 @@ class RuleScope: # audience → title keywords (checked case-insensitively; first match wins) _AUDIENCE_KEYWORDS: list[tuple[str, list[str]]] = [ ("c_suite", ["ceo", "cto", "coo", "cfo", "cro", "cmo", "chief"]), - ("vp", ["vp", "vice president", "head of"]), + ("vp", ["vp", "vice president", "head of"]), ("director", ["director"]), - ("manager", ["manager"]), + ("manager", ["manager"]), # "ic" is the fallback for anything unmatched ] # channel inference: task keywords → channel _CHANNEL_KEYWORDS: list[tuple[str, list[str]]] = [ - ("email", ["email", "draft", "compose", "reply", "follow-up", "followup"]), - ("slack", ["slack", "message", "dm"]), + ("email", ["email", "draft", "compose", "reply", "follow-up", "followup"]), + ("slack", ["slack", "message", "dm"]), ("document", ["doc", "readme", "guide", "spec", "report"]), - ("call", ["call", "meeting", "demo", "presentation"]), + ("call", ["call", "meeting", "demo", "presentation"]), ] # stakes inference: task keywords → stakes override _STAKES_KEYWORDS: list[tuple[str, list[str]]] = [ - ("high", ["demo", "call", "meeting", "presentation", "proposal", "critical"]), - ("low", ["internal", "draft", "note"]), + ("high", ["demo", "call", "meeting", "presentation", "proposal", "critical"]), + ("low", ["internal", "draft", "note"]), ] @@ -294,5 +322,16 @@ def scope_from_dict(data: dict[str, str]) -> RuleScope: channel='', stakes='high') """ valid_fields = {f for f in RuleScope.__dataclass_fields__} # type: ignore[attr-defined] - filtered = {k: v for k, v in data.items() if k in valid_fields} + # Coerce int fields that were stringified by scope_to_dict + _INT_FIELDS = {"max_idle_sessions", "created_session"} + filtered = {} + for k, v in data.items(): + if k not in valid_fields: + continue + if k in _INT_FIELDS: + try: + v = int(v) + except (ValueError, TypeError): + v = 0 + filtered[k] = v return RuleScope(**filtered) diff --git a/tests/test_scope.py b/tests/test_scope.py new file mode 100644 index 00000000..ca45723a --- /dev/null +++ b/tests/test_scope.py @@ -0,0 +1,49 @@ +"""Tests for _scope — temporal scope fields and decay.""" + +from gradata._scope import RuleScope, temporal_decay + + +class TestTemporalScope: + def test_default_temporal_fields(self): + scope = RuleScope() + assert scope.temporal_relevance == "" + assert scope.max_idle_sessions == 0 + assert scope.created_session == 0 + + def test_custom_temporal_fields(self): + scope = RuleScope( + domain="sales", temporal_relevance="seasonal", max_idle_sessions=15, created_session=42 + ) + assert scope.temporal_relevance == "seasonal" + assert scope.max_idle_sessions == 15 + assert scope.created_session == 42 + + def test_decay_fresh_rule(self): + assert temporal_decay(sessions_since_fire=0, max_idle=20) == 1.0 + + def test_decay_midpoint(self): + factor = temporal_decay(sessions_since_fire=10, max_idle=20) + assert 0.4 < factor < 0.8 + + def test_decay_at_limit(self): + factor = temporal_decay(sessions_since_fire=20, max_idle=20) + assert factor < 0.15 + + def test_decay_beyond_limit(self): + factor = temporal_decay(sessions_since_fire=30, max_idle=20) + assert factor == 0.05 + + def test_decay_zero_max_idle(self): + assert temporal_decay(sessions_since_fire=100, max_idle=0) == 1.0 + + def test_scope_serialization_roundtrip(self): + from gradata._scope import scope_to_dict, scope_from_dict + + scope = RuleScope( + domain="sales", temporal_relevance="recent", max_idle_sessions=10, created_session=5 + ) + d = scope_to_dict(scope) + restored = scope_from_dict(d) + assert restored.temporal_relevance == "recent" + assert restored.max_idle_sessions == 10 + assert restored.created_session == 5 From 66f4134e1e930a3dcd0e62920a684d73883df41f Mon Sep 17 00:00:00 2001 From: Oliver Le Date: Fri, 10 Apr 2026 12:22:52 -0700 Subject: [PATCH 06/13] feat(P2.3): cognitive load taxonomy tag Co-Authored-By: Gradata --- src/gradata/_tag_taxonomy.py | 167 +++++++++++++++++++++++++++-------- tests/test_tag_taxonomy.py | 46 ++++++++++ 2 files changed, 174 insertions(+), 39 deletions(-) create mode 100644 tests/test_tag_taxonomy.py diff --git a/src/gradata/_tag_taxonomy.py b/src/gradata/_tag_taxonomy.py index fc20d360..df93b314 100644 --- a/src/gradata/_tag_taxonomy.py +++ b/src/gradata/_tag_taxonomy.py @@ -29,11 +29,22 @@ "mode": "closed", "values": { # Core categories (from edit classifier: content/factual/tone/structure/style) - "CONTENT", "FACTUAL", "TONE", "STRUCTURE", "STYLE", + "CONTENT", + "FACTUAL", + "TONE", + "STRUCTURE", + "STYLE", # Learning system categories - "DRAFTING", "ACCURACY", "PROCESS", "ARCHITECTURE", - "COMMUNICATION", "CONTEXT", "CONSTRAINT", - "DATA_INTEGRITY", "THOROUGHNESS", "COST", + "DRAFTING", + "ACCURACY", + "PROCESS", + "ARCHITECTURE", + "COMMUNICATION", + "CONTEXT", + "CONSTRAINT", + "DATA_INTEGRITY", + "THOROUGHNESS", + "COST", }, "required_on": ["CORRECTION"], }, @@ -54,6 +65,12 @@ "values": {"direct", "casual", "consultative", "formal", "urgent"}, "required_on": [], }, + "cognitive_load": { + "desc": "Cognitive load type (Sweller's CLT)", + "mode": "closed", + "values": {"intrinsic", "extraneous", "germane"}, + "required_on": [], + }, } # ── Domain-Specific Defaults (Sales) ────────────────────────────────── @@ -75,54 +92,88 @@ "output": { "desc": "Output type produced", "mode": "closed", - "values": {"email", "cheat_sheet", "research", "cold_call_script", - "linkedin_message", "crm_note", "proposal", "sequence", - "report", "system_artifact"}, + "values": { + "email", + "cheat_sheet", + "research", + "cold_call_script", + "linkedin_message", + "crm_note", + "proposal", + "sequence", + "report", + "system_artifact", + }, "required_on": ["OUTPUT"], }, "angle": { "desc": "Communication angle used in messaging", "mode": "closed", - "values": {"direct", "pain-point", "time-savings", "roi", - "competitor-displacement", "social-proof", "curiosity", - "event-triggered", "referral", "break-up", "custom"}, + "values": { + "direct", + "pain-point", + "time-savings", + "roi", + "competitor-displacement", + "social-proof", + "curiosity", + "event-triggered", + "referral", + "break-up", + "custom", + }, "required_on": [], }, "persona": { "desc": "Buyer persona type", "mode": "closed", - "values": {"agency-owner", "founder", "ecom-director", "cmo", - "growth-lead", "fractional-cmo", "marketing-manager", - "vp-marketing", "other"}, + "values": { + "agency-owner", + "founder", + "ecom-director", + "cmo", + "growth-lead", + "fractional-cmo", + "marketing-manager", + "vp-marketing", + "other", + }, "required_on": [], }, "framework": { "desc": "Sales/email framework applied", "mode": "closed", - "values": {"ccq", "spin", "gap", "jolt", "challenger", - "great-demo", "sandler", "custom"}, + "values": {"ccq", "spin", "gap", "jolt", "challenger", "great-demo", "sandler", "custom"}, "required_on": [], }, "channel": { "desc": "Communication channel", "mode": "closed", - "values": {"email", "phone", "linkedin", "meeting", "demo", - "slack", "text"}, + "values": {"email", "phone", "linkedin", "meeting", "demo", "slack", "text"}, "required_on": ["DELTA_TAG"], }, "outcome": { "desc": "Interaction outcome", "mode": "closed", - "values": {"reply", "no-reply", "positive-reply", "negative-reply", - "meeting-booked", "demo-completed", "deal-advanced", - "deal-lost", "ghosted", "objection-raised", "pending"}, + "values": { + "reply", + "no-reply", + "positive-reply", + "negative-reply", + "meeting-booked", + "demo-completed", + "deal-advanced", + "deal-lost", + "ghosted", + "objection-raised", + "pending", + }, "required_on": [], }, } # Sales-specific correction categories (extend core set) -_DOMAIN_CATEGORIES = {"CRM", "STRATEGY", "ENTITIES", "POSITIONING", - "PRESENTATION", "STARTUP"} +_DOMAIN_CATEGORIES = {"CRM", "STRATEGY", "ENTITIES", "POSITIONING", "PRESENTATION", "STARTUP"} def _load_taxonomy() -> dict: @@ -136,7 +187,9 @@ def _load_taxonomy() -> dict: taxonomy = dict(_CORE_TAXONOMY) # Try loading domain config from brain directory - taxonomy_path = _p.BRAIN_DIR / "taxonomy.json" if hasattr(_p, 'BRAIN_DIR') and _p.BRAIN_DIR else None + taxonomy_path = ( + _p.BRAIN_DIR / "taxonomy.json" if hasattr(_p, "BRAIN_DIR") and _p.BRAIN_DIR else None + ) if taxonomy_path and taxonomy_path.exists(): try: with open(taxonomy_path, encoding="utf-8") as f: @@ -158,8 +211,8 @@ def _load_taxonomy() -> dict: taxonomy[key] = spec # Extend core category values if domain provides extra categories if "extra_categories" in domain_taxonomy: - taxonomy["category"]["values"] = ( - taxonomy["category"]["values"] | set(domain_taxonomy["extra_categories"]) + taxonomy["category"]["values"] = taxonomy["category"]["values"] | set( + domain_taxonomy["extra_categories"] ) return taxonomy except Exception: @@ -184,6 +237,7 @@ def reload_taxonomy() -> None: # ── Validation ───────────────────────────────────────────────────────── + def _get_entity_names() -> set[str]: """Get valid entity names from brain's entity directory. @@ -193,7 +247,7 @@ def _get_entity_names() -> set[str]: names = set() # Try common entity directory names for dirname in ("prospects", "candidates", "customers", "entities"): - entity_dir = _p.BRAIN_DIR / dirname if hasattr(_p, 'BRAIN_DIR') and _p.BRAIN_DIR else None + entity_dir = _p.BRAIN_DIR / dirname if hasattr(_p, "BRAIN_DIR") and _p.BRAIN_DIR else None if entity_dir and entity_dir.exists(): for f in entity_dir.glob("*.md"): if f.name.startswith("_"): @@ -201,7 +255,7 @@ def _get_entity_names() -> set[str]: parts = re.split(r"\s*[—–-]\s*", f.stem, maxsplit=1) names.add(parts[0].strip()) # Legacy fallback - if not names and hasattr(_p, 'PROSPECTS_DIR'): + if not names and hasattr(_p, "PROSPECTS_DIR"): prospects_dir = _p.PROSPECTS_DIR if prospects_dir.exists(): for f in prospects_dir.glob("*.md"): @@ -236,8 +290,9 @@ def validate_tag(tag: str, strict: bool = False) -> tuple[bool, str]: return True, "ok" -def validate_tags(tags: list[str], event_type: str | None = None, - strict: bool = False) -> list[str]: +def validate_tags( + tags: list[str], event_type: str | None = None, strict: bool = False +) -> list[str]: issues = [] for tag in tags: valid, msg = validate_tag(tag, strict=strict) @@ -247,12 +302,13 @@ def validate_tags(tags: list[str], event_type: str | None = None, present_prefixes = {t.split(":")[0] for t in tags if ":" in t} for prefix, spec in TAXONOMY.items(): if event_type in spec.get("required_on", []) and prefix not in present_prefixes: - issues.append(f"Missing required tag '{prefix}:' for {event_type} events") + issues.append(f"Missing required tag '{prefix}:' for {event_type} events") return issues -def enrich_tags(tags: list[str], event_type: str | None = None, - data: dict | None = None) -> list[str]: +def enrich_tags( + tags: list[str], event_type: str | None = None, data: dict | None = None +) -> list[str]: enriched = list(tags) prefixes = {t.split(":")[0] for t in enriched if ":" in t} data = data or {} @@ -266,10 +322,15 @@ def enrich_tags(tags: list[str], event_type: str | None = None, if ot: # Normalize common variants to taxonomy values output_map = { - "email_draft": "email", "email_reply": "email", "follow_up": "email", - "call_script": "cold_call_script", "call_prep": "cold_call_script", - "demo_prep": "cheat_sheet", "meeting_prep": "cheat_sheet", - "linkedin_dm": "linkedin_message", "linkedin_connect": "linkedin_message", + "email_draft": "email", + "email_reply": "email", + "follow_up": "email", + "call_script": "cold_call_script", + "call_prep": "cold_call_script", + "demo_prep": "cheat_sheet", + "meeting_prep": "cheat_sheet", + "linkedin_dm": "linkedin_message", + "linkedin_connect": "linkedin_message", } normalized = output_map.get(ot, ot) enriched.append(f"output:{normalized}") @@ -280,12 +341,40 @@ def enrich_tags(tags: list[str], event_type: str | None = None, if event_type == "DELTA_TAG" and "channel" not in prefixes: activity = data.get("activity_type", "") channel_map = { - "email_sent": "email", "email_received": "email", - "call": "phone", "meeting": "meeting", - "demo": "demo", "linkedin": "linkedin", + "email_sent": "email", + "email_received": "email", + "call": "phone", + "meeting": "meeting", + "demo": "demo", + "linkedin": "linkedin", } if activity in channel_map: enriched.append(f"channel:{channel_map[activity]}") + if event_type == "CORRECTION" and "cognitive_load" not in prefixes: + cat = data.get("category", "").upper() + _COGNITIVE_LOAD_MAP = { + "FACTUAL": "intrinsic", + "CONTENT": "intrinsic", + "DATA_INTEGRITY": "intrinsic", + "ENTITIES": "intrinsic", + "CONTEXT": "intrinsic", + "STYLE": "extraneous", + "TONE": "extraneous", + "COMMUNICATION": "extraneous", + "PRESENTATION": "extraneous", + "COST": "extraneous", + "PROCESS": "germane", + "DRAFTING": "germane", + "ACCURACY": "germane", + "THOROUGHNESS": "germane", + "ARCHITECTURE": "germane", + "CONSTRAINT": "germane", + "STRATEGY": "germane", + "POSITIONING": "germane", + } + cl = _COGNITIVE_LOAD_MAP.get(cat) + if cl: + enriched.append(f"cognitive_load:{cl}") return enriched diff --git a/tests/test_tag_taxonomy.py b/tests/test_tag_taxonomy.py new file mode 100644 index 00000000..687a6939 --- /dev/null +++ b/tests/test_tag_taxonomy.py @@ -0,0 +1,46 @@ +"""Tests for _tag_taxonomy — cognitive load tag and enrichment.""" + +from gradata._tag_taxonomy import TAXONOMY, validate_tag, enrich_tags + + +class TestCognitiveLoadTaxonomy: + def test_cognitive_load_in_taxonomy(self): + assert "cognitive_load" in TAXONOMY + spec = TAXONOMY["cognitive_load"] + assert spec["mode"] == "closed" + assert spec["values"] == {"intrinsic", "extraneous", "germane"} + + def test_validate_valid(self): + valid, msg = validate_tag("cognitive_load:intrinsic") + assert valid + valid, msg = validate_tag("cognitive_load:extraneous") + assert valid + valid, msg = validate_tag("cognitive_load:germane") + assert valid + + def test_validate_invalid(self): + valid, msg = validate_tag("cognitive_load:bogus") + assert not valid + assert "Invalid cognitive_load value" in msg + + def test_enrich_factual_is_intrinsic(self): + tags = enrich_tags([], event_type="CORRECTION", data={"category": "FACTUAL"}) + assert "cognitive_load:intrinsic" in tags + + def test_enrich_style_is_extraneous(self): + tags = enrich_tags([], event_type="CORRECTION", data={"category": "STYLE"}) + assert "cognitive_load:extraneous" in tags + + def test_enrich_process_is_germane(self): + tags = enrich_tags([], event_type="CORRECTION", data={"category": "PROCESS"}) + assert "cognitive_load:germane" in tags + + def test_enrich_does_not_override_explicit(self): + tags = enrich_tags( + ["cognitive_load:extraneous"], + event_type="CORRECTION", + data={"category": "FACTUAL"}, + ) + # Should keep explicit value, not add intrinsic + assert tags.count("cognitive_load:extraneous") == 1 + assert "cognitive_load:intrinsic" not in tags From 78a70eb2187da8227d1492e748cdb7f82a8b2db5 Mon Sep 17 00:00:00 2001 From: Oliver Le Date: Fri, 10 Apr 2026 12:28:26 -0700 Subject: [PATCH 07/13] feat(P1.1): integrate beta posterior into confidence pipeline Add alpha/beta_param persistence in parse_lessons/format_lessons for round-trip serialization. Bayesian blending already wired in Task 1. Co-Authored-By: Gradata --- src/gradata/enhancements/self_improvement.py | 235 ++++++++++++++----- tests/test_bayesian_confidence.py | 84 +++++++ 2 files changed, 255 insertions(+), 64 deletions(-) create mode 100644 tests/test_bayesian_confidence.py diff --git a/src/gradata/enhancements/self_improvement.py b/src/gradata/enhancements/self_improvement.py index 85feadca..c2bf4581 100644 --- a/src/gradata/enhancements/self_improvement.py +++ b/src/gradata/enhancements/self_improvement.py @@ -51,20 +51,20 @@ # SPEC Section 1: maturity-aware kill switches (relevant cycles only) KILL_LIMITS: dict[str, int] = { - "INFANT": 8, # 0-50 sessions — unproven, die fast - "ADOLESCENT": 12, # 50-100 sessions — some evidence, moderate grace - "MATURE": 15, # 100-200 sessions — proven useful, longer grace - "STABLE": 20, # 200+ sessions — battle-tested, longest grace + "INFANT": 8, # 0-50 sessions — unproven, die fast + "ADOLESCENT": 12, # 50-100 sessions — some evidence, moderate grace + "MATURE": 15, # 100-200 sessions — proven useful, longer grace + "STABLE": 20, # 200+ sessions — battle-tested, longest grace } # Severity multipliers for contradiction penalty # Typo fix barely dents confidence; rewrite hits hard SEVERITY_WEIGHTS: dict[str, float] = { - "trivial": 0.20, # typo fix: -0.10 * 0.20 = -0.02 - "minor": 0.50, # word swap: -0.10 * 0.50 = -0.05 - "moderate": 0.80, # sentence rewrite: -0.10 * 0.80 = -0.08 - "major": 1.00, # significant change: -0.10 * 1.00 = -0.10 - "rewrite": 1.30, # complete rewrite: -0.10 * 1.30 = -0.13 + "trivial": 0.20, # typo fix: -0.10 * 0.20 = -0.02 + "minor": 0.50, # word swap: -0.10 * 0.50 = -0.05 + "moderate": 0.80, # sentence rewrite: -0.10 * 0.80 = -0.08 + "major": 1.00, # significant change: -0.10 * 1.00 = -0.10 + "rewrite": 1.30, # complete rewrite: -0.10 * 1.30 = -0.13 } # Legacy survival severity weights (v2.3: survival is flat 0.08, these kept @@ -100,7 +100,7 @@ } # Graduation gate thresholds -_GRADUATION_DEDUP_THRESHOLD = 0.85 # Near-duplicate rule detection +_GRADUATION_DEDUP_THRESHOLD = 0.85 # Near-duplicate rule detection # Auto-detection threshold: if a session has more than this many corrections, # it's likely machine-driven. Set high enough that intensive human sessions @@ -124,6 +124,7 @@ def detect_correction_poisoning(corrections: list[dict]) -> list[str]: These categories should be treated with reduced confidence updates. """ from collections import defaultdict + by_category: dict[str, list[str]] = defaultdict(list) for c in corrections: cat = c.get("category", "").upper() @@ -146,8 +147,13 @@ def detect_correction_poisoning(corrections: list[dict]) -> list[str]: contradictions += 1 if comparisons > 0 and contradictions / comparisons >= _POISONING_CONTRADICTION_RATE: poisoned.append(cat) - _log.warning("Poisoning detected in %s: %.0f%% contradictions (%d/%d)", - cat, contradictions / comparisons * 100, contradictions, comparisons) + _log.warning( + "Poisoning detected in %s: %.0f%% contradictions (%d/%d)", + cat, + contradictions / comparisons * 100, + contradictions, + comparisons, + ) return poisoned @@ -165,9 +171,11 @@ def _detect_machine_context( is_machine = len(corrections) > _MACHINE_CORRECTION_THRESHOLD if is_machine: import logging + logging.getLogger(__name__).info( "Auto-detected machine context: %d corrections (threshold: %d)", - len(corrections), _MACHINE_CORRECTION_THRESHOLD, + len(corrections), + _MACHINE_CORRECTION_THRESHOLD, ) return is_machine @@ -181,14 +189,29 @@ def _detect_machine_context( # Session-type → categories that are testable in that session type # DRAFTING is immune during system sessions; ARCHITECTURE is immune during sales CATEGORY_SESSION_MAP: dict[str, frozenset[str]] = { - "full": frozenset(), # empty = all categories testable - "systems": frozenset({ - "ARCHITECTURE", "PROCESS", "TOOL", "THOROUGHNESS", "CONTEXT", - }), - "sales": frozenset({ - "DRAFTING", "LEADS", "PRICING", "DEMO_PREP", "POSITIONING", - "COMMUNICATION", "TONE", "ACCURACY", "DATA_INTEGRITY", - }), + "full": frozenset(), # empty = all categories testable + "systems": frozenset( + { + "ARCHITECTURE", + "PROCESS", + "TOOL", + "THOROUGHNESS", + "CONTEXT", + } + ), + "sales": frozenset( + { + "DRAFTING", + "LEADS", + "PRICING", + "DEMO_PREP", + "POSITIONING", + "COMMUNICATION", + "TONE", + "ACCURACY", + "DATA_INTEGRITY", + } + ), } @@ -286,38 +309,58 @@ def parse_lessons(text: str) -> list[Lesson]: memory_ids: list[str] = [] scope_json: str = "" domain_scores: dict = {} + alpha = 1.0 + beta_param_val = 1.0 j = i + 1 while j < len(lines) and lines[j].startswith(" "): meta_line = lines[j].strip() if meta_line.startswith("Root cause:") and not root_cause: - root_cause = meta_line[len("Root cause:"):].strip() + root_cause = meta_line[len("Root cause:") :].strip() elif meta_line.startswith("Agent:"): - agent_type = meta_line[len("Agent:"):].strip() + agent_type = meta_line[len("Agent:") :].strip() elif meta_line.startswith("Kill reason:"): - kill_reason = meta_line[len("Kill reason:"):].strip() + kill_reason = meta_line[len("Kill reason:") :].strip() elif meta_line.startswith("Corrections:"): - ids_str = meta_line[len("Corrections:"):].strip() + ids_str = meta_line[len("Corrections:") :].strip() correction_event_ids = [x.strip() for x in ids_str.split(",") if x.strip()] elif meta_line.startswith("Pending approval:"): - pending_approval = meta_line[len("Pending approval:"):].strip().lower() == "yes" + pending_approval = meta_line[len("Pending approval:") :].strip().lower() == "yes" elif meta_line.startswith("Parent meta-rule:"): - parent_meta_rule_id = meta_line[len("Parent meta-rule:"):].strip() or None + parent_meta_rule_id = meta_line[len("Parent meta-rule:") :].strip() or None elif meta_line.startswith("Memory links:"): - memory_ids = [x.strip() for x in meta_line[len("Memory links:"):].strip().split(",") if x.strip()] + memory_ids = [ + x.strip() + for x in meta_line[len("Memory links:") :].strip().split(",") + if x.strip() + ] elif meta_line.startswith("Scope:"): - scope_json = meta_line[len("Scope:"):].strip() + scope_json = meta_line[len("Scope:") :].strip() + elif meta_line.startswith("Beta params:"): + import json as _json_bp + + try: + bp = _json_bp.loads(meta_line[len("Beta params:") :].strip()) + alpha = bp.get("alpha", 1.0) + beta_param_val = bp.get("beta", 1.0) + except _json_bp.JSONDecodeError: + pass elif meta_line.startswith("Domain scores:"): import json as _json + try: - domain_scores = _json.loads(meta_line[len("Domain scores:"):].strip()) + domain_scores = _json.loads(meta_line[len("Domain scores:") :].strip()) except _json.JSONDecodeError: domain_scores = {} elif meta_line.startswith("Metadata:"): import json as _json_md + try: - _md_dict = _json_md.loads(meta_line[len("Metadata:"):].strip()) + _md_dict = _json_md.loads(meta_line[len("Metadata:") :].strip()) from gradata._types import RuleMetadata as _RM - metadata_obj = _RM(**{k: v for k, v in _md_dict.items() if k in _RM.__dataclass_fields__}) + + metadata_obj = _RM( + **{k: v for k, v in _md_dict.items() if k in _RM.__dataclass_fields__} + ) except (ValueError, TypeError, _json_md.JSONDecodeError): metadata_obj = None meta_m = _META_RE.search(meta_line) @@ -345,6 +388,8 @@ def parse_lessons(text: str) -> list[Lesson]: parent_meta_rule_id=parent_meta_rule_id, memory_ids=memory_ids, domain_scores=domain_scores, + alpha=alpha, + beta_param=beta_param_val, ) if metadata_obj is not None: _lesson.metadata = metadata_obj @@ -390,6 +435,45 @@ def fsrs_penalty(confidence: float, *, machine: bool = False) -> float: return round(base * (0.5 + confidence * 0.5), 4) +# --------------------------------------------------------------------------- +# Bayesian Confidence Blending +# --------------------------------------------------------------------------- + +_BAYESIAN_BLEND_MIN_OBS = 5 +_BAYESIAN_BLEND_MAX_OBS = 20 + + +def _bayesian_blend_weight(total_observations: int) -> float: + """Return weight for Bayesian component [0.0, 0.7].""" + if total_observations < _BAYESIAN_BLEND_MIN_OBS: + return 0.0 + if total_observations >= _BAYESIAN_BLEND_MAX_OBS: + return 0.7 + return ( + 0.7 + * (total_observations - _BAYESIAN_BLEND_MIN_OBS) + / (_BAYESIAN_BLEND_MAX_OBS - _BAYESIAN_BLEND_MIN_OBS) + ) + + +def _bayesian_confidence(lesson: "Lesson") -> float: + """Compute blended confidence from beta posterior + FSRS.""" + from gradata._stats import beta_posterior + + total_obs = int(lesson.alpha + lesson.beta_param - 2) + blend_w = _bayesian_blend_weight(total_obs) + + if blend_w == 0.0: + return lesson.confidence + + post = beta_posterior( + successes=max(0, int(lesson.alpha - 1)), + trials=max(0, total_obs), + ) + bayesian_conf = post["posterior_mean"] + return round(blend_w * bayesian_conf + (1.0 - blend_w) * lesson.confidence, 2) + + # --------------------------------------------------------------------------- # Correction Direction Detection # --------------------------------------------------------------------------- @@ -571,6 +655,7 @@ def update_confidence( if direction == "REINFORCING": # Reinforcing: correction aligns with lesson direction → BONUS + lesson.alpha += 1.0 base_bonus = fsrs_bonus(lesson.confidence, machine=is_machine) if cat in severity_data: raw_severity = severity_data[cat] @@ -579,12 +664,12 @@ def update_confidence( bonus = base_bonus * weight else: bonus = base_bonus - lesson.confidence = round( - max(0.0, min(1.0, lesson.confidence + bonus)), 2 - ) + lesson.confidence = round(max(0.0, min(1.0, lesson.confidence + bonus)), 2) + lesson.confidence = max(0.0, min(1.0, _bayesian_confidence(lesson))) lesson.fire_count += 1 else: # CONTRADICTING or UNKNOWN: FSRS penalty + lesson.beta_param += 1.0 base_penalty = fsrs_penalty(lesson.confidence, machine=is_machine) if cat in severity_data: raw_severity = severity_data[cat] @@ -596,9 +681,8 @@ def update_confidence( # Preferences decay slower — stable user taste signals if lesson.correction_type == CorrectionType.PREFERENCE: penalty *= PREFERENCE_DECAY_DAMPER - lesson.confidence = round( - max(0.0, min(1.0, lesson.confidence - penalty)), 2 - ) + lesson.confidence = round(max(0.0, min(1.0, lesson.confidence - penalty)), 2) + lesson.confidence = max(0.0, min(1.0, _bayesian_confidence(lesson))) else: # Survived: category was testable, corrections exist elsewhere base_bonus = fsrs_bonus(lesson.confidence, machine=is_machine) @@ -620,9 +704,7 @@ def update_confidence( bonus = base_bonus else: bonus = base_bonus - lesson.confidence = round( - max(0.0, min(1.0, lesson.confidence + bonus)), 2 - ) + lesson.confidence = round(max(0.0, min(1.0, lesson.confidence + bonus)), 2) # Track sessions since fire lesson.sessions_since_fire += 1 @@ -631,9 +713,8 @@ def update_confidence( if ( lesson.fire_count == 0 and lesson.sessions_since_fire >= kill_limit - and lesson.state not in ( - LessonState.UNTESTABLE, LessonState.KILLED, LessonState.ARCHIVED - ) + and lesson.state + not in (LessonState.UNTESTABLE, LessonState.KILLED, LessonState.ARCHIVED) ): lesson.state = LessonState.UNTESTABLE @@ -643,6 +724,7 @@ def update_confidence( # Use salted thresholds consistent with graduate() if salt: from gradata.security.brain_salt import salt_threshold + _uc_pattern_thr = salt_threshold(PATTERN_THRESHOLD, salt, "PATTERN") _uc_rule_thr = salt_threshold(RULE_THRESHOLD, salt, "RULE") else: @@ -664,10 +746,7 @@ def update_confidence( ): lesson.state = transition(lesson.state, "promote") # Demote PATTERN -> INSTINCT - elif ( - lesson.state == LessonState.PATTERN - and lesson.confidence < _uc_pattern_thr - ): + elif lesson.state == LessonState.PATTERN and lesson.confidence < _uc_pattern_thr: lesson.state = transition(lesson.state, "demote") return lessons @@ -709,6 +788,7 @@ def graduate( # Compute effective thresholds (salted or default) if salt: from gradata.security.brain_salt import salt_threshold + eff_pattern_threshold = salt_threshold(PATTERN_THRESHOLD, salt, "PATTERN") eff_rule_threshold = salt_threshold(RULE_THRESHOLD, salt, "RULE") else: @@ -716,7 +796,9 @@ def graduate( eff_rule_threshold = RULE_THRESHOLD if renter: active = [l for l in lessons if l.state in (LessonState.INSTINCT, LessonState.PATTERN)] - graduated = [l for l in lessons if l.state not in (LessonState.INSTINCT, LessonState.PATTERN)] + graduated = [ + l for l in lessons if l.state not in (LessonState.INSTINCT, LessonState.PATTERN) + ] return active, graduated eff_kill_limits = MACHINE_KILL_LIMITS if machine_mode else KILL_LIMITS @@ -733,13 +815,16 @@ def graduate( # Safety assertion: warn if a single session caused an unusually # large confidence jump (possible runaway boosting). - pre_conf = getattr(lesson, '_pre_session_confidence', None) + pre_conf = getattr(lesson, "_pre_session_confidence", None) if isinstance(pre_conf, (int, float)): jump = lesson.confidence - pre_conf if jump > eff_pattern_threshold: _log.warning( "Safety assertion: confidence jump %.2f exceeds PATTERN_THRESHOLD %.2f for %s: %s", - jump, eff_pattern_threshold, lesson.category, lesson.description[:60], + jump, + eff_pattern_threshold, + lesson.category, + lesson.description[:60], ) if lesson.pending_approval: @@ -750,8 +835,12 @@ def graduate( if lesson.scope_json: try: import json as _json + _scope = _json.loads(lesson.scope_json) - if _scope.get("correction_scope") == "one_off" and lesson.state in (LessonState.INSTINCT, LessonState.PATTERN): + if _scope.get("correction_scope") == "one_off" and lesson.state in ( + LessonState.INSTINCT, + LessonState.PATTERN, + ): block_promotion = True except (ValueError, TypeError): pass @@ -803,12 +892,16 @@ def graduate( # outside the lesson loop to avoid redundant tokenization (O(n*m) → O(n+m)). try: from gradata.enhancements.similarity import semantic_similarity + for existing in existing_rules: sim = semantic_similarity(lesson.description, existing.description) if sim > _GRADUATION_DEDUP_THRESHOLD: blocked = True - _log.debug("Graduation blocked (duplicate): %.2f sim with '%s'", - sim, existing.description[:60]) + _log.debug( + "Graduation blocked (duplicate): %.2f sim with '%s'", + sim, + existing.description[:60], + ) break except ImportError: pass @@ -817,12 +910,14 @@ def graduate( if not blocked: try: direction = _classify_correction_direction( - lesson.description, existing_rule_summary, + lesson.description, + existing_rule_summary, ) if direction == "CONTRADICTING": blocked = True - _log.debug("Graduation blocked (contradiction): '%s'", - lesson.description[:60]) + _log.debug( + "Graduation blocked (contradiction): '%s'", lesson.description[:60] + ) except Exception: pass @@ -832,6 +927,7 @@ def graduate( if not blocked: try: from gradata.enhancements.similarity import semantic_similarity + words = lesson.description.split() if len(words) >= 4: mid = len(words) // 2 @@ -839,8 +935,11 @@ def graduate( sim = semantic_similarity(lesson.description, paraphrase) if sim < 0.25: blocked = True - _log.debug("Graduation blocked (fragile wording): sim=%.2f '%s'", - sim, lesson.description[:60]) + _log.debug( + "Graduation blocked (fragile wording): sim=%.2f '%s'", + sim, + lesson.description[:60], + ) except ImportError: pass @@ -850,6 +949,7 @@ def graduate( # Refine rule wording before graduation try: from gradata.contrib.patterns.tree_of_thoughts import evaluate_rule_candidates + result = evaluate_rule_candidates(lesson.description, existing_rule_descs) if result.best.content and result.best.score > 0.3: lesson.description = result.best.content @@ -869,10 +969,7 @@ def graduate( continue # Demote PATTERN -> INSTINCT - if ( - lesson.state == LessonState.PATTERN - and lesson.confidence < eff_pattern_threshold - ): + if lesson.state == LessonState.PATTERN and lesson.confidence < eff_pattern_threshold: lesson.state = transition(lesson.state, "demote") continue @@ -970,10 +1067,19 @@ def format_lessons(lessons: list[Lesson]) -> str: if lesson.domain_scores: import json as _json + lines.append(f" Domain scores: {_json.dumps(lesson.domain_scores)}") + if lesson.alpha != 1.0 or lesson.beta_param != 1.0: + import json as _json_bp + + lines.append( + f" Beta params: {_json_bp.dumps({'alpha': lesson.alpha, 'beta': lesson.beta_param})}" + ) + if hasattr(lesson, "metadata") and lesson.metadata is not None: import json as _json_meta + md = lesson.metadata.to_dict() if hasattr(lesson.metadata, "to_dict") else {} if any(v for v in md.values() if v and v != 0.5): # only write non-default lines.append(f" Metadata: {_json_meta.dumps(md)}") @@ -1021,7 +1127,8 @@ def compute_learning_velocity( rule_lessons = [l for l in lessons if l.state == LessonState.RULE] avg_time = ( sum(l.fire_count + l.sessions_since_fire for l in rule_lessons) / len(rule_lessons) - if rule_lessons else 0.0 + if rule_lessons + else 0.0 ) return { diff --git a/tests/test_bayesian_confidence.py b/tests/test_bayesian_confidence.py new file mode 100644 index 00000000..71a5c7f9 --- /dev/null +++ b/tests/test_bayesian_confidence.py @@ -0,0 +1,84 @@ +"""Tests for Bayesian confidence integration in self_improvement.""" + +from gradata._types import Lesson, LessonState +from gradata.enhancements.self_improvement import update_confidence, format_lessons, parse_lessons + + +class TestBayesianConfidence: + def test_reinforcing_increments_alpha(self): + lesson = Lesson( + date="2026-04-10", + state=LessonState.PATTERN, + confidence=0.65, + category="TONE", + description="Be direct", + alpha=3.0, + beta_param=1.0, + ) + corrections = [{"category": "TONE", "direction": "REINFORCING"}] + update_confidence([lesson], corrections) + assert lesson.alpha > 3.0 + assert lesson.beta_param == 1.0 + + def test_contradicting_increments_beta(self): + lesson = Lesson( + date="2026-04-10", + state=LessonState.PATTERN, + confidence=0.65, + category="TONE", + description="Be direct", + alpha=3.0, + beta_param=1.0, + ) + corrections = [{"category": "TONE", "direction": "CONTRADICTING"}] + update_confidence([lesson], corrections) + assert lesson.alpha == 3.0 + assert lesson.beta_param > 1.0 + + def test_bayesian_blending_produces_valid_confidence(self): + lesson = Lesson( + date="2026-04-10", + state=LessonState.INSTINCT, + confidence=0.40, + category="TONE", + description="Be direct", + alpha=1.0, + beta_param=1.0, + ) + for _ in range(5): + corrections = [{"category": "TONE", "direction": "REINFORCING"}] + update_confidence([lesson], corrections) + assert 0.0 <= lesson.confidence <= 1.0 + assert lesson.alpha > 1.0 + + def test_survival_does_not_touch_beta_params(self): + lesson = Lesson( + date="2026-04-10", + state=LessonState.PATTERN, + confidence=0.65, + category="TONE", + description="Be direct", + alpha=3.0, + beta_param=1.0, + ) + corrections = [{"category": "STYLE"}] + update_confidence([lesson], corrections) + assert lesson.alpha == 3.0 + assert lesson.beta_param == 1.0 + + def test_alpha_beta_roundtrip(self): + lesson = Lesson( + date="2026-04-10", + state=LessonState.PATTERN, + confidence=0.72, + category="TONE", + description="Be direct", + alpha=8.5, + beta_param=2.0, + fire_count=7, + ) + text = format_lessons([lesson]) + restored = parse_lessons(text) + assert len(restored) == 1 + assert restored[0].alpha == 8.5 + assert restored[0].beta_param == 2.0 From 584b44cc5f9fc29e23eb534540bcb21cca521b70 Mon Sep 17 00:00:00 2001 From: Oliver Le Date: Fri, 10 Apr 2026 12:28:41 -0700 Subject: [PATCH 08/13] feat(P1.2): positive signal detection + OUTPUT_ACCEPTED Add APPROVAL_PATTERNS to implicit_feedback hook for detecting user approval signals (looks good, perfect, ship it, etc). Emit OUTPUT_ACCEPTED event when approval detected. Co-Authored-By: Gradata --- src/gradata/hooks/implicit_feedback.py | 36 +++++++++++++++++--- tests/test_approval_signals.py | 46 ++++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 5 deletions(-) create mode 100644 tests/test_approval_signals.py diff --git a/src/gradata/hooks/implicit_feedback.py b/src/gradata/hooks/implicit_feedback.py index f5f8d03f..2db735e1 100644 --- a/src/gradata/hooks/implicit_feedback.py +++ b/src/gradata/hooks/implicit_feedback.py @@ -1,4 +1,5 @@ """UserPromptSubmit hook: detect implicit feedback signals in user messages.""" + from __future__ import annotations import logging @@ -44,10 +45,21 @@ re.compile(r"\bwhy (did|would|are) you\b", re.I), ] +APPROVAL_PATTERNS = [ + re.compile(r"\blooks? good\b", re.I), + re.compile(r"\bperfect\b", re.I), + re.compile(r"\bexactly( what)?\b", re.I), + re.compile(r"\bthat'?s (right|correct|great|perfect)\b", re.I), + re.compile(r"\byes[,.]?\s+(exactly|perfect|great|right|that)\b", re.I), + re.compile(r"\bship it\b", re.I), + re.compile(r"\bnailed it\b", re.I), +] + SIGNAL_MAP = { "negation": NEGATION_PATTERNS, "reminder": REMINDER_PATTERNS, "challenge": CHALLENGE_PATTERNS, + "approval": APPROVAL_PATTERNS, } @@ -60,11 +72,13 @@ def _detect_signals(text: str) -> list[dict]: start = max(0, match.start() - 20) end = min(len(text), match.end() + 40) snippet = text[start:end].strip() - signals.append({ - "type": signal_type, - "match": match.group(), - "snippet": snippet, - }) + signals.append( + { + "type": signal_type, + "match": match.group(), + "snippet": snippet, + } + ) break # One match per category is enough return signals @@ -86,6 +100,7 @@ def main(data: dict) -> dict | None: try: from gradata._events import emit from gradata._paths import BrainContext + ctx = BrainContext.from_brain_dir(brain_dir) emit( "IMPLICIT_FEEDBACK", @@ -97,6 +112,17 @@ def main(data: dict) -> dict | None: }, ctx=ctx, ) + # Emit OUTPUT_ACCEPTED for approval signals + if any(s["type"] == "approval" for s in signals): + emit( + "OUTPUT_ACCEPTED", + source="hook:implicit_feedback", + data={ + "snippets": [s["snippet"] for s in signals if s["type"] == "approval"], + "message_preview": message[:200], + }, + ctx=ctx, + ) except Exception as exc: _log.debug("implicit_feedback emit failed: %s", exc) diff --git a/tests/test_approval_signals.py b/tests/test_approval_signals.py new file mode 100644 index 00000000..a8499f29 --- /dev/null +++ b/tests/test_approval_signals.py @@ -0,0 +1,46 @@ +"""Tests for approval signal detection in implicit_feedback hook.""" + +from gradata.hooks.implicit_feedback import _detect_signals + + +class TestApprovalSignals: + def test_looks_good(self): + signals = _detect_signals("Looks good, ship it.") + types = [s["type"] for s in signals] + assert "approval" in types + + def test_perfect(self): + signals = _detect_signals("Perfect, that's what I needed.") + types = [s["type"] for s in signals] + assert "approval" in types + + def test_exactly(self): + signals = _detect_signals("Yes, exactly what I wanted.") + types = [s["type"] for s in signals] + assert "approval" in types + + def test_thats_correct(self): + signals = _detect_signals("That's correct.") + types = [s["type"] for s in signals] + assert "approval" in types + + def test_nailed_it(self): + signals = _detect_signals("You nailed it!") + types = [s["type"] for s in signals] + assert "approval" in types + + def test_ship_it(self): + signals = _detect_signals("Ship it!") + types = [s["type"] for s in signals] + assert "approval" in types + + def test_no_false_positive_on_neutral(self): + signals = _detect_signals("Can you update the database schema?") + types = [s["type"] for s in signals] + assert "approval" not in types + + def test_negation_not_approval(self): + signals = _detect_signals("No, that's wrong.") + types = [s["type"] for s in signals] + assert "negation" in types + assert "approval" not in types From 3f6ee5dc1ea7983e827729aa2c69e7aa0f2b3843 Mon Sep 17 00:00:00 2001 From: Oliver Le Date: Fri, 10 Apr 2026 12:28:58 -0700 Subject: [PATCH 09/13] feat(P1.3): dual-layer intent classifier Regex-based classifier that determines correction intent: factual, compliance, formatting, clarity, conciseness, completeness, tone_shift, preference, or unknown. Returns CorrectionIntent with confidence and evidence. Co-Authored-By: Gradata --- src/gradata/detection/__init__.py | 3 + src/gradata/detection/intent_classifier.py | 175 +++++++++++++++++++++ tests/test_intent_classifier.py | 74 +++++++++ 3 files changed, 252 insertions(+) create mode 100644 src/gradata/detection/intent_classifier.py create mode 100644 tests/test_intent_classifier.py diff --git a/src/gradata/detection/__init__.py b/src/gradata/detection/__init__.py index e0c979f4..c5c804bf 100644 --- a/src/gradata/detection/__init__.py +++ b/src/gradata/detection/__init__.py @@ -12,13 +12,16 @@ detect_conflict, extract_diff_tokens, ) +from gradata.detection.intent_classifier import CorrectionIntent, classify_intent from gradata.detection.mode_classifier import MODE_CATEGORY_MAP, classify_mode __all__ = [ "MODE_CATEGORY_MAP", "AdditionTracker", "ConflictTracker", + "CorrectionIntent", "classify_addition", + "classify_intent", "classify_mode", "detect_conflict", "extract_diff_tokens", diff --git a/src/gradata/detection/intent_classifier.py b/src/gradata/detection/intent_classifier.py new file mode 100644 index 00000000..5f7ff6b6 --- /dev/null +++ b/src/gradata/detection/intent_classifier.py @@ -0,0 +1,175 @@ +"""Dual-layer intent classifier for corrections. + +Classifies corrections by intent using regex heuristics to determine +*why* the user corrected the output, not just *what* changed. +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass + +# --------------------------------------------------------------------------- +# Intent categories +# --------------------------------------------------------------------------- + +INTENTS = [ + "factual_correction", + "compliance", + "formatting", + "clarity", + "conciseness", + "completeness", + "tone_shift", + "preference", + "unknown", +] + +# --------------------------------------------------------------------------- +# Regex heuristics (layer 1) +# --------------------------------------------------------------------------- + +_FACTUAL_RE = [ + re.compile(r"\b\d{4}[-/]\d{2}[-/]\d{2}\b"), # dates + re.compile(r"\b\d+(\.\d+)?%"), # percentages + re.compile(r"\$\d+"), # dollar amounts + re.compile(r"\bhttps?://\S+"), # URLs + re.compile(r"\b\d{3,}\b"), # large numbers (3+ digits) +] + +_COMPLIANCE_RE = [ + re.compile(r"\bGDPR\b", re.I), + re.compile(r"\bSOC\s*2\b", re.I), + re.compile(r"\bcomplian(ce|t)\b", re.I), + re.compile(r"\bregulat(ion|ory)\b", re.I), + re.compile(r"\blegal(ly)?\b", re.I), + re.compile(r"\bprivacy\b", re.I), + re.compile(r"\bdisclaimer\b", re.I), +] + +_HEDGE_WORDS = re.compile(r"\b(perhaps|maybe|might|could be|it seems|arguably|somewhat)\b", re.I) + +_LIST_RE = re.compile(r"^[\s]*[-*\d]+[.)]\s", re.MULTILINE) + +_CONCISE_RE = [ + re.compile(r"\btoo (long|verbose|wordy)\b", re.I), + re.compile(r"\bshorten\b", re.I), + re.compile(r"\bbrevity\b", re.I), + re.compile(r"\bconcise\b", re.I), + re.compile(r"\btrim\b", re.I), +] + +_COMPLETE_RE = [ + re.compile(r"\bmissing\b", re.I), + re.compile(r"\bincomplete\b", re.I), + re.compile(r"\badd (more|the|a)\b", re.I), + re.compile(r"\bdon'?t forget\b", re.I), + re.compile(r"\binclude\b", re.I), +] + +_TONE_RE = [ + re.compile(r"\btoo (formal|casual|aggressive|harsh|soft)\b", re.I), + re.compile(r"\btone\b", re.I), + re.compile(r"\bfriendl(y|ier)\b", re.I), + re.compile(r"\bprofessional\b", re.I), + re.compile(r"\bwarm(er)?\b", re.I), +] + + +@dataclass +class CorrectionIntent: + """Result of intent classification.""" + + intent: str + confidence: float + evidence: str + + +def classify_intent( + correction_text: str, + original_text: str = "", +) -> CorrectionIntent: + """Classify correction intent using regex heuristics. + + Args: + correction_text: The correction description or user message. + original_text: The original output that was corrected (optional). + + Returns: + CorrectionIntent with intent label, confidence, and evidence. + """ + combined = f"{correction_text} {original_text}" + + # Check each intent category in priority order + for pat in _FACTUAL_RE: + m = pat.search(correction_text) + if m: + return CorrectionIntent( + intent="factual_correction", + confidence=0.85, + evidence=m.group(), + ) + + for pat in _COMPLIANCE_RE: + m = pat.search(combined) + if m: + return CorrectionIntent( + intent="compliance", + confidence=0.80, + evidence=m.group(), + ) + + for pat in _TONE_RE: + m = pat.search(combined) + if m: + return CorrectionIntent( + intent="tone_shift", + confidence=0.75, + evidence=m.group(), + ) + + for pat in _CONCISE_RE: + m = pat.search(combined) + if m: + return CorrectionIntent( + intent="conciseness", + confidence=0.75, + evidence=m.group(), + ) + + for pat in _COMPLETE_RE: + m = pat.search(combined) + if m: + return CorrectionIntent( + intent="completeness", + confidence=0.70, + evidence=m.group(), + ) + + if _HEDGE_WORDS.search(correction_text): + return CorrectionIntent( + intent="clarity", + confidence=0.60, + evidence=_HEDGE_WORDS.search(correction_text).group(), # type: ignore[union-attr] + ) + + if _LIST_RE.search(correction_text): + return CorrectionIntent( + intent="formatting", + confidence=0.60, + evidence="list/bullet formatting detected", + ) + + # Fallback: if correction is short and doesn't match anything, likely preference + if len(correction_text.split()) <= 10 and correction_text.strip(): + return CorrectionIntent( + intent="preference", + confidence=0.40, + evidence="short correction without specific pattern", + ) + + return CorrectionIntent( + intent="unknown", + confidence=0.20, + evidence="no pattern matched", + ) diff --git a/tests/test_intent_classifier.py b/tests/test_intent_classifier.py new file mode 100644 index 00000000..0f5ac459 --- /dev/null +++ b/tests/test_intent_classifier.py @@ -0,0 +1,74 @@ +"""Tests for the dual-layer intent classifier.""" + +from gradata.detection.intent_classifier import CorrectionIntent, classify_intent + + +class TestIntentClassifier: + def test_factual_correction_date(self): + result = classify_intent("The date should be 2026-04-10, not 2026-03-10") + assert result.intent == "factual_correction" + assert result.confidence >= 0.8 + + def test_factual_correction_url(self): + result = classify_intent("Use https://example.com instead") + assert result.intent == "factual_correction" + + def test_factual_correction_percentage(self): + result = classify_intent("It's 15.5% not 20%") + assert result.intent == "factual_correction" + + def test_compliance_gdpr(self): + result = classify_intent("We need GDPR compliance here") + assert result.intent == "compliance" + + def test_compliance_legal(self): + result = classify_intent("Add a privacy disclaimer") + assert result.intent == "compliance" + + def test_tone_shift(self): + result = classify_intent("Make it more professional") + assert result.intent == "tone_shift" + + def test_tone_too_casual(self): + result = classify_intent("Too casual for this audience") + assert result.intent == "tone_shift" + + def test_conciseness(self): + result = classify_intent("This is too verbose, shorten it") + assert result.intent == "conciseness" + + def test_completeness(self): + result = classify_intent("You're missing the error handling section") + assert result.intent == "completeness" + + def test_clarity_hedge(self): + result = classify_intent("Perhaps this could be clearer") + assert result.intent == "clarity" + + def test_formatting_list(self): + result = classify_intent("1. First item\n2. Second item") + assert result.intent == "formatting" + + def test_preference_short(self): + result = classify_intent("Use blue instead") + assert result.intent == "preference" + + def test_unknown_long_text(self): + result = classify_intent( + "I think there are some general issues with the overall approach " + "that we should discuss in more detail at some point in the future" + ) + assert result.intent == "unknown" + + def test_returns_dataclass(self): + result = classify_intent("Fix the date to 2026-01-01") + assert isinstance(result, CorrectionIntent) + assert isinstance(result.evidence, str) + assert 0.0 <= result.confidence <= 1.0 + + def test_original_text_used_for_compliance(self): + result = classify_intent( + "Fix this section", + original_text="The GDPR requirements state that...", + ) + assert result.intent == "compliance" From d455c0d0b9fa68e56151cec04039845c1b26c762 Mon Sep 17 00:00:00 2001 From: Oliver Le Date: Fri, 10 Apr 2026 12:41:10 -0700 Subject: [PATCH 10/13] =?UTF-8?q?feat(P0.1):=20semantic=20severity=20adjus?= =?UTF-8?q?tment=20=E2=80=94=20downgrade=20when=20meaning=20unchanged?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Gradata --- src/gradata/enhancements/diff_engine.py | 41 +++++++++++++++++-- tests/test_diff_engine.py | 52 +++++++++++++++++++++++++ 2 files changed, 89 insertions(+), 4 deletions(-) create mode 100644 tests/test_diff_engine.py diff --git a/src/gradata/enhancements/diff_engine.py b/src/gradata/enhancements/diff_engine.py index ee3de11e..2ff2f097 100644 --- a/src/gradata/enhancements/diff_engine.py +++ b/src/gradata/enhancements/diff_engine.py @@ -75,6 +75,7 @@ class DiffResult: changed_sections: list[ChangedSection] severity: str summary_stats: dict[str, int] + semantic_similarity: float | None = None # --------------------------------------------------------------------------- @@ -85,7 +86,7 @@ class DiffResult: (0.02, "as-is"), (0.10, "minor"), (0.30, "moderate"), - (0.80, "major"), # Was 0.60 — too many real corrections were "discarded" + (0.80, "major"), # Was 0.60 — too many real corrections were "discarded" ] @@ -105,6 +106,40 @@ def _classify_severity(edit_distance: float) -> str: return "discarded" +# --------------------------------------------------------------------------- +# Semantic severity adjustment +# --------------------------------------------------------------------------- + +_SEVERITY_DOWNGRADE: dict[str, str] = { + "discarded": "major", + "major": "moderate", + "moderate": "minor", + "minor": "as-is", +} + +_SEMANTIC_DOWNGRADE_THRESHOLD = 0.85 + + +def adjust_severity_by_semantics( + result: DiffResult, + semantic_similarity: float, + threshold: float = _SEMANTIC_DOWNGRADE_THRESHOLD, +) -> DiffResult: + """Downgrade severity by one level when semantic similarity is high.""" + new_severity = result.severity + if semantic_similarity >= threshold and result.severity in _SEVERITY_DOWNGRADE: + new_severity = _SEVERITY_DOWNGRADE[result.severity] + + return DiffResult( + edit_distance=result.edit_distance, + compression_distance=result.compression_distance, + changed_sections=result.changed_sections, + severity=new_severity, + summary_stats=result.summary_stats, + semantic_similarity=semantic_similarity, + ) + + # --------------------------------------------------------------------------- # Section extraction # --------------------------------------------------------------------------- @@ -268,9 +303,7 @@ def compute_diff(draft: str, final: str) -> DiffResult: # Normalised similarity ratio: 1.0 = identical, 0.0 = nothing in common. # We invert to get edit_distance where 0.0 = identical. - similarity = difflib.SequenceMatcher( - None, draft, final, autojunk=False - ).ratio() + similarity = difflib.SequenceMatcher(None, draft, final, autojunk=False).ratio() edit_distance = round(1.0 - similarity, 6) # Compression distance: better correlated with human editing effort. diff --git a/tests/test_diff_engine.py b/tests/test_diff_engine.py new file mode 100644 index 00000000..daab2ad2 --- /dev/null +++ b/tests/test_diff_engine.py @@ -0,0 +1,52 @@ +"""Tests for diff engine semantic severity adjustment.""" + +import pytest + +from gradata.enhancements.diff_engine import DiffResult, adjust_severity_by_semantics + + +class TestSemanticSeverityAdjustment: + def test_high_similarity_downgrades_moderate_to_minor(self): + result = DiffResult( + edit_distance=0.25, + compression_distance=0.20, + changed_sections=[], + severity="moderate", + summary_stats={"lines_added": 2, "lines_removed": 1, "lines_changed": 1}, + ) + adjusted = adjust_severity_by_semantics(result, semantic_similarity=0.90) + assert adjusted.severity == "minor" + assert adjusted.semantic_similarity == 0.90 + + def test_low_similarity_no_change(self): + result = DiffResult( + edit_distance=0.25, + compression_distance=0.20, + changed_sections=[], + severity="moderate", + summary_stats={"lines_added": 2, "lines_removed": 1, "lines_changed": 1}, + ) + adjusted = adjust_severity_by_semantics(result, semantic_similarity=0.50) + assert adjusted.severity == "moderate" + + def test_as_is_not_downgraded(self): + result = DiffResult( + edit_distance=0.01, + compression_distance=0.01, + changed_sections=[], + severity="as-is", + summary_stats={"lines_added": 0, "lines_removed": 0, "lines_changed": 0}, + ) + adjusted = adjust_severity_by_semantics(result, semantic_similarity=0.95) + assert adjusted.severity == "as-is" + + def test_discarded_downgrades_to_major(self): + result = DiffResult( + edit_distance=0.85, + compression_distance=0.82, + changed_sections=[], + severity="discarded", + summary_stats={"lines_added": 10, "lines_removed": 8, "lines_changed": 8}, + ) + adjusted = adjust_severity_by_semantics(result, semantic_similarity=0.92) + assert adjusted.severity == "major" From 7377f3504ce8009ba0bb2380f2f088cf11f48e3b Mon Sep 17 00:00:00 2001 From: Oliver Le Date: Fri, 10 Apr 2026 12:41:40 -0700 Subject: [PATCH 11/13] feat(P0.2): wire RULE_SUPPRESSION events into apply_rules filter stages Co-Authored-By: Gradata --- src/gradata/rules/rule_engine.py | 108 ++++++++++++++++++++++--------- 1 file changed, 77 insertions(+), 31 deletions(-) diff --git a/src/gradata/rules/rule_engine.py b/src/gradata/rules/rule_engine.py index b3d8dfcc..3c606705 100644 --- a/src/gradata/rules/rule_engine.py +++ b/src/gradata/rules/rule_engine.py @@ -120,22 +120,29 @@ class AppliedRule: # Keywords signaling universal scope (AI quality issues, not personal style) _UNIVERSAL_SIGNALS: list[str] = [ - "em dash", "em dashes", - "verify", "verification", - "fabricat", "hallucin", + "em dash", + "em dashes", + "verify", + "verification", + "fabricat", + "hallucin", "bold mid-paragraph", "rule of three", "promotional language", "never skip", - "don't assume", "never assume", - "check before", "verify before", + "don't assume", + "never assume", + "check before", + "verify before", "superficial analysis", ] # Keywords signaling team/org scope (tool or company specific). # Override via brain config to add your team's tooling keywords. _TEAM_SIGNALS: list[str] = [ - "brain/", ".carl/", "domain/", + "brain/", + ".carl/", + "domain/", ] @@ -304,9 +311,7 @@ def _make_rule_id(lesson: Lesson) -> str: Returns: Rule identifier string. """ - digest = hashlib.sha256( - f"{lesson.category}:{lesson.description}".encode() - ).hexdigest()[:8] + digest = hashlib.sha256(f"{lesson.category}:{lesson.description}".encode()).hexdigest()[:8] return f"{lesson.category}:{digest}" @@ -351,8 +356,7 @@ def validate_assumptions( if lesson.confidence < floor: return ( False, - f"confidence {lesson.confidence:.2f} below {lesson.state.value} " - f"floor {floor:.2f}", + f"confidence {lesson.confidence:.2f} below {lesson.state.value} floor {floor:.2f}", ) # (b) Category contradiction @@ -372,8 +376,7 @@ def validate_assumptions( if rule_tt and rule_tt != current_tt: return ( False, - f"rule scoped to '{rule_tt}' but current task is " - f"'{current_tt}'", + f"rule scoped to '{rule_tt}' but current task is '{current_tt}'", ) except Exception: pass # malformed scope_json — don't block on it @@ -417,7 +420,9 @@ def filter_by_scope( if lesson.scope_json: try: scope_dict = json.loads(lesson.scope_json) - lesson_scope = RuleScope(**{k: v for k, v in scope_dict.items() if k in RuleScope.__dataclass_fields__}) + lesson_scope = RuleScope( + **{k: v for k, v in scope_dict.items() if k in RuleScope.__dataclass_fields__} + ) except Exception: lesson_scope = RuleScope() else: @@ -434,6 +439,7 @@ def _beta_ppf_05(alpha: float, beta_param: float) -> float: Uses normal approximation. For tiny samples, returns conservative estimate. """ import math + if alpha <= 0 or beta_param <= 0: return 0.0 total = alpha + beta_param @@ -503,6 +509,7 @@ def apply_rules( _context: str = "", bus: EventBus | None = None, graph: RuleGraph | None = None, + _ctx=None, ) -> list[AppliedRule]: """Select and rank lessons relevant to the given scope. @@ -559,12 +566,24 @@ def apply_rules( if is_rule_disabled_for_domain(lesson, current_domain): if bus: scores = lesson.domain_scores.get(current_domain, {}) - bus.emit("rule_scoped_out", { - "lesson_category": lesson.category, - "lesson_description": lesson.description[:80], - "domain": current_domain, - "misfire_rate": scores.get("misfires", 0) / max(1, scores.get("fires", 1)), - }) + bus.emit( + "rule_scoped_out", + { + "lesson_category": lesson.category, + "lesson_description": lesson.description[:80], + "domain": current_domain, + "misfire_rate": scores.get("misfires", 0) + / max(1, scores.get("fires", 1)), + }, + ) + from gradata.rules.rule_tracker import log_suppression + + log_suppression( + rule_id=_make_rule_id(lesson), + reason="domain_disabled", + relevance=0.0, + ctx=_ctx, + ) else: filtered.append(lesson) eligible = filtered @@ -576,7 +595,9 @@ def apply_rules( if lesson.scope_json: try: scope_dict = json.loads(lesson.scope_json) - lesson_scope = RuleScope(**{k: v for k, v in scope_dict.items() if k in RuleScope.__dataclass_fields__}) + lesson_scope = RuleScope( + **{k: v for k, v in scope_dict.items() if k in RuleScope.__dataclass_fields__} + ) except Exception: lesson_scope = RuleScope() else: @@ -586,6 +607,15 @@ def apply_rules( relevance *= _CT_BOOST.get(lesson.correction_type, 1.0) if relevance >= 0.3: scored.append((lesson, relevance)) + elif _ctx: + from gradata.rules.rule_tracker import log_suppression + + log_suppression( + rule_id=_make_rule_id(lesson), + reason="relevance_threshold", + relevance=relevance, + ctx=_ctx, + ) # Step 3.5 — assumption invalidation (dynamic runtime checks) if _context: @@ -606,8 +636,18 @@ def apply_rules( else: _log.debug( "Skipping rule %s: assumption invalid — %s", - lesson.category, reason, + lesson.category, + reason, ) + if _ctx: + from gradata.rules.rule_tracker import log_suppression + + log_suppression( + rule_id=_make_rule_id(lesson), + reason="assumption_invalid", + relevance=relevance, + ctx=_ctx, + ) scored = validated # Step 4 — compute difficulty per rule @@ -626,7 +666,9 @@ def _effective_conf(lesson: Lesson, domain: str) -> float: key=lambda t: ( _STATE_PRIORITY[t[0].state], # Difficulty: use event history if available, else lesson counters - compute_rule_difficulty(t[0].category, events) if events else _difficulty_from_lesson(t[0]), + compute_rule_difficulty(t[0].category, events) + if events + else _difficulty_from_lesson(t[0]), t[1], _effective_conf(t[0], current_domain), ), @@ -648,6 +690,16 @@ def _effective_conf(lesson: Lesson, domain: str) -> float: if not dominated: filtered_scored.append((lesson, relevance)) selected_ids.add(rule_id) + elif _ctx: + from gradata.rules.rule_tracker import log_suppression + + log_suppression( + rule_id=rule_id, + reason="conflict", + relevance=relevance, + competing_rule_ids=list(selected_ids), + ctx=_ctx, + ) scored = filtered_scored # Step 6 — assemble AppliedRule objects, capped at max_rules @@ -655,10 +707,7 @@ def _effective_conf(lesson: Lesson, domain: str) -> float: for lesson, relevance in scored[:max_rules]: rule_id = _make_rule_id(lesson) tier_label = _tier_label(lesson) - instruction = ( - f"[{tier_label}]" - f" {lesson.category}: {lesson.description}" - ) + instruction = f"[{tier_label}] {lesson.category}: {lesson.description}" applied.append( AppliedRule( rule_id=rule_id, @@ -782,10 +831,7 @@ def format_rules_for_prompt( # Include few-shot examples for rules that need reinforcement lesson = rule.lesson - needs_reinforcement = ( - lesson.confidence < 0.80 - or getattr(lesson, "misfire_count", 0) > 0 - ) + needs_reinforcement = lesson.confidence < 0.80 or getattr(lesson, "misfire_count", 0) > 0 if ( needs_reinforcement and getattr(lesson, "example_draft", None) is not None From d13a46ebc60e28a31348cd6cb60bc49684ad4413 Mon Sep 17 00:00:00 2001 From: Oliver Le Date: Fri, 10 Apr 2026 12:41:56 -0700 Subject: [PATCH 12/13] =?UTF-8?q?feat(P0.3):=20Rosch=206-category=20taxono?= =?UTF-8?q?my=20=E2=80=94=20hierarchical=20parents=20with=20backward=20com?= =?UTF-8?q?pat?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Gradata --- src/gradata/_tag_taxonomy.py | 174 +++++++++++++++++++++++++++-------- 1 file changed, 135 insertions(+), 39 deletions(-) diff --git a/src/gradata/_tag_taxonomy.py b/src/gradata/_tag_taxonomy.py index fc20d360..8353a70d 100644 --- a/src/gradata/_tag_taxonomy.py +++ b/src/gradata/_tag_taxonomy.py @@ -29,11 +29,25 @@ "mode": "closed", "values": { # Core categories (from edit classifier: content/factual/tone/structure/style) - "CONTENT", "FACTUAL", "TONE", "STRUCTURE", "STYLE", + "CONTENT", + "FACTUAL", + "TONE", + "STRUCTURE", + "STYLE", # Learning system categories - "DRAFTING", "ACCURACY", "PROCESS", "ARCHITECTURE", - "COMMUNICATION", "CONTEXT", "CONSTRAINT", - "DATA_INTEGRITY", "THOROUGHNESS", "COST", + "DRAFTING", + "ACCURACY", + "PROCESS", + "ARCHITECTURE", + "COMMUNICATION", + "CONTEXT", + "CONSTRAINT", + "DATA_INTEGRITY", + "THOROUGHNESS", + "COST", + # Rosch superordinate categories + "VOICE", + "QUALITY", }, "required_on": ["CORRECTION"], }, @@ -56,6 +70,41 @@ }, } +# ── Rosch 6-Category Hierarchy ─────────────────────────────────────── +# Superordinate categories that group fine-grained correction categories +# into broader cognitive buckets (inspired by Rosch's prototype theory). +# A subordinate may appear under multiple parents (e.g. ACCURACY is both +# PROCESS and QUALITY). + +ROSCH_HIERARCHY: dict[str, set[str]] = { + "CONTENT": {"FACTUAL", "DATA_INTEGRITY", "ENTITIES"}, + "VOICE": {"TONE", "STYLE", "COMMUNICATION", "PRESENTATION"}, + "STRUCTURE": {"ARCHITECTURE", "CONTEXT"}, + "PROCESS": {"DRAFTING", "ACCURACY", "THOROUGHNESS"}, + "CONSTRAINT": {"CRM", "STRATEGY", "POSITIONING", "COST", "STARTUP"}, + "QUALITY": {"DATA_INTEGRITY", "ACCURACY", "THOROUGHNESS"}, +} + +ROSCH_PARENTS: set[str] = set(ROSCH_HIERARCHY.keys()) + +_SUB_TO_PARENT: dict[str, str] = {} +for _parent, _subs in ROSCH_HIERARCHY.items(): + for _sub in _subs: + _SUB_TO_PARENT.setdefault(_sub, _parent) +for _p in ROSCH_PARENTS: + _SUB_TO_PARENT[_p] = _p + + +def parent_category(category: str) -> str | None: + """Return the Rosch superordinate parent for a category, or None.""" + return _SUB_TO_PARENT.get(category.upper()) + + +def subordinates_of(parent: str) -> set[str]: + """Return the set of subordinate categories under a Rosch parent.""" + return set(ROSCH_HIERARCHY.get(parent.upper(), set())) + + # ── Domain-Specific Defaults (Sales) ────────────────────────────────── # These ship as defaults because sales is the first domain. # Other domains override via brain_dir/taxonomy.json. @@ -75,54 +124,88 @@ "output": { "desc": "Output type produced", "mode": "closed", - "values": {"email", "cheat_sheet", "research", "cold_call_script", - "linkedin_message", "crm_note", "proposal", "sequence", - "report", "system_artifact"}, + "values": { + "email", + "cheat_sheet", + "research", + "cold_call_script", + "linkedin_message", + "crm_note", + "proposal", + "sequence", + "report", + "system_artifact", + }, "required_on": ["OUTPUT"], }, "angle": { "desc": "Communication angle used in messaging", "mode": "closed", - "values": {"direct", "pain-point", "time-savings", "roi", - "competitor-displacement", "social-proof", "curiosity", - "event-triggered", "referral", "break-up", "custom"}, + "values": { + "direct", + "pain-point", + "time-savings", + "roi", + "competitor-displacement", + "social-proof", + "curiosity", + "event-triggered", + "referral", + "break-up", + "custom", + }, "required_on": [], }, "persona": { "desc": "Buyer persona type", "mode": "closed", - "values": {"agency-owner", "founder", "ecom-director", "cmo", - "growth-lead", "fractional-cmo", "marketing-manager", - "vp-marketing", "other"}, + "values": { + "agency-owner", + "founder", + "ecom-director", + "cmo", + "growth-lead", + "fractional-cmo", + "marketing-manager", + "vp-marketing", + "other", + }, "required_on": [], }, "framework": { "desc": "Sales/email framework applied", "mode": "closed", - "values": {"ccq", "spin", "gap", "jolt", "challenger", - "great-demo", "sandler", "custom"}, + "values": {"ccq", "spin", "gap", "jolt", "challenger", "great-demo", "sandler", "custom"}, "required_on": [], }, "channel": { "desc": "Communication channel", "mode": "closed", - "values": {"email", "phone", "linkedin", "meeting", "demo", - "slack", "text"}, + "values": {"email", "phone", "linkedin", "meeting", "demo", "slack", "text"}, "required_on": ["DELTA_TAG"], }, "outcome": { "desc": "Interaction outcome", "mode": "closed", - "values": {"reply", "no-reply", "positive-reply", "negative-reply", - "meeting-booked", "demo-completed", "deal-advanced", - "deal-lost", "ghosted", "objection-raised", "pending"}, + "values": { + "reply", + "no-reply", + "positive-reply", + "negative-reply", + "meeting-booked", + "demo-completed", + "deal-advanced", + "deal-lost", + "ghosted", + "objection-raised", + "pending", + }, "required_on": [], }, } # Sales-specific correction categories (extend core set) -_DOMAIN_CATEGORIES = {"CRM", "STRATEGY", "ENTITIES", "POSITIONING", - "PRESENTATION", "STARTUP"} +_DOMAIN_CATEGORIES = {"CRM", "STRATEGY", "ENTITIES", "POSITIONING", "PRESENTATION", "STARTUP"} def _load_taxonomy() -> dict: @@ -136,7 +219,9 @@ def _load_taxonomy() -> dict: taxonomy = dict(_CORE_TAXONOMY) # Try loading domain config from brain directory - taxonomy_path = _p.BRAIN_DIR / "taxonomy.json" if hasattr(_p, 'BRAIN_DIR') and _p.BRAIN_DIR else None + taxonomy_path = ( + _p.BRAIN_DIR / "taxonomy.json" if hasattr(_p, "BRAIN_DIR") and _p.BRAIN_DIR else None + ) if taxonomy_path and taxonomy_path.exists(): try: with open(taxonomy_path, encoding="utf-8") as f: @@ -158,8 +243,8 @@ def _load_taxonomy() -> dict: taxonomy[key] = spec # Extend core category values if domain provides extra categories if "extra_categories" in domain_taxonomy: - taxonomy["category"]["values"] = ( - taxonomy["category"]["values"] | set(domain_taxonomy["extra_categories"]) + taxonomy["category"]["values"] = taxonomy["category"]["values"] | set( + domain_taxonomy["extra_categories"] ) return taxonomy except Exception: @@ -184,6 +269,7 @@ def reload_taxonomy() -> None: # ── Validation ───────────────────────────────────────────────────────── + def _get_entity_names() -> set[str]: """Get valid entity names from brain's entity directory. @@ -193,7 +279,7 @@ def _get_entity_names() -> set[str]: names = set() # Try common entity directory names for dirname in ("prospects", "candidates", "customers", "entities"): - entity_dir = _p.BRAIN_DIR / dirname if hasattr(_p, 'BRAIN_DIR') and _p.BRAIN_DIR else None + entity_dir = _p.BRAIN_DIR / dirname if hasattr(_p, "BRAIN_DIR") and _p.BRAIN_DIR else None if entity_dir and entity_dir.exists(): for f in entity_dir.glob("*.md"): if f.name.startswith("_"): @@ -201,7 +287,7 @@ def _get_entity_names() -> set[str]: parts = re.split(r"\s*[—–-]\s*", f.stem, maxsplit=1) names.add(parts[0].strip()) # Legacy fallback - if not names and hasattr(_p, 'PROSPECTS_DIR'): + if not names and hasattr(_p, "PROSPECTS_DIR"): prospects_dir = _p.PROSPECTS_DIR if prospects_dir.exists(): for f in prospects_dir.glob("*.md"): @@ -236,8 +322,9 @@ def validate_tag(tag: str, strict: bool = False) -> tuple[bool, str]: return True, "ok" -def validate_tags(tags: list[str], event_type: str | None = None, - strict: bool = False) -> list[str]: +def validate_tags( + tags: list[str], event_type: str | None = None, strict: bool = False +) -> list[str]: issues = [] for tag in tags: valid, msg = validate_tag(tag, strict=strict) @@ -247,12 +334,13 @@ def validate_tags(tags: list[str], event_type: str | None = None, present_prefixes = {t.split(":")[0] for t in tags if ":" in t} for prefix, spec in TAXONOMY.items(): if event_type in spec.get("required_on", []) and prefix not in present_prefixes: - issues.append(f"Missing required tag '{prefix}:' for {event_type} events") + issues.append(f"Missing required tag '{prefix}:' for {event_type} events") return issues -def enrich_tags(tags: list[str], event_type: str | None = None, - data: dict | None = None) -> list[str]: +def enrich_tags( + tags: list[str], event_type: str | None = None, data: dict | None = None +) -> list[str]: enriched = list(tags) prefixes = {t.split(":")[0] for t in enriched if ":" in t} data = data or {} @@ -266,10 +354,15 @@ def enrich_tags(tags: list[str], event_type: str | None = None, if ot: # Normalize common variants to taxonomy values output_map = { - "email_draft": "email", "email_reply": "email", "follow_up": "email", - "call_script": "cold_call_script", "call_prep": "cold_call_script", - "demo_prep": "cheat_sheet", "meeting_prep": "cheat_sheet", - "linkedin_dm": "linkedin_message", "linkedin_connect": "linkedin_message", + "email_draft": "email", + "email_reply": "email", + "follow_up": "email", + "call_script": "cold_call_script", + "call_prep": "cold_call_script", + "demo_prep": "cheat_sheet", + "meeting_prep": "cheat_sheet", + "linkedin_dm": "linkedin_message", + "linkedin_connect": "linkedin_message", } normalized = output_map.get(ot, ot) enriched.append(f"output:{normalized}") @@ -280,9 +373,12 @@ def enrich_tags(tags: list[str], event_type: str | None = None, if event_type == "DELTA_TAG" and "channel" not in prefixes: activity = data.get("activity_type", "") channel_map = { - "email_sent": "email", "email_received": "email", - "call": "phone", "meeting": "meeting", - "demo": "demo", "linkedin": "linkedin", + "email_sent": "email", + "email_received": "email", + "call": "phone", + "meeting": "meeting", + "demo": "demo", + "linkedin": "linkedin", } if activity in channel_map: enriched.append(f"channel:{channel_map[activity]}") From 55a4ea3799bb31f0fa0ea419f0c3017060823779 Mon Sep 17 00:00:00 2001 From: Oliver Le Date: Fri, 10 Apr 2026 14:26:09 -0700 Subject: [PATCH 13/13] =?UTF-8?q?fix:=20Rosch=20loop=20variable=20=5Fp=20s?= =?UTF-8?q?hadowed=20=5Fpaths=20import=20=E2=80=94=20broke=20taxonomy=20re?= =?UTF-8?q?load?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The module-level `for _p in ROSCH_PARENTS` overwrote the `import gradata._paths as _p` alias, causing `_load_taxonomy()` to fail silently (hasattr check on a string). Renamed loop var to `_rp`. Co-Authored-By: Gradata --- src/gradata/_tag_taxonomy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gradata/_tag_taxonomy.py b/src/gradata/_tag_taxonomy.py index 275d2329..429e7a05 100644 --- a/src/gradata/_tag_taxonomy.py +++ b/src/gradata/_tag_taxonomy.py @@ -97,8 +97,8 @@ for _parent, _subs in ROSCH_HIERARCHY.items(): for _sub in _subs: _SUB_TO_PARENT.setdefault(_sub, _parent) -for _p in ROSCH_PARENTS: - _SUB_TO_PARENT[_p] = _p +for _rp in ROSCH_PARENTS: + _SUB_TO_PARENT[_rp] = _rp def parent_category(category: str) -> str | None: