-
Notifications
You must be signed in to change notification settings - Fork 0
feat(search-optimization): cache tool catalog and parallelize per-account MCP fetches #173
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
68faae6
6a52804
982165d
1c834b1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,134 @@ | ||
| """Benchmark: measure SDK search latency with caching. | ||
|
|
||
| Runs fetch_tools, local (BM25+TF-IDF) search, and semantic search N times, | ||
| reports cold vs warm average latency and the speedup from caching. | ||
|
|
||
| Prerequisites: | ||
| - STACKONE_API_KEY environment variable | ||
| - STACKONE_ACCOUNT_ID environment variable | ||
|
|
||
| Run with: | ||
| uv run python examples/benchmark_search.py # default 100 iterations | ||
| uv run python examples/benchmark_search.py -n 50 # fewer for a quick check | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import argparse | ||
| import os | ||
| import sys | ||
| import time | ||
|
|
||
| try: | ||
| from dotenv import load_dotenv | ||
|
|
||
| load_dotenv() | ||
| except ModuleNotFoundError: | ||
| pass | ||
|
|
||
| from stackone_ai import StackOneToolSet | ||
|
|
||
| QUERIES = [ | ||
| "list events", | ||
| "cancel a meeting", | ||
| "send a message", | ||
| "get current user", | ||
| "list employees", | ||
| ] | ||
|
|
||
|
|
||
| def bench(fn, n: int) -> tuple[float, float, list[float]]: | ||
| """Run fn() n times. Return (cold, warm_avg, all_times).""" | ||
| times: list[float] = [] | ||
| for _ in range(n): | ||
| t = time.perf_counter() | ||
| fn() | ||
| times.append(time.perf_counter() - t) | ||
|
|
||
| cold = times[0] | ||
| warm_times = times[1:] | ||
| warm_avg = sum(warm_times) / len(warm_times) if warm_times else cold | ||
| return cold, warm_avg, times | ||
|
|
||
|
|
||
| def fmt_ms(seconds: float) -> str: | ||
| return f"{seconds * 1000:8.1f}ms" | ||
|
|
||
|
|
||
| def main() -> int: | ||
| parser = argparse.ArgumentParser(description="Benchmark SDK search latency") | ||
| parser.add_argument( | ||
| "--iterations", "-n", type=int, default=100, help="iterations per benchmark (default 100)" | ||
| ) | ||
| args = parser.parse_args() | ||
| n = args.iterations | ||
|
|
||
| api_key = os.getenv("STACKONE_API_KEY") | ||
| account_id = os.getenv("STACKONE_ACCOUNT_ID") | ||
|
|
||
| if not api_key: | ||
| print("Set STACKONE_API_KEY to run this benchmark.") | ||
| return 1 | ||
| if not account_id: | ||
| print("Set STACKONE_ACCOUNT_ID to run this benchmark.") | ||
| return 1 | ||
|
|
||
| print(f"Benchmarking with account {account_id[:8]}..., {n} iterations each\n") | ||
|
|
||
| ts = StackOneToolSet( | ||
| api_key=api_key, | ||
| account_id=account_id, | ||
| search={"method": "auto", "top_k": 5}, | ||
| ) | ||
|
|
||
| results: list[tuple[str, float, float, float]] = [] | ||
| query_idx = 0 | ||
|
|
||
| def next_query() -> str: | ||
| nonlocal query_idx | ||
| q = QUERIES[query_idx % len(QUERIES)] | ||
| query_idx += 1 | ||
| return q | ||
|
Comment on lines
+87
to
+91
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. feel like we don't need a global variable for a simple iteration on a function??
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree, its trivial to fix but I think we can clean this up in Examples cleaup or Can do now? |
||
|
|
||
| # --- 1. fetch_tools --- | ||
| print(f"[1/3] fetch_tools x{n} ...") | ||
| ts.clear_catalog_cache() | ||
| cold, warm_avg, _ = bench(lambda: ts.fetch_tools(), n) | ||
| speedup = cold / warm_avg if warm_avg > 0 else float("inf") | ||
| results.append(("fetch_tools", cold, warm_avg, speedup)) | ||
| print(f" cold={fmt_ms(cold)} warm_avg={fmt_ms(warm_avg)} speedup={speedup:.0f}x") | ||
|
|
||
| # --- 2. local search (BM25 + TF-IDF) --- | ||
| print(f"[2/3] search_tools (local) x{n} ...") | ||
| ts.clear_catalog_cache() | ||
| query_idx = 0 | ||
| cold, warm_avg, _ = bench(lambda: ts.search_tools(next_query(), search="local"), n) | ||
| speedup = cold / warm_avg if warm_avg > 0 else float("inf") | ||
| results.append(("search (local/BM25)", cold, warm_avg, speedup)) | ||
| print(f" cold={fmt_ms(cold)} warm_avg={fmt_ms(warm_avg)} speedup={speedup:.0f}x") | ||
|
|
||
| # --- 3. semantic search (auto) --- | ||
| print(f"[3/3] search_tools (semantic/auto) x{n} ...") | ||
| ts.clear_catalog_cache() | ||
| query_idx = 0 | ||
| cold, warm_avg, _ = bench(lambda: ts.search_tools(next_query(), search="auto"), n) | ||
| speedup = cold / warm_avg if warm_avg > 0 else float("inf") | ||
| results.append(("search (semantic)", cold, warm_avg, speedup)) | ||
| print(f" cold={fmt_ms(cold)} warm_avg={fmt_ms(warm_avg)} speedup={speedup:.0f}x") | ||
|
|
||
| # --- Summary --- | ||
| print("\n" + "=" * 65) | ||
| print(f"{'Benchmark':<22} {'Cold':>10} {'Warm (avg)':>10} {'Speedup':>10}") | ||
| print("-" * 65) | ||
| for name, c, w, s in results: | ||
| print(f"{name:<22} {fmt_ms(c):>10} {fmt_ms(w):>10} {s:>9.0f}x") | ||
| print("=" * 65) | ||
|
|
||
| print(f"\nWarm = average of {n - 1} calls after the first (cold) call.") | ||
| print("Speedup = cold / warm_avg — shows the benefit of caching.\n") | ||
|
|
||
| return 0 | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| sys.exit(main()) | ||
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -170,7 +170,6 @@ class _ExecuteTool(StackOneTool): | |||||||||||
| """LLM-callable tool that executes a StackOne tool by name.""" | ||||||||||||
|
|
||||||||||||
| _toolset: Any = PrivateAttr(default=None) | ||||||||||||
| _cached_tools: Any = PrivateAttr(default=None) | ||||||||||||
|
|
||||||||||||
| def execute( | ||||||||||||
| self, arguments: str | JsonDict | None = None, *, options: JsonDict | None = None | ||||||||||||
|
|
@@ -185,10 +184,8 @@ def execute( | |||||||||||
| parsed = _ExecuteInput(**raw_params) | ||||||||||||
| tool_name = parsed.tool_name | ||||||||||||
|
|
||||||||||||
| if self._cached_tools is None: | ||||||||||||
| self._cached_tools = self._toolset.fetch_tools(account_ids=self._toolset._account_ids) | ||||||||||||
|
|
||||||||||||
| target = self._cached_tools.get_tool(parsed.tool_name) | ||||||||||||
| tools = self._toolset.fetch_tools(account_ids=self._toolset._account_ids) | ||||||||||||
| target = tools.get_tool(parsed.tool_name) | ||||||||||||
|
|
||||||||||||
| if target is None: | ||||||||||||
| return { | ||||||||||||
|
|
@@ -602,6 +599,8 @@ def __init__( | |||||||||||
| execute_timeout = execute.get("timeout") if execute else None | ||||||||||||
| self._timeout: float = timeout if timeout is not None else (execute_timeout or 60.0) | ||||||||||||
| self._tools_cache: Tools | None = None | ||||||||||||
| self._catalog_cache: dict[tuple[Any, ...], Tools] = {} | ||||||||||||
| self._tool_index_cache: tuple[int, Any] | None = None | ||||||||||||
|
|
||||||||||||
| def set_accounts(self, account_ids: list[str]) -> StackOneToolSet: | ||||||||||||
| """Set account IDs for filtering tools | ||||||||||||
|
|
@@ -613,8 +612,18 @@ def set_accounts(self, account_ids: list[str]) -> StackOneToolSet: | |||||||||||
| This toolset instance for chaining | ||||||||||||
| """ | ||||||||||||
| self._account_ids = account_ids | ||||||||||||
| self.clear_catalog_cache() | ||||||||||||
| return self | ||||||||||||
|
|
||||||||||||
| def clear_catalog_cache(self) -> None: | ||||||||||||
| """Invalidate cached tool catalog and local search index. | ||||||||||||
|
|
||||||||||||
| Call when linked accounts change outside of ``set_accounts`` or when | ||||||||||||
| you need to force a fresh fetch from the StackOne MCP endpoint. | ||||||||||||
| """ | ||||||||||||
| self._catalog_cache.clear() | ||||||||||||
| self._tool_index_cache = None | ||||||||||||
|
|
||||||||||||
| def get_search_tool(self, *, search: SearchMode | None = None) -> SearchTool: | ||||||||||||
| """Get a callable search tool that returns Tools collections. | ||||||||||||
|
|
||||||||||||
|
|
@@ -802,7 +811,10 @@ def _local_search( | |||||||||||
| if not available_connectors: | ||||||||||||
| return Tools([]) | ||||||||||||
|
|
||||||||||||
| index = ToolIndex(list(all_tools)) | ||||||||||||
| cache_key = id(all_tools) | ||||||||||||
| if self._tool_index_cache is None or self._tool_index_cache[0] != cache_key: | ||||||||||||
| self._tool_index_cache = (cache_key, ToolIndex(list(all_tools))) | ||||||||||||
| index = self._tool_index_cache[1] | ||||||||||||
| results = index.search( | ||||||||||||
| query, | ||||||||||||
| limit=top_k if top_k is not None else 5, | ||||||||||||
|
|
@@ -1171,22 +1183,41 @@ def fetch_tools( | |||||||||||
| else: | ||||||||||||
| account_scope = [None] | ||||||||||||
|
|
||||||||||||
| cache_key = ( | ||||||||||||
| tuple(sorted(account_scope, key=lambda a: (a is None, a))), | ||||||||||||
| tuple(sorted(p.lower() for p in providers)) if providers else None, | ||||||||||||
| tuple(sorted(actions)) if actions else None, | ||||||||||||
| ) | ||||||||||||
| cached = self._catalog_cache.get(cache_key) | ||||||||||||
| if cached is not None: | ||||||||||||
| return cached | ||||||||||||
|
Comment on lines
+1191
to
+1193
|
||||||||||||
|
|
||||||||||||
| endpoint = f"{self.base_url.rstrip('/')}/mcp" | ||||||||||||
| all_tools: list[StackOneTool] = [] | ||||||||||||
|
|
||||||||||||
| for account in account_scope: | ||||||||||||
| def _fetch_for_account(account: str | None) -> list[StackOneTool]: | ||||||||||||
| headers = self._build_mcp_headers(account) | ||||||||||||
| catalog = _fetch_mcp_tools(endpoint, headers) | ||||||||||||
| for tool_def in catalog: | ||||||||||||
| all_tools.append(self._create_rpc_tool(tool_def, account)) | ||||||||||||
| return [self._create_rpc_tool(tool_def, account) for tool_def in catalog] | ||||||||||||
|
|
||||||||||||
| all_tools: list[StackOneTool] = [] | ||||||||||||
| if len(account_scope) == 1: | ||||||||||||
| all_tools.extend(_fetch_for_account(account_scope[0])) | ||||||||||||
| else: | ||||||||||||
| max_workers = min(len(account_scope), 10) | ||||||||||||
| with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as pool: | ||||||||||||
| futures = [pool.submit(_fetch_for_account, acc) for acc in account_scope] | ||||||||||||
| for future in futures: | ||||||||||||
| all_tools.extend(future.result()) | ||||||||||||
|
Comment on lines
+1208
to
+1210
|
||||||||||||
| futures = [pool.submit(_fetch_for_account, acc) for acc in account_scope] | |
| for future in concurrent.futures.as_completed(futures): | |
| all_tools.extend(future.result()) | |
| for account_tools in pool.map(_fetch_for_account, account_scope): | |
| all_tools.extend(account_tools) |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
P2: Validate
--iterationsto require a positive integer;0or negative values currently crash attimes[0]inbench().Prompt for AI agents