Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions examples/benchmark_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""Benchmark: measure SDK search latency with caching.

Runs fetch_tools, local (BM25+TF-IDF) search, and semantic search N times,
reports cold vs warm average latency and the speedup from caching.

Prerequisites:
- STACKONE_API_KEY environment variable
- STACKONE_ACCOUNT_ID environment variable

Run with:
uv run python examples/benchmark_search.py # default 100 iterations
uv run python examples/benchmark_search.py -n 50 # fewer for a quick check
"""

from __future__ import annotations

import argparse
import os
import sys
import time

try:
from dotenv import load_dotenv

load_dotenv()
except ModuleNotFoundError:
pass

from stackone_ai import StackOneToolSet

QUERIES = [
"list events",
"cancel a meeting",
"send a message",
"get current user",
"list employees",
]


def bench(fn, n: int) -> tuple[float, float, list[float]]:
"""Run fn() n times. Return (cold, warm_avg, all_times)."""
times: list[float] = []
for _ in range(n):
t = time.perf_counter()
fn()
times.append(time.perf_counter() - t)

cold = times[0]
warm_times = times[1:]
warm_avg = sum(warm_times) / len(warm_times) if warm_times else cold
return cold, warm_avg, times


def fmt_ms(seconds: float) -> str:
return f"{seconds * 1000:8.1f}ms"


def main() -> int:
parser = argparse.ArgumentParser(description="Benchmark SDK search latency")
parser.add_argument(
"--iterations", "-n", type=int, default=100, help="iterations per benchmark (default 100)"
)
args = parser.parse_args()
n = args.iterations
Copy link
Copy Markdown

@cubic-dev-ai cubic-dev-ai Bot Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2: Validate --iterations to require a positive integer; 0 or negative values currently crash at times[0] in bench().

Prompt for AI agents
Check if this issue is valid — if so, understand the root cause and fix it. At examples/benchmark_search.py, line 61:

<comment>Validate `--iterations` to require a positive integer; `0` or negative values currently crash at `times[0]` in `bench()`.</comment>

<file context>
@@ -0,0 +1,131 @@
+    parser = argparse.ArgumentParser(description="Benchmark SDK search latency")
+    parser.add_argument("--iterations", "-n", type=int, default=100, help="iterations per benchmark (default 100)")
+    args = parser.parse_args()
+    n = args.iterations
+
+    api_key = os.getenv("STACKONE_API_KEY")
</file context>
Fix with Cubic


api_key = os.getenv("STACKONE_API_KEY")
account_id = os.getenv("STACKONE_ACCOUNT_ID")

if not api_key:
print("Set STACKONE_API_KEY to run this benchmark.")
return 1
if not account_id:
print("Set STACKONE_ACCOUNT_ID to run this benchmark.")
return 1

print(f"Benchmarking with account {account_id[:8]}..., {n} iterations each\n")

ts = StackOneToolSet(
api_key=api_key,
account_id=account_id,
search={"method": "auto", "top_k": 5},
)

results: list[tuple[str, float, float, float]] = []
query_idx = 0

def next_query() -> str:
nonlocal query_idx
q = QUERIES[query_idx % len(QUERIES)]
query_idx += 1
return q
Comment on lines +87 to +91
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feel like we don't need a global variable for a simple iteration on a function??

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, its trivial to fix but I think we can clean this up in Examples cleaup or Can do now?


# --- 1. fetch_tools ---
print(f"[1/3] fetch_tools x{n} ...")
ts.clear_catalog_cache()
cold, warm_avg, _ = bench(lambda: ts.fetch_tools(), n)
speedup = cold / warm_avg if warm_avg > 0 else float("inf")
results.append(("fetch_tools", cold, warm_avg, speedup))
print(f" cold={fmt_ms(cold)} warm_avg={fmt_ms(warm_avg)} speedup={speedup:.0f}x")

# --- 2. local search (BM25 + TF-IDF) ---
print(f"[2/3] search_tools (local) x{n} ...")
ts.clear_catalog_cache()
query_idx = 0
cold, warm_avg, _ = bench(lambda: ts.search_tools(next_query(), search="local"), n)
speedup = cold / warm_avg if warm_avg > 0 else float("inf")
results.append(("search (local/BM25)", cold, warm_avg, speedup))
print(f" cold={fmt_ms(cold)} warm_avg={fmt_ms(warm_avg)} speedup={speedup:.0f}x")

# --- 3. semantic search (auto) ---
print(f"[3/3] search_tools (semantic/auto) x{n} ...")
ts.clear_catalog_cache()
query_idx = 0
cold, warm_avg, _ = bench(lambda: ts.search_tools(next_query(), search="auto"), n)
speedup = cold / warm_avg if warm_avg > 0 else float("inf")
results.append(("search (semantic)", cold, warm_avg, speedup))
print(f" cold={fmt_ms(cold)} warm_avg={fmt_ms(warm_avg)} speedup={speedup:.0f}x")

# --- Summary ---
print("\n" + "=" * 65)
print(f"{'Benchmark':<22} {'Cold':>10} {'Warm (avg)':>10} {'Speedup':>10}")
print("-" * 65)
for name, c, w, s in results:
print(f"{name:<22} {fmt_ms(c):>10} {fmt_ms(w):>10} {s:>9.0f}x")
print("=" * 65)

print(f"\nWarm = average of {n - 1} calls after the first (cold) call.")
print("Speedup = cold / warm_avg — shows the benefit of caching.\n")

return 0


if __name__ == "__main__":
sys.exit(main())
1 change: 1 addition & 0 deletions examples/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def get_example_files() -> list[str]:
"semantic_search_example.py": ["mcp"],
"mcp_server.py": ["mcp"],
"workday_integration.py": ["openai", "mcp"],
"benchmark_search.py": ["mcp"],
}


Expand Down
53 changes: 42 additions & 11 deletions stackone_ai/toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,6 @@ class _ExecuteTool(StackOneTool):
"""LLM-callable tool that executes a StackOne tool by name."""

_toolset: Any = PrivateAttr(default=None)
_cached_tools: Any = PrivateAttr(default=None)

def execute(
self, arguments: str | JsonDict | None = None, *, options: JsonDict | None = None
Expand All @@ -185,10 +184,8 @@ def execute(
parsed = _ExecuteInput(**raw_params)
tool_name = parsed.tool_name

if self._cached_tools is None:
self._cached_tools = self._toolset.fetch_tools(account_ids=self._toolset._account_ids)

target = self._cached_tools.get_tool(parsed.tool_name)
tools = self._toolset.fetch_tools(account_ids=self._toolset._account_ids)
target = tools.get_tool(parsed.tool_name)

if target is None:
return {
Expand Down Expand Up @@ -602,6 +599,8 @@ def __init__(
execute_timeout = execute.get("timeout") if execute else None
self._timeout: float = timeout if timeout is not None else (execute_timeout or 60.0)
self._tools_cache: Tools | None = None
self._catalog_cache: dict[tuple[Any, ...], Tools] = {}
self._tool_index_cache: tuple[int, Any] | None = None

def set_accounts(self, account_ids: list[str]) -> StackOneToolSet:
"""Set account IDs for filtering tools
Expand All @@ -613,8 +612,18 @@ def set_accounts(self, account_ids: list[str]) -> StackOneToolSet:
This toolset instance for chaining
"""
self._account_ids = account_ids
self.clear_catalog_cache()
return self

def clear_catalog_cache(self) -> None:
"""Invalidate cached tool catalog and local search index.

