diff --git a/mempalace/palace_graph.py b/mempalace/palace_graph.py index 71cad89ec..125ec0d4a 100644 --- a/mempalace/palace_graph.py +++ b/mempalace/palace_graph.py @@ -18,6 +18,8 @@ import hashlib import json import os +import threading +import time from collections import Counter, defaultdict from datetime import datetime, timezone @@ -25,6 +27,23 @@ from .palace import get_collection as _get_palace_collection from .palace import mine_lock +# Module-level graph cache with TTL and write-invalidation. +# Warm cache serves build_graph() in O(1); invalidate_graph_cache() clears on writes. +_graph_cache_lock = threading.Lock() +_graph_cache_nodes = None +_graph_cache_edges = None +_graph_cache_time = 0.0 +_GRAPH_CACHE_TTL = 60.0 # seconds — graph changes less often than metadata + + +def invalidate_graph_cache(): + """Clear the graph cache. Called from mcp_server.py on writes.""" + global _graph_cache_nodes, _graph_cache_edges, _graph_cache_time + with _graph_cache_lock: + _graph_cache_nodes = None + _graph_cache_edges = None + _graph_cache_time = 0.0 + def _get_collection(config=None): config = config or MempalaceConfig() @@ -42,10 +61,25 @@ def build_graph(col=None, config=None): """ Build the palace graph from ChromaDB metadata. + Returns cached result if fresh (within TTL). Cache is invalidated + on writes via invalidate_graph_cache(). Thread-safe via _graph_cache_lock. + + Note: warm cache ignores ``col`` and ``config`` arguments — this is + intentional for the MCP server's single-palace use case. Callers + switching collections should call ``invalidate_graph_cache()`` first. + Returns: nodes: dict of {room: {wings: set, halls: set, count: int}} edges: list of {room, wing_a, wing_b, hall} — one per tunnel crossing """ + global _graph_cache_nodes, _graph_cache_edges, _graph_cache_time + now = time.time() + # NOTE: warm cache ignores col/config args — intentional for the MCP server's + # single-palace use case. Callers switching collections must invalidate first. + with _graph_cache_lock: + if _graph_cache_nodes is not None and (now - _graph_cache_time) < _GRAPH_CACHE_TTL: + return _graph_cache_nodes, _graph_cache_edges + if col is None: col = _get_collection(config) if not col: @@ -101,6 +135,14 @@ def build_graph(col=None, config=None): "dates": sorted(data["dates"])[-5:] if data["dates"] else [], } + # Only cache non-empty graphs so new data is picked up immediately + # when the palace is first populated. + if nodes: + with _graph_cache_lock: + _graph_cache_nodes = nodes + _graph_cache_edges = edges + _graph_cache_time = time.time() + return nodes, edges diff --git a/tests/test_palace_graph.py b/tests/test_palace_graph.py index ddda2724e..7bc45e04b 100644 --- a/tests/test_palace_graph.py +++ b/tests/test_palace_graph.py @@ -30,6 +30,7 @@ def fake_get(limit=1000, offset=0, include=None): build_graph, find_tunnels, graph_stats, + invalidate_graph_cache, traverse, ) @@ -38,6 +39,9 @@ def fake_get(limit=1000, offset=0, include=None): class TestBuildGraph: + def setup_method(self): + invalidate_graph_cache() + def test_empty_collection(self): col = _make_fake_collection([]) nodes, edges = build_graph(col=col) @@ -114,11 +118,43 @@ def test_dates_capped_at_five(self): nodes, _ = build_graph(col=col) assert len(nodes["busy"]["dates"]) <= 5 + def test_cache_returns_same_result(self): + """Second call within TTL returns cached nodes without re-scanning. + + The cache intentionally ignores col/config args when warm — this is + correct for the MCP server's single-palace use case. Callers that + switch collections must call invalidate_graph_cache() first. + """ + col = _make_fake_collection( + [{"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-01"}] + ) + nodes1, edges1 = build_graph(col=col) + # Second call with a *different* collection — should still return cached result + col2 = _make_fake_collection([]) + nodes2, edges2 = build_graph(col=col2) + assert nodes1 == nodes2 + assert edges1 == edges2 + + def test_invalidate_clears_cache(self): + """invalidate_graph_cache() forces a fresh scan on next call.""" + col = _make_fake_collection( + [{"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-01"}] + ) + build_graph(col=col) + invalidate_graph_cache() + col_empty = _make_fake_collection([]) + nodes, edges = build_graph(col=col_empty) + assert nodes == {} + assert edges == [] + # --- traverse --- class TestTraverse: + def setup_method(self): + invalidate_graph_cache() + def _build_col(self): return _make_fake_collection( [ @@ -156,6 +192,9 @@ def test_traverse_max_hops(self): class TestFindTunnels: + def setup_method(self): + invalidate_graph_cache() + def _build_tunnel_col(self): return _make_fake_collection( [ @@ -192,6 +231,9 @@ def test_find_tunnels_both_wings(self): class TestGraphStats: + def setup_method(self): + invalidate_graph_cache() + def test_empty_graph(self): col = _make_fake_collection([]) stats = graph_stats(col=col)