diff --git a/docs/docs/using/plugins/index.md b/docs/docs/using/plugins/index.md index fd1b32590..e618a19a9 100644 --- a/docs/docs/using/plugins/index.md +++ b/docs/docs/using/plugins/index.md @@ -863,7 +863,7 @@ class MyPlugin(Plugin): ### Plugin Context and State -Each hook function has a `context` object of type `PluginContext` which is designed to allow plugins to pass state between one another (across pre/post hook pairs) or for a plugin to pass state information to itself across pre/post hook pairs. The plugin context looks as follows: +Each hook function has a `context` object of type `PluginContext` which is designed to allow plugins to pass state between one another across all hook types in a request, or for a plugin to pass state information to itself across different hooks. The plugin context looks as follows: ```python class GlobalContext(BaseModel): @@ -900,10 +900,9 @@ class PluginContext(BaseModel): metadata: dict[str, Any] = Field(default_factory=dict) ``` -As can be seen, the `PluginContext` has both a `state` dictionary and a `global_context` object that also has a `state` dictionary. A single plugin can share state between pre/post hook pairs by using the -the `PluginContext` state dictionary. It can share state with other plugins using the `context.global_context.state` dictionary. Metadata for the specific hook site is passed in through the `metadata` dictionaries in the `context.global_context.metadata`. It is meant to be read-only. The `context.metadata` is plugin specific metadata and can be used to store metadata information such as timing information. +As can be seen, the `PluginContext` has both a `state` dictionary and a `global_context` object that also has a `state` dictionary. A single plugin can share state across all hooks in a request by using the `PluginContext` state dictionary. It can share state with other plugins using the `context.global_context.state` dictionary. Metadata for the specific hook site is passed in through the `metadata` dictionaries in the `context.global_context.metadata`. It is meant to be read-only. The `context.metadata` is plugin specific metadata and can be used to store metadata information such as timing information. -The following shows how plugins can maintain state between pre/post hooks: +The following shows how plugins can maintain state across different hooks: ```python async def prompt_pre_fetch(self, payload, context): @@ -926,7 +925,7 @@ async def prompt_post_fetch(self, payload, context): #### Tool and Gateway Metadata -Currently, the tool pre/post hooks have access to tool and gateway metadata through the global context metadata dictionary. They are accessible as follows: +Tool hooks have access to tool and gateway metadata through the global context metadata dictionary. They are accessible as follows: It can be accessed inside of the tool hooks through: diff --git a/mcpgateway/auth.py b/mcpgateway/auth.py index c538a7479..2eaa8872f 100644 --- a/mcpgateway/auth.py +++ b/mcpgateway/auth.py @@ -170,22 +170,25 @@ async def get_current_user( headers = dict(request.headers) # Get request ID from request state (set by middleware) or generate new one - request_id = None - if request and hasattr(request, "state") and hasattr(request.state, "request_id"): - request_id = request.state.request_id - else: + request_id = getattr(request.state, "request_id", None) if request else None + if not request_id: request_id = uuid.uuid4().hex - # Create global context - global_context = GlobalContext( - request_id=request_id, - server_id=None, - tenant_id=None, - ) + # Get plugin contexts from request state if available + global_context = getattr(request.state, "plugin_global_context", None) if request else None + if not global_context: + # Create global context + global_context = GlobalContext( + request_id=request_id, + server_id=None, + tenant_id=None, + ) + + context_table = getattr(request.state, "plugin_context_table", None) if request else None # Invoke custom auth resolution hook # violations_as_exceptions=True so PluginViolationError is raised for explicit denials - auth_result, _ = await plugin_manager.invoke_hook( + auth_result, context_table_result = await plugin_manager.invoke_hook( HttpHookType.HTTP_AUTH_RESOLVE_USER, payload=HttpAuthResolveUserPayload( credentials=credentials_dict, @@ -194,7 +197,7 @@ async def get_current_user( client_port=client_port, ), global_context=global_context, - local_contexts=None, + local_contexts=context_table, violations_as_exceptions=True, # Raise PluginViolationError for auth denials ) @@ -215,12 +218,17 @@ async def get_current_user( ) # Store auth_method in request.state so it can be accessed by RBAC middleware - if request and hasattr(request, "state") and auth_result.metadata: + if request and auth_result.metadata: auth_method = auth_result.metadata.get("auth_method") if auth_method: request.state.auth_method = auth_method logger.debug(f"Stored auth_method '{auth_method}' in request.state") + if request and context_table_result: + request.state.plugin_context_table = context_table_result + + if request and global_context: + request.state.plugin_global_context = global_context return user # If continue_processing=True (no payload), fall through to standard auth @@ -294,7 +302,7 @@ async def get_current_user( # Check team level token, if applicable. If public token, then will be defaulted to personal team. team_id = await get_team_from_token(payload, db) - if request and hasattr(request, "state"): + if request: request.state.team_id = team_id except HTTPException: diff --git a/mcpgateway/main.py b/mcpgateway/main.py index 032e30c5d..bcb276958 100644 --- a/mcpgateway/main.py +++ b/mcpgateway/main.py @@ -2960,9 +2960,21 @@ async def read_resource(resource_id: str, request: Request, db: Session = Depend if cached := resource_cache.get(resource_id): return cached + # Get plugin contexts from request.state for cross-hook sharing + plugin_context_table = getattr(request.state, "plugin_context_table", None) + plugin_global_context = getattr(request.state, "plugin_global_context", None) + try: # Call service with context for plugin support - content = await resource_service.read_resource(db, resource_id=resource_id, request_id=request_id, user=user, server_id=server_id) + content = await resource_service.read_resource( + db, + resource_id=resource_id, + request_id=request_id, + user=user, + server_id=server_id, + plugin_context_table=plugin_context_table, + plugin_global_context=plugin_global_context, + ) except (ResourceNotFoundError, ResourceError) as exc: # Translate to FastAPI HTTP error raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc @@ -3289,6 +3301,7 @@ async def create_prompt( @prompt_router.post("/{prompt_id}") @require_permission("prompts.read") async def get_prompt( + request: Request, prompt_id: str, args: Dict[str, str] = Body({}), db: Session = Depends(get_db), @@ -3301,6 +3314,7 @@ async def get_prompt( Args: + request: FastAPI request object. prompt_id: ID of the prompt. args: Template arguments. db: Database session. @@ -3314,9 +3328,19 @@ async def get_prompt( """ logger.debug(f"User: {user} requested prompt: {prompt_id} with args={args}") + # Get plugin contexts from request.state for cross-hook sharing + plugin_context_table = getattr(request.state, "plugin_context_table", None) + plugin_global_context = getattr(request.state, "plugin_global_context", None) + try: PromptExecuteArgs(args=args) - result = await prompt_service.get_prompt(db, prompt_id, args) + result = await prompt_service.get_prompt( + db, + prompt_id, + args, + plugin_context_table=plugin_context_table, + plugin_global_context=plugin_global_context, + ) logger.debug(f"Prompt execution successful for '{prompt_id}'") except Exception as ex: logger.error(f"Could not retrieve prompt {prompt_id}: {ex}") @@ -3334,6 +3358,7 @@ async def get_prompt( @prompt_router.get("/{prompt_id}") @require_permission("prompts.read") async def get_prompt_no_args( + request: Request, prompt_id: str, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions), @@ -3343,6 +3368,7 @@ async def get_prompt_no_args( This endpoint is for convenience when no arguments are needed. Args: + request: FastAPI request object. prompt_id: The ID of the prompt to retrieve db: Database session user: Authenticated user @@ -3354,7 +3380,18 @@ async def get_prompt_no_args( Exception: Re-raised from prompt service. """ logger.debug(f"User: {user} requested prompt: {prompt_id} with no arguments") - return await prompt_service.get_prompt(db, prompt_id, {}) + + # Get plugin contexts from request.state for cross-hook sharing + plugin_context_table = getattr(request.state, "plugin_context_table", None) + plugin_global_context = getattr(request.state, "plugin_global_context", None) + + return await prompt_service.get_prompt( + db, + prompt_id, + {}, + plugin_context_table=plugin_context_table, + plugin_global_context=plugin_global_context, + ) @prompt_router.put("/{prompt_id}", response_model=PromptRead) @@ -3921,8 +3958,18 @@ async def handle_rpc(request: Request, db: Session = Depends(get_db), user=Depen raise JSONRPCError(-32602, "Missing resource URI in parameters", params) # Get user email for OAuth token selection user_email = get_user_email(user) + # Get plugin contexts from request.state for cross-hook sharing + plugin_context_table = getattr(request.state, "plugin_context_table", None) + plugin_global_context = getattr(request.state, "plugin_global_context", None) try: - result = await resource_service.read_resource(db, uri, request_id=request_id, user=user_email) + result = await resource_service.read_resource( + db, + resource_uri=uri, + request_id=request_id, + user=user_email, + plugin_context_table=plugin_context_table, + plugin_global_context=plugin_global_context, + ) if hasattr(result, "model_dump"): result = {"contents": [result.model_dump(by_alias=True, exclude_none=True)]} else: @@ -3966,7 +4013,16 @@ async def handle_rpc(request: Request, db: Session = Depends(get_db), user=Depen arguments = params.get("arguments", {}) if not name: raise JSONRPCError(-32602, "Missing prompt name in parameters", params) - result = await prompt_service.get_prompt(db, name, arguments) + # Get plugin contexts from request.state for cross-hook sharing + plugin_context_table = getattr(request.state, "plugin_context_table", None) + plugin_global_context = getattr(request.state, "plugin_global_context", None) + result = await prompt_service.get_prompt( + db, + name, + arguments, + plugin_context_table=plugin_context_table, + plugin_global_context=plugin_global_context, + ) if hasattr(result, "model_dump"): result = result.model_dump(by_alias=True, exclude_none=True) elif method == "ping": @@ -3981,8 +4037,19 @@ async def handle_rpc(request: Request, db: Session = Depends(get_db), user=Depen raise JSONRPCError(-32602, "Missing tool name in parameters", params) # Get user email for OAuth token selection user_email = get_user_email(user) + # Get plugin contexts from request.state for cross-hook sharing + plugin_context_table = getattr(request.state, "plugin_context_table", None) + plugin_global_context = getattr(request.state, "plugin_global_context", None) try: - result = await tool_service.invoke_tool(db=db, name=name, arguments=arguments, request_headers=headers, app_user_email=user_email) + result = await tool_service.invoke_tool( + db=db, + name=name, + arguments=arguments, + request_headers=headers, + app_user_email=user_email, + plugin_context_table=plugin_context_table, + plugin_global_context=plugin_global_context, + ) if hasattr(result, "model_dump"): result = result.model_dump(by_alias=True, exclude_none=True) except ValueError: diff --git a/mcpgateway/middleware/http_auth_middleware.py b/mcpgateway/middleware/http_auth_middleware.py index 5b8940279..84058641f 100644 --- a/mcpgateway/middleware/http_auth_middleware.py +++ b/mcpgateway/middleware/http_auth_middleware.py @@ -95,6 +95,12 @@ async def dispatch(self, request: Request, call_next): violations_as_exceptions=False, # Don't block on pre-request violations ) + if context_table: + request.state.plugin_context_table = context_table + + if global_context: + request.state.plugin_global_context = global_context + # Apply modified headers if plugin returned them if pre_result.modified_payload: # Modify request headers by updating request.scope["headers"] diff --git a/mcpgateway/middleware/rbac.py b/mcpgateway/middleware/rbac.py index e8fc63080..6ea8f07d2 100644 --- a/mcpgateway/middleware/rbac.py +++ b/mcpgateway/middleware/rbac.py @@ -150,6 +150,11 @@ async def protected_route(user = Depends(get_current_user_with_permissions)): request_id = getattr(request.state, "request_id", None) team_id = getattr(request.state, "team_id", None) + # Read plugin context data from request.state for cross-hook context sharing + # (set by HttpAuthMiddleware for passing contexts between different hook types) + plugin_context_table = getattr(request.state, "plugin_context_table", None) + plugin_global_context = getattr(request.state, "plugin_global_context", None) + # Add request context for permission auditing return { "email": user.email, @@ -161,6 +166,8 @@ async def protected_route(user = Depends(get_current_user_with_permissions)): "auth_method": auth_method, # Include auth_method from plugin "request_id": request_id, # Include request_id from middleware "team_id": team_id, # Include team_id from token + "plugin_context_table": plugin_context_table, # Plugin contexts for cross-hook sharing + "plugin_global_context": plugin_global_context, # Global context for consistency } except Exception as e: logger.error(f"Authentication failed: {type(e).__name__}: {e}") @@ -256,18 +263,24 @@ async def wrapper(*args, **kwargs): plugin_manager = get_plugin_manager() if plugin_manager: - # Get request_id from user_context (passed from get_current_user_with_permissions) - # Generate a fallback if not present - request_id = user_context.get("request_id") or uuid.uuid4().hex - - # Create global context for plugin invocation - global_context = GlobalContext( - request_id=request_id, - server_id=None, - tenant_id=None, - ) + # Get plugin contexts from user_context (stored in request.state by HttpAuthMiddleware) + # These enable cross-hook context sharing between HTTP_PRE_REQUEST and HTTP_AUTH_CHECK_PERMISSION + plugin_context_table = user_context.get("plugin_context_table") + plugin_global_context = user_context.get("plugin_global_context") + + # Reuse existing global context from middleware if available for consistency + # Otherwise create a new one (fallback for cases where middleware didn't run) + if plugin_global_context: + global_context = plugin_global_context + else: + request_id = user_context.get("request_id") or uuid.uuid4().hex + global_context = GlobalContext( + request_id=request_id, + server_id=None, + tenant_id=None, + ) - # Invoke permission check hook + # Invoke permission check hook, passing plugin contexts from HTTP_PRE_REQUEST hook result, _ = await plugin_manager.invoke_hook( HttpHookType.HTTP_AUTH_CHECK_PERMISSION, payload=HttpAuthCheckPermissionPayload( @@ -281,6 +294,7 @@ async def wrapper(*args, **kwargs): user_agent=user_context.get("user_agent"), ), global_context=global_context, + local_contexts=plugin_context_table, # Pass context table for cross-hook state ) # If a plugin made a decision, respect it diff --git a/mcpgateway/plugins/framework/__init__.py b/mcpgateway/plugins/framework/__init__.py index 3a4286bb4..5c19f364e 100644 --- a/mcpgateway/plugins/framework/__init__.py +++ b/mcpgateway/plugins/framework/__init__.py @@ -54,6 +54,7 @@ PluginCondition, PluginConfig, PluginContext, + PluginContextTable, PluginErrorModel, PluginMode, PluginPayload, @@ -120,6 +121,7 @@ def get_plugin_manager() -> Optional[PluginManager]: "PluginCondition", "PluginConfig", "PluginContext", + "PluginContextTable", "PluginError", "PluginErrorModel", "PluginLoader", diff --git a/mcpgateway/services/prompt_service.py b/mcpgateway/services/prompt_service.py index 456ed2461..b0c0b8bad 100644 --- a/mcpgateway/services/prompt_service.py +++ b/mcpgateway/services/prompt_service.py @@ -36,7 +36,7 @@ from mcpgateway.db import Prompt as DbPrompt from mcpgateway.db import PromptMetric, server_prompt_association from mcpgateway.observability import create_span -from mcpgateway.plugins.framework import GlobalContext, PluginManager, PromptHookType, PromptPosthookPayload, PromptPrehookPayload +from mcpgateway.plugins.framework import GlobalContext, PluginContextTable, PluginManager, PromptHookType, PromptPosthookPayload, PromptPrehookPayload from mcpgateway.schemas import PromptCreate, PromptRead, PromptUpdate, TopPerformer from mcpgateway.services.logging_service import LoggingService from mcpgateway.services.observability_service import current_trace_id, ObservabilityService @@ -652,6 +652,8 @@ async def get_prompt( tenant_id: Optional[str] = None, server_id: Optional[str] = None, request_id: Optional[str] = None, + plugin_context_table: Optional[PluginContextTable] = None, + plugin_global_context: Optional[GlobalContext] = None, ) -> PromptResult: """Get a prompt template and optionally render it. @@ -663,6 +665,8 @@ async def get_prompt( tenant_id: Optional tenant identifier for plugin context server_id: Optional server identifier for plugin context request_id: Optional request ID, generated if not provided + plugin_context_table: Optional plugin context table from previous hooks for cross-hook state sharing. + plugin_global_context: Optional global context from middleware for consistency across hooks. Returns: Prompt result with rendered messages @@ -747,14 +751,30 @@ async def get_prompt( prompt_id_int = prompt_id if self._plugin_manager: - if not request_id: - request_id = uuid.uuid4().hex - global_context = GlobalContext(request_id=request_id, user=user, server_id=server_id, tenant_id=tenant_id) + # Use existing context_table from previous hooks if available + context_table = plugin_context_table + + # Reuse existing global_context from middleware or create new one + if plugin_global_context: + global_context = plugin_global_context + # Update fields with prompt-specific information + if user: + global_context.user = user + if server_id: + global_context.server_id = server_id + if tenant_id: + global_context.tenant_id = tenant_id + else: + # Create new context (fallback when middleware didn't run) + if not request_id: + request_id = uuid.uuid4().hex + global_context = GlobalContext(request_id=request_id, user=user, server_id=server_id, tenant_id=tenant_id) + pre_result, context_table = await self._plugin_manager.invoke_hook( PromptHookType.PROMPT_PRE_FETCH, payload=PromptPrehookPayload(prompt_id=str(prompt_id), args=arguments), global_context=global_context, - local_contexts=None, + local_contexts=context_table, # Pass context from previous hooks violations_as_exceptions=True, ) diff --git a/mcpgateway/services/resource_service.py b/mcpgateway/services/resource_service.py index 5bc52b450..5cb8670f4 100644 --- a/mcpgateway/services/resource_service.py +++ b/mcpgateway/services/resource_service.py @@ -59,7 +59,7 @@ # Plugin support imports (conditional) try: # First-Party - from mcpgateway.plugins.framework import GlobalContext, PluginManager, ResourceHookType, ResourcePostFetchPayload, ResourcePreFetchPayload + from mcpgateway.plugins.framework import GlobalContext, PluginContextTable, PluginManager, ResourceHookType, ResourcePostFetchPayload, ResourcePreFetchPayload PLUGINS_AVAILABLE = True except ImportError: @@ -685,6 +685,8 @@ async def read_resource( user: Optional[str] = None, server_id: Optional[str] = None, include_inactive: bool = False, + plugin_context_table: Optional[PluginContextTable] = None, + plugin_global_context: Optional[GlobalContext] = None, ) -> ResourceContent: """Read a resource's content with plugin hook support. @@ -696,6 +698,8 @@ async def read_resource( user: Optional user making the request. server_id: Optional server ID for context. include_inactive: Whether to include inactive resources. Defaults to False. + plugin_context_table: Optional plugin context table from previous hooks for cross-hook state sharing. + plugin_global_context: Optional global context from middleware for consistency across hooks. Returns: Resource content object @@ -811,12 +815,29 @@ async def read_resource( # Attempt to fallback to attribute access user_id = getattr(user, "email", None) - global_context = GlobalContext(request_id=request_id, user=user_id, server_id=server_id) + # Use existing global_context from middleware or create new one + if plugin_global_context: + global_context = plugin_global_context + # Update fields with resource-specific information + if user_id: + global_context.user = user_id + if server_id: + global_context.server_id = server_id + else: + # Create new context (fallback when middleware didn't run) + global_context = GlobalContext(request_id=request_id, user=user_id, server_id=server_id) + # Create pre-fetch payload pre_payload = ResourcePreFetchPayload(uri=uri, metadata={}) - # Execute pre-fetch hooks - pre_result, contexts = await self._plugin_manager.invoke_hook(ResourceHookType.RESOURCE_PRE_FETCH, pre_payload, global_context, violations_as_exceptions=True) + # Execute pre-fetch hooks with context from previous hooks + pre_result, contexts = await self._plugin_manager.invoke_hook( + ResourceHookType.RESOURCE_PRE_FETCH, + pre_payload, + global_context, + local_contexts=plugin_context_table, # Pass context from previous hooks + violations_as_exceptions=True, + ) # Use modified URI if plugin changed it if pre_result.modified_payload: uri = pre_result.modified_payload.uri diff --git a/mcpgateway/services/tool_service.py b/mcpgateway/services/tool_service.py index 2314c10b5..ae8b9ce1f 100644 --- a/mcpgateway/services/tool_service.py +++ b/mcpgateway/services/tool_service.py @@ -51,7 +51,17 @@ from mcpgateway.db import Tool as DbTool from mcpgateway.db import ToolMetric from mcpgateway.observability import create_span -from mcpgateway.plugins.framework import GlobalContext, HttpHeaderPayload, PluginError, PluginManager, PluginViolationError, ToolHookType, ToolPostInvokePayload, ToolPreInvokePayload +from mcpgateway.plugins.framework import ( + GlobalContext, + HttpHeaderPayload, + PluginContextTable, + PluginError, + PluginManager, + PluginViolationError, + ToolHookType, + ToolPostInvokePayload, + ToolPreInvokePayload, +) from mcpgateway.plugins.framework.constants import GATEWAY_METADATA, TOOL_METADATA from mcpgateway.schemas import ToolCreate, ToolRead, ToolUpdate, TopPerformer from mcpgateway.services.logging_service import LoggingService @@ -1103,7 +1113,16 @@ async def toggle_tool_status(self, db: Session, tool_id: str, activate: bool, re db.rollback() raise ToolError(f"Failed to toggle tool status: {str(e)}") - async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any], request_headers: Optional[Dict[str, str]] = None, app_user_email: Optional[str] = None) -> ToolResult: + async def invoke_tool( + self, + db: Session, + name: str, + arguments: Dict[str, Any], + request_headers: Optional[Dict[str, str]] = None, + app_user_email: Optional[str] = None, + plugin_context_table: Optional[PluginContextTable] = None, + plugin_global_context: Optional[GlobalContext] = None, + ) -> ToolResult: """ Invoke a registered tool and record execution metrics. @@ -1115,6 +1134,8 @@ async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any], r Defaults to None. app_user_email (Optional[str], optional): MCP Gateway user email for OAuth token retrieval. Required for OAuth-protected gateways. + plugin_context_table: Optional plugin context table from previous hooks for cross-hook state sharing. + plugin_global_context: Optional global context from middleware for consistency across hooks. Returns: Tool invocation result. @@ -1157,12 +1178,22 @@ async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any], r return await self._invoke_a2a_tool(db=db, tool=tool, arguments=arguments) # Plugin hook: tool pre-invoke - context_table = None - request_id = uuid.uuid4().hex - # Use gateway_id if available, otherwise use a generic server identifier - gateway_id = getattr(tool, "gateway_id", "unknown") - server_id = gateway_id if isinstance(gateway_id, str) else "unknown" - global_context = GlobalContext(request_id=request_id, server_id=server_id, tenant_id=None) + # Use existing context_table from previous hooks if available + context_table = plugin_context_table + + # Reuse existing global_context from middleware or create new one + if plugin_global_context: + global_context = plugin_global_context + # Update server_id if we have better information + gateway_id = getattr(tool, "gateway_id", None) + if gateway_id and isinstance(gateway_id, str): + global_context.server_id = gateway_id + else: + # Create new context (fallback when middleware didn't run) + request_id = uuid.uuid4().hex + gateway_id = getattr(tool, "gateway_id", "unknown") + server_id = gateway_id if isinstance(gateway_id, str) else "unknown" + global_context = GlobalContext(request_id=request_id, server_id=server_id, tenant_id=None, user=app_user_email) start_time = time.monotonic() success = False @@ -1210,7 +1241,7 @@ async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any], r ToolHookType.TOOL_PRE_INVOKE, payload=ToolPreInvokePayload(name=name, args=arguments, headers=HttpHeaderPayload(root=headers)), global_context=global_context, - local_contexts=None, + local_contexts=context_table, # Pass context from previous hooks violations_as_exceptions=True, ) if pre_result.modified_payload: diff --git a/plugins/README.md b/plugins/README.md index c6862cd68..09a93c119 100644 --- a/plugins/README.md +++ b/plugins/README.md @@ -521,7 +521,7 @@ async def tool_pre_invoke( user = context.global_context.user tenant_id = context.global_context.tenant_id - # Store plugin-specific state (persists across pre/post hooks) + # Store plugin-specific state (persists across all hooks in the request) context.state["invocation_count"] = context.state.get("invocation_count", 0) + 1 # Add metadata diff --git a/tests/integration/test_cross_hook_context_sharing.py b/tests/integration/test_cross_hook_context_sharing.py new file mode 100644 index 000000000..208955238 --- /dev/null +++ b/tests/integration/test_cross_hook_context_sharing.py @@ -0,0 +1,281 @@ +# -*- coding: utf-8 -*- +"""Location: ./tests/integration/test_cross_hook_context_sharing.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Integration tests for cross-hook context sharing functionality. + +These tests verify that plugin contexts are properly shared across different +hook types (HTTP → Tool, HTTP → Resource, HTTP → Prompt, RBAC hooks, etc.). +""" + +# Standard +import os +from pathlib import Path + +# Third-Party +import pytest +from fastapi.testclient import TestClient +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +# First-Party +from mcpgateway.db import Base +from mcpgateway.main import app +from mcpgateway.middleware.http_auth_middleware import HttpAuthMiddleware +from mcpgateway.plugins.framework import PluginManager + + +class TestCrossHookContextSharing: + """Integration tests for cross-hook context sharing. + + These tests verify that: + 1. Context stored in HTTP_PRE_REQUEST is accessible in HTTP_AUTH_CHECK_PERMISSION + 2. Context stored in HTTP hooks is accessible in MCP hooks (Tool, Resource, Prompt) + 3. GlobalContext is properly shared across all hooks + 4. Plugin state is isolated per plugin + """ + + @pytest.fixture + def test_db(self): + """Create a test database.""" + engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(engine) + SessionLocal = sessionmaker(bind=engine) + db = SessionLocal() + yield db + db.close() + + @pytest.fixture + async def plugin_manager(self): + """Create plugin manager with cross-hook context test plugin.""" + # Get the path to the test plugin config + config_file = Path(__file__).parent.parent / "unit" / "mcpgateway" / "plugins" / "fixtures" / "configs" / "cross_hook_context.yaml" + + # Enable plugins for this test + with pytest.MonkeyPatch.context() as mp: + mp.setenv("PLUGINS_ENABLED", "true") + mp.setenv("PLUGIN_CONFIG_FILE", str(config_file)) + + # Create plugin manager + manager = PluginManager(str(config_file)) + await manager.initialize() + + yield manager + + # Cleanup + await manager.shutdown() + + @pytest.fixture + def test_client_with_plugins(self, plugin_manager): + """Create test client with plugin middleware enabled.""" + # Add the HttpAuthMiddleware with plugin manager + app.add_middleware(HttpAuthMiddleware, plugin_manager=plugin_manager) + + with TestClient(app) as client: + yield client + + @pytest.mark.asyncio + async def test_http_to_rbac_context_sharing(self, test_client_with_plugins, plugin_manager): + """Test context sharing from HTTP_PRE_REQUEST to HTTP_AUTH_CHECK_PERMISSION. + + This test verifies that: + 1. HTTP_PRE_REQUEST hook stores context data + 2. HTTP_AUTH_CHECK_PERMISSION hook can read that data + 3. The plugin doesn't raise any ValueError about missing context + """ + # Make a request that triggers both HTTP_PRE_REQUEST and HTTP_AUTH_CHECK_PERMISSION + response = test_client_with_plugins.get( + "/tools", + headers={"Authorization": "Bearer test-token"} + ) + + # If cross-hook context sharing works, the plugin won't raise ValueError + # and the request will succeed (or fail for other reasons like auth) + # The important thing is that we don't get a 500 error from the plugin + + # Note: This might return 401 if auth fails, but that's OK - + # we're testing that the plugin's cross-hook context access works + assert response.status_code in [200, 401], \ + "Plugin should not raise ValueError about missing context" + + @pytest.mark.asyncio + async def test_http_to_tool_context_sharing( + self, test_db, test_client_with_plugins, plugin_manager + ): + """Test context sharing from HTTP hooks to TOOL_PRE_INVOKE hook. + + This test verifies that: + 1. HTTP_PRE_REQUEST stores context + 2. TOOL_PRE_INVOKE can read HTTP context data + 3. TOOL_PRE_INVOKE can also read HTTP_AUTH_CHECK_PERMISSION data + """ + # First, set up a test tool + from mcpgateway.schemas import ToolCreate + from mcpgateway.services.tool_service import ToolService + + tool_service = ToolService() + + # Register a test tool + tool_data = ToolCreate( + name="test_cross_hook_tool", + description="Test tool for cross-hook context", + input_schema={"type": "object", "properties": {}}, + ) + + await tool_service.register_tool(test_db, tool_data) + + # Make a request to invoke the tool + response = test_client_with_plugins.post( + "/rpc/", + json={ + "jsonrpc": "2.0", + "id": "test-1", + "method": "tools/call", + "params": { + "name": "test_cross_hook_tool", + "arguments": {} + } + }, + headers={"Authorization": "Bearer test-token"} + ) + + # The plugin should successfully access context from HTTP hooks + # If it fails to find the context, it will raise ValueError and return 500 + assert response.status_code != 500, \ + "Cross-hook context sharing should work for HTTP → Tool" + + @pytest.mark.asyncio + async def test_http_to_resource_context_sharing( + self, test_db, test_client_with_plugins, plugin_manager + ): + """Test context sharing from HTTP hooks to RESOURCE_PRE_FETCH hook. + + This test verifies that context stored in HTTP_PRE_REQUEST is + accessible in the RESOURCE_PRE_FETCH hook. + """ + # First, set up a test resource + from mcpgateway.schemas import ResourceCreate + from mcpgateway.services.resource_service import ResourceService + + resource_service = ResourceService() + + # Register a test resource + resource_data = ResourceCreate( + uri="test://cross-hook-resource", + name="Cross-hook test resource", + content="Test content", + mime_type="text/plain", + ) + + created = await resource_service.register_resource(test_db, resource_data) + + # Make a request to read the resource + response = test_client_with_plugins.get( + f"/resources/{created.id}", + headers={"Authorization": "Bearer test-token"} + ) + + # The plugin should successfully access context from HTTP hooks + assert response.status_code != 500, \ + "Cross-hook context sharing should work for HTTP → Resource" + + @pytest.mark.asyncio + async def test_http_to_prompt_context_sharing( + self, test_db, test_client_with_plugins, plugin_manager + ): + """Test context sharing from HTTP hooks to PROMPT_PRE_FETCH hook. + + This test verifies that context stored in HTTP_PRE_REQUEST is + accessible in the PROMPT_PRE_FETCH hook. + """ + # First, set up a test prompt + from mcpgateway.schemas import PromptCreate + from mcpgateway.services.prompt_service import PromptService + + prompt_service = PromptService() + + # Register a test prompt + prompt_data = PromptCreate( + name="test_cross_hook_prompt", + template="Hello {name}!", + description="Test prompt for cross-hook context", + ) + + created = await prompt_service.register_prompt( + test_db, + prompt_data, + user_email="test@example.com" + ) + + # Make a request to get the prompt + response = test_client_with_plugins.get( + f"/prompts/{created.name}", + headers={"Authorization": "Bearer test-token"} + ) + + # The plugin should successfully access context from HTTP hooks + assert response.status_code != 500, \ + "Cross-hook context sharing should work for HTTP → Prompt" + + @pytest.mark.asyncio + async def test_global_context_consistency(self, test_client_with_plugins, plugin_manager): + """Test that GlobalContext is consistent across all hooks. + + This test verifies that the same GlobalContext instance (or at least + the same request_id) is used across all hooks in a single request. + """ + # Make a request that triggers multiple hooks + response = test_client_with_plugins.get( + "/tools", + headers={"Authorization": "Bearer test-token"} + ) + + # The plugin stores request_id in global context during HTTP_PRE_REQUEST + # and verifies it's present in subsequent hooks + # If the global context wasn't shared, the plugin would raise ValueError + + assert response.status_code in [200, 401], \ + "GlobalContext should be consistent across all hooks" + + @pytest.mark.asyncio + async def test_plugin_context_isolation(self, plugin_manager): + """Test that plugin contexts are properly isolated. + + This test verifies that each plugin gets its own isolated context + and cannot access other plugins' context data. + """ + from mcpgateway.plugins.framework import ( + GlobalContext, + HttpPreRequestPayload, + HttpHeaderPayload, + HttpHookType, + ) + + # Create a global context + global_context = GlobalContext(request_id="test-isolation-123") + + # Invoke HTTP_PRE_REQUEST hook (which stores data in context) + payload = HttpPreRequestPayload( + path="/test", + method="GET", + headers=HttpHeaderPayload(root={}), + ) + + result, context_table = await plugin_manager.invoke_hook( + HttpHookType.HTTP_PRE_REQUEST, + payload=payload, + global_context=global_context, + ) + + # Verify context table was created + assert context_table is not None + assert len(context_table) > 0 + + # Verify each plugin has its own isolated context + # The key format is: request_id + plugin_uuid + for key, context in context_table.items(): + assert key.startswith("test-isolation-123") + assert "http_timestamp" in context.state + # Each plugin should only see its own state + assert context.state["http_timestamp"] == "2025-01-01T00:00:00Z" diff --git a/tests/unit/mcpgateway/plugins/fixtures/configs/cross_hook_context.yaml b/tests/unit/mcpgateway/plugins/fixtures/configs/cross_hook_context.yaml new file mode 100644 index 000000000..4926eb043 --- /dev/null +++ b/tests/unit/mcpgateway/plugins/fixtures/configs/cross_hook_context.yaml @@ -0,0 +1,22 @@ +plugins: + - name: "CrossHookContextPlugin" + kind: "tests.unit.mcpgateway.plugins.fixtures.plugins.cross_hook_context.CrossHookContextPlugin" + description: "Test plugin that demonstrates cross-hook context sharing" + version: "1.0.0" + author: "Test Author" + hooks: + - "http_pre_request" + - "http_auth_check_permission" + - "tool_pre_invoke" + - "resource_pre_fetch" + - "prompt_pre_fetch" + tags: ["test", "cross-hook", "context-sharing"] + mode: "enforce" + priority: 50 + config: {} + +plugin_settings: + parallel_execution_within_band: false + plugin_timeout: 30 + fail_on_plugin_error: true + plugin_health_check_interval: 60 diff --git a/tests/unit/mcpgateway/plugins/fixtures/plugins/cross_hook_context.py b/tests/unit/mcpgateway/plugins/fixtures/plugins/cross_hook_context.py new file mode 100644 index 000000000..7cfa79f3b --- /dev/null +++ b/tests/unit/mcpgateway/plugins/fixtures/plugins/cross_hook_context.py @@ -0,0 +1,252 @@ +# -*- coding: utf-8 -*- + +"""Location: ./tests/unit/mcpgateway/plugins/fixtures/plugins/cross_hook_context.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Cross-hook context sharing test plugin. + +This plugin demonstrates sharing context across different hook types: +- HTTP_PRE_REQUEST stores data +- HTTP_AUTH_CHECK_PERMISSION reads and verifies data +- TOOL_PRE_INVOKE reads and adds more data +- RESOURCE_PRE_FETCH reads and adds more data +- PROMPT_PRE_FETCH reads and adds more data +""" + +import logging + +from mcpgateway.plugins.framework import ( + HttpAuthCheckPermissionPayload, + HttpAuthCheckPermissionResult, + HttpPreRequestPayload, + HttpPreRequestResult, + Plugin, + PluginContext, + PromptPrehookPayload, + PromptPrehookResult, + ResourcePreFetchPayload, + ResourcePreFetchResult, + ToolPreInvokePayload, + ToolPreInvokeResult, +) + +logger = logging.getLogger("cross_hook_context_plugin") +logger.setLevel(logging.INFO) # Ensure INFO level logs are captured + + +class CrossHookContextPlugin(Plugin): + """Plugin that demonstrates cross-hook context sharing. + + This plugin stores context in HTTP_PRE_REQUEST and verifies it's accessible + in subsequent hooks like HTTP_AUTH_CHECK_PERMISSION, TOOL_PRE_INVOKE, + RESOURCE_PRE_FETCH, and PROMPT_PRE_FETCH. + """ + + async def http_pre_request( + self, payload: HttpPreRequestPayload, context: PluginContext + ) -> HttpPreRequestResult: + """Store initial context data in HTTP_PRE_REQUEST hook. + + Args: + payload: The HTTP request payload. + context: Plugin context for state storage. + + Returns: + Result allowing processing to continue. + """ + logger.info( + f"🔍 [CrossHookContextPlugin] HTTP_PRE_REQUEST executed - " + f"request_id={context.global_context.request_id}, " + f"path={payload.path}, method={payload.method}" + ) + + # Store data in plugin-specific state + context.state["http_timestamp"] = "2025-01-01T00:00:00Z" + context.state["http_request_path"] = payload.path + context.state["http_method"] = payload.method + + # Also store in global context to show it's shared + context.global_context.state["shared_request_id"] = context.global_context.request_id + + return HttpPreRequestResult(continue_processing=True) + + async def http_auth_check_permission( + self, payload: HttpAuthCheckPermissionPayload, context: PluginContext + ) -> HttpAuthCheckPermissionResult: + """Verify context from HTTP_PRE_REQUEST is accessible. + + Args: + payload: The permission check payload. + context: Plugin context that should contain data from HTTP_PRE_REQUEST. + + Returns: + Result with permission decision. + + Raises: + ValueError: If expected context data is missing. + """ + logger.info( + f"🔍 [CrossHookContextPlugin] HTTP_AUTH_CHECK_PERMISSION executed - " + f"request_id={context.global_context.request_id}, " + f"user_email={payload.user_email}" + ) + + # Verify we can read data stored in HTTP_PRE_REQUEST + if "http_timestamp" not in context.state: + raise ValueError("http_timestamp not found in context! Cross-hook sharing failed.") + + if "http_request_path" not in context.state: + raise ValueError("http_request_path not found in context!") + + # Verify global context is shared + if "shared_request_id" not in context.global_context.state: + raise ValueError("shared_request_id not found in global context!") + + # Verify request_id consistency + shared_request_id = context.global_context.state["shared_request_id"] + if shared_request_id != context.global_context.request_id: + raise ValueError( + f"Request ID mismatch! shared_request_id={shared_request_id}, " + f"global_context.request_id={context.global_context.request_id}" + ) + + logger.info( + f"✅ [CrossHookContextPlugin] Request ID verified: {context.global_context.request_id}" + ) + + # Add permission-specific data + context.state["permission_checked"] = True + context.state["user_email"] = payload.user_email + + return HttpAuthCheckPermissionResult(continue_processing=True) + + async def tool_pre_invoke( + self, payload: ToolPreInvokePayload, context: PluginContext + ) -> ToolPreInvokeResult: + """Verify context from HTTP hooks is accessible in tool hooks. + + Args: + payload: The tool invocation payload. + context: Plugin context that should contain data from HTTP hooks. + + Returns: + Result allowing tool invocation to continue. + + Raises: + ValueError: If expected context data is missing. + """ + logger.info( + f"🔍 [CrossHookContextPlugin] TOOL_PRE_INVOKE executed - " + f"request_id={context.global_context.request_id}, " + f"tool_name={payload.name}" + ) + + # Verify we can read data from HTTP_PRE_REQUEST + if "http_timestamp" not in context.state: + raise ValueError("http_timestamp not found in tool hook! Cross-hook sharing failed.") + + # Verify we can read data from HTTP_AUTH_CHECK_PERMISSION + if "permission_checked" not in context.state: + raise ValueError("permission_checked not found in tool hook!") + + # Verify request_id consistency + if "shared_request_id" in context.global_context.state: + shared_request_id = context.global_context.state["shared_request_id"] + if shared_request_id != context.global_context.request_id: + raise ValueError( + f"Request ID mismatch in tool hook! shared_request_id={shared_request_id}, " + f"global_context.request_id={context.global_context.request_id}" + ) + + # Add tool-specific data + context.state["tool_name"] = payload.name + context.state["tool_invoked_at"] = "2025-01-01T00:01:00Z" + + return ToolPreInvokeResult(continue_processing=True) + + async def resource_pre_fetch( + self, payload: ResourcePreFetchPayload, context: PluginContext + ) -> ResourcePreFetchResult: + """Verify context from HTTP hooks is accessible in resource hooks. + + Args: + payload: The resource fetch payload. + context: Plugin context that should contain data from HTTP hooks. + + Returns: + Result allowing resource fetch to continue. + + Raises: + ValueError: If expected context data is missing. + """ + logger.info( + f"🔍 [CrossHookContextPlugin] RESOURCE_PRE_FETCH executed - " + f"request_id={context.global_context.request_id}, " + f"resource_uri={payload.uri}" + ) + + # Verify we can read data from HTTP_PRE_REQUEST + if "http_timestamp" not in context.state: + raise ValueError("http_timestamp not found in resource hook! Cross-hook sharing failed.") + + # Verify global context is shared + if "shared_request_id" not in context.global_context.state: + raise ValueError("shared_request_id not found in resource hook!") + + # Verify request_id consistency + shared_request_id = context.global_context.state["shared_request_id"] + if shared_request_id != context.global_context.request_id: + raise ValueError( + f"Request ID mismatch in resource hook! shared_request_id={shared_request_id}, " + f"global_context.request_id={context.global_context.request_id}" + ) + + # Add resource-specific data + context.state["resource_uri"] = payload.uri + context.state["resource_fetched_at"] = "2025-01-01T00:02:00Z" + + return ResourcePreFetchResult(continue_processing=True) + + async def prompt_pre_fetch( + self, payload: PromptPrehookPayload, context: PluginContext + ) -> PromptPrehookResult: + """Verify context from HTTP hooks is accessible in prompt hooks. + + Args: + payload: The prompt fetch payload. + context: Plugin context that should contain data from HTTP hooks. + + Returns: + Result allowing prompt fetch to continue. + + Raises: + ValueError: If expected context data is missing. + """ + logger.info( + f"🔍 [CrossHookContextPlugin] PROMPT_PRE_FETCH executed - " + f"request_id={context.global_context.request_id}, " + f"prompt_id={payload.prompt_id}" + ) + + # Verify we can read data from HTTP_PRE_REQUEST + if "http_timestamp" not in context.state: + raise ValueError("http_timestamp not found in prompt hook! Cross-hook sharing failed.") + + # Verify global context is shared + if "shared_request_id" not in context.global_context.state: + raise ValueError("shared_request_id not found in prompt hook!") + + # Verify request_id consistency + shared_request_id = context.global_context.state["shared_request_id"] + if shared_request_id != context.global_context.request_id: + raise ValueError( + f"Request ID mismatch in prompt hook! shared_request_id={shared_request_id}, " + f"global_context.request_id={context.global_context.request_id}" + ) + + # Add prompt-specific data + context.state["prompt_id"] = payload.prompt_id + context.state["prompt_fetched_at"] = "2025-01-01T00:03:00Z" + + return PromptPrehookResult(continue_processing=True) diff --git a/tests/unit/mcpgateway/test_main.py b/tests/unit/mcpgateway/test_main.py index eab1fae2f..90101ef00 100644 --- a/tests/unit/mcpgateway/test_main.py +++ b/tests/unit/mcpgateway/test_main.py @@ -795,7 +795,7 @@ def test_get_prompt_no_args(self, mock_get, test_client, auth_headers): mock_get.return_value = {"name": "test", "template": "Hello"} response = test_client.get("/prompts/test", headers=auth_headers) assert response.status_code == 200 - mock_get.assert_called_once_with(ANY, "test", {}) + mock_get.assert_called_once_with(ANY, "test", {}, plugin_context_table=None, plugin_global_context=ANY) @patch("mcpgateway.main.prompt_service.update_prompt") def test_update_prompt_endpoint(self, mock_update, test_client, auth_headers): @@ -871,7 +871,7 @@ def test_get_prompt_no_args(self, mock_get, test_client, auth_headers): mock_get.return_value = {"name": "test", "template": "Hello"} response = test_client.get("/prompts/test", headers=auth_headers) assert response.status_code == 200 - mock_get.assert_called_once_with(ANY, "test", {}) + mock_get.assert_called_once_with(ANY, "test", {}, plugin_context_table=None, plugin_global_context=ANY) @patch("mcpgateway.main.prompt_service.update_prompt") def test_update_prompt_endpoint(self, mock_update, test_client, auth_headers): @@ -1113,7 +1113,15 @@ def test_rpc_tool_invocation(self, mock_invoke_tool, test_client, auth_headers): assert response.status_code == 200 body = response.json() assert body["result"]["content"][0]["text"] == "Tool response" - mock_invoke_tool.assert_called_once_with(db=ANY, name="test_tool", arguments={"param": "value"}, request_headers=ANY, app_user_email="test_user") + mock_invoke_tool.assert_called_once_with( + db=ANY, + name="test_tool", + arguments={"param": "value"}, + request_headers=ANY, + app_user_email="test_user", + plugin_context_table=None, + plugin_global_context=ANY, + ) @patch("mcpgateway.main.prompt_service.get_prompt") # @patch("mcpgateway.main.validate_request") @@ -1135,7 +1143,7 @@ def test_rpc_prompt_get(self, mock_get_prompt, test_client, auth_headers): assert response.status_code == 200 body = response.json() assert body["result"]["messages"][0]["content"]["text"] == "Rendered prompt" - mock_get_prompt.assert_called_once_with(ANY, "test_prompt", {"param": "value"}) + mock_get_prompt.assert_called_once_with(ANY, "test_prompt", {"param": "value"}, plugin_context_table=None, plugin_global_context=ANY) @patch("mcpgateway.main.tool_service.list_tools") # @patch("mcpgateway.main.validate_request")