diff --git a/dhee/cli.py b/dhee/cli.py index b06e421..6102c5e 100644 --- a/dhee/cli.py +++ b/dhee/cli.py @@ -77,6 +77,39 @@ def _get_vector_store(): ) +class _NoopMemoryOSClient: + def remember(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: + return {"stored": False} + + def recall(self, *args: Any, **kwargs: Any) -> list: + return [] + + def recent(self, *args: Any, **kwargs: Any) -> list: + return [] + + +def _get_memory_os_service(*, with_memory: bool = False): + from pathlib import Path + + from dhee.world_memory.capture_store import CaptureStore + from dhee.world_memory.causal_graph import CausalGraphProjection + from dhee.world_memory.service import DheeMemoryClient, MemoryOSService + from dhee.world_memory.session_graph import SessionGraphStore + from dhee.world_memory.store import WorldMemoryStore + + runtime_root = Path(os.environ.get("DHEE_DATA_DIR") or (Path.home() / ".dhee")).expanduser() + memory_os_dir = runtime_root / "memory_os" + memory_os_dir.mkdir(parents=True, exist_ok=True) + memory_client = DheeMemoryClient(_get_memory()) if with_memory else _NoopMemoryOSClient() + return MemoryOSService( + capture_store=CaptureStore(str(memory_os_dir / "capture.db")), + world_store=WorldMemoryStore(str(memory_os_dir / "world_memory.db")), + graph_store=SessionGraphStore(str(runtime_root / "capture" / "sessions")), + memory_client=memory_client, + graph_projection=CausalGraphProjection(str(memory_os_dir / "causal_scene.kuzu")), + ) + + # --------------------------------------------------------------------------- # Command handlers # --------------------------------------------------------------------------- @@ -498,6 +531,89 @@ def _store() -> ContextStateStore: print(f"Rollover required: {'yes' if data['rollover_required'] else 'no'}") return + if action == "repo-brain": + from dhee import repo_intelligence + + subaction = args.entry_id or "index" + extra = list(getattr(args, "context_args", []) or []) + brain_actions = {"index", "show", "get", "localize"} + if subaction not in brain_actions: + extra = [subaction, *extra] + subaction = "localize" + repo = args.repo or os.getcwd() + if subaction == "index": + goal = " ".join(str(item) for item in extra).strip() + brain = repo_intelligence.build_repo_brain( + repo, + goal=goal, + must_run=getattr(args, "must_run", None), + persist=True, + ) + result = { + "format": "dhee_repo_brain_index.v1", + "repo_intelligence": repo_intelligence.repo_brain_summary(brain), + "localization": repo_intelligence.localize_issue(goal, brain) if goal else None, + } + if args.json: + _json_out(result) + return + summary = result["repo_intelligence"] + print(f"Repo brain indexed: {summary.get('ref')}") + print(f" path {summary.get('path')}") + print(f" files {summary.get('file_count')} ({summary.get('indexed_file_count')} indexed)") + print(f" symbols {summary.get('symbol_count')}") + print(f" tests {summary.get('test_count')}") + print(f" calls {summary.get('call_edge_count')}") + if result.get("localization"): + loc = result["localization"] + print(f" localized {loc.get('status')} confidence={loc.get('confidence')}") + return + if subaction in {"show", "get"}: + ref = extra[0] if extra else None + result = repo_intelligence.load_repo_brain(repo, ref=ref) + brain = result.get("brain") if isinstance(result.get("brain"), dict) else None + if brain: + result["repo_intelligence"] = repo_intelligence.repo_brain_summary(brain) + if not args.json: + result["brain"] = None + if args.json: + _json_out(result) + return + if not brain: + print("Repo brain not found. Run `dhee context repo-brain index` first.") + return + summary = result["repo_intelligence"] + print(f"Repo brain: {summary.get('ref')}") + print(f" path {summary.get('path')}") + print(f" head {summary.get('head_commit')}") + print(f" files {summary.get('file_count')} ({summary.get('indexed_file_count')} indexed)") + print(f" symbols {summary.get('symbol_count')}") + print(f" tests {summary.get('test_count')}") + print(f" failures {summary.get('historical_failure_count')}") + return + if subaction == "localize": + goal = " ".join(str(item) for item in extra).strip() + if not goal: + print("Pass a goal: dhee context repo-brain localize \"Fix failing context firewall tests\"") + sys.exit(1) + loaded = repo_intelligence.load_repo_brain(repo) + brain = loaded.get("brain") if isinstance(loaded.get("brain"), dict) else None + if not brain: + brain = repo_intelligence.build_repo_brain(repo, goal=goal, persist=True) + localization = repo_intelligence.localize_issue(goal, brain) + result = { + "format": "dhee_repo_brain_localize.v1", + "repo_intelligence": repo_intelligence.repo_brain_summary(brain), + "localization": localization, + } + if args.json: + _json_out(result) + return + print(f"Localization: {localization.get('status')} confidence={localization.get('confidence')}") + for item in localization.get("candidate_files") or []: + print(f" {item.get('confidence'):.2f} {item.get('path')} {', '.join(item.get('reasons') or [])}") + return + if action == "task": from dhee import task_contracts @@ -1397,6 +1513,176 @@ def cmd_handoff(args: argparse.Namespace) -> None: _json_out(snapshot) +def cmd_graph(args: argparse.Namespace) -> None: + """Manage the Kuzu causal-scene graph projection.""" + service = _get_memory_os_service(with_memory=False) + projection = service.graph_projection + if projection is None: + raise RuntimeError("Causal graph projection is not configured") + + action = args.graph_action + user_id = getattr(args, "user_id", "default") + if action == "sync": + result = projection.sync(service.capture_store, user_id=user_id) + elif action == "rebuild": + result = projection.rebuild(service.capture_store, user_id=user_id) + elif action == "verify": + result = projection.verify(service.capture_store, user_id=user_id) + elif action == "show-event": + if not args.target: + raise ValueError("show-event requires an event_id") + result = projection.show_event(args.target) + elif action == "show-cone": + if not args.target: + raise ValueError("show-cone requires an event_id") + result = projection.show_cone( + args.target, + direction=getattr(args, "direction", "backward"), + depth=int(getattr(args, "depth", 3) or 3), + ) + elif action == "show-scene": + if not args.target: + raise ValueError("show-scene requires a scene_id") + result = projection.show_scene(args.target) + elif action == "show-thread": + if not args.target: + raise ValueError("show-thread requires a thread_id") + result = projection.show_thread(args.target) + elif action == "explain-retrieval": + if not args.target: + raise ValueError("explain-retrieval requires a query_id") + result = service.explain_causal_retrieval(args.target, user_id=user_id) + elif action == "prune-traces": + result = service.capture_store.prune_retrieval_traces( + user_id=user_id, + older_than_days=getattr(args, "older_than_days", None), + keep_latest=int(getattr(args, "keep_latest", 1000) or 0), + dry_run=bool(getattr(args, "dry_run", False)), + ) + else: + raise ValueError(f"Unknown graph action: {action}") + + if args.json: + _json_out(result) + return + if isinstance(result, dict) and "ok" in result: + print(f"graph {action}: {'ok' if result.get('ok') else 'failed'}") + for error in result.get("errors") or []: + print(f" - {error}") + return + if isinstance(result, dict) and "counts" in result: + print(f"graph {action}: {result.get('status', 'ok')} ({result.get('backend')})") + for key, value in sorted((result.get("counts") or {}).items()): + print(f" {key}: {value}") + return + _json_out(result) + + +def cmd_causal(args: argparse.Namespace) -> None: + """Run causal-scene checkpoint and retrieval commands.""" + action = args.causal_action + service = _get_memory_os_service(with_memory=(action == "checkpoint")) + user_id = getattr(args, "user_id", "default") + scope = getattr(args, "scope", "global") + + if action == "checkpoint": + result = service.compile_causal_checkpoint( + session_id=getattr(args, "session", None), + user_id=user_id, + time_window_start=getattr(args, "start", None), + time_window_end=getattr(args, "end", None), + ) + elif action == "frontier": + result = service.get_active_frontier(user_id=user_id, scope=scope) + elif action == "why": + result = service.causal_why( + event_id=getattr(args, "event_id", None), + query=getattr(args, "query", "") or "", + user_id=user_id, + scope=scope, + ) + elif action == "what-happened": + if not args.target: + raise ValueError("what-happened requires a scene, episode, or thread id") + result = service.causal_what_happened(target_id=args.target, user_id=user_id, scope=scope) + elif action == "handoff": + result = service.causal_handoff(user_id=user_id, scope=scope) + elif action == "preference": + result = service.causal_preference(query=getattr(args, "query", "") or "", user_id=user_id, scope=scope) + elif action == "gems": + result = service.causal_gems( + user_id=user_id, + scope=scope, + kind=getattr(args, "kind", None), + limit=int(getattr(args, "gem_limit", 50) or 50), + ) + elif action == "show-gem": + if not args.target: + raise ValueError("show-gem requires a gem id or gem RawEvent id") + result = service.causal_show_gem(args.target, user_id=user_id, scope=scope) + elif action == "submit-gem": + if not args.target: + raise ValueError("submit-gem requires a gem id or gem RawEvent id") + from dhee.core.learnings import LearningExchange + + result = service.causal_submit_gem( + args.target, + learning_exchange=LearningExchange(), + user_id=user_id, + scope=scope, + repo=getattr(args, "repo", None), + status=getattr(args, "learning_status", "candidate") or "candidate", + ) + elif action == "extract-gems": + from dhee.core.learnings import LearningExchange + from dhee.world_memory.gem_extractor import ( + extract_memory_gems, + submit_gem_learning_candidates, + summarize_gems, + write_gem_raw_events, + ) + + db = _get_db() + memories = db.get_all_memories( + user_id=user_id, + limit=int(getattr(args, "limit", 500) or 500), + include_tombstoned=False, + ) + gems = extract_memory_gems( + memories, + user_id=user_id, + limit=int(getattr(args, "gem_limit", 50) or 50), + min_score=float(getattr(args, "min_score", 0.62) or 0.62), + ) + write_report = {"written": [], "skipped_existing": []} + learning_report = {"submitted": [], "rejected": []} + if not getattr(args, "dry_run", False): + write_report = write_gem_raw_events(service.capture_store, gems) + if service.graph_projection: + service.graph_projection.sync(service.capture_store, user_id=user_id) + if getattr(args, "submit_learnings", False): + learning_report = submit_gem_learning_candidates( + LearningExchange(), + gems, + repo=getattr(args, "repo", None), + ) + result = { + "status": "dry_run" if getattr(args, "dry_run", False) else "extracted", + "scanned_memories": len(memories), + "min_score": float(getattr(args, "min_score", 0.62) or 0.62), + "summary": summarize_gems(gems), + "raw_events": write_report, + "learning_candidates": learning_report, + } + else: + raise ValueError(f"Unknown causal action: {action}") + + if args.json: + _json_out(result) + return + _json_out(result) + + def cmd_thread_state(args: argparse.Namespace) -> None: """Read, update, or clear thread-native continuity state.""" db = _get_db() @@ -3213,7 +3499,7 @@ def build_parser() -> argparse.ArgumentParser: p_context.add_argument( "context_action", nargs="?", - choices=["list", "show", "delete", "refresh", "check", "status", "state", "checkpoint", "rollover", "provision", "debt", "capsule", "task"], + choices=["list", "show", "delete", "refresh", "check", "status", "state", "checkpoint", "rollover", "provision", "debt", "capsule", "repo-brain", "task"], default="list", help="Subcommand (default: list)", ) @@ -3571,6 +3857,73 @@ def build_parser() -> argparse.ArgumentParser: p_assets.add_argument("--user-id", default="default", help="User ID") p_assets.add_argument("--json", action="store_true", help="JSON output") + # graph — Kuzu causal-scene projection lifecycle and inspection + p_graph = sub.add_parser("graph", help="Manage the Kuzu causal-scene graph projection") + p_graph.add_argument( + "graph_action", + choices=[ + "sync", + "rebuild", + "verify", + "show-event", + "show-cone", + "show-scene", + "show-thread", + "explain-retrieval", + "prune-traces", + ], + help="Graph subcommand", + ) + p_graph.add_argument("target", nargs="?", help="Event, scene, thread, or retrieval id") + p_graph.add_argument("--direction", choices=["backward", "forward"], default="backward", help="For show-cone") + p_graph.add_argument("--depth", type=int, default=3, help="For show-cone") + p_graph.add_argument("--older-than-days", type=int, help="For prune-traces: delete traces older than this many days") + p_graph.add_argument("--keep-latest", type=int, default=1000, help="For prune-traces: always preserve this many newest traces") + p_graph.add_argument("--dry-run", action="store_true", help="For prune-traces: report candidates without deleting") + p_graph.add_argument("--user-id", default="default", help="User ID") + p_graph.add_argument("--json", action="store_true", help="JSON output") + + # causal — causal-scene checkpoint and retrieval modes + p_causal = sub.add_parser("causal", help="Run causal-scene checkpoint and retrieval modes") + p_causal.add_argument( + "causal_action", + choices=[ + "checkpoint", + "frontier", + "why", + "what-happened", + "handoff", + "preference", + "gems", + "show-gem", + "submit-gem", + "extract-gems", + ], + help="Causal subcommand", + ) + p_causal.add_argument("target", nargs="?", help="Target id for what-happened/show-gem/submit-gem") + p_causal.add_argument("--session", help="Session id for checkpoint") + p_causal.add_argument("--start", help="Start timestamp for checkpoint") + p_causal.add_argument("--end", help="End timestamp for checkpoint") + p_causal.add_argument("--event-id", help="Target event id for why") + p_causal.add_argument("--query", default="", help="Query for why/preference") + p_causal.add_argument("--user-id", default="default", help="User ID") + p_causal.add_argument("--scope", default="global", help="Privacy scope for retrieval") + p_causal.add_argument("--kind", help="For gems: filter by gem kind") + p_causal.add_argument("--limit", type=int, default=500, help="For extract-gems: memories to scan") + p_causal.add_argument("--gem-limit", type=int, default=50, help="For extract-gems: max gems to extract") + p_causal.add_argument("--min-score", type=float, default=0.62, help="For extract-gems: minimum gem score") + p_causal.add_argument("--dry-run", action="store_true", help="For extract-gems: score without writing RawEvents") + p_causal.add_argument("--submit-learnings", action="store_true", help="For extract-gems: also submit learning candidates") + p_causal.add_argument( + "--learning-status", + choices=["candidate", "promoted", "rejected", "archived"], + default="candidate", + help="For submit-gem: learning lifecycle status", + ) + p_causal.add_argument("--repo", help="For gem learning candidates: repo path") + p_causal.add_argument("--json", action="store_true", help="JSON output") + # router p_router = sub.add_parser("router", help="Context router (enable/disable/stats)") p_router.add_argument( @@ -3676,6 +4029,8 @@ def build_parser() -> argparse.ArgumentParser: "ingest": cmd_ingest, "docs": cmd_docs, "assets": cmd_assets, + "graph": cmd_graph, + "causal": cmd_causal, "replay-corpus": cmd_replay_corpus, "portability-eval": cmd_portability_eval, "decades-eval": cmd_decades_eval, diff --git a/dhee/world_memory/__init__.py b/dhee/world_memory/__init__.py index 1cd24a5..d564f45 100644 --- a/dhee/world_memory/__init__.py +++ b/dhee/world_memory/__init__.py @@ -13,8 +13,22 @@ create_default_encoder, ) from .predictor import ActionConditionedPredictor, compute_surprise +from .causal_graph import CausalGraphProjection +from .gem_extractor import ( + GEM_SCHEMA_VERSION, + MemoryGem, + extract_memory_gems, + score_memory_gem, + submit_gem_learning_candidates, + submit_projected_gem_learning_candidate, + summarize_gems, + write_gem_raw_events, +) from .schema import ( + CAUSAL_PROJECTION_VERSION, + CAUSAL_SCHEMA_VERSION, ActionTransition, + CausalEdge, CaptureAction, CaptureEvent, CaptureLink, @@ -23,7 +37,11 @@ CapturedArtifact, CapturedObservation, CapturedSurface, + CheckpointReport, EvidenceChunk, + EventFrame, + RawEvent, + RetrievalTrace, TransitionMatch, WorldState, ) @@ -34,6 +52,11 @@ __all__ = [ "ActionConditionedPredictor", "ActionTransition", + "CAUSAL_PROJECTION_VERSION", + "CAUSAL_SCHEMA_VERSION", + "CausalEdge", + "CausalGraphProjection", + "GEM_SCHEMA_VERSION", "CaptureAction", "CaptureEvent", "CaptureLink", @@ -43,15 +66,26 @@ "CapturedArtifact", "CapturedObservation", "CapturedSurface", + "CheckpointReport", "ContentAwareFrameEncoder", "DeterministicFrameEncoder", "EvidenceChunk", + "EventFrame", + "MemoryGem", "MemoryOSService", "NvidiaVLFrameEncoder", "SessionGraphStore", + "RawEvent", + "RetrievalTrace", "TransitionMatch", "WorldMemoryStore", "WorldState", "compute_surprise", "create_default_encoder", + "extract_memory_gems", + "score_memory_gem", + "submit_gem_learning_candidates", + "submit_projected_gem_learning_candidate", + "summarize_gems", + "write_gem_raw_events", ] diff --git a/dhee/world_memory/capture_store.py b/dhee/world_memory/capture_store.py index 4b0eaf6..10f3a38 100644 --- a/dhee/world_memory/capture_store.py +++ b/dhee/world_memory/capture_store.py @@ -5,10 +5,23 @@ import sqlite3 import uuid from contextlib import contextmanager -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from typing import Any, Dict, Iterable, List, Optional -from .schema import CaptureEvent, CapturePolicy, CaptureSession +from .schema import ( + AUTOMATIC_CAUSAL_EDGE_TYPES, + CAUSAL_EDGE_STATUSES, + CAUSAL_SCHEMA_VERSION, + CHECKPOINT_CAUSAL_EDGE_TYPES, + CausalEdge, + CaptureEvent, + CapturePolicy, + CaptureSession, + CheckpointReport, + EventFrame, + RawEvent, + RetrievalTrace, +) def _utcnow() -> str: @@ -84,6 +97,115 @@ def _init_db(self) -> None: metadata_json TEXT NOT NULL, updated_at TEXT NOT NULL ); + + CREATE TABLE IF NOT EXISTS raw_events ( + id TEXT PRIMARY KEY, + schema_version TEXT NOT NULL, + user_id TEXT NOT NULL, + session_id TEXT, + source_app TEXT NOT NULL, + namespace TEXT NOT NULL, + event_type TEXT NOT NULL, + timestamp TEXT NOT NULL, + content_ref TEXT, + content_hash TEXT, + privacy_scope TEXT NOT NULL, + metadata_json TEXT NOT NULL, + deleted_at TEXT, + redacted_at TEXT, + redaction_reason TEXT, + created_at TEXT NOT NULL + ); + CREATE INDEX IF NOT EXISTS idx_raw_events_user_timestamp + ON raw_events(user_id, timestamp DESC); + CREATE INDEX IF NOT EXISTS idx_raw_events_session_timestamp + ON raw_events(session_id, timestamp DESC); + CREATE INDEX IF NOT EXISTS idx_raw_events_scope + ON raw_events(user_id, privacy_scope); + + CREATE TABLE IF NOT EXISTS event_frames ( + id TEXT PRIMARY KEY, + schema_version TEXT NOT NULL, + user_id TEXT NOT NULL, + frame_type TEXT NOT NULL, + summary TEXT NOT NULL, + source_event_ids_json TEXT NOT NULL, + confidence REAL NOT NULL, + privacy_scope TEXT NOT NULL, + created_at TEXT NOT NULL, + deleted_at TEXT, + redacted_at TEXT, + redaction_reason TEXT, + metadata_json TEXT NOT NULL + ); + CREATE INDEX IF NOT EXISTS idx_event_frames_user_created + ON event_frames(user_id, created_at DESC); + + CREATE TABLE IF NOT EXISTS causal_edges ( + id TEXT PRIMARY KEY, + schema_version TEXT NOT NULL, + user_id TEXT NOT NULL, + source_id TEXT NOT NULL, + target_id TEXT NOT NULL, + edge_type TEXT NOT NULL, + confidence REAL NOT NULL, + status TEXT NOT NULL, + evidence_event_ids_json TEXT NOT NULL, + inferred_by TEXT NOT NULL, + explanation TEXT NOT NULL, + privacy_scope TEXT NOT NULL, + created_at TEXT NOT NULL, + deleted_at TEXT, + redacted_at TEXT, + redaction_reason TEXT, + metadata_json TEXT NOT NULL + ); + CREATE INDEX IF NOT EXISTS idx_causal_edges_user_type + ON causal_edges(user_id, edge_type); + CREATE INDEX IF NOT EXISTS idx_causal_edges_source + ON causal_edges(source_id); + CREATE INDEX IF NOT EXISTS idx_causal_edges_target + ON causal_edges(target_id); + + CREATE TABLE IF NOT EXISTS causal_checkpoint_reports ( + id TEXT PRIMARY KEY, + schema_version TEXT NOT NULL, + user_id TEXT NOT NULL, + session_id TEXT, + time_window_start TEXT, + time_window_end TEXT, + status TEXT NOT NULL, + event_frame_ids_json TEXT NOT NULL, + causal_edge_ids_json TEXT NOT NULL, + summary_memory_id TEXT, + report_json TEXT NOT NULL, + created_at TEXT NOT NULL + ); + CREATE INDEX IF NOT EXISTS idx_checkpoint_reports_user_created + ON causal_checkpoint_reports(user_id, created_at DESC); + + CREATE TABLE IF NOT EXISTS retrieval_traces ( + id TEXT PRIMARY KEY, + schema_version TEXT NOT NULL, + user_id TEXT NOT NULL, + mode TEXT NOT NULL, + scope TEXT NOT NULL, + query TEXT NOT NULL, + target_id TEXT NOT NULL, + retrieval_path_json TEXT NOT NULL, + evidence_json TEXT NOT NULL, + result_json TEXT NOT NULL, + privacy_scope TEXT NOT NULL, + metadata_json TEXT NOT NULL, + created_at TEXT NOT NULL, + deleted_at TEXT, + redacted_at TEXT, + redaction_reason TEXT + ); + CREATE INDEX IF NOT EXISTS idx_retrieval_traces_user_created + ON retrieval_traces(user_id, created_at DESC); + CREATE INDEX IF NOT EXISTS idx_retrieval_traces_mode + ON retrieval_traces(user_id, mode); """ ) @@ -273,6 +395,417 @@ def get_policy(self, source_app: str) -> Optional[CapturePolicy]: ).fetchone() return _row_to_policy(row) if row else None + def record_raw_event(self, event: RawEvent) -> RawEvent: + """Append one immutable raw event to the SQLite truth layer.""" + with self._tx() as conn: + conn.execute( + """ + INSERT INTO raw_events ( + id, schema_version, user_id, session_id, source_app, namespace, + event_type, timestamp, content_ref, content_hash, privacy_scope, + metadata_json, deleted_at, redacted_at, redaction_reason, created_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + event.id, + event.schema_version or CAUSAL_SCHEMA_VERSION, + event.user_id, + event.session_id, + event.source_app, + event.namespace, + event.event_type, + event.timestamp, + event.content_ref, + event.content_hash, + event.privacy_scope, + json.dumps(event.metadata), + event.deleted_at, + event.redacted_at, + event.redaction_reason, + _utcnow(), + ), + ) + return event + + def get_raw_event(self, event_id: str) -> Optional[RawEvent]: + with self._tx() as conn: + row = conn.execute( + "SELECT * FROM raw_events WHERE id = ? LIMIT 1", + (event_id,), + ).fetchone() + return _row_to_raw_event(row) if row else None + + def list_raw_events( + self, + *, + user_id: str = "default", + session_id: Optional[str] = None, + source_app: Optional[str] = None, + privacy_scopes: Optional[Iterable[str]] = None, + limit: int = 100, + include_deleted: bool = False, + include_redacted: bool = False, + order: str = "desc", + ) -> List[RawEvent]: + query = "SELECT * FROM raw_events WHERE user_id = ?" + params: List[Any] = [user_id] + if session_id: + query += " AND session_id = ?" + params.append(session_id) + if source_app: + query += " AND source_app = ?" + params.append(source_app) + scopes = [str(scope) for scope in (privacy_scopes or []) if str(scope).strip()] + if scopes: + query += f" AND privacy_scope IN ({','.join('?' for _ in scopes)})" + params.extend(scopes) + if not include_deleted: + query += " AND deleted_at IS NULL" + if not include_redacted: + query += " AND redacted_at IS NULL" + direction = "ASC" if str(order).lower() == "asc" else "DESC" + query += f" ORDER BY timestamp {direction} LIMIT ?" + params.append(int(limit)) + with self._tx() as conn: + rows = conn.execute(query, params).fetchall() + return [_row_to_raw_event(row) for row in rows] + + def redact_raw_event( + self, + event_id: str, + *, + redacted_at: Optional[str] = None, + reason: str = "", + delete: bool = False, + ) -> Optional[RawEvent]: + existing = self.get_raw_event(event_id) + if not existing: + return None + now = redacted_at or _utcnow() + with self._tx() as conn: + conn.execute( + """ + UPDATE raw_events + SET redacted_at = ?, redaction_reason = ?, deleted_at = COALESCE(?, deleted_at) + WHERE id = ? + """, + (now, reason, now if delete else None, event_id), + ) + conn.execute( + """ + UPDATE event_frames + SET redacted_at = COALESCE(redacted_at, ?), + redaction_reason = COALESCE(NULLIF(redaction_reason, ''), ?) + WHERE source_event_ids_json LIKE ? + """, + (now, reason, f'%"{event_id}"%'), + ) + conn.execute( + """ + UPDATE causal_edges + SET redacted_at = COALESCE(redacted_at, ?), + redaction_reason = COALESCE(NULLIF(redaction_reason, ''), ?) + WHERE evidence_event_ids_json LIKE ? + """, + (now, reason, f'%"{event_id}"%'), + ) + conn.execute( + """ + UPDATE retrieval_traces + SET redacted_at = COALESCE(redacted_at, ?), + redaction_reason = COALESCE(NULLIF(redaction_reason, ''), ?), + deleted_at = COALESCE(?, deleted_at) + WHERE target_id = ? + OR retrieval_path_json LIKE ? + OR evidence_json LIKE ? + OR result_json LIKE ? + """, + ( + now, + reason, + now if delete else None, + event_id, + f'%"{event_id}"%', + f'%"{event_id}"%', + f'%"{event_id}"%', + ), + ) + return self.get_raw_event(event_id) + + def add_event_frame(self, frame: EventFrame) -> EventFrame: + with self._tx() as conn: + conn.execute( + """ + INSERT INTO event_frames ( + id, schema_version, user_id, frame_type, summary, + source_event_ids_json, confidence, privacy_scope, created_at, + deleted_at, redacted_at, redaction_reason, metadata_json + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(id) DO UPDATE SET + summary = excluded.summary, + source_event_ids_json = excluded.source_event_ids_json, + confidence = excluded.confidence, + privacy_scope = excluded.privacy_scope, + deleted_at = excluded.deleted_at, + redacted_at = excluded.redacted_at, + redaction_reason = excluded.redaction_reason, + metadata_json = excluded.metadata_json + """, + ( + frame.id, + frame.schema_version or CAUSAL_SCHEMA_VERSION, + frame.user_id, + frame.frame_type, + frame.summary, + json.dumps(frame.source_event_ids), + float(frame.confidence), + frame.privacy_scope, + frame.created_at, + frame.deleted_at, + frame.redacted_at, + frame.redaction_reason, + json.dumps(frame.metadata), + ), + ) + return frame + + def list_event_frames( + self, + *, + user_id: str = "default", + limit: int = 100, + include_deleted: bool = False, + include_redacted: bool = False, + ) -> List[EventFrame]: + query = "SELECT * FROM event_frames WHERE user_id = ?" + params: List[Any] = [user_id] + if not include_deleted: + query += " AND deleted_at IS NULL" + if not include_redacted: + query += " AND redacted_at IS NULL" + query += " ORDER BY created_at DESC LIMIT ?" + params.append(int(limit)) + with self._tx() as conn: + rows = conn.execute(query, params).fetchall() + return [_row_to_event_frame(row) for row in rows] + + def add_causal_edge(self, edge: CausalEdge) -> CausalEdge: + allowed_types = AUTOMATIC_CAUSAL_EDGE_TYPES | CHECKPOINT_CAUSAL_EDGE_TYPES + if edge.edge_type not in allowed_types: + raise ValueError(f"Unsupported causal edge type: {edge.edge_type}") + if edge.status not in CAUSAL_EDGE_STATUSES: + raise ValueError(f"Unsupported causal edge status: {edge.status}") + if edge.edge_type == "CAUSED" and not edge.evidence_event_ids: + raise ValueError("CAUSED edges require evidence_event_ids") + if not edge.evidence_event_ids and edge.edge_type not in AUTOMATIC_CAUSAL_EDGE_TYPES: + raise ValueError(f"{edge.edge_type} edges require evidence_event_ids") + with self._tx() as conn: + conn.execute( + """ + INSERT INTO causal_edges ( + id, schema_version, user_id, source_id, target_id, edge_type, + confidence, status, evidence_event_ids_json, inferred_by, + explanation, privacy_scope, created_at, deleted_at, redacted_at, + redaction_reason, metadata_json + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(id) DO UPDATE SET + confidence = excluded.confidence, + status = excluded.status, + evidence_event_ids_json = excluded.evidence_event_ids_json, + inferred_by = excluded.inferred_by, + explanation = excluded.explanation, + privacy_scope = excluded.privacy_scope, + deleted_at = excluded.deleted_at, + redacted_at = excluded.redacted_at, + redaction_reason = excluded.redaction_reason, + metadata_json = excluded.metadata_json + """, + ( + edge.id, + edge.schema_version or CAUSAL_SCHEMA_VERSION, + edge.user_id, + edge.source_id, + edge.target_id, + edge.edge_type, + float(edge.confidence), + edge.status, + json.dumps(edge.evidence_event_ids), + edge.inferred_by, + edge.explanation, + edge.privacy_scope, + edge.created_at, + edge.deleted_at, + edge.redacted_at, + edge.redaction_reason, + json.dumps(edge.metadata), + ), + ) + return edge + + def list_causal_edges( + self, + *, + user_id: str = "default", + edge_type: Optional[str] = None, + limit: int = 200, + include_deleted: bool = False, + include_redacted: bool = False, + ) -> List[CausalEdge]: + query = "SELECT * FROM causal_edges WHERE user_id = ?" + params: List[Any] = [user_id] + if edge_type: + query += " AND edge_type = ?" + params.append(edge_type) + if not include_deleted: + query += " AND deleted_at IS NULL" + if not include_redacted: + query += " AND redacted_at IS NULL" + query += " ORDER BY created_at DESC LIMIT ?" + params.append(int(limit)) + with self._tx() as conn: + rows = conn.execute(query, params).fetchall() + return [_row_to_causal_edge(row) for row in rows] + + def add_checkpoint_report(self, report: CheckpointReport) -> CheckpointReport: + with self._tx() as conn: + conn.execute( + """ + INSERT INTO causal_checkpoint_reports ( + id, schema_version, user_id, session_id, time_window_start, + time_window_end, status, event_frame_ids_json, + causal_edge_ids_json, summary_memory_id, report_json, created_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + report.id, + report.schema_version or CAUSAL_SCHEMA_VERSION, + report.user_id, + report.session_id, + report.time_window_start, + report.time_window_end, + report.status, + json.dumps(report.event_frame_ids), + json.dumps(report.causal_edge_ids), + report.summary_memory_id, + json.dumps(report.report), + report.created_at, + ), + ) + return report + + def add_retrieval_trace(self, trace: RetrievalTrace) -> RetrievalTrace: + with self._tx() as conn: + conn.execute( + """ + INSERT INTO retrieval_traces ( + id, schema_version, user_id, mode, scope, query, target_id, + retrieval_path_json, evidence_json, result_json, privacy_scope, + metadata_json, created_at, deleted_at, redacted_at, redaction_reason + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(id) DO UPDATE SET + retrieval_path_json = excluded.retrieval_path_json, + evidence_json = excluded.evidence_json, + result_json = excluded.result_json, + privacy_scope = excluded.privacy_scope, + metadata_json = excluded.metadata_json, + deleted_at = excluded.deleted_at, + redacted_at = excluded.redacted_at, + redaction_reason = excluded.redaction_reason + """, + ( + trace.id, + trace.schema_version or CAUSAL_SCHEMA_VERSION, + trace.user_id, + trace.mode, + trace.scope, + trace.query, + trace.target_id, + json.dumps(trace.retrieval_path), + json.dumps(trace.evidence), + json.dumps(trace.result), + trace.privacy_scope, + json.dumps(trace.metadata), + trace.created_at, + trace.deleted_at, + trace.redacted_at, + trace.redaction_reason, + ), + ) + return trace + + def get_retrieval_trace( + self, + trace_id: str, + *, + include_deleted: bool = False, + include_redacted: bool = False, + ) -> Optional[RetrievalTrace]: + query = "SELECT * FROM retrieval_traces WHERE id = ?" + params: List[Any] = [trace_id] + if not include_deleted: + query += " AND deleted_at IS NULL" + if not include_redacted: + query += " AND redacted_at IS NULL" + query += " LIMIT 1" + with self._tx() as conn: + row = conn.execute(query, params).fetchone() + return _row_to_retrieval_trace(row) if row else None + + def prune_retrieval_traces( + self, + *, + user_id: str = "default", + older_than_days: Optional[int] = None, + keep_latest: int = 1000, + dry_run: bool = False, + ) -> Dict[str, Any]: + keep_count = max(0, int(keep_latest or 0)) + cutoff = None + if older_than_days is not None: + cutoff = (datetime.now(timezone.utc) - timedelta(days=max(0, int(older_than_days)))).isoformat() + + with self._tx() as conn: + rows = conn.execute( + """ + SELECT id, created_at, redacted_at, deleted_at + FROM retrieval_traces + WHERE user_id = ? + ORDER BY created_at DESC + """, + (user_id,), + ).fetchall() + protected_ids = {str(row["id"]) for row in rows[:keep_count]} + candidates: List[str] = [] + for row in rows: + trace_id = str(row["id"]) + if trace_id in protected_ids: + continue + outside_keep = keep_count >= 0 + old_enough = bool(cutoff and str(row["created_at"] or "") < cutoff) + if cutoff is None: + should_prune = outside_keep + else: + should_prune = old_enough + if should_prune: + candidates.append(trace_id) + if candidates and not dry_run: + conn.execute( + f"DELETE FROM retrieval_traces WHERE id IN ({','.join('?' for _ in candidates)})", + candidates, + ) + + return { + "user_id": user_id, + "dry_run": bool(dry_run), + "older_than_days": older_than_days, + "keep_latest": keep_count, + "total_traces": len(rows), + "protected_latest": min(keep_count, len(rows)), + "candidate_count": len(candidates), + "pruned_count": 0 if dry_run else len(candidates), + "candidate_ids": candidates[:50], + } + def _row_to_session(row: sqlite3.Row) -> CaptureSession: return CaptureSession( @@ -319,6 +852,96 @@ def _row_to_policy(row: sqlite3.Row) -> CapturePolicy: ) +def _row_to_raw_event(row: sqlite3.Row) -> RawEvent: + return RawEvent( + id=row["id"], + schema_version=row["schema_version"], + user_id=row["user_id"], + session_id=row["session_id"], + source_app=row["source_app"], + namespace=row["namespace"], + event_type=row["event_type"], + timestamp=row["timestamp"], + content_ref=row["content_ref"], + content_hash=row["content_hash"], + privacy_scope=row["privacy_scope"], + metadata=_loads_dict(row["metadata_json"]), + deleted_at=row["deleted_at"], + redacted_at=row["redacted_at"], + redaction_reason=row["redaction_reason"], + ) + + +def _row_to_event_frame(row: sqlite3.Row) -> EventFrame: + return EventFrame( + id=row["id"], + schema_version=row["schema_version"], + user_id=row["user_id"], + frame_type=row["frame_type"], + summary=row["summary"], + source_event_ids=[str(item) for item in _loads_list(row["source_event_ids_json"])], + confidence=float(row["confidence"]), + privacy_scope=row["privacy_scope"], + created_at=row["created_at"], + deleted_at=row["deleted_at"], + redacted_at=row["redacted_at"], + redaction_reason=row["redaction_reason"], + metadata=_loads_dict(row["metadata_json"]), + ) + + +def _row_to_causal_edge(row: sqlite3.Row) -> CausalEdge: + return CausalEdge( + id=row["id"], + schema_version=row["schema_version"], + user_id=row["user_id"], + source_id=row["source_id"], + target_id=row["target_id"], + edge_type=row["edge_type"], + confidence=float(row["confidence"]), + status=row["status"], + evidence_event_ids=[str(item) for item in _loads_list(row["evidence_event_ids_json"])], + inferred_by=row["inferred_by"], + explanation=row["explanation"], + privacy_scope=row["privacy_scope"], + created_at=row["created_at"], + deleted_at=row["deleted_at"], + redacted_at=row["redacted_at"], + redaction_reason=row["redaction_reason"], + metadata=_loads_dict(row["metadata_json"]), + ) + + +def _row_to_retrieval_trace(row: sqlite3.Row) -> RetrievalTrace: + return RetrievalTrace( + id=row["id"], + schema_version=row["schema_version"], + user_id=row["user_id"], + mode=row["mode"], + scope=row["scope"], + query=row["query"], + target_id=row["target_id"], + retrieval_path=[ + item for item in _loads_list(row["retrieval_path_json"]) if isinstance(item, dict) + ], + evidence=[ + item for item in _loads_list(row["evidence_json"]) if isinstance(item, dict) + ], + result=_loads_dict(row["result_json"]), + privacy_scope=row["privacy_scope"], + metadata=_loads_dict(row["metadata_json"]), + created_at=row["created_at"], + deleted_at=row["deleted_at"], + redacted_at=row["redacted_at"], + redaction_reason=row["redaction_reason"], + ) + + def _loads_dict(raw: str) -> Dict[str, Any]: value = json.loads(raw or "{}") return value if isinstance(value, dict) else {} + + +def _loads_list(raw: str) -> List[Any]: + value = json.loads(raw or "[]") + return value if isinstance(value, list) else [] diff --git a/dhee/world_memory/causal_graph.py b/dhee/world_memory/causal_graph.py new file mode 100644 index 0000000..3ebfb2f --- /dev/null +++ b/dhee/world_memory/causal_graph.py @@ -0,0 +1,1921 @@ +from __future__ import annotations + +import json +import re +import shutil +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple + +import kuzu + +from .capture_store import CaptureStore +from .schema import ( + CAUSAL_PROJECTION_VERSION, + CAUSAL_SCHEMA_VERSION, + CHECKPOINT_CAUSAL_EDGE_TYPES, + CausalEdge, + EventFrame, + RawEvent, +) + + +AUTOMATIC_REL_TABLES = { + "TEMPORAL_NEXT": ("RawEvent", "RawEvent"), + "OBSERVED_ON": ("RawEvent", "Surface"), + "MENTIONS": ("RawEvent", "Entity"), + "CREATED": ("RawEvent", "Artifact"), + "UPDATED": ("RawEvent", "Memory"), + "BELONGS_TO": ("RawEvent", "Scene"), + "PROJECTED_INTO": ("RawEvent", "MemoryThread"), +} + + +class CausalGraphProjection: + """Kuzu projection over the SQLite causal-scene truth tables.""" + + def __init__(self, db_path: str): + path = Path(db_path).expanduser() + if path.suffix: + self.db_path = path + self.root_dir = path.parent + else: + self.root_dir = path + self.db_path = path / "causal_scene.kuzu" + self.root_dir.mkdir(parents=True, exist_ok=True) + + def rebuild(self, capture_store: CaptureStore, *, user_id: str = "default") -> Dict[str, Any]: + self.delete() + conn = self._connect() + try: + self._create_schema(conn) + events = capture_store.list_raw_events( + user_id=user_id, + limit=100_000, + include_deleted=False, + include_redacted=False, + order="asc", + ) + frames = capture_store.list_event_frames( + user_id=user_id, + limit=100_000, + include_deleted=False, + include_redacted=False, + ) + edges = capture_store.list_causal_edges( + user_id=user_id, + limit=100_000, + include_deleted=False, + include_redacted=False, + ) + counts = self._project(conn, events=events, frames=frames, edges=edges) + return { + "status": "rebuilt", + "backend": "kuzu", + "db_path": str(self.db_path), + "schema_version": CAUSAL_SCHEMA_VERSION, + "projection_version": CAUSAL_PROJECTION_VERSION, + "counts": counts, + } + finally: + self._close(conn) + + def sync(self, capture_store: CaptureStore, *, user_id: str = "default") -> Dict[str, Any]: + # V1 keeps sync rebuild-backed so projection corruption never becomes memory loss. + result = self.rebuild(capture_store, user_id=user_id) + result["status"] = "synced" + return result + + def delete(self) -> None: + for candidate in (self.db_path, Path(str(self.db_path) + ".wal")): + if candidate.is_dir(): + shutil.rmtree(candidate) + elif candidate.exists(): + candidate.unlink() + + def verify(self, capture_store: CaptureStore, *, user_id: str = "default") -> Dict[str, Any]: + conn = self._connect() + errors: List[str] = [] + checks: Dict[str, Any] = {} + try: + self._create_schema(conn) + expected_events = capture_store.list_raw_events( + user_id=user_id, + limit=100_000, + include_deleted=False, + include_redacted=False, + order="asc", + ) + expected_ids = {event.id for event in expected_events} + graph_ids = set( + self._single_column( + conn, + """ + MATCH (e:RawEvent) + WHERE e.user_id = $user_id + RETURN e.id + """, + {"user_id": user_id}, + ) + ) + checks["node_count"] = { + "expected_raw_events": len(expected_ids), + "projected_raw_events": len(graph_ids), + "ok": len(expected_ids) == len(graph_ids), + } + missing = sorted(expected_ids - graph_ids) + extra = sorted(graph_ids - expected_ids) + if missing: + errors.append(f"missing RawEvent projection(s): {', '.join(missing[:5])}") + if extra: + errors.append(f"extra RawEvent projection(s): {', '.join(extra[:5])}") + + bad_schema = self._single_column( + conn, + """ + MATCH (e:RawEvent) + WHERE e.user_id = $user_id AND e.schema_version <> $schema_version + RETURN e.id + """, + {"user_id": user_id, "schema_version": CAUSAL_SCHEMA_VERSION}, + ) + checks["schema_version"] = {"bad_nodes": bad_schema, "ok": not bad_schema} + if bad_schema: + errors.append("schema-version mismatch in RawEvent projection") + + bad_projection = self._single_column( + conn, + """ + MATCH (e:RawEvent) + WHERE e.user_id = $user_id AND e.projection_version <> $projection_version + RETURN e.id + """, + {"user_id": user_id, "projection_version": CAUSAL_PROJECTION_VERSION}, + ) + checks["projection_version"] = {"bad_nodes": bad_projection, "ok": not bad_projection} + if bad_projection: + errors.append("projection-version mismatch in RawEvent projection") + + active_redactions = self._single_column( + conn, + """ + MATCH (e:RawEvent) + WHERE e.user_id = $user_id + AND (e.deleted_at <> '' OR e.redacted_at <> '') + RETURN e.id + """, + {"user_id": user_id}, + ) + checks["privacy"] = {"redacted_or_deleted_active_nodes": active_redactions, "ok": not active_redactions} + if active_redactions: + errors.append("redacted/deleted RawEvent projected as active") + + duplicate_rows = self._rows( + conn, + """ + MATCH (e:RawEvent) + WHERE e.user_id = $user_id + RETURN e.id, count(*) AS c + """, + {"user_id": user_id}, + ) + duplicates = [row[0] for row in duplicate_rows if len(row) > 1 and int(row[1] or 0) > 1] + checks["duplicate_nodes"] = {"duplicates": duplicates, "ok": not duplicates} + if duplicates: + errors.append("duplicate RawEvent projection nodes") + + orphan_errors = self._orphan_errors(conn) + checks["orphan_edges"] = {"errors": orphan_errors, "ok": not orphan_errors} + errors.extend(orphan_errors) + + source_ref_missing = self._single_column( + conn, + """ + MATCH (e:RawEvent) + WHERE e.user_id = $user_id AND e.sqlite_id = '' + RETURN e.id + """, + {"user_id": user_id}, + ) + checks["source_refs"] = {"missing": source_ref_missing, "ok": not source_ref_missing} + if source_ref_missing: + errors.append("RawEvent projection missing SQLite source refs") + + return { + "ok": not errors, + "backend": "kuzu", + "db_path": str(self.db_path), + "schema_version": CAUSAL_SCHEMA_VERSION, + "projection_version": CAUSAL_PROJECTION_VERSION, + "checks": checks, + "errors": errors, + } + finally: + self._close(conn) + + def show_event(self, event_id: str) -> Dict[str, Any]: + conn = self._connect() + try: + rows = self._rows( + conn, + """ + MATCH (e:RawEvent {id: $event_id}) + RETURN e.id, e.schema_version, e.projection_version, e.sqlite_id, + e.user_id, e.session_id, e.source_app, e.event_type, + e.timestamp, e.privacy_scope, e.metadata_json + """, + {"event_id": event_id}, + ) + if not rows: + return {"status": "not_found", "event_id": event_id} + row = rows[0] + return { + "status": "ok", + "event": { + "id": row[0], + "schema_version": row[1], + "projection_version": row[2], + "sqlite_id": row[3], + "user_id": row[4], + "session_id": row[5], + "source_app": row[6], + "event_type": row[7], + "timestamp": row[8], + "privacy_scope": row[9], + "metadata": _loads_dict(row[10]), + }, + "incoming": self._event_relations(conn, event_id, "incoming"), + "outgoing": self._event_relations(conn, event_id, "outgoing"), + "threads": self._threads_for_event(conn, event_id), + } + finally: + self._close(conn) + + def show_cone(self, event_id: str, *, direction: str = "backward", depth: int = 3) -> Dict[str, Any]: + conn = self._connect() + try: + frontier = [event_id] + seen = {event_id} + hops: List[Dict[str, Any]] = [] + for hop in range(max(int(depth), 1)): + next_frontier: List[str] = [] + for current in frontier: + relations = self._temporal_relations(conn, current, direction) + for relation in relations: + other = relation["from"] if direction == "backward" else relation["to"] + if other not in seen: + next_frontier.append(other) + seen.add(other) + hops.append({"hop": hop + 1, **relation}) + frontier = next_frontier + if not frontier: + break + return { + "event_id": event_id, + "direction": direction, + "depth": depth, + "causal_path": hops, + "evidence": [{"event_id": item} for item in sorted(seen)], + } + finally: + self._close(conn) + + def show_scene(self, scene_id: str) -> Dict[str, Any]: + conn = self._connect() + try: + scene_rows = self._rows( + conn, + "MATCH (s:Scene {id: $scene_id}) RETURN s.id, s.title, s.status, s.summary, s.privacy_scope, s.metadata_json", + {"scene_id": scene_id}, + ) + if not scene_rows: + return {"status": "not_found", "scene_id": scene_id} + events = self._rows( + conn, + """ + MATCH (e:RawEvent)-[r:BELONGS_TO]->(s:Scene {id: $scene_id}) + RETURN e.id, e.timestamp, e.source_app, e.event_type, r.evidence_json + ORDER BY e.timestamp ASC + """, + {"scene_id": scene_id}, + ) + return { + "status": "ok", + "scene": { + "id": scene_rows[0][0], + "title": scene_rows[0][1], + "status": scene_rows[0][2], + "summary": scene_rows[0][3], + "privacy_scope": scene_rows[0][4], + "metadata": _loads_dict(scene_rows[0][5]), + }, + "supporting_events": [ + { + "event_id": row[0], + "timestamp": row[1], + "source_app": row[2], + "event_type": row[3], + "evidence": _loads_list(row[4]), + } + for row in events + ], + } + finally: + self._close(conn) + + def show_thread(self, thread_id: str) -> Dict[str, Any]: + conn = self._connect() + try: + thread_rows = self._rows( + conn, + "MATCH (t:MemoryThread {id: $thread_id}) RETURN t.id, t.thread_type, t.title, t.status, t.summary, t.privacy_scope, t.metadata_json", + {"thread_id": thread_id}, + ) + if not thread_rows: + return {"status": "not_found", "thread_id": thread_id} + events = self._rows( + conn, + """ + MATCH (e:RawEvent)-[r:PROJECTED_INTO]->(t:MemoryThread {id: $thread_id}) + RETURN e.id, e.timestamp, e.source_app, e.event_type, r.evidence_json + ORDER BY e.timestamp ASC + """, + {"thread_id": thread_id}, + ) + return { + "status": "ok", + "thread": { + "id": thread_rows[0][0], + "thread_type": thread_rows[0][1], + "title": thread_rows[0][2], + "status": thread_rows[0][3], + "summary": thread_rows[0][4], + "privacy_scope": thread_rows[0][5], + "metadata": _loads_dict(thread_rows[0][6]), + }, + "events": [ + { + "event_id": row[0], + "timestamp": row[1], + "source_app": row[2], + "event_type": row[3], + "evidence": _loads_list(row[4]), + } + for row in events + ], + } + finally: + self._close(conn) + + def explain_retrieval(self, query_id: str) -> Dict[str, Any]: + return { + "query_id": query_id, + "status": "not_recorded", + "traversal": [], + "evidence": [], + } + + def get_active_frontier(self, *, user_id: str = "default", scope: str = "global") -> Dict[str, Any]: + conn = self._connect() + try: + allowed = _allowed_scopes(scope) + thread_scope_clause, thread_scope_params = _scope_where("t", allowed) + threads = self._rows( + conn, + f""" + MATCH (t:MemoryThread) + WHERE t.user_id = $user_id + {thread_scope_clause} + RETURN t.id, t.thread_type, t.title, t.status, t.summary, t.privacy_scope + ORDER BY t.updated_at DESC + LIMIT 12 + """, + {"user_id": user_id, **thread_scope_params}, + ) + active_threads = [ + { + "thread_id": row[0], + "thread_type": row[1], + "title": row[2], + "status": row[3], + "summary": row[4], + "privacy_scope": row[5], + "evidence": self._thread_evidence(conn, row[0], allowed), + } + for row in threads + ] + event_scope_clause, event_scope_params = _scope_where("e", allowed) + recent = self._rows( + conn, + f""" + MATCH (e:RawEvent) + WHERE e.user_id = $user_id + {event_scope_clause} + RETURN e.id, e.timestamp, e.source_app, e.event_type, e.privacy_scope + ORDER BY e.timestamp DESC + LIMIT 1 + """, + {"user_id": user_id, **event_scope_params}, + ) + recent_scene = None + if recent: + scene_rows = self._rows( + conn, + """ + MATCH (e:RawEvent {id: $event_id})-[r:BELONGS_TO]->(s:Scene) + RETURN s.id, s.title, s.summary, s.privacy_scope + LIMIT 1 + """, + {"event_id": recent[0][0]}, + ) + if scene_rows and scene_rows[0][3] in allowed: + recent_scene = { + "scene_id": scene_rows[0][0], + "title": scene_rows[0][1], + "summary": scene_rows[0][2], + "privacy_scope": scene_rows[0][3], + } + preference_gems = sorted( + self._list_gems(conn, user_id=user_id, scope=scope, kind="preference", limit=8).get("gems", []), + key=lambda gem: _rank_preference_gem(gem, ""), + reverse=True, + ) + preference_signals = [_preference_signal_from_gem(gem) for gem in preference_gems] + if not preference_signals: + preference_signals = [ + thread for thread in active_threads if thread.get("thread_type") == "preference" + ] + return { + "active_threads": active_threads, + "recent_scene": recent_scene, + "last_verified_state": _last_event_state(recent[0]) if recent else "", + "open_questions": [], + "next_likely_need": "", + "high_confidence_preferences": preference_signals, + "evidence": [{"event_id": recent[0][0]}] if recent else [], + } + finally: + self._close(conn) + + def why(self, *, event_id: Optional[str] = None, query: str = "", user_id: str = "default", scope: str = "global") -> Dict[str, Any]: + target = event_id or self._find_event_for_query(query=query, user_id=user_id, scope=scope) + if not target: + return { + "target_event_id": "", + "likely_causes": [], + "rejected_causes": [], + "confidence": 0.0, + "causal_path": [], + "evidence": [], + } + cone = self.show_cone(target, direction="backward", depth=3) + likely = [ + { + "event_id": item["from"], + "relation": item["edge_type"], + "confidence": item.get("confidence", 0.0), + "evidence": item.get("evidence", []), + } + for item in cone.get("causal_path", []) + ] + return { + "target_event_id": target, + "likely_causes": likely, + "rejected_causes": [], + "confidence": max([item.get("confidence", 0.0) for item in likely] or [0.0]), + "causal_path": cone.get("causal_path", []), + "evidence": cone.get("evidence", []), + } + + def what_happened(self, *, target_id: str, user_id: str = "default", scope: str = "global") -> Dict[str, Any]: + conn = self._connect() + try: + allowed = _allowed_scopes(scope) + event_scope_clause, event_scope_params = _scope_where("e", allowed) + events = self._rows( + conn, + f""" + MATCH (e:RawEvent)-[r:PROJECTED_INTO]->(t:MemoryThread {{id: $target_id}}) + WHERE e.user_id = $user_id + {event_scope_clause} + RETURN e.id, e.timestamp, e.source_app, e.event_type, e.privacy_scope + ORDER BY e.timestamp ASC + """, + {"target_id": target_id, "user_id": user_id, **event_scope_params}, + ) + if not events: + events = self._rows( + conn, + f""" + MATCH (e:RawEvent)-[r:BELONGS_TO]->(s:Scene {{id: $target_id}}) + WHERE e.user_id = $user_id + {event_scope_clause} + RETURN e.id, e.timestamp, e.source_app, e.event_type, e.privacy_scope + ORDER BY e.timestamp ASC + """, + {"target_id": target_id, "user_id": user_id, **event_scope_params}, + ) + timeline = [ + { + "event_id": row[0], + "timestamp": row[1], + "source_app": row[2], + "event_type": row[3], + "privacy_scope": row[4], + } + for row in events + ] + return { + "target_id": target_id, + "ordered_timeline": timeline, + "scene_boundaries": [], + "source_events": [{"event_id": item["event_id"]} for item in timeline], + "summary": _timeline_summary(timeline), + } + finally: + self._close(conn) + + def handoff(self, *, user_id: str = "default", scope: str = "global") -> Dict[str, Any]: + frontier = self.get_active_frontier(user_id=user_id, scope=scope) + return { + "active_causal_frontier": frontier, + "blockers": [], + "last_verified_state": frontier.get("last_verified_state", ""), + "next_action_candidates": [], + "evidence": frontier.get("evidence", []), + } + + def preference(self, *, query: str = "", user_id: str = "default", scope: str = "global") -> Dict[str, Any]: + gem_result = self.list_gems(user_id=user_id, scope=scope, kind="preference", limit=100) + ranked_gems = sorted( + gem_result.get("gems", []), + key=lambda gem: _rank_preference_gem(gem, query), + reverse=True, + ) + if ranked_gems: + top = ranked_gems[0] + return { + "query": query, + "preference_signal": _preference_signal_from_gem(top), + "confidence": _float(top.get("confidence"), 0.7), + "supporting_events": [_supporting_event_from_gem(top)], + "contradictions": [], + "scope": scope, + "retrieval_path": [ + { + "mode": "preference", + "step": "list_gems", + "kind": "preference", + "candidate_count": gem_result.get("count", len(ranked_gems)), + }, + { + "mode": "preference", + "step": "rank_preference_gems", + "query_tokens": _query_tokens(query), + "selected_event_id": top.get("event_id"), + }, + ], + } + + frontier = self.get_active_frontier(user_id=user_id, scope=scope) + preferences = frontier.get("high_confidence_preferences", []) + return { + "query": query, + "preference_signal": preferences[0] if preferences else None, + "confidence": 0.7 if preferences else 0.0, + "supporting_events": (preferences[0].get("evidence", []) if preferences else []), + "contradictions": [], + "scope": scope, + "retrieval_path": [ + { + "mode": "preference", + "step": "frontier_thread_fallback", + "candidate_count": len(preferences), + } + ], + } + + def list_gems( + self, + *, + user_id: str = "default", + scope: str = "global", + kind: Optional[str] = None, + limit: int = 50, + ) -> Dict[str, Any]: + conn = self._connect() + try: + return self._list_gems(conn, user_id=user_id, scope=scope, kind=kind, limit=limit) + finally: + self._close(conn) + + def show_gem(self, target: str, *, user_id: str = "default", scope: str = "global") -> Dict[str, Any]: + event_id = _normalize_gem_event_id(target) + conn = self._connect() + try: + allowed = _allowed_scopes(scope) + scope_clause, scope_params = _scope_where("e", allowed) + rows = self._rows( + conn, + f""" + MATCH (e:RawEvent) + WHERE e.user_id = $user_id + AND e.id = $event_id + AND e.source_app = 'memory-gem' + {scope_clause} + RETURN e.id, e.schema_version, e.projection_version, e.sqlite_id, + e.timestamp, e.event_type, e.privacy_scope, e.content_ref, e.metadata_json + LIMIT 1 + """, + {"user_id": user_id, "event_id": event_id, **scope_params}, + ) + if not rows: + return { + "status": "not_found", + "target": target, + "event_id": event_id, + "scope": scope, + "retrieval_path": [ + { + "mode": "show_gem", + "step": "match_scoped_memory_gem", + "matched": False, + } + ], + } + row = rows[0] + metadata = _loads_dict(row[8]) + gem = { + "event_id": row[0], + "schema_version": row[1], + "projection_version": row[2], + "sqlite_id": row[3], + "gem_id": metadata.get("gem_id"), + "kind": metadata.get("kind") or str(row[5]).replace("gem_", ""), + "title": metadata.get("title") or "", + "summary": metadata.get("summary") or "", + "score": metadata.get("score"), + "confidence": metadata.get("confidence"), + "timestamp": row[4], + "event_type": row[5], + "privacy_scope": row[6], + "content_ref": row[7], + } + source_memory_id = str(metadata.get("source_memory_id") or "").strip() + source_event_id = str(metadata.get("source_event_id") or "").strip() + evidence = metadata.get("evidence") or [] + threads = self._threads_for_event(conn, row[0]) + return { + "status": "ok", + "target": target, + "gem": gem, + "source_memory": { + "memory_id": source_memory_id, + "content_ref": row[7], + "source_event_id": source_event_id, + "source_app": metadata.get("source_app") or "", + "memory_type": metadata.get("memory_type") or "", + "categories": metadata.get("categories") or [], + }, + "supporting_events": [ + { + "event_id": row[0], + "source_memory_id": source_memory_id, + "source_event_id": source_event_id, + "content_ref": row[7], + "evidence": evidence, + } + ], + "threads": threads, + "derived_summaries": [], + "retrieval_path": [ + { + "mode": "show_gem", + "step": "match_scoped_memory_gem", + "matched": True, + "event_id": row[0], + }, + { + "mode": "show_gem", + "step": "load_thread_memberships", + "thread_count": len(threads), + }, + ], + } + finally: + self._close(conn) + + def _list_gems( + self, + conn: kuzu.Connection, + *, + user_id: str, + scope: str, + kind: Optional[str], + limit: int, + ) -> Dict[str, Any]: + allowed = _allowed_scopes(scope) + scope_clause, scope_params = _scope_where("e", allowed) + kind_clause = "" + kind_filter = str(kind or "").strip().lower() or None + params: Dict[str, Any] = { + "user_id": user_id, + "limit": max(1, int(limit or 50)), + **scope_params, + } + if kind_filter: + kind_clause = "AND e.event_type = $event_type" + params["event_type"] = f"gem_{kind_filter}" + rows = self._rows( + conn, + f""" + MATCH (e:RawEvent) + WHERE e.user_id = $user_id + AND e.source_app = 'memory-gem' + {scope_clause} + {kind_clause} + RETURN e.id, e.timestamp, e.event_type, e.privacy_scope, e.content_ref, e.metadata_json + ORDER BY e.timestamp DESC + LIMIT $limit + """, + params, + ) + gems: List[Dict[str, Any]] = [] + by_kind: Dict[str, int] = {} + for row in rows: + metadata = _loads_dict(row[5]) + gem_kind = str(metadata.get("kind") or str(row[2]).replace("gem_", "") or "unknown") + if kind_filter and gem_kind != kind_filter: + continue + by_kind[gem_kind] = by_kind.get(gem_kind, 0) + 1 + gems.append( + { + "event_id": row[0], + "gem_id": metadata.get("gem_id"), + "kind": gem_kind, + "title": metadata.get("title") or "", + "summary": metadata.get("summary") or "", + "score": metadata.get("score"), + "confidence": metadata.get("confidence"), + "source_memory_id": metadata.get("source_memory_id"), + "source_event_id": metadata.get("source_event_id"), + "content_ref": row[4], + "timestamp": row[1], + "privacy_scope": row[3], + "evidence": metadata.get("evidence") or [], + } + ) + return { + "count": len(gems), + "by_kind": by_kind, + "scope": scope, + "kind": kind_filter, + "gems": gems, + "evidence": [ + {"event_id": item["event_id"], "source_memory_id": item["source_memory_id"]} + for item in gems + ], + } + + def _connect(self) -> kuzu.Connection: + db = kuzu.Database(str(self.db_path)) + conn = kuzu.Connection(db) + self._last_db = db + return conn + + @staticmethod + def _close(conn: kuzu.Connection) -> None: + try: + conn.close() + except Exception: + pass + + def _create_schema(self, conn: kuzu.Connection) -> None: + common_fields = """ + id STRING, + schema_version STRING, + projection_version STRING, + sqlite_id STRING, + user_id STRING, + privacy_scope STRING, + deleted_at STRING, + redacted_at STRING, + redaction_reason STRING, + metadata_json STRING + """ + common = f"{common_fields}, PRIMARY KEY(id)" + conn.execute( + """ + CREATE NODE TABLE IF NOT EXISTS RawEvent( + id STRING, + schema_version STRING, + projection_version STRING, + sqlite_id STRING, + user_id STRING, + session_id STRING, + source_app STRING, + namespace STRING, + event_type STRING, + timestamp STRING, + content_ref STRING, + content_hash STRING, + privacy_scope STRING, + deleted_at STRING, + redacted_at STRING, + redaction_reason STRING, + metadata_json STRING, + PRIMARY KEY(id) + ); + """ + ) + conn.execute( + """ + CREATE NODE TABLE IF NOT EXISTS EventFrame( + id STRING, + schema_version STRING, + projection_version STRING, + sqlite_id STRING, + user_id STRING, + frame_type STRING, + summary STRING, + source_event_ids_json STRING, + confidence DOUBLE, + privacy_scope STRING, + deleted_at STRING, + redacted_at STRING, + redaction_reason STRING, + metadata_json STRING, + PRIMARY KEY(id) + ); + """ + ) + conn.execute( + f""" + CREATE NODE TABLE IF NOT EXISTS Scene( + {common_fields}, + title STRING, + status STRING, + summary STRING, + updated_at STRING, + PRIMARY KEY(id) + ); + """ + ) + conn.execute( + f""" + CREATE NODE TABLE IF NOT EXISTS Episode( + {common_fields}, + title STRING, + status STRING, + summary STRING, + updated_at STRING, + PRIMARY KEY(id) + ); + """ + ) + conn.execute( + f""" + CREATE NODE TABLE IF NOT EXISTS MemoryThread( + {common_fields}, + thread_type STRING, + title STRING, + status STRING, + summary STRING, + updated_at STRING, + PRIMARY KEY(id) + ); + """ + ) + for table in ["Memory", "Artifact", "Actor", "Entity", "Surface"]: + conn.execute(f"CREATE NODE TABLE IF NOT EXISTS {table}({common});") + rel_cols = """ + id STRING, + schema_version STRING, + projection_version STRING, + sqlite_id STRING, + source_sqlite_id STRING, + user_id STRING, + privacy_scope STRING, + confidence DOUBLE, + status STRING, + evidence_json STRING, + explanation STRING, + created_at STRING + """ + for rel, (source, target) in AUTOMATIC_REL_TABLES.items(): + conn.execute(f"CREATE REL TABLE IF NOT EXISTS {rel}(FROM {source} TO {target}, {rel_cols});") + for rel in CHECKPOINT_CAUSAL_EDGE_TYPES: + conn.execute(f"CREATE REL TABLE IF NOT EXISTS {rel}(FROM EventFrame TO EventFrame, {rel_cols});") + + def _project( + self, + conn: kuzu.Connection, + *, + events: Sequence[RawEvent], + frames: Sequence[EventFrame], + edges: Sequence[CausalEdge], + ) -> Dict[str, int]: + counts = { + "RawEvent": 0, + "EventFrame": 0, + "MemoryThread": 0, + "Scene": 0, + "Episode": 0, + "Surface": 0, + "Entity": 0, + "Artifact": 0, + "Memory": 0, + "automatic_edges": 0, + "causal_edges": 0, + } + for event in events: + self._insert_raw_event(conn, event) + counts["RawEvent"] += 1 + counts["Surface"] += self._project_surface(conn, event) + counts["Entity"] += self._project_entities(conn, event) + artifact_count, memory_count = self._project_artifacts_and_memories(conn, event) + counts["Artifact"] += artifact_count + counts["Memory"] += memory_count + scene_count, episode_count, scene_edge_count = self._project_scene_episode(conn, event) + counts["Scene"] += scene_count + counts["Episode"] += episode_count + counts["automatic_edges"] += scene_edge_count + thread_count, thread_edge_count = self._project_threads(conn, event) + counts["MemoryThread"] += thread_count + counts["automatic_edges"] += thread_edge_count + + for previous, current in zip(events, events[1:]): + if previous.user_id != current.user_id: + continue + self._insert_rel( + conn, + "TEMPORAL_NEXT", + "RawEvent", + "RawEvent", + previous.id, + current.id, + rel_id=f"temporal:{previous.id}:{current.id}", + user_id=current.user_id, + privacy_scope=_strictest_scope([previous.privacy_scope, current.privacy_scope]), + confidence=1.0, + status="observed", + evidence=[previous.id, current.id], + explanation="Events were adjacent in SQLite RawEvent timestamp order.", + source_sqlite_id=current.id, + created_at=current.timestamp, + ) + counts["automatic_edges"] += 1 + + for frame in frames: + self._insert_event_frame(conn, frame) + counts["EventFrame"] += 1 + + frame_ids = {frame.id for frame in frames} + for edge in edges: + if edge.edge_type not in CHECKPOINT_CAUSAL_EDGE_TYPES: + continue + if edge.source_id not in frame_ids or edge.target_id not in frame_ids: + continue + self._insert_rel( + conn, + edge.edge_type, + "EventFrame", + "EventFrame", + edge.source_id, + edge.target_id, + rel_id=edge.id, + user_id=edge.user_id, + privacy_scope=edge.privacy_scope, + confidence=edge.confidence, + status=edge.status, + evidence=edge.evidence_event_ids, + explanation=edge.explanation, + source_sqlite_id=edge.id, + created_at=edge.created_at, + ) + counts["causal_edges"] += 1 + return self._projection_counts(conn) + + def _insert_raw_event(self, conn: kuzu.Connection, event: RawEvent) -> None: + conn.execute( + """ + MERGE (e:RawEvent {id: $id}) + SET e.schema_version = $schema_version, + e.projection_version = $projection_version, + e.sqlite_id = $sqlite_id, + e.user_id = $user_id, + e.session_id = $session_id, + e.source_app = $source_app, + e.namespace = $namespace, + e.event_type = $event_type, + e.timestamp = $timestamp, + e.content_ref = $content_ref, + e.content_hash = $content_hash, + e.privacy_scope = $privacy_scope, + e.deleted_at = $deleted_at, + e.redacted_at = $redacted_at, + e.redaction_reason = $redaction_reason, + e.metadata_json = $metadata_json + """, + _clean_params( + { + "id": event.id, + "schema_version": event.schema_version, + "projection_version": CAUSAL_PROJECTION_VERSION, + "sqlite_id": event.id, + "user_id": event.user_id, + "session_id": event.session_id, + "source_app": event.source_app, + "namespace": event.namespace, + "event_type": event.event_type, + "timestamp": event.timestamp, + "content_ref": event.content_ref, + "content_hash": event.content_hash, + "privacy_scope": event.privacy_scope, + "deleted_at": event.deleted_at, + "redacted_at": event.redacted_at, + "redaction_reason": event.redaction_reason, + "metadata_json": json.dumps(event.metadata, sort_keys=True), + } + ), + ) + + def _insert_event_frame(self, conn: kuzu.Connection, frame: EventFrame) -> None: + conn.execute( + """ + MERGE (f:EventFrame {id: $id}) + SET f.schema_version = $schema_version, + f.projection_version = $projection_version, + f.sqlite_id = $sqlite_id, + f.user_id = $user_id, + f.frame_type = $frame_type, + f.summary = $summary, + f.source_event_ids_json = $source_event_ids_json, + f.confidence = $confidence, + f.privacy_scope = $privacy_scope, + f.deleted_at = $deleted_at, + f.redacted_at = $redacted_at, + f.redaction_reason = $redaction_reason, + f.metadata_json = $metadata_json + """, + _clean_params( + { + "id": frame.id, + "schema_version": frame.schema_version, + "projection_version": CAUSAL_PROJECTION_VERSION, + "sqlite_id": frame.id, + "user_id": frame.user_id, + "frame_type": frame.frame_type, + "summary": frame.summary, + "source_event_ids_json": json.dumps(frame.source_event_ids), + "confidence": float(frame.confidence), + "privacy_scope": frame.privacy_scope, + "deleted_at": frame.deleted_at, + "redacted_at": frame.redacted_at, + "redaction_reason": frame.redaction_reason, + "metadata_json": json.dumps(frame.metadata, sort_keys=True), + } + ), + ) + + def _insert_generic_node( + self, + conn: kuzu.Connection, + table: str, + node_id: str, + *, + user_id: str, + privacy_scope: str, + sqlite_id: str, + metadata: Optional[Dict[str, Any]] = None, + deleted_at: Optional[str] = None, + redacted_at: Optional[str] = None, + redaction_reason: Optional[str] = None, + ) -> None: + conn.execute( + f""" + MERGE (n:{table} {{id: $id}}) + SET n.schema_version = $schema_version, + n.projection_version = $projection_version, + n.sqlite_id = $sqlite_id, + n.user_id = $user_id, + n.privacy_scope = $privacy_scope, + n.deleted_at = $deleted_at, + n.redacted_at = $redacted_at, + n.redaction_reason = $redaction_reason, + n.metadata_json = $metadata_json + """, + _clean_params( + { + "id": node_id, + "schema_version": CAUSAL_SCHEMA_VERSION, + "projection_version": CAUSAL_PROJECTION_VERSION, + "sqlite_id": sqlite_id, + "user_id": user_id, + "privacy_scope": privacy_scope, + "deleted_at": deleted_at, + "redacted_at": redacted_at, + "redaction_reason": redaction_reason, + "metadata_json": json.dumps(metadata or {}, sort_keys=True), + } + ), + ) + + def _insert_scene( + self, + conn: kuzu.Connection, + node_id: str, + *, + user_id: str, + privacy_scope: str, + sqlite_id: str, + title: str, + status: str, + summary: str, + updated_at: str, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + conn.execute( + """ + MERGE (n:Scene {id: $id}) + SET n.schema_version = $schema_version, + n.projection_version = $projection_version, + n.sqlite_id = $sqlite_id, + n.user_id = $user_id, + n.privacy_scope = $privacy_scope, + n.deleted_at = '', + n.redacted_at = '', + n.redaction_reason = '', + n.metadata_json = $metadata_json, + n.title = $title, + n.status = $status, + n.summary = $summary, + n.updated_at = $updated_at + """, + _clean_params( + { + "id": node_id, + "schema_version": CAUSAL_SCHEMA_VERSION, + "projection_version": CAUSAL_PROJECTION_VERSION, + "sqlite_id": sqlite_id, + "user_id": user_id, + "privacy_scope": privacy_scope, + "metadata_json": json.dumps(metadata or {}, sort_keys=True), + "title": title, + "status": status, + "summary": summary, + "updated_at": updated_at, + } + ), + ) + + def _insert_episode( + self, + conn: kuzu.Connection, + node_id: str, + *, + user_id: str, + privacy_scope: str, + sqlite_id: str, + title: str, + status: str, + summary: str, + updated_at: str, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + conn.execute( + """ + MERGE (n:Episode {id: $id}) + SET n.schema_version = $schema_version, + n.projection_version = $projection_version, + n.sqlite_id = $sqlite_id, + n.user_id = $user_id, + n.privacy_scope = $privacy_scope, + n.deleted_at = '', + n.redacted_at = '', + n.redaction_reason = '', + n.metadata_json = $metadata_json, + n.title = $title, + n.status = $status, + n.summary = $summary, + n.updated_at = $updated_at + """, + _clean_params( + { + "id": node_id, + "schema_version": CAUSAL_SCHEMA_VERSION, + "projection_version": CAUSAL_PROJECTION_VERSION, + "sqlite_id": sqlite_id, + "user_id": user_id, + "privacy_scope": privacy_scope, + "metadata_json": json.dumps(metadata or {}, sort_keys=True), + "title": title, + "status": status, + "summary": summary, + "updated_at": updated_at, + } + ), + ) + + def _project_surface(self, conn: kuzu.Connection, event: RawEvent) -> int: + source = event.source_app or "unknown" + node_id = f"surface:{source}" + self._insert_generic_node( + conn, + "Surface", + node_id, + user_id=event.user_id, + privacy_scope=event.privacy_scope, + sqlite_id=event.id, + metadata={"source_app": source, "surface_type": "app"}, + ) + self._insert_rel( + conn, + "OBSERVED_ON", + "RawEvent", + "Surface", + event.id, + node_id, + rel_id=f"observed:{event.id}:{node_id}", + user_id=event.user_id, + privacy_scope=event.privacy_scope, + confidence=1.0, + status="observed", + evidence=[event.id], + explanation="RawEvent source_app was deterministically projected as the observed surface.", + source_sqlite_id=event.id, + created_at=event.timestamp, + ) + return 1 + + def _project_entities(self, conn: kuzu.Connection, event: RawEvent) -> int: + entities = _metadata_list(event.metadata, "entities") + if not entities and event.metadata.get("entity"): + entities = [str(event.metadata["entity"])] + count = 0 + for entity in entities: + node_id = f"entity:{_slug(entity)}" + self._insert_generic_node( + conn, + "Entity", + node_id, + user_id=event.user_id, + privacy_scope=event.privacy_scope, + sqlite_id=event.id, + metadata={"name": entity}, + ) + self._insert_rel( + conn, + "MENTIONS", + "RawEvent", + "Entity", + event.id, + node_id, + rel_id=f"mentions:{event.id}:{node_id}", + user_id=event.user_id, + privacy_scope=event.privacy_scope, + confidence=1.0, + status="observed", + evidence=[event.id], + explanation="Entity was extracted from deterministic RawEvent metadata.", + source_sqlite_id=event.id, + created_at=event.timestamp, + ) + count += 1 + return count + + def _project_artifacts_and_memories(self, conn: kuzu.Connection, event: RawEvent) -> Tuple[int, int]: + artifact_count = 0 + memory_count = 0 + artifact_id = str(event.metadata.get("artifact_id") or "").strip() + if artifact_id: + node_id = f"artifact:{artifact_id}" + self._insert_generic_node( + conn, + "Artifact", + node_id, + user_id=event.user_id, + privacy_scope=event.privacy_scope, + sqlite_id=artifact_id, + metadata={"artifact_id": artifact_id}, + ) + self._insert_rel( + conn, + "CREATED", + "RawEvent", + "Artifact", + event.id, + node_id, + rel_id=f"created:{event.id}:{node_id}", + user_id=event.user_id, + privacy_scope=event.privacy_scope, + confidence=1.0, + status="observed", + evidence=[event.id], + explanation="RawEvent metadata points at a created artifact.", + source_sqlite_id=event.id, + created_at=event.timestamp, + ) + artifact_count += 1 + memory_id = str(event.metadata.get("memory_id") or "").strip() + if memory_id: + node_id = f"memory:{memory_id}" + self._insert_generic_node( + conn, + "Memory", + node_id, + user_id=event.user_id, + privacy_scope=event.privacy_scope, + sqlite_id=memory_id, + metadata={"memory_id": memory_id}, + ) + self._insert_rel( + conn, + "UPDATED", + "RawEvent", + "Memory", + event.id, + node_id, + rel_id=f"updated:{event.id}:{node_id}", + user_id=event.user_id, + privacy_scope=event.privacy_scope, + confidence=1.0, + status="observed", + evidence=[event.id], + explanation="RawEvent metadata points at an updated Dhee memory.", + source_sqlite_id=event.id, + created_at=event.timestamp, + ) + memory_count += 1 + return artifact_count, memory_count + + def _project_scene_episode(self, conn: kuzu.Connection, event: RawEvent) -> Tuple[int, int, int]: + if not event.session_id: + return 0, 0, 0 + scene_id = f"scene:{event.session_id}" + episode_id = f"episode:{event.session_id}" + self._insert_scene( + conn, + scene_id, + user_id=event.user_id, + privacy_scope=event.privacy_scope, + sqlite_id=event.session_id, + title=f"{event.source_app or 'app'} session", + status="active", + summary="Session-local raw event cluster.", + updated_at=event.timestamp, + metadata={ + "episode_id": episode_id, + }, + ) + self._insert_episode( + conn, + episode_id, + user_id=event.user_id, + privacy_scope=event.privacy_scope, + sqlite_id=event.session_id, + title=f"{event.source_app or 'app'} episode", + status="active", + summary="Session-local episode projection.", + updated_at=event.timestamp, + metadata={"scene_id": scene_id}, + ) + self._insert_rel( + conn, + "BELONGS_TO", + "RawEvent", + "Scene", + event.id, + scene_id, + rel_id=f"belongs:{event.id}:{scene_id}", + user_id=event.user_id, + privacy_scope=event.privacy_scope, + confidence=1.0, + status="observed", + evidence=[event.id], + explanation="RawEvent session_id deterministically groups it into a session scene.", + source_sqlite_id=event.id, + created_at=event.timestamp, + ) + return 1, 1, 1 + + def _project_threads(self, conn: kuzu.Connection, event: RawEvent) -> Tuple[int, int]: + gem_kind = _gem_kind(event) + thread_specs = [ + ("source", f"source:{event.source_app or 'unknown'}", f"{event.source_app or 'unknown'} events"), + ("event_type", f"event_type:{event.event_type or 'unknown'}", f"{event.event_type or 'unknown'} events"), + ] + for raw_thread in _metadata_list(event.metadata, "threads"): + thread_specs.append(("custom", f"thread:{_slug(raw_thread)}", raw_thread)) + if event.source_app in {"gmail", "mail"}: + thread_specs.append(("gmail", "gmail", "Gmail")) + if event.source_app in {"chrome", "arc", "firefox", "safari", "browser"}: + thread_specs.append(("browser", "browser", "Browser")) + if event.event_type in {"preference", "user_correction", "correction"}: + thread_specs.append(("preference", "preference", "User preference signals")) + if event.event_type in {"artifact", "artifact_created"} or event.metadata.get("artifact_id"): + thread_specs.append(("artifact", "artifact", "Artifacts")) + if event.metadata.get("project"): + thread_specs.append(("project", f"project:{_slug(str(event.metadata['project']))}", str(event.metadata["project"]))) + if event.metadata.get("contact"): + thread_specs.append(("contact", f"contact:{_slug(str(event.metadata['contact']))}", str(event.metadata["contact"]))) + if event.event_type in {"learning", "lesson", "skill"}: + thread_specs.append(("learning", "learning", "Learning")) + if gem_kind: + thread_specs.append(("gem", "gems", "Memory gems")) + thread_specs.append(("gem_kind", f"gem:{gem_kind}", f"{gem_kind.title()} gems")) + if gem_kind == "preference": + thread_specs.append(("preference", "preference", "User preference signals")) + elif gem_kind == "decision": + thread_specs.append(("decision", "decision", "Decisions")) + elif gem_kind == "learning": + thread_specs.append(("learning", "learning", "Learning")) + elif gem_kind == "task": + thread_specs.append(("task", "task", "Tasks")) + elif gem_kind == "artifact": + thread_specs.append(("artifact", "artifact", "Artifacts")) + source_memory_id = str(event.metadata.get("source_memory_id") or "").strip() + if source_memory_id: + thread_specs.append(("source_memory", f"memory:{source_memory_id}", f"Source memory {source_memory_id[:8]}")) + seen: set[str] = set() + node_count = 0 + edge_count = 0 + for thread_type, thread_key, title in thread_specs: + node_id = f"thread:{thread_key}" + if node_id in seen: + continue + seen.add(node_id) + self._insert_memory_thread( + conn, + node_id, + user_id=event.user_id, + privacy_scope=event.privacy_scope, + sqlite_id=event.id, + thread_type=thread_type, + title=title, + updated_at=event.timestamp, + ) + self._insert_rel( + conn, + "PROJECTED_INTO", + "RawEvent", + "MemoryThread", + event.id, + node_id, + rel_id=f"projected:{event.id}:{node_id}", + user_id=event.user_id, + privacy_scope=event.privacy_scope, + confidence=1.0, + status="observed", + evidence=[event.id], + explanation="Thread projection was deterministically derived from RawEvent metadata/source/type.", + source_sqlite_id=event.id, + created_at=event.timestamp, + ) + node_count += 1 + edge_count += 1 + return node_count, edge_count + + def _insert_memory_thread( + self, + conn: kuzu.Connection, + node_id: str, + *, + user_id: str, + privacy_scope: str, + sqlite_id: str, + thread_type: str, + title: str, + updated_at: str, + ) -> None: + metadata = { + "thread_type": thread_type, + "title": title, + "status": "active", + "summary": title, + "updated_at": updated_at, + } + conn.execute( + """ + MERGE (n:MemoryThread {id: $id}) + SET n.schema_version = $schema_version, + n.projection_version = $projection_version, + n.sqlite_id = $sqlite_id, + n.user_id = $user_id, + n.privacy_scope = $privacy_scope, + n.deleted_at = '', + n.redacted_at = '', + n.redaction_reason = '', + n.metadata_json = $metadata_json, + n.thread_type = $thread_type, + n.title = $title, + n.status = 'active', + n.summary = $summary, + n.updated_at = $updated_at + """, + _clean_params( + { + "id": node_id, + "schema_version": CAUSAL_SCHEMA_VERSION, + "projection_version": CAUSAL_PROJECTION_VERSION, + "sqlite_id": sqlite_id, + "user_id": user_id, + "privacy_scope": privacy_scope, + "metadata_json": json.dumps(metadata, sort_keys=True), + "thread_type": thread_type, + "title": title, + "summary": title, + "updated_at": updated_at, + } + ), + ) + + def _insert_rel( + self, + conn: kuzu.Connection, + rel: str, + source_table: str, + target_table: str, + source_id: str, + target_id: str, + *, + rel_id: str, + user_id: str, + privacy_scope: str, + confidence: float, + status: str, + evidence: Sequence[str], + explanation: str, + source_sqlite_id: str, + created_at: str, + ) -> None: + conn.execute( + f""" + MATCH (source:{source_table} {{id: $source_id}}), (target:{target_table} {{id: $target_id}}) + CREATE (source)-[:{rel} {{ + id: $id, + schema_version: $schema_version, + projection_version: $projection_version, + sqlite_id: $sqlite_id, + source_sqlite_id: $source_sqlite_id, + user_id: $user_id, + privacy_scope: $privacy_scope, + confidence: $confidence, + status: $status, + evidence_json: $evidence_json, + explanation: $explanation, + created_at: $created_at + }}]->(target) + """, + _clean_params( + { + "source_id": source_id, + "target_id": target_id, + "id": rel_id, + "schema_version": CAUSAL_SCHEMA_VERSION, + "projection_version": CAUSAL_PROJECTION_VERSION, + "sqlite_id": rel_id, + "source_sqlite_id": source_sqlite_id, + "user_id": user_id, + "privacy_scope": privacy_scope, + "confidence": float(confidence), + "status": status, + "evidence_json": json.dumps(list(evidence)), + "explanation": explanation, + "created_at": created_at, + } + ), + ) + + def _rows(self, conn: kuzu.Connection, query: str, params: Optional[Dict[str, Any]] = None) -> List[List[Any]]: + result = conn.execute(query, _clean_params(params or {})) + return result.get_all() + + def _single_column(self, conn: kuzu.Connection, query: str, params: Optional[Dict[str, Any]] = None) -> List[Any]: + return [row[0] for row in self._rows(conn, query, params)] + + def _projection_counts(self, conn: kuzu.Connection) -> Dict[str, int]: + counts: Dict[str, int] = {} + for table in [ + "RawEvent", + "EventFrame", + "MemoryThread", + "Scene", + "Episode", + "Surface", + "Entity", + "Artifact", + "Memory", + ]: + counts[table] = self._count(conn, f"MATCH (n:{table}) RETURN count(n)") + counts["automatic_edges"] = sum( + self._count(conn, f"MATCH ()-[r:{rel}]->() RETURN count(r)") + for rel in AUTOMATIC_REL_TABLES + ) + counts["causal_edges"] = sum( + self._count(conn, f"MATCH ()-[r:{rel}]->() RETURN count(r)") + for rel in CHECKPOINT_CAUSAL_EDGE_TYPES + ) + return counts + + def _count(self, conn: kuzu.Connection, query: str) -> int: + rows = self._rows(conn, query) + return int(rows[0][0] or 0) if rows else 0 + + def _orphan_errors(self, conn: kuzu.Connection) -> List[str]: + errors: List[str] = [] + for rel, (source, target) in AUTOMATIC_REL_TABLES.items(): + try: + self._rows(conn, f"MATCH (a:{source})-[r:{rel}]->(b:{target}) RETURN r.id LIMIT 1") + except Exception as exc: + errors.append(f"{rel} orphan/schema check failed: {exc}") + for rel in CHECKPOINT_CAUSAL_EDGE_TYPES: + try: + self._rows(conn, f"MATCH (a:EventFrame)-[r:{rel}]->(b:EventFrame) RETURN r.id LIMIT 1") + except Exception as exc: + errors.append(f"{rel} orphan/schema check failed: {exc}") + return errors + + def _event_relations(self, conn: kuzu.Connection, event_id: str, direction: str) -> List[Dict[str, Any]]: + relations: List[Dict[str, Any]] = [] + for rel, (source, target) in AUTOMATIC_REL_TABLES.items(): + if source != "RawEvent" and target != "RawEvent": + continue + if direction == "incoming": + query = f"MATCH (a:{source})-[r:{rel}]->(b:{target} {{id: $event_id}}) RETURN a.id, r.id, r.confidence, r.evidence_json, r.explanation" + else: + query = f"MATCH (a:{source} {{id: $event_id}})-[r:{rel}]->(b:{target}) RETURN b.id, r.id, r.confidence, r.evidence_json, r.explanation" + for row in self._rows(conn, query, {"event_id": event_id}): + relations.append( + { + "edge_type": rel, + "other_id": row[0], + "edge_id": row[1], + "confidence": row[2], + "evidence": _loads_list(row[3]), + "explanation": row[4], + } + ) + return relations + + def _temporal_relations(self, conn: kuzu.Connection, event_id: str, direction: str) -> List[Dict[str, Any]]: + if direction == "forward": + rows = self._rows( + conn, + """ + MATCH (a:RawEvent {id: $event_id})-[r:TEMPORAL_NEXT]->(b:RawEvent) + RETURN a.id, b.id, r.id, r.confidence, r.evidence_json, r.explanation + """, + {"event_id": event_id}, + ) + else: + rows = self._rows( + conn, + """ + MATCH (a:RawEvent)-[r:TEMPORAL_NEXT]->(b:RawEvent {id: $event_id}) + RETURN a.id, b.id, r.id, r.confidence, r.evidence_json, r.explanation + """, + {"event_id": event_id}, + ) + return [ + { + "from": row[0], + "to": row[1], + "edge_id": row[2], + "edge_type": "TEMPORAL_NEXT", + "confidence": row[3], + "evidence": _loads_list(row[4]), + "explanation": row[5], + } + for row in rows + ] + + def _threads_for_event(self, conn: kuzu.Connection, event_id: str) -> List[Dict[str, Any]]: + rows = self._rows( + conn, + """ + MATCH (e:RawEvent {id: $event_id})-[r:PROJECTED_INTO]->(t:MemoryThread) + RETURN t.id, t.metadata_json, r.evidence_json + """, + {"event_id": event_id}, + ) + return [ + { + "thread_id": row[0], + **_loads_dict(row[1]), + "evidence": _loads_list(row[2]), + } + for row in rows + ] + + def _thread_evidence(self, conn: kuzu.Connection, thread_id: str, allowed_scopes: Sequence[str]) -> List[Dict[str, Any]]: + scope_clause, scope_params = _scope_where("e", allowed_scopes) + rows = self._rows( + conn, + f""" + MATCH (e:RawEvent)-[r:PROJECTED_INTO]->(t:MemoryThread) + WHERE t.id = $thread_id + {scope_clause} + RETURN e.id, e.privacy_scope, r.evidence_json + ORDER BY e.timestamp DESC + LIMIT 5 + """, + {"thread_id": thread_id, **scope_params}, + ) + return [ + {"event_id": row[0], "evidence": _loads_list(row[2])} + for row in rows + ] + + def _find_event_for_query(self, *, query: str, user_id: str, scope: str) -> str: + conn = self._connect() + try: + allowed = _allowed_scopes(scope) + scope_clause, scope_params = _scope_where("e", allowed) + rows = self._rows( + conn, + f""" + MATCH (e:RawEvent) + WHERE e.user_id = $user_id + {scope_clause} + RETURN e.id, e.event_type, e.source_app, e.metadata_json, e.privacy_scope + ORDER BY e.timestamp DESC + LIMIT 100 + """, + {"user_id": user_id, **scope_params}, + ) + q = query.lower() + for row in rows: + haystack = " ".join([str(row[1]), str(row[2]), row[3] or ""]).lower() + if q and q in haystack: + return str(row[0]) + return str(rows[0][0]) if rows else "" + finally: + self._close(conn) + + +def _clean_params(params: Dict[str, Any]) -> Dict[str, Any]: + clean: Dict[str, Any] = {} + for key, value in params.items(): + if value is None: + clean[key] = "" + elif isinstance(value, bool): + clean[key] = bool(value) + elif isinstance(value, (int, float, str)): + clean[key] = value + else: + clean[key] = json.dumps(value, sort_keys=True) + return clean + + +def _rank_preference_gem(gem: Dict[str, Any], query: str) -> float: + score = _float(gem.get("score"), 0.0) + score += _float(gem.get("confidence"), 0.0) * 0.25 + tokens = _query_tokens(query) + if not tokens: + return score + + haystack = " ".join( + [ + str(gem.get("title") or ""), + str(gem.get("summary") or ""), + str(gem.get("source_memory_id") or ""), + str(gem.get("content_ref") or ""), + ] + ).lower() + phrase = str(query or "").strip().lower() + if phrase and phrase in haystack: + score += 0.75 + matches = sum(1 for token in tokens if token in haystack) + score += matches / max(len(tokens), 1) + return score + + +def _preference_signal_from_gem(gem: Dict[str, Any]) -> Dict[str, Any]: + return { + "signal_type": "memory_gem", + "thread_type": "preference", + "event_id": gem.get("event_id"), + "gem_id": gem.get("gem_id"), + "title": gem.get("title") or "", + "summary": gem.get("summary") or "", + "score": gem.get("score"), + "confidence": gem.get("confidence"), + "source_memory_id": gem.get("source_memory_id"), + "content_ref": gem.get("content_ref"), + "privacy_scope": gem.get("privacy_scope"), + "evidence": gem.get("evidence") or [], + } + + +def _supporting_event_from_gem(gem: Dict[str, Any]) -> Dict[str, Any]: + return { + "event_id": gem.get("event_id"), + "source_memory_id": gem.get("source_memory_id"), + "content_ref": gem.get("content_ref"), + "evidence": gem.get("evidence") or [], + } + + +def _normalize_gem_event_id(target: str) -> str: + value = str(target or "").strip() + if value.startswith("gem:"): + return value + return f"gem:{value}" + + +def _query_tokens(query: str) -> List[str]: + return [ + token + for token in re.findall(r"[a-z0-9][a-z0-9_-]*", str(query or "").lower()) + if len(token) > 2 + ][:12] + + +def _float(value: Any, default: float) -> float: + try: + return float(value) + except (TypeError, ValueError): + return float(default) + + +def _loads_dict(raw: Any) -> Dict[str, Any]: + try: + value = json.loads(raw or "{}") + except Exception: + return {} + return value if isinstance(value, dict) else {} + + +def _loads_list(raw: Any) -> List[Any]: + try: + value = json.loads(raw or "[]") + except Exception: + return [] + return value if isinstance(value, list) else [] + + +def _metadata_list(metadata: Dict[str, Any], key: str) -> List[str]: + value = metadata.get(key) + if isinstance(value, list): + return [str(item).strip() for item in value if str(item).strip()] + if isinstance(value, str) and value.strip(): + return [value.strip()] + return [] + + +def _gem_kind(event: RawEvent) -> str: + if event.source_app != "memory-gem" and not str(event.event_type or "").startswith("gem_"): + return "" + kind = str(event.metadata.get("kind") or "").strip().lower() + if not kind and str(event.event_type or "").startswith("gem_"): + kind = str(event.event_type).replace("gem_", "", 1).strip().lower() + return kind + + +def _slug(value: str) -> str: + return "".join(ch.lower() if ch.isalnum() else "-" for ch in str(value)).strip("-") or "unknown" + + +def _strictest_scope(scopes: Iterable[str]) -> str: + order = ["public", "project", "connector", "global", "private"] + ranked = {scope: idx for idx, scope in enumerate(order)} + return max((str(scope or "global") for scope in scopes), key=lambda item: ranked.get(item, 3)) + + +def _scope_where(alias: str, scopes: Sequence[str]) -> Tuple[str, Dict[str, str]]: + safe_alias = "".join(ch for ch in str(alias or "n") if ch.isalnum() or ch == "_") or "n" + params: Dict[str, str] = {} + clauses: List[str] = [] + for idx, scope in enumerate(scopes or ["global"]): + key = f"scope_{idx}" + params[key] = str(scope) + clauses.append(f"{safe_alias}.privacy_scope = ${key}") + return f"AND ({' OR '.join(clauses)})", params + + +def _allowed_scopes(scope: str) -> List[str]: + requested = str(scope or "global") + if requested == "private": + return ["public", "project", "connector", "global", "private"] + if requested == "global": + return ["public", "project", "connector", "global"] + if requested == "connector": + return ["public", "project", "connector"] + if requested == "project": + return ["public", "project"] + return [requested] + + +def _last_event_state(row: Sequence[Any]) -> str: + return f"Last event {row[0]} from {row[2]}:{row[3]} at {row[1]}" + + +def _timeline_summary(timeline: Sequence[Dict[str, Any]]) -> str: + if not timeline: + return "" + first = timeline[0] + last = timeline[-1] + return ( + f"{len(timeline)} event(s) from {first.get('timestamp')} to {last.get('timestamp')}; " + f"latest event type: {last.get('event_type')}." + ) diff --git a/dhee/world_memory/gem_extractor.py b/dhee/world_memory/gem_extractor.py new file mode 100644 index 0000000..aa24b65 --- /dev/null +++ b/dhee/world_memory/gem_extractor.py @@ -0,0 +1,498 @@ +from __future__ import annotations + +import hashlib +import re +from dataclasses import asdict, dataclass, field +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple + +from dhee.core.learnings import LearningExchange + +from .capture_store import CaptureStore +from .schema import CAUSAL_SCHEMA_VERSION, RawEvent + + +GEM_SCHEMA_VERSION = "memory_gem.v1" + +GEM_KINDS = { + "preference", + "decision", + "learning", + "task", + "artifact", + "context", + "fact", +} + +_PREFERENCE_TERMS = { + "prefer", + "preference", + "like", + "want", + "style", + "tone", + "always", + "never", + "recommended", +} +_DECISION_TERMS = { + "decided", + "decision", + "choose", + "chosen", + "instead", + "architecture", + "invariant", + "law", + "source of truth", +} +_LEARNING_TERMS = { + "learned", + "lesson", + "pattern", + "pitfall", + "works", + "failed", + "fix", + "regression", + "test", + "verify", + "checkpoint", +} +_TASK_TERMS = { + "todo", + "task", + "blocked", + "next", + "follow up", + "implement", + "build", +} +_ARTIFACT_TERMS = {"artifact", "file", "document", "pdf", "screenshot", "attachment"} +_PASSIVE_NOISE_PREFIXES = ( + "chotu observed useful visible screen activity", + "edited /", + "opened ", + "viewed ", +) +_NOISE_KINDS = {"file_touched", "artifact_chunk", "test_fixture", "fixture"} +_PRIVACY_MAP = { + "public": "public", + "shareable": "public", + "repo": "project", + "project": "project", + "workspace": "project", + "connector": "connector", + "global": "global", + "work": "global", + "personal": "private", + "private": "private", + "secret": "private", + "restricted": "private", +} + + +@dataclass +class MemoryGem: + id: str + source_memory_id: str + user_id: str + kind: str + title: str + summary: str + score: float + confidence: float + privacy_scope: str + source_app: str = "" + memory_type: str = "" + categories: List[str] = field(default_factory=list) + source_event_id: Optional[str] = None + timestamp: str = "" + evidence: List[Dict[str, Any]] = field(default_factory=list) + metadata: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + data = asdict(self) + data["score"] = round(float(self.score), 4) + data["confidence"] = round(float(self.confidence), 4) + return data + + def to_raw_event(self) -> RawEvent: + return RawEvent( + id=f"gem:{self.id}", + schema_version=CAUSAL_SCHEMA_VERSION, + user_id=self.user_id, + source_app="memory-gem", + namespace="memory.gems", + event_type=f"gem_{self.kind}", + timestamp=self.timestamp, + content_ref=f"memory:{self.source_memory_id}", + content_hash=_stable_hash([self.source_memory_id, self.summary]), + privacy_scope=self.privacy_scope, + metadata={ + "gem_schema_version": GEM_SCHEMA_VERSION, + "gem_id": self.id, + "kind": self.kind, + "score": round(float(self.score), 4), + "confidence": round(float(self.confidence), 4), + "title": self.title, + "summary": self.summary, + "source_memory_id": self.source_memory_id, + "source_event_id": self.source_event_id, + "source_app": self.source_app, + "memory_type": self.memory_type, + "categories": self.categories, + "evidence": self.evidence, + **self.metadata, + }, + ) + + +def extract_memory_gems( + memories: Iterable[Dict[str, Any]], + *, + user_id: str = "default", + limit: int = 50, + min_score: float = 0.62, +) -> List[MemoryGem]: + candidates: List[MemoryGem] = [] + for memory in memories: + gem = score_memory_gem(memory, default_user_id=user_id) + if gem and gem.score >= float(min_score): + candidates.append(gem) + candidates.sort(key=lambda item: (item.score, item.confidence, item.timestamp), reverse=True) + return candidates[: max(1, int(limit or 50))] + + +def score_memory_gem(memory: Dict[str, Any], *, default_user_id: str = "default") -> Optional[MemoryGem]: + text = _memory_text(memory) + if not text: + return None + metadata = _dict(memory.get("metadata")) + categories = _list(memory.get("categories")) + source_memory_id = str(memory.get("id") or "").strip() + if not source_memory_id: + return None + memory_type = str(memory.get("memory_type") or metadata.get("memory_type") or "").strip() + source_app = str(memory.get("source_app") or metadata.get("source_app") or "").strip() + kind_hint = str(metadata.get("kind") or "").strip().lower() + if kind_hint in _NOISE_KINDS: + return None + + lowered = text.lower() + score = 0.16 + score += _float(memory.get("strength"), 0.5) * 0.18 + score += _float(memory.get("importance") or metadata.get("importance"), 0.5) * 0.18 + score += min(0.14, _float(memory.get("access_count"), 0.0) * 0.02) + + kind, kind_score = _classify_kind(text, metadata, categories, memory_type) + score += kind_score + if memory_type in {"task", "episodic", "semantic"}: + score += {"task": 0.08, "episodic": 0.04, "semantic": 0.05}.get(memory_type, 0.0) + if categories: + score += min(0.08, len(categories) * 0.02) + if source_app in {"gmail", "chrome", "browser", "codex", "claude-code", "chotu"}: + score += 0.03 + if len(text) >= 80: + score += 0.04 + if len(text) >= 280: + score += 0.03 + if _looks_like_noise(lowered, metadata, categories): + score -= 0.28 + if str(memory.get("tombstone") or "0") not in {"0", "False", "false", ""}: + return None + + score = max(0.0, min(1.0, score)) + confidence = max(0.2, min(1.0, 0.48 + kind_score + _float(memory.get("strength"), 0.5) * 0.22)) + title = _title_for(text, kind) + timestamp = str(memory.get("updated_at") or memory.get("created_at") or metadata.get("event_time") or "") + privacy_scope = _privacy_scope(memory, metadata) + evidence = [ + { + "kind": "memory", + "memory_id": source_memory_id, + "source_event_id": memory.get("source_event_id") or metadata.get("source_event_id"), + "source_app": source_app, + "memory_type": memory_type, + "categories": categories, + "strength": memory.get("strength"), + "importance": memory.get("importance") or metadata.get("importance"), + } + ] + gem_id = "memgem_" + _stable_hash([source_memory_id, kind, title])[:16] + return MemoryGem( + id=gem_id, + source_memory_id=source_memory_id, + user_id=str(memory.get("user_id") or metadata.get("user_id") or default_user_id), + kind=kind, + title=title, + summary=_clip(text, 900), + score=score, + confidence=confidence, + privacy_scope=privacy_scope, + source_app=source_app, + memory_type=memory_type, + categories=categories, + source_event_id=memory.get("source_event_id") or metadata.get("source_event_id"), + timestamp=timestamp, + evidence=evidence, + metadata={ + "extractor": "deterministic", + "source_namespace": memory.get("namespace") or metadata.get("namespace"), + }, + ) + + +def write_gem_raw_events( + capture_store: CaptureStore, + gems: Sequence[MemoryGem], + *, + overwrite_existing: bool = False, +) -> Dict[str, Any]: + written: List[str] = [] + skipped: List[str] = [] + for gem in gems: + event = gem.to_raw_event() + if capture_store.get_raw_event(event.id): + if not overwrite_existing: + skipped.append(event.id) + continue + try: + capture_store.record_raw_event(event) + written.append(event.id) + except Exception as exc: + if "UNIQUE constraint failed" in str(exc): + skipped.append(event.id) + continue + raise + return {"written": written, "skipped_existing": skipped} + + +def submit_gem_learning_candidates( + exchange: LearningExchange, + gems: Sequence[MemoryGem], + *, + repo: Optional[str] = None, + source_agent_id: str = "memory-gem-extractor", + status: str = "candidate", +) -> Dict[str, Any]: + submitted: List[str] = [] + rejected: List[Dict[str, Any]] = [] + for gem in gems: + if gem.kind not in {"learning", "preference", "decision"}: + continue + try: + candidate = exchange.submit( + title=gem.title, + body=gem.summary, + kind=_learning_kind(gem.kind), + source_agent_id=source_agent_id, + source_harness="dhee", + task_type=f"memory_gem_{gem.kind}", + repo=repo, + scope="personal", + confidence=gem.confidence, + utility=gem.score, + evidence=gem.evidence, + metadata={ + "gem_id": gem.id, + "source_memory_id": gem.source_memory_id, + "privacy_scope": gem.privacy_scope, + "schema_version": GEM_SCHEMA_VERSION, + }, + status=status, + learning_id="lrn_" + gem.id[-16:], + ) + submitted.append(candidate.id) + except Exception as exc: + rejected.append({"gem_id": gem.id, "reason": str(exc)}) + return {"submitted": submitted, "rejected": rejected} + + +def submit_projected_gem_learning_candidate( + exchange: LearningExchange, + projected_gem: Dict[str, Any], + *, + repo: Optional[str] = None, + source_agent_id: str = "memory-gem-debug", + status: str = "candidate", +) -> Dict[str, Any]: + if projected_gem.get("status") != "ok": + return { + "submitted": [], + "rejected": [ + { + "target": projected_gem.get("target"), + "reason": projected_gem.get("status") or "not_found", + } + ], + } + + gem = _dict(projected_gem.get("gem")) + gem_kind = str(gem.get("kind") or "").strip().lower() + gem_id = str(gem.get("gem_id") or "").strip() + if gem_kind not in {"learning", "preference", "decision"}: + return { + "submitted": [], + "rejected": [ + { + "gem_id": gem_id, + "reason": f"gem kind {gem_kind or 'unknown'} is not promotable", + } + ], + } + if not gem_id: + gem_id = _stable_hash([gem.get("event_id"), gem.get("title"), gem.get("summary")])[:16] + + source = _dict(projected_gem.get("source_memory")) + evidence = list(projected_gem.get("supporting_events") or []) + evidence.append( + { + "kind": "memory_gem", + "event_id": gem.get("event_id"), + "gem_id": gem_id, + "source_memory_id": source.get("memory_id"), + "content_ref": source.get("content_ref") or gem.get("content_ref"), + } + ) + try: + candidate = exchange.submit( + title=str(gem.get("title") or "").strip(), + body=str(gem.get("summary") or "").strip(), + kind=_learning_kind(gem_kind), + source_agent_id=source_agent_id, + source_harness="dhee", + task_type=f"memory_gem_{gem_kind}", + repo=repo, + scope="personal", + confidence=_float(gem.get("confidence"), 0.5), + utility=_float(gem.get("score"), 0.0), + evidence=evidence, + metadata={ + "gem_id": gem_id, + "event_id": gem.get("event_id"), + "source_memory_id": source.get("memory_id"), + "source_event_id": source.get("source_event_id"), + "privacy_scope": gem.get("privacy_scope"), + "schema_version": gem.get("schema_version"), + "projection_version": gem.get("projection_version"), + }, + status=status, + learning_id="lrn_" + gem_id[-16:], + ) + return {"submitted": [candidate.id], "rejected": [], "candidate": candidate.to_dict()} + except Exception as exc: + return {"submitted": [], "rejected": [{"gem_id": gem_id, "reason": str(exc)}]} + + +def summarize_gems(gems: Sequence[MemoryGem]) -> Dict[str, Any]: + by_kind: Dict[str, int] = {} + by_scope: Dict[str, int] = {} + for gem in gems: + by_kind[gem.kind] = by_kind.get(gem.kind, 0) + 1 + by_scope[gem.privacy_scope] = by_scope.get(gem.privacy_scope, 0) + 1 + return { + "count": len(gems), + "by_kind": by_kind, + "by_privacy_scope": by_scope, + "top": [gem.to_dict() for gem in gems[:10]], + } + + +def _classify_kind( + text: str, + metadata: Dict[str, Any], + categories: Sequence[str], + memory_type: str, +) -> Tuple[str, float]: + haystack = " ".join([text, " ".join(categories), memory_type, str(metadata.get("memory_type") or "")]).lower() + scores = { + "preference": _term_score(haystack, _PREFERENCE_TERMS, 0.24), + "decision": _term_score(haystack, _DECISION_TERMS, 0.23), + "learning": _term_score(haystack, _LEARNING_TERMS, 0.24), + "task": _term_score(haystack, _TASK_TERMS, 0.18), + "artifact": _term_score(haystack, _ARTIFACT_TERMS, 0.16), + } + if memory_type == "task": + scores["task"] += 0.12 + if metadata.get("source_event_id"): + scores["context"] = 0.08 + kind = max(scores, key=scores.get) + if scores[kind] <= 0.02: + return "fact", 0.06 + return kind if kind in GEM_KINDS else "fact", scores[kind] + + +def _term_score(text: str, terms: Sequence[str], cap: float) -> float: + hits = sum(1 for term in terms if term in text) + return min(cap, hits * (cap / 4.0)) + + +def _looks_like_noise(text: str, metadata: Dict[str, Any], categories: Sequence[str]) -> bool: + if any(text.startswith(prefix) for prefix in _PASSIVE_NOISE_PREFIXES): + useful_terms = _PREFERENCE_TERMS | _DECISION_TERMS | _LEARNING_TERMS | _TASK_TERMS + return not any(term in text for term in useful_terms) + if str(metadata.get("kind") or "").lower() in _NOISE_KINDS: + return True + if "artifact_chunk" in {str(item).lower() for item in categories}: + return True + return False + + +def _memory_text(memory: Dict[str, Any]) -> str: + return str(memory.get("memory") or memory.get("content") or memory.get("text") or "").strip() + + +def _title_for(text: str, kind: str) -> str: + first = re.split(r"[\n.?!]", text.strip(), maxsplit=1)[0].strip() + if not first: + first = text.strip() + return f"{kind.title()}: {_clip(first, 92)}" + + +def _privacy_scope(memory: Dict[str, Any], metadata: Dict[str, Any]) -> str: + raw = str( + metadata.get("privacy_scope") + or metadata.get("scope") + or memory.get("confidentiality_scope") + or metadata.get("confidentiality_scope") + or "global" + ).strip().lower() + return _PRIVACY_MAP.get(raw, "global") + + +def _learning_kind(gem_kind: str) -> str: + return { + "preference": "policy", + "decision": "policy", + "learning": "heuristic", + }.get(gem_kind, "memory") + + +def _clip(text: str, max_chars: int) -> str: + value = " ".join(str(text or "").split()) + if len(value) <= max_chars: + return value + return value[: max_chars - 1].rstrip() + "..." + + +def _dict(value: Any) -> Dict[str, Any]: + return value if isinstance(value, dict) else {} + + +def _list(value: Any) -> List[str]: + if isinstance(value, list): + return [str(item) for item in value if str(item).strip()] + if isinstance(value, str) and value.strip(): + return [value.strip()] + return [] + + +def _float(value: Any, default: float) -> float: + try: + return float(value) + except (TypeError, ValueError): + return default + + +def _stable_hash(parts: Any) -> str: + return hashlib.sha256(repr(parts).encode("utf-8")).hexdigest() diff --git a/dhee/world_memory/schema.py b/dhee/world_memory/schema.py index 40ead8f..8243337 100644 --- a/dhee/world_memory/schema.py +++ b/dhee/world_memory/schema.py @@ -4,6 +4,135 @@ from typing import Any, Dict, List, Optional +CAUSAL_SCHEMA_VERSION = "csm.v1" +CAUSAL_PROJECTION_VERSION = "csm.projection.v1" + +AUTOMATIC_CAUSAL_EDGE_TYPES = { + "TEMPORAL_NEXT", + "OBSERVED_ON", + "MENTIONS", + "CREATED", + "UPDATED", + "BELONGS_TO", + "PROJECTED_INTO", +} + +CHECKPOINT_CAUSAL_EDGE_TYPES = { + "CAUSED", + "ENABLED", + "BLOCKED", + "RESOLVED", + "CONTRADICTED", + "REFINED", + "SUPPORTED", + "INVALIDATED", + "DERIVED_FACT", + "SKILL_EXTRACTED_FROM", +} + +CAUSAL_EDGE_STATUSES = { + "inferred", + "verified", + "rejected", + "stale", + "superseded", + "archived", +} + + +@dataclass +class RawEvent: + id: str + user_id: str + source_app: str + event_type: str + timestamp: str + schema_version: str = CAUSAL_SCHEMA_VERSION + session_id: Optional[str] = None + namespace: str = "" + content_ref: Optional[str] = None + content_hash: Optional[str] = None + privacy_scope: str = "global" + metadata: Dict[str, Any] = field(default_factory=dict) + deleted_at: Optional[str] = None + redacted_at: Optional[str] = None + redaction_reason: Optional[str] = None + + +@dataclass +class EventFrame: + id: str + user_id: str + frame_type: str + summary: str + source_event_ids: List[str] + confidence: float + created_at: str + schema_version: str = CAUSAL_SCHEMA_VERSION + privacy_scope: str = "global" + metadata: Dict[str, Any] = field(default_factory=dict) + deleted_at: Optional[str] = None + redacted_at: Optional[str] = None + redaction_reason: Optional[str] = None + + +@dataclass +class CausalEdge: + id: str + source_id: str + target_id: str + edge_type: str + confidence: float + status: str + evidence_event_ids: List[str] + inferred_by: str + explanation: str + created_at: str + schema_version: str = CAUSAL_SCHEMA_VERSION + user_id: str = "default" + privacy_scope: str = "global" + metadata: Dict[str, Any] = field(default_factory=dict) + deleted_at: Optional[str] = None + redacted_at: Optional[str] = None + redaction_reason: Optional[str] = None + + +@dataclass +class CheckpointReport: + id: str + user_id: str + status: str + report: Dict[str, Any] + created_at: str + schema_version: str = CAUSAL_SCHEMA_VERSION + session_id: Optional[str] = None + time_window_start: Optional[str] = None + time_window_end: Optional[str] = None + event_frame_ids: List[str] = field(default_factory=list) + causal_edge_ids: List[str] = field(default_factory=list) + summary_memory_id: Optional[str] = None + + +@dataclass +class RetrievalTrace: + id: str + user_id: str + mode: str + scope: str + query: str + target_id: str + retrieval_path: List[Dict[str, Any]] + evidence: List[Dict[str, Any]] + result: Dict[str, Any] + created_at: str + schema_version: str = CAUSAL_SCHEMA_VERSION + privacy_scope: str = "global" + metadata: Dict[str, Any] = field(default_factory=dict) + deleted_at: Optional[str] = None + redacted_at: Optional[str] = None + redaction_reason: Optional[str] = None + + @dataclass class WorldState: id: str diff --git a/dhee/world_memory/service.py b/dhee/world_memory/service.py index ff0db6a..8c0766e 100644 --- a/dhee/world_memory/service.py +++ b/dhee/world_memory/service.py @@ -18,15 +18,22 @@ BeautifulSoup = None from .capture_store import CaptureStore +from .causal_graph import CausalGraphProjection from .encoder import DeterministicFrameEncoder, create_default_encoder from .predictor import ActionConditionedPredictor, compute_surprise from .schema import ( + CAUSAL_SCHEMA_VERSION, + CausalEdge, CaptureAction, CaptureEvent, CaptureLink, CapturedArtifact, CapturedObservation, CapturedSurface, + CheckpointReport, + EventFrame, + RawEvent, + RetrievalTrace, ) from .session_graph import SessionGraphStore from .store import WorldMemoryStore, asdict_chunk @@ -139,6 +146,7 @@ def __init__( world_store: WorldMemoryStore, graph_store: SessionGraphStore, memory_client: Any, + graph_projection: CausalGraphProjection | None = None, encoder: Any | None = None, predictor: Any | None = None, ): @@ -146,6 +154,7 @@ def __init__( self.world_store = world_store self.graph_store = graph_store self.memory_client = memory_client + self.graph_projection = graph_projection self.encoder = encoder or create_default_encoder() self.predictor = predictor or ActionConditionedPredictor() self._ensure_default_policies() @@ -161,6 +170,7 @@ def from_default_runtime(cls, *, memory: Any, data_dir: Optional[str] = None) -> world_store=WorldMemoryStore(str(memory_os_dir / "world_memory.db")), graph_store=SessionGraphStore(str(runtime_root / "capture" / "sessions")), memory_client=DheeMemoryClient(memory), + graph_projection=CausalGraphProjection(str(memory_os_dir / "causal_scene.kuzu")), ) def _ensure_default_policies(self) -> None: @@ -357,6 +367,7 @@ def record_action(self, payload: Dict[str, Any]) -> Dict[str, Any]: metadata=metadata, ) ) + self._record_raw_from_capture_event(event) return { "action": asdict(action), "surface": asdict(surface), @@ -439,6 +450,7 @@ def record_observation(self, payload: Dict[str, Any]) -> Dict[str, Any]: }, ) ) + self._record_raw_from_capture_event(event) return { "observation": asdict(observation), "surface": asdict(surface), @@ -492,7 +504,7 @@ def record_artifact(self, payload: Dict[str, Any]) -> Dict[str, Any]: active_surface_id=surface.id, artifact_bytes=self.graph_store.artifact_bytes(session.id), ) - self.capture_store.record_event( + event = self.capture_store.record_event( CaptureEvent( id=str(uuid.uuid4()), session_id=session.id, @@ -518,6 +530,7 @@ def record_artifact(self, payload: Dict[str, Any]) -> Dict[str, Any]: metadata=artifact.metadata, ) ) + self._record_raw_from_capture_event(event) if artifact.action_id: self.graph_store.append_link( CaptureLink( @@ -550,6 +563,309 @@ def cleanup_expired_artifacts(self) -> Dict[str, Any]: self.graph_store.patch_manifest(session_id, artifact_bytes=self.graph_store.artifact_bytes(session_id)) return {"checked": checked, "removed": removed} + def record_raw_event(self, payload: Dict[str, Any]) -> Dict[str, Any]: + """Write a RawEvent to SQLite truth first, then rebuild-backed Kuzu projection.""" + metadata = dict(payload.get("metadata") or {}) + text = str(payload.get("text") or payload.get("text_payload") or "") + content_hash = str(payload.get("content_hash") or "") + if not content_hash and text: + content_hash = hashlib.sha256(text.encode("utf-8")).hexdigest() + event = RawEvent( + id=str(payload.get("id") or uuid.uuid4()), + schema_version=str(payload.get("schema_version") or CAUSAL_SCHEMA_VERSION), + user_id=str(payload.get("user_id") or "default"), + session_id=str(payload.get("session_id") or "") or None, + source_app=str(payload.get("source_app") or "unknown").strip().lower(), + namespace=str(payload.get("namespace") or ""), + event_type=str(payload.get("event_type") or "raw_event"), + timestamp=str(payload.get("timestamp") or payload.get("created_at") or _now_iso()), + content_ref=str(payload.get("content_ref") or "") or None, + content_hash=content_hash or None, + privacy_scope=str(payload.get("privacy_scope") or payload.get("scope") or "global"), + metadata=metadata, + deleted_at=str(payload.get("deleted_at") or "") or None, + redacted_at=str(payload.get("redacted_at") or "") or None, + redaction_reason=str(payload.get("redaction_reason") or "") or None, + ) + self.capture_store.record_raw_event(event) + projection = self._sync_causal_projection(user_id=event.user_id) + return {"event": asdict(event), "projection": projection} + + def compile_causal_checkpoint( + self, + *, + session_id: Optional[str] = None, + user_id: str = "default", + time_window_start: Optional[str] = None, + time_window_end: Optional[str] = None, + ) -> Dict[str, Any]: + raw_events = self.capture_store.list_raw_events( + user_id=user_id, + session_id=session_id, + limit=1000, + include_deleted=False, + include_redacted=False, + order="asc", + ) + if time_window_start: + raw_events = [event for event in raw_events if event.timestamp >= time_window_start] + if time_window_end: + raw_events = [event for event in raw_events if event.timestamp <= time_window_end] + + frames: List[EventFrame] = [] + for event in raw_events: + frame = EventFrame( + id=f"frame:{event.id}", + schema_version=CAUSAL_SCHEMA_VERSION, + user_id=event.user_id, + frame_type=event.event_type or "event", + summary=_raw_event_summary(event), + source_event_ids=[event.id], + confidence=0.8, + privacy_scope=event.privacy_scope, + created_at=_now_iso(), + metadata={ + "source_app": event.source_app, + "session_id": event.session_id, + "source_event_id": event.id, + }, + ) + self.capture_store.add_event_frame(frame) + frames.append(frame) + + edges: List[CausalEdge] = [] + for previous, current in zip(frames, frames[1:]): + prev_event_id = previous.source_event_ids[0] + current_event_id = current.source_event_ids[0] + edge = CausalEdge( + id=f"edge:supported:{previous.id}:{current.id}", + schema_version=CAUSAL_SCHEMA_VERSION, + user_id=user_id, + source_id=previous.id, + target_id=current.id, + edge_type="SUPPORTED", + confidence=0.55, + status="inferred", + evidence_event_ids=[prev_event_id, current_event_id], + inferred_by="rule", + explanation="Adjacent checkpoint frames support the local scene progression; no causal claim is made.", + privacy_scope=_strictest_privacy([previous.privacy_scope, current.privacy_scope]), + created_at=_now_iso(), + metadata={"checkpoint_v": "0"}, + ) + self.capture_store.add_causal_edge(edge) + edges.append(edge) + + summary_text = _checkpoint_summary(raw_events) + summary_memory_id = None + if raw_events and summary_text: + remembered = self.memory_client.remember( + summary_text, + user_id=user_id, + namespace=(raw_events[0].namespace if raw_events else "default"), + source_app="causal-checkpoint", + scope=_strictest_privacy([event.privacy_scope for event in raw_events]), + categories=["causal_scene_memory", "checkpoint_summary"], + metadata={ + "memory_type": "causal_checkpoint_summary", + "session_id": session_id, + "source_event_ids": [event.id for event in raw_events], + "schema_version": CAUSAL_SCHEMA_VERSION, + }, + ) + summary_memory_id = remembered.get("id") if isinstance(remembered, dict) else None + + report = CheckpointReport( + id=str(uuid.uuid4()), + schema_version=CAUSAL_SCHEMA_VERSION, + user_id=user_id, + session_id=session_id, + time_window_start=time_window_start, + time_window_end=time_window_end, + status="completed", + event_frame_ids=[frame.id for frame in frames], + causal_edge_ids=[edge.id for edge in edges], + summary_memory_id=summary_memory_id, + report={ + "raw_event_count": len(raw_events), + "event_frame_count": len(frames), + "causal_edge_count": len(edges), + "summary": summary_text, + }, + created_at=_now_iso(), + ) + self.capture_store.add_checkpoint_report(report) + projection = self._sync_causal_projection(user_id=user_id) + return { + "report": asdict(report), + "eventFrames": [asdict(frame) for frame in frames], + "causalEdges": [asdict(edge) for edge in edges], + "projection": projection, + } + + def get_active_frontier(self, *, user_id: str = "default", scope: str = "global") -> Dict[str, Any]: + projection = self._require_causal_projection() + result = projection.get_active_frontier(user_id=user_id, scope=scope) + return self._record_retrieval_trace(result, mode="frontier", user_id=user_id, scope=scope) + + def causal_why( + self, + *, + event_id: Optional[str] = None, + query: str = "", + user_id: str = "default", + scope: str = "global", + ) -> Dict[str, Any]: + projection = self._require_causal_projection() + result = projection.why(event_id=event_id, query=query, user_id=user_id, scope=scope) + return self._record_retrieval_trace( + result, + mode="why", + user_id=user_id, + scope=scope, + query=query, + target_id=event_id or result.get("target_event_id", ""), + ) + + def causal_what_happened(self, *, target_id: str, user_id: str = "default", scope: str = "global") -> Dict[str, Any]: + projection = self._require_causal_projection() + result = projection.what_happened(target_id=target_id, user_id=user_id, scope=scope) + return self._record_retrieval_trace( + result, + mode="what_happened", + user_id=user_id, + scope=scope, + target_id=target_id, + ) + + def causal_handoff(self, *, user_id: str = "default", scope: str = "global") -> Dict[str, Any]: + projection = self._require_causal_projection() + result = projection.handoff(user_id=user_id, scope=scope) + return self._record_retrieval_trace(result, mode="handoff", user_id=user_id, scope=scope) + + def causal_preference(self, *, query: str = "", user_id: str = "default", scope: str = "global") -> Dict[str, Any]: + projection = self._require_causal_projection() + result = projection.preference(query=query, user_id=user_id, scope=scope) + return self._record_retrieval_trace( + result, + mode="preference", + user_id=user_id, + scope=scope, + query=query, + target_id=(result.get("preference_signal") or {}).get("event_id", ""), + ) + + def causal_gems( + self, + *, + user_id: str = "default", + scope: str = "global", + kind: Optional[str] = None, + limit: int = 50, + ) -> Dict[str, Any]: + projection = self._require_causal_projection() + result = projection.list_gems(user_id=user_id, scope=scope, kind=kind, limit=limit) + return self._record_retrieval_trace( + result, + mode="gems", + user_id=user_id, + scope=scope, + query=kind or "", + metadata={"kind": kind, "limit": limit}, + ) + + def causal_show_gem(self, target: str, *, user_id: str = "default", scope: str = "global") -> Dict[str, Any]: + projection = self._require_causal_projection() + result = projection.show_gem(target, user_id=user_id, scope=scope) + return self._record_retrieval_trace( + result, + mode="show_gem", + user_id=user_id, + scope=scope, + target_id=(result.get("gem") or {}).get("event_id", target), + ) + + def explain_causal_retrieval(self, retrieval_id: str, *, user_id: str = "default") -> Dict[str, Any]: + trace = self.capture_store.get_retrieval_trace(retrieval_id) + if not trace or trace.user_id != user_id: + return { + "query_id": retrieval_id, + "status": "not_found", + "traversal": [], + "evidence": [], + } + return { + "query_id": trace.id, + "status": "ok", + "schema_version": trace.schema_version, + "mode": trace.mode, + "scope": trace.scope, + "query": trace.query, + "target_id": trace.target_id, + "privacy_scope": trace.privacy_scope, + "created_at": trace.created_at, + "traversal": trace.retrieval_path, + "evidence": trace.evidence, + "result": trace.result, + "metadata": trace.metadata, + } + + def causal_submit_gem( + self, + target: str, + *, + learning_exchange: Any, + user_id: str = "default", + scope: str = "global", + repo: Optional[str] = None, + status: str = "candidate", + ) -> Dict[str, Any]: + from .gem_extractor import submit_projected_gem_learning_candidate + + projected = self.causal_show_gem(target, user_id=user_id, scope=scope) + result = submit_projected_gem_learning_candidate( + learning_exchange, + projected, + repo=repo, + status=status, + ) + result["gem"] = projected + return result + + def _record_retrieval_trace( + self, + result: Dict[str, Any], + *, + mode: str, + user_id: str, + scope: str, + query: str = "", + target_id: str = "", + metadata: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + if not isinstance(result, dict): + return result + clean_result = json.loads(json.dumps(result)) + retrieval_id = "retr_" + uuid.uuid4().hex[:16] + trace = RetrievalTrace( + id=retrieval_id, + user_id=user_id, + mode=mode, + scope=scope, + query=str(query or ""), + target_id=str(target_id or ""), + retrieval_path=_as_dict_list(result.get("retrieval_path")), + evidence=_retrieval_evidence(result), + result=clean_result, + privacy_scope=scope, + metadata=dict(metadata or {}), + created_at=_now_iso(), + ) + self.capture_store.add_retrieval_trace(trace) + out = dict(result) + out["retrieval_id"] = retrieval_id + return out + def record_capture_event(self, payload: Dict[str, Any]) -> Dict[str, Any]: session_id = str(payload.get("session_id") or "").strip() session = self.capture_store.get_session(session_id) @@ -637,6 +953,7 @@ def record_capture_event(self, payload: Dict[str, Any]) -> Dict[str, Any]: metadata=metadata, ) ) + self._record_raw_from_capture_event(event) return { "event": asdict(event), "worldTransition": world_record, @@ -945,6 +1262,50 @@ def timeline(self, *, user_id: str = "default", source_app: Optional[str] = None items.sort(key=lambda item: item.get("timestamp") or "", reverse=True) return {"items": items[:limit]} + def _record_raw_from_capture_event(self, event: CaptureEvent) -> RawEvent: + metadata = { + **event.metadata, + "capture_event_id": event.id, + "memory_id": event.memory_id, + "world_ptr": event.world_ptr, + "window_title": event.window_title, + "url": event.url, + "structured_payload": event.structured_payload, + "action_type": event.action_type, + } + text = event.text_payload or _summarize_structured_payload(event.structured_payload) + raw = RawEvent( + id=event.id, + schema_version=CAUSAL_SCHEMA_VERSION, + user_id=event.user_id, + session_id=event.session_id, + source_app=event.source_app, + namespace=event.namespace, + event_type=event.event_type, + timestamp=event.created_at, + content_ref=event.world_ptr or (f"memory:{event.memory_id}" if event.memory_id else None), + content_hash=hashlib.sha256(text.encode("utf-8")).hexdigest() if text else None, + privacy_scope=str(event.metadata.get("privacy_scope") or event.metadata.get("scope") or "global"), + metadata=metadata, + ) + try: + self.capture_store.record_raw_event(raw) + except Exception as exc: + if "UNIQUE constraint failed" not in str(exc): + raise + self._sync_causal_projection(user_id=raw.user_id) + return raw + + def _sync_causal_projection(self, *, user_id: str = "default") -> Optional[Dict[str, Any]]: + if not self.graph_projection: + return None + return self.graph_projection.sync(self.capture_store, user_id=user_id) + + def _require_causal_projection(self) -> CausalGraphProjection: + if not self.graph_projection: + raise RuntimeError("Causal graph projection is not configured") + return self.graph_projection + def _require_session(self, session_id: str) -> Any: session = self.capture_store.get_session(session_id) if not session: @@ -1182,6 +1543,42 @@ def _now_iso() -> str: return datetime.now(timezone.utc).isoformat() +def _raw_event_summary(event: RawEvent) -> str: + metadata = event.metadata or {} + text = str( + metadata.get("text") + or metadata.get("text_payload") + or metadata.get("window_title") + or metadata.get("url") + or metadata.get("action_type") + or "" + ).strip() + if text: + return f"{event.source_app}:{event.event_type} - {text[:180]}" + return f"{event.source_app}:{event.event_type} at {event.timestamp}" + + +def _checkpoint_summary(events: List[RawEvent]) -> str: + if not events: + return "" + first = events[0] + last = events[-1] + apps = sorted({event.source_app for event in events if event.source_app}) + types = sorted({event.event_type for event in events if event.event_type}) + return ( + f"Causal checkpoint over {len(events)} raw event(s) from {first.timestamp} " + f"to {last.timestamp}. Apps: {', '.join(apps) or 'unknown'}. " + f"Event types: {', '.join(types) or 'unknown'}." + ) + + +def _strictest_privacy(scopes: Iterable[str]) -> str: + order = ["public", "project", "connector", "global", "private"] + ranked = {scope: idx for idx, scope in enumerate(order)} + values = [str(scope or "global") for scope in scopes] + return max(values or ["global"], key=lambda item: ranked.get(item, 3)) + + def _coerce_iso(value: Any) -> Optional[datetime]: if value in (None, ""): return None @@ -1589,6 +1986,31 @@ def _collect_focus_targets(matches: List[Any], limit: int = 6) -> List[Dict[str, return targets +def _as_dict_list(value: Any) -> List[Dict[str, Any]]: + if not isinstance(value, list): + return [] + return [dict(item) for item in value if isinstance(item, dict)] + + +def _retrieval_evidence(result: Dict[str, Any]) -> List[Dict[str, Any]]: + evidence: List[Dict[str, Any]] = [] + for key in ("evidence", "supporting_events", "source_events"): + evidence.extend(_as_dict_list(result.get(key))) + for item in _as_dict_list(result.get("causal_path")): + event_id = item.get("from") or item.get("to") or item.get("event_id") + if event_id: + evidence.append({"event_id": event_id, "edge_id": item.get("edge_id"), "edge_type": item.get("edge_type")}) + seen = set() + deduped: List[Dict[str, Any]] = [] + for item in evidence: + key = json.dumps(item, sort_keys=True) + if key in seen: + continue + seen.add(key) + deduped.append(item) + return deduped + + def _build_current_page_skim(chunks: List[Dict[str, Any]], limit: int = 6) -> List[Dict[str, Any]]: skim: List[Dict[str, Any]] = [] for chunk in chunks: diff --git a/pyproject.toml b/pyproject.toml index b42deb2..e20230b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "fastapi>=0.100.0", "uvicorn>=0.20.0", "python-multipart>=0.0.9", + "kuzu>=0.11.3", ] [project.optional-dependencies] @@ -45,6 +46,7 @@ ollama = ["ollama>=0.4.0"] nvidia = ["openai>=1.0.0"] zvec = ["zvec>=0.2.1"] sqlite_vec = ["sqlite-vec>=0.1.1"] +graph = ["kuzu>=0.11.3"] # Local Qwen stack (CPU-native, zero API cost) local = ["llama-cpp-python>=0.3", "sentence-transformers>=3.0"] # Integrations @@ -72,6 +74,7 @@ all = [ "fastapi>=0.100.0", "uvicorn>=0.20.0", "python-multipart>=0.0.9", + "kuzu>=0.11.3", "huggingface_hub>=0.24.0", "llama-cpp-python>=0.3", "sentence-transformers>=3.0", diff --git a/tests/test_causal_scene_memory.py b/tests/test_causal_scene_memory.py new file mode 100644 index 0000000..cfa6d6c --- /dev/null +++ b/tests/test_causal_scene_memory.py @@ -0,0 +1,428 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +from dhee.core.learnings import LearningExchange +from dhee.world_memory.capture_store import CaptureStore +from dhee.world_memory.causal_graph import CausalGraphProjection +from dhee.world_memory.gem_extractor import ( + extract_memory_gems, + write_gem_raw_events, +) +from dhee.world_memory.schema import CAUSAL_SCHEMA_VERSION, CausalEdge, RetrievalTrace +from dhee.world_memory.service import MemoryOSService +from dhee.world_memory.session_graph import SessionGraphStore +from dhee.world_memory.store import WorldMemoryStore + + +class FakeMemoryClient: + def __init__(self) -> None: + self.rows = [] + + def remember(self, content: str, **kwargs): + row = { + "id": f"mem-{len(self.rows) + 1}", + "memory": content, + "metadata": dict(kwargs.get("metadata") or {}), + "source_app": kwargs.get("source_app"), + } + self.rows.append(row) + return row + + def recall(self, *args, **kwargs): + return [] + + def recent(self, *args, **kwargs): + return list(reversed(self.rows))[: kwargs.get("limit", 12)] + + +def _service(tmp_path: Path) -> tuple[MemoryOSService, FakeMemoryClient]: + memory = FakeMemoryClient() + service = MemoryOSService( + capture_store=CaptureStore(str(tmp_path / "capture.db")), + world_store=WorldMemoryStore(str(tmp_path / "world.db")), + graph_store=SessionGraphStore(str(tmp_path / "sessions")), + memory_client=memory, + graph_projection=CausalGraphProjection(str(tmp_path / "causal_scene.kuzu")), + ) + return service, memory + + +def test_raw_event_truth_layer_and_kuzu_rebuild_verify(tmp_path: Path) -> None: + service, memory = _service(tmp_path) + + recorded = service.record_raw_event( + { + "id": "evt-1", + "user_id": "default", + "source_app": "gmail", + "event_type": "email_received", + "timestamp": "2026-05-20T00:00:00+00:00", + "session_id": "session-1", + "privacy_scope": "global", + "metadata": {"entities": ["Dhee"], "threads": ["task"]}, + } + ) + + stored = service.capture_store.get_raw_event("evt-1") + assert stored is not None + assert stored.schema_version == CAUSAL_SCHEMA_VERSION + assert stored.deleted_at is None + assert stored.redacted_at is None + assert recorded["projection"]["backend"] == "kuzu" + assert memory.rows == [] + + projection = service.graph_projection + assert projection is not None + projection.delete() + rebuilt = projection.rebuild(service.capture_store) + assert rebuilt["backend"] == "kuzu" + verified = projection.verify(service.capture_store) + assert verified["ok"] is True + assert verified["checks"]["node_count"]["projected_raw_events"] == 1 + + +def test_redacted_events_do_not_project_as_active_nodes(tmp_path: Path) -> None: + service, _memory = _service(tmp_path) + service.record_raw_event( + { + "id": "evt-redact", + "user_id": "default", + "source_app": "chrome", + "event_type": "browser_decision", + "timestamp": "2026-05-20T00:00:00+00:00", + "privacy_scope": "private", + } + ) + + service.capture_store.redact_raw_event("evt-redact", reason="user requested redaction") + assert service.graph_projection is not None + service.graph_projection.sync(service.capture_store) + + verified = service.graph_projection.verify(service.capture_store) + assert verified["ok"] is True + assert verified["checks"]["node_count"]["projected_raw_events"] == 0 + assert service.graph_projection.show_event("evt-redact")["status"] == "not_found" + + +def test_causal_edge_requires_evidence_for_caused(tmp_path: Path) -> None: + store = CaptureStore(str(tmp_path / "capture.db")) + + with pytest.raises(ValueError, match="CAUSED edges require evidence_event_ids"): + store.add_causal_edge( + CausalEdge( + id="edge-1", + source_id="frame-a", + target_id="frame-b", + edge_type="CAUSED", + confidence=0.8, + status="inferred", + evidence_event_ids=[], + inferred_by="rule", + explanation="invalid", + created_at="2026-05-20T00:00:00+00:00", + ) + ) + + +def test_checkpoint_creates_event_frames_edges_and_report(tmp_path: Path) -> None: + service, memory = _service(tmp_path) + service.record_raw_event( + { + "id": "evt-a", + "user_id": "default", + "source_app": "gmail", + "event_type": "email_received", + "timestamp": "2026-05-20T00:00:00+00:00", + "session_id": "session-2", + } + ) + service.record_raw_event( + { + "id": "evt-b", + "user_id": "default", + "source_app": "chrome", + "event_type": "task_created", + "timestamp": "2026-05-20T00:01:00+00:00", + "session_id": "session-2", + } + ) + + checkpoint = service.compile_causal_checkpoint(session_id="session-2") + + assert [frame["id"] for frame in checkpoint["eventFrames"]] == ["frame:evt-a", "frame:evt-b"] + assert checkpoint["causalEdges"][0]["edge_type"] == "SUPPORTED" + assert checkpoint["causalEdges"][0]["evidence_event_ids"] == ["evt-a", "evt-b"] + assert checkpoint["report"]["summary_memory_id"] == "mem-1" + assert memory.rows[0]["metadata"]["source_event_ids"] == ["evt-a", "evt-b"] + + +def test_threads_scope_filter_and_retrieval_shapes(tmp_path: Path) -> None: + service, _memory = _service(tmp_path) + service.record_raw_event( + { + "id": "evt-thread", + "user_id": "default", + "source_app": "chrome", + "event_type": "preference", + "timestamp": "2026-05-20T00:00:00+00:00", + "session_id": "session-3", + "privacy_scope": "private", + "metadata": {"threads": ["task"], "project": "Dhee"}, + } + ) + + projection = service.graph_projection + assert projection is not None + shown = projection.show_event("evt-thread") + assert shown["event"]["sqlite_id"] == "evt-thread" + assert len(shown["threads"]) >= 4 + + global_frontier = service.get_active_frontier(user_id="default", scope="global") + assert global_frontier["active_threads"] == [] + + private_frontier = service.get_active_frontier(user_id="default", scope="private") + assert private_frontier["high_confidence_preferences"] + assert private_frontier["evidence"] == [{"event_id": "evt-thread"}] + + why = service.causal_why(event_id="evt-thread", user_id="default", scope="private") + assert why["target_event_id"] == "evt-thread" + assert "evidence" in why + + +def test_memory_gem_extractor_writes_provenance_raw_events(tmp_path: Path) -> None: + store = CaptureStore(str(tmp_path / "capture.db")) + memories = [ + { + "id": "mem-pref", + "user_id": "default", + "memory": "User prefers brutally honest architectural critique with concrete tradeoffs.", + "strength": 0.9, + "importance": 0.9, + "memory_type": "semantic", + "categories": ["preference"], + "source_app": "codex", + "created_at": "2026-05-20T00:00:00+00:00", + "metadata": {"scope": "personal"}, + }, + { + "id": "mem-noise", + "user_id": "default", + "memory": "Chotu observed useful visible screen activity. App: Chrome", + "strength": 0.5, + "importance": 0.2, + "memory_type": "episodic", + "categories": [], + "source_app": "chotu", + "created_at": "2026-05-20T00:01:00+00:00", + "metadata": {}, + }, + ] + + gems = extract_memory_gems(memories, user_id="default", min_score=0.5) + assert len(gems) == 1 + assert gems[0].kind == "preference" + assert gems[0].privacy_scope == "private" + + report = write_gem_raw_events(store, gems) + assert report["written"] == [f"gem:{gems[0].id}"] + stored = store.get_raw_event(f"gem:{gems[0].id}") + assert stored is not None + assert stored.content_ref == "memory:mem-pref" + assert stored.metadata["source_memory_id"] == "mem-pref" + + again = write_gem_raw_events(store, gems) + assert again["written"] == [] + assert again["skipped_existing"] == [f"gem:{gems[0].id}"] + + +def test_memory_gem_projects_into_semantic_threads_and_gem_listing(tmp_path: Path) -> None: + service, _memory = _service(tmp_path) + memories = [ + { + "id": "mem-pref", + "user_id": "default", + "memory": "User prefers product-first framing and dislikes vague architecture plans.", + "strength": 0.92, + "importance": 0.88, + "memory_type": "semantic", + "categories": ["preference"], + "source_app": "codex", + "created_at": "2026-05-20T00:00:00+00:00", + "metadata": {"scope": "work"}, + }, + { + "id": "mem-pref-visual", + "user_id": "default", + "memory": "User prefers polished frontend layouts with concrete controls and restrained cards.", + "strength": 0.82, + "importance": 0.8, + "memory_type": "semantic", + "categories": ["preference"], + "source_app": "codex", + "created_at": "2026-05-20T00:01:00+00:00", + "metadata": {"scope": "work"}, + }, + { + "id": "mem-pref-private", + "user_id": "default", + "memory": "User prefers private personal notes to stay out of global retrieval.", + "strength": 0.91, + "importance": 0.91, + "memory_type": "semantic", + "categories": ["preference"], + "source_app": "codex", + "created_at": "2026-05-20T00:02:00+00:00", + "metadata": {"scope": "personal"}, + }, + ] + gems = extract_memory_gems(memories, user_id="default", min_score=0.5) + write_gem_raw_events(service.capture_store, gems) + assert service.graph_projection is not None + service.graph_projection.rebuild(service.capture_store) + + product_gem = next(gem for gem in gems if gem.source_memory_id == "mem-pref") + event_id = f"gem:{product_gem.id}" + shown = service.graph_projection.show_event(event_id) + thread_types = {row.get("thread_type") for row in shown["threads"]} + assert {"gem", "gem_kind", "preference", "source_memory"}.issubset(thread_types) + + shown_gem = service.causal_show_gem(event_id, user_id="default", scope="global") + assert shown_gem["status"] == "ok" + assert shown_gem["gem"]["event_id"] == event_id + assert shown_gem["gem"]["projection_version"] + assert shown_gem["source_memory"]["memory_id"] == "mem-pref" + assert shown_gem["supporting_events"][0]["source_memory_id"] == "mem-pref" + assert {"preference", "source_memory"}.issubset({row.get("thread_type") for row in shown_gem["threads"]}) + assert shown_gem["retrieval_path"][0]["step"] == "match_scoped_memory_gem" + + shown_by_gem_id = service.causal_show_gem(product_gem.id, user_id="default", scope="global") + assert shown_by_gem_id["gem"]["event_id"] == event_id + + exchange = LearningExchange(tmp_path / "learnings") + submitted = service.causal_submit_gem( + product_gem.id, + learning_exchange=exchange, + user_id="default", + scope="global", + repo=str(tmp_path), + ) + assert submitted["submitted"] == [f"lrn_{product_gem.id[-16:]}"] + candidate = exchange.get(submitted["submitted"][0]) + assert candidate is not None + assert candidate.kind == "policy" + assert candidate.metadata["gem_id"] == product_gem.id + assert candidate.metadata["source_memory_id"] == "mem-pref" + assert candidate.evidence + + private_gem = next(gem for gem in gems if gem.source_memory_id == "mem-pref-private") + private_hidden = service.causal_show_gem(private_gem.id, user_id="default", scope="global") + assert private_hidden["status"] == "not_found" + private_submit = service.causal_submit_gem( + private_gem.id, + learning_exchange=exchange, + user_id="default", + scope="global", + ) + assert private_submit["submitted"] == [] + assert private_submit["rejected"][0]["reason"] == "not_found" + private_visible = service.causal_show_gem(private_gem.id, user_id="default", scope="private") + assert private_visible["status"] == "ok" + + preferences = service.causal_preference(query="product-first architecture", user_id="default", scope="global") + assert preferences["preference_signal"]["signal_type"] == "memory_gem" + assert preferences["preference_signal"]["thread_type"] == "preference" + assert preferences["preference_signal"]["source_memory_id"] == "mem-pref" + assert preferences["supporting_events"][0]["event_id"] == event_id + assert preferences["retrieval_path"][0]["step"] == "list_gems" + assert preferences["retrieval_id"].startswith("retr_") + explained = service.explain_causal_retrieval(preferences["retrieval_id"], user_id="default") + assert explained["status"] == "ok" + assert explained["mode"] == "preference" + assert explained["query"] == "product-first architecture" + assert explained["traversal"][0]["step"] == "list_gems" + assert explained["evidence"][0]["event_id"] == event_id + assert explained["result"]["preference_signal"]["source_memory_id"] == "mem-pref" + assert service.explain_causal_retrieval(preferences["retrieval_id"], user_id="other")["status"] == "not_found" + + frontier = service.get_active_frontier(user_id="default", scope="global") + assert frontier["retrieval_id"].startswith("retr_") + frontier_preferences = frontier["high_confidence_preferences"] + assert frontier_preferences[0]["signal_type"] == "memory_gem" + assert frontier_preferences[0]["thread_type"] == "preference" + assert { + preference["source_memory_id"] for preference in frontier_preferences + } == {"mem-pref", "mem-pref-visual"} + + listed = service.causal_gems(user_id="default", scope="global", kind="preference") + assert listed["count"] == 2 + assert listed["by_kind"] == {"preference": 2} + assert {gem["source_memory_id"] for gem in listed["gems"]} == {"mem-pref", "mem-pref-visual"} + + private_listed = service.causal_gems(user_id="default", scope="private", kind="preference") + assert private_listed["count"] == 3 + private_frontier = service.get_active_frontier(user_id="default", scope="private") + assert { + preference["source_memory_id"] for preference in private_frontier["high_confidence_preferences"] + } == {"mem-pref", "mem-pref-visual", "mem-pref-private"} + + service.capture_store.redact_raw_event(event_id, reason="user requested trace redaction") + assert service.explain_causal_retrieval(preferences["retrieval_id"], user_id="default")["status"] == "not_found" + redacted_trace = service.capture_store.get_retrieval_trace( + preferences["retrieval_id"], + include_redacted=True, + ) + assert redacted_trace is not None + assert redacted_trace.redacted_at + assert redacted_trace.redaction_reason == "user requested trace redaction" + + +def test_retrieval_trace_pruning_preserves_latest_and_supports_dry_run(tmp_path: Path) -> None: + store = CaptureStore(str(tmp_path / "capture.db")) + for index in range(5): + store.add_retrieval_trace( + RetrievalTrace( + id=f"retr-{index}", + user_id="default", + mode="preference", + scope="global", + query="", + target_id=f"evt-{index}", + retrieval_path=[], + evidence=[], + result={"index": index}, + privacy_scope="global", + created_at=f"2026-05-20T00:0{index}:00+00:00", + ) + ) + store.add_retrieval_trace( + RetrievalTrace( + id="retr-other", + user_id="other", + mode="preference", + scope="global", + query="", + target_id="evt-other", + retrieval_path=[], + evidence=[], + result={}, + privacy_scope="global", + created_at="2026-05-20T00:10:00+00:00", + ) + ) + + dry_run = store.prune_retrieval_traces(user_id="default", keep_latest=2, dry_run=True) + assert dry_run["candidate_count"] == 3 + assert dry_run["pruned_count"] == 0 + assert store.get_retrieval_trace("retr-0") is not None + + pruned = store.prune_retrieval_traces(user_id="default", keep_latest=2) + assert pruned["pruned_count"] == 3 + assert store.get_retrieval_trace("retr-0") is None + assert store.get_retrieval_trace("retr-1") is None + assert store.get_retrieval_trace("retr-2") is None + assert store.get_retrieval_trace("retr-3") is not None + assert store.get_retrieval_trace("retr-4") is not None + assert store.get_retrieval_trace("retr-other") is not None