Call when linked accounts change outside of ``set_accounts`` or when
you need to force a fresh fetch from the StackOne MCP endpoint.
"""
self._catalog_cache.clear()
self._tool_index_cache = None

def get_search_tool(self, *, search: SearchMode | None = None) -> SearchTool:
"""Get a callable search tool that returns Tools collections.

Expand Down Expand Up @@ -802,7 +811,10 @@ def _local_search(
if not available_connectors:
return Tools([])

index = ToolIndex(list(all_tools))
cache_key = id(all_tools)
if self._tool_index_cache is None or self._tool_index_cache[0] != cache_key:
self._tool_index_cache = (cache_key, ToolIndex(list(all_tools)))
index = self._tool_index_cache[1]
results = index.search(
query,
limit=top_k if top_k is not None else 5,
Expand Down Expand Up @@ -1171,22 +1183,41 @@ def fetch_tools(
else:
account_scope = [None]

cache_key = (
tuple(sorted(account_scope, key=lambda a: (a is None, a))),
tuple(sorted(p.lower() for p in providers)) if providers else None,
tuple(sorted(actions)) if actions else None,
)
cached = self._catalog_cache.get(cache_key)
if cached is not None:
return cached
Comment on lines +1191 to +1193
Copy link

Copilot AI Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On cache hits this returns the same cached Tools instance. Tools and StackOneTool are mutable (e.g., Tools.set_account_id() mutates the contained tools), so a consumer mutating the returned object will implicitly mutate the toolset’s cache and affect subsequent fetch_tools()/search_tools() calls. If you want memoization to be internal-only/no behavioral change, consider returning a defensive copy (or caching an immutable representation and constructing new tool objects per call) so external mutations can’t leak back into the cache.

Copilot uses AI. Check for mistakes.

endpoint = f"{self.base_url.rstrip('/')}/mcp"
all_tools: list[StackOneTool] = []

for account in account_scope:
def _fetch_for_account(account: str | None) -> list[StackOneTool]:
headers = self._build_mcp_headers(account)
catalog = _fetch_mcp_tools(endpoint, headers)
for tool_def in catalog:
all_tools.append(self._create_rpc_tool(tool_def, account))
return [self._create_rpc_tool(tool_def, account) for tool_def in catalog]

all_tools: list[StackOneTool] = []
if len(account_scope) == 1:
all_tools.extend(_fetch_for_account(account_scope[0]))
else:
max_workers = min(len(account_scope), 10)
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as pool:
futures = [pool.submit(_fetch_for_account, acc) for acc in account_scope]
for future in futures:
all_tools.extend(future.result())
Comment on lines +1208 to +1210
Copy link

Copilot AI Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parallel fetch aggregates results via as_completed, which makes all_tools order nondeterministic across runs. Since Tools builds _tool_map = {tool.name: tool for tool in tools}, any duplicate tool names across accounts will be resolved by “last one wins” in completion order, potentially changing which account’s tool get_tool()/tool_execute selects compared to the previous deterministic per-account loop. Consider preserving a deterministic account ordering while still fetching in parallel (e.g., use executor.map over an ordered account list, or collect results keyed by account and extend in that order; ideally use the same ordering as the cache key).

Suggested change
futures = [pool.submit(_fetch_for_account, acc) for acc in account_scope]
for future in concurrent.futures.as_completed(futures):
all_tools.extend(future.result())
for account_tools in pool.map(_fetch_for_account, account_scope):
all_tools.extend(account_tools)

Copilot uses AI. Check for mistakes.

if providers:
all_tools = [tool for tool in all_tools if self._filter_by_provider(tool.name, providers)]

if actions:
all_tools = [tool for tool in all_tools if self._filter_by_action(tool.name, actions)]

return Tools(all_tools)
result = Tools(all_tools)
self._catalog_cache[cache_key] = result
return result

except ToolsetError:
raise
Expand Down
9 changes: 7 additions & 2 deletions tests/test_agent_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,11 @@ def test_invalid_json_returns_error_dict(self):

assert "error" in result

def test_caches_fetched_tools(self):
def test_delegates_catalog_lookup_to_toolset(self):
# _ExecuteTool no longer holds a local cache; the toolset's catalog
# cache (see StackOneToolSet._catalog_cache) is the single source of
# truth. Verify execute always defers to the toolset so it benefits
# from that shared cache.
toolset = MagicMock()
toolset.api_key = "test-key"
toolset._account_ids = []
Expand All @@ -286,7 +290,8 @@ def test_caches_fetched_tools(self):
execute.execute({"tool_name": "test_tool"})
execute.execute({"tool_name": "test_tool"})

toolset.fetch_tools.assert_called_once()
assert toolset.fetch_tools.call_count == 2
toolset.fetch_tools.assert_called_with(account_ids=[])

def test_passes_account_ids_from_toolset(self):
toolset = MagicMock()
Expand Down
Loading