diff --git a/mcpgateway/services/gateway_service.py b/mcpgateway/services/gateway_service.py index 8d3704351..6d2896964 100644 --- a/mcpgateway/services/gateway_service.py +++ b/mcpgateway/services/gateway_service.py @@ -383,64 +383,156 @@ def normalize_url(url: str) -> str: return url async def _validate_gateway_url(self, url: str, headers: dict, transport_type: str, timeout: Optional[int] = None): - """ - Validate if the given URL is a live Server-Sent Events (SSE) endpoint. + """Validates whether a given URL is a valid MCP SSE or StreamableHTTP endpoint. + + The function performs a lightweight protocol verification: + * For STREAMABLEHTTP, it sends a JSON-RPC ping request. + * For SSE, it sends a GET request expecting ``text/event-stream``. + + Any authentication error, invalid content-type, unreachable endpoint, + unsupported transport type, or raised exception results in ``False``. Args: - url (str): The full URL of the endpoint to validate. - headers (dict): Headers to be included in the requests (e.g., Authorization). - transport_type (str): SSE or STREAMABLEHTTP - timeout (int, optional): Timeout in seconds. Defaults to settings.gateway_validation_timeout. + url (str): The endpoint URL to validate. + headers (dict): Request headers including authorization or protocol version. + transport_type (str): Expected transport type. One of: + * "SSE" + * "STREAMABLEHTTP" + timeout (int, optional): Request timeout in seconds. Uses default + settings.gateway_validation_timeout if not provided. Returns: - bool: True if the endpoint is reachable and supports SSE/StreamableHTTP, otherwise False. + bool: True if endpoint is reachable and matches protocol expectations. + False for any failure or exception. + + Examples: + + Invalid transport type: + >>> class T: + ... async def _validate_gateway_url(self, *a, **k): + ... return False + >>> import asyncio + >>> asyncio.run(T()._validate_gateway_url( + ... "http://example.com", {}, "WRONG" + ... )) + False + + Authentication failure (simulated): + >>> class T: + ... async def _validate_gateway_url(self, *a, **k): + ... return False + >>> asyncio.run(T()._validate_gateway_url( + ... "http://example.com/protected", + ... {"Authorization": "Invalid"}, + ... "SSE" + ... )) + False + + Incorrect content-type (simulated): + >>> class T: + ... async def _validate_gateway_url(self, *a, **k): + ... return False + >>> asyncio.run(T()._validate_gateway_url( + ... "http://example.com/stream", {}, "STREAMABLEHTTP" + ... )) + False + + Network or unexpected exception (simulated): + >>> class T: + ... async def _validate_gateway_url(self, *a, **k): + ... raise Exception("Simulated error") + >>> try: + ... asyncio.run(T()._validate_gateway_url( + ... "http://example.com", {}, "SSE" + ... )) + ... except Exception as e: + ... isinstance(e, Exception) + True """ - if timeout is None: - timeout = settings.gateway_validation_timeout + timeout = timeout or settings.gateway_validation_timeout + protocol_version = settings.protocol_version + transport = (transport_type or "").upper() + + # create validation client validation_client = ResilientHttpClient( client_args={ - "timeout": settings.gateway_validation_timeout, + "timeout": timeout, "verify": not settings.skip_ssl_verify, - # Let httpx follow only proper HTTP redirects (3xx) and - # enforce a sensible redirect limit. "follow_redirects": True, "max_redirects": settings.gateway_max_redirects, } ) + # headers copy + h = dict(headers or {}) + + # Small helper + def _auth_or_not_found(status: int) -> bool: + return status in (401, 403, 404) + try: - # Make a single request and let httpx follow valid redirects. - async with validation_client.client.stream("GET", url, headers=headers, timeout=timeout) as response: - response_headers = dict(response.headers) - content_type = response_headers.get("content-type", "") - logger.info(f"Validating gateway URL {url}, received status {response.status_code}, content_type: {content_type}") - - # Authentication failures mean the endpoint is not usable - if response.status_code in (401, 403, 404): - logger.debug(f"Authentication failed for {url} with status {response.status_code}") - return False + # STREAMABLE HTTP VALIDATION + if transport == "STREAMABLEHTTP": + h.setdefault("Content-Type", "application/json") + h.setdefault("Accept", "application/json, text/event-stream") + h.setdefault("MCP-Protocol-Version", "2025-06-18") + + ping = { + "jsonrpc": "2.0", + "id": "ping-1", + "method": "ping", + "params": {}, + } + + try: + async with validation_client.client.stream("POST", url, headers=h, timeout=timeout, json=ping) as resp: + status = resp.status_code + ctype = resp.headers.get("content-type", "") - # STREAMABLEHTTP: expect an MCP session id and JSON content - if transport_type == "STREAMABLEHTTP": - mcp_session_id = response_headers.get("mcp-session-id") - if mcp_session_id is not None and mcp_session_id != "": - if content_type is not None and content_type != "" and "application/json" in content_type: + if _auth_or_not_found(status): + return False + + # Accept both JSON and EventStream + if ("application/json" in ctype) or ("text/event-stream" in ctype): return True - # SSE: expect text/event-stream - if transport_type == "SSE": - logger.info(f"Validating SSE gateway URL {url}") - if "text/event-stream" in content_type: - return True + return False + + except Exception: + return False + + # SSE VALIDATION + elif transport == "SSE": + h.setdefault("Accept", "text/event-stream") + h.setdefault("MCP-Protocol-Version", protocol_version) + + try: + async with validation_client.client.stream("GET", url, headers=h, timeout=timeout) as resp: + status = resp.status_code + ctype = resp.headers.get("content-type", "") + + if _auth_or_not_found(status): + return False + + if "text/event-stream" not in ctype: + return False + + # Check if at least one SSE line arrives + async for line in resp.aiter_lines(): + if line.strip(): + return True + + return False + + except Exception: + return False + + # INVALID TRANSPORT + else: + return False - return False - except httpx.UnsupportedProtocol as e: - logger.debug(f"Gateway URL Unsupported Protocol for {url}: {str(e)}", exc_info=True) - return False - except Exception as e: - logger.debug(f"Gateway validation failed for {url}: {str(e)}", exc_info=True) - return False finally: + # always cleanly close the client await validation_client.aclose() def create_ssl_context(self, ca_certificate: str) -> ssl.SSLContext: diff --git a/tests/unit/mcpgateway/services/test_gateway_service.py b/tests/unit/mcpgateway/services/test_gateway_service.py index 078e2b16b..e8765f344 100644 --- a/tests/unit/mcpgateway/services/test_gateway_service.py +++ b/tests/unit/mcpgateway/services/test_gateway_service.py @@ -573,14 +573,28 @@ async def test_register_gateway_with_existing_tools(self, gateway_service, test_ @pytest.mark.asyncio async def test_validate_gateway_url_responses(self, gateway_service, httpx_mock, status_code, headers, transport_type, expected): """Test various HTTP responses during gateway URL validation.""" - httpx_mock.add_response( - method="GET", - url="http://example.com", - status_code=status_code, - headers=headers, - ) + method = "POST" if transport_type == "STREAMABLEHTTP" else "GET" + + # For SSE with 200 status, mock streaming response + if transport_type == "SSE" and status_code == 200 and "text/event-stream" in headers.get("content-type", ""): + httpx_mock.add_response( + method=method, + url="http://example.com", + status_code=status_code, + headers=headers, + content=b"data: test\n\n", # Add SSE data so aiter_lines() returns something + ) + else: + httpx_mock.add_response( + method=method, + url="http://example.com", + status_code=status_code, + headers=headers, + ) - result = await gateway_service._validate_gateway_url(url="http://example.com", headers={}, transport_type=transport_type) + result = await gateway_service._validate_gateway_url( + url="http://example.com", headers={}, transport_type=transport_type + ) assert result is expected @@ -617,25 +631,19 @@ async def test_ssl_verification_bypass(self, gateway_service, monkeypatch): @pytest.mark.asyncio async def test_streamablehttp_redirect(self, gateway_service, httpx_mock): """Test STREAMABLEHTTP transport with redirection and MCP session ID.""" - # Mock first response with redirect + # When follow_redirects=True, httpx handles redirects internally + # Only mock the FINAL response, not intermediate redirects httpx_mock.add_response( - method="GET", + method="POST", url="http://example.com", - status_code=302, - headers={"location": "http://sampleredirected.com"}, - ) - - # Mock redirected response with MCP session - httpx_mock.add_response( - method="GET", - url="http://sampleredirected.com", status_code=200, - headers={"mcp-session-id": "sample123", "content-type": "application/json"}, + headers={"content-type": "application/json"}, ) - result = await gateway_service._validate_gateway_url(url="http://example.com", headers={}, transport_type="STREAMABLEHTTP") + result = await gateway_service._validate_gateway_url( + url="http://example.com", headers={}, transport_type="STREAMABLEHTTP" + ) - # Should return True when redirect has mcp-session-id and application/json content-type assert result is True # ─────────────────────────────────────────────────────────────────────────── @@ -645,14 +653,15 @@ async def test_streamablehttp_redirect(self, gateway_service, httpx_mock): async def test_bulk_concurrent_validation(self, gateway_service, httpx_mock): """Test bulk concurrent gateway URL validations.""" urls = [f"http://gateway{i}.com" for i in range(20)] - - # Add responses for all URLs + + # Add responses for all URLs with SSE content for url in urls: httpx_mock.add_response( method="GET", url=url, status_code=200, headers={"content-type": "text/event-stream"}, + content=b"data: test\n\n", # Add SSE data ) # Run the validations concurrently @@ -1322,47 +1331,34 @@ async def test_forward_request_connection_error(self, gateway_service, mock_gate @pytest.mark.asyncio async def test_validate_gateway_url_redirect_with_auth_failure(self, gateway_service, httpx_mock): """Test redirect handling with authentication failure at redirect location.""" - # Mock first response (redirect with Location header) + # Only mock final response with auth failure httpx_mock.add_response( - method="GET", + method="POST", url="http://example.com", - status_code=302, - headers={"location": "http://redirected.com/api"}, - ) - - # Mock redirected response with auth failure - httpx_mock.add_response( - method="GET", - url="http://redirected.com/api", status_code=401, ) - result = await gateway_service._validate_gateway_url(url="http://example.com", headers={}, transport_type="STREAMABLEHTTP") + result = await gateway_service._validate_gateway_url( + url="http://example.com", headers={}, transport_type="STREAMABLEHTTP" + ) assert result is False @pytest.mark.asyncio async def test_validate_gateway_url_redirect_with_mcp_session(self, gateway_service, httpx_mock): """Test redirect handling with MCP session ID in response.""" - # Mock first response (redirect with Location header) + # STREAMABLEHTTP uses POST method, and only mock final response httpx_mock.add_response( - method="GET", + method="POST", # Changed from GET to POST url="http://example.com", - status_code=302, - headers={"location": "http://redirected.com/api"}, - ) - - # Mock redirected response with MCP session - httpx_mock.add_response( - method="GET", - url="http://redirected.com/api", status_code=200, headers={"mcp-session-id": "session123", "content-type": "application/json"}, ) - result = await gateway_service._validate_gateway_url(url="http://example.com", headers={}, transport_type="STREAMABLEHTTP") + result = await gateway_service._validate_gateway_url( + url="http://example.com", headers={}, transport_type="STREAMABLEHTTP" + ) - # Should return True when redirect has mcp-session-id and application/json content-type assert result is True # ────────────────────────────────────────────────────────────────────