From e1eb92350e9872cfadbff4bff6a274a49af59c83 Mon Sep 17 00:00:00 2001 From: Brian O'Kelley Date: Mon, 20 Apr 2026 09:23:38 -0400 Subject: [PATCH 1/2] feat(server): TypeVar-bound ADCPHandler for typed ToolContext subclasses (closes #223) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Roadmap PR-Q + expert-review followup from #219. Multi-tenant agents routinely subclass ToolContext to carry typed tenant/adapter/testing fields the base doesn't name — before this PR those fields required casts at every handler method. Now ADCPHandler is Generic[TContext] bound to ToolContext; downstream writes ``class MyAgent(ADCPHandler[MyContext])`` and every handler method signature propagates the subclass type. API - New TContext = TypeVar("TContext", bound="ToolContext") exported from adcp.server. Docstring at the declaration site explains the pattern. - ADCPHandler now inherits Generic[TContext]. All 57 method signatures in base.py rewritten to take context: TContext | None. - Protocol handlers (BrandHandler, ComplianceHandler, ContentStandardsHandler, GovernanceHandler, SponsoredIntelligenceHandler, TmpHandler) propagate TContext via class X(ADCPHandler[TContext], Generic[TContext]) so downstream can write class MyBrand(BrandHandler[MyContext]). - Internal SDK annotations (mcp_tools.py, serve.py, a2a_server.py, builder.py) use ADCPHandler[Any] where the SDK doesn't care about the TContext — the decorator-builder path and the transport executors don't thread a specific subclass. Backward compat - class MyAgent(ADCPHandler) without a TypeVar argument still works at runtime. Existing subclasses keep working without edits. - No runtime behavior change. The TypeVar is purely type-system narrowing; handler dispatch paths are unchanged. Tests — tests/test_handler_typevar.py (9 new, 1544 total) - Unparameterised subclass still works (backward compat). - Parameterised ADCPHandler[MyContext] constructs cleanly. - Protocol handlers propagate the TypeVar (BrandHandler[MyContext]). - Handler methods receive the subclass at dispatch time — the runtime isinstance(context, MyContext) is true. - BrandHandler[MyContext] propagation tested end-to-end. - TContext.__bound__ is ToolContext. - A2A executor dispatches a typed handler without issue. - Method signature structure preserved (self, params, context positions). Co-Authored-By: Claude Opus 4.7 (1M context) --- src/adcp/server/__init__.py | 2 + src/adcp/server/a2a_server.py | 6 +- src/adcp/server/base.py | 148 +++++++----- src/adcp/server/brand.py | 6 +- src/adcp/server/builder.py | 23 +- src/adcp/server/compliance.py | 6 +- src/adcp/server/content_standards.py | 36 +-- src/adcp/server/governance.py | 48 ++-- src/adcp/server/mcp_tools.py | 10 +- src/adcp/server/serve.py | 10 +- src/adcp/server/sponsored_intelligence.py | 24 +- src/adcp/server/tmp.py | 6 +- tests/test_handler_typevar.py | 278 ++++++++++++++++++++++ 13 files changed, 458 insertions(+), 145 deletions(-) create mode 100644 tests/test_handler_typevar.py diff --git a/src/adcp/server/__init__.py b/src/adcp/server/__init__.py index a35f868e..443a7de6 100644 --- a/src/adcp/server/__init__.py +++ b/src/adcp/server/__init__.py @@ -57,6 +57,7 @@ async def get_products(params, context=None): from adcp.server.base import ( ADCPHandler, NotImplementedResponse, + TContext, ToolContext, not_supported, ) @@ -129,6 +130,7 @@ async def get_products(params, context=None): "ADCPHandler", "BrandHandler", "ComplianceHandler", + "TContext", "TmpHandler", "ToolContext", "NotImplementedResponse", diff --git a/src/adcp/server/a2a_server.py b/src/adcp/server/a2a_server.py index 36d24631..60148e19 100644 --- a/src/adcp/server/a2a_server.py +++ b/src/adcp/server/a2a_server.py @@ -67,7 +67,7 @@ class ADCPAgentExecutor(AgentExecutor): def __init__( self, - handler: ADCPHandler, + handler: ADCPHandler[Any], test_controller: TestControllerStore | None = None, *, context_factory: ContextFactory | None = None, @@ -435,7 +435,7 @@ def _make_task( def _build_agent_card( - handler: ADCPHandler, + handler: ADCPHandler[Any], *, name: str, port: int, @@ -481,7 +481,7 @@ def _build_agent_card( def create_a2a_server( - handler: ADCPHandler, + handler: ADCPHandler[Any], *, name: str = "adcp-agent", port: int | None = None, diff --git a/src/adcp/server/base.py b/src/adcp/server/base.py index 9546f79c..7eca795b 100644 --- a/src/adcp/server/base.py +++ b/src/adcp/server/base.py @@ -8,7 +8,7 @@ from abc import ABC from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Generic, TypeVar from pydantic import BaseModel @@ -151,13 +151,43 @@ def not_supported( ) -class ADCPHandler(ABC): +TContext = TypeVar("TContext", bound="ToolContext") +"""TypeVar bound to :class:`ToolContext` for parameterising +:class:`ADCPHandler` over a caller-defined context subclass. + +Multi-tenant agents typically subclass :class:`ToolContext` to carry +typed tenant/adapter/testing fields the base doesn't name. Declaring +``class MyAgent(ADCPHandler[MyContext])`` makes that subclass visible to +every handler method signature — callers get the typed subclass on the +``context`` parameter without casting:: + + @dataclass + class MyContext(ToolContext): + adapter: MyPlatformAdapter + + class MyAgent(ADCPHandler[MyContext]): + async def get_products(self, params, context: MyContext | None = None): + if context is not None: + adapter = context.adapter # typed, no cast + +Handlers that don't subclass ``ToolContext`` can still write +``class MyAgent(ADCPHandler)`` — unparameterised Generic resolves to +``ADCPHandler[ToolContext]`` at runtime (the ``TypeVar`` bound), so +existing subclasses keep working without edits. +""" + + +class ADCPHandler(ABC, Generic[TContext]): """Base class for ADCP operation handlers. Subclass this to implement ADCP operations. All operations have default implementations that return 'not supported', allowing you to implement only the operations your agent supports. + Parameterise over a :class:`ToolContext` subclass — ``class MyAgent(ADCPHandler[MyContext])`` + — to get typed ``context`` arguments on every method signature. See + :data:`TContext` for the pattern. + For protocol-specific handlers, use: - ContentStandardsHandler: For content standards agents - SponsoredIntelligenceHandler: For sponsored intelligence agents @@ -175,7 +205,7 @@ def _not_supported(self, operation: str) -> NotImplementedResponse: # ======================================================================== async def get_products( - self, params: GetProductsRequest | dict[str, Any], context: ToolContext | None = None + self, params: GetProductsRequest | dict[str, Any], context: TContext | None = None ) -> Any: """Get advertising products. @@ -186,7 +216,7 @@ async def get_products( async def list_creative_formats( self, params: ListCreativeFormatsRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> Any: """List supported creative formats. @@ -199,7 +229,7 @@ async def list_creative_formats( # ======================================================================== async def sync_creatives( - self, params: SyncCreativesRequest | dict[str, Any], context: ToolContext | None = None + self, params: SyncCreativesRequest | dict[str, Any], context: TContext | None = None ) -> Any: """Sync creatives. @@ -208,7 +238,7 @@ async def sync_creatives( return self._not_supported("sync_creatives") async def list_creatives( - self, params: ListCreativesRequest | dict[str, Any], context: ToolContext | None = None + self, params: ListCreativesRequest | dict[str, Any], context: TContext | None = None ) -> Any: """List creatives. @@ -217,7 +247,7 @@ async def list_creatives( return self._not_supported("list_creatives") async def build_creative( - self, params: BuildCreativeRequest | dict[str, Any], context: ToolContext | None = None + self, params: BuildCreativeRequest | dict[str, Any], context: TContext | None = None ) -> Any: """Build a creative. @@ -226,7 +256,7 @@ async def build_creative( return self._not_supported("build_creative") async def preview_creative( - self, params: PreviewCreativeRequest | dict[str, Any], context: ToolContext | None = None + self, params: PreviewCreativeRequest | dict[str, Any], context: TContext | None = None ) -> Any: """Preview a creative rendering. @@ -237,7 +267,7 @@ async def preview_creative( async def get_creative_delivery( self, params: GetCreativeDeliveryRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> Any: """Get creative delivery metrics. @@ -250,7 +280,7 @@ async def get_creative_delivery( # ======================================================================== async def create_media_buy( - self, params: CreateMediaBuyRequest | dict[str, Any], context: ToolContext | None = None + self, params: CreateMediaBuyRequest | dict[str, Any], context: TContext | None = None ) -> Any: """Create a media buy. @@ -259,7 +289,7 @@ async def create_media_buy( return self._not_supported("create_media_buy") async def update_media_buy( - self, params: UpdateMediaBuyRequest | dict[str, Any], context: ToolContext | None = None + self, params: UpdateMediaBuyRequest | dict[str, Any], context: TContext | None = None ) -> Any: """Update a media buy. @@ -270,7 +300,7 @@ async def update_media_buy( async def get_media_buy_delivery( self, params: GetMediaBuyDeliveryRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> Any: """Get media buy delivery metrics. @@ -279,7 +309,7 @@ async def get_media_buy_delivery( return self._not_supported("get_media_buy_delivery") async def get_media_buys( - self, params: GetMediaBuysRequest | dict[str, Any], context: ToolContext | None = None + self, params: GetMediaBuysRequest | dict[str, Any], context: TContext | None = None ) -> Any: """Get media buys with status and optional delivery snapshots. @@ -292,7 +322,7 @@ async def get_media_buys( # ======================================================================== async def get_signals( - self, params: GetSignalsRequest | dict[str, Any], context: ToolContext | None = None + self, params: GetSignalsRequest | dict[str, Any], context: TContext | None = None ) -> Any: """Get available signals. @@ -301,7 +331,7 @@ async def get_signals( return self._not_supported("get_signals") async def activate_signal( - self, params: ActivateSignalRequest | dict[str, Any], context: ToolContext | None = None + self, params: ActivateSignalRequest | dict[str, Any], context: TContext | None = None ) -> Any: """Activate a signal. @@ -316,7 +346,7 @@ async def activate_signal( async def provide_performance_feedback( self, params: ProvidePerformanceFeedbackRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> Any: """Provide performance feedback. @@ -329,7 +359,7 @@ async def provide_performance_feedback( # ======================================================================== async def list_accounts( - self, params: ListAccountsRequest | dict[str, Any], context: ToolContext | None = None + self, params: ListAccountsRequest | dict[str, Any], context: TContext | None = None ) -> Any: """List accounts. @@ -338,7 +368,7 @@ async def list_accounts( return self._not_supported("list_accounts") async def sync_accounts( - self, params: SyncAccountsRequest | dict[str, Any], context: ToolContext | None = None + self, params: SyncAccountsRequest | dict[str, Any], context: TContext | None = None ) -> Any: """Sync accounts. @@ -349,7 +379,7 @@ async def sync_accounts( async def get_account_financials( self, params: GetAccountFinancialsRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> Any: """Get account financials. @@ -358,7 +388,7 @@ async def get_account_financials( return self._not_supported("get_account_financials") async def report_usage( - self, params: ReportUsageRequest | dict[str, Any], context: ToolContext | None = None + self, params: ReportUsageRequest | dict[str, Any], context: TContext | None = None ) -> Any: """Report account usage. @@ -371,7 +401,7 @@ async def report_usage( # ======================================================================== async def log_event( - self, params: LogEventRequest | dict[str, Any], context: ToolContext | None = None + self, params: LogEventRequest | dict[str, Any], context: TContext | None = None ) -> Any: """Log event. @@ -380,7 +410,7 @@ async def log_event( return self._not_supported("log_event") async def sync_event_sources( - self, params: SyncEventSourcesRequest | dict[str, Any], context: ToolContext | None = None + self, params: SyncEventSourcesRequest | dict[str, Any], context: TContext | None = None ) -> Any: """Sync event sources. @@ -389,7 +419,7 @@ async def sync_event_sources( return self._not_supported("sync_event_sources") async def sync_audiences( - self, params: SyncAudiencesRequest | dict[str, Any], context: ToolContext | None = None + self, params: SyncAudiencesRequest | dict[str, Any], context: TContext | None = None ) -> Any: """Sync audiences. @@ -398,7 +428,7 @@ async def sync_audiences( return self._not_supported("sync_audiences") async def sync_governance( - self, params: SyncGovernanceRequest | dict[str, Any], context: ToolContext | None = None + self, params: SyncGovernanceRequest | dict[str, Any], context: TContext | None = None ) -> Any: """Sync governance agents for accounts. @@ -407,7 +437,7 @@ async def sync_governance( return self._not_supported("sync_governance") async def sync_catalogs( - self, params: SyncCatalogsRequest | dict[str, Any], context: ToolContext | None = None + self, params: SyncCatalogsRequest | dict[str, Any], context: TContext | None = None ) -> Any: """Sync catalogs. @@ -422,7 +452,7 @@ async def sync_catalogs( async def get_adcp_capabilities( self, params: GetAdcpCapabilitiesRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> Any: """Get ADCP capabilities. @@ -437,7 +467,7 @@ async def get_adcp_capabilities( async def create_content_standards( self, params: CreateContentStandardsRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> Any: """Create content standards configuration. @@ -448,7 +478,7 @@ async def create_content_standards( async def get_content_standards( self, params: GetContentStandardsRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> Any: """Get content standards configuration. @@ -459,7 +489,7 @@ async def get_content_standards( async def list_content_standards( self, params: ListContentStandardsRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> Any: """List content standards configurations. @@ -470,7 +500,7 @@ async def list_content_standards( async def update_content_standards( self, params: UpdateContentStandardsRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> Any: """Update content standards configuration. @@ -479,7 +509,7 @@ async def update_content_standards( return self._not_supported("update_content_standards") async def calibrate_content( - self, params: CalibrateContentRequest | dict[str, Any], context: ToolContext | None = None + self, params: CalibrateContentRequest | dict[str, Any], context: TContext | None = None ) -> Any: """Calibrate content against standards. @@ -490,7 +520,7 @@ async def calibrate_content( async def validate_content_delivery( self, params: ValidateContentDeliveryRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> Any: """Validate content delivery against standards. @@ -501,7 +531,7 @@ async def validate_content_delivery( async def get_media_buy_artifacts( self, params: GetMediaBuyArtifactsRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> Any: """Get artifacts associated with a media buy. @@ -514,7 +544,7 @@ async def get_media_buy_artifacts( # ======================================================================== async def si_get_offering( - self, params: SiGetOfferingRequest | dict[str, Any], context: ToolContext | None = None + self, params: SiGetOfferingRequest | dict[str, Any], context: TContext | None = None ) -> Any: """Get sponsored intelligence offering. @@ -523,7 +553,7 @@ async def si_get_offering( return self._not_supported("si_get_offering") async def si_initiate_session( - self, params: SiInitiateSessionRequest | dict[str, Any], context: ToolContext | None = None + self, params: SiInitiateSessionRequest | dict[str, Any], context: TContext | None = None ) -> Any: """Initiate sponsored intelligence session. @@ -532,7 +562,7 @@ async def si_initiate_session( return self._not_supported("si_initiate_session") async def si_send_message( - self, params: SiSendMessageRequest | dict[str, Any], context: ToolContext | None = None + self, params: SiSendMessageRequest | dict[str, Any], context: TContext | None = None ) -> Any: """Send message in sponsored intelligence session. @@ -541,7 +571,7 @@ async def si_send_message( return self._not_supported("si_send_message") async def si_terminate_session( - self, params: SiTerminateSessionRequest | dict[str, Any], context: ToolContext | None = None + self, params: SiTerminateSessionRequest | dict[str, Any], context: TContext | None = None ) -> Any: """Terminate sponsored intelligence session. @@ -556,7 +586,7 @@ async def si_terminate_session( async def get_creative_features( self, params: GetCreativeFeaturesRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> Any: """Evaluate governance features for a creative. @@ -565,7 +595,7 @@ async def get_creative_features( return self._not_supported("get_creative_features") async def sync_plans( - self, params: SyncPlansRequest | dict[str, Any], context: ToolContext | None = None + self, params: SyncPlansRequest | dict[str, Any], context: TContext | None = None ) -> Any: """Sync campaign governance plans. @@ -574,7 +604,7 @@ async def sync_plans( return self._not_supported("sync_plans") async def check_governance( - self, params: CheckGovernanceRequest | dict[str, Any], context: ToolContext | None = None + self, params: CheckGovernanceRequest | dict[str, Any], context: TContext | None = None ) -> Any: """Check an action against campaign governance. @@ -583,7 +613,7 @@ async def check_governance( return self._not_supported("check_governance") async def report_plan_outcome( - self, params: ReportPlanOutcomeRequest | dict[str, Any], context: ToolContext | None = None + self, params: ReportPlanOutcomeRequest | dict[str, Any], context: TContext | None = None ) -> Any: """Report the outcome of a governed action. @@ -592,7 +622,7 @@ async def report_plan_outcome( return self._not_supported("report_plan_outcome") async def get_plan_audit_logs( - self, params: GetPlanAuditLogsRequest | dict[str, Any], context: ToolContext | None = None + self, params: GetPlanAuditLogsRequest | dict[str, Any], context: TContext | None = None ) -> Any: """Retrieve governance audit logs for plans. @@ -601,7 +631,7 @@ async def get_plan_audit_logs( return self._not_supported("get_plan_audit_logs") async def create_property_list( - self, params: CreatePropertyListRequest | dict[str, Any], context: ToolContext | None = None + self, params: CreatePropertyListRequest | dict[str, Any], context: TContext | None = None ) -> Any: """Create a property list for governance filtering. @@ -610,7 +640,7 @@ async def create_property_list( return self._not_supported("create_property_list") async def get_property_list( - self, params: GetPropertyListRequest | dict[str, Any], context: ToolContext | None = None + self, params: GetPropertyListRequest | dict[str, Any], context: TContext | None = None ) -> Any: """Get a property list with optional resolution. @@ -619,7 +649,7 @@ async def get_property_list( return self._not_supported("get_property_list") async def list_property_lists( - self, params: ListPropertyListsRequest | dict[str, Any], context: ToolContext | None = None + self, params: ListPropertyListsRequest | dict[str, Any], context: TContext | None = None ) -> Any: """List property lists. @@ -628,7 +658,7 @@ async def list_property_lists( return self._not_supported("list_property_lists") async def update_property_list( - self, params: UpdatePropertyListRequest | dict[str, Any], context: ToolContext | None = None + self, params: UpdatePropertyListRequest | dict[str, Any], context: TContext | None = None ) -> Any: """Update a property list. @@ -637,7 +667,7 @@ async def update_property_list( return self._not_supported("update_property_list") async def delete_property_list( - self, params: DeletePropertyListRequest | dict[str, Any], context: ToolContext | None = None + self, params: DeletePropertyListRequest | dict[str, Any], context: TContext | None = None ) -> Any: """Delete a property list. @@ -652,7 +682,7 @@ async def delete_property_list( async def create_collection_list( self, params: CreateCollectionListRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> Any: """Create a collection list for governance filtering. @@ -661,7 +691,7 @@ async def create_collection_list( return self._not_supported("create_collection_list") async def get_collection_list( - self, params: GetCollectionListRequest | dict[str, Any], context: ToolContext | None = None + self, params: GetCollectionListRequest | dict[str, Any], context: TContext | None = None ) -> Any: """Get a collection list with optional resolution. @@ -672,7 +702,7 @@ async def get_collection_list( async def list_collection_lists( self, params: ListCollectionListsRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> Any: """List collection lists. @@ -683,7 +713,7 @@ async def list_collection_lists( async def update_collection_list( self, params: UpdateCollectionListRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> Any: """Update a collection list. @@ -694,7 +724,7 @@ async def update_collection_list( async def delete_collection_list( self, params: DeleteCollectionListRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> Any: """Delete a collection list. @@ -707,7 +737,7 @@ async def delete_collection_list( # ======================================================================== async def context_match( - self, params: ContextMatchRequest | dict[str, Any], context: ToolContext | None = None + self, params: ContextMatchRequest | dict[str, Any], context: TContext | None = None ) -> Any: """Match ad context to buyer packages. @@ -716,7 +746,7 @@ async def context_match( return self._not_supported("context_match") async def identity_match( - self, params: IdentityMatchRequest | dict[str, Any], context: ToolContext | None = None + self, params: IdentityMatchRequest | dict[str, Any], context: TContext | None = None ) -> Any: """Match user identity for package eligibility. @@ -729,7 +759,7 @@ async def identity_match( # ======================================================================== async def get_brand_identity( - self, params: GetBrandIdentityRequest | dict[str, Any], context: ToolContext | None = None + self, params: GetBrandIdentityRequest | dict[str, Any], context: TContext | None = None ) -> Any: """Get brand identity information. @@ -738,7 +768,7 @@ async def get_brand_identity( return self._not_supported("get_brand_identity") async def get_rights( - self, params: GetRightsRequest | dict[str, Any], context: ToolContext | None = None + self, params: GetRightsRequest | dict[str, Any], context: TContext | None = None ) -> Any: """Get available rights for licensing. @@ -747,7 +777,7 @@ async def get_rights( return self._not_supported("get_rights") async def acquire_rights( - self, params: AcquireRightsRequest | dict[str, Any], context: ToolContext | None = None + self, params: AcquireRightsRequest | dict[str, Any], context: TContext | None = None ) -> Any: """Acquire rights for brand content usage. @@ -756,7 +786,7 @@ async def acquire_rights( return self._not_supported("acquire_rights") async def update_rights( - self, params: UpdateRightsRequest | dict[str, Any], context: ToolContext | None = None + self, params: UpdateRightsRequest | dict[str, Any], context: TContext | None = None ) -> Any: """Update terms of an existing rights acquisition. @@ -782,7 +812,7 @@ async def update_rights( async def comply_test_controller( self, params: ComplyTestControllerRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> Any: """Compliance test controller (sandbox only). diff --git a/src/adcp/server/brand.py b/src/adcp/server/brand.py index 5dc52fa3..a3ab690e 100644 --- a/src/adcp/server/brand.py +++ b/src/adcp/server/brand.py @@ -2,10 +2,12 @@ from __future__ import annotations -from adcp.server.base import ADCPHandler +from typing import Generic +from adcp.server.base import ADCPHandler, TContext -class BrandHandler(ADCPHandler): + +class BrandHandler(ADCPHandler[TContext], Generic[TContext]): """Handler for brand rights operations. Subclass this to implement brand identity and rights management. diff --git a/src/adcp/server/builder.py b/src/adcp/server/builder.py index 4a96b979..f8704eb5 100644 --- a/src/adcp/server/builder.py +++ b/src/adcp/server/builder.py @@ -126,14 +126,8 @@ def __getattr__(self, task_name: str) -> Callable[..., Any]: raise AttributeError(task_name) def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: - if ( - task_name not in HANDLER_TO_DOMAIN - and task_name != "get_adcp_capabilities" - ): - raise ValueError( - f"'{task_name}' is not a known ADCP task. " - f"Check for typos." - ) + if task_name not in HANDLER_TO_DOMAIN and task_name != "get_adcp_capabilities": + raise ValueError(f"'{task_name}' is not a known ADCP task. " f"Check for typos.") self._handlers[task_name] = fn return fn @@ -148,7 +142,7 @@ def _detect_domains(self) -> list[str]: domains.add(domain) return sorted(domains) - def build_handler(self) -> ADCPHandler: + def build_handler(self) -> ADCPHandler[Any]: """Build an ADCPHandler from registered decorators. If ``get_adcp_capabilities`` is not registered, it will be @@ -162,15 +156,16 @@ def build_handler(self) -> ADCPHandler: if domains: from adcp.server.responses import capabilities_response - async def auto_capabilities( - params: Any, context: Any = None - ) -> dict[str, Any]: + async def auto_capabilities(params: Any, context: Any = None) -> dict[str, Any]: return capabilities_response(domains) handlers["get_adcp_capabilities"] = auto_capabilities - # Create a dynamic subclass - class DynamicHandler(ADCPHandler): + # Create a dynamic subclass. ``ADCPHandler[Any]`` because the + # decorator-builder path doesn't thread a specific ToolContext + # subclass — callers who want typed context go through the + # class-based ``ADCPHandler[MyContext]`` route instead. + class DynamicHandler(ADCPHandler[Any]): pass for task_name, fn in handlers.items(): diff --git a/src/adcp/server/compliance.py b/src/adcp/server/compliance.py index 6319dfcc..d55b3b6d 100644 --- a/src/adcp/server/compliance.py +++ b/src/adcp/server/compliance.py @@ -2,10 +2,12 @@ from __future__ import annotations -from adcp.server.base import ADCPHandler +from typing import Generic +from adcp.server.base import ADCPHandler, TContext -class ComplianceHandler(ADCPHandler): + +class ComplianceHandler(ADCPHandler[TContext], Generic[TContext]): """Handler for compliance test operations. Subclass this to implement compliance sandbox testing. diff --git a/src/adcp/server/content_standards.py b/src/adcp/server/content_standards.py index c647a313..eb9fa7d6 100644 --- a/src/adcp/server/content_standards.py +++ b/src/adcp/server/content_standards.py @@ -7,11 +7,11 @@ from __future__ import annotations from abc import abstractmethod -from typing import Any +from typing import Any, Generic from pydantic import ValidationError -from adcp.server.base import ADCPHandler, NotImplementedResponse, ToolContext +from adcp.server.base import ADCPHandler, NotImplementedResponse, TContext from adcp.types import ( CalibrateContentRequest, CalibrateContentResponse, @@ -31,7 +31,7 @@ ) -class ContentStandardsHandler(ADCPHandler): +class ContentStandardsHandler(ADCPHandler[TContext], Generic[TContext]): """Handler for Content Standards protocol. Subclass this to implement a Content Standards agent. All Content Standards @@ -47,7 +47,7 @@ class MyContentStandardsHandler(ContentStandardsHandler): async def handle_create_content_standards( self, request: CreateContentStandardsRequest, - context: ToolContext | None = None + context: TContext | None = None ) -> CreateContentStandardsResponse: # Your implementation return CreateContentStandardsResponse(...) @@ -62,7 +62,7 @@ async def handle_create_content_standards( async def create_content_standards( self, params: CreateContentStandardsRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> CreateContentStandardsResponse | NotImplementedResponse: """Create content standards configuration. @@ -81,7 +81,7 @@ async def create_content_standards( async def get_content_standards( self, params: GetContentStandardsRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> GetContentStandardsResponse | NotImplementedResponse: """Get content standards configuration. @@ -100,7 +100,7 @@ async def get_content_standards( async def list_content_standards( self, params: ListContentStandardsRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> ListContentStandardsResponse | NotImplementedResponse: """List content standards configurations. @@ -119,7 +119,7 @@ async def list_content_standards( async def update_content_standards( self, params: UpdateContentStandardsRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> UpdateContentStandardsResponse | NotImplementedResponse: """Update content standards configuration. @@ -138,7 +138,7 @@ async def update_content_standards( async def calibrate_content( self, params: CalibrateContentRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> CalibrateContentResponse | NotImplementedResponse: """Calibrate content against standards. @@ -157,7 +157,7 @@ async def calibrate_content( async def validate_content_delivery( self, params: ValidateContentDeliveryRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> ValidateContentDeliveryResponse | NotImplementedResponse: """Validate content delivery against standards. @@ -176,7 +176,7 @@ async def validate_content_delivery( async def get_media_buy_artifacts( self, params: GetMediaBuyArtifactsRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> GetMediaBuyArtifactsResponse | NotImplementedResponse: """Get artifacts associated with a media buy. @@ -200,7 +200,7 @@ async def get_media_buy_artifacts( async def handle_create_content_standards( self, request: CreateContentStandardsRequest, - context: ToolContext | None = None, + context: TContext | None = None, ) -> CreateContentStandardsResponse: """Handle create content standards request.""" ... @@ -209,7 +209,7 @@ async def handle_create_content_standards( async def handle_get_content_standards( self, request: GetContentStandardsRequest, - context: ToolContext | None = None, + context: TContext | None = None, ) -> GetContentStandardsResponse: """Handle get content standards request.""" ... @@ -218,7 +218,7 @@ async def handle_get_content_standards( async def handle_list_content_standards( self, request: ListContentStandardsRequest, - context: ToolContext | None = None, + context: TContext | None = None, ) -> ListContentStandardsResponse: """Handle list content standards request.""" ... @@ -227,7 +227,7 @@ async def handle_list_content_standards( async def handle_update_content_standards( self, request: UpdateContentStandardsRequest, - context: ToolContext | None = None, + context: TContext | None = None, ) -> UpdateContentStandardsResponse: """Handle update content standards request.""" ... @@ -236,7 +236,7 @@ async def handle_update_content_standards( async def handle_calibrate_content( self, request: CalibrateContentRequest, - context: ToolContext | None = None, + context: TContext | None = None, ) -> CalibrateContentResponse: """Handle calibrate content request.""" ... @@ -245,7 +245,7 @@ async def handle_calibrate_content( async def handle_validate_content_delivery( self, request: ValidateContentDeliveryRequest, - context: ToolContext | None = None, + context: TContext | None = None, ) -> ValidateContentDeliveryResponse: """Handle validate content delivery request.""" ... @@ -254,7 +254,7 @@ async def handle_validate_content_delivery( async def handle_get_media_buy_artifacts( self, request: GetMediaBuyArtifactsRequest, - context: ToolContext | None = None, + context: TContext | None = None, ) -> GetMediaBuyArtifactsResponse: """Handle get media buy artifacts request.""" ... diff --git a/src/adcp/server/governance.py b/src/adcp/server/governance.py index 1cc540a4..d448f7ac 100644 --- a/src/adcp/server/governance.py +++ b/src/adcp/server/governance.py @@ -7,11 +7,11 @@ from __future__ import annotations from abc import abstractmethod -from typing import Any +from typing import Any, Generic from pydantic import ValidationError -from adcp.server.base import ADCPHandler, NotImplementedResponse, ToolContext +from adcp.server.base import ADCPHandler, NotImplementedResponse, TContext from adcp.types import ( CheckGovernanceRequest, CheckGovernanceResponse, @@ -37,7 +37,7 @@ ) -class GovernanceHandler(ADCPHandler): +class GovernanceHandler(ADCPHandler[TContext], Generic[TContext]): """Handler for Governance protocol (Property Lists). Subclass this to implement a Governance agent that manages property lists @@ -55,7 +55,7 @@ class MyGovernanceHandler(GovernanceHandler): async def handle_create_property_list( self, request: CreatePropertyListRequest, - context: ToolContext | None = None + context: TContext | None = None ) -> CreatePropertyListResponse: # Store the list definition list_id = generate_id() @@ -72,7 +72,7 @@ async def handle_create_property_list( async def get_creative_features( self, params: GetCreativeFeaturesRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> GetCreativeFeaturesResponse | NotImplementedResponse: """Evaluate governance features for a creative manifest.""" try: @@ -88,7 +88,7 @@ async def get_creative_features( async def sync_plans( self, params: SyncPlansRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> SyncPlansResponse | NotImplementedResponse: """Sync campaign governance plans to the agent.""" try: @@ -104,7 +104,7 @@ async def sync_plans( async def check_governance( self, params: CheckGovernanceRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> CheckGovernanceResponse | NotImplementedResponse: """Check whether a proposed or committed action complies with plan governance.""" try: @@ -120,7 +120,7 @@ async def check_governance( async def report_plan_outcome( self, params: ReportPlanOutcomeRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> ReportPlanOutcomeResponse | NotImplementedResponse: """Report the outcome of a previously governed action.""" try: @@ -136,7 +136,7 @@ async def report_plan_outcome( async def get_plan_audit_logs( self, params: GetPlanAuditLogsRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> GetPlanAuditLogsResponse | NotImplementedResponse: """Retrieve governance audit logs for one or more plans.""" try: @@ -152,7 +152,7 @@ async def get_plan_audit_logs( async def create_property_list( self, params: CreatePropertyListRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> CreatePropertyListResponse | NotImplementedResponse: """Create a property list for governance filtering. @@ -171,7 +171,7 @@ async def create_property_list( async def get_property_list( self, params: GetPropertyListRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> GetPropertyListResponse | NotImplementedResponse: """Get a property list with optional resolution. @@ -190,7 +190,7 @@ async def get_property_list( async def list_property_lists( self, params: ListPropertyListsRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> ListPropertyListsResponse | NotImplementedResponse: """List property lists. @@ -209,7 +209,7 @@ async def list_property_lists( async def update_property_list( self, params: UpdatePropertyListRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> UpdatePropertyListResponse | NotImplementedResponse: """Update a property list. @@ -228,7 +228,7 @@ async def update_property_list( async def delete_property_list( self, params: DeletePropertyListRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> DeletePropertyListResponse | NotImplementedResponse: """Delete a property list. @@ -252,7 +252,7 @@ async def delete_property_list( async def handle_get_creative_features( self, request: GetCreativeFeaturesRequest, - context: ToolContext | None = None, + context: TContext | None = None, ) -> GetCreativeFeaturesResponse: """Handle creative feature evaluation.""" ... @@ -261,7 +261,7 @@ async def handle_get_creative_features( async def handle_sync_plans( self, request: SyncPlansRequest, - context: ToolContext | None = None, + context: TContext | None = None, ) -> SyncPlansResponse: """Handle campaign governance plan sync.""" ... @@ -270,7 +270,7 @@ async def handle_sync_plans( async def handle_check_governance( self, request: CheckGovernanceRequest, - context: ToolContext | None = None, + context: TContext | None = None, ) -> CheckGovernanceResponse: """Handle a governance check request.""" ... @@ -279,7 +279,7 @@ async def handle_check_governance( async def handle_report_plan_outcome( self, request: ReportPlanOutcomeRequest, - context: ToolContext | None = None, + context: TContext | None = None, ) -> ReportPlanOutcomeResponse: """Handle reporting of a governed action outcome.""" ... @@ -288,7 +288,7 @@ async def handle_report_plan_outcome( async def handle_get_plan_audit_logs( self, request: GetPlanAuditLogsRequest, - context: ToolContext | None = None, + context: TContext | None = None, ) -> GetPlanAuditLogsResponse: """Handle retrieval of governance audit logs.""" ... @@ -297,7 +297,7 @@ async def handle_get_plan_audit_logs( async def handle_create_property_list( self, request: CreatePropertyListRequest, - context: ToolContext | None = None, + context: TContext | None = None, ) -> CreatePropertyListResponse: """Handle create property list request.""" ... @@ -306,7 +306,7 @@ async def handle_create_property_list( async def handle_get_property_list( self, request: GetPropertyListRequest, - context: ToolContext | None = None, + context: TContext | None = None, ) -> GetPropertyListResponse: """Handle get property list request.""" ... @@ -315,7 +315,7 @@ async def handle_get_property_list( async def handle_list_property_lists( self, request: ListPropertyListsRequest, - context: ToolContext | None = None, + context: TContext | None = None, ) -> ListPropertyListsResponse: """Handle list property lists request.""" ... @@ -324,7 +324,7 @@ async def handle_list_property_lists( async def handle_update_property_list( self, request: UpdatePropertyListRequest, - context: ToolContext | None = None, + context: TContext | None = None, ) -> UpdatePropertyListResponse: """Handle update property list request.""" ... @@ -333,7 +333,7 @@ async def handle_update_property_list( async def handle_delete_property_list( self, request: DeletePropertyListRequest, - context: ToolContext | None = None, + context: TContext | None = None, ) -> DeletePropertyListResponse: """Handle delete property list request.""" ... diff --git a/src/adcp/server/mcp_tools.py b/src/adcp/server/mcp_tools.py index d2e1b9c4..7a958e12 100644 --- a/src/adcp/server/mcp_tools.py +++ b/src/adcp/server/mcp_tools.py @@ -1187,7 +1187,9 @@ def _apply_pydantic_schemas() -> None: _apply_pydantic_schemas() -def get_tools_for_handler(handler: ADCPHandler | type[ADCPHandler]) -> list[dict[str, Any]]: +def get_tools_for_handler( + handler: ADCPHandler[Any] | type[ADCPHandler[Any]], +) -> list[dict[str, Any]]: """Return tool definitions filtered by handler type. Walks the MRO to find the matching handler base class, so subclasses @@ -1211,7 +1213,7 @@ def get_tools_for_handler(handler: ADCPHandler | type[ADCPHandler]) -> list[dict def create_tool_caller( - handler: ADCPHandler, + handler: ADCPHandler[Any], method_name: str, ) -> Callable[..., Any]: """Create a tool caller function for an ADCP handler method. @@ -1257,7 +1259,7 @@ class MCPToolSet: Provides tool definitions and handlers for registering with an MCP server. """ - def __init__(self, handler: ADCPHandler): + def __init__(self, handler: ADCPHandler[Any]): """Create tool set from handler. Args: @@ -1299,7 +1301,7 @@ def get_tool_names(self) -> list[str]: return list(self._tools.keys()) -def create_mcp_tools(handler: ADCPHandler) -> MCPToolSet: +def create_mcp_tools(handler: ADCPHandler[Any]) -> MCPToolSet: """Create MCP tools from an ADCP handler. This is the main entry point for MCP server integration. diff --git a/src/adcp/server/serve.py b/src/adcp/server/serve.py index de336f71..3eb70c78 100644 --- a/src/adcp/server/serve.py +++ b/src/adcp/server/serve.py @@ -216,7 +216,7 @@ def build_context(meta: RequestMetadata) -> ToolContext: def serve( - handler: ADCPHandler | Any, + handler: ADCPHandler[Any] | Any, *, name: str = "adcp-agent", port: int | None = None, @@ -363,7 +363,7 @@ def _bind_reusable_socket(host: str, port: int) -> Any: def _serve_mcp( - handler: ADCPHandler, + handler: ADCPHandler[Any], *, name: str, port: int | None, @@ -429,7 +429,7 @@ async def _serve() -> None: def _serve_a2a( - handler: ADCPHandler, + handler: ADCPHandler[Any], *, name: str, port: int | None, @@ -471,7 +471,7 @@ async def _serve() -> None: def create_mcp_server( - handler: ADCPHandler, + handler: ADCPHandler[Any], *, name: str = "adcp-agent", port: int | None = None, @@ -571,7 +571,7 @@ def create_mcp_server( def _register_handler_tools( mcp: Any, - handler: ADCPHandler, + handler: ADCPHandler[Any], *, include_test_controller: bool = False, context_factory: ContextFactory | None = None, diff --git a/src/adcp/server/sponsored_intelligence.py b/src/adcp/server/sponsored_intelligence.py index 611a53d1..313e39ec 100644 --- a/src/adcp/server/sponsored_intelligence.py +++ b/src/adcp/server/sponsored_intelligence.py @@ -7,11 +7,11 @@ from __future__ import annotations from abc import abstractmethod -from typing import Any +from typing import Any, Generic from pydantic import ValidationError -from adcp.server.base import ADCPHandler, NotImplementedResponse, ToolContext +from adcp.server.base import ADCPHandler, NotImplementedResponse, TContext from adcp.types import ( Error, SiGetOfferingRequest, @@ -25,7 +25,7 @@ ) -class SponsoredIntelligenceHandler(ADCPHandler): +class SponsoredIntelligenceHandler(ADCPHandler[TContext], Generic[TContext]): """Handler for Sponsored Intelligence protocol. Subclass this to implement a Sponsored Intelligence agent. All SI @@ -41,7 +41,7 @@ class MySIHandler(SponsoredIntelligenceHandler): async def handle_si_get_offering( self, request: SiGetOfferingRequest, - context: ToolContext | None = None + context: TContext | None = None ) -> SiGetOfferingResponse: # Your implementation return SiGetOfferingResponse(...) @@ -56,7 +56,7 @@ async def handle_si_get_offering( async def si_get_offering( self, params: SiGetOfferingRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> SiGetOfferingResponse | NotImplementedResponse: """Get sponsored intelligence offering. @@ -75,7 +75,7 @@ async def si_get_offering( async def si_initiate_session( self, params: SiInitiateSessionRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> SiInitiateSessionResponse | NotImplementedResponse: """Initiate sponsored intelligence session. @@ -94,7 +94,7 @@ async def si_initiate_session( async def si_send_message( self, params: SiSendMessageRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> SiSendMessageResponse | NotImplementedResponse: """Send message in sponsored intelligence session. @@ -113,7 +113,7 @@ async def si_send_message( async def si_terminate_session( self, params: SiTerminateSessionRequest | dict[str, Any], - context: ToolContext | None = None, + context: TContext | None = None, ) -> SiTerminateSessionResponse | NotImplementedResponse: """Terminate sponsored intelligence session. @@ -137,7 +137,7 @@ async def si_terminate_session( async def handle_si_get_offering( self, request: SiGetOfferingRequest, - context: ToolContext | None = None, + context: TContext | None = None, ) -> SiGetOfferingResponse: """Handle get offering request.""" ... @@ -146,7 +146,7 @@ async def handle_si_get_offering( async def handle_si_initiate_session( self, request: SiInitiateSessionRequest, - context: ToolContext | None = None, + context: TContext | None = None, ) -> SiInitiateSessionResponse: """Handle initiate session request.""" ... @@ -155,7 +155,7 @@ async def handle_si_initiate_session( async def handle_si_send_message( self, request: SiSendMessageRequest, - context: ToolContext | None = None, + context: TContext | None = None, ) -> SiSendMessageResponse: """Handle send message request.""" ... @@ -164,7 +164,7 @@ async def handle_si_send_message( async def handle_si_terminate_session( self, request: SiTerminateSessionRequest, - context: ToolContext | None = None, + context: TContext | None = None, ) -> SiTerminateSessionResponse: """Handle terminate session request.""" ... diff --git a/src/adcp/server/tmp.py b/src/adcp/server/tmp.py index 77d0a5b4..4b05c4b3 100644 --- a/src/adcp/server/tmp.py +++ b/src/adcp/server/tmp.py @@ -2,10 +2,12 @@ from __future__ import annotations -from adcp.server.base import ADCPHandler +from typing import Generic +from adcp.server.base import ADCPHandler, TContext -class TmpHandler(ADCPHandler): + +class TmpHandler(ADCPHandler[TContext], Generic[TContext]): """Handler for Temporal Matching Protocol operations. Subclass this to implement context matching and identity matching. diff --git a/tests/test_handler_typevar.py b/tests/test_handler_typevar.py new file mode 100644 index 00000000..99daff5e --- /dev/null +++ b/tests/test_handler_typevar.py @@ -0,0 +1,278 @@ +"""Runtime coverage for ``ADCPHandler[TContext]`` — closes #223. + +The TypeVar work is a typing-level refactor (mypy-visible), but the +contract it promises has runtime consequences too: + +1. Existing ``class MyAgent(ADCPHandler)`` code keeps working without + edits — unparameterised subclasses must not break. +2. Parameterising with a ``ToolContext`` subclass is a legal Generic + subscription — ``ADCPHandler[MyContext]`` resolves at class-body + time. +3. Protocol handlers (``BrandHandler``, ``ContentStandardsHandler`` etc.) + propagate the same TypeVar — downstream can write + ``class MyBrand(BrandHandler[MyContext])``. +4. At dispatch time, the handler method receives whatever ``ToolContext`` + subclass the transport hands it — no isinstance check loses the + subclass type. + +These tests are behavioural, not type-system assertions — they verify +the TypeVar machinery doesn't impose a runtime cost and that the +subclass flows through the A2A/MCP invocation paths. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import pytest + +from adcp.server import ( + ADCPHandler, + BrandHandler, + ComplianceHandler, + TmpHandler, + ToolContext, +) +from adcp.server.base import TContext # noqa: F401 — imported to pin the export + + +@dataclass +class _PlatformAdapter: + """Stand-in for a real platform adapter — the typed field a downstream + would carry on their ToolContext subclass.""" + + name: str + + +@dataclass +class _TypedContext(ToolContext): + """Demonstrates the multi-tenant pattern: handlers need typed access + to tenant + adapter fields beyond what ToolContext names.""" + + adapter: _PlatformAdapter | None = None + + +# --------------------------------------------------------------------------- +# Unparameterised subclasses — existing pattern must keep working +# --------------------------------------------------------------------------- + + +def test_unparameterised_subclass_still_works(): + """``class MyAgent(ADCPHandler)`` with no TypeVar argument must + keep working for backward compat. The bulk of existing adopters + aren't ready to introduce typed context subclasses yet.""" + + class _MyAgent(ADCPHandler): + _agent_type = "test" + + async def get_adcp_capabilities(self, params, context=None): + return {"adcp": {"major_versions": [3]}} + + agent = _MyAgent() + assert agent._agent_type == "test" + + +def test_unparameterised_protocol_handler_still_works(): + """Same backward-compat check for the non-abstract protocol handler + bases — ``BrandHandler``, ``ComplianceHandler``, ``TmpHandler`` + don't declare additional abstract methods, so they can be + subclassed directly. ``ContentStandardsHandler``, ``GovernanceHandler``, + and ``SponsoredIntelligenceHandler`` have ``handle_*`` abstracts + (predating this PR) that subclasses must implement — covered + separately by the typed-subclass tests below.""" + for cls in (BrandHandler, ComplianceHandler, TmpHandler): + + class _Concrete(cls): # type: ignore[valid-type,misc] + _agent_type = "test" + + async def get_adcp_capabilities(self, params, context=None): + return {"adcp": {"major_versions": [3]}} + + instance = _Concrete() + assert instance._agent_type == "test" + + +# --------------------------------------------------------------------------- +# Parameterised subclasses — the new capability +# --------------------------------------------------------------------------- + + +def test_parameterised_adcphandler_subclass_resolves(): + """``class MyAgent(ADCPHandler[MyContext])`` must construct without + error — the Generic subscription is the core promise of #223.""" + + class _TypedAgent(ADCPHandler[_TypedContext]): + _agent_type = "typed" + + async def get_adcp_capabilities(self, params, context=None): + return {"adcp": {"major_versions": [3]}} + + agent = _TypedAgent() + # __class_getitem__ returned something sensible — we can subclass it + # and instantiate the subclass. + assert agent._agent_type == "typed" + + +def test_protocol_handler_propagates_typevar(): + """``BrandHandler[MyContext]`` must work the same way. Without + this the TypeVar on the base is useless for the specialised + handler classes.""" + + class _TypedBrand(BrandHandler[_TypedContext]): + _agent_type = "typed brand" + + agent = _TypedBrand() + assert agent._agent_type == "typed brand" + + +def test_handler_receives_subclass_at_dispatch_time(): + """The TypeVar is static-type narrowing, but the runtime path must + preserve the subclass identity on the ``context`` argument — a + handler that does ``context.adapter`` at runtime needs the subclass + to survive the dispatch.""" + received: list[Any] = [] + + class _TypedAgent(ADCPHandler[_TypedContext]): + _agent_type = "adapter-reader" + + async def get_adcp_capabilities(self, params, context=None): + received.append(context) + return {"adcp": {"major_versions": [3]}} + + import asyncio + + agent = _TypedAgent() + ctx = _TypedContext( + caller_identity="p-1", + tenant_id="t-1", + adapter=_PlatformAdapter(name="demo"), + ) + asyncio.run(agent.get_adcp_capabilities({}, ctx)) + + assert len(received) == 1 + got = received[0] + assert isinstance(got, _TypedContext) + assert got.adapter is not None + assert got.adapter.name == "demo" + assert got.caller_identity == "p-1" + assert got.tenant_id == "t-1" + + +def test_protocol_handler_subclass_receives_typed_context(): + """End-to-end for a specialised handler: BrandHandler[MyContext] + subclass's methods receive the typed subclass at dispatch.""" + received: list[Any] = [] + + class _TypedBrand(BrandHandler[_TypedContext]): + _agent_type = "typed-brand" + + async def get_adcp_capabilities(self, params, context=None): + received.append(context) + return {"adcp": {"major_versions": [3]}} + + import asyncio + + agent = _TypedBrand() + ctx = _TypedContext( + caller_identity="brand-p", + adapter=_PlatformAdapter(name="brand-adapter"), + ) + asyncio.run(agent.get_adcp_capabilities({}, ctx)) + + assert isinstance(received[0], _TypedContext) + assert received[0].adapter is not None + assert received[0].adapter.name == "brand-adapter" + + +# --------------------------------------------------------------------------- +# Negative case: the TypeVar has a bound +# --------------------------------------------------------------------------- + + +def test_typevar_is_bound_to_toolcontext(): + """The TypeVar bound prevents parameterising with an unrelated + class. At runtime Python doesn't enforce the bound (only mypy + does), so this test just asserts the bound attribute — the static + check is mypy's job and is covered by the CI mypy step.""" + from adcp.server.base import TContext as _TContext + + # TypeVar has __bound__ (forward ref or evaluated class). + bound = _TContext.__bound__ + # Forward ref evaluates to the string; evaluated binding to the class. + assert bound is ToolContext or ( + hasattr(bound, "__forward_arg__") and bound.__forward_arg__ == "ToolContext" + ) + + +# --------------------------------------------------------------------------- +# ADCPAgentExecutor integration — the subclass still flows through +# --------------------------------------------------------------------------- + + +async def test_typed_handler_works_under_a2a_executor(): + """A handler parameterised with a custom ToolContext subclass must + still dispatch correctly under the A2A executor. Runtime doesn't + touch the TypeVar directly (the executor passes whatever context + the context_factory returned), but this pins the no-regression + promise: adding the TypeVar didn't break the A2A dispatch path.""" + from a2a.server.agent_execution.context import RequestContext + from a2a.server.events.event_queue import EventQueue + from a2a.types import DataPart, Message, MessageSendParams, Part, Role, Task + + from adcp.server.a2a_server import ADCPAgentExecutor + + class _TypedAgent(ADCPHandler[_TypedContext]): + _agent_type = "typed-executor-test" + + async def get_adcp_capabilities(self, params, context=None): + return {"adcp": {"major_versions": [3]}, "supported_protocols": ["media_buy"]} + + executor = ADCPAgentExecutor(_TypedAgent()) + msg = Message( + message_id="m-1", + role=Role.user, + parts=[Part(root=DataPart(data={"skill": "get_adcp_capabilities", "parameters": {}}))], + ) + ctx = RequestContext(request=MessageSendParams(message=msg)) + queue = EventQueue() + await executor.execute(ctx, queue) + + event = await queue.dequeue_event(no_wait=True) + assert isinstance(event, Task) + assert event.status.state == "completed" + + +# --------------------------------------------------------------------------- +# Handler method signature annotations survive the TypeVar +# --------------------------------------------------------------------------- + + +def test_handler_method_signatures_accept_subclass_positionally(): + """A sanity check that handler methods accept a ``ToolContext`` + subclass positionally — the rewrite of 57 method sigs from + ``context: ToolContext | None`` to ``context: TContext | None`` + must not have shifted any parameter positions.""" + + class _Agent(ADCPHandler): + async def get_adcp_capabilities(self, params, context=None): + return {"adcp": {}} + + async def get_products(self, params, context=None): + return {"products": []} + + import inspect + + for method_name in ("get_adcp_capabilities", "get_products", "create_media_buy"): + method = getattr(ADCPHandler, method_name, None) + assert method is not None, f"{method_name} missing from ADCPHandler" + sig = inspect.signature(method) + params = list(sig.parameters.keys()) + # self, params, context — in that order, minimum. + assert params[0] == "self" + assert params[1] == "params" + assert "context" in params + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From df5898ca7dead3acb25fc7d5571ae1f88c55a52e Mon Sep 17 00:00:00 2001 From: Brian O'Kelley Date: Mon, 20 Apr 2026 10:06:08 -0400 Subject: [PATCH 2/2] =?UTF-8?q?fix(server):=20PR=20#234=20expert-review=20?= =?UTF-8?q?followups=20=E2=80=94=20TypeVar=20tests=20&=20docs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - test_unparameterised_protocol_handler_still_works now covers ContentStandardsHandler, GovernanceHandler, and SponsoredIntelligenceHandler by dynamically building concrete subclasses that stub every abstract handle_. Proves the TypeVar refactor didn't accidentally add a new abstract on the base. - test_typevar_is_bound_to_toolcontext now forces forward-ref resolution via typing.get_type_hints so a typo in the bound (e.g. "ToolContect") would fail the test. Previously the unresolved forward-ref string was accepted as proof enough. - test_handler_method_signatures_accept_subclass_positionally renamed to test_handler_method_signatures_preserve_parameter_order — reflects what it actually checks. - Documented ADCPHandler[Any] choice in mcp_tools.py and a2a_server.py module docstrings: these modules dispatch by tool name and never read typed context fields, so Any is correct and avoids cascading the TypeVar through plumbing. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/adcp/server/a2a_server.py | 9 +++ src/adcp/server/mcp_tools.py | 12 ++++ tests/test_handler_typevar.py | 113 ++++++++++++++++++++++++++-------- 3 files changed, 107 insertions(+), 27 deletions(-) diff --git a/src/adcp/server/a2a_server.py b/src/adcp/server/a2a_server.py index 60148e19..45b5c585 100644 --- a/src/adcp/server/a2a_server.py +++ b/src/adcp/server/a2a_server.py @@ -5,6 +5,15 @@ from adcp.server import ADCPHandler, serve serve(MyHandler(), name="my-agent", transport="a2a") + +.. note:: + Function signatures here use ``ADCPHandler[Any]`` rather than a + propagated ``TContext`` TypeVar. This module dispatches by tool + name and never reads typed fields off the context, so ``Any`` is + both correct and keeps the call sites tidy — downstream code that + needs typed context (their own handler subclass) keeps the TypeVar + all the way to dispatch via :class:`ADCPHandler`. See the matching + note in :mod:`adcp.server.mcp_tools`. """ from __future__ import annotations diff --git a/src/adcp/server/mcp_tools.py b/src/adcp/server/mcp_tools.py index 7a958e12..39f5293f 100644 --- a/src/adcp/server/mcp_tools.py +++ b/src/adcp/server/mcp_tools.py @@ -2,6 +2,18 @@ """MCP server integration helpers. Provides utilities for registering ADCP handlers with MCP servers. + +.. note:: + Function signatures in this module use ``ADCPHandler[Any]`` rather + than a propagated ``TContext`` TypeVar. The rationale: these + functions (``get_tools_for_handler``, ``create_mcp_tools``, etc.) + treat the handler opaquely — they walk the MRO and dispatch by tool + name without ever touching the ``context`` argument's typed fields. + Binding a TypeVar here would force callers to narrow at the call + site for no runtime benefit, and cascade the TypeVar through every + plumbing function in :mod:`adcp.server.serve`. ``Any`` keeps the + plumbing honest: the static type says "this code works with any + ``ToolContext`` subclass," which is exactly true. """ from __future__ import annotations diff --git a/tests/test_handler_typevar.py b/tests/test_handler_typevar.py index 99daff5e..5e408ca3 100644 --- a/tests/test_handler_typevar.py +++ b/tests/test_handler_typevar.py @@ -31,6 +31,9 @@ ADCPHandler, BrandHandler, ComplianceHandler, + ContentStandardsHandler, + GovernanceHandler, + SponsoredIntelligenceHandler, TmpHandler, ToolContext, ) @@ -74,13 +77,18 @@ async def get_adcp_capabilities(self, params, context=None): def test_unparameterised_protocol_handler_still_works(): - """Same backward-compat check for the non-abstract protocol handler - bases — ``BrandHandler``, ``ComplianceHandler``, ``TmpHandler`` - don't declare additional abstract methods, so they can be - subclassed directly. ``ContentStandardsHandler``, ``GovernanceHandler``, - and ``SponsoredIntelligenceHandler`` have ``handle_*`` abstracts - (predating this PR) that subclasses must implement — covered - separately by the typed-subclass tests below.""" + """Backward-compat check for every protocol handler base — + unparameterised subclasses must keep instantiating. + + ``BrandHandler``, ``ComplianceHandler``, ``TmpHandler`` are + non-abstract and subclass directly. + + ``ContentStandardsHandler``, ``GovernanceHandler``, and + ``SponsoredIntelligenceHandler`` declare ``handle_`` abstract + methods (predating this PR). We build minimal concrete subclasses + that stub every abstract so we can prove the TypeVar refactor + didn't accidentally add a new unimplementable abstract on the base. + """ for cls in (BrandHandler, ComplianceHandler, TmpHandler): class _Concrete(cls): # type: ignore[valid-type,misc] @@ -92,6 +100,40 @@ async def get_adcp_capabilities(self, params, context=None): instance = _Concrete() assert instance._agent_type == "test" + # Abstract bases — build concrete via a type() call so every + # abstract handle_ gets a stub in the class namespace at + # creation time (ABC machinery freezes __abstractmethods__ there). + for abstract_base in ( + ContentStandardsHandler, + GovernanceHandler, + SponsoredIntelligenceHandler, + ): + abstracts = { + name + for name in dir(abstract_base) + if name.startswith("handle_") + and getattr(getattr(abstract_base, name, None), "__isabstractmethod__", False) + } + + async def _capabilities(self, params, context=None): # noqa: ARG001 + return {"adcp": {"major_versions": [3]}} + + async def _stub(self, request, context=None): # noqa: ARG001 + return {} + + namespace: dict[str, Any] = { + "_agent_type": "test", + "get_adcp_capabilities": _capabilities, + } + for _name in abstracts: + namespace[_name] = _stub + + concrete = type(f"_{abstract_base.__name__}Concrete", (abstract_base,), namespace) + instance = concrete() + assert ( + instance._agent_type == "test" + ), f"{abstract_base.__name__} unparameterised subclass failed to instantiate" + # --------------------------------------------------------------------------- # Parameterised subclasses — the new capability @@ -193,16 +235,39 @@ async def get_adcp_capabilities(self, params, context=None): def test_typevar_is_bound_to_toolcontext(): """The TypeVar bound prevents parameterising with an unrelated class. At runtime Python doesn't enforce the bound (only mypy - does), so this test just asserts the bound attribute — the static - check is mypy's job and is covered by the CI mypy step.""" + does), so this test asserts the bound resolves to ``ToolContext`` — + not just that *some* bound exists. Previously this accepted the + unresolved forward-reference string as proof enough, which meant a + typo in the bound (e.g. ``ToolContect``) would have silently + passed.""" + import typing + + from adcp.server import base as _base from adcp.server.base import TContext as _TContext - # TypeVar has __bound__ (forward ref or evaluated class). bound = _TContext.__bound__ - # Forward ref evaluates to the string; evaluated binding to the class. - assert bound is ToolContext or ( - hasattr(bound, "__forward_arg__") and bound.__forward_arg__ == "ToolContext" - ) + if bound is None: + pytest.fail("TContext has no bound") + + # Force forward-ref resolution against the module namespace TContext + # lives in. typing.get_type_hints is the blessed API for this; it + # walks the annotation through typing._eval_type and returns the + # actual class the string resolves to. + if hasattr(bound, "__forward_arg__"): + resolved = typing.get_type_hints( + type( + "_Probe", + (), + {"__annotations__": {"x": bound}, "__module__": _base.__name__}, + ), + globalns=vars(_base), + )["x"] + else: + resolved = bound + + assert ( + resolved is ToolContext + ), f"TContext bound did not resolve to ToolContext; got {resolved!r}" # --------------------------------------------------------------------------- @@ -248,19 +313,13 @@ async def get_adcp_capabilities(self, params, context=None): # --------------------------------------------------------------------------- -def test_handler_method_signatures_accept_subclass_positionally(): - """A sanity check that handler methods accept a ``ToolContext`` - subclass positionally — the rewrite of 57 method sigs from - ``context: ToolContext | None`` to ``context: TContext | None`` - must not have shifted any parameter positions.""" - - class _Agent(ADCPHandler): - async def get_adcp_capabilities(self, params, context=None): - return {"adcp": {}} - - async def get_products(self, params, context=None): - return {"products": []} - +def test_handler_method_signatures_preserve_parameter_order(): + """Sanity check on the mechanical rewrite of 57 method sigs — the + ``context: ToolContext | None`` → ``context: TContext | None`` + change is a single-word swap in the annotation and must not have + shifted parameter positions or renamed ``params``. Failure here + typically means a stray sed corrupted a signature. + """ import inspect for method_name in ("get_adcp_capabilities", "get_products", "create_media_buy"):