diff --git a/CMakeLists.txt b/CMakeLists.txt index ee8ac93..742f307 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -265,6 +265,7 @@ add_library(flapi-lib STATIC src/sql_utils.cpp src/mcp_server.cpp src/mcp_tool_handler.cpp + src/mcp_tool_rate_limiter.cpp src/mcp_route_handlers.cpp src/mcp_session_manager.cpp src/mcp_error_builder.cpp diff --git a/src/endpoint_config_parser.cpp b/src/endpoint_config_parser.cpp index 1a25f75..d29eace 100644 --- a/src/endpoint_config_parser.cpp +++ b/src/endpoint_config_parser.cpp @@ -212,6 +212,13 @@ void EndpointConfigParser::parseMcpToolFields( } } + // W2.5: per-tool rate limit. Absent block → enabled=false (no gate). + if (auto rl = mcp_tool_node["rate-limit"]; rl.IsDefined()) { + tool_info.rate_limit.enabled = config_manager_->safeGet(rl, "enabled", "mcp-tool.rate-limit.enabled", true); + tool_info.rate_limit.max = config_manager_->safeGet(rl, "max", "mcp-tool.rate-limit.max", 100); + tool_info.rate_limit.interval = config_manager_->safeGet(rl, "interval", "mcp-tool.rate-limit.interval", 60); + } + config.mcp_tool = tool_info; } diff --git a/src/include/config_manager.hpp b/src/include/config_manager.hpp index f1d5638..fdb091d 100644 --- a/src/include/config_manager.hpp +++ b/src/include/config_manager.hpp @@ -202,6 +202,11 @@ struct EndpointConfig { std::vector redact_columns; bool sample = false; } response; + + // W2.5: per-tool rate limit. `enabled: false` (default) leaves + // the tool ungated; otherwise `max` calls are permitted per + // `interval` seconds, scoped to (tool_name, principal). + RateLimitConfig rate_limit; }; struct MCPResourceInfo { diff --git a/src/include/mcp_tool_handler.hpp b/src/include/mcp_tool_handler.hpp index fd75796..5e16f4c 100644 --- a/src/include/mcp_tool_handler.hpp +++ b/src/include/mcp_tool_handler.hpp @@ -10,6 +10,7 @@ #include "config_manager.hpp" #include "database_manager.hpp" #include "mcp_authorization_policy.hpp" +#include "mcp_tool_rate_limiter.hpp" #include "sql_template_processor.hpp" #include "request_validator.hpp" @@ -82,6 +83,7 @@ QueryResult executeQueryWithEndpoint(const EndpointConfig& endpoint_config, std::unique_ptr sql_processor; std::shared_ptr audit_logger; MCPAuthorizationPolicy authorization_policy; + MCPToolRateLimiter rate_limiter; }; } // namespace flapi diff --git a/src/include/mcp_tool_rate_limiter.hpp b/src/include/mcp_tool_rate_limiter.hpp new file mode 100644 index 0000000..4331c0b --- /dev/null +++ b/src/include/mcp_tool_rate_limiter.hpp @@ -0,0 +1,57 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace flapi { + +struct RateLimitConfig; + +// W2.5: Per-tool rate limiter for MCP tool calls. A separate enforcement +// point from the REST `RateLimitMiddleware` because MCP tool calls all +// land on the same HTTP path (`/mcp/jsonrpc`) — keying on the URL path +// can't distinguish two tools. This limiter keys on (tool_name, principal) +// instead. +// +// Thread-safe (mutex-guarded); shared by all concurrent tool calls in +// the process. Disabled `RateLimitConfig` short-circuits to allow. +class MCPToolRateLimiter { +public: + struct AcquireDecision { + bool allowed = false; + std::int64_t remaining = 0; + std::int64_t retry_after_seconds = 0; + }; + + using Clock = std::function; + + MCPToolRateLimiter(); + explicit MCPToolRateLimiter(Clock clock); + + // Try to consume one slot in the bucket identified by + // (tool_name, principal). Returns whether the call is allowed, + // how many slots remain in the current window, and (when denied) + // how many seconds until the window resets. + AcquireDecision tryAcquire(const std::string& tool_name, + const std::string& principal, + const RateLimitConfig& config); + +private: + struct Bucket { + std::int64_t remaining = 0; + std::chrono::steady_clock::time_point reset_time; + }; + + static std::string keyFor(const std::string& tool_name, + const std::string& principal); + + Clock clock_; + std::mutex mutex_; + std::unordered_map buckets_; +}; + +} // namespace flapi diff --git a/src/mcp_tool_handler.cpp b/src/mcp_tool_handler.cpp index d19c6d6..1d0e6f8 100644 --- a/src/mcp_tool_handler.cpp +++ b/src/mcp_tool_handler.cpp @@ -77,6 +77,35 @@ MCPToolExecutionResult MCPToolHandler::executeTool(const MCPToolCallRequest& req } } + // W2.5: per-tool rate limit. Runs before argument validation so a + // flooded caller never consumes DB or template I/O. The principal + // key falls back to a stable marker when the request is anonymous + // so anonymous floods share one bucket per tool. + if (endpoint_config->mcp_tool && endpoint_config->mcp_tool->rate_limit.enabled) { + std::string principal = "anonymous"; + auto ctx_it = request.context.find("auth.username"); + if (ctx_it != request.context.end() && !ctx_it->second.empty()) { + principal = ctx_it->second; + } + auto decision = rate_limiter.tryAcquire(request.tool_name, + principal, + endpoint_config->mcp_tool->rate_limit); + if (!decision.allowed) { + std::unordered_map metadata; + metadata["tool_name"] = request.tool_name; + metadata["rate_limited"] = "true"; + metadata["retry_after_seconds"] = std::to_string(decision.retry_after_seconds); + MCPToolExecutionResult result; + result.success = false; + result.error_message = "Rate limit exceeded for tool '" + request.tool_name + + "'. Retry after " + + std::to_string(decision.retry_after_seconds) + + " seconds."; + result.metadata = std::move(metadata); + return result; + } + } + // W2.2 dry-run: peel `_dryRun` off the arguments before validation so // the reserved key never reaches the unknown-parameter check. A copy // of the arguments is made because MCPToolCallRequest is const here. diff --git a/src/mcp_tool_rate_limiter.cpp b/src/mcp_tool_rate_limiter.cpp new file mode 100644 index 0000000..49d4ff0 --- /dev/null +++ b/src/mcp_tool_rate_limiter.cpp @@ -0,0 +1,49 @@ +#include "mcp_tool_rate_limiter.hpp" + +#include "config_manager.hpp" + +namespace flapi { + +MCPToolRateLimiter::MCPToolRateLimiter() + : clock_([]() { return std::chrono::steady_clock::now(); }) {} + +MCPToolRateLimiter::MCPToolRateLimiter(Clock clock) + : clock_(std::move(clock)) {} + +std::string MCPToolRateLimiter::keyFor(const std::string& tool_name, + const std::string& principal) { + return tool_name + "|" + principal; +} + +MCPToolRateLimiter::AcquireDecision MCPToolRateLimiter::tryAcquire( + const std::string& tool_name, + const std::string& principal, + const RateLimitConfig& config) { + + if (!config.enabled) { + return {true, /*remaining=*/0, /*retry_after=*/0}; + } + + const auto key = keyFor(tool_name, principal); + const auto now = clock_(); + + std::lock_guard guard(mutex_); + auto& bucket = buckets_[key]; + + // Initialise or roll over an expired window. + if (now >= bucket.reset_time) { + bucket.remaining = config.max; + bucket.reset_time = now + std::chrono::seconds(config.interval); + } + + if (bucket.remaining <= 0) { + const auto until = std::chrono::duration_cast( + bucket.reset_time - now).count(); + return {false, /*remaining=*/0, /*retry_after=*/std::max(1, until)}; + } + + bucket.remaining -= 1; + return {true, bucket.remaining, /*retry_after=*/0}; +} + +} // namespace flapi diff --git a/test/cpp/CMakeLists.txt b/test/cpp/CMakeLists.txt index 0fc40ee..e773082 100644 --- a/test/cpp/CMakeLists.txt +++ b/test/cpp/CMakeLists.txt @@ -26,6 +26,7 @@ add_executable(flapi_tests mcp_prompt_handler_test.cpp mcp_request_validator_test.cpp mcp_response_shaper_test.cpp + mcp_tool_rate_limiter_test.cpp password_hasher_test.cpp query_executor_test.cpp rate_limit_key_builder_test.cpp diff --git a/test/cpp/endpoint_config_parser_test.cpp b/test/cpp/endpoint_config_parser_test.cpp index fcca2ff..a69943a 100644 --- a/test/cpp/endpoint_config_parser_test.cpp +++ b/test/cpp/endpoint_config_parser_test.cpp @@ -91,6 +91,38 @@ template-source: test.sql REQUIRE_FALSE(result.config.mcp_tool->response.max_rows.has_value()); REQUIRE(result.config.mcp_tool->response.redact_columns.empty()); REQUIRE_FALSE(result.config.mcp_tool->response.sample); + // Default: no per-tool rate limit configured. + REQUIRE_FALSE(result.config.mcp_tool->rate_limit.enabled); + + fs::remove(yaml_file); + fs::remove(config_file); +} + +TEST_CASE("EndpointConfigParser: Parse MCP Tool with rate-limit", "[endpoint_parser][ratelimit]") { + std::string yaml_content = R"( +mcp-tool: + name: throttled_tool + description: Tool with a per-tool rate limit + rate-limit: + enabled: true + max: 5 + interval: 30 +template-source: test.sql +connection: + - test_db +)"; + + std::string yaml_file = createTempYamlFile(yaml_content); + std::string config_file = createMinimalFlapiConfig(); + + ConfigManager manager{fs::path(config_file)}; + EndpointConfigParser parser(manager.getYamlParser(), &manager); + auto result = parser.parseFromFile(yaml_file); + + REQUIRE(result.success == true); + REQUIRE(result.config.mcp_tool->rate_limit.enabled); + REQUIRE(result.config.mcp_tool->rate_limit.max == 5); + REQUIRE(result.config.mcp_tool->rate_limit.interval == 30); fs::remove(yaml_file); fs::remove(config_file); diff --git a/test/cpp/mcp_tool_rate_limiter_test.cpp b/test/cpp/mcp_tool_rate_limiter_test.cpp new file mode 100644 index 0000000..adfcac8 --- /dev/null +++ b/test/cpp/mcp_tool_rate_limiter_test.cpp @@ -0,0 +1,179 @@ +#include + +#include +#include +#include +#include + +#include "config_manager.hpp" +#include "mcp_tool_rate_limiter.hpp" + +namespace flapi { +namespace test { + +namespace { + +using namespace std::chrono_literals; +using time_point = std::chrono::steady_clock::time_point; + +RateLimitConfig makeLimit(int max, int interval) { + RateLimitConfig c; + c.enabled = true; + c.max = max; + c.interval = interval; + return c; +} + +class FakeClock { +public: + explicit FakeClock(time_point start = time_point{}) + : now_(start) {} + time_point now() const { return now_; } + void advance(std::chrono::seconds delta) { now_ += delta; } +private: + time_point now_; +}; + +} // namespace + +TEST_CASE("MCPToolRateLimiter: disabled config always allows", + "[security][mcp][ratelimit]") { + FakeClock clk; + MCPToolRateLimiter limiter([&clk]() { return clk.now(); }); + RateLimitConfig cfg; + cfg.enabled = false; + cfg.max = 1; + cfg.interval = 60; + + for (int i = 0; i < 50; ++i) { + auto decision = limiter.tryAcquire("tool_a", "principal_a", cfg); + REQUIRE(decision.allowed); + } +} + +TEST_CASE("MCPToolRateLimiter: max=N allows exactly N calls then denies", + "[security][mcp][ratelimit]") { + FakeClock clk; + MCPToolRateLimiter limiter([&clk]() { return clk.now(); }); + auto cfg = makeLimit(/*max=*/3, /*interval=*/60); + + REQUIRE(limiter.tryAcquire("t", "p", cfg).allowed); + REQUIRE(limiter.tryAcquire("t", "p", cfg).allowed); + REQUIRE(limiter.tryAcquire("t", "p", cfg).allowed); + auto denied = limiter.tryAcquire("t", "p", cfg); + REQUIRE_FALSE(denied.allowed); + REQUIRE(denied.retry_after_seconds > 0); +} + +TEST_CASE("MCPToolRateLimiter: bucket resets after the interval elapses", + "[security][mcp][ratelimit]") { + FakeClock clk; + MCPToolRateLimiter limiter([&clk]() { return clk.now(); }); + auto cfg = makeLimit(/*max=*/2, /*interval=*/30); + + REQUIRE(limiter.tryAcquire("t", "p", cfg).allowed); + REQUIRE(limiter.tryAcquire("t", "p", cfg).allowed); + REQUIRE_FALSE(limiter.tryAcquire("t", "p", cfg).allowed); + + // Step past the reset boundary. + clk.advance(31s); + REQUIRE(limiter.tryAcquire("t", "p", cfg).allowed); + REQUIRE(limiter.tryAcquire("t", "p", cfg).allowed); + REQUIRE_FALSE(limiter.tryAcquire("t", "p", cfg).allowed); +} + +TEST_CASE("MCPToolRateLimiter: two tools have independent buckets", + "[security][mcp][ratelimit]") { + // Per-tool rate limiting is the whole point of W2.5 — tool A's + // traffic must not consume tool B's quota. + FakeClock clk; + MCPToolRateLimiter limiter([&clk]() { return clk.now(); }); + auto cfg = makeLimit(/*max=*/1, /*interval=*/60); + + REQUIRE(limiter.tryAcquire("tool_a", "p", cfg).allowed); + REQUIRE_FALSE(limiter.tryAcquire("tool_a", "p", cfg).allowed); + + // tool_b still has a full bucket. + REQUIRE(limiter.tryAcquire("tool_b", "p", cfg).allowed); + REQUIRE_FALSE(limiter.tryAcquire("tool_b", "p", cfg).allowed); +} + +TEST_CASE("MCPToolRateLimiter: two principals have independent buckets", + "[security][mcp][ratelimit]") { + FakeClock clk; + MCPToolRateLimiter limiter([&clk]() { return clk.now(); }); + auto cfg = makeLimit(/*max=*/2, /*interval=*/60); + + REQUIRE(limiter.tryAcquire("t", "alice", cfg).allowed); + REQUIRE(limiter.tryAcquire("t", "alice", cfg).allowed); + REQUIRE_FALSE(limiter.tryAcquire("t", "alice", cfg).allowed); + + // bob is a different principal on the same tool — full quota. + REQUIRE(limiter.tryAcquire("t", "bob", cfg).allowed); + REQUIRE(limiter.tryAcquire("t", "bob", cfg).allowed); + REQUIRE_FALSE(limiter.tryAcquire("t", "bob", cfg).allowed); +} + +TEST_CASE("MCPToolRateLimiter: retry_after equals seconds until the bucket resets", + "[security][mcp][ratelimit]") { + FakeClock clk; + MCPToolRateLimiter limiter([&clk]() { return clk.now(); }); + auto cfg = makeLimit(/*max=*/1, /*interval=*/45); + + REQUIRE(limiter.tryAcquire("t", "p", cfg).allowed); + clk.advance(10s); + auto denied = limiter.tryAcquire("t", "p", cfg); + REQUIRE_FALSE(denied.allowed); + // Reset at t=45, currently at t=10 → ~35s left. + REQUIRE(denied.retry_after_seconds >= 34); + REQUIRE(denied.retry_after_seconds <= 35); +} + +TEST_CASE("MCPToolRateLimiter: remaining counts down on each successful acquire", + "[security][mcp][ratelimit]") { + FakeClock clk; + MCPToolRateLimiter limiter([&clk]() { return clk.now(); }); + auto cfg = makeLimit(/*max=*/3, /*interval=*/60); + + auto d1 = limiter.tryAcquire("t", "p", cfg); + REQUIRE(d1.allowed); + REQUIRE(d1.remaining == 2); + + auto d2 = limiter.tryAcquire("t", "p", cfg); + REQUIRE(d2.allowed); + REQUIRE(d2.remaining == 1); + + auto d3 = limiter.tryAcquire("t", "p", cfg); + REQUIRE(d3.allowed); + REQUIRE(d3.remaining == 0); +} + +TEST_CASE("MCPToolRateLimiter: concurrent acquires honour the cap exactly", + "[security][mcp][ratelimit][threading]") { + // Real-clock test: hammer the limiter from N threads with max=K + // and verify exactly K acquires succeed. + MCPToolRateLimiter limiter; // default clock + auto cfg = makeLimit(/*max=*/50, /*interval=*/60); + + constexpr int kThreads = 16; + constexpr int kAttemptsPerThread = 25; + std::atomic allowed_count{0}; + std::vector workers; + workers.reserve(kThreads); + for (int t = 0; t < kThreads; ++t) { + workers.emplace_back([&]() { + for (int i = 0; i < kAttemptsPerThread; ++i) { + if (limiter.tryAcquire("t", "p", cfg).allowed) { + allowed_count.fetch_add(1, std::memory_order_relaxed); + } + } + }); + } + for (auto& w : workers) { + w.join(); + } + REQUIRE(allowed_count.load() == 50); +} + +} // namespace test +} // namespace flapi diff --git a/test/integration/test_mcp_per_tool_rate_limit.py b/test/integration/test_mcp_per_tool_rate_limit.py new file mode 100644 index 0000000..8f6fc07 --- /dev/null +++ b/test/integration/test_mcp_per_tool_rate_limit.py @@ -0,0 +1,279 @@ +"""End-to-end tests for per-tool MCP rate limits (issue #24, W2.5). + +Boots flapi with two MCP tools that each declare different +`rate-limit` blocks. Hammers each independently and verifies: +- Tool A's traffic does not consume tool B's quota. +- Once a tool's bucket is empty, calls receive a Rate-limit error + response with `retry_after_seconds` in the JSON-RPC error payload. + +Marked `standalone_server` so the conftest autouse fixture does not +spin up the shared api_configuration server. Skips cleanly when flapi +cannot boot (local DuckDB extension-cache mismatch); CI runs against +fresh extensions. +""" + +import base64 +import hashlib +import hmac +import json +import os +import socket +import subprocess +import tempfile +import time +from typing import Dict, Iterator, List + +import pytest +import requests + + +JWT_SECRET = "tool-rate-test-secret" +JWT_ISSUER = "tool-rate-test-issuer" + + +def _repo_root() -> str: + return os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) + + +def _flapi_binary() -> str: + candidates: List[str] = [] + for build_type in ("release", "debug"): + path = os.path.join(_repo_root(), "build", build_type, "flapi") + if os.path.exists(path): + candidates.append(path) + if not candidates: + pytest.skip("flapi binary not found in build/release or build/debug") + candidates.sort(key=os.path.getmtime, reverse=True) + return candidates[0] + + +def _free_port() -> int: + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +def _b64url(data: bytes) -> str: + return base64.urlsafe_b64encode(data).rstrip(b"=").decode("utf-8") + + +def _make_jwt(sub: str = "rate-test-user") -> str: + header = {"alg": "HS256", "typ": "JWT"} + now = int(time.time()) + payload = { + "iss": JWT_ISSUER, + "sub": sub, + "roles": ["analyst"], + "iat": now, + "exp": now + 3600, + } + h = _b64url(json.dumps(header, separators=(",", ":")).encode("utf-8")) + p = _b64url(json.dumps(payload, separators=(",", ":")).encode("utf-8")) + signature = hmac.new(JWT_SECRET.encode("utf-8"), f"{h}.{p}".encode("utf-8"), hashlib.sha256).digest() + return f"{h}.{p}.{_b64url(signature)}" + + +def _write_config(dirpath: str, port: int) -> str: + sqls = os.path.join(dirpath, "sqls") + os.makedirs(sqls) + + with open(os.path.join(dirpath, "flapi.yaml"), "w") as f: + f.write( + f"project-name: per-tool-rate-test\n" + f"project-description: Per-tool MCP rate limit E2E\n" + f"http-port: {port}\n" + f"template:\n" + f" path: ./sqls\n" + f"connections:\n" + f" inmem:\n" + f" properties:\n" + f" database: ':memory:'\n" + f"mcp:\n" + f" enabled: true\n" + f" auth:\n" + f" enabled: true\n" + f" type: bearer\n" + f" jwt-secret: {JWT_SECRET}\n" + f" jwt-issuer: {JWT_ISSUER}\n" + ) + + # tool_a: max 2 calls / 60s + with open(os.path.join(sqls, "tool_a.yaml"), "w") as f: + f.write(""" +template-source: tool_a.sql +connection: [inmem] +mcp-tool: + name: tool_a + description: Tool A, capped at 2 calls/minute + rate-limit: + enabled: true + max: 2 + interval: 60 +""") + with open(os.path.join(sqls, "tool_a.sql"), "w") as f: + f.write("SELECT 'tool_a' AS name\n") + + # tool_b: max 5 calls / 60s (independent bucket) + with open(os.path.join(sqls, "tool_b.yaml"), "w") as f: + f.write(""" +template-source: tool_b.sql +connection: [inmem] +mcp-tool: + name: tool_b + description: Tool B, capped at 5 calls/minute + rate-limit: + enabled: true + max: 5 + interval: 60 +""") + with open(os.path.join(sqls, "tool_b.sql"), "w") as f: + f.write("SELECT 'tool_b' AS name\n") + + return os.path.join(dirpath, "flapi.yaml") + + +@pytest.fixture +def rate_limit_server() -> Iterator[Dict[str, str]]: + binary = _flapi_binary() + port = _free_port() + with tempfile.TemporaryDirectory(prefix="flapi_tool_rl_") as tmpdir: + config_path = _write_config(tmpdir, port) + log_path = os.path.join(tmpdir, "server.log") + log_file = open(log_path, "w") + proc = subprocess.Popen( + [binary, "-c", config_path, "--no-telemetry"], + cwd=tmpdir, + stdout=log_file, + stderr=subprocess.STDOUT, + ) + try: + base_url = f"http://127.0.0.1:{port}" + deadline = time.time() + 30 + up = False + while time.time() < deadline: + if proc.poll() is not None: + break + try: + r = requests.get(f"{base_url}/mcp/health", timeout=1) + if r.status_code < 500: + up = True + break + except requests.exceptions.RequestException: + time.sleep(0.5) + if not up: + proc.terminate() + try: + proc.wait(timeout=5) + except subprocess.TimeoutExpired: + proc.kill() + log_file.close() + with open(log_path) as f: + log_text = f.read() + if "core_functions_duckdb_cpp_init" in log_text and "unique_ptr that is NULL" in log_text: + pytest.skip( + "flapi could not boot: local DuckDB extension cache is " + "incompatible with the in-tree DuckDB submodule. CI exercises this path." + ) + raise RuntimeError(f"flapi failed to start. Log:\n{log_text}") + yield {"base_url": base_url} + finally: + proc.terminate() + try: + proc.wait(timeout=10) + except subprocess.TimeoutExpired: + proc.kill() + log_file.close() + + +def _open_session(base_url: str, token: str) -> str: + r = requests.post( + f"{base_url}/mcp/jsonrpc", + headers={"Authorization": f"Bearer {token}", "Content-Type": "application/json"}, + json={ + "jsonrpc": "2.0", + "id": "init-1", + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "clientInfo": {"name": "rate-limit-test", "version": "0.1"}, + "capabilities": {}, + }, + }, + timeout=10, + ) + assert r.status_code == 200, r.text + sid = r.headers.get("Mcp-Session-Id") + assert sid, f"no session id: {dict(r.headers)}" + return sid + + +def _tools_call(base_url: str, token: str, sid: str, tool: str) -> requests.Response: + return requests.post( + f"{base_url}/mcp/jsonrpc", + headers={ + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + "Mcp-Session-Id": sid, + }, + json={ + "jsonrpc": "2.0", + "id": f"call-{tool}", + "method": "tools/call", + "params": {"name": tool, "arguments": {}}, + }, + timeout=10, + ) + + +def _is_rate_limited(body: dict) -> bool: + err = body.get("error", "") + if isinstance(err, str): + return "Rate limit exceeded" in err + if isinstance(err, dict): + return "Rate limit exceeded" in err.get("message", "") + return False + + +@pytest.mark.standalone_server +class TestPerToolRateLimit: + """End-to-end coverage for `mcp-tool.rate-limit`.""" + + def test_tool_a_blocks_after_its_two_call_quota(self, rate_limit_server): + token = _make_jwt() + sid = _open_session(rate_limit_server["base_url"], token) + + # Two calls allowed, third must be rate-limited. + r1 = _tools_call(rate_limit_server["base_url"], token, sid, "tool_a") + r2 = _tools_call(rate_limit_server["base_url"], token, sid, "tool_a") + r3 = _tools_call(rate_limit_server["base_url"], token, sid, "tool_a") + + assert r1.status_code == 200 + assert r2.status_code == 200 + assert r3.status_code == 200 + assert not _is_rate_limited(r1.json()), r1.json() + assert not _is_rate_limited(r2.json()), r2.json() + assert _is_rate_limited(r3.json()), r3.json() + + def test_tool_b_quota_independent_of_tool_a(self, rate_limit_server): + # After tool_a is exhausted, tool_b still has its full bucket. + token = _make_jwt() + sid = _open_session(rate_limit_server["base_url"], token) + + for _ in range(3): + _tools_call(rate_limit_server["base_url"], token, sid, "tool_a") + + # tool_a should now be blocked; tool_b should still allow 5 calls. + a_after = _tools_call(rate_limit_server["base_url"], token, sid, "tool_a") + assert _is_rate_limited(a_after.json()), a_after.json() + + b_results = [ + _tools_call(rate_limit_server["base_url"], token, sid, "tool_b") + for _ in range(5) + ] + for i, r in enumerate(b_results): + assert r.status_code == 200, r.text + assert not _is_rate_limited(r.json()), f"tool_b call {i} unexpectedly limited: {r.json()}" + + # 6th tool_b call must be rate-limited. + b_blocked = _tools_call(rate_limit_server["base_url"], token, sid, "tool_b") + assert _is_rate_limited(b_blocked.json()), b_blocked.json()