Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions docs/docs/using/plugins/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -863,7 +863,7 @@ class MyPlugin(Plugin):

### Plugin Context and State

Each hook function has a `context` object of type `PluginContext` which is designed to allow plugins to pass state between one another (across pre/post hook pairs) or for a plugin to pass state information to itself across pre/post hook pairs. The plugin context looks as follows:
Each hook function has a `context` object of type `PluginContext` which is designed to allow plugins to pass state between one another across all hook types in a request, or for a plugin to pass state information to itself across different hooks. The plugin context looks as follows:

```python
class GlobalContext(BaseModel):
Expand Down Expand Up @@ -900,10 +900,9 @@ class PluginContext(BaseModel):
metadata: dict[str, Any] = Field(default_factory=dict)
```

As can be seen, the `PluginContext` has both a `state` dictionary and a `global_context` object that also has a `state` dictionary. A single plugin can share state between pre/post hook pairs by using the
the `PluginContext` state dictionary. It can share state with other plugins using the `context.global_context.state` dictionary. Metadata for the specific hook site is passed in through the `metadata` dictionaries in the `context.global_context.metadata`. It is meant to be read-only. The `context.metadata` is plugin specific metadata and can be used to store metadata information such as timing information.
As can be seen, the `PluginContext` has both a `state` dictionary and a `global_context` object that also has a `state` dictionary. A single plugin can share state across all hooks in a request by using the `PluginContext` state dictionary. It can share state with other plugins using the `context.global_context.state` dictionary. Metadata for the specific hook site is passed in through the `metadata` dictionaries in the `context.global_context.metadata`. It is meant to be read-only. The `context.metadata` is plugin specific metadata and can be used to store metadata information such as timing information.

The following shows how plugins can maintain state between pre/post hooks:
The following shows how plugins can maintain state across different hooks:

