diff --git a/docs/handler-authoring.md b/docs/handler-authoring.md index 3514ff19..8f2b6396 100644 --- a/docs/handler-authoring.md +++ b/docs/handler-authoring.md @@ -240,15 +240,30 @@ inside the method still work: Pydantic's `model_validate` on an already-typed instance is a no-op (returns the same object; field validators are skipped — so a custom `@field_validator` layered on a params model won't fire twice, and won't fire again on the defensive -re-call inside the handler). +re-call inside the handler). Note: this no-op applies only when +re-validating an instance of the *same* type; the dispatch-boundary +re-validation described below uses `model_dump → model_validate` and +will fire subclass validators exactly once. **Custom models too.** You aren't restricted to the SDK's generated -request classes. Any `BaseModel` subclass declared on `params` -triggers typed dispatch — useful when you want to layer stricter -field constraints or business invariants on top of the spec shape. -Define the model at module top-level so forward-reference resolution -works (`from __future__ import annotations` stringifies all -annotations). +request classes. Any `BaseModel` subclass declared on the first +non-self, non-context parameter triggers typed dispatch — useful when +you want to layer stricter field constraints or business invariants on +top of the spec shape. Define the model at module top-level so +forward-reference resolution works (`from __future__ import +annotations` stringifies all annotations). + +This applies to both `ADCPHandler` and `DecisioningPlatform` subclasses. +For `DecisioningPlatform`, the framework detects when your platform +method's first parameter annotation is a stricter subclass of the +library's base request type and automatically re-validates the +already-deserialized params through your subclass before calling your +method. For example, a subclass with `extra="forbid"` will reject +unknown fields at the dispatch boundary, and a `@field_validator` that +narrows an enum will fire before your business logic runs. A +`pydantic.ValidationError` from this re-validation surfaces as +`INVALID_REQUEST / correctable` on the wire — not as an opaque +`INTERNAL_ERROR`. ## Authentication diff --git a/src/adcp/decisioning/dispatch.py b/src/adcp/decisioning/dispatch.py index 2902254f..8e254dde 100644 --- a/src/adcp/decisioning/dispatch.py +++ b/src/adcp/decisioning/dispatch.py @@ -38,8 +38,10 @@ import contextvars import difflib import functools +import inspect import logging import os +import typing import warnings from concurrent.futures import ThreadPoolExecutor from typing import TYPE_CHECKING, Any @@ -1120,6 +1122,92 @@ def _build_request_context( # --------------------------------------------------------------------------- +def _coerce_params_to_platform_type(method: Any, params: Any, method_name: str) -> Any: + """Re-validate ``params`` through the platform method's own type annotation. + + The shim layer (``PlatformHandler``) deserialises the wire dict into + the library's request type (e.g. ``GetProductsRequest`` with + ``extra='allow'``). When the platform subclass overrides the method + with a *stricter* subclass annotation (e.g. ``extra='forbid'``, custom + field validators), re-validate so those rules fire at the dispatch + boundary — not silently bypassed. + + Decision logic: + + * Same type — no-op; avoid double-validation overhead. + * Strict subclass (``issubclass(annotation, type(params))``) — dump + + re-validate through the subclass. A ``ValidationError`` means the + wire request genuinely violated the subclass contract; raise as + ``AdcpError('INVALID_REQUEST')`` so the wire envelope carries a + spec-typed recovery hint rather than ``INTERNAL_ERROR``. + * No subclass relation, Union annotation, non-Pydantic annotation, or + ``get_type_hints`` failure — skip coercion and return ``params`` + unchanged. + + Only called when ``arg_projector is None`` (the projector path replaces + positional args entirely, so ``params`` is unused there). + + .. note:: + The ``model_dump(mode="python") → model_validate()`` roundtrip is + safe because generated library request types carry no mutating + ``field_validator`` or ``model_validator`` declarations today. If + that changes, a validator declared on the *base* type would fire + twice: once when the shim builds the library instance, and again + here. Revisit if generated types gain mutating validators. + """ + from pydantic import BaseModel, ValidationError + + if not isinstance(params, BaseModel): + return params + try: + hints = typing.get_type_hints(method) + except Exception: + return params + + sig = inspect.signature(method) + for name, param_obj in sig.parameters.items(): + if name in ("self", "ctx", "context"): + continue + if param_obj.kind in ( + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD, + ): + # *args / **kwargs — not a typed request param; stop searching. + break + annotation = hints.get(name) + if annotation is None: + # Non-standard signature (e.g. unannotated first arg); skip + # coercion rather than guessing which param is the request. + break + if not (isinstance(annotation, type) and issubclass(annotation, BaseModel)): + break + if annotation is type(params): + return params # identical type — skip + if issubclass(annotation, type(params)): + try: + # mode="python" preserves native types (datetime, Decimal, + # UUID) so subclass validators receive them as-is, not as + # JSON-serialized strings. + return annotation.model_validate(params.model_dump(mode="python")) + except ValidationError as exc: + errors = exc.errors(include_input=False, include_context=False, include_url=False) + first: dict[str, Any] = dict(errors[0]) if errors else {} + field_path = ".".join(str(loc) for loc in first.get("loc", ())) + msg = first.get("msg", "validation failed") + raise AdcpError( + "INVALID_REQUEST", + message=( + f"Request validation failed for {method_name!r}: {msg}" + + (f" (field: {field_path!r})" if field_path else "") + ), + field=field_path or None, + recovery="correctable", + ) from exc + break + + return params + + async def _invoke_platform_method( platform: DecisioningPlatform, method_name: str, @@ -1180,6 +1268,21 @@ async def _invoke_platform_method( Hook errors are logged but never block exception propagation. """ method = getattr(platform, method_name) + # Re-validate through the platform method's own annotation when it's a + # stricter subclass of the shim's already-deserialized type. Skipped + # when arg_projector is set — that path replaces positional args entirely. + # + # Wrapped in its own try/except so on_failure fires when coercion raises + # AdcpError before the main try block — proposal-flow callers wire + # on_failure to release a reservation taken before _invoke_platform_method; + # if we raise outside the try block the reservation leaks until TTL. + if arg_projector is None: + try: + params = _coerce_params_to_platform_type(method, params, method_name) + except AdcpError as exc: + if on_failure is not None: + await _safe_on_failure_call(on_failure, exc, method_name) + raise try: if asyncio.iscoroutinefunction(method): diff --git a/tests/test_decisioning_dispatch.py b/tests/test_decisioning_dispatch.py index a9ebbbce..9ce8af80 100644 --- a/tests/test_decisioning_dispatch.py +++ b/tests/test_decisioning_dispatch.py @@ -29,6 +29,7 @@ REQUIRED_METHODS_PER_SPECIALISM, SPEC_SPECIALISM_ENUM, _build_request_context, + _coerce_params_to_platform_type, _invoke_platform_method, _project_handoff, compose_caller_identity, @@ -1282,3 +1283,265 @@ async def create_media_buy(self, req, ctx): assert isinstance(result, dict) assert result["status"] == "submitted" assert "task_type" not in result + + +# ---- _coerce_params_to_platform_type (issue #596) ---- + + +# _BaseRequest simulates the library's request type: extra="allow" so the +# shim's model_validate() accepts unknown wire fields (as the real library +# types do). _StrictSubRequest simulates the adopter's stricter subclass. +class _BaseRequest(BaseModel): + model_config = {"extra": "allow"} + known_field: str = "base" + + +class _StrictSubRequest(_BaseRequest): + model_config = {"extra": "forbid"} + + +@pytest.mark.asyncio +async def test_coerce_applies_extra_forbid_on_subclass_annotation( + executor: ThreadPoolExecutor, +) -> None: + """Platform method with extra='forbid' subclass annotation rejects unknown fields.""" + unknown_field_seen: list[bool] = [] + + class _StrictPlatform(DecisioningPlatform): + capabilities = DecisioningCapabilities() + accounts = SingletonAccounts(account_id="x") + + async def get_products(self, req: _StrictSubRequest, ctx): + unknown_field_seen.append(True) + return _ProductsResponse() + + # _BaseRequest has extra="allow" (simulating the library shim type), so + # model_validate accepts the unknown field and stores it. model_dump() + # then includes it, letting _StrictSubRequest's extra="forbid" fire. + base_params = _BaseRequest.model_validate({"known_field": "ok", "unknown_field": "bad"}) + + ctx = _build_request_context(ToolContext(), Account(id="x"), None) + with pytest.raises(AdcpError) as exc_info: + await _invoke_platform_method( + _StrictPlatform(), + "get_products", + base_params, + ctx, + executor=executor, + registry=InMemoryTaskRegistry(), + ) + assert exc_info.value.code == "INVALID_REQUEST" + assert exc_info.value.recovery == "correctable" + # Handler was never called — validation fired at the dispatch boundary. + assert not unknown_field_seen + + +@pytest.mark.asyncio +async def test_coerce_same_type_is_noop( + executor: ThreadPoolExecutor, +) -> None: + """When the platform method annotation matches the already-deserialized type exactly, + no re-validation occurs.""" + calls: list[Any] = [] + + class _ExactPlatform(DecisioningPlatform): + capabilities = DecisioningCapabilities() + accounts = SingletonAccounts(account_id="x") + + async def get_products(self, req: _BaseRequest, ctx): + calls.append(req) + return _ProductsResponse() + + base_params = _BaseRequest(known_field="hello") + ctx = _build_request_context(ToolContext(), Account(id="x"), None) + await _invoke_platform_method( + _ExactPlatform(), + "get_products", + base_params, + ctx, + executor=executor, + registry=InMemoryTaskRegistry(), + ) + assert len(calls) == 1 + assert calls[0].known_field == "hello" + assert type(calls[0]) is _BaseRequest + + +@pytest.mark.asyncio +async def test_coerce_subclass_annotation_passes_valid_data( + executor: ThreadPoolExecutor, +) -> None: + """Valid data passes through subclass re-validation and the method receives + a subclass instance.""" + received: list[Any] = [] + + class _SubPlatform(DecisioningPlatform): + capabilities = DecisioningCapabilities() + accounts = SingletonAccounts(account_id="x") + + async def get_products(self, req: _StrictSubRequest, ctx): + received.append(req) + return _ProductsResponse() + + # Only known_field — _StrictSubRequest allows this. + base_params = _BaseRequest(known_field="valid") + ctx = _build_request_context(ToolContext(), Account(id="x"), None) + await _invoke_platform_method( + _SubPlatform(), + "get_products", + base_params, + ctx, + executor=executor, + registry=InMemoryTaskRegistry(), + ) + assert len(received) == 1 + assert isinstance(received[0], _StrictSubRequest) + assert received[0].known_field == "valid" + + +@pytest.mark.asyncio +async def test_coerce_unrelated_annotation_is_noop( + executor: ThreadPoolExecutor, +) -> None: + """When the platform method annotation is not a subclass of the params type, + no coercion is attempted.""" + received: list[Any] = [] + + class _UnrelatedRequest(BaseModel): + other: int = 0 + + class _UnrelatedPlatform(DecisioningPlatform): + capabilities = DecisioningCapabilities() + accounts = SingletonAccounts(account_id="x") + + async def get_products(self, req: _UnrelatedRequest, ctx): + received.append(req) + return _ProductsResponse() + + base_params = _BaseRequest(known_field="x") + ctx = _build_request_context(ToolContext(), Account(id="x"), None) + # Should NOT raise — unrelated annotation skips coercion. + await _invoke_platform_method( + _UnrelatedPlatform(), + "get_products", + base_params, + ctx, + executor=executor, + registry=InMemoryTaskRegistry(), + ) + # Method received the original base_params unchanged. + assert len(received) == 1 + assert type(received[0]) is _BaseRequest + + +@pytest.mark.asyncio +async def test_coerce_param_name_agnostic( + executor: ThreadPoolExecutor, +) -> None: + """Coercion works regardless of the first parameter's name (req, params, request, etc.).""" + received: list[Any] = [] + + class _ReqNamedPlatform(DecisioningPlatform): + capabilities = DecisioningCapabilities() + accounts = SingletonAccounts(account_id="x") + + async def get_products(self, params: _StrictSubRequest, ctx): + received.append(params) + return _ProductsResponse() + + base_params = _BaseRequest(known_field="named") + ctx = _build_request_context(ToolContext(), Account(id="x"), None) + await _invoke_platform_method( + _ReqNamedPlatform(), + "get_products", + base_params, + ctx, + executor=executor, + registry=InMemoryTaskRegistry(), + ) + assert isinstance(received[0], _StrictSubRequest) + + +@pytest.mark.asyncio +async def test_coerce_get_type_hints_failure_passes_through( + executor: ThreadPoolExecutor, +) -> None: + """When get_type_hints() fails (e.g. TYPE_CHECKING-only annotation that + can't be resolved at runtime), coercion is skipped and the original + params are passed through unchanged.""" + received: list[Any] = [] + + class _ForwardRefPlatform(DecisioningPlatform): + capabilities = DecisioningCapabilities() + accounts = SingletonAccounts(account_id="x") + + # Annotated with a string that won't resolve — simulates an + # annotation declared under TYPE_CHECKING. + async def get_products(self, req: _NonExistentType, ctx): # type: ignore[name-defined] # noqa: F821 + received.append(req) + return _ProductsResponse() + + base_params = _BaseRequest(known_field="passthrough") + ctx = _build_request_context(ToolContext(), Account(id="x"), None) + # Should NOT raise — graceful degradation when get_type_hints fails. + await _invoke_platform_method( + _ForwardRefPlatform(), + "get_products", + base_params, + ctx, + executor=executor, + registry=InMemoryTaskRegistry(), + ) + # Original params passed through unmodified. + assert len(received) == 1 + assert type(received[0]) is _BaseRequest + assert received[0].known_field == "passthrough" + + +@pytest.mark.asyncio +async def test_coerce_fires_on_failure_hook_on_validation_error( + executor: ThreadPoolExecutor, +) -> None: + """When coercion raises AdcpError (extra='forbid' violation), the on_failure + hook must be called so proposal-flow callers can release reservations.""" + on_failure_calls: list[BaseException] = [] + + async def _on_failure(exc: BaseException) -> None: + on_failure_calls.append(exc) + + class _StrictPlatform(DecisioningPlatform): + capabilities = DecisioningCapabilities() + accounts = SingletonAccounts(account_id="x") + + async def get_products(self, req: _StrictSubRequest, ctx): + return _ProductsResponse() + + base_params = _BaseRequest.model_validate({"known_field": "ok", "unknown_field": "bad"}) + ctx = _build_request_context(ToolContext(), Account(id="x"), None) + with pytest.raises(AdcpError) as exc_info: + await _invoke_platform_method( + _StrictPlatform(), + "get_products", + base_params, + ctx, + executor=executor, + registry=InMemoryTaskRegistry(), + on_failure=_on_failure, + ) + assert exc_info.value.code == "INVALID_REQUEST" + # on_failure must fire — proposal-flow callers wire it to release reservations. + assert len(on_failure_calls) == 1 + assert on_failure_calls[0] is exc_info.value + + +def test_coerce_varargs_annotation_is_noop() -> None: + """Annotated *args should not trigger coercion — VAR_POSITIONAL guard fires.""" + + async def _varargs_method(self, *args: _StrictSubRequest, ctx): # type: ignore[name-defined] + pass + + # Extra field present — if coercion fired, it would raise. + base_params = _BaseRequest.model_validate({"known_field": "ok", "unknown_field": "bad"}) + result = _coerce_params_to_platform_type(_varargs_method, base_params, "test") + # VAR_POSITIONAL guard must prevent coercion — original object returned unchanged. + assert result is base_params