From 4e448266ea924f95b5d66f990d599bcc0a3da366 Mon Sep 17 00:00:00 2001 From: Elias Bachaalany Date: Sun, 21 Dec 2025 11:31:04 -0800 Subject: [PATCH] feat: tasks, sampling tools loop, and sync parity updates - Implement in-flight tasks with cancel + TTL (SEP-1686) - Add task status push notifications (running, terminal states) - Add statusMessage streaming support - Implement sampling tools loop - Add mount tool_names, streamable redirects, prompt meta - Cover elicitation defaults + nullability in tests - Fix MSVC assert popups in mounting tests --- .github/workflows/ci.yml | 6 - CMakeLists.txt | 12 +- include/fastmcpp/app.hpp | 15 + include/fastmcpp/mcp/handler.hpp | 11 +- include/fastmcpp/mcp/tasks.hpp | 26 + include/fastmcpp/prompts/prompt.hpp | 10 + include/fastmcpp/server/sampling.hpp | 80 ++ include/fastmcpp/server/session.hpp | 23 +- src/app.cpp | 170 +++- src/client/transports.cpp | 97 ++- src/mcp/handler.cpp | 773 ++++++++++++++++--- src/mcp/tasks.cpp | 42 + src/server/elicitation.cpp | 2 + src/server/sampling.cpp | 256 ++++++ src/server/sse_server.cpp | 84 ++ src/server/stdio_server.cpp | 97 +++ tests/app/mounting.cpp | 103 +++ tests/client/tasks.cpp | 64 +- tests/server/sse_tasks_notifications.cpp | 472 +++++++++++ tests/server/streamable_http_integration.cpp | 66 ++ tests/server/test_elicitation_defaults.cpp | 106 +++ tests/server/test_sampling_tools.cpp | 144 ++++ tests/server/test_server_session.cpp | 11 +- 23 files changed, 2517 insertions(+), 153 deletions(-) create mode 100644 include/fastmcpp/mcp/tasks.hpp create mode 100644 include/fastmcpp/server/sampling.hpp create mode 100644 src/mcp/tasks.cpp create mode 100644 src/server/sampling.cpp create mode 100644 tests/server/sse_tasks_notifications.cpp create mode 100644 tests/server/test_sampling_tools.cpp diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f6eab31..c762df0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -17,12 +17,6 @@ jobs: matrix: os: [ubuntu-latest, windows-latest, macos-latest] build_type: [Debug, Release] - include: - # ARM64 Linux builds - - os: ubuntu-24.04-arm - build_type: Debug - - os: ubuntu-24.04-arm - build_type: Release runs-on: ${{ matrix.os }} name: ${{ matrix.os }} (${{ matrix.build_type }}) diff --git a/CMakeLists.txt b/CMakeLists.txt index e60f77b..b85f7a5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,5 +1,5 @@ cmake_minimum_required(VERSION 3.16) -project(fastmcpp VERSION 2.14.0 LANGUAGES CXX) +project(fastmcpp VERSION 2.14.1 LANGUAGES CXX) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) @@ -18,6 +18,7 @@ add_library(fastmcpp_core src/app.cpp src/proxy.cpp src/mcp/handler.cpp + src/mcp/tasks.cpp src/resources/resource.cpp src/resources/manager.cpp src/resources/template.cpp @@ -30,6 +31,7 @@ add_library(fastmcpp_core src/server/context.cpp src/server/middleware.cpp src/server/security_middleware.cpp + src/server/sampling.cpp src/server/http_server.cpp src/server/stdio_server.cpp src/server/sse_server.cpp @@ -231,6 +233,10 @@ if(FASTMCPP_BUILD_TESTS) target_link_libraries(fastmcpp_sse_mcp_format PRIVATE fastmcpp_core) add_test(NAME fastmcpp_sse_mcp_format COMMAND fastmcpp_sse_mcp_format) + add_executable(fastmcpp_sse_tasks_notifications tests/server/sse_tasks_notifications.cpp) + target_link_libraries(fastmcpp_sse_tasks_notifications PRIVATE fastmcpp_core) + add_test(NAME fastmcpp_sse_tasks_notifications COMMAND fastmcpp_sse_tasks_notifications) + # Advanced test suites (Task 3.4) add_executable(fastmcpp_tools_validation tests/tools/validation.cpp) target_link_libraries(fastmcpp_tools_validation PRIVATE fastmcpp_core) @@ -288,6 +294,10 @@ if(FASTMCPP_BUILD_TESTS) target_link_libraries(fastmcpp_server_context_sampling PRIVATE fastmcpp_core) add_test(NAME fastmcpp_server_context_sampling COMMAND fastmcpp_server_context_sampling) + add_executable(fastmcpp_server_sampling_tools tests/server/test_sampling_tools.cpp) + target_link_libraries(fastmcpp_server_sampling_tools PRIVATE fastmcpp_core) + add_test(NAME fastmcpp_server_sampling_tools COMMAND fastmcpp_server_sampling_tools) + add_executable(fastmcpp_server_elicitation_defaults tests/server/test_elicitation_defaults.cpp) target_link_libraries(fastmcpp_server_elicitation_defaults PRIVATE fastmcpp_core) add_test(NAME fastmcpp_server_elicitation_defaults COMMAND fastmcpp_server_elicitation_defaults) diff --git a/include/fastmcpp/app.hpp b/include/fastmcpp/app.hpp index 59ad9e0..3881984 100644 --- a/include/fastmcpp/app.hpp +++ b/include/fastmcpp/app.hpp @@ -10,6 +10,7 @@ #include #include #include +#include #include namespace fastmcpp @@ -20,6 +21,8 @@ struct MountedApp { std::string prefix; // Prefix for tools/prompts (e.g., "weather") class FastMCP* app; // Non-owning pointer to mounted app + std::optional> + tool_names; // Optional tool name overrides }; /// Proxy-mounted app with prefix (proxy mode) @@ -27,6 +30,8 @@ struct ProxyMountedApp { std::string prefix; // Prefix for tools/prompts std::unique_ptr proxy; // Owning pointer to proxy wrapper + std::optional> + tool_names; // Optional tool name overrides }; /// MCP Application - bundles server metadata with managers @@ -125,8 +130,14 @@ class FastMCP /// @param app The app to mount (must outlive this app in direct mode) /// @param prefix Optional prefix (empty string = no prefix) /// @param as_proxy If true, mount in proxy mode (uses MCP handler for communication) + /// @param tool_names Optional mapping of original tool names to custom names. Keys are the + /// original tool names from the mounted server (after any nested prefixing). void mount(FastMCP& app, const std::string& prefix = "", bool as_proxy = false); + /// Mount another app with optional tool name overrides. + void mount(FastMCP& app, const std::string& prefix, bool as_proxy, + std::optional> tool_names); + /// Get list of directly mounted apps const std::vector& mounted() const { @@ -173,6 +184,10 @@ class FastMCP /// Get prompt messages by name (handles prefixed routing) std::vector get_prompt(const std::string& name, const Json& args) const; + /// Get prompt result by name (handles prefixed routing) + /// Includes description and optional _meta parity with Python SDK (fastmcp 2.14.1+). + prompts::PromptResult get_prompt_result(const std::string& name, const Json& args) const; + private: server::Server server_; tools::ToolManager tools_; diff --git a/include/fastmcpp/mcp/handler.hpp b/include/fastmcpp/mcp/handler.hpp index fefa482..c1b77bf 100644 --- a/include/fastmcpp/mcp/handler.hpp +++ b/include/fastmcpp/mcp/handler.hpp @@ -26,6 +26,9 @@ class SseServerWrapper; namespace fastmcpp::mcp { +/// Session accessor callback type - retrieves ServerSession for a session_id +using SessionAccessor = std::function(const std::string&)>; + // Factory that produces a JSON-RPC handler compatible with ClaudeOptions::sdk_mcp_handlers. // It supports a subset of MCP methods needed for in-process tools: // - "initialize" @@ -62,13 +65,15 @@ make_mcp_handler(const std::string& server_name, const std::string& version, // Uses app's aggregated lists and routing for mounted sub-apps std::function make_mcp_handler(const FastMCP& app); +// Overload: FastMCP handler with session access. +// Enables server-initiated features (e.g., task status push) keyed by params._meta.session_id. +std::function +make_mcp_handler(const FastMCP& app, SessionAccessor session_accessor); + // MCP handler from ProxyApp - supports proxying to backend server // Uses app's aggregated lists (local + remote) and routing std::function make_mcp_handler(const ProxyApp& app); -/// Session accessor callback type - retrieves ServerSession for a session_id -using SessionAccessor = std::function(const std::string&)>; - /// MCP handler with sampling support /// The session_accessor callback is used to get ServerSession for sampling requests. /// Session ID is extracted from params._meta.session_id (injected by SSE server). diff --git a/include/fastmcpp/mcp/tasks.hpp b/include/fastmcpp/mcp/tasks.hpp new file mode 100644 index 0000000..b39efba --- /dev/null +++ b/include/fastmcpp/mcp/tasks.hpp @@ -0,0 +1,26 @@ +#pragma once + +#include + +namespace fastmcpp::mcp::tasks +{ + +/// Report a status message for the currently executing background task (SEP-1686). +/// +/// This sends best-effort `notifications/tasks/status` updates (via the transport/session) +/// when called from within a task execution context created by `mcp::make_mcp_handler(...)`. +/// +/// No-op if called outside a background task context. +void report_status_message(const std::string& message); + +namespace detail +{ +using StatusMessageFn = void (*)(void* ctx, const std::string& task_id, const std::string& message); + +// Internal: set/clear the task context for the current thread. +// Used by the MCP task execution runtime (TaskRegistry). +void set_current_task(void* ctx, StatusMessageFn fn, std::string task_id); +void clear_current_task(); +} // namespace detail + +} // namespace fastmcpp::mcp::tasks diff --git a/include/fastmcpp/prompts/prompt.hpp b/include/fastmcpp/prompts/prompt.hpp index 742fba3..03bef18 100644 --- a/include/fastmcpp/prompts/prompt.hpp +++ b/include/fastmcpp/prompts/prompt.hpp @@ -25,11 +25,21 @@ struct PromptMessage std::string content; // Message content }; +/// Result of prompts/get (prompt rendering) +struct PromptResult +{ + std::vector messages; + std::optional description; + std::optional meta; // Returned as _meta in MCP prompts/get +}; + /// MCP Prompt definition struct Prompt { std::string name; std::optional description; + std::optional + meta; // Optional prompt metadata (returned as _meta in prompts/get) std::vector arguments; std::function(const Json&)> generator; // Message generator fastmcpp::TaskSupport task_support{fastmcpp::TaskSupport::Forbidden}; // SEP-1686 task mode diff --git a/include/fastmcpp/server/sampling.hpp b/include/fastmcpp/server/sampling.hpp new file mode 100644 index 0000000..1e32304 --- /dev/null +++ b/include/fastmcpp/server/sampling.hpp @@ -0,0 +1,80 @@ +#pragma once +#include "fastmcpp/server/session.hpp" +#include "fastmcpp/types.hpp" + +#include +#include +#include +#include +#include +#include + +namespace fastmcpp::server::sampling +{ + +// --------------------------------------------------------------------------- +// SEP-1577 sampling-with-tools helpers (server-initiated sampling/createMessage) +// --------------------------------------------------------------------------- + +struct Tool +{ + std::string name; + std::optional description; + fastmcpp::Json input_schema{fastmcpp::Json::object()}; + std::function fn; +}; + +struct Message +{ + std::string role; // "user" or "assistant" + fastmcpp::Json content; // MCP SamplingMessageContentBlock or list thereof +}; + +inline Message make_text_message(const std::string& role, const std::string& text) +{ + return Message{role, fastmcpp::Json{{"type", "text"}, {"text", text}}}; +} + +struct Options +{ + std::optional system_prompt; + std::optional temperature; + int max_tokens{512}; + std::optional model_preferences; + std::optional> stop_sequences; + std::optional metadata; + + std::optional> tools; + // Simplified tool choice: "auto", "required", or "none" + std::optional tool_choice; + + bool execute_tools{true}; + bool mask_error_details{false}; + int max_iterations{10}; + std::chrono::milliseconds timeout{ServerSession::DEFAULT_TIMEOUT}; +}; + +struct Step +{ + fastmcpp::Json response; // CreateMessageResult(+WithTools) JSON + std::vector history; + + bool is_tool_use() const; + std::optional text() const; + std::vector tool_calls() const; +}; + +struct Result +{ + std::optional text; + fastmcpp::Json response; + std::vector history; +}; + +Step sample_step(std::shared_ptr session, const std::vector& messages, + const Options& options); + +Result sample(std::shared_ptr session, const std::vector& messages, + Options options); + +} // namespace fastmcpp::server::sampling diff --git a/include/fastmcpp/server/session.hpp b/include/fastmcpp/server/session.hpp index dcb359e..127976f 100644 --- a/include/fastmcpp/server/session.hpp +++ b/include/fastmcpp/server/session.hpp @@ -103,8 +103,17 @@ class ServerSession capabilities_ = capabilities; // Parse common capability flags + supports_sampling_ = false; + supports_sampling_tools_ = false; + supports_elicitation_ = false; + supports_roots_ = false; if (capabilities.contains("sampling") && capabilities["sampling"].is_object()) + { supports_sampling_ = true; + const auto& sampling = capabilities["sampling"]; + if (sampling.contains("tools") && sampling["tools"].is_object()) + supports_sampling_tools_ = true; + } if (capabilities.contains("elicitation") && capabilities["elicitation"].is_object()) supports_elicitation_ = true; if (capabilities.contains("roots") && capabilities["roots"].is_object()) @@ -118,6 +127,13 @@ class ServerSession return supports_sampling_; } + /// Check if client supports sampling with tools (sampling.tools capability) + bool supports_sampling_tools() const + { + std::lock_guard lock(cap_mutex_); + return supports_sampling_tools_; + } + /// Check if client supports elicitation bool supports_elicitation() const { @@ -262,10 +278,14 @@ class ServerSession * * @param method The JSON-RPC method name (e.g., "notifications/progress") * @param params Notification parameters + * @param meta Optional top-level _meta for the notification */ - void send_notification(const std::string& method, const Json& params = Json::object()) + void send_notification(const std::string& method, const Json& params = Json::object(), + const std::optional& meta = std::nullopt) { Json notification = {{"jsonrpc", "2.0"}, {"method", method}, {"params", params}}; + if (meta.has_value() && meta->is_object() && !meta->empty()) + notification["_meta"] = *meta; if (send_callback_) send_callback_(notification); @@ -331,6 +351,7 @@ class ServerSession bool supports_sampling_{false}; bool supports_elicitation_{false}; bool supports_roots_{false}; + bool supports_sampling_tools_{false}; // Pending requests std::mutex pending_mutex_; diff --git a/src/app.cpp b/src/app.cpp index f08770e..585247b 100644 --- a/src/app.cpp +++ b/src/app.cpp @@ -5,6 +5,8 @@ #include "fastmcpp/exceptions.hpp" #include "fastmcpp/mcp/handler.hpp" +#include + namespace fastmcpp { @@ -16,6 +18,43 @@ FastMCP::FastMCP(std::string name, std::string version, std::optional>& tool_names) +{ + if (!tool_names) + return; + + std::unordered_set seen; + seen.reserve(tool_names->size()); + for (const auto& [_, value] : *tool_names) + if (!seen.insert(value).second) + throw fastmcpp::ValidationError("tool_names values must be unique"); +} + +std::optional find_original_tool_name_for_override( + const std::optional>& tool_names, + const std::string& exposed_name) +{ + if (!tool_names) + return std::nullopt; + + for (const auto& [original_name, custom_name] : *tool_names) + if (custom_name == exposed_name) + return original_name; + return std::nullopt; +} +} // namespace + +void FastMCP::mount(FastMCP& app, const std::string& prefix, bool as_proxy, + std::optional> tool_names) +{ + validate_tool_name_overrides(tool_names); + if (as_proxy) { // Create MCP handler for the app @@ -28,11 +67,11 @@ void FastMCP::mount(FastMCP& app, const std::string& prefix, bool as_proxy) // Create ProxyApp wrapper auto proxy = std::make_unique(client_factory, app.name(), app.version()); - proxy_mounted_.push_back({prefix, std::move(proxy)}); + proxy_mounted_.push_back({prefix, std::move(proxy), std::move(tool_names)}); } else { - mounted_.push_back({prefix, &app}); + mounted_.push_back({prefix, &app, std::move(tool_names)}); } } @@ -128,7 +167,19 @@ std::vector> FastMCP::list_all_tools( for (const auto& [child_name, tool] : child_tools) { - std::string prefixed_name = add_prefix(child_name, mounted.prefix); + std::string prefixed_name = child_name; + if (mounted.tool_names) + { + auto override_it = mounted.tool_names->find(child_name); + if (override_it != mounted.tool_names->end()) + prefixed_name = override_it->second; + else + prefixed_name = add_prefix(child_name, mounted.prefix); + } + else + { + prefixed_name = add_prefix(child_name, mounted.prefix); + } result.emplace_back(prefixed_name, tool); } } @@ -143,7 +194,19 @@ std::vector> FastMCP::list_all_tools( for (const auto& tool_info : proxy_tools) { - std::string prefixed_name = add_prefix(tool_info.name, proxy_mount.prefix); + std::string prefixed_name = tool_info.name; + if (proxy_mount.tool_names) + { + auto override_it = proxy_mount.tool_names->find(tool_info.name); + if (override_it != proxy_mount.tool_names->end()) + prefixed_name = override_it->second; + else + prefixed_name = add_prefix(tool_info.name, proxy_mount.prefix); + } + else + { + prefixed_name = add_prefix(tool_info.name, proxy_mount.prefix); + } // We can't return a pointer for proxy tools, so we add a placeholder // This is a limitation - users should prefer list_all_tools_info() for full access result.emplace_back(prefixed_name, nullptr); @@ -183,7 +246,18 @@ std::vector FastMCP::list_all_tools_info() const for (auto& tool_info : child_tools) { - tool_info.name = add_prefix(tool_info.name, mounted.prefix); + if (mounted.tool_names) + { + auto override_it = mounted.tool_names->find(tool_info.name); + if (override_it != mounted.tool_names->end()) + tool_info.name = override_it->second; + else + tool_info.name = add_prefix(tool_info.name, mounted.prefix); + } + else + { + tool_info.name = add_prefix(tool_info.name, mounted.prefix); + } result.push_back(tool_info); } } @@ -196,7 +270,18 @@ std::vector FastMCP::list_all_tools_info() const for (auto& tool_info : proxy_tools) { - tool_info.name = add_prefix(tool_info.name, proxy_mount.prefix); + if (proxy_mount.tool_names) + { + auto override_it = proxy_mount.tool_names->find(tool_info.name); + if (override_it != proxy_mount.tool_names->end()) + tool_info.name = override_it->second; + else + tool_info.name = add_prefix(tool_info.name, proxy_mount.prefix); + } + else + { + tool_info.name = add_prefix(tool_info.name, proxy_mount.prefix); + } result.push_back(tool_info); } } @@ -356,16 +441,26 @@ Json FastMCP::invoke_tool(const std::string& name, const Json& args) const { const auto& mounted = *it; - std::string try_name = name; - if (!mounted.prefix.empty()) + std::optional overridden_original = + find_original_tool_name_for_override(mounted.tool_names, name); + std::string try_name; + if (overridden_original) { - // Check if name has the right prefix - std::string expected_prefix = mounted.prefix + "_"; - if (name.substr(0, expected_prefix.size()) != expected_prefix) - continue; + try_name = *overridden_original; + } + else + { + try_name = name; + if (!mounted.prefix.empty()) + { + // Check if name has the right prefix + std::string expected_prefix = mounted.prefix + "_"; + if (name.substr(0, expected_prefix.size()) != expected_prefix) + continue; - // Strip prefix for child lookup - try_name = name.substr(expected_prefix.size()); + // Strip prefix for child lookup + try_name = name.substr(expected_prefix.size()); + } } try @@ -383,13 +478,23 @@ Json FastMCP::invoke_tool(const std::string& name, const Json& args) const { const auto& proxy_mount = *it; - std::string try_name = name; - if (!proxy_mount.prefix.empty()) + std::optional overridden_original = + find_original_tool_name_for_override(proxy_mount.tool_names, name); + std::string try_name; + if (overridden_original) { - std::string expected_prefix = proxy_mount.prefix + "_"; - if (name.substr(0, expected_prefix.size()) != expected_prefix) - continue; - try_name = name.substr(expected_prefix.size()); + try_name = *overridden_original; + } + else + { + try_name = name; + if (!proxy_mount.prefix.empty()) + { + std::string expected_prefix = proxy_mount.prefix + "_"; + if (name.substr(0, expected_prefix.size()) != expected_prefix) + continue; + try_name = name.substr(expected_prefix.size()); + } } try @@ -568,11 +673,21 @@ resources::ResourceContent FastMCP::read_resource(const std::string& uri, const std::vector FastMCP::get_prompt(const std::string& name, const Json& args) const +{ + return get_prompt_result(name, args).messages; +} + +prompts::PromptResult FastMCP::get_prompt_result(const std::string& name, const Json& args) const { // Try local prompts first try { - return prompts_.render(name, args); + const auto& prompt = prompts_.get(name); + prompts::PromptResult out; + out.messages = prompts_.render(name, args); + out.description = prompt.description; + out.meta = prompt.meta; + return out; } catch (const NotFoundError&) { @@ -598,7 +713,7 @@ std::vector FastMCP::get_prompt(const std::string& name, try { - return mounted.app->get_prompt(try_name, args); + return mounted.app->get_prompt_result(try_name, args); } catch (const NotFoundError&) { @@ -624,23 +739,26 @@ std::vector FastMCP::get_prompt(const std::string& name, { auto result = proxy_mount.proxy->get_prompt(try_name, args); + prompts::PromptResult out; + out.description = result.description; + out.meta = result._meta; + // Convert GetPromptResult to vector - std::vector messages; for (const auto& pm : result.messages) { prompts::PromptMessage msg; msg.role = (pm.role == client::Role::Assistant) ? "assistant" : "user"; - // Extract text content + // Extract text content (best-effort) if (!pm.content.empty()) { if (auto* text = std::get_if(&pm.content[0])) msg.content = text->text; } - messages.push_back(msg); + out.messages.push_back(std::move(msg)); } - return messages; + return out; } catch (const NotFoundError&) { diff --git a/src/client/transports.cpp b/src/client/transports.cpp index 87c618d..616e0b4 100644 --- a/src/client/transports.cpp +++ b/src/client/transports.cpp @@ -29,6 +29,12 @@ struct ParsedUrl bool is_https; }; +struct ParsedUrlWithPath +{ + ParsedUrl base; + std::string path; // includes leading '/' +}; + ParsedUrl parse_url(const std::string& base) { ParsedUrl result; @@ -86,6 +92,54 @@ ParsedUrl parse_url(const std::string& base) return result; } + +ParsedUrlWithPath parse_url_with_path(const std::string& url) +{ + ParsedUrlWithPath out; + out.base = parse_url(url); + + // Extract path (if any) from the original string, preserving query/fragment (if present) + auto scheme_pos = url.find("://"); + size_t host_start = (scheme_pos == std::string::npos) ? 0 : (scheme_pos + 3); + auto path_pos = url.find('/', host_start); + if (path_pos == std::string::npos) + out.path = "/"; + else + out.path = url.substr(path_pos); + + if (out.path.empty() || out.path[0] != '/') + out.path.insert(out.path.begin(), '/'); + + return out; +} + +bool is_redirect_status(int status) +{ + return status == 301 || status == 302 || status == 303 || status == 307 || status == 308; +} + +std::pair resolve_redirect_target(const std::string& current_full_url, + const std::string& current_path, + const std::string& location) +{ + if (location.rfind("http://", 0) == 0 || location.rfind("https://", 0) == 0) + { + auto parsed = parse_url_with_path(location); + std::string full_url = + parsed.base.scheme + "://" + parsed.base.host + ":" + std::to_string(parsed.base.port); + return {std::move(full_url), parsed.path}; + } + + if (!location.empty() && location[0] == '/') + return {current_full_url, location}; + + // Relative redirect - resolve against current path + std::string base_dir = "/"; + auto last_slash = current_path.rfind('/'); + if (last_slash != std::string::npos) + base_dir = current_path.substr(0, last_slash + 1); + return {current_full_url, base_dir + location}; +} } // namespace fastmcpp::Json HttpTransport::request(const std::string& route, const fastmcpp::Json& payload) @@ -931,10 +985,6 @@ fastmcpp::Json StreamableHttpTransport::request(const std::string& route, // Create client std::string full_url = url.scheme + "://" + url.host + ":" + std::to_string(url.port); - httplib::Client cli(full_url.c_str()); - cli.set_connection_timeout(10, 0); - cli.set_read_timeout(60, 0); // Longer timeout for streaming - cli.set_keep_alive(true); // Build request headers httplib::Headers request_headers = {{"Accept", "application/json, text/event-stream"}, @@ -956,12 +1006,47 @@ fastmcpp::Json StreamableHttpTransport::request(const std::string& route, fastmcpp::Json rpc_request = { {"jsonrpc", "2.0"}, {"method", route}, {"params", payload}, {"id", id}}; - // Send request - auto res = cli.Post(mcp_path_.c_str(), request_headers, rpc_request.dump(), "application/json"); + std::string path = mcp_path_.empty() ? "/mcp" : mcp_path_; + if (!path.empty() && path[0] != '/') + path.insert(path.begin(), '/'); + + // Send request (follow redirects like the Python SDK's httpx follow_redirects=True) + httplib::Result res; + for (int redirects = 0; redirects <= 5; ++redirects) + { + httplib::Client cli(full_url.c_str()); + cli.set_connection_timeout(10, 0); + cli.set_read_timeout(60, 0); // Longer timeout for streaming + cli.set_keep_alive(true); + cli.set_follow_location(false); + + res = cli.Post(path.c_str(), request_headers, rpc_request.dump(), "application/json"); + if (!res) + throw fastmcpp::TransportError("StreamableHttp request failed: no response"); + + if (is_redirect_status(res->status)) + { + auto loc = res->headers.find("Location"); + if (loc == res->headers.end()) + loc = res->headers.find("location"); + if (loc == res->headers.end()) + throw fastmcpp::TransportError("StreamableHttp redirect without Location header"); + + auto next = resolve_redirect_target(full_url, path, loc->second); + full_url = std::move(next.first); + path = std::move(next.second); + continue; + } + + break; + } if (!res) throw fastmcpp::TransportError("StreamableHttp request failed: no response"); + if (is_redirect_status(res->status)) + throw fastmcpp::TransportError("StreamableHttp redirect limit exceeded"); + if (res->status < 200 || res->status >= 300) throw fastmcpp::TransportError("StreamableHttp error: " + std::to_string(res->status)); diff --git a/src/mcp/handler.cpp b/src/mcp/handler.cpp index 3b7a1eb..7aa2f2e 100644 --- a/src/mcp/handler.cpp +++ b/src/mcp/handler.cpp @@ -1,17 +1,23 @@ #include "fastmcpp/mcp/handler.hpp" #include "fastmcpp/app.hpp" +#include "fastmcpp/mcp/tasks.hpp" #include "fastmcpp/proxy.hpp" #include "fastmcpp/server/sse_server.hpp" #include #include +#include #include +#include +#include #include #include #include #include +#include #include +#include namespace fastmcpp::mcp { @@ -42,6 +48,15 @@ static bool schema_is_object(const fastmcpp::Json& schema) return false; } +// Extract session_id from request meta (injected by transports like SSE). +static std::string extract_session_id(const fastmcpp::Json& params) +{ + if (params.contains("_meta") && params["_meta"].is_object() && + params["_meta"].contains("session_id") && params["_meta"]["session_id"].is_string()) + return params["_meta"]["session_id"].get(); + return ""; +} + static fastmcpp::Json normalize_output_schema_for_mcp(const fastmcpp::Json& schema) { if (schema.is_null()) @@ -113,14 +128,22 @@ struct TaskInfo std::string task_id; std::string task_type; // e.g., "tool" std::string component_identifier; // tool name, prompt name, or resource URI - std::string status; // "working", "completed", "failed", "cancelled" + std::string status; // "queued", "running", "completed", "failed", "cancelled" std::string status_message; std::string created_at; // ISO8601 string (best-effort) std::string last_updated_at; // ISO8601 string (best-effort) int ttl_ms{60000}; - fastmcpp::Json result_payload; // Shape of method-specific result }; +inline std::string mcp_status_from_internal(const std::string& status) +{ + // Per SEP-1686 final spec: tasks MUST begin in "working". + // fastmcpp tracks "queued"/"running" internally; map both to "working" externally. + if (status == "queued" || status == "running") + return "working"; + return status; +} + inline std::string to_iso8601_now() { using clock = std::chrono::system_clock; @@ -141,54 +164,478 @@ inline std::string to_iso8601_now() class TaskRegistry { public: - std::string add_completed_task(const std::string& task_type, - const std::string& component_identifier, - fastmcpp::Json result_payload, int ttl_ms) + explicit TaskRegistry(SessionAccessor session_accessor = {}) + : session_accessor_(std::move(session_accessor)) { - TaskInfo info; - info.task_id = generate_task_id(); - info.task_type = task_type; - info.component_identifier = component_identifier; - info.status = "completed"; - info.ttl_ms = ttl_ms; - info.created_at = to_iso8601_now(); - info.last_updated_at = info.created_at; - info.result_payload = std::move(result_payload); + worker_ = std::thread([this]() { worker_loop(); }); + } - std::lock_guard lock(mutex_); - tasks_[info.task_id] = info; - return info.task_id; + ~TaskRegistry() + { + { + std::lock_guard lock(queue_mutex_); + stop_requested_ = true; + } + queue_cv_.notify_all(); + if (worker_.joinable()) + worker_.join(); + } + + struct CreateResult + { + std::string task_id; + std::string created_at; + }; + + CreateResult create_task(const std::string& task_type, const std::string& component_identifier, + int ttl_ms, std::string owner_session_id) + { + TaskEntry entry; + entry.info.task_id = generate_task_id(); + entry.info.task_type = task_type; + entry.info.component_identifier = component_identifier; + entry.info.status = "queued"; + entry.info.status_message = ""; + entry.info.ttl_ms = ttl_ms; + entry.info.created_at = to_iso8601_now(); + entry.info.last_updated_at = entry.info.created_at; + entry.created_tp = std::chrono::steady_clock::now(); + entry.last_updated_tp = entry.created_tp; + entry.owner_session_id = std::move(owner_session_id); + entry.cancel_requested = std::make_shared(false); + + std::string task_id = entry.info.task_id; + std::string created_at = entry.info.created_at; + auto notify = build_status_notification(entry, /*include_non_terminal=*/true); + { + std::lock_guard lock(mutex_); + tasks_[entry.info.task_id] = std::move(entry); + } + + if (notify) + send_status_notification(*notify); + + return {std::move(task_id), std::move(created_at)}; } - std::optional get_task(const std::string& task_id) const + void enqueue_task(const std::string& task_id, std::function work) { + { + std::lock_guard lock(mutex_); + auto it = tasks_.find(task_id); + if (it == tasks_.end()) + return; + it->second.work = std::move(work); + } + + { + std::lock_guard lock(queue_mutex_); + queue_.push_back(task_id); + } + queue_cv_.notify_one(); + } + + std::optional get_task(const std::string& task_id) + { + purge_expired_locked(); + std::lock_guard lock(mutex_); auto it = tasks_.find(task_id); if (it == tasks_.end()) return std::nullopt; - return it->second; + return it->second.info; } - std::vector list_tasks() const + std::vector list_tasks() { + purge_expired_locked(); + std::lock_guard lock(mutex_); std::vector result; result.reserve(tasks_.size()); for (const auto& kv : tasks_) - result.push_back(kv.second); + result.push_back(kv.second.info); return result; } + enum class ResultState + { + NotFound, + NotReady, + Completed, + Failed, + Cancelled, + }; + + struct ResultQuery + { + ResultState state{ResultState::NotFound}; + fastmcpp::Json payload; + std::string error_message; + }; + + ResultQuery get_result(const std::string& task_id) + { + purge_expired_locked(); + + std::lock_guard lock(mutex_); + auto it = tasks_.find(task_id); + if (it == tasks_.end()) + { + ResultQuery query; + query.state = ResultState::NotFound; + return query; + } + + const auto& info = it->second.info; + if (info.status == "completed") + { + ResultQuery query; + query.state = ResultState::Completed; + query.payload = it->second.result_payload; + return query; + } + if (info.status == "failed") + { + ResultQuery query; + query.state = ResultState::Failed; + query.error_message = it->second.error_message; + return query; + } + if (info.status == "cancelled") + { + ResultQuery query; + query.state = ResultState::Cancelled; + query.error_message = it->second.error_message; + return query; + } + + ResultQuery query; + query.state = ResultState::NotReady; + return query; + } + + bool cancel(const std::string& task_id) + { + purge_expired_locked(); + + std::optional notify; + { + std::lock_guard lock(mutex_); + auto it = tasks_.find(task_id); + if (it == tasks_.end()) + return false; + + auto& entry = it->second; + if (entry.cancel_requested) + entry.cancel_requested->store(true); + + if (entry.info.status == "queued" || entry.info.status == "running") + { + entry.info.status = "cancelled"; + entry.info.status_message = "Task cancelled"; + entry.error_message = "Task cancelled"; + entry.info.last_updated_at = to_iso8601_now(); + entry.last_updated_tp = std::chrono::steady_clock::now(); + notify = build_status_notification(entry, /*include_non_terminal=*/false); + } + } + + if (notify) + send_status_notification(*notify); + return true; + } + private: + bool set_status_message(const std::string& task_id, std::string message) + { + std::optional notify; + { + std::lock_guard lock(mutex_); + auto it = tasks_.find(task_id); + if (it == tasks_.end()) + return false; + + auto& entry = it->second; + bool terminal = (entry.info.status == "completed" || entry.info.status == "failed" || + entry.info.status == "cancelled"); + if (terminal) + return false; + + if (entry.info.status_message == message) + return true; + + entry.info.status_message = std::move(message); + entry.info.last_updated_at = to_iso8601_now(); + entry.last_updated_tp = std::chrono::steady_clock::now(); + notify = build_status_notification(entry, /*include_non_terminal=*/true); + } + + if (notify) + send_status_notification(*notify); + return true; + } + + static void tls_set_status_message(void* ctx, const std::string& task_id, + const std::string& message) + { + auto* self = static_cast(ctx); + (void)self->set_status_message(task_id, message); + } + + struct TaskEntry + { + TaskInfo info; + std::chrono::steady_clock::time_point created_tp{}; + std::chrono::steady_clock::time_point last_updated_tp{}; + std::string owner_session_id; + std::shared_ptr cancel_requested; + std::function work; + fastmcpp::Json result_payload; + std::string error_message; + }; + + struct StatusNotification + { + std::string owner_session_id; + fastmcpp::Json params; + }; + + std::optional build_status_notification(const TaskEntry& entry, + bool include_non_terminal) const + { + if (entry.owner_session_id.empty()) + return std::nullopt; + + const auto& info = entry.info; + bool terminal = + (info.status == "completed" || info.status == "failed" || info.status == "cancelled"); + if (!terminal && !include_non_terminal) + return std::nullopt; + + fastmcpp::Json status_params = { + {"taskId", info.task_id}, {"status", mcp_status_from_internal(info.status)}, + {"createdAt", info.created_at}, {"lastUpdatedAt", info.last_updated_at}, + {"ttl", info.ttl_ms}, {"pollInterval", 1000}, + }; + if (!info.status_message.empty()) + status_params["statusMessage"] = info.status_message; + + return StatusNotification{entry.owner_session_id, std::move(status_params)}; + } + + void send_status_notification(const StatusNotification& notification) const + { + if (!session_accessor_) + return; + auto session = session_accessor_(notification.owner_session_id); + if (!session) + return; + + session->send_notification("notifications/tasks/status", notification.params); + } + + void worker_loop() + { + while (true) + { + std::string task_id; + { + std::unique_lock lock(queue_mutex_); + queue_cv_.wait(lock, [&] { return stop_requested_ || !queue_.empty(); }); + if (stop_requested_ && queue_.empty()) + break; + task_id = std::move(queue_.front()); + queue_.pop_front(); + } + + execute_task(task_id); + } + } + + void execute_task(const std::string& task_id) + { + std::function work; + std::shared_ptr cancel_requested; + std::optional notify; + bool should_execute = false; + { + std::lock_guard lock(mutex_); + auto it = tasks_.find(task_id); + if (it == tasks_.end()) + return; + + auto& entry = it->second; + cancel_requested = entry.cancel_requested; + if (cancel_requested && cancel_requested->load() && entry.info.status == "queued") + { + entry.info.status = "cancelled"; + entry.info.status_message = "Task cancelled"; + entry.error_message = "Task cancelled"; + entry.info.last_updated_at = to_iso8601_now(); + entry.last_updated_tp = std::chrono::steady_clock::now(); + notify = build_status_notification(entry, /*include_non_terminal=*/false); + } + else if (entry.info.status == "queued") + { + entry.info.status = "running"; + entry.info.last_updated_at = to_iso8601_now(); + entry.last_updated_tp = std::chrono::steady_clock::now(); + work = entry.work; + should_execute = true; + notify = build_status_notification(entry, /*include_non_terminal=*/true); + } + // else: already terminal or running - nothing to do + } + + if (notify) + { + send_status_notification(*notify); + if (!should_execute) + return; + } + if (!should_execute) + return; + + if (!work) + { + { + std::lock_guard lock(mutex_); + auto it = tasks_.find(task_id); + if (it == tasks_.end()) + return; + auto& entry = it->second; + entry.info.status = "failed"; + entry.info.status_message = "Task has no work scheduled"; + entry.error_message = "Task has no work scheduled"; + entry.info.last_updated_at = to_iso8601_now(); + entry.last_updated_tp = std::chrono::steady_clock::now(); + notify = build_status_notification(entry, /*include_non_terminal=*/false); + } + if (notify) + send_status_notification(*notify); + return; + } + + bool ok = false; + fastmcpp::Json payload; + std::string error; + try + { + struct TaskTlsScope + { + explicit TaskTlsScope(TaskRegistry* registry, const std::string& task_id) + { + fastmcpp::mcp::tasks::detail::set_current_task( + registry, &TaskRegistry::tls_set_status_message, task_id); + } + ~TaskTlsScope() + { + fastmcpp::mcp::tasks::detail::clear_current_task(); + } + }; + + TaskTlsScope scope(this, task_id); + payload = work(); + ok = true; + } + catch (const std::exception& e) + { + ok = false; + error = e.what(); + } + catch (...) + { + ok = false; + error = "Unknown task error"; + } + + { + std::lock_guard lock(mutex_); + auto it = tasks_.find(task_id); + if (it == tasks_.end()) + return; + auto& entry = it->second; + + if (entry.cancel_requested && entry.cancel_requested->load()) + { + entry.info.status = "cancelled"; + entry.info.status_message = "Task cancelled"; + entry.error_message = "Task cancelled"; + entry.info.last_updated_at = to_iso8601_now(); + entry.last_updated_tp = std::chrono::steady_clock::now(); + } + else if (ok) + { + entry.result_payload = std::move(payload); + entry.info.status = "completed"; + entry.info.status_message = "Task completed successfully"; + } + else + { + entry.info.status = "failed"; + entry.info.status_message = "Task failed"; + entry.error_message = error.empty() ? "Task failed" : error; + } + + entry.info.last_updated_at = to_iso8601_now(); + entry.last_updated_tp = std::chrono::steady_clock::now(); + + notify = build_status_notification(entry, /*include_non_terminal=*/false); + } + + if (notify) + send_status_notification(*notify); + } + + void purge_expired_locked() + { + std::lock_guard lock(mutex_); + purge_expired_locked_no_lock(); + } + + void purge_expired_locked_no_lock() + { + auto now = std::chrono::steady_clock::now(); + for (auto it = tasks_.begin(); it != tasks_.end();) + { + const auto& entry = it->second; + const auto& info = entry.info; + bool terminal = (info.status == "completed" || info.status == "failed" || + info.status == "cancelled"); + if (!terminal) + { + ++it; + continue; + } + + auto age_ms = + std::chrono::duration_cast(now - entry.last_updated_tp) + .count(); + if (age_ms > info.ttl_ms) + it = tasks_.erase(it); + else + ++it; + } + } + std::string generate_task_id() { uint64_t id = next_id_.fetch_add(1, std::memory_order_relaxed) + 1; return "task-" + std::to_string(id); } - mutable std::mutex mutex_; - std::unordered_map tasks_; + std::mutex mutex_; + std::unordered_map tasks_; std::atomic next_id_{0}; + + std::mutex queue_mutex_; + std::condition_variable queue_cv_; + std::deque queue_; + bool stop_requested_{false}; + std::thread worker_; + + SessionAccessor session_accessor_; }; // Helper: convert a tool invocation JSON result into an MCP CallToolResult payload. @@ -1078,7 +1525,13 @@ make_mcp_handler(const std::string& server_name, const std::string& version, // FastMCP handler - supports mounted apps with aggregation std::function make_mcp_handler(const FastMCP& app) { - auto tasks = std::make_shared(); + return make_mcp_handler(app, SessionAccessor{}); +} + +std::function +make_mcp_handler(const FastMCP& app, SessionAccessor session_accessor) +{ + auto tasks = std::make_shared(std::move(session_accessor)); return [&app, tasks](const fastmcpp::Json& message) -> fastmcpp::Json { try @@ -1211,18 +1664,25 @@ std::function make_mcp_handler(const Fast if (has_task_meta) { - // Minimal server-side tasks support: execute synchronously but expose - // result via tasks/get and tasks/result. The initial stub reports a - // taskId and completed status so clients can poll if desired. - auto invoke_result = app.invoke_tool(name, args); - fastmcpp::Json result_payload = - build_fastmcp_tool_result(invoke_result, has_output_schema); + auto created = + tasks->create_task("tool", name, ttl_ms, extract_session_id(params)); + std::string task_id = created.task_id; - std::string task_id = tasks->add_completed_task( - "tool", name, std::move(result_payload), ttl_ms); + tasks->enqueue_task( + task_id, + [&app, name, args, has_output_schema]() -> fastmcpp::Json + { + auto invoke_result = app.invoke_tool(name, args); + return build_fastmcp_tool_result(invoke_result, has_output_schema); + }); fastmcpp::Json task_meta = { - {"taskId", task_id}, {"status", "completed"}, {"ttl", ttl_ms}}; + {"taskId", task_id}, + {"status", "working"}, + {"ttl", ttl_ms}, + {"createdAt", created.created_at}, + {"lastUpdatedAt", created.created_at}, + }; fastmcpp::Json response_result = { {"content", fastmcpp::Json::array()}, @@ -1268,7 +1728,7 @@ std::function make_mcp_handler(const Fast fastmcpp::Json status_json = { {"taskId", info->task_id}, - {"status", info->status}, + {"status", mcp_status_from_internal(info->status)}, }; if (!info->created_at.empty()) status_json["createdAt"] = info->created_at; @@ -1292,15 +1752,19 @@ std::function make_mcp_handler(const Fast if (task_id.empty()) return jsonrpc_error(id, -32602, "Missing taskId"); - auto info = tasks->get_task(task_id); - if (!info) + auto q = tasks->get_result(task_id); + if (q.state == TaskRegistry::ResultState::NotFound) return jsonrpc_error(id, -32602, "Invalid taskId"); + if (q.state == TaskRegistry::ResultState::NotReady) + return jsonrpc_error(id, -32602, "Task not completed"); + if (q.state == TaskRegistry::ResultState::Cancelled) + return jsonrpc_error( + id, -32603, q.error_message.empty() ? "Task cancelled" : q.error_message); + if (q.state == TaskRegistry::ResultState::Failed) + return jsonrpc_error(id, -32603, + q.error_message.empty() ? "Task failed" : q.error_message); - return fastmcpp::Json{ - {"jsonrpc", "2.0"}, - {"id", id}, - {"result", info->result_payload}, - }; + return fastmcpp::Json{{"jsonrpc", "2.0"}, {"id", id}, {"result", q.payload}}; } if (method == "tasks/list") @@ -1308,7 +1772,8 @@ std::function make_mcp_handler(const Fast fastmcpp::Json tasks_array = fastmcpp::Json::array(); for (const auto& t : tasks->list_tasks()) { - fastmcpp::Json t_json = {{"taskId", t.task_id}, {"status", t.status}}; + fastmcpp::Json t_json = {{"taskId", t.task_id}, + {"status", mcp_status_from_internal(t.status)}}; if (!t.created_at.empty()) t_json["createdAt"] = t.created_at; if (!t.last_updated_at.empty()) @@ -1334,20 +1799,25 @@ std::function make_mcp_handler(const Fast if (task_id.empty()) return jsonrpc_error(id, -32602, "Missing taskId"); + if (!tasks->cancel(task_id)) + return jsonrpc_error(id, -32602, "Invalid taskId"); + auto info = tasks->get_task(task_id); if (!info) return jsonrpc_error(id, -32602, "Invalid taskId"); fastmcpp::Json result = { {"taskId", info->task_id}, - {"status", "cancelled"}, + {"status", mcp_status_from_internal(info->status)}, }; if (!info->created_at.empty()) result["createdAt"] = info->created_at; - result["lastUpdatedAt"] = to_iso8601_now(); + if (!info->last_updated_at.empty()) + result["lastUpdatedAt"] = info->last_updated_at; result["ttl"] = info->ttl_ms; result["pollInterval"] = 1000; - result["statusMessage"] = "Cancellation acknowledged (no-op for completed tasks)"; + if (!info->status_message.empty()) + result["statusMessage"] = info->status_message; return fastmcpp::Json{ {"jsonrpc", "2.0"}, @@ -1416,6 +1886,81 @@ std::function make_mcp_handler(const Fast "Task execution required for resource: " + uri); } + if (as_task) + { + auto created = + tasks->create_task("resource", uri, ttl_ms, extract_session_id(params)); + std::string task_id = created.task_id; + + fastmcpp::Json params_for_task = params; + if (params_for_task.contains("_meta") && + params_for_task["_meta"].is_object()) + params_for_task["_meta"].erase("modelcontextprotocol.io/task"); + + tasks->enqueue_task( + task_id, + [&app, uri, params_for_task]() mutable -> fastmcpp::Json + { + auto content = app.read_resource(uri, params_for_task); + fastmcpp::Json content_json = {{"uri", content.uri}}; + if (content.mime_type) + content_json["mimeType"] = *content.mime_type; + + if (std::holds_alternative(content.data)) + { + content_json["text"] = std::get(content.data); + } + else + { + const auto& binary = + std::get>(content.data); + static const char* b64_chars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz012345" + "6789+/"; + std::string b64; + b64.reserve((binary.size() + 2) / 3 * 4); + for (size_t i = 0; i < binary.size(); i += 3) + { + uint32_t n = binary[i] << 16; + if (i + 1 < binary.size()) + n |= binary[i + 1] << 8; + if (i + 2 < binary.size()) + n |= binary[i + 2]; + b64.push_back(b64_chars[(n >> 18) & 0x3F]); + b64.push_back(b64_chars[(n >> 12) & 0x3F]); + b64.push_back((i + 1 < binary.size()) + ? b64_chars[(n >> 6) & 0x3F] + : '='); + b64.push_back((i + 2 < binary.size()) ? b64_chars[n & 0x3F] + : '='); + } + content_json["blob"] = b64; + } + + return fastmcpp::Json{ + {"contents", fastmcpp::Json::array({content_json})}}; + }); + + fastmcpp::Json task_meta = { + {"taskId", task_id}, + {"status", "working"}, + {"ttl", ttl_ms}, + {"createdAt", created.created_at}, + {"lastUpdatedAt", created.created_at}, + }; + + fastmcpp::Json response_result = { + {"contents", fastmcpp::Json::array()}, + {"_meta", + fastmcpp::Json{ + {"modelcontextprotocol.io/task", task_meta}, + }}, + }; + + return fastmcpp::Json{ + {"jsonrpc", "2.0"}, {"id", id}, {"result", response_result}}; + } + auto content = app.read_resource(uri, params); fastmcpp::Json content_json = {{"uri", content.uri}}; if (content.mime_type) @@ -1451,26 +1996,6 @@ std::function make_mcp_handler(const Fast fastmcpp::Json result_payload = fastmcpp::Json{{"contents", fastmcpp::Json::array({content_json})}}; - if (as_task) - { - std::string task_id = tasks->add_completed_task( - "resource", uri, std::move(result_payload), ttl_ms); - - fastmcpp::Json task_meta = { - {"taskId", task_id}, {"status", "completed"}, {"ttl", ttl_ms}}; - - fastmcpp::Json response_result = { - {"contents", fastmcpp::Json::array()}, - {"_meta", - fastmcpp::Json{ - {"modelcontextprotocol.io/task", task_meta}, - }}, - }; - - return fastmcpp::Json{ - {"jsonrpc", "2.0"}, {"id", id}, {"result", response_result}}; - } - return fastmcpp::Json{ {"jsonrpc", "2.0"}, {"id", id}, {"result", result_payload}}; } @@ -1491,20 +2016,23 @@ std::function make_mcp_handler(const Fast for (const auto& [name, prompt] : app.list_all_prompts()) { fastmcpp::Json prompt_json = {{"name", name}}; - if (prompt->description) - prompt_json["description"] = *prompt->description; - if (!prompt->arguments.empty()) + if (prompt) { - fastmcpp::Json args_array = fastmcpp::Json::array(); - for (const auto& arg : prompt->arguments) + if (prompt->description) + prompt_json["description"] = *prompt->description; + if (!prompt->arguments.empty()) { - fastmcpp::Json arg_json = {{"name", arg.name}, - {"required", arg.required}}; - if (arg.description) - arg_json["description"] = *arg.description; - args_array.push_back(arg_json); + fastmcpp::Json args_array = fastmcpp::Json::array(); + for (const auto& arg : prompt->arguments) + { + fastmcpp::Json arg_json = {{"name", arg.name}, + {"required", arg.required}}; + if (arg.description) + arg_json["description"] = *arg.description; + args_array.push_back(arg_json); + } + prompt_json["arguments"] = args_array; } - prompt_json["arguments"] = args_array; } prompts_array.push_back(prompt_json); } @@ -1534,26 +2062,43 @@ std::function make_mcp_handler(const Fast "Task execution required for prompt: " + name); } - fastmcpp::Json args = params.value("arguments", fastmcpp::Json::object()); - auto messages = app.get_prompt(name, args); - - fastmcpp::Json messages_array = fastmcpp::Json::array(); - for (const auto& msg : messages) - { - messages_array.push_back( - {{"role", msg.role}, - {"content", fastmcpp::Json{{"type", "text"}, {"text", msg.content}}}}); - } - - fastmcpp::Json result_payload = fastmcpp::Json{{"messages", messages_array}}; - if (as_task) { - std::string task_id = tasks->add_completed_task( - "prompt", name, std::move(result_payload), ttl_ms); + auto created = + tasks->create_task("prompt", name, ttl_ms, extract_session_id(params)); + std::string task_id = created.task_id; + + fastmcpp::Json args_for_task = + params.value("arguments", fastmcpp::Json::object()); + tasks->enqueue_task( + task_id, + [&app, name, args_for_task]() -> fastmcpp::Json + { + auto prompt_result = app.get_prompt_result(name, args_for_task); + fastmcpp::Json messages_array = fastmcpp::Json::array(); + for (const auto& msg : prompt_result.messages) + { + messages_array.push_back( + {{"role", msg.role}, + {"content", fastmcpp::Json{{"type", "text"}, + {"text", msg.content}}}}); + } + + fastmcpp::Json result_payload = {{"messages", messages_array}}; + if (prompt_result.description) + result_payload["description"] = *prompt_result.description; + if (prompt_result.meta) + result_payload["_meta"] = *prompt_result.meta; + return result_payload; + }); fastmcpp::Json task_meta = { - {"taskId", task_id}, {"status", "completed"}, {"ttl", ttl_ms}}; + {"taskId", task_id}, + {"status", "working"}, + {"ttl", ttl_ms}, + {"createdAt", created.created_at}, + {"lastUpdatedAt", created.created_at}, + }; fastmcpp::Json response_result = { {"messages", fastmcpp::Json::array()}, @@ -1567,6 +2112,23 @@ std::function make_mcp_handler(const Fast {"jsonrpc", "2.0"}, {"id", id}, {"result", response_result}}; } + fastmcpp::Json args = params.value("arguments", fastmcpp::Json::object()); + auto prompt_result = app.get_prompt_result(name, args); + + fastmcpp::Json messages_array = fastmcpp::Json::array(); + for (const auto& msg : prompt_result.messages) + { + messages_array.push_back( + {{"role", msg.role}, + {"content", fastmcpp::Json{{"type", "text"}, {"text", msg.content}}}}); + } + + fastmcpp::Json result_payload = {{"messages", messages_array}}; + if (prompt_result.description) + result_payload["description"] = *prompt_result.description; + if (prompt_result.meta) + result_payload["_meta"] = *prompt_result.meta; + return fastmcpp::Json{ {"jsonrpc", "2.0"}, {"id", id}, {"result", result_payload}}; } @@ -1949,14 +2511,6 @@ make_sampling_callback(std::shared_ptr session) }; } -// Extract session_id from request meta -static std::string extract_session_id(const fastmcpp::Json& params) -{ - if (params.contains("_meta") && params["_meta"].contains("session_id")) - return params["_meta"]["session_id"].get(); - return ""; -} - std::function make_mcp_handler_with_sampling(const FastMCP& app, SessionAccessor session_accessor) { @@ -2237,19 +2791,28 @@ make_mcp_handler_with_sampling(const FastMCP& app, SessionAccessor session_acces try { fastmcpp::Json args = params.value("arguments", fastmcpp::Json::object()); - auto messages = app.get_prompt(prompt_name, args); + auto prompt_result = app.get_prompt_result(prompt_name, args); fastmcpp::Json messages_array = fastmcpp::Json::array(); - for (const auto& msg : messages) + for (const auto& msg : prompt_result.messages) { messages_array.push_back( {{"role", msg.role}, {"content", fastmcpp::Json{{"type", "text"}, {"text", msg.content}}}}); } - return fastmcpp::Json{{"jsonrpc", "2.0"}, - {"id", id}, - {"result", fastmcpp::Json{{"messages", messages_array}}}}; + return fastmcpp::Json{ + {"jsonrpc", "2.0"}, + {"id", id}, + {"result", [&]() + { + fastmcpp::Json result = {{"messages", messages_array}}; + if (prompt_result.description) + result["description"] = *prompt_result.description; + if (prompt_result.meta) + result["_meta"] = *prompt_result.meta; + return result; + }()}}; } catch (const NotFoundError& e) { diff --git a/src/mcp/tasks.cpp b/src/mcp/tasks.cpp new file mode 100644 index 0000000..d142a3d --- /dev/null +++ b/src/mcp/tasks.cpp @@ -0,0 +1,42 @@ +#include "fastmcpp/mcp/tasks.hpp" + +#include + +namespace fastmcpp::mcp::tasks +{ +namespace +{ +struct TaskContext +{ + void* ctx{nullptr}; + detail::StatusMessageFn fn{nullptr}; + std::string task_id; +}; + +thread_local TaskContext tls_task; +} // namespace + +void report_status_message(const std::string& message) +{ + if (!tls_task.fn || tls_task.task_id.empty()) + return; + tls_task.fn(tls_task.ctx, tls_task.task_id, message); +} + +namespace detail +{ +void set_current_task(void* ctx, StatusMessageFn fn, std::string task_id) +{ + tls_task.ctx = ctx; + tls_task.fn = fn; + tls_task.task_id = std::move(task_id); +} + +void clear_current_task() +{ + tls_task.ctx = nullptr; + tls_task.fn = nullptr; + tls_task.task_id.clear(); +} +} // namespace detail +} // namespace fastmcpp::mcp::tasks diff --git a/src/server/elicitation.cpp b/src/server/elicitation.cpp index 7aa0edd..4212a14 100644 --- a/src/server/elicitation.cpp +++ b/src/server/elicitation.cpp @@ -175,6 +175,8 @@ void validate_elicitation_json_schema(const fastmcpp::Json& schema) "' has union type with missing 'type' which is not allowed."); std::string union_type = branch["type"].get(); + if (union_type == "null") + continue; if (allowed.count(union_type) == 0) { throw fastmcpp::ValidationError( diff --git a/src/server/sampling.cpp b/src/server/sampling.cpp new file mode 100644 index 0000000..bed6d0f --- /dev/null +++ b/src/server/sampling.cpp @@ -0,0 +1,256 @@ +#include "fastmcpp/server/sampling.hpp" + +#include "fastmcpp/exceptions.hpp" + +#include +#include + +namespace fastmcpp::server::sampling +{ +namespace +{ +fastmcpp::Json message_to_json(const Message& msg) +{ + return fastmcpp::Json{{"role", msg.role}, {"content", msg.content}}; +} + +fastmcpp::Json to_json_tools(const std::vector& tools) +{ + fastmcpp::Json arr = fastmcpp::Json::array(); + for (const auto& t : tools) + { + fastmcpp::Json tool = {{"name", t.name}, {"inputSchema", t.input_schema}}; + if (t.description) + tool["description"] = *t.description; + arr.push_back(std::move(tool)); + } + return arr; +} + +std::vector normalize_content_to_array(const fastmcpp::Json& content) +{ + if (content.is_array()) + return content.get>(); + if (content.is_object()) + return {content}; + return {}; +} + +std::optional extract_first_text_block(const fastmcpp::Json& content) +{ + for (const auto& block : normalize_content_to_array(content)) + { + if (!block.is_object()) + continue; + if (block.value("type", "") != "text") + continue; + if (block.contains("text") && block["text"].is_string()) + return block["text"].get(); + } + return std::nullopt; +} + +std::vector extract_tool_use_blocks(const fastmcpp::Json& content) +{ + std::vector calls; + for (const auto& block : normalize_content_to_array(content)) + { + if (!block.is_object()) + continue; + if (block.value("type", "") != "tool_use") + continue; + calls.push_back(block); + } + return calls; +} + +fastmcpp::Json make_tool_result_block(const std::string& tool_use_id, const std::string& text, + bool is_error) +{ + fastmcpp::Json content = + fastmcpp::Json::array({fastmcpp::Json{{"type", "text"}, {"text", text}}}); + fastmcpp::Json block = { + {"type", "tool_result"}, {"toolUseId", tool_use_id}, {"content", content}}; + if (is_error) + block["isError"] = true; + return block; +} + +std::string stringify_tool_result(const fastmcpp::Json& value) +{ + if (value.is_string()) + return value.get(); + if (value.is_null()) + return "null"; + return value.dump(); +} +} // namespace + +bool Step::is_tool_use() const +{ + return response.is_object() && response.value("stopReason", std::string()) == "toolUse"; +} + +std::optional Step::text() const +{ + if (!response.is_object() || !response.contains("content")) + return std::nullopt; + return extract_first_text_block(response["content"]); +} + +std::vector Step::tool_calls() const +{ + if (!response.is_object() || !response.contains("content")) + return {}; + return extract_tool_use_blocks(response["content"]); +} + +Step sample_step(std::shared_ptr session, const std::vector& messages, + const Options& options) +{ + if (!session) + throw std::runtime_error("sampling::sample_step: session is null"); + if (!session->supports_sampling()) + throw SamplingNotSupportedError("Client does not support sampling"); + + bool needs_tools = options.tools.has_value() && !options.tools->empty(); + if (needs_tools && !session->supports_sampling_tools()) + throw SamplingNotSupportedError( + "Client does not support sampling with tools. The client must advertise the " + "sampling.tools capability."); + + fastmcpp::Json params = fastmcpp::Json::object(); + params["messages"] = fastmcpp::Json::array(); + for (const auto& msg : messages) + params["messages"].push_back(message_to_json(msg)); + + if (options.system_prompt) + params["systemPrompt"] = *options.system_prompt; + if (options.temperature) + params["temperature"] = *options.temperature; + params["maxTokens"] = options.max_tokens; + + if (options.model_preferences) + params["modelPreferences"] = *options.model_preferences; + if (options.stop_sequences) + params["stopSequences"] = *options.stop_sequences; + if (options.metadata) + params["metadata"] = *options.metadata; + + if (needs_tools) + params["tools"] = to_json_tools(*options.tools); + + if (options.tool_choice && !options.tool_choice->empty()) + params["toolChoice"] = fastmcpp::Json{{"mode", *options.tool_choice}}; + + fastmcpp::Json response = + session->send_request("sampling/createMessage", params, options.timeout); + + Step step; + step.response = response; + step.history = messages; + + // Always append assistant response to history. + if (response.is_object() && response.contains("content")) + step.history.push_back(Message{"assistant", response["content"]}); + + if (!step.is_tool_use()) + return step; + if (!options.execute_tools) + return step; + + // Execute tool calls and append tool results message to history. + std::unordered_map tool_map; + if (needs_tools) + for (const auto& t : *options.tools) + tool_map.emplace(t.name, t); + + fastmcpp::Json tool_results = fastmcpp::Json::array(); + for (const auto& tool_use : step.tool_calls()) + { + std::string tool_use_id = tool_use.value("id", ""); + std::string name = tool_use.value("name", ""); + fastmcpp::Json input = + tool_use.contains("input") ? tool_use["input"] : fastmcpp::Json::object(); + + if (tool_use_id.empty() || name.empty()) + continue; + + auto it = tool_map.find(name); + if (it == tool_map.end()) + { + tool_results.push_back( + make_tool_result_block(tool_use_id, "Error: Unknown tool '" + name + "'", true)); + continue; + } + + try + { + fastmcpp::Json out = it->second.fn(input); + tool_results.push_back( + make_tool_result_block(tool_use_id, stringify_tool_result(out), false)); + } + catch (const fastmcpp::Error& e) + { + tool_results.push_back(make_tool_result_block( + tool_use_id, + options.mask_error_details + ? ("Error executing tool '" + name + "'") + : ("Error executing tool '" + name + "': " + std::string(e.what())), + true)); + } + catch (const std::exception& e) + { + tool_results.push_back(make_tool_result_block( + tool_use_id, + options.mask_error_details + ? ("Error executing tool '" + name + "'") + : ("Error executing tool '" + name + "': " + std::string(e.what())), + true)); + } + catch (...) + { + tool_results.push_back(make_tool_result_block( + tool_use_id, "Error executing tool '" + name + "': unknown error", true)); + } + } + + if (!tool_results.empty()) + step.history.push_back(Message{"user", tool_results}); + + return step; +} + +Result sample(std::shared_ptr session, const std::vector& messages, + Options options) +{ + if (options.max_iterations <= 0) + throw std::runtime_error("sampling::sample: max_iterations must be > 0"); + + std::vector current_messages = messages; + std::optional initial_tool_choice = options.tool_choice; + + for (int i = 0; i < options.max_iterations; ++i) + { + options.tool_choice = initial_tool_choice; + Step step = sample_step(session, current_messages, options); + + if (!step.is_tool_use()) + { + Result result; + result.text = step.text(); + result.response = std::move(step.response); + result.history = std::move(step.history); + return result; + } + + current_messages = std::move(step.history); + + // After first iteration, reset tool choice to default ("auto") per Python behavior. + initial_tool_choice.reset(); + } + + throw std::runtime_error("Sampling exceeded maximum iterations"); +} + +} // namespace fastmcpp::server::sampling diff --git a/src/server/sse_server.cpp b/src/server/sse_server.cpp index a072847..e625d30 100644 --- a/src/server/sse_server.cpp +++ b/src/server/sse_server.cpp @@ -4,15 +4,77 @@ #include "fastmcpp/util/json.hpp" #include +#include #include #include #include +#include #include #include namespace fastmcpp::server { +namespace +{ +struct TaskNotificationInfo +{ + std::string task_id; + std::string status{"completed"}; + int ttl_ms{60000}; + std::string created_at; + std::string last_updated_at; +}; + +std::string to_iso8601_now() +{ + using clock = std::chrono::system_clock; + auto now = clock::now(); + std::time_t t = clock::to_time_t(now); +#ifdef _WIN32 + std::tm tm; + gmtime_s(&tm, &t); +#else + std::tm tm; + gmtime_r(&t, &tm); +#endif + std::ostringstream oss; + oss << std::put_time(&tm, "%Y-%m-%dT%H:%M:%SZ"); + return oss.str(); +} + +std::optional extract_task_notification_info(const fastmcpp::Json& response) +{ + if (!response.is_object() || !response.contains("result") || !response["result"].is_object()) + return std::nullopt; + + const auto& result = response["result"]; + if (!result.contains("_meta") || !result["_meta"].is_object()) + return std::nullopt; + + const auto& meta = result["_meta"]; + auto it = meta.find("modelcontextprotocol.io/task"); + if (it == meta.end() || !it->is_object()) + return std::nullopt; + + const auto& task = *it; + if (!task.contains("taskId") || !task["taskId"].is_string()) + return std::nullopt; + + TaskNotificationInfo info; + info.task_id = task["taskId"].get(); + if (task.contains("status") && task["status"].is_string()) + info.status = task["status"].get(); + if (task.contains("ttl") && task["ttl"].is_number_integer()) + info.ttl_ms = task["ttl"].get(); + if (task.contains("createdAt") && task["createdAt"].is_string()) + info.created_at = task["createdAt"].get(); + if (task.contains("lastUpdatedAt") && task["lastUpdatedAt"].is_string()) + info.last_updated_at = task["lastUpdatedAt"].get(); + return info; +} +} // namespace + SseServerWrapper::SseServerWrapper(McpHandler handler, std::string host, int port, std::string sse_path, std::string message_path, std::string auth_token, std::string cors_origin) @@ -404,6 +466,28 @@ bool SseServerWrapper::start() // Normal request - process with handler auto response = handler_(message); + if (auto info = extract_task_notification_info(response)) + { + if (auto session = get_session(session_id)) + { + fastmcpp::Json created_meta = {{"modelcontextprotocol.io/related-task", + fastmcpp::Json{{"taskId", info->task_id}}}}; + session->send_notification("notifications/tasks/created", + fastmcpp::Json::object(), created_meta); + + std::string created_at = + info->created_at.empty() ? to_iso8601_now() : info->created_at; + std::string last_updated_at = + info->last_updated_at.empty() ? created_at : info->last_updated_at; + fastmcpp::Json status_params = { + {"taskId", info->task_id}, {"status", info->status}, + {"createdAt", created_at}, {"lastUpdatedAt", last_updated_at}, + {"ttl", info->ttl_ms}, {"pollInterval", 1000}, + }; + session->send_notification("notifications/tasks/status", status_params); + } + } + // Send response only to the requesting session send_event_to_session(session_id, response); diff --git a/src/server/stdio_server.cpp b/src/server/stdio_server.cpp index 13748d0..7112104 100644 --- a/src/server/stdio_server.cpp +++ b/src/server/stdio_server.cpp @@ -3,12 +3,77 @@ #include "fastmcpp/exceptions.hpp" #include "fastmcpp/util/json.hpp" +#include +#include +#include #include +#include +#include #include namespace fastmcpp::server { +namespace +{ +struct TaskNotificationInfo +{ + std::string task_id; + std::string status{"completed"}; + int ttl_ms{60000}; + std::string created_at; + std::string last_updated_at; +}; + +std::string to_iso8601_now() +{ + using clock = std::chrono::system_clock; + auto now = clock::now(); + std::time_t t = clock::to_time_t(now); +#ifdef _WIN32 + std::tm tm; + gmtime_s(&tm, &t); +#else + std::tm tm; + gmtime_r(&t, &tm); +#endif + std::ostringstream oss; + oss << std::put_time(&tm, "%Y-%m-%dT%H:%M:%SZ"); + return oss.str(); +} + +std::optional extract_task_notification_info(const fastmcpp::Json& response) +{ + if (!response.is_object() || !response.contains("result") || !response["result"].is_object()) + return std::nullopt; + + const auto& result = response["result"]; + if (!result.contains("_meta") || !result["_meta"].is_object()) + return std::nullopt; + + const auto& meta = result["_meta"]; + auto it = meta.find("modelcontextprotocol.io/task"); + if (it == meta.end() || !it->is_object()) + return std::nullopt; + + const auto& task = *it; + if (!task.contains("taskId") || !task["taskId"].is_string()) + return std::nullopt; + + TaskNotificationInfo info; + info.task_id = task["taskId"].get(); + if (task.contains("status") && task["status"].is_string()) + info.status = task["status"].get(); + if (task.contains("ttl") && task["ttl"].is_number_integer()) + info.ttl_ms = task["ttl"].get(); + if (task.contains("createdAt") && task["createdAt"].is_string()) + info.created_at = task["createdAt"].get(); + if (task.contains("lastUpdatedAt") && task["lastUpdatedAt"].is_string()) + info.last_updated_at = task["lastUpdatedAt"].get(); + return info; +} +} // namespace + StdioServerWrapper::StdioServerWrapper(McpHandler handler) : handler_(std::move(handler)) {} StdioServerWrapper::~StdioServerWrapper() @@ -34,6 +99,38 @@ void StdioServerWrapper::run_loop() // Process with handler auto response = handler_(request); + if (auto info = extract_task_notification_info(response)) + { + fastmcpp::Json created_meta = {{"modelcontextprotocol.io/related-task", + fastmcpp::Json{{"taskId", info->task_id}}}}; + + fastmcpp::Json created_notification = { + {"jsonrpc", "2.0"}, + {"method", "notifications/tasks/created"}, + {"params", fastmcpp::Json::object()}, + {"_meta", created_meta}, + }; + + std::string created_at = + info->created_at.empty() ? to_iso8601_now() : info->created_at; + std::string last_updated_at = + info->last_updated_at.empty() ? created_at : info->last_updated_at; + fastmcpp::Json status_params = { + {"taskId", info->task_id}, {"status", info->status}, + {"createdAt", created_at}, {"lastUpdatedAt", last_updated_at}, + {"ttl", info->ttl_ms}, {"pollInterval", 1000}, + }; + + fastmcpp::Json status_notification = { + {"jsonrpc", "2.0"}, + {"method", "notifications/tasks/status"}, + {"params", status_params}, + }; + + std::cout << created_notification.dump() << std::endl; + std::cout << status_notification.dump() << std::endl; + } + // Write JSON-RPC response to stdout (line-delimited) std::cout << response.dump() << std::endl; std::cout.flush(); diff --git a/tests/app/mounting.cpp b/tests/app/mounting.cpp index a939aba..d13570b 100644 --- a/tests/app/mounting.cpp +++ b/tests/app/mounting.cpp @@ -5,6 +5,10 @@ #include #include +#include +#ifdef _MSC_VER +#include +#endif using namespace fastmcpp; @@ -371,6 +375,38 @@ void test_no_prefix_mounting() std::cout << " PASSED" << std::endl; } +void test_tool_name_overrides_direct() +{ + std::cout << "test_tool_name_overrides_direct..." << std::endl; + + FastMCP main_app("MainApp", "1.0.0"); + FastMCP child_app("ChildApp", "1.0.0"); + + child_app.tools().register_tool(make_echo_tool("echo")); + + main_app.mount(child_app, "child", false, + std::unordered_map{{"echo", "say"}}); + + auto all_tools = main_app.list_all_tools(); + + bool found_say = false; + bool found_child_echo = false; + for (const auto& [name, _] : all_tools) + { + if (name == "say") + found_say = true; + if (name == "child_echo") + found_child_echo = true; + } + assert(found_say); + assert(!found_child_echo); + + auto result = main_app.invoke_tool("say", Json{{"message", "override"}}); + assert(result.get() == "override"); + + std::cout << " PASSED" << std::endl; +} + void test_mcp_handler_integration() { std::cout << "test_mcp_handler_integration..." << std::endl; @@ -382,6 +418,14 @@ void test_mcp_handler_integration() main_app.tools().register_tool(make_add_tool()); child_app.tools().register_tool(make_echo_tool("echo")); + // Register a prompt with meta on the main app + { + auto p = make_prompt("demo", "Hello from prompt!"); + p.description = "Prompt description"; + p.meta = Json{{"source", "mounting_test"}}; + main_app.prompts().register_prompt(p); + } + // Mount child main_app.mount(child_app, "child"); @@ -416,6 +460,18 @@ void test_mcp_handler_integration() assert(call_response.contains("result")); assert(call_response["result"]["content"][0]["text"] == "hello via handler"); + // Test prompts/get - should include description and _meta + auto prompt_response = + handler(Json{{"jsonrpc", "2.0"}, + {"id", 4}, + {"method", "prompts/get"}, + {"params", Json{{"name", "demo"}, {"arguments", Json::object()}}}}); + assert(prompt_response.contains("result")); + assert(prompt_response["result"].contains("description")); + assert(prompt_response["result"]["description"] == "Prompt description"); + assert(prompt_response["result"].contains("_meta")); + assert(prompt_response["result"]["_meta"]["source"] == "mounting_test"); + std::cout << " PASSED" << std::endl; } @@ -532,6 +588,38 @@ void test_proxy_mode_tool_routing() std::cout << " PASSED" << std::endl; } +void test_tool_name_overrides_proxy() +{ + std::cout << "test_tool_name_overrides_proxy..." << std::endl; + + FastMCP main_app("MainApp", "1.0.0"); + FastMCP child_app("ChildApp", "1.0.0"); + + child_app.tools().register_tool(make_echo_tool("echo")); + + main_app.mount(child_app, "child", true, + std::unordered_map{{"echo", "say"}}); + + auto all_tools = main_app.list_all_tools(); + + bool found_say = false; + bool found_child_echo = false; + for (const auto& [name, _] : all_tools) + { + if (name == "say") + found_say = true; + if (name == "child_echo") + found_child_echo = true; + } + assert(found_say); + assert(!found_child_echo); + + auto result = main_app.invoke_tool("say", Json{{"message", "override via proxy"}}); + assert(result.get() == "override via proxy"); + + std::cout << " PASSED" << std::endl; +} + void test_proxy_mode_resource_aggregation() { std::cout << "test_proxy_mode_resource_aggregation..." << std::endl; @@ -727,6 +815,19 @@ void test_proxy_mode_mcp_handler() int main() { +#ifdef _MSC_VER +#ifdef _DEBUG + // Avoid modal "Abort/Retry/Ignore" dialogs on assertion failures when running tests directly. + // Route CRT reports to stderr so ctest/CI logs capture details. + _CrtSetReportMode(_CRT_ASSERT, _CRTDBG_MODE_FILE); + _CrtSetReportFile(_CRT_ASSERT, _CRTDBG_FILE_STDERR); + _CrtSetReportMode(_CRT_ERROR, _CRTDBG_MODE_FILE); + _CrtSetReportFile(_CRT_ERROR, _CRTDBG_FILE_STDERR); + _CrtSetReportMode(_CRT_WARN, _CRTDBG_MODE_FILE); + _CrtSetReportFile(_CRT_WARN, _CRTDBG_FILE_STDERR); +#endif +#endif + std::cout << "=== FastMCP Mounting Tests ===" << std::endl; test_basic_app(); @@ -739,6 +840,7 @@ int main() test_prompt_routing(); test_nested_mounting(); test_no_prefix_mounting(); + test_tool_name_overrides_direct(); test_mcp_handler_integration(); test_multiple_mounts(); @@ -747,6 +849,7 @@ int main() test_proxy_mode_basic(); test_proxy_mode_tool_aggregation(); test_proxy_mode_tool_routing(); + test_tool_name_overrides_proxy(); test_proxy_mode_resource_aggregation(); test_proxy_mode_resource_routing(); test_proxy_mode_prompt_aggregation(); diff --git a/tests/client/tasks.cpp b/tests/client/tasks.cpp index 55cd755..83477cf 100644 --- a/tests/client/tasks.cpp +++ b/tests/client/tasks.cpp @@ -5,6 +5,9 @@ #include "fastmcpp/mcp/handler.hpp" #include "test_helpers.hpp" +#include +#include + using namespace fastmcpp; void test_call_tool_task_immediate() @@ -90,7 +93,10 @@ void test_call_tool_task_with_server_tasks() auto status = task->status(); assert(!status.taskId.empty()); - // Our minimal implementation marks tasks as completed immediately + assert(status.status == "working" || status.status == "completed" || + status.status == "failed" || status.status == "cancelled"); + + status = task->wait(); assert(status.status == "completed"); auto result = task->result(); @@ -144,7 +150,7 @@ void test_prompt_and_resource_tasks_with_server_tasks() auto prompt_task = c.get_prompt_task("greeting", Json{{"name", "Alice"}}, 60000); assert(prompt_task); assert(!prompt_task->returned_immediately()); - auto prompt_status = prompt_task->status(); + auto prompt_status = prompt_task->wait(); assert(prompt_status.status == "completed"); auto prompt_result = prompt_task->result(); assert(!prompt_result.messages.empty()); @@ -153,7 +159,7 @@ void test_prompt_and_resource_tasks_with_server_tasks() auto resource_task = c.read_resource_task("mem://hello", 60000); assert(resource_task); assert(!resource_task->returned_immediately()); - auto resource_status = resource_task->status(); + auto resource_status = resource_task->wait(); assert(resource_status.status == "completed"); auto contents = resource_task->result(); assert(!contents.empty()); @@ -163,6 +169,55 @@ void test_prompt_and_resource_tasks_with_server_tasks() std::cout << " [PASS] Prompt and resource tasks work with FastMCP handler\n"; } +void test_cancel_task() +{ + std::cout << "Test 6: tasks/cancel cancels in-flight work...\n"; + + FastMCP app("task-cancel-app", "1.0.0"); + + Json input_schema = {{"type", "object"}, + {"properties", Json::object({{"ms", Json{{"type", "integer"}}}})}}; + + tools::Tool slow_tool{"slow", input_schema, Json{{"type", "string"}}, [](const Json& in) -> Json + { + int ms = in.value("ms", 1500); + std::this_thread::sleep_for(std::chrono::milliseconds(ms)); + return Json("done"); + }}; + slow_tool.set_task_support(TaskSupport::Optional); + app.tools().register_tool(slow_tool); + + auto handler = mcp::make_mcp_handler(app); + client::Client c(std::make_unique(std::move(handler))); + + auto task = c.call_tool_task("slow", Json{{"ms", 1500}}, 60000); + assert(task); + assert(!task->returned_immediately()); + + // Best-effort: give the background worker a moment to start. + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + + auto cancelled = c.cancel_task(task->task_id()); + assert(cancelled.taskId == task->task_id()); + assert(cancelled.status == "cancelled" || cancelled.status == "completed"); + + if (cancelled.status == "cancelled") + { + bool threw = false; + try + { + (void)c.get_task_result_raw(task->task_id()); + } + catch (const fastmcpp::Error&) + { + threw = true; + } + assert(threw); + } + + std::cout << " [PASS] tasks/cancel reports cancelled for background task\n"; +} + void test_task_support_execution_and_capabilities() { std::cout << "Test 5: TaskSupport enforcement + execution/capabilities...\n"; @@ -374,7 +429,8 @@ int main() test_call_tool_task_with_server_tasks(); test_prompt_and_resource_tasks_with_server_tasks(); test_task_support_execution_and_capabilities(); - std::cout << "\n[OK] Client Task API tests passed! (5 tests)\n"; + test_cancel_task(); + std::cout << "\n[OK] Client Task API tests passed! (6 tests)\n"; return 0; } catch (const std::exception& e) diff --git a/tests/server/sse_tasks_notifications.cpp b/tests/server/sse_tasks_notifications.cpp new file mode 100644 index 0000000..0740ea8 --- /dev/null +++ b/tests/server/sse_tasks_notifications.cpp @@ -0,0 +1,472 @@ +// SSE task notifications test (SEP-1686 subset) +// +// Validates that when a client requests task execution via params._meta +// (modelcontextprotocol.io/task), the server emits: +// - notifications/tasks/created (with taskId in top-level _meta.related-task) +// - notifications/tasks/status (initial + terminal status in params) +// +// Transport emits created/initial status; handler emits terminal status when session access is +// configured. + +#include "fastmcpp/app.hpp" +#include "fastmcpp/mcp/handler.hpp" +#include "fastmcpp/mcp/tasks.hpp" +#include "fastmcpp/server/sse_server.hpp" +#include "fastmcpp/tools/manager.hpp" +#include "fastmcpp/util/json.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using fastmcpp::FastMCP; +using fastmcpp::Json; +using fastmcpp::TaskSupport; +using fastmcpp::server::SseServerWrapper; + +namespace +{ +struct Captured +{ + std::string session_id; + std::vector messages; +}; + +bool has_method(const Json& msg, const std::string& method) +{ + return msg.is_object() && msg.contains("method") && msg["method"].is_string() && + msg["method"].get() == method; +} + +std::optional find_first_by_method(const std::vector& messages, + const std::string& method) +{ + for (const auto& m : messages) + if (has_method(m, method)) + return m; + return std::nullopt; +} + +std::optional find_first_by_id(const std::vector& messages, int id) +{ + for (const auto& m : messages) + { + if (!m.is_object() || !m.contains("id")) + continue; + const auto& mid = m["id"]; + if (mid.is_number_integer() && mid.get() == id) + return m; + } + return std::nullopt; +} + +std::string extract_task_id_from_response(const Json& response) +{ + if (!response.contains("result") || !response["result"].is_object()) + return {}; + const auto& result = response["result"]; + if (!result.contains("_meta") || !result["_meta"].is_object()) + return {}; + const auto& meta = result["_meta"]; + auto it = meta.find("modelcontextprotocol.io/task"); + if (it == meta.end() || !it->is_object()) + return {}; + const auto& task = *it; + if (!task.contains("taskId") || !task["taskId"].is_string()) + return {}; + return task["taskId"].get(); +} + +std::optional find_task_status(const std::vector& messages, const std::string& task_id, + const std::string& status) +{ + for (const auto& m : messages) + { + if (!has_method(m, "notifications/tasks/status")) + continue; + if (!m.contains("params") || !m["params"].is_object()) + continue; + const auto& params = m["params"]; + if (!params.contains("taskId") || !params["taskId"].is_string()) + continue; + if (params["taskId"].get() != task_id) + continue; + if (!params.contains("status") || !params["status"].is_string()) + continue; + if (params["status"].get() != status) + continue; + return m; + } + return std::nullopt; +} + +std::optional find_task_status_message(const std::vector& messages, + const std::string& task_id, + const std::string& substring) +{ + for (const auto& m : messages) + { + if (!has_method(m, "notifications/tasks/status")) + continue; + if (!m.contains("params") || !m["params"].is_object()) + continue; + const auto& params = m["params"]; + if (!params.contains("taskId") || !params["taskId"].is_string()) + continue; + if (params["taskId"].get() != task_id) + continue; + if (!params.contains("statusMessage") || !params["statusMessage"].is_string()) + continue; + const auto msg = params["statusMessage"].get(); + if (msg.find(substring) == std::string::npos) + continue; + return m; + } + return std::nullopt; +} +} // namespace + +int main() +{ + std::cout << "=== SSE tasks notifications test ===\n\n"; + + // Build a FastMCP app with a task-capable tool + FastMCP app("tasks-notify-app", "1.0.0"); + Json input_schema = {{"type", "object"}, + {"properties", Json::object({{"a", Json{{"type", "number"}}}, + {"b", Json{{"type", "number"}}}})}}; + + fastmcpp::tools::Tool add_tool{"add", input_schema, Json{{"type", "number"}}, [](const Json& in) + { + fastmcpp::mcp::tasks::report_status_message("starting"); + double a = in.at("a").get(); + double b = in.at("b").get(); + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + fastmcpp::mcp::tasks::report_status_message("done"); + return Json(a + b); + }}; + add_tool.set_task_support(TaskSupport::Optional); + app.tools().register_tool(add_tool); + + SseServerWrapper* server_ptr = nullptr; + auto handler = fastmcpp::mcp::make_mcp_handler( + app, [&server_ptr](const std::string& session_id) + { return server_ptr ? server_ptr->get_session(session_id) : nullptr; }); + + // Start SSE server + const int port = 18109; + SseServerWrapper server(handler, "127.0.0.1", port, "/sse", "/messages"); + server_ptr = &server; + if (!server.start()) + { + std::cerr << "[FAIL] Failed to start SSE server\n"; + return 1; + } + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + + std::atomic sse_connected{false}; + std::atomic stop_capturing{false}; + std::mutex m; + std::condition_variable cv; + Captured captured; + std::string buffer; + + // Start SSE connection in background thread + std::thread sse_thread( + [&]() + { + httplib::Client sse_client("127.0.0.1", port); + sse_client.set_read_timeout(std::chrono::seconds(20)); + sse_client.set_connection_timeout(std::chrono::seconds(5)); + sse_client.set_keep_alive(true); + sse_client.set_default_headers({{"Accept", "text/event-stream"}}); + + auto receiver = [&](const char* data, size_t len) + { + sse_connected = true; + buffer.append(data, len); + + // Parse SSE events separated by blank line + for (;;) + { + size_t sep = buffer.find("\n\n"); + if (sep == std::string::npos) + break; + + std::string event = buffer.substr(0, sep); + buffer.erase(0, sep + 2); + + std::string event_type; + std::string data_line; + + size_t pos = 0; + while (pos < event.size()) + { + size_t eol = event.find('\n', pos); + if (eol == std::string::npos) + eol = event.size(); + std::string line = event.substr(pos, eol - pos); + pos = (eol < event.size()) ? (eol + 1) : eol; + + if (line.rfind("event: ", 0) == 0) + event_type = line.substr(7); + else if (line.rfind("data: ", 0) == 0) + data_line = line.substr(6); + } + + // Endpoint event includes session_id in the data payload + if (event_type == "endpoint") + { + size_t sid_pos = data_line.find("session_id="); + if (sid_pos != std::string::npos) + { + std::string sid = data_line.substr(sid_pos + 11); + std::lock_guard lock(m); + captured.session_id = sid; + cv.notify_all(); + } + continue; + } + + if (!data_line.empty()) + { + try + { + Json msg = Json::parse(data_line); + std::lock_guard lock(m); + captured.messages.push_back(std::move(msg)); + cv.notify_all(); + } + catch (...) + { + // Ignore non-JSON SSE data lines. + } + } + } + + return !stop_capturing.load(); + }; + + auto res = sse_client.Get("/sse", receiver); + (void)res; + }); + + // Wait for SSE connection + for (int i = 0; i < 500 && !sse_connected; ++i) + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + if (!sse_connected) + { + std::cerr << "[FAIL] SSE connection failed to establish\n"; + stop_capturing = true; + server.stop(); + if (sse_thread.joinable()) + sse_thread.detach(); + return 1; + } + + // Wait for session_id + { + std::unique_lock lock(m); + cv.wait_for(lock, std::chrono::seconds(5), [&] { return !captured.session_id.empty(); }); + } + if (captured.session_id.empty()) + { + std::cerr << "[FAIL] Failed to extract session_id from SSE endpoint\n"; + stop_capturing = true; + server.stop(); + if (sse_thread.joinable()) + sse_thread.detach(); + return 1; + } + + // Send initialize + tools/call with task meta + httplib::Client post_client("127.0.0.1", port); + post_client.set_connection_timeout(std::chrono::seconds(5)); + post_client.set_read_timeout(std::chrono::seconds(5)); + + std::string post_url = "/messages?session_id=" + captured.session_id; + + Json init_request = {{"jsonrpc", "2.0"}, + {"id", 1}, + {"method", "initialize"}, + {"params", + {{"protocolVersion", "2024-11-05"}, + {"capabilities", Json::object()}, + {"clientInfo", {{"name", "test_client"}, {"version", "1.0.0"}}}}}}; + + auto init_res = post_client.Post(post_url, init_request.dump(), "application/json"); + if (!init_res || init_res->status != 200) + { + std::cerr << "[FAIL] initialize POST failed\n"; + stop_capturing = true; + server.stop(); + if (sse_thread.joinable()) + sse_thread.detach(); + return 1; + } + + Json call_request = {{"jsonrpc", "2.0"}, + {"id", 2}, + {"method", "tools/call"}, + {"params", + {{"name", "add"}, + {"arguments", {{"a", 2}, {"b", 3}}}, + {"_meta", {{"modelcontextprotocol.io/task", {{"ttl", 60000}}}}}}}}; + + auto call_res = post_client.Post(post_url, call_request.dump(), "application/json"); + if (!call_res || call_res->status != 200) + { + std::cerr << "[FAIL] tools/call POST failed\n"; + stop_capturing = true; + server.stop(); + if (sse_thread.joinable()) + sse_thread.detach(); + return 1; + } + + // Wait until we see the expected messages. + // We expect: created notification + initial status notification + tools/call response. + { + std::unique_lock lock(m); + cv.wait_for(lock, std::chrono::seconds(5), + [&] + { + auto created = + find_first_by_method(captured.messages, "notifications/tasks/created"); + auto status = + find_first_by_method(captured.messages, "notifications/tasks/status"); + auto resp = find_first_by_id(captured.messages, 2); + return created.has_value() && status.has_value() && resp.has_value(); + }); + } + + std::optional created; + std::optional status; + std::optional response; + { + std::lock_guard lock(m); + created = find_first_by_method(captured.messages, "notifications/tasks/created"); + status = find_first_by_method(captured.messages, "notifications/tasks/status"); + response = find_first_by_id(captured.messages, 2); + } + + if (!created || !status || !response) + { + std::cerr << "[FAIL] Missing expected task notifications/response\n"; + stop_capturing = true; + server.stop(); + if (sse_thread.joinable()) + sse_thread.join(); + return 1; + } + + std::string task_id = extract_task_id_from_response(*response); + if (task_id.empty()) + { + std::cerr << "[FAIL] tools/call response missing taskId in result._meta\n"; + stop_capturing = true; + server.stop(); + if (sse_thread.joinable()) + sse_thread.join(); + return 1; + } + + // Wait for a terminal status notification pushed by the handler. + { + std::unique_lock lock(m); + cv.wait_for( + lock, std::chrono::seconds(10), + [&] { return find_task_status(captured.messages, task_id, "completed").has_value(); }); + } + + // Best-effort: verify we saw at least one statusMessage update while working. + { + std::unique_lock lock(m); + cv.wait_for( + lock, std::chrono::seconds(10), + [&] + { + return find_task_status_message(captured.messages, task_id, "starting").has_value(); + }); + } + + stop_capturing = true; + server.stop(); + if (sse_thread.joinable()) + sse_thread.join(); + + // Validate created notification: taskId lives in top-level _meta.related-task + if (!created->contains("_meta") || !(*created)["_meta"].is_object()) + { + std::cerr << "[FAIL] notifications/tasks/created missing top-level _meta\n"; + return 1; + } + const auto& cmeta = (*created)["_meta"]; + if (!cmeta.contains("modelcontextprotocol.io/related-task") || + !cmeta["modelcontextprotocol.io/related-task"].is_object()) + { + std::cerr << "[FAIL] notifications/tasks/created missing related-task metadata\n"; + return 1; + } + const auto& related = cmeta["modelcontextprotocol.io/related-task"]; + if (!related.contains("taskId") || !related["taskId"].is_string() || + related["taskId"].get() != task_id) + { + std::cerr << "[FAIL] notifications/tasks/created taskId mismatch\n"; + return 1; + } + + // Validate status notification: taskId in params + if (!status->contains("params") || !(*status)["params"].is_object()) + { + std::cerr << "[FAIL] notifications/tasks/status missing params\n"; + return 1; + } + const auto& sparams = (*status)["params"]; + if (!sparams.contains("taskId") || !sparams["taskId"].is_string() || + sparams["taskId"].get() != task_id) + { + std::cerr << "[FAIL] notifications/tasks/status taskId mismatch\n"; + return 1; + } + if (!sparams.contains("status") || !sparams["status"].is_string()) + { + std::cerr << "[FAIL] notifications/tasks/status missing status\n"; + return 1; + } + + // Validate terminal status push + { + std::optional terminal; + std::lock_guard lock(m); + terminal = find_task_status(captured.messages, task_id, "completed"); + if (!terminal) + { + std::cerr << "[FAIL] Missing terminal notifications/tasks/status (completed)\n"; + return 1; + } + } + + // Validate we saw at least one statusMessage update while working + { + std::optional progress; + std::lock_guard lock(m); + progress = find_task_status_message(captured.messages, task_id, "starting"); + if (!progress) + { + std::cerr + << "[FAIL] Missing non-terminal notifications/tasks/status statusMessage update\n"; + return 1; + } + } + + std::cout << "[OK] tasks notifications emitted (created + status + completion push)\n"; + return 0; +} diff --git a/tests/server/streamable_http_integration.cpp b/tests/server/streamable_http_integration.cpp index 880fcaf..c18f26a 100644 --- a/tests/server/streamable_http_integration.cpp +++ b/tests/server/streamable_http_integration.cpp @@ -13,9 +13,11 @@ #include "fastmcpp/mcp/handler.hpp" #include "fastmcpp/server/streamable_http_server.hpp" #include "fastmcpp/tools/manager.hpp" +#include "fastmcpp/util/json.hpp" #include #include +#include #include #include @@ -107,6 +109,69 @@ void test_basic_request_response() server.stop(); } +void test_redirect_follow() +{ + std::cout << " test_redirect_follow... " << std::flush; + + const int port = 18354; + const std::string host = "127.0.0.1"; + + httplib::Server svr; + svr.Post("/mcp", + [&](const httplib::Request&, httplib::Response& res) + { + res.status = 307; + res.set_header("Location", "/real_mcp"); + }); + + svr.Post("/real_mcp", + [&](const httplib::Request& req, httplib::Response& res) + { + Json rpc_request = fastmcpp::util::json::parse(req.body); + Json id = rpc_request.value("id", Json()); + Json result = + Json{{"serverInfo", Json{{"name", "redirected"}, {"version", "1.0"}}}}; + Json rpc_response = Json{{"jsonrpc", "2.0"}, {"id", id}, {"result", result}}; + + res.status = 200; + res.set_header("Mcp-Session-Id", "redirect-session"); + res.set_content(rpc_response.dump(), "application/json"); + }); + + std::thread th([&]() { svr.listen(host, port); }); + svr.wait_until_ready(); + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + + try + { + client::StreamableHttpTransport transport("http://" + host + ":" + std::to_string(port)); + + Json init_params = {{"protocolVersion", "2024-11-05"}, + {"capabilities", Json::object()}, + {"clientInfo", {{"name", "test_client"}, {"version", "1.0.0"}}}}; + + auto init_result = transport.request("initialize", init_params); + assert(init_result.contains("serverInfo") && "Should have serverInfo"); + assert(init_result["serverInfo"]["name"] == "redirected"); + assert(transport.has_session() && "Should have session after initialize"); + assert(transport.session_id() == "redirect-session"); + + std::cout << "PASSED\n"; + } + catch (const std::exception& e) + { + std::cout << "FAILED: " << e.what() << "\n"; + svr.stop(); + if (th.joinable()) + th.join(); + throw; + } + + svr.stop(); + if (th.joinable()) + th.join(); +} + void test_session_management() { std::cout << " test_session_management... " << std::flush; @@ -314,6 +379,7 @@ int main() try { test_basic_request_response(); + test_redirect_follow(); test_session_management(); test_server_info(); test_error_handling(); diff --git a/tests/server/test_elicitation_defaults.cpp b/tests/server/test_elicitation_defaults.cpp index 8ee0451..6cc7d1f 100644 --- a/tests/server/test_elicitation_defaults.cpp +++ b/tests/server/test_elicitation_defaults.cpp @@ -174,6 +174,109 @@ static void test_mixed_defaults_and_required() std::cout << "PASSED\n"; } +static void test_nullable_fields_not_required() +{ + std::cout << " test_nullable_fields_not_required... " << std::flush; + + Json props = Json::object(); + props["maybe_name"] = Json{{"type", "string"}, {"nullable", true}}; + props["age"] = Json{{"type", "integer"}}; + + Json schema = Json{{"type", "object"}, {"properties", props}}; + Json result = get_elicitation_schema(schema); + + const Json& required = result.contains("required") && result["required"].is_array() + ? result["required"] + : Json::array(); + + bool has_age = false; + bool has_maybe_name = false; + for (const auto& v : required) + { + if (!v.is_string()) + continue; + std::string name = v.get(); + if (name == "age") + has_age = true; + if (name == "maybe_name") + has_maybe_name = true; + } + + assert(has_age); + assert(!has_maybe_name); + + std::cout << "PASSED\n"; +} + +static void test_type_array_allows_null_not_required() +{ + std::cout << " test_type_array_allows_null_not_required... " << std::flush; + + Json props = Json::object(); + props["nickname"] = Json{{"type", Json::array({"string", "null"})}}; + props["age"] = Json{{"type", "integer"}}; + + Json schema = Json{{"type", "object"}, {"properties", props}}; + Json result = get_elicitation_schema(schema); + + const Json& required = result.contains("required") && result["required"].is_array() + ? result["required"] + : Json::array(); + + bool has_age = false; + bool has_nickname = false; + for (const auto& v : required) + { + if (!v.is_string()) + continue; + std::string name = v.get(); + if (name == "age") + has_age = true; + if (name == "nickname") + has_nickname = true; + } + + assert(has_age); + assert(!has_nickname); + + std::cout << "PASSED\n"; +} + +static void test_anyof_null_not_required() +{ + std::cout << " test_anyof_null_not_required... " << std::flush; + + Json props = Json::object(); + props["maybe"] = + Json{{"anyOf", Json::array({Json{{"type", "string"}}, Json{{"type", "null"}}})}}; + props["age"] = Json{{"type", "integer"}}; + + Json schema = Json{{"type", "object"}, {"properties", props}}; + Json result = get_elicitation_schema(schema); + + const Json& required = result.contains("required") && result["required"].is_array() + ? result["required"] + : Json::array(); + + bool has_age = false; + bool has_maybe = false; + for (const auto& v : required) + { + if (!v.is_string()) + continue; + std::string name = v.get(); + if (name == "age") + has_age = true; + if (name == "maybe") + has_maybe = true; + } + + assert(has_age); + assert(!has_maybe); + + std::cout << "PASSED\n"; +} + static void test_compress_schema_preserves_defaults() { std::cout << " test_compress_schema_preserves_defaults... " << std::flush; @@ -267,6 +370,9 @@ int main() test_enum_default_preserved(); test_all_defaults_preserved_together(); test_mixed_defaults_and_required(); + test_nullable_fields_not_required(); + test_type_array_allows_null_not_required(); + test_anyof_null_not_required(); test_compress_schema_preserves_defaults(); test_context_elicit_uses_schema_helper(); diff --git a/tests/server/test_sampling_tools.cpp b/tests/server/test_sampling_tools.cpp new file mode 100644 index 0000000..6d236a8 --- /dev/null +++ b/tests/server/test_sampling_tools.cpp @@ -0,0 +1,144 @@ +/// @file test_sampling_tools.cpp +/// @brief Tests for SEP-1577 sampling-with-tools helpers + +#include "fastmcpp/server/sampling.hpp" +#include "fastmcpp/server/session.hpp" + +#include +#include +#include +#include +#include + +using fastmcpp::Json; +using fastmcpp::server::ServerSession; + +namespace sampling = fastmcpp::server::sampling; + +void test_sampling_tools_loop_executes_tool_and_returns_text() +{ + std::cout << " test_sampling_tools_loop_executes_tool_and_returns_text... " << std::flush; + + std::shared_ptr session; + std::shared_ptr* session_ptr = &session; + + int request_count = 0; + bool add_called = false; + int add_a = 0; + int add_b = 0; + + session = std::make_shared( + "sess_tools", + [&](const Json& request) + { + assert(ServerSession::is_request(request)); + assert(request.value("method", "") == "sampling/createMessage"); + assert(request.contains("id")); + assert(request.contains("params")); + + const auto& params = request["params"]; + assert(params.contains("messages")); + assert(params["messages"].is_array()); + + ++request_count; + + Json result; + if (request_count == 1) + { + // First request should include tools. + assert(params.contains("tools")); + assert(params["tools"].is_array()); + bool saw_add = false; + for (const auto& tool : params["tools"]) + if (tool.is_object() && tool.value("name", "") == "add") + saw_add = true; + assert(saw_add); + + result = + Json{{"role", "assistant"}, + {"model", "mock-model"}, + {"stopReason", "toolUse"}, + {"content", Json::array({Json{{"type", "tool_use"}, + {"id", "toolu_1"}, + {"name", "add"}, + {"input", Json{{"a", 2}, {"b", 3}}}}})}}; + } + else if (request_count == 2) + { + // Second request should include a tool_result message in history. + bool saw_tool_result = false; + for (const auto& msg : params["messages"]) + { + if (!msg.is_object() || msg.value("role", "") != "user") + continue; + if (!msg.contains("content")) + continue; + const auto& content = msg["content"]; + if (!content.is_array()) + continue; + for (const auto& block : content) + { + if (block.is_object() && block.value("type", "") == "tool_result" && + block.value("toolUseId", "") == "toolu_1") + { + saw_tool_result = true; + } + } + } + assert(saw_tool_result); + + result = Json{{"role", "assistant"}, + {"model", "mock-model"}, + {"stopReason", "endTurn"}, + {"content", Json{{"type", "text"}, {"text", "Result: 5"}}}}; + } + else + { + assert(false && "Unexpected sampling request count"); + } + + Json response = {{"jsonrpc", "2.0"}, {"id", request["id"]}, {"result", result}}; + (void)(*session_ptr)->handle_response(response); + }); + + session->set_capabilities(Json{{"sampling", Json{{"tools", Json::object()}}}}); + assert(session->supports_sampling()); + assert(session->supports_sampling_tools()); + + sampling::Tool add_tool; + add_tool.name = "add"; + add_tool.description = "Add two numbers"; + add_tool.input_schema = Json{ + {"type", "object"}, + {"properties", Json{{"a", Json{{"type", "integer"}}}, {"b", Json{{"type", "integer"}}}}}, + {"required", Json::array({"a", "b"})}}; + add_tool.fn = [&](const Json& input) -> Json + { + add_called = true; + add_a = input.value("a", 0); + add_b = input.value("b", 0); + return Json(add_a + add_b); + }; + + sampling::Options opts; + opts.max_tokens = 64; + opts.tools = std::vector{add_tool}; + opts.tool_choice = std::string("auto"); + + auto result = + sampling::sample(session, {sampling::make_text_message("user", "Compute 2+3")}, opts); + assert(add_called); + assert(add_a == 2); + assert(add_b == 3); + assert(result.text.has_value()); + assert(result.text->find("Result: 5") != std::string::npos); + + std::cout << "PASSED\n"; +} + +int main() +{ + std::cout << "=== sampling tools tests ===\n\n"; + test_sampling_tools_loop_executes_tool_and_returns_text(); + return 0; +} diff --git a/tests/server/test_server_session.cpp b/tests/server/test_server_session.cpp index d868f87..e199b95 100644 --- a/tests/server/test_server_session.cpp +++ b/tests/server/test_server_session.cpp @@ -21,6 +21,7 @@ void test_session_creation() assert(session.session_id() == "sess_123"); assert(!session.supports_sampling()); + assert(!session.supports_sampling_tools()); assert(!session.supports_elicitation()); assert(!session.supports_roots()); @@ -35,6 +36,7 @@ void test_set_capabilities() // No capabilities initially assert(!session.supports_sampling()); + assert(!session.supports_sampling_tools()); assert(!session.supports_elicitation()); // Set capabilities @@ -42,13 +44,20 @@ void test_set_capabilities() session.set_capabilities(caps); assert(session.supports_sampling()); + assert(!session.supports_sampling_tools()); assert(!session.supports_elicitation()); assert(session.supports_roots()); + // Enable sampling tools capability + Json caps_tools = {{"sampling", Json{{"tools", Json::object()}}}}; + session.set_capabilities(caps_tools); + assert(session.supports_sampling()); + assert(session.supports_sampling_tools()); + // Get raw capabilities auto raw = session.capabilities(); assert(raw.contains("sampling")); - assert(raw.contains("roots")); + // roots may be absent after set_capabilities(caps_tools) std::cout << "PASSED\n"; }