```python
async def prompt_pre_fetch(self, payload, context):
Expand All @@ -926,7 +925,7 @@ async def prompt_post_fetch(self, payload, context):

#### Tool and Gateway Metadata

Currently, the tool pre/post hooks have access to tool and gateway metadata through the global context metadata dictionary. They are accessible as follows:
Tool hooks have access to tool and gateway metadata through the global context metadata dictionary. They are accessible as follows:

It can be accessed inside of the tool hooks through:

Expand Down
36 changes: 22 additions & 14 deletions mcpgateway/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,22 +170,25 @@ async def get_current_user(
headers = dict(request.headers)

# Get request ID from request state (set by middleware) or generate new one
request_id = None
if request and hasattr(request, "state") and hasattr(request.state, "request_id"):
request_id = request.state.request_id
else:
request_id = getattr(request.state, "request_id", None) if request else None
if not request_id:
request_id = uuid.uuid4().hex

# Create global context
global_context = GlobalContext(
request_id=request_id,
server_id=None,
tenant_id=None,
)
# Get plugin contexts from request state if available
global_context = getattr(request.state, "plugin_global_context", None) if request else None
if not global_context:
# Create global context
global_context = GlobalContext(
request_id=request_id,
server_id=None,
tenant_id=None,
)

context_table = getattr(request.state, "plugin_context_table", None) if request else None

# Invoke custom auth resolution hook
# violations_as_exceptions=True so PluginViolationError is raised for explicit denials
auth_result, _ = await plugin_manager.invoke_hook(
auth_result, context_table_result = await plugin_manager.invoke_hook(
HttpHookType.HTTP_AUTH_RESOLVE_USER,
payload=HttpAuthResolveUserPayload(
credentials=credentials_dict,
Expand All @@ -194,7 +197,7 @@ async def get_current_user(
client_port=client_port,
),
global_context=global_context,
local_contexts=None,
local_contexts=context_table,
violations_as_exceptions=True, # Raise PluginViolationError for auth denials
)

Expand All @@ -215,12 +218,17 @@ async def get_current_user(
)

# Store auth_method in request.state so it can be accessed by RBAC middleware
if request and hasattr(request, "state") and auth_result.metadata:
if request and auth_result.metadata:
auth_method = auth_result.metadata.get("auth_method")
if auth_method:
request.state.auth_method = auth_method
logger.debug(f"Stored auth_method '{auth_method}' in request.state")

if request and context_table_result:
request.state.plugin_context_table = context_table_result

if request and global_context:
request.state.plugin_global_context = global_context
return user
# If continue_processing=True (no payload), fall through to standard auth

Expand Down Expand Up @@ -294,7 +302,7 @@ async def get_current_user(

# Check team level token, if applicable. If public token, then will be defaulted to personal team.
team_id = await get_team_from_token(payload, db)
if request and hasattr(request, "state"):
if request:
request.state.team_id = team_id

except HTTPException:
Expand Down
79 changes: 73 additions & 6 deletions mcpgateway/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2960,9 +2960,21 @@ async def read_resource(resource_id: str, request: Request, db: Session = Depend
if cached := resource_cache.get(resource_id):
return cached

# Get plugin contexts from request.state for cross-hook sharing
plugin_context_table = getattr(request.state, "plugin_context_table", None)
plugin_global_context = getattr(request.state, "plugin_global_context", None)

try:
# Call service with context for plugin support
content = await resource_service.read_resource(db, resource_id=resource_id, request_id=request_id, user=user, server_id=server_id)
content = await resource_service.read_resource(
db,
resource_id=resource_id,
request_id=request_id,
user=user,
server_id=server_id,
plugin_context_table=plugin_context_table,
plugin_global_context=plugin_global_context,
)
except (ResourceNotFoundError, ResourceError) as exc:
# Translate to FastAPI HTTP error
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
Expand Down Expand Up @@ -3289,6 +3301,7 @@ async def create_prompt(
@prompt_router.post("/{prompt_id}")
@require_permission("prompts.read")
async def get_prompt(
request: Request,
prompt_id: str,
args: Dict[str, str] = Body({}),
db: Session = Depends(get_db),
Expand All @@ -3301,6 +3314,7 @@ async def get_prompt(


Args:
request: FastAPI request object.
prompt_id: ID of the prompt.
args: Template arguments.
db: Database session.
Expand All @@ -3314,9 +3328,19 @@ async def get_prompt(
"""
logger.debug(f"User: {user} requested prompt: {prompt_id} with args={args}")

# Get plugin contexts from request.state for cross-hook sharing
plugin_context_table = getattr(request.state, "plugin_context_table", None)
plugin_global_context = getattr(request.state, "plugin_global_context", None)

try:
PromptExecuteArgs(args=args)
result = await prompt_service.get_prompt(db, prompt_id, args)
result = await prompt_service.get_prompt(
db,
prompt_id,
args,
plugin_context_table=plugin_context_table,
plugin_global_context=plugin_global_context,
)
logger.debug(f"Prompt execution successful for '{prompt_id}'")
except Exception as ex:
logger.error(f"Could not retrieve prompt {prompt_id}: {ex}")
Expand All @@ -3334,6 +3358,7 @@ async def get_prompt(
@prompt_router.get("/{prompt_id}")
@require_permission("prompts.read")
async def get_prompt_no_args(
request: Request,
prompt_id: str,
db: Session = Depends(get_db),
user=Depends(get_current_user_with_permissions),
Expand All @@ -3343,6 +3368,7 @@ async def get_prompt_no_args(
This endpoint is for convenience when no arguments are needed.

Args:
request: FastAPI request object.
prompt_id: The ID of the prompt to retrieve
db: Database session
user: Authenticated user
Expand All @@ -3354,7 +3380,18 @@ async def get_prompt_no_args(
Exception: Re-raised from prompt service.
"""
logger.debug(f"User: {user} requested prompt: {prompt_id} with no arguments")
return await prompt_service.get_prompt(db, prompt_id, {})

# Get plugin contexts from request.state for cross-hook sharing
plugin_context_table = getattr(request.state, "plugin_context_table", None)
plugin_global_context = getattr(request.state, "plugin_global_context", None)

return await prompt_service.get_prompt(
db,
prompt_id,
{},
plugin_context_table=plugin_context_table,
plugin_global_context=plugin_global_context,
)


@prompt_router.put("/{prompt_id}", response_model=PromptRead)
Expand Down Expand Up @@ -3921,8 +3958,18 @@ async def handle_rpc(request: Request, db: Session = Depends(get_db), user=Depen
raise JSONRPCError(-32602, "Missing resource URI in parameters", params)
# Get user email for OAuth token selection
user_email = get_user_email(user)
# Get plugin contexts from request.state for cross-hook sharing
plugin_context_table = getattr(request.state, "plugin_context_table", None)
plugin_global_context = getattr(request.state, "plugin_global_context", None)
try:
result = await resource_service.read_resource(db, uri, request_id=request_id, user=user_email)
result = await resource_service.read_resource(
db,
resource_uri=uri,
request_id=request_id,
user=user_email,
plugin_context_table=plugin_context_table,
plugin_global_context=plugin_global_context,
)
if hasattr(result, "model_dump"):
result = {"contents": [result.model_dump(by_alias=True, exclude_none=True)]}
else:
Expand Down Expand Up @@ -3966,7 +4013,16 @@ async def handle_rpc(request: Request, db: Session = Depends(get_db), user=Depen
arguments = params.get("arguments", {})
if not name:
raise JSONRPCError(-32602, "Missing prompt name in parameters", params)
result = await prompt_service.get_prompt(db, name, arguments)
# Get plugin contexts from request.state for cross-hook sharing
plugin_context_table = getattr(request.state, "plugin_context_table", None)
plugin_global_context = getattr(request.state, "plugin_global_context", None)
result = await prompt_service.get_prompt(
db,
name,
arguments,
plugin_context_table=plugin_context_table,
plugin_global_context=plugin_global_context,
)
if hasattr(result, "model_dump"):
result = result.model_dump(by_alias=True, exclude_none=True)
elif method == "ping":
Expand All @@ -3981,8 +4037,19 @@ async def handle_rpc(request: Request, db: Session = Depends(get_db), user=Depen
raise JSONRPCError(-32602, "Missing tool name in parameters", params)
# Get user email for OAuth token selection
user_email = get_user_email(user)
# Get plugin contexts from request.state for cross-hook sharing
plugin_context_table = getattr(request.state, "plugin_context_table", None)
plugin_global_context = getattr(request.state, "plugin_global_context", None)
try:
result = await tool_service.invoke_tool(db=db, name=name, arguments=arguments, request_headers=headers, app_user_email=user_email)
result = await tool_service.invoke_tool(
db=db,
name=name,
arguments=arguments,
request_headers=headers,
app_user_email=user_email,
plugin_context_table=plugin_context_table,
plugin_global_context=plugin_global_context,
)
if hasattr(result, "model_dump"):
result = result.model_dump(by_alias=True, exclude_none=True)
except ValueError:
Expand Down
6 changes: 6 additions & 0 deletions mcpgateway/middleware/http_auth_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ async def dispatch(self, request: Request, call_next):
violations_as_exceptions=False, # Don't block on pre-request violations
)

if context_table:
request.state.plugin_context_table = context_table

if global_context:
request.state.plugin_global_context = global_context

# Apply modified headers if plugin returned them
if pre_result.modified_payload:
# Modify request headers by updating request.scope["headers"]
Expand Down
36 changes: 25 additions & 11 deletions mcpgateway/middleware/rbac.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,11 @@ async def protected_route(user = Depends(get_current_user_with_permissions)):
request_id = getattr(request.state, "request_id", None)
team_id = getattr(request.state, "team_id", None)

# Read plugin context data from request.state for cross-hook context sharing
# (set by HttpAuthMiddleware for passing contexts between different hook types)
plugin_context_table = getattr(request.state, "plugin_context_table", None)
plugin_global_context = getattr(request.state, "plugin_global_context", None)

# Add request context for permission auditing
return {
"email": user.email,
Expand All @@ -161,6 +166,8 @@ async def protected_route(user = Depends(get_current_user_with_permissions)):
"auth_method": auth_method, # Include auth_method from plugin
"request_id": request_id, # Include request_id from middleware
"team_id": team_id, # Include team_id from token
"plugin_context_table": plugin_context_table, # Plugin contexts for cross-hook sharing
"plugin_global_context": plugin_global_context, # Global context for consistency
}
except Exception as e:
logger.error(f"Authentication failed: {type(e).__name__}: {e}")
Expand Down Expand Up @@ -256,18 +263,24 @@ async def wrapper(*args, **kwargs):

plugin_manager = get_plugin_manager()
if plugin_manager:
# Get request_id from user_context (passed from get_current_user_with_permissions)
# Generate a fallback if not present
request_id = user_context.get("request_id") or uuid.uuid4().hex

# Create global context for plugin invocation
global_context = GlobalContext(
request_id=request_id,
server_id=None,
tenant_id=None,
)
# Get plugin contexts from user_context (stored in request.state by HttpAuthMiddleware)
# These enable cross-hook context sharing between HTTP_PRE_REQUEST and HTTP_AUTH_CHECK_PERMISSION
plugin_context_table = user_context.get("plugin_context_table")
plugin_global_context = user_context.get("plugin_global_context")

# Reuse existing global context from middleware if available for consistency
# Otherwise create a new one (fallback for cases where middleware didn't run)
if plugin_global_context:
global_context = plugin_global_context
else:
request_id = user_context.get("request_id") or uuid.uuid4().hex
global_context = GlobalContext(
request_id=request_id,
server_id=None,
tenant_id=None,
)

# Invoke permission check hook
# Invoke permission check hook, passing plugin contexts from HTTP_PRE_REQUEST hook
result, _ = await plugin_manager.invoke_hook(
HttpHookType.HTTP_AUTH_CHECK_PERMISSION,
payload=HttpAuthCheckPermissionPayload(
Expand All @@ -281,6 +294,7 @@ async def wrapper(*args, **kwargs):
user_agent=user_context.get("user_agent"),
),
global_context=global_context,
local_contexts=plugin_context_table, # Pass context table for cross-hook state
)

# If a plugin made a decision, respect it
Expand Down
2 changes: 2 additions & 0 deletions mcpgateway/plugins/framework/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
PluginCondition,
PluginConfig,
PluginContext,
PluginContextTable,
PluginErrorModel,
PluginMode,
PluginPayload,
Expand Down Expand Up @@ -120,6 +121,7 @@ def get_plugin_manager() -> Optional[PluginManager]:
"PluginCondition",
"PluginConfig",
"PluginContext",
"PluginContextTable",
"PluginError",
"PluginErrorModel",
"PluginLoader",
Expand Down
Loading
Loading