diff --git a/mcpgateway/config.py b/mcpgateway/config.py index ac830b26f..2779919f1 100644 --- a/mcpgateway/config.py +++ b/mcpgateway/config.py @@ -914,6 +914,7 @@ def parse_issuers(cls, v: Any) -> list[str]: # Validation Gateway URL gateway_validation_timeout: int = 5 # seconds + gateway_max_redirects: int = 5 filelock_name: str = "gateway_service_leader.lock" diff --git a/mcpgateway/routers/observability.py b/mcpgateway/routers/observability.py index 6e26a58f5..15881ac34 100644 --- a/mcpgateway/routers/observability.py +++ b/mcpgateway/routers/observability.py @@ -80,6 +80,27 @@ def list_traces( Returns: List[ObservabilityTraceRead]: List of traces matching filters + + Examples: + >>> import mcpgateway.routers.observability as obs + >>> class FakeTrace: + ... def __init__(self, trace_id='t1'): + ... self.trace_id = trace_id + ... self.name = 'n' + ... self.start_time = None + ... self.end_time = None + ... self.duration_ms = 100 + ... self.status = 'ok' + ... self.http_method = 'GET' + ... self.http_url = '/' + ... self.http_status_code = 200 + ... self.user_email = 'u' + >>> class FakeService: + ... def query_traces(self, **kwargs): + ... return [FakeTrace('t1')] + >>> obs.ObservabilityService = FakeService + >>> obs.list_traces(db=None)[0].trace_id + 't1' """ service = ObservabilityService() traces = service.query_traces( @@ -138,6 +159,27 @@ def query_traces_advanced( Raises: HTTPException: 400 error if request body is invalid + + Examples: + >>> from fastapi import HTTPException + >>> try: + ... query_traces_advanced({"start_time": "not-a-date"}, db=None) + ... except HTTPException as e: + ... (e.status_code, "Invalid request body" in str(e.detail)) + (400, True) + + >>> import mcpgateway.routers.observability as obs + >>> class FakeTrace: + ... def __init__(self): + ... self.trace_id = 'tx' + ... self.name = 'n' + + >>> class FakeService2: + ... def query_traces(self, **kwargs): + ... return [FakeTrace()] + >>> obs.ObservabilityService = FakeService2 + >>> obs.query_traces_advanced({}, db=None)[0].trace_id + 'tx' """ # Third-Party from pydantic import ValidationError @@ -199,6 +241,24 @@ def get_trace(trace_id: str, db: Session = Depends(get_db)): Raises: HTTPException: 404 if trace not found + + Examples: + >>> import mcpgateway.routers.observability as obs + >>> class FakeService: + ... def get_trace_with_spans(self, db, trace_id): + ... return None + >>> obs.ObservabilityService = FakeService + >>> try: + ... obs.get_trace('missing', db=None) + ... except obs.HTTPException as e: + ... e.status_code + 404 + >>> class FakeService2: + ... def get_trace_with_spans(self, db, trace_id): + ... return {'trace_id': trace_id} + >>> obs.ObservabilityService = FakeService2 + >>> obs.get_trace('found', db=None)['trace_id'] + 'found' """ service = ObservabilityService() trace = service.get_trace_with_spans(db, trace_id) @@ -235,6 +295,20 @@ def list_spans( Returns: List[ObservabilitySpanRead]: List of spans matching filters + + Examples: + >>> import mcpgateway.routers.observability as obs + >>> class FakeSpan: + ... def __init__(self): + ... self.span_id = 's1' + ... self.trace_id = 't1' + ... self.name = 'op' + >>> class FakeService: + ... def query_spans(self, **kwargs): + ... return [FakeSpan()] + >>> obs.ObservabilityService = FakeService + >>> obs.list_spans(db=None)[0].span_id + 's1' """ service = ObservabilityService() spans = service.query_spans( @@ -266,6 +340,16 @@ def cleanup_old_traces( Returns: dict: Number of deleted traces and cutoff time + + Examples: + >>> import mcpgateway.routers.observability as obs + >>> class FakeService: + ... def delete_old_traces(self, db, cutoff): + ... return 5 + >>> obs.ObservabilityService = FakeService + >>> res = obs.cleanup_old_traces(days=7, db=None) + >>> res['deleted'] + 5 """ service = ObservabilityService() cutoff_time = datetime.now() - timedelta(days=days) @@ -358,6 +442,41 @@ def export_traces( Raises: HTTPException: 400 error if format is invalid or export fails + + Examples: + >>> from fastapi import HTTPException + >>> try: + ... export_traces({}, format="xml", db=None) + ... except HTTPException as e: + ... (e.status_code, "format must be one of" in str(e.detail)) + (400, True) + >>> import mcpgateway.routers.observability as obs + >>> from datetime import datetime + >>> class FakeTrace: + ... def __init__(self): + ... self.trace_id = 'tx' + ... self.name = 'name' + ... self.start_time = datetime(2025,1,1) + ... self.end_time = None + ... self.duration_ms = 100 + ... self.status = 'ok' + ... self.http_method = 'GET' + ... self.http_url = '/' + ... self.http_status_code = 200 + ... self.user_email = 'u' + >>> class FakeService: + ... def query_traces(self, **kwargs): + ... return [FakeTrace()] + >>> obs.ObservabilityService = FakeService + >>> out = obs.export_traces({}, format='json', db=None) + >>> out[0]['trace_id'] + 'tx' + >>> resp = obs.export_traces({}, format='csv', db=None) + >>> hasattr(resp, 'media_type') and 'csv' in resp.media_type + True + >>> resp2 = obs.export_traces({}, format='ndjson', db=None) + >>> type(resp2).__name__ + 'StreamingResponse' """ # Standard import csv @@ -437,6 +556,13 @@ def export_traces( elif format == "ndjson": # Newline-delimited JSON (streaming) def generate(): + """Yield newline-delimited JSON strings for each trace. + + This nested generator is used to stream NDJSON responses. + + Yields: + str: A JSON-encoded line (with trailing newline) for a trace. + """ for t in traces: # Standard import json @@ -475,7 +601,32 @@ def get_query_performance(hours: int = Query(24, ge=1, le=168, description="Time Returns: dict: Performance analytics + + Examples: + >>> import mcpgateway.routers.observability as obs + >>> class EmptyDB: + ... def query(self, *a, **k): + ... return self + ... def filter(self, *a, **k): + ... return self + ... def all(self): + ... return [] + >>> obs.get_query_performance(hours=1, db=EmptyDB())['total_traces'] + 0 + + >>> class SmallDB: + ... def query(self, *a, **k): + ... return self + ... def filter(self, *a, **k): + ... return self + ... def all(self): + ... return [(10,), (20,), (30,), (40,)] + >>> res = obs.get_query_performance(hours=1, db=SmallDB()) + >>> res['total_traces'] + 4 + """ + # Third-Party # First-Party diff --git a/mcpgateway/services/gateway_service.py b/mcpgateway/services/gateway_service.py index e2b24141c..5472e17ee 100644 --- a/mcpgateway/services/gateway_service.py +++ b/mcpgateway/services/gateway_service.py @@ -397,33 +397,43 @@ async def _validate_gateway_url(self, url: str, headers: dict, transport_type: s """ if timeout is None: timeout = settings.gateway_validation_timeout - validation_client = ResilientHttpClient(client_args={"timeout": settings.gateway_validation_timeout, "verify": not settings.skip_ssl_verify}) + validation_client = ResilientHttpClient( + client_args={ + "timeout": settings.gateway_validation_timeout, + "verify": not settings.skip_ssl_verify, + # Let httpx follow only proper HTTP redirects (3xx) and + # enforce a sensible redirect limit. + "follow_redirects": True, + "max_redirects": settings.gateway_max_redirects, + } + ) + try: + # Make a single request and let httpx follow valid redirects. async with validation_client.client.stream("GET", url, headers=headers, timeout=timeout) as response: response_headers = dict(response.headers) - location = response_headers.get("location") - content_type = response_headers.get("content-type") - if response.status_code in (401, 403): + content_type = response_headers.get("content-type", "") + logger.info(f"Validating gateway URL {url}, received status {response.status_code}, content_type: {content_type}") + + # Authentication failures mean the endpoint is not usable + if response.status_code in (401, 403, 404): logger.debug(f"Authentication failed for {url} with status {response.status_code}") return False + # STREAMABLEHTTP: expect an MCP session id and JSON content if transport_type == "STREAMABLEHTTP": - if location: - async with validation_client.client.stream("GET", location, headers=headers, timeout=timeout) as response_redirect: - response_headers = dict(response_redirect.headers) - mcp_session_id = response_headers.get("mcp-session-id") - content_type = response_headers.get("content-type") - if response_redirect.status_code in (401, 403): - logger.debug(f"Authentication failed at redirect location {location}") - return False - if mcp_session_id is not None and mcp_session_id != "": - if content_type is not None and content_type != "" and "application/json" in content_type: - return True - - elif transport_type == "SSE": - if content_type is not None and content_type != "" and "text/event-stream" in content_type: + mcp_session_id = response_headers.get("mcp-session-id") + if mcp_session_id is not None and mcp_session_id != "": + if content_type is not None and content_type != "" and "application/json" in content_type: + return True + + # SSE: expect text/event-stream + if transport_type == "SSE": + logger.info(f"Validating SSE gateway URL {url}") + if "text/event-stream" in content_type: return True - return False + + return False except httpx.UnsupportedProtocol as e: logger.debug(f"Gateway URL Unsupported Protocol for {url}: {str(e)}", exc_info=True) return False @@ -3312,8 +3322,8 @@ async def connect_to_streamablehttp_server(self, server_url: str, authentication """ if authentication is None: authentication = {} - # Use authentication directly instead + # Use authentication directly instead def get_httpx_client_factory( headers: dict[str, str] | None = None, timeout: httpx.Timeout | None = None, @@ -3341,59 +3351,61 @@ def get_httpx_client_factory( auth=auth, ) - async with streamablehttp_client(url=server_url, headers=authentication, httpx_client_factory=get_httpx_client_factory) as (read_stream, write_stream, _get_session_id): - async with ClientSession(read_stream, write_stream) as session: - # Initialize the session - response = await session.initialize() - capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True) - logger.debug(f"Server capabilities: {capabilities}") - - response = await session.list_tools() - tools = response.tools - tools = [tool.model_dump(by_alias=True, exclude_none=True) for tool in tools] - - tools = [ToolCreate.model_validate(tool) for tool in tools] - for tool in tools: - tool.request_type = "STREAMABLEHTTP" - if tools: - logger.info(f"Fetched {len(tools)} tools from gateway") - - # Fetch resources if supported - resources = [] - logger.debug(f"Checking for resources support: {capabilities.get('resources')}") - if capabilities.get("resources"): - try: - response = await session.list_resources() - raw_resources = response.resources - resources = [] - for resource in raw_resources: - resource_data = resource.model_dump(by_alias=True, exclude_none=True) - # Convert AnyUrl to string if present - if "uri" in resource_data and hasattr(resource_data["uri"], "unicode_string"): - resource_data["uri"] = str(resource_data["uri"]) - # Add default content if not present - if "content" not in resource_data: - resource_data["content"] = "" - resources.append(ResourceCreate.model_validate(resource_data)) - logger.info(f"Fetched {len(resources)} resources from gateway") - except Exception as e: - logger.warning(f"Failed to fetch resources: {e}") + if await self._validate_gateway_url(url=server_url, headers=authentication, transport_type="STREAMABLEHTTP"): + async with streamablehttp_client(url=server_url, headers=authentication, httpx_client_factory=get_httpx_client_factory) as (read_stream, write_stream, _get_session_id): + async with ClientSession(read_stream, write_stream) as session: + # Initialize the session + response = await session.initialize() + capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True) + logger.debug(f"Server capabilities: {capabilities}") - # Fetch prompts if supported - prompts = [] - logger.debug(f"Checking for prompts support: {capabilities.get('prompts')}") - if capabilities.get("prompts"): - try: - response = await session.list_prompts() - raw_prompts = response.prompts - prompts = [] - for prompt in raw_prompts: - prompt_data = prompt.model_dump(by_alias=True, exclude_none=True) - # Add default template if not present - if "template" not in prompt_data: - prompt_data["template"] = "" - prompts.append(PromptCreate.model_validate(prompt_data)) - except Exception as e: - logger.warning(f"Failed to fetch prompts: {e}") + response = await session.list_tools() + tools = response.tools + tools = [tool.model_dump(by_alias=True, exclude_none=True) for tool in tools] + + tools = [ToolCreate.model_validate(tool) for tool in tools] + for tool in tools: + tool.request_type = "STREAMABLEHTTP" + if tools: + logger.info(f"Fetched {len(tools)} tools from gateway") + + # Fetch resources if supported + resources = [] + logger.debug(f"Checking for resources support: {capabilities.get('resources')}") + if capabilities.get("resources"): + try: + response = await session.list_resources() + raw_resources = response.resources + resources = [] + for resource in raw_resources: + resource_data = resource.model_dump(by_alias=True, exclude_none=True) + # Convert AnyUrl to string if present + if "uri" in resource_data and hasattr(resource_data["uri"], "unicode_string"): + resource_data["uri"] = str(resource_data["uri"]) + # Add default content if not present + if "content" not in resource_data: + resource_data["content"] = "" + resources.append(ResourceCreate.model_validate(resource_data)) + logger.info(f"Fetched {len(resources)} resources from gateway") + except Exception as e: + logger.warning(f"Failed to fetch resources: {e}") + + # Fetch prompts if supported + prompts = [] + logger.debug(f"Checking for prompts support: {capabilities.get('prompts')}") + if capabilities.get("prompts"): + try: + response = await session.list_prompts() + raw_prompts = response.prompts + prompts = [] + for prompt in raw_prompts: + prompt_data = prompt.model_dump(by_alias=True, exclude_none=True) + # Add default template if not present + if "template" not in prompt_data: + prompt_data["template"] = "" + prompts.append(PromptCreate.model_validate(prompt_data)) + except Exception as e: + logger.warning(f"Failed to fetch prompts: {e}") - return capabilities, tools, resources, prompts + return capabilities, tools, resources, prompts + raise GatewayConnectionError(f"Failed to initialize gateway at{server_url}") diff --git a/tests/unit/mcpgateway/services/test_gateway_service_extended.py b/tests/unit/mcpgateway/services/test_gateway_service_extended.py index e2eb9cd98..7eb6cbfc6 100644 --- a/tests/unit/mcpgateway/services/test_gateway_service_extended.py +++ b/tests/unit/mcpgateway/services/test_gateway_service_extended.py @@ -144,6 +144,9 @@ async def test_initialize_gateway_streamablehttp_transport(self): mock_tools_response.tools = [mock_tool] mock_session_instance.list_tools.return_value = mock_tools_response + # Mock _validate_gateway_url to return True (same as SSE test) + service._validate_gateway_url = AsyncMock(return_value=True) + # Execute capabilities, tools, resources, prompts = await service._initialize_gateway("http://test.example.com", {"Authorization": "Bearer token"}, "streamablehttp") diff --git a/tests/unit/mcpgateway/services/test_gateway_validation_redirects.py b/tests/unit/mcpgateway/services/test_gateway_validation_redirects.py new file mode 100644 index 000000000..e643b55c2 --- /dev/null +++ b/tests/unit/mcpgateway/services/test_gateway_validation_redirects.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- +"""Location: ./tests/unit/mcpgateway/services/test_gateway_validation_redirects.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Mihai Criveti + +Unit tests for the GatewayService implementation. + +These tests validate gateway URL redirection behavior. They avoid +real network access and real databases by using httpx.MockTransport +and lightweight fakes (MagicMock / AsyncMock). Where the service +relies on Pydantic models or SQLAlchemy Result objects we monkey- +patch or provide small stand-ins to exercise only the code paths +under test. +""" + +import pytest +import httpx +from unittest.mock import patch +from mcpgateway.services.gateway_service import GatewayService +from mcpgateway.utils.retry_manager import ResilientHttpClient + +@pytest.mark.asyncio +async def test_streamablehttp_follows_3xx_redirects_and_validates(): + svc = GatewayService() + + # Mock transport behavior: + # 1) GET http://example/start -> 302 Location: /final + # 2) GET http://example/final -> 200 with mcp-session-id + application/json + async def mock_dispatch(request: httpx.Request) -> httpx.Response: + url = str(request.url) + if url.endswith("/start"): + return httpx.Response(302, headers={"location": "/final"}) + if url.endswith("/final"): + return httpx.Response(200, headers={"mcp-session-id": "abc", "content-type": "application/json"}) + return httpx.Response(404) + + transport = httpx.MockTransport(mock_dispatch) + # Build a ResilientHttpClient that uses this transport + client_args = {"transport": transport, "follow_redirects": True} + mock_resilient = ResilientHttpClient(client_args=client_args) + + # Patch ResilientHttpClient where gateway_service constructs it + class MockResilientFactory: + def __init__(self, *args, **kwargs): + # ignore args; use our prebuilt instance + self.client = mock_resilient.client + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + async def aclose(self): + await mock_resilient.aclose() + + # expose stream method used by gateway_service + def stream(self, method, url, **kwargs): + return mock_resilient.client.stream(method, url, **kwargs) + + with patch("mcpgateway.services.gateway_service.ResilientHttpClient", MockResilientFactory): + headers = {} + ok = await svc._validate_gateway_url("http://example/start", headers, transport_type="STREAMABLEHTTP") + assert ok is True + +@pytest.mark.asyncio +async def test_200_with_location_is_not_treated_as_redirect(): + svc = GatewayService() + + async def mock_dispatch(request: httpx.Request) -> httpx.Response: + url = str(request.url) + if url.endswith("/bad"): + # Non-standard: 200 with Location header. Should NOT be treated as redirect. + return httpx.Response(200, headers={"location": "/should-not-follow", "content-type": "text/plain"}) + if url.endswith("/should-not-follow"): + # If code incorrectly followed Location on 200, we'd reach here + return httpx.Response(200, headers={"mcp-session-id": "x", "content-type": "application/json"}) + return httpx.Response(404) + + transport = httpx.MockTransport(mock_dispatch) + client_args = {"transport": transport, "follow_redirects": True} + mock_resilient = ResilientHttpClient(client_args=client_args) + + class MockResilientFactory: + def __init__(self, *args, **kwargs): + self.client = mock_resilient.client + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + async def aclose(self): + await mock_resilient.aclose() + + def stream(self, method, url, **kwargs): + return mock_resilient.client.stream(method, url, **kwargs) + + with patch("mcpgateway.services.gateway_service.ResilientHttpClient", MockResilientFactory): + headers = {} + ok = await svc._validate_gateway_url("http://example/bad", headers, transport_type="STREAMABLEHTTP") + assert ok is False \ No newline at end of file