diff --git a/docs/handler-authoring.md b/docs/handler-authoring.md index f22434f1..a827019a 100644 --- a/docs/handler-authoring.md +++ b/docs/handler-authoring.md @@ -89,6 +89,72 @@ def _resolve_identity(ctx: ToolContext | None) -> ResolvedIdentity: ) ``` +## Typed handler params + +Handler methods may declare their `params` as a Pydantic model instead +of `dict[str, Any]`. The dispatcher reads the annotation and +deserialises the incoming request before calling your method — you +get IDE autocomplete, Pydantic validation at the handler boundary, and +typed attribute access in exchange for a one-line signature change. + +```python +from adcp.server import ADCPHandler, ToolContext +from adcp.types import GetProductsRequest, GetProductsResponse, Product + + +class MySeller(ADCPHandler): + async def get_products( + self, + params: GetProductsRequest, + context: ToolContext | None = None, + ) -> GetProductsResponse: + # params.buying_mode, params.promoted_offering, params.brief — + # typed, validated, autocompleted. No params.get(...) anywhere. + if params.buying_mode.value == "refine": + ... + return GetProductsResponse(products=[...]) +``` + +**Validation errors surface as `INVALID_REQUEST`.** A Pydantic +`ValidationError` at the boundary is converted to a structured AdCP +error with the field path and validation detail — callers see the +spec-typed recovery classification (`correctable`), not a stack trace. +The raw offending value is stripped from the error (SDK sends +`include_input=False` to Pydantic) so mistyped secrets don't echo +back to multi-hop intermediaries. + +> **Custom validator caveat.** If you layer `@field_validator` or +> `@model_validator` on a custom params model, **don't f-string the +> offending value into the `ValueError` message** +> (`raise ValueError(f"bad token {v}")`). The message text flows into +> the client-visible error — `include_input=False` only suppresses +> Pydantic's default echo, not your own. Stick to describing the +> constraint (`raise ValueError("token must match pk_… pattern")`). + +**Back-compat is automatic.** Handlers that keep `params: dict[str, Any]` +work unchanged. The dispatcher falls back to the dict path when no +Pydantic model is in the annotation — migrate incrementally, one +method at a time. Sibling methods with mixed typed/dict signatures +coexist on the same handler. + +**Unions with dict are supported.** `params: GetProductsRequest | dict[str, Any]` +(the shape the specialized SDK bases use internally) works — the +dispatcher picks the first Pydantic branch and deserialises. Existing +handlers that do defensive `GetProductsRequest.model_validate(params)` +inside the method still work: Pydantic's `model_validate` on an +already-typed instance is a no-op (returns the same object; field +validators are skipped — so a custom `@field_validator` layered on a +params model won't fire twice, and won't fire again on the defensive +re-call inside the handler). + +**Custom models too.** You aren't restricted to the SDK's generated +request classes. Any `BaseModel` subclass declared on `params` +triggers typed dispatch — useful when you want to layer stricter +field constraints or business invariants on top of the spec shape. +Define the model at module top-level so forward-reference resolution +works (`from __future__ import annotations` stringifies all +annotations). + ## Authentication The SDK does not enforce authentication. There are two supported diff --git a/examples/typed_handler_demo.py b/examples/typed_handler_demo.py new file mode 100644 index 00000000..9462f5d1 --- /dev/null +++ b/examples/typed_handler_demo.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 +"""Typed handler params — minimal demonstration (#214). + +Shows a handler that declares its ``params`` as a Pydantic model +rather than ``dict[str, Any]``. The dispatcher validates and +deserialises the request at the boundary, so the handler body works +with typed attributes instead of ``params.get(...)``. + +Mixing typed and dict signatures on the same handler is supported — +useful for migrating a large seller one method at a time. + +Run:: + + python examples/typed_handler_demo.py + +Then call ``get_products`` from any MCP client. A request missing +``buying_mode`` returns a structured ``INVALID_REQUEST`` AdCP error. +""" + +from __future__ import annotations + +from typing import Any + +from adcp.server import ADCPHandler, ToolContext, serve +from adcp.types import ( + GetAdcpCapabilitiesResponse, + GetProductsRequest, + GetProductsResponse, + Product, + PublisherPropertiesAll, +) + + +class TypedSeller(ADCPHandler): + """Minimal handler demonstrating typed dispatch. + + Only two methods are overridden: + + - ``get_adcp_capabilities`` — required by the AdCP spec. + - ``get_products`` — typed ``params: GetProductsRequest``. The + dispatcher deserialises before calling, so ``params.buying_mode`` + is a typed enum attribute, not a dict lookup. + """ + + _agent_type = "typed-demo-seller" + + async def get_adcp_capabilities( + self, params: dict[str, Any], context: ToolContext | None = None + ) -> dict[str, Any]: + return GetAdcpCapabilitiesResponse( + adcp={"major_versions": [3]}, + supported_protocols=["media_buy"], + ).model_dump(mode="json", exclude_none=True) + + async def get_products( + self, + params: GetProductsRequest, + context: ToolContext | None = None, + ) -> dict[str, Any]: + # Typed attribute access — no params.get("buying_mode") anywhere. + # Pydantic already validated the shape; the handler focuses on + # business logic. + requested_mode = params.buying_mode.value + + products: list[Product] = [ + Product( + product_id="demo-product", + name=f"Demo — {requested_mode} mode", + description="A demonstration product for the typed-dispatch example.", + publisher_properties=[ + PublisherPropertiesAll( + publisher_domain="example.com", + selection_type="all", + ) + ], + format_ids=[], + delivery_type="non_guaranteed", + pricing_options=[], + ) + ] + + return GetProductsResponse(products=products).model_dump(mode="json", exclude_none=True) + + +if __name__ == "__main__": + # Demo only — ``serve()`` defaults to binding 0.0.0.0 with no auth. + # For production, wrap with an auth middleware (see + # ``examples/mcp_with_auth_middleware.py``) and restrict the host + # via reverse-proxy config or the ``port=`` / bind-host hooks. + serve(TypedSeller(), name="typed-demo-seller", transport="streamable-http") diff --git a/pyproject.toml b/pyproject.toml index b0d98f47..5a7ef330 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,14 @@ dependencies = [ "httpcore>=1.0,<2.0", "pydantic>=2.0.0", "typing-extensions>=4.5.0", - "a2a-sdk>=0.3.0", + # Cap at <1.0 — a2a-sdk 1.0.0 (released 2026-04-20) is a breaking + # rewrite that moves types to a2a.types.a2a_pb2, renames + # DefaultRequestHandler, removes ServerError from a2a.utils.errors, + # and changes Part/Message construction away from ``root=`` kwargs. + # Migration is non-trivial (28+ mypy errors across webhooks, client, + # protocols/a2a, server/a2a_server, server/translate). Tracked as a + # separate compat PR. + "a2a-sdk>=0.3.0,<1.0", "mcp>=1.23.2", "email-validator>=2.0.0", "cryptography>=41.0.0", diff --git a/src/adcp/server/mcp_tools.py b/src/adcp/server/mcp_tools.py index 8ed9e6ff..369a4891 100644 --- a/src/adcp/server/mcp_tools.py +++ b/src/adcp/server/mcp_tools.py @@ -1388,6 +1388,58 @@ def get_tools_for_handler( ] +def _resolve_params_pydantic_model(method: Any) -> type[Any] | None: + """Resolve the Pydantic model the handler expects for ``params``. + + Inspects the method's ``params`` annotation. Returns the Pydantic + class when the annotation is: + + - A direct ``BaseModel`` subclass (``params: GetProductsRequest``). + - A Union / Optional whose first member is a ``BaseModel`` subclass + (``params: GetProductsRequest | dict[str, Any]``). This shape is + what the specialized SDK handler bases declare — typed-dispatch + treats the first Pydantic branch as the authoritative shape, so + existing ``params: Request | dict`` annotations keep working. + + Returns ``None`` for ``dict``, missing annotation, or forward + references that fail to resolve — the dispatcher then falls back + to the legacy dict path. + + Cached per method object via the returned value being computed once + at ``create_tool_caller`` setup time. + """ + import typing + from types import UnionType + + from pydantic import BaseModel + + try: + hints = typing.get_type_hints(method) + except Exception as exc: # forward-ref failure, missing import, etc. + # Log at debug so an author whose typed annotation silently + # failed to resolve (typo in the class name, import not at + # module top-level, PEP 563 name bound in a local scope) can + # find out why their handler is dispatching via the dict path. + logger.debug( + "typed params annotation failed to resolve for %r: %s; " + "falling back to dict dispatch", + method, + exc, + ) + return None + annotation = hints.get("params") + if annotation is None: + return None + if isinstance(annotation, type) and issubclass(annotation, BaseModel): + return annotation + origin = typing.get_origin(annotation) + if origin is typing.Union or origin is UnionType: + for arg in typing.get_args(annotation): + if isinstance(arg, type) and issubclass(arg, BaseModel): + return arg + return None + + def create_tool_caller( handler: ADCPHandler[Any], method_name: str, @@ -1398,6 +1450,17 @@ def create_tool_caller( ``context`` field, it is echoed back in the response (ADCP requirement). Handlers no longer need to call ``inject_context()`` manually. + **Typed params (closes #214).** When the handler method declares its + ``params`` parameter as a Pydantic model (e.g. + ``params: GetProductsRequest``), the dispatcher deserialises the raw + dict into the model before calling the handler — giving authors + IDE autocomplete, Pydantic validation at the boundary, and typed + attribute access instead of ``params.get(...)``. Handlers still + declaring ``params: dict[str, Any]`` keep working unchanged. A + Pydantic ``ValidationError`` surfaces as a structured + ``INVALID_REQUEST`` AdCP error so callers see a spec-typed recovery + classification rather than a raw stack trace. + Args: handler: The ADCP handler instance method_name: Name of the method to call @@ -1411,19 +1474,68 @@ def create_tool_caller( per-principal scoping, audit logging) gets the real principal. When no context is supplied, a bare :class:`ToolContext` is used. """ + from pydantic import ValidationError + + from adcp.exceptions import ADCPTaskError from adcp.server.helpers import inject_context + from adcp.types import Error method = getattr(handler, method_name) + params_model = _resolve_params_pydantic_model(method) async def call_tool(params: dict[str, Any], context: ToolContext | None = None) -> Any: ctx = context if context is not None else ToolContext() - result = await method(params, ctx) + raw_params = params # Preserve the original dict for context echo. + call_params: Any = params + if params_model is not None and isinstance(params, dict): + try: + call_params = params_model.model_validate(params) + except ValidationError as exc: + # Surface as a structured AdCP error so MCP clients see + # INVALID_REQUEST with a field-level pointer instead of + # a raw Pydantic traceback. translate_error maps this + # to ToolError (MCP) / ServerError (A2A) per transport. + # + # Strip ``input``/``ctx``/``url`` from the Pydantic error + # details — they echo the raw offending value verbatim + # (``input`` in particular). In multi-hop agent chains the + # response flows through intermediaries, so echoing the + # user-supplied value is a PII/secret-leak vector: a + # mistyped API key or secret-shaped idempotency_key could + # land in the broker's logs. The field path in + # ``Error.field`` is all clients need to programmatically + # locate the bad value in their own request. + errors_list = exc.errors( + include_input=False, include_context=False, include_url=False + ) + first: dict[str, Any] = dict(errors_list[0]) if errors_list else {} + field_path = ".".join(str(loc) for loc in first.get("loc", ())) + message = first.get("msg", "validation failed") + suggestion = ( + f"Invalid value for field {field_path!r}: {message}" + if field_path + else f"Request validation failed: {message}" + ) + raise ADCPTaskError( + operation=method_name, + errors=[ + Error( + code="INVALID_REQUEST", + field=field_path or None, + message=suggestion, + details={"validation_errors": errors_list}, + ) + ], + ) from exc + result = await method(call_params, ctx) # Convert Pydantic models to JSON-safe dicts for MCP serialization if hasattr(result, "model_dump"): result = result.model_dump(mode="json", exclude_none=True) - # ADCP requires echoing context from request to response + # ADCP requires echoing context from request to response — read + # from the raw dict the transport sent, not from the validated + # model (which won't carry the wire ``context`` field). if isinstance(result, dict): - inject_context(params, result) + inject_context(raw_params, result) return result return call_tool @@ -1482,9 +1594,7 @@ def get_tool_names(self) -> list[str]: return list(self._tools.keys()) -def create_mcp_tools( - handler: ADCPHandler[Any], *, advertise_all: bool = False -) -> MCPToolSet: +def create_mcp_tools(handler: ADCPHandler[Any], *, advertise_all: bool = False) -> MCPToolSet: """Create MCP tools from an ADCP handler. This is the main entry point for MCP server integration. diff --git a/src/adcp/server/translate.py b/src/adcp/server/translate.py index 4173f0eb..e2b5f89b 100644 --- a/src/adcp/server/translate.py +++ b/src/adcp/server/translate.py @@ -142,6 +142,7 @@ def translate_error( raise ValueError(f"protocol must be 'mcp' or 'a2a', got {protocol!r}") # Extract structured fields from the input + field: str | None = None if isinstance(exc, Error): code = exc.code message = exc.message @@ -149,6 +150,7 @@ def translate_error( details = exc.details recovery = _recovery_for_code(code) errors = None + field = exc.field elif isinstance(exc, ADCPError): code = _error_code_for_exception(exc) message = exc.message @@ -156,11 +158,18 @@ def translate_error( recovery = _recovery_for_code(code) details = None errors = getattr(exc, "errors", None) + # ADCPTaskError carries a list of Error objects — lift the first + # error's ``field`` so MCP clients see the field path too (A2A + # already surfaces it inside ``data.errors[i].field`` via the + # structured error passthrough). + if errors: + first = errors[0] + field = getattr(first, "field", None) else: raise TypeError(f"Expected ADCPError or Error, got {type(exc).__name__}") if proto == "mcp": - return _to_mcp(code, message, suggestion=suggestion) + return _to_mcp(code, message, suggestion=suggestion, field=field) return _to_a2a( code, message, @@ -176,9 +185,23 @@ def _to_mcp( message: str, *, suggestion: str | None = None, + field: str | None = None, ) -> ToolError: - """Format error as a ToolError for MCP servers.""" - text = f"{code}: {message}" + """Format error as a ToolError for MCP servers. + + MCP's ``ToolError`` is a flat text payload — there's no structured + ``data`` channel equivalent to A2A's. To give MCP clients a + programmatic handle on the offending field, the field path is + embedded in the code prefix: ``INVALID_REQUEST[packages[0].budget]: + …``. Clients can parse the bracketed form with a simple regex + (``^([A-Z_]+)(?:\\[([^\\]]+)\\])?:``) to recover both the AdCP code + and the field path — same shape the spec suggests for the JS + client. + """ + if field: + text = f"{code}[{field}]: {message}" + else: + text = f"{code}: {message}" if suggestion: text += f"\nSuggestion: {suggestion}" return ToolError(text) diff --git a/tests/test_typed_handler_params.py b/tests/test_typed_handler_params.py new file mode 100644 index 00000000..26887175 --- /dev/null +++ b/tests/test_typed_handler_params.py @@ -0,0 +1,435 @@ +"""Typed handler params — closes #214. + +Before this PR, handlers took ``params: dict[str, Any]`` and wrote +``params.get("buying_mode")`` everywhere. Rounds 4–7 of DX validation +flagged this as the biggest structural boilerplate complaint: no IDE +autocomplete, no Pydantic validation at the handler boundary, typos +land silently as ``None`` at runtime. + +The dispatcher now inspects the handler override's ``params`` +annotation. When it's a Pydantic model, the raw dict is +``model_validate``'d before the handler runs — the handler receives a +typed instance with autocomplete and validation. Invalid payloads +surface as a structured ``INVALID_REQUEST`` AdCP error (spec-typed +recovery classification) instead of a raw Pydantic traceback. + +Legacy ``params: dict[str, Any]`` handlers keep working — the +dispatcher falls back to the dict path when no Pydantic model is in +the annotation. This is a pure DX upgrade, not a breaking change. +""" + +from __future__ import annotations + +from typing import Any + +import pytest +from pydantic import BaseModel + +from adcp.exceptions import ADCPTaskError +from adcp.server import ADCPHandler, ToolContext +from adcp.server.mcp_tools import ( + _resolve_params_pydantic_model, + create_tool_caller, +) +from adcp.types import GetProductsRequest, ListCreativeFormatsRequest + +# --------------------------------------------------------------------------- +# _resolve_params_pydantic_model — the signature inspection helper +# --------------------------------------------------------------------------- + + +def test_resolves_direct_pydantic_annotation(): + """``params: GetProductsRequest`` — the primary target shape.""" + + async def fn(self, params: GetProductsRequest, context: ToolContext | None = None) -> Any: + return {} + + assert _resolve_params_pydantic_model(fn) is GetProductsRequest + + +def test_resolves_union_with_pydantic_and_dict(): + """``params: GetProductsRequest | dict[str, Any]`` is the shape the + specialized SDK bases already declare. The helper picks the Pydantic + branch so existing specialized-base subclasses get typed dispatch + without code changes.""" + + async def fn( + self, + params: GetProductsRequest | dict[str, Any], + context: ToolContext | None = None, + ) -> Any: + return {} + + assert _resolve_params_pydantic_model(fn) is GetProductsRequest + + +def test_returns_none_for_dict_annotation(): + """``params: dict[str, Any]`` — legacy signature. No deserialization + happens; dispatcher passes the dict through.""" + + async def fn(self, params: dict[str, Any], context: ToolContext | None = None) -> Any: + return {} + + assert _resolve_params_pydantic_model(fn) is None + + +def test_returns_none_for_missing_annotation(): + """``params`` with no annotation — legacy pattern. Pass through.""" + + async def fn(self, params, context=None) -> Any: + return {} + + assert _resolve_params_pydantic_model(fn) is None + + +def test_returns_none_for_non_pydantic_class(): + """A class that isn't a Pydantic model doesn't trigger typed + dispatch — we're not going to synthesize validation logic for + arbitrary user types.""" + + class _NotPydantic: + pass + + async def fn(self, params: _NotPydantic, context: ToolContext | None = None) -> Any: + return {} + + assert _resolve_params_pydantic_model(fn) is None + + +# --------------------------------------------------------------------------- +# Dispatcher hands typed instance to handler +# --------------------------------------------------------------------------- + + +async def test_typed_handler_receives_pydantic_instance(): + """The primary #214 promise. Author writes + ``async def get_products(self, params: GetProductsRequest, ...)`` + and the handler gets a typed instance — with attribute access, + autocomplete, and validation already done.""" + received: list[Any] = [] + + class _Agent(ADCPHandler): + async def get_adcp_capabilities(self, params, context=None): + return {"adcp": {"major_versions": [3]}} + + async def get_products( + self, + params: GetProductsRequest, + context: ToolContext | None = None, + ) -> Any: + received.append(params) + return {"products": []} + + caller = create_tool_caller(_Agent(), "get_products") + await caller({"buying_mode": "brief", "promoted_offering": "test"}) + + assert len(received) == 1 + # Typed instance — attribute access works (would fail on a dict). + assert isinstance(received[0], GetProductsRequest) + assert received[0].buying_mode.value == "brief" + assert received[0].promoted_offering == "test" + + +async def test_legacy_dict_handler_still_works(): + """Backward-compat. Pre-#214 handlers with + ``params: dict[str, Any]`` keep getting dicts — the dispatcher + sees no Pydantic annotation and passes through.""" + received: list[Any] = [] + + class _Agent(ADCPHandler): + async def get_adcp_capabilities(self, params, context=None): + return {"adcp": {"major_versions": [3]}} + + async def get_products( + self, params: dict[str, Any], context: ToolContext | None = None + ) -> Any: + received.append(params) + return {"products": []} + + caller = create_tool_caller(_Agent(), "get_products") + await caller({"buying_mode": "brief"}) + + assert len(received) == 1 + assert isinstance(received[0], dict) + assert received[0]["buying_mode"] == "brief" + + +async def test_validation_error_surfaces_as_invalid_request(): + """A Pydantic ValidationError at the dispatcher boundary must NOT + propagate as a raw traceback. It surfaces as a structured + ADCPTaskError with code ``INVALID_REQUEST`` so ``translate_error`` + maps it to the right MCP/A2A error shape and clients can programmatic- + handle recovery.""" + + class _Agent(ADCPHandler): + async def get_adcp_capabilities(self, params, context=None): + return {"adcp": {"major_versions": [3]}} + + async def get_products( + self, + params: GetProductsRequest, + context: ToolContext | None = None, + ) -> Any: + return {"products": []} + + caller = create_tool_caller(_Agent(), "get_products") + + with pytest.raises(ADCPTaskError) as exc_info: + # Missing required field `buying_mode`. + await caller({"promoted_offering": "test"}) + + err = exc_info.value + assert "INVALID_REQUEST" in err.error_codes + # The error carries the Pydantic validation details so downstream + # can inspect programmatically. + assert err.errors[0].details is not None + assert "validation_errors" in err.errors[0].details + # The field path is lifted onto Error.field — the spec's dedicated + # field for programmatic client handling (vs. parsing the message). + assert err.errors[0].field == "buying_mode" + + +async def test_validation_error_strips_input_value(): + """**PII/secret-leak regression guard**. Pydantic's ``errors()`` + echoes the raw offending input under ``input`` (and ``ctx``/``url``). + In multi-hop agent chains the error flows through broker + intermediaries — echoing a mistyped bearer token or secret-shaped + value exposes it. The dispatcher strips ``input``/``ctx``/``url`` + before wrapping in ADCPTaskError. Regression here would silently + reintroduce the leak (security review of PR #238).""" + + class _Agent(ADCPHandler): + async def get_adcp_capabilities(self, params, context=None): + return {"adcp": {"major_versions": [3]}} + + async def get_products( + self, + params: GetProductsRequest, + context: ToolContext | None = None, + ) -> Any: + return {"products": []} + + caller = create_tool_caller(_Agent(), "get_products") + # Submit a value the caller might regret broadcasting — a + # secret-shaped string for a field with the wrong type + # constraint. The error must NOT echo it back. + sensitive = "sk_live_SUPER_SECRET_VALUE_xyz" + with pytest.raises(ADCPTaskError) as exc_info: + await caller({"buying_mode": sensitive}) + + err = exc_info.value + # The raw sensitive string must not appear anywhere in the error. + details_serialised = str(err.errors[0].details) + assert sensitive not in details_serialised + assert sensitive not in err.errors[0].message + # Structural details still carry loc/msg/type — client debuggability + # is preserved via the field path. + validation_errors = err.errors[0].details["validation_errors"] + assert validation_errors + assert "loc" in validation_errors[0] + assert "msg" in validation_errors[0] + # And explicitly the stripped keys are gone. + assert "input" not in validation_errors[0] + assert "url" not in validation_errors[0] + + +def test_mcp_error_translation_embeds_field_path(): + """``translate_error`` for MCP previously dropped ``Error.field`` + because MCP's ToolError has no structured ``data`` channel. The + fix embeds the field path in the code prefix: ``INVALID_REQUEST[field]: + message``. A2A already carries ``field`` structurally via the data + passthrough. Regression guard — dropping ``field`` on the MCP side + leaves clients stuck parsing free-form English to find what went + wrong.""" + from adcp.server.translate import translate_error + from adcp.types import Error + + err = Error( + code="INVALID_REQUEST", + field="packages[0].budget", + message="Value should be positive", + ) + mcp_error = translate_error(err, protocol="mcp") + # ToolError's text — the only channel MCP has. + text = str(mcp_error) + assert "INVALID_REQUEST[packages[0].budget]" in text + assert "Value should be positive" in text + + +async def test_mixed_typed_and_legacy_handlers_coexist(): + """Sellers migrate incrementally — some handlers typed, others + still dict. Both must route correctly on the same handler + instance.""" + typed_received: list[Any] = [] + dict_received: list[Any] = [] + + class _Agent(ADCPHandler): + async def get_adcp_capabilities(self, params, context=None): + return {"adcp": {"major_versions": [3]}} + + async def get_products( + self, params: GetProductsRequest, context: ToolContext | None = None + ) -> Any: + typed_received.append(params) + return {"products": []} + + # Still legacy-style. + async def sync_creatives( + self, params: dict[str, Any], context: ToolContext | None = None + ) -> Any: + dict_received.append(params) + return {"results": []} + + agent = _Agent() + typed_caller = create_tool_caller(agent, "get_products") + dict_caller = create_tool_caller(agent, "sync_creatives") + + await typed_caller({"buying_mode": "brief"}) + await dict_caller({"creatives": []}) + + assert isinstance(typed_received[0], GetProductsRequest) + assert isinstance(dict_received[0], dict) + + +async def test_context_echo_uses_raw_dict_not_validated_model(): + """ADCP requires the server to echo the ``context`` field from the + request into the response. The wire ``context`` field isn't part + of typed request models — the dispatcher reads it from the raw + dict, not from the validated instance, so context echo still + works under typed dispatch.""" + + class _Agent(ADCPHandler): + async def get_adcp_capabilities(self, params, context=None): + return {"adcp": {"major_versions": [3]}} + + async def get_products( + self, params: GetProductsRequest, context: ToolContext | None = None + ) -> Any: + return {"products": []} + + caller = create_tool_caller(_Agent(), "get_products") + result = await caller( + { + "buying_mode": "brief", + "context": {"conversation_id": "c-1"}, + } + ) + # inject_context copied the request.context into the response. + assert result.get("context") == {"conversation_id": "c-1"} + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +async def test_handler_returning_already_typed_params_no_double_validation(): + """When the handler calls ``Model.model_validate(params)`` itself + (the specialized SDK bases still do this today), the typed + dispatch passing a typed instance must NOT break it. Pydantic + ``model_validate`` on an already-typed instance is a no-op — + returns the same object, validators are skipped. Verify the + existing specialized-base pattern is unaffected.""" + from adcp.types import GetProductsResponse + + received_types: list[type] = [] + + class _Agent(ADCPHandler): + async def get_adcp_capabilities(self, params, context=None): + return {"adcp": {"major_versions": [3]}} + + async def get_products( + self, + params: GetProductsRequest | dict[str, Any], + context: ToolContext | None = None, + ) -> Any: + # Specialized-base pattern: defensively re-validate. + req = GetProductsRequest.model_validate(params) + received_types.append(type(req)) + return GetProductsResponse(products=[]) + + caller = create_tool_caller(_Agent(), "get_products") + await caller({"buying_mode": "brief"}) + + # Dispatch handed the method a typed instance; the method's + # defensive model_validate was a no-op pass-through. No crash, + # no error — the existing pattern keeps working. + assert received_types == [GetProductsRequest] + + +# --------------------------------------------------------------------------- +# Custom Pydantic model — not limited to the generated request types +# --------------------------------------------------------------------------- + + +class _StrictGetProductsRequest(BaseModel): + """Module-level custom model. + + Defined at module scope because ``typing.get_type_hints`` needs to + resolve the forward reference string (``from __future__ import + annotations`` stringifies all annotations) against a reachable + namespace — the handler module globals. Models defined inside a + function body live in a local namespace that the dispatcher can't + see. Production handlers define their params models at module + top-level, so this limitation matches real usage. + """ + + buying_mode: str + promoted_offering: str + + +async def test_custom_pydantic_model_also_works(): + """Authors aren't restricted to the SDK's generated request classes. + Any Pydantic model declared on the ``params`` annotation triggers + typed dispatch. Useful for sellers who want to layer additional + validation (stricter field constraints, invariants) on top of the + spec shape.""" + received: list[Any] = [] + + class _Agent(ADCPHandler): + async def get_adcp_capabilities(self, params, context=None): + return {"adcp": {"major_versions": [3]}} + + async def get_products( + self, + params: _StrictGetProductsRequest, + context: ToolContext | None = None, + ) -> Any: + received.append(params) + return {"products": []} + + caller = create_tool_caller(_Agent(), "get_products") + await caller({"buying_mode": "brief", "promoted_offering": "test"}) + + assert isinstance(received[0], _StrictGetProductsRequest) + assert received[0].buying_mode == "brief" + + +# --------------------------------------------------------------------------- +# A second tool — prove the plumbing is tool-agnostic +# --------------------------------------------------------------------------- + + +async def test_typed_dispatch_on_second_tool(): + """Coverage for a second tool to prove the typed-dispatch plumbing + isn't ``get_products``-specific — the signature inspection walks + every handler method the same way.""" + received: list[Any] = [] + + class _Agent(ADCPHandler): + async def get_adcp_capabilities(self, params, context=None): + return {"adcp": {"major_versions": [3]}} + + async def list_creative_formats( + self, + params: ListCreativeFormatsRequest, + context: ToolContext | None = None, + ) -> Any: + received.append(params) + return {"formats": []} + + caller = create_tool_caller(_Agent(), "list_creative_formats") + await caller({}) + + assert len(received) == 1 + assert isinstance(received[0], ListCreativeFormatsRequest)