diff --git a/.secrets.baseline b/.secrets.baseline index 0d88f53255..0879b4dea3 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -3,7 +3,7 @@ "files": "^.secrets.baseline|package-lock.json|Cargo.lock|scripts/sign_image.sh|scripts/zap|sonar-project.properties|uv.lock|^.secrets.baseline$", "lines": null }, - "generated_at": "2026-04-13T06:27:05Z", + "generated_at": "2026-04-13T06:40:07Z", "plugins_used": [ { "name": "AWSKeyDetector" @@ -5086,7 +5086,7 @@ "hashed_secret": "fa9beb99e4029ad5a6615399e7bbae21356086b3", "is_secret": false, "is_verified": false, - "line_number": 4239, + "line_number": 4240, "type": "Secret Keyword", "verified_result": null }, @@ -5094,7 +5094,7 @@ "hashed_secret": "559b05f1b2863e725b76e216ac3dadecbf92e244", "is_secret": false, "is_verified": false, - "line_number": 4840, + "line_number": 4841, "type": "Secret Keyword", "verified_result": null }, @@ -5102,7 +5102,7 @@ "hashed_secret": "a8af4759392d4f7496d613174f33afe2074a4b8d", "is_secret": false, "is_verified": false, - "line_number": 4842, + "line_number": 4843, "type": "Secret Keyword", "verified_result": null }, @@ -5110,7 +5110,7 @@ "hashed_secret": "85b60d811d16ff56b3654587d4487f713bfa33b7", "is_secret": false, "is_verified": false, - "line_number": 15050, + "line_number": 15169, "type": "Secret Keyword", "verified_result": null } @@ -5962,7 +5962,7 @@ "hashed_secret": "c377074d6473f35a91001981355da793dc808ffd", "is_secret": false, "is_verified": false, - "line_number": 4197, + "line_number": 4220, "type": "Hex High Entropy String", "verified_result": null }, @@ -5970,7 +5970,7 @@ "hashed_secret": "6367c48dd193d56ea7b0baad25b19455e529f5ee", "is_secret": false, "is_verified": false, - "line_number": 5310, + "line_number": 5333, "type": "Secret Keyword", "verified_result": null }, @@ -5978,7 +5978,7 @@ "hashed_secret": "f2b14f68eb995facb3a1c35287b778d5bd785511", "is_secret": false, "is_verified": false, - "line_number": 5474, + "line_number": 5497, "type": "Secret Keyword", "verified_result": null }, @@ -5986,7 +5986,7 @@ "hashed_secret": "f42a3fabe1e9bed059d727f47eb752e3aa61b977", "is_secret": false, "is_verified": false, - "line_number": 5531, + "line_number": 5554, "type": "Secret Keyword", "verified_result": null }, @@ -5994,7 +5994,7 @@ "hashed_secret": "b85788b459aa4d67e1070930dae6d0827756aadb", "is_secret": false, "is_verified": false, - "line_number": 5569, + "line_number": 5592, "type": "Secret Keyword", "verified_result": null }, @@ -6002,7 +6002,7 @@ "hashed_secret": "52dcc83ec1e54426ad58a64854d1eb8d5f5d9685", "is_secret": false, "is_verified": false, - "line_number": 5570, + "line_number": 5593, "type": "Secret Keyword", "verified_result": null } @@ -9958,7 +9958,7 @@ "hashed_secret": "c00dbbc9dadfbe1e232e93a729dd4752fade0abf", "is_secret": false, "is_verified": false, - "line_number": 14402, + "line_number": 14411, "type": "Secret Keyword", "verified_result": null }, @@ -9966,7 +9966,7 @@ "hashed_secret": "f2b14f68eb995facb3a1c35287b778d5bd785511", "is_secret": false, "is_verified": false, - "line_number": 17159, + "line_number": 17168, "type": "Secret Keyword", "verified_result": null }, @@ -9974,7 +9974,7 @@ "hashed_secret": "a4b48a81cdab1e1a5dd37907d6c85ca1c61ddc7c", "is_secret": false, "is_verified": false, - "line_number": 17178, + "line_number": 17187, "type": "Secret Keyword", "verified_result": null }, @@ -9982,7 +9982,7 @@ "hashed_secret": "dc8002865f92070749b264e76045b04fa3b8de71", "is_secret": false, "is_verified": false, - "line_number": 20836, + "line_number": 20845, "type": "Secret Keyword", "verified_result": null } diff --git a/mcpgateway/admin.py b/mcpgateway/admin.py index f171141970..c513761494 100644 --- a/mcpgateway/admin.py +++ b/mcpgateway/admin.py @@ -136,6 +136,7 @@ from mcpgateway.services.logging_service import LoggingService from mcpgateway.services.mcp_session_pool import get_mcp_session_pool from mcpgateway.services.oauth_manager import OAuthManager +from mcpgateway.services.openapi_service import fetch_and_extract_schemas from mcpgateway.services.performance_service import get_performance_service from mcpgateway.services.permission_service import PermissionService from mcpgateway.services.plugin_service import get_plugin_service @@ -11772,6 +11773,124 @@ async def admin_edit_tool( return ORJSONResponse(content={"message": str(ex), "success": False}, status_code=500) +@admin_router.post("/tools/generate-schemas-from-openapi") +# tools.create — this endpoint makes outbound HTTP requests to user-supplied +# URLs to fetch OpenAPI specs. tools.read would let viewers probe internal +# services; tools.create scopes it to users who can already register tools. +@require_permission("tools.create", allow_admin_bypass=False) +async def generate_schemas_from_openapi( + request: Request, + _user=Depends(get_current_user_with_permissions), +) -> JSONResponse: + """ + Generate input_schema and output_schema from OpenAPI specification URL. + + Expects JSON body with: + - url: The tool URL (e.g., http://localhost:8100/calculate) + - request_type: HTTP method (GET, POST, etc.) + - openapi_url: (optional) Direct OpenAPI spec URL + + Args: + request: FastAPI Request object containing JSON body + + Returns: + JSONResponse with generated schemas or error message. + """ + try: + body = await _read_request_json(request) + except Exception: + return ORJSONResponse( + content={"message": "Invalid JSON in request body", "success": False}, + status_code=400, + ) + + if not isinstance(body, dict): + return ORJSONResponse( + content={"message": "Request body must be a JSON object", "success": False}, + status_code=400, + ) + + tool_url = body.get("url", "") + request_type = body.get("request_type", "GET") + openapi_url = body.get("openapi_url", "") + + if not isinstance(tool_url, str) or not isinstance(request_type, str) or not isinstance(openapi_url, str): + return ORJSONResponse( + content={"message": "'url', 'request_type', and 'openapi_url' must be strings", "success": False}, + status_code=400, + ) + + tool_url = tool_url.strip() + request_type = request_type.strip() + openapi_url = openapi_url.strip() + + if not tool_url: + return ORJSONResponse( + content={"message": "'url' is required to identify the API path and base URL", "success": False}, + status_code=400, + ) + + try: + SecurityValidator.validate_url(tool_url, "Tool URL") + except ValueError as e: + return ORJSONResponse( + content={"message": str(e), "success": False}, + status_code=400, + ) + + parsed = urllib.parse.urlparse(tool_url) + base_url = f"{parsed.scheme}://{parsed.netloc}" + tool_path = parsed.path + + try: + input_schema, output_schema, spec_url = await fetch_and_extract_schemas( + base_url=base_url, + path=tool_path, + method=request_type, + openapi_url=openapi_url, + timeout=10.0, + ) + except ValueError as e: + return ORJSONResponse( + content={"message": f"Security validation failed: {str(e)}", "success": False}, + status_code=400, + ) + except KeyError as e: + return ORJSONResponse( + content={"message": str(e), "success": False}, + status_code=404, + ) + except httpx.HTTPStatusError as e: + LOGGER.warning("OpenAPI spec server returned HTTP %s", e.response.status_code, exc_info=True) + return ORJSONResponse( + content={"message": f"OpenAPI spec server returned HTTP {e.response.status_code}", "success": False}, + status_code=502, + ) + except httpx.HTTPError: + LOGGER.warning("Failed to fetch OpenAPI spec", exc_info=True) + return ORJSONResponse( + content={"message": "Failed to fetch OpenAPI spec from the provided URL", "success": False}, + status_code=502, + ) + except Exception: + LOGGER.error("Error fetching OpenAPI spec", exc_info=True) + return ORJSONResponse( + content={"message": "An unexpected error occurred while processing the OpenAPI spec", "success": False}, + status_code=500, + ) + + return ORJSONResponse( + content={ + "message": "Schemas generated successfully from OpenAPI spec", + "success": True, + "input_schema": input_schema, + "output_schema": output_schema, + "spec_url": spec_url, + }, + status_code=200, + ) + + @admin_router.post("/tools/{tool_id}/delete") @require_permission("tools.delete", allow_admin_bypass=False) async def admin_delete_tool(tool_id: str, request: Request, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions)) -> RedirectResponse: diff --git a/mcpgateway/schemas.py b/mcpgateway/schemas.py index 6f170dbce3..e3fa2ba8ff 100644 --- a/mcpgateway/schemas.py +++ b/mcpgateway/schemas.py @@ -427,6 +427,43 @@ class AuthenticationValues(BaseModelWithConfigDict): authHeaders: Optional[List[Dict[str, str]]] = Field(None, alias="authHeaders", description="List of custom headers for authentication (multi-header format)") # noqa: N815 +# Minimal valid JSON Schema used as the default input_schema for REST tools. +_DEFAULT_INPUT_SCHEMA: dict = {"type": "object", "properties": {}} + + +def _extract_rest_url_components(values: dict) -> dict: + """Extract ``base_url`` and ``path_template`` from ``url`` for REST integration tools. + + Shared logic used by both :class:`ToolCreate` and :class:`ToolUpdate` model + validators so the URL-parsing behaviour stays consistent across create and + update paths. + + Args: + values: The raw model input dict (mutated in-place). + + Returns: + The same *values* dict, potentially with ``base_url`` and + ``path_template`` populated. + """ + url = values.get("url") + if not url: + return values + + parsed = urlparse(str(url)) + base_url = f"{parsed.scheme}://{parsed.netloc}" + path_template = parsed.path + + if path_template: + path_template = "/" + path_template.lstrip("/") + + if not values.get("base_url"): + values["base_url"] = base_url + if not values.get("path_template"): + values["path_template"] = path_template + + return values + + class ToolCreate(BaseModel): """ Represents the configuration for creating a tool with various attributes and settings. @@ -458,7 +495,7 @@ class ToolCreate(BaseModel): integration_type: Literal["REST", "MCP", "A2A"] = Field("REST", description="'REST' for individual endpoints, 'MCP' for gateway-discovered tools, 'A2A' for A2A agents") request_type: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "SSE", "STDIO", "STREAMABLEHTTP"] = Field("SSE", description="HTTP method to be used for invoking the tool") headers: Optional[Dict[str, str]] = Field(None, description="Additional headers to send when invoking the tool") - input_schema: Optional[Dict[str, Any]] = Field(default_factory=lambda: {"type": "object", "properties": {}}, description="JSON Schema for validating tool parameters", alias="inputSchema") + input_schema: Optional[Dict[str, Any]] = Field(default_factory=lambda: dict(_DEFAULT_INPUT_SCHEMA), description="JSON Schema for validating tool parameters", alias="inputSchema") output_schema: Optional[Dict[str, Any]] = Field(default=None, description="JSON Schema for validating tool output", alias="outputSchema") annotations: Optional[Dict[str, Any]] = Field( default_factory=dict, @@ -868,33 +905,22 @@ def enforce_passthrough_fields_for_rest(cls, values: Dict[str, Any]) -> Dict[str @model_validator(mode="before") @classmethod def extract_base_url_and_path_template(cls, values: dict) -> dict: - """ - Only for integration_type 'REST': - If 'url' is provided, extract 'base_url' and 'path_template'. - Ensures path_template starts with a single '/'. + """For REST tools: extract URL components and ensure a default input_schema. Args: values (dict): The input values to process. Returns: - dict: The updated values with base_url and path_template if applicable. + dict: The updated values with base_url and path_template extracted from url. """ - integration_type = values.get("integration_type") - if integration_type != "REST": - # Only process for REST, skip for others + if values.get("integration_type") != "REST": return values - url = values.get("url") - if url: - parsed = urlparse(str(url)) - base_url = f"{parsed.scheme}://{parsed.netloc}" - path_template = parsed.path - # Ensure path_template starts with a single '/' - if path_template: - path_template = "/" + path_template.lstrip("/") - if not values.get("base_url"): - values["base_url"] = base_url - if not values.get("path_template"): - values["path_template"] = path_template + + _extract_rest_url_components(values) + + if not values.get("input_schema"): + values["input_schema"] = dict(_DEFAULT_INPUT_SCHEMA) + return values @field_validator("base_url") @@ -1267,6 +1293,31 @@ def assemble_auth(cls, values: Dict[str, Any]) -> Dict[str, Any]: values["auth"] = {"auth_type": "authheaders", "auth_value": None} return values + @model_validator(mode="before") + @classmethod + def extract_base_url_and_path_template(cls, values: dict) -> dict: + """For REST tools: extract URL components and normalise empty input_schema. + + Args: + values (dict): The input values to process. + + Returns: + dict: The updated values with base_url and path_template extracted from url. + """ + if values.get("integration_type") != "REST": + return values + + _extract_rest_url_components(values) + + # Normalise explicitly-empty input_schema to the typed default. + # None is left alone (partial update semantics — omitted fields + # should not overwrite existing values in the database). + input_schema = values.get("input_schema") + if input_schema is not None and isinstance(input_schema, dict) and not input_schema: + values["input_schema"] = dict(_DEFAULT_INPUT_SCHEMA) + + return values + @field_validator("displayName") @classmethod def validate_display_name(cls, v: Optional[str]) -> Optional[str]: @@ -1321,34 +1372,6 @@ def prevent_manual_mcp_update(cls, values: Dict[str, Any]) -> Dict[str, Any]: raise ValueError("Cannot update tools to A2A integration type. A2A tools are managed by the A2A service.") return values - @model_validator(mode="before") - @classmethod - def extract_base_url_and_path_template(cls, values: dict) -> dict: - """ - If 'integration_type' is 'REST' and 'url' is provided, extract 'base_url' and 'path_template'. - Ensures path_template starts with a single '/'. - - Args: - values (dict): The input values to process. - - Returns: - dict: The updated values with base_url and path_template if applicable. - """ - integration_type = values.get("integration_type") - url = values.get("url") - if integration_type == "REST" and url: - parsed = urlparse(str(url)) - base_url = f"{parsed.scheme}://{parsed.netloc}" - path_template = parsed.path - # Ensure path_template starts with a single '/' - if path_template: - path_template = "/" + path_template.lstrip("/") - if not values.get("base_url"): - values["base_url"] = base_url - if not values.get("path_template"): - values["path_template"] = path_template - return values - @field_validator("base_url") @classmethod def validate_base_url(cls, v): diff --git a/mcpgateway/services/http_client_service.py b/mcpgateway/services/http_client_service.py index 63a9c128d6..0ca7b84272 100644 --- a/mcpgateway/services/http_client_service.py +++ b/mcpgateway/services/http_client_service.py @@ -310,6 +310,7 @@ async def get_isolated_http_client( connect_timeout: Optional[float] = None, write_timeout: Optional[float] = None, pool_timeout: Optional[float] = None, + follow_redirects: bool = True, ) -> AsyncIterator[httpx.AsyncClient]: """ Create an isolated HTTP client with custom settings. @@ -330,6 +331,8 @@ async def get_isolated_http_client( connect_timeout: Optional connect timeout override (seconds). write_timeout: Optional write timeout override (seconds). pool_timeout: Optional pool timeout override (seconds). + follow_redirects: Whether to follow HTTP redirects (default: True). + Set to False for SSRF-sensitive requests. Yields: httpx.AsyncClient: A new isolated client instance. @@ -359,6 +362,6 @@ async def get_isolated_http_client( verify=effective_verify, auth=auth, http2=http2 if http2 is not None else settings.httpx_http2_enabled, - follow_redirects=True, + follow_redirects=follow_redirects, ) as client: yield client diff --git a/mcpgateway/services/openapi_service.py b/mcpgateway/services/openapi_service.py new file mode 100644 index 0000000000..eb0d230db2 --- /dev/null +++ b/mcpgateway/services/openapi_service.py @@ -0,0 +1,199 @@ +# -*- coding: utf-8 -*- +"""Location: ./mcpgateway/services/openapi_service.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Mihai Criveti + +OpenAPI Service for ContextForge AI Gateway. +This module provides services for fetching and extracting schemas from OpenAPI specifications. +""" + +# Standard +import logging +from typing import Optional, Tuple +import urllib.parse + +# Third-Party +import orjson + +# First-Party +from mcpgateway.common.validators import SecurityValidator +from mcpgateway.services.http_client_service import get_isolated_http_client + +logger = logging.getLogger(__name__) + + +def _resolve_schema(schema_obj: Optional[dict], components_schemas: dict) -> Optional[dict]: + """Resolve a schema from a ``$ref`` reference or return an inline schema. + + Only resolves top-level local ``$ref`` references of the form + ``#/components/schemas/``. Nested ``$ref`` chains (a resolved + schema that itself contains ``$ref``) and external file references + (e.g. ``./models.json#/Foo``) are **not** supported and will return + ``None`` or the unresolved object respectively. + + Args: + schema_obj: Schema object that may contain a ``$ref`` or inline schema. + components_schemas: The ``components.schemas`` section of the OpenAPI spec. + + Returns: + Resolved schema dictionary, or ``None`` if no valid schema found. + """ + if isinstance(schema_obj, dict) and "$ref" in schema_obj: + ref_path = schema_obj["$ref"] + if not ref_path.startswith("#/components/schemas/"): + logger.warning("Unsupported $ref format '%s': only local #/components/schemas/ references are resolved", ref_path) + return None + schema_name = ref_path.split("/")[-1] + resolved = components_schemas.get(schema_name) + if resolved is None: + logger.warning("Unresolved $ref '%s': schema '%s' not found in components.schemas", ref_path, schema_name) + return resolved + return schema_obj if schema_obj is not None else None + + +# 10 MiB — generous for any realistic OpenAPI spec, prevents memory exhaustion from malicious servers. +_MAX_SPEC_BYTES = 10 * 1024 * 1024 + + +async def fetch_openapi_spec(spec_url: str, timeout: float = 10.0) -> dict: + """ + Fetch OpenAPI specification from a URL with SSRF protection. + + Redirects are disabled to prevent SSRF bypass (an attacker-controlled + server could redirect to an internal address after the initial URL + passes validation). Response bodies larger than ``_MAX_SPEC_BYTES`` + are rejected to guard against memory exhaustion. + + Args: + spec_url: The URL to fetch the OpenAPI spec from + timeout: Request timeout in seconds (default: 10.0) + + Returns: + dict: The parsed OpenAPI specification + + Raises: + ValueError: If URL fails security validation, response is too large, or + response body is not valid JSON + httpx.HTTPError: If the request fails + """ + # SSRF Protection: Validate the spec URL before making request + SecurityValidator.validate_url(spec_url, "OpenAPI spec URL") + + async with get_isolated_http_client(timeout=timeout, follow_redirects=False) as client: + async with client.stream("GET", spec_url) as response: + response.raise_for_status() + + # Early reject via Content-Length when the header is present. + try: + cl = int(response.headers.get("content-length", "0")) + except (ValueError, OverflowError): + cl = 0 # Malformed header — fall through to streamed check below + if cl > _MAX_SPEC_BYTES: + raise ValueError(f"OpenAPI spec response too large ({cl} bytes, max {_MAX_SPEC_BYTES})") + + # Stream the body in chunks so we never buffer more than the cap. + chunks: list[bytes] = [] + total = 0 + async for chunk in response.aiter_bytes(chunk_size=8192): + total += len(chunk) + if total > _MAX_SPEC_BYTES: + raise ValueError(f"OpenAPI spec response too large (>{_MAX_SPEC_BYTES} bytes)") + chunks.append(chunk) + + body = b"".join(chunks) + + try: + return orjson.loads(body) + except (orjson.JSONDecodeError, ValueError) as exc: + raise ValueError("Response is not valid JSON. Ensure the URL points to a JSON OpenAPI specification.") from exc + + +def extract_schemas_from_openapi( + spec: dict, + path: str, + method: str, +) -> Tuple[Optional[dict], Optional[dict]]: + """Extract input and output schemas from an OpenAPI specification. + + Args: + spec: The OpenAPI specification dictionary. + path: The API path (e.g., ``"/calculate"``). + method: The HTTP method (e.g., ``"post"``). + + Returns: + Tuple of (input_schema, output_schema), either may be ``None``. + + Raises: + KeyError: If *path* or *method* is not found in the spec. + """ + method = method.lower() + + # Check if path and method exist in spec + if path not in spec.get("paths", {}): + raise KeyError(f"Path '{path}' not found in OpenAPI spec") + + if method not in spec["paths"][path]: + raise KeyError(f"Method '{method}' not found for path '{path}'") + + operation = spec["paths"][path][method] + components_schemas = spec.get("components", {}).get("schemas", {}) + + # Extract input schema from requestBody + input_schema = None + request_body = operation.get("requestBody", {}) + if request_body: + json_content = request_body.get("content", {}).get("application/json", {}) + if "schema" in json_content: + input_schema = _resolve_schema(json_content["schema"], components_schemas) + + # Extract output schema from responses (200, 201, or default) + output_schema = None + responses = operation.get("responses", {}) + success_response = responses.get("200") if "200" in responses else responses.get("201") + if success_response: + json_content = success_response.get("content", {}).get("application/json", {}) + if "schema" in json_content: + output_schema = _resolve_schema(json_content["schema"], components_schemas) + + return input_schema, output_schema + + +async def fetch_and_extract_schemas( + base_url: str, + path: str, + method: str, + openapi_url: Optional[str] = None, + timeout: float = 10.0, +) -> Tuple[Optional[dict], Optional[dict], str]: + """ + Fetch OpenAPI spec and extract input/output schemas with SSRF protection. + + Args: + base_url: The base URL of the API (e.g., "http://localhost:8100") + path: The API path (e.g., "/calculate") + method: The HTTP method (e.g., "POST") + openapi_url: Optional direct URL to OpenAPI spec (overrides base_url) + timeout: Request timeout in seconds (default: 10.0) + + Returns: + Tuple of (input_schema, output_schema, spec_url) + + Raises: + ValueError: If URL fails security validation + httpx.HTTPError: If the request fails + KeyError: If path or method not found in spec + """ + # Determine OpenAPI spec URL + if openapi_url: + spec_url = openapi_url + else: + spec_url = urllib.parse.urljoin(base_url, "/openapi.json") + + # Fetch the spec with SSRF protection + spec = await fetch_openapi_spec(spec_url, timeout=timeout) + + # Extract schemas + input_schema, output_schema = extract_schemas_from_openapi(spec, path, method) + + return input_schema, output_schema, spec_url diff --git a/tests/unit/mcpgateway/services/test_openapi_service.py b/tests/unit/mcpgateway/services/test_openapi_service.py new file mode 100644 index 0000000000..2406fc09ad --- /dev/null +++ b/tests/unit/mcpgateway/services/test_openapi_service.py @@ -0,0 +1,433 @@ +# -*- coding: utf-8 -*- +"""Location: ./tests/unit/mcpgateway/services/test_openapi_service.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Unit tests for OpenAPI service. +""" + +# Standard +from contextlib import asynccontextmanager +from typing import Optional +from unittest.mock import AsyncMock, MagicMock, patch + +# Third-Party +import httpx +import orjson +import pytest + +# First-Party +from mcpgateway.services.openapi_service import ( + _MAX_SPEC_BYTES, + extract_schemas_from_openapi, + fetch_and_extract_schemas, + fetch_openapi_spec, +) + + +class TestExtractSchemasFromOpenAPI: + """Tests for extract_schemas_from_openapi function.""" + + def test_extract_inline_schemas(self): + """Test extraction of inline schemas (no $ref).""" + spec = { + "paths": { + "/calculate": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "a": {"type": "number"}, + "b": {"type": "number"}, + }, + } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": {"result": {"type": "number"}}, + } + } + } + } + }, + } + } + } + } + + input_schema, output_schema = extract_schemas_from_openapi(spec, "/calculate", "post") + + assert input_schema is not None + assert input_schema["type"] == "object" + assert "a" in input_schema["properties"] + assert "b" in input_schema["properties"] + + assert output_schema is not None + assert output_schema["type"] == "object" + assert "result" in output_schema["properties"] + + def test_extract_ref_schemas(self): + """Test extraction of schemas with $ref references.""" + spec = { + "paths": { + "/calculate": { + "post": { + "requestBody": {"content": {"application/json": {"schema": {"$ref": "#/components/schemas/CalculateRequest"}}}}, + "responses": {"200": {"content": {"application/json": {"schema": {"$ref": "#/components/schemas/CalculateResponse"}}}}}, + } + } + }, + "components": { + "schemas": { + "CalculateRequest": { + "type": "object", + "properties": {"x": {"type": "number"}, "y": {"type": "number"}}, + }, + "CalculateResponse": {"type": "object", "properties": {"sum": {"type": "number"}}}, + } + }, + } + + input_schema, output_schema = extract_schemas_from_openapi(spec, "/calculate", "post") + + assert input_schema is not None + assert input_schema["type"] == "object" + assert "x" in input_schema["properties"] + assert "y" in input_schema["properties"] + + assert output_schema is not None + assert output_schema["type"] == "object" + assert "sum" in output_schema["properties"] + + def test_extract_with_201_response(self): + """Test extraction when response is 201 instead of 200.""" + spec = { + "paths": { + "/create": { + "post": { + "requestBody": {"content": {"application/json": {"schema": {"type": "object", "properties": {"name": {"type": "string"}}}}}}, + "responses": {"201": {"content": {"application/json": {"schema": {"type": "object", "properties": {"id": {"type": "string"}}}}}}}, + } + } + } + } + + input_schema, output_schema = extract_schemas_from_openapi(spec, "/create", "post") + + assert input_schema is not None + assert output_schema is not None + assert "id" in output_schema["properties"] + + def test_extract_no_request_body(self): + """Test extraction when there's no request body (GET request).""" + spec = {"paths": {"/status": {"get": {"responses": {"200": {"content": {"application/json": {"schema": {"type": "object", "properties": {"status": {"type": "string"}}}}}}}}}}} + + input_schema, output_schema = extract_schemas_from_openapi(spec, "/status", "get") + + assert input_schema is None + assert output_schema is not None + assert "status" in output_schema["properties"] + + def test_extract_no_response_schema(self): + """Test extraction when there's no response schema.""" + spec = { + "paths": { + "/delete": { + "delete": { + "requestBody": {"content": {"application/json": {"schema": {"type": "object", "properties": {"id": {"type": "string"}}}}}}, + "responses": {"204": {"description": "No content"}}, + } + } + } + } + + input_schema, output_schema = extract_schemas_from_openapi(spec, "/delete", "delete") + + assert input_schema is not None + assert output_schema is None + + def test_path_not_found(self): + """Test error when path doesn't exist in spec.""" + spec = {"paths": {"/calculate": {"post": {}}}} + + with pytest.raises(KeyError, match="Path '/nonexistent' not found"): + extract_schemas_from_openapi(spec, "/nonexistent", "post") + + def test_method_not_found(self): + """Test error when method doesn't exist for path.""" + spec = {"paths": {"/calculate": {"post": {}}}} + + with pytest.raises(KeyError, match="Method 'get' not found"): + extract_schemas_from_openapi(spec, "/calculate", "get") + + def test_method_case_insensitive(self): + """Test that method matching is case-insensitive.""" + spec = {"paths": {"/test": {"post": {"responses": {"200": {"content": {"application/json": {"schema": {"type": "object"}}}}}}}}} + + # Should work with uppercase + input_schema, output_schema = extract_schemas_from_openapi(spec, "/test", "POST") + assert output_schema is not None + + # Should work with mixed case + input_schema, output_schema = extract_schemas_from_openapi(spec, "/test", "Post") + assert output_schema is not None + + def test_missing_ref_returns_none(self): + """Test that missing $ref returns None instead of raising error.""" + spec = { + "paths": { + "/test": { + "post": { + "requestBody": {"content": {"application/json": {"schema": {"$ref": "#/components/schemas/NonExistent"}}}}, + "responses": {"200": {"content": {"application/json": {"schema": {"type": "object"}}}}}, + } + } + }, + "components": {"schemas": {}}, + } + + input_schema, output_schema = extract_schemas_from_openapi(spec, "/test", "post") + + # Missing ref should return None + assert input_schema is None + assert output_schema is not None + + def test_missing_ref_logs_warning(self, caplog): + """Unresolved $ref logs a warning with the ref path and schema name.""" + spec = { + "paths": { + "/test": { + "post": { + "requestBody": {"content": {"application/json": {"schema": {"$ref": "#/components/schemas/Missing"}}}}, + "responses": {"200": {"content": {"application/json": {"schema": {"type": "object"}}}}}, + } + } + }, + "components": {"schemas": {}}, + } + + with caplog.at_level("WARNING", logger="mcpgateway.services.openapi_service"): + extract_schemas_from_openapi(spec, "/test", "post") + + assert any("Unresolved $ref" in msg and "Missing" in msg for msg in caplog.messages) + + def test_unsupported_ref_format_returns_none_and_logs(self, caplog): + """External or malformed $ref returns None and logs a warning.""" + spec = { + "paths": { + "/test": { + "post": { + "requestBody": {"content": {"application/json": {"schema": {"$ref": "https://external.com/schemas/Foo"}}}}, + "responses": {"200": {"content": {"application/json": {"schema": {"$ref": "SomeGarbage"}}}}}, + } + } + }, + "components": {"schemas": {"Foo": {"type": "object"}}}, + } + + with caplog.at_level("WARNING", logger="mcpgateway.services.openapi_service"): + input_schema, output_schema = extract_schemas_from_openapi(spec, "/test", "post") + + assert input_schema is None + assert output_schema is None + assert any("Unsupported $ref format" in msg for msg in caplog.messages) + + +@asynccontextmanager +async def _mock_isolated_client(body: bytes, headers: Optional[dict] = None, raise_for_status: Optional[Exception] = None): + """Async context manager mimicking ``get_isolated_http_client`` with canned responses.""" + + async def _aiter_bytes(chunk_size=8192): + for i in range(0, len(body), chunk_size): + yield body[i : i + chunk_size] + + mock_response = MagicMock() + mock_response.headers = headers or {} + if raise_for_status: + mock_response.raise_for_status.side_effect = raise_for_status + else: + mock_response.raise_for_status = MagicMock() + mock_response.aiter_bytes = _aiter_bytes + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=None) + + mock_client = AsyncMock() + mock_client.stream = MagicMock(return_value=mock_response) + yield mock_client + + +_PATCH_CLIENT = "mcpgateway.services.openapi_service.get_isolated_http_client" +_PATCH_VALIDATE = "mcpgateway.services.openapi_service.SecurityValidator.validate_url" + + +class TestFetchOpenAPISpec: + """Tests for fetch_openapi_spec function.""" + + @pytest.mark.asyncio + async def test_fetch_success(self): + """Test successful fetch of OpenAPI spec.""" + mock_spec = {"openapi": "3.0.0", "paths": {}} + + with patch(_PATCH_CLIENT, return_value=_mock_isolated_client(orjson.dumps(mock_spec))): + with patch(_PATCH_VALIDATE): + result = await fetch_openapi_spec("http://example.com/openapi.json") + + assert result == mock_spec + + @pytest.mark.asyncio + async def test_fetch_with_ssrf_validation(self): + """Test that SSRF validation is called when enabled.""" + with patch(_PATCH_CLIENT, return_value=_mock_isolated_client(orjson.dumps({"openapi": "3.0.0"}))): + with patch(_PATCH_VALIDATE) as mock_validate_url: + await fetch_openapi_spec("http://example.com/openapi.json") + + mock_validate_url.assert_called_once() + + @pytest.mark.asyncio + async def test_fetch_url_validation_failure(self): + """Test that URL validation errors are propagated.""" + with patch(_PATCH_VALIDATE) as mock_validate: + mock_validate.side_effect = ValueError("Invalid URL") + + with pytest.raises(ValueError, match="Invalid URL"): + await fetch_openapi_spec("javascript:alert(1)") + + @pytest.mark.asyncio + async def test_fetch_http_error(self): + """Test handling of HTTP errors.""" + error = httpx.HTTPStatusError("404 Not Found", request=MagicMock(), response=MagicMock()) + + with patch(_PATCH_CLIENT, return_value=_mock_isolated_client(b"", raise_for_status=error)): + with patch(_PATCH_VALIDATE): + with pytest.raises(httpx.HTTPStatusError): + await fetch_openapi_spec("http://example.com/openapi.json") + + @pytest.mark.asyncio + async def test_fetch_timeout(self): + """Test custom timeout is passed to get_isolated_http_client.""" + with patch(_PATCH_CLIENT, return_value=_mock_isolated_client(orjson.dumps({"openapi": "3.0.0"}))) as mock_get_client: + with patch(_PATCH_VALIDATE): + await fetch_openapi_spec("http://example.com/openapi.json", timeout=5.0) + + mock_get_client.assert_called_once_with(timeout=5.0, follow_redirects=False) + + @pytest.mark.asyncio + async def test_rejects_response_with_content_length_exceeding_limit(self): + """Content-Length header exceeding _MAX_SPEC_BYTES raises ValueError.""" + with patch(_PATCH_CLIENT, return_value=_mock_isolated_client(b"", headers={"content-length": str(_MAX_SPEC_BYTES + 1)})): + with patch(_PATCH_VALIDATE): + with pytest.raises(ValueError, match="too large"): + await fetch_openapi_spec("http://example.com/openapi.json") + + @pytest.mark.asyncio + async def test_rejects_response_body_exceeding_limit(self): + """Response body exceeding _MAX_SPEC_BYTES raises ValueError during streaming.""" + with patch(_PATCH_CLIENT, return_value=_mock_isolated_client(b"x" * (_MAX_SPEC_BYTES + 1))): + with patch(_PATCH_VALIDATE): + with pytest.raises(ValueError, match="too large"): + await fetch_openapi_spec("http://example.com/openapi.json") + + @pytest.mark.asyncio + async def test_malformed_content_length_falls_through_to_body_check(self): + """Malformed Content-Length header doesn't crash — falls through to streamed check.""" + mock_spec = {"openapi": "3.0.0"} + + with patch(_PATCH_CLIENT, return_value=_mock_isolated_client(orjson.dumps(mock_spec), headers={"content-length": "not-a-number"})): + with patch(_PATCH_VALIDATE): + result = await fetch_openapi_spec("http://example.com/openapi.json") + + assert result == mock_spec + + @pytest.mark.asyncio + async def test_invalid_json_response_raises_valueerror(self): + """Non-JSON response body (e.g. HTML) raises ValueError with clear message.""" + with patch(_PATCH_CLIENT, return_value=_mock_isolated_client(b"Not Found")): + with patch(_PATCH_VALIDATE): + with pytest.raises(ValueError, match="not valid JSON"): + await fetch_openapi_spec("http://example.com/openapi.json") + + +class TestFetchAndExtractSchemas: + """Tests for fetch_and_extract_schemas function.""" + + @pytest.mark.asyncio + async def test_fetch_and_extract_success(self): + """Test successful fetch and extraction.""" + mock_spec = { + "paths": { + "/calculate": { + "post": { + "requestBody": {"content": {"application/json": {"schema": {"type": "object", "properties": {"x": {"type": "number"}}}}}}, + "responses": {"200": {"content": {"application/json": {"schema": {"type": "object", "properties": {"result": {"type": "number"}}}}}}}, + } + } + } + } + + with patch("mcpgateway.services.openapi_service.fetch_openapi_spec") as mock_fetch: + mock_fetch.return_value = mock_spec + + input_schema, output_schema, spec_url = await fetch_and_extract_schemas(base_url="http://localhost:8100", path="/calculate", method="POST") + + assert input_schema is not None + assert "x" in input_schema["properties"] + assert output_schema is not None + assert "result" in output_schema["properties"] + assert spec_url == "http://localhost:8100/openapi.json" + + @pytest.mark.asyncio + async def test_fetch_and_extract_with_custom_openapi_url(self): + """Test using custom OpenAPI URL instead of base_url.""" + mock_spec = {"paths": {"/test": {"get": {"responses": {"200": {"content": {"application/json": {"schema": {"type": "object"}}}}}}}}} + + with patch("mcpgateway.services.openapi_service.fetch_openapi_spec") as mock_fetch: + mock_fetch.return_value = mock_spec + + input_schema, output_schema, spec_url = await fetch_and_extract_schemas( + base_url="http://localhost:8100", + path="/test", + method="GET", + openapi_url="http://custom.com/spec.json", + ) + + # Should use custom URL + assert spec_url == "http://custom.com/spec.json" + mock_fetch.assert_called_once_with("http://custom.com/spec.json", timeout=10.0) + + @pytest.mark.asyncio + async def test_fetch_and_extract_path_not_found(self): + """Test error propagation when path not found.""" + mock_spec = {"paths": {"/other": {"get": {}}}} + + with patch("mcpgateway.services.openapi_service.fetch_openapi_spec") as mock_fetch: + mock_fetch.return_value = mock_spec + + with pytest.raises(KeyError, match="Path '/calculate' not found"): + await fetch_and_extract_schemas(base_url="http://localhost:8100", path="/calculate", method="POST") + + @pytest.mark.asyncio + async def test_fetch_and_extract_custom_timeout(self): + """Test custom timeout is passed through.""" + mock_spec = {"paths": {"/test": {"get": {"responses": {"200": {}}}}}} + + with patch("mcpgateway.services.openapi_service.fetch_openapi_spec") as mock_fetch: + mock_fetch.return_value = mock_spec + + await fetch_and_extract_schemas( + base_url="http://localhost:8100", + path="/test", + method="GET", + timeout=5.0, + ) + + # Verify timeout was passed + mock_fetch.assert_called_once_with("http://localhost:8100/openapi.json", timeout=5.0) diff --git a/tests/unit/mcpgateway/test_admin.py b/tests/unit/mcpgateway/test_admin.py index e2da62402a..a5056f7587 100644 --- a/tests/unit/mcpgateway/test_admin.py +++ b/tests/unit/mcpgateway/test_admin.py @@ -1586,7 +1586,7 @@ async def test_admin_edit_tool_with_empty_optional_fields(self, mock_update_tool call_args = mock_update_tool.call_args[0] tool_update = call_args[2] assert tool_update.headers == {} - assert tool_update.input_schema == {} + assert tool_update.input_schema == {"type": "object", "properties": {}} @patch.object(ToolService, "register_tool") async def test_admin_add_tool_with_basic_auth(self, mock_register_tool, mock_request, mock_db): @@ -12057,6 +12057,7 @@ async def test_admin_unified_search_empty_query_and_tags_returns_empty(mock_db, @pytest.mark.asyncio async def test_admin_search_roots_returns_matching_by_name(allow_permission, monkeypatch): """admin_search_roots returns roots whose name contains the query.""" + # First-Party from mcpgateway.common.models import Root root_tmp = Root(uri="file:///tmp", name="tmp") @@ -12073,6 +12074,7 @@ async def test_admin_search_roots_returns_matching_by_name(allow_permission, mon @pytest.mark.asyncio async def test_admin_search_roots_matches_by_uri(allow_permission, monkeypatch): """admin_search_roots returns roots whose URI contains the query.""" + # First-Party from mcpgateway.common.models import Root root = Root(uri="file:///project/data", name="data") @@ -12087,6 +12089,7 @@ async def test_admin_search_roots_matches_by_uri(allow_permission, monkeypatch): @pytest.mark.asyncio async def test_admin_search_roots_empty_query_returns_all(allow_permission, monkeypatch): """admin_search_roots with empty query returns all roots.""" + # First-Party from mcpgateway.common.models import Root roots = [Root(uri="file:///tmp", name="tmp"), Root(uri="file:///home", name="home")] @@ -12100,6 +12103,7 @@ async def test_admin_search_roots_empty_query_returns_all(allow_permission, monk @pytest.mark.asyncio async def test_admin_search_roots_no_match_returns_empty(allow_permission, monkeypatch): """admin_search_roots returns empty list when no roots match the query.""" + # First-Party from mcpgateway.common.models import Root roots = [Root(uri="file:///tmp", name="tmp")] @@ -12114,6 +12118,7 @@ async def test_admin_search_roots_no_match_returns_empty(allow_permission, monke @pytest.mark.asyncio async def test_admin_search_roots_respects_limit(allow_permission, monkeypatch): """admin_search_roots respects the limit parameter.""" + # First-Party from mcpgateway.common.models import Root roots = [Root(uri=f"file:///dir{i}", name=f"dir{i}") for i in range(10)] @@ -12127,6 +12132,7 @@ async def test_admin_search_roots_respects_limit(allow_permission, monkeypatch): @pytest.mark.asyncio async def test_admin_search_roots_case_insensitive(allow_permission, monkeypatch): """admin_search_roots performs case-insensitive matching on both name and URI.""" + # First-Party from mcpgateway.common.models import Root root = Root(uri="file:///TMP", name="MyRoot") @@ -12142,6 +12148,7 @@ async def test_admin_search_roots_case_insensitive(allow_permission, monkeypatch @pytest.mark.asyncio async def test_admin_search_roots_null_name_falls_back_to_uri(allow_permission, monkeypatch): """admin_search_roots returns the URI as name when root.name is None.""" + # First-Party from mcpgateway.common.models import Root root = Root(uri="file:///tmp") @@ -12249,6 +12256,7 @@ async def test_admin_unified_search_roots_swallows_http_exception(monkeypatch, m @pytest.mark.asyncio async def test_admin_search_roots_denies_without_system_config_permission(monkeypatch, mock_db): + # First-Party from mcpgateway.common.models import Root monkeypatch.setattr("mcpgateway.admin.root_service", MagicMock(list_roots=AsyncMock(return_value=[Root(uri="file:///tmp", name="tmp")]))) @@ -12320,6 +12328,7 @@ async def _check_permission(**kwargs): @pytest.mark.parametrize("raw_limit", [0, -5, 1_000_000]) async def test_admin_search_roots_clamps_out_of_range_limit(raw_limit, allow_permission, monkeypatch): """Defense-in-depth clamp: direct Python callers bypass FastAPI ge/le validation.""" + # First-Party from mcpgateway.common.models import Root from mcpgateway.config import settings diff --git a/tests/unit/mcpgateway/test_admin_openapi.py b/tests/unit/mcpgateway/test_admin_openapi.py new file mode 100644 index 0000000000..8b0f3ba9b8 --- /dev/null +++ b/tests/unit/mcpgateway/test_admin_openapi.py @@ -0,0 +1,286 @@ +# -*- coding: utf-8 -*- +"""Location: ./tests/unit/mcpgateway/test_admin_openapi.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Tests for the generate_schemas_from_openapi admin endpoint. + +These tests verify the endpoint's own logic: input validation, URL parsing, +and exception-to-HTTP-status mapping. The underlying service layer +(fetch_and_extract_schemas) is mocked — its logic is tested separately in +test_openapi_service.py. +""" + +# Standard +from unittest.mock import AsyncMock, MagicMock, patch + +# Third-Party +from fastapi import Request +import httpx +import orjson +import pytest + +# First-Party +from mcpgateway.admin import generate_schemas_from_openapi + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _mock_request(body) -> MagicMock: + """Return a mock FastAPI Request compatible with ``_read_request_json``.""" + raw = orjson.dumps(body) if isinstance(body, dict) else body + req = MagicMock(spec=Request) + req.body = AsyncMock(return_value=raw) + # _read_request_json falls back to request.json() for empty bodies + req.json = AsyncMock(return_value=body if isinstance(body, dict) else {}) + return req + + +_USER = {"email": "test@example.com"} + + +# --------------------------------------------------------------------------- +# Happy-path tests +# --------------------------------------------------------------------------- + + +class TestGenerateSchemasFromOpenAPI: + """Tests for generate_schemas_from_openapi endpoint.""" + + @pytest.mark.asyncio + async def test_success(self): + """Successful schema generation returns 200 with both schemas.""" + input_schema = {"type": "object", "properties": {"x": {"type": "number"}}} + output_schema = {"type": "object", "properties": {"result": {"type": "number"}}} + + with patch("mcpgateway.admin.fetch_and_extract_schemas") as mock_fetch: + mock_fetch.return_value = (input_schema, output_schema, "http://example.com/openapi.json") + + response = await generate_schemas_from_openapi( + request=_mock_request({"url": "http://example.com/calculate", "request_type": "POST"}), + _user=_USER, + ) + + assert response.status_code == 200 + content = orjson.loads(response.body) + assert content["success"] is True + assert content["input_schema"] == input_schema + assert content["output_schema"] == output_schema + assert content["spec_url"] == "http://example.com/openapi.json" + + @pytest.mark.asyncio + async def test_with_openapi_url(self): + """Custom openapi_url is forwarded to the service.""" + with patch("mcpgateway.admin.fetch_and_extract_schemas") as mock_fetch: + mock_fetch.return_value = (None, {"type": "object"}, "http://example.com/custom-spec.json") + + response = await generate_schemas_from_openapi( + request=_mock_request({"url": "http://example.com/api", "openapi_url": "http://example.com/custom-spec.json", "request_type": "GET"}), + _user=_USER, + ) + + assert response.status_code == 200 + assert mock_fetch.call_args[1]["openapi_url"] == "http://example.com/custom-spec.json" + + @pytest.mark.asyncio + async def test_default_request_type_is_get(self): + """request_type defaults to GET when omitted.""" + with patch("mcpgateway.admin.fetch_and_extract_schemas") as mock_fetch: + mock_fetch.return_value = (None, {"type": "object"}, "http://example.com/openapi.json") + + response = await generate_schemas_from_openapi( + request=_mock_request({"url": "http://example.com/status"}), + _user=_USER, + ) + + assert response.status_code == 200 + assert mock_fetch.call_args[1]["method"] == "GET" + + @pytest.mark.asyncio + async def test_url_parsing(self): + """URL is correctly split into base_url and path for the service call.""" + with patch("mcpgateway.admin.fetch_and_extract_schemas") as mock_fetch: + mock_fetch.return_value = ({"type": "object"}, {"type": "object"}, "https://api.example.com:8443/openapi.json") + + response = await generate_schemas_from_openapi( + request=_mock_request({"url": "https://api.example.com:8443/v1/calculate", "request_type": "POST"}), + _user=_USER, + ) + + assert response.status_code == 200 + call_args = mock_fetch.call_args[1] + assert call_args["base_url"] == "https://api.example.com:8443" + assert call_args["path"] == "/v1/calculate" + assert call_args["method"] == "POST" + + +# --------------------------------------------------------------------------- +# Input validation +# --------------------------------------------------------------------------- + + +class TestGenerateSchemasInputValidation: + """Tests for request-level validation in the endpoint.""" + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "body", + [ + pytest.param({"request_type": "POST"}, id="missing-url"), + pytest.param({"url": "", "openapi_url": "http://x.com/spec.json", "request_type": "POST"}, id="empty-url"), + pytest.param({"openapi_url": "http://x.com/spec.json", "request_type": "GET"}, id="openapi-url-without-url"), + ], + ) + async def test_missing_or_empty_url_returns_400(self, body): + """url is required; its absence yields 400.""" + response = await generate_schemas_from_openapi(request=_mock_request(body), _user=_USER) + + assert response.status_code == 400 + content = orjson.loads(response.body) + assert content["success"] is False + assert "'url' is required" in content["message"] + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "body", + [ + pytest.param(b"[1, 2, 3]", id="json-array"), + pytest.param(b"42", id="json-scalar"), + pytest.param(b'"hello"', id="json-string"), + ], + ) + async def test_non_object_json_returns_400(self, body): + """JSON body that is not an object yields 400.""" + response = await generate_schemas_from_openapi(request=_mock_request(body), _user=_USER) + + assert response.status_code == 400 + content = orjson.loads(response.body) + assert content["success"] is False + assert "JSON object" in content["message"] + + @pytest.mark.asyncio + async def test_invalid_json_returns_400(self): + """Malformed JSON body yields 400.""" + response = await generate_schemas_from_openapi(request=_mock_request(b"invalid json {"), _user=_USER) + + assert response.status_code == 400 + content = orjson.loads(response.body) + assert "Invalid JSON" in content["message"] + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "body", + [ + pytest.param({"url": 123, "request_type": "POST"}, id="url-is-int"), + pytest.param({"url": "http://example.com/api", "request_type": ["POST"]}, id="request_type-is-list"), + pytest.param({"url": "http://example.com/api", "openapi_url": 42}, id="openapi_url-is-int"), + ], + ) + async def test_non_string_fields_return_400(self, body): + """Non-string values for url/request_type/openapi_url yield 400.""" + response = await generate_schemas_from_openapi(request=_mock_request(body), _user=_USER) + + assert response.status_code == 400 + content = orjson.loads(response.body) + assert content["success"] is False + assert "must be strings" in content["message"] + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "url", + [ + pytest.param("just-a-path", id="no-scheme-no-host"), + pytest.param("/calculate", id="path-only"), + pytest.param("ftp://example.com/api", id="unsupported-scheme"), + ], + ) + async def test_invalid_url_returns_400(self, url): + """URLs that fail SecurityValidator.validate_url yield 400.""" + response = await generate_schemas_from_openapi(request=_mock_request({"url": url, "request_type": "GET"}), _user=_USER) + + assert response.status_code == 400 + content = orjson.loads(response.body) + assert content["success"] is False + + +# --------------------------------------------------------------------------- +# Exception → HTTP status mapping +# --------------------------------------------------------------------------- + + +class TestGenerateSchemasErrorMapping: + """Each service-layer exception type maps to the correct HTTP status.""" + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "exception,expected_status,expected_fragment", + [ + pytest.param(ValueError("SSRF blocked"), 400, "Security validation failed", id="ValueError-400"), + pytest.param(KeyError("Path '/x' not found"), 404, "Path '/x' not found", id="KeyError-404"), + pytest.param( + httpx.HTTPStatusError("error", request=MagicMock(), response=MagicMock(status_code=403)), + 502, + "OpenAPI spec server returned HTTP 403", + id="HTTPStatusError-502", + ), + pytest.param(httpx.ConnectError("refused"), 502, "Failed to fetch OpenAPI spec from the provided URL", id="ConnectError-502"), + pytest.param(httpx.TimeoutException("timeout"), 502, "Failed to fetch OpenAPI spec from the provided URL", id="Timeout-502"), + pytest.param(Exception("unexpected"), 500, "An unexpected error occurred", id="Exception-500"), + ], + ) + async def test_exception_to_status(self, exception, expected_status, expected_fragment): + """Service exceptions are converted to the correct HTTP status and message.""" + with patch("mcpgateway.admin.fetch_and_extract_schemas") as mock_fetch: + mock_fetch.side_effect = exception + + response = await generate_schemas_from_openapi( + request=_mock_request({"url": "http://example.com/api", "request_type": "POST"}), + _user=_USER, + ) + + assert response.status_code == expected_status + content = orjson.loads(response.body) + assert content["success"] is False + assert expected_fragment in content["message"] + + @pytest.mark.asyncio + async def test_request_body_failure_returns_400(self): + """If _read_request_json fails, the endpoint returns 400.""" + req = MagicMock(spec=Request) + req.body = AsyncMock(side_effect=Exception("boom")) + + response = await generate_schemas_from_openapi(request=req, _user=_USER) + + assert response.status_code == 400 + content = orjson.loads(response.body) + assert "Invalid JSON" in content["message"] + + +# --------------------------------------------------------------------------- +# Deny-path regression tests +# --------------------------------------------------------------------------- + + +class TestGenerateSchemasPermissionDenial: + """Verify the endpoint rejects callers without tools.create permission.""" + + @pytest.mark.asyncio + async def test_denies_without_tools_create_permission(self, monkeypatch): + """Users with only tools.read (not tools.create) are rejected with 403.""" + # Third-Party + from fastapi import HTTPException + + deny_service = MagicMock() + deny_service.check_permission = AsyncMock(return_value=False) + monkeypatch.setattr("mcpgateway.middleware.rbac.PermissionService", lambda db: deny_service) + monkeypatch.setattr("mcpgateway.admin.PermissionService", lambda db: deny_service) + + with pytest.raises(HTTPException) as exc_info: + await generate_schemas_from_openapi( + request=_mock_request({"url": "http://example.com/api", "request_type": "GET"}), + _user={"email": "viewer@example.com"}, + ) + + assert exc_info.value.status_code == 403 diff --git a/tests/unit/mcpgateway/test_rest_schema_population.py b/tests/unit/mcpgateway/test_rest_schema_population.py new file mode 100644 index 0000000000..13d4abc071 --- /dev/null +++ b/tests/unit/mcpgateway/test_rest_schema_population.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +"""Location: ./tests/unit/mcpgateway/test_rest_schema_population.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Tests for REST tool validator behaviour: URL component extraction and +default input_schema population. + +Schema *fetching* from OpenAPI specs is tested in test_openapi_service.py +and test_admin_openapi.py. These tests cover the Pydantic model validators +in ToolCreate and ToolUpdate. +""" + +# Third-Party + +# First-Party +from mcpgateway.schemas import ToolCreate, ToolUpdate + +_DEFAULT_SCHEMA = {"type": "object", "properties": {}} + + +class TestToolCreateRESTDefaults: + """ToolCreate validator: URL extraction and default input_schema for REST tools.""" + + def test_default_input_schema_when_none(self): + """REST ToolCreate with no input_schema gets the typed default.""" + tool = ToolCreate(name="t", integration_type="REST", base_url="http://example.com", path_template="/api") + assert tool.input_schema == _DEFAULT_SCHEMA + + def test_default_input_schema_when_empty_dict(self): + """REST ToolCreate with input_schema={} gets the typed default.""" + tool = ToolCreate(name="t", integration_type="REST", base_url="http://example.com", path_template="/api", input_schema={}) + assert tool.input_schema == _DEFAULT_SCHEMA + + def test_provided_input_schema_preserved(self): + """REST ToolCreate with a real input_schema keeps it untouched.""" + schema = {"type": "object", "properties": {"a": {"type": "string"}}} + tool = ToolCreate(name="t", integration_type="REST", base_url="http://example.com", path_template="/api", input_schema=schema) + assert tool.input_schema == schema + + def test_url_extracts_base_url_and_path(self): + """Providing 'url' auto-populates base_url and path_template.""" + tool = ToolCreate(name="t", integration_type="REST", url="https://api.example.com:8443/v1/calculate") + assert tool.base_url == "https://api.example.com:8443" + assert tool.path_template == "/v1/calculate" + + def test_explicit_base_url_not_overwritten(self): + """Explicit base_url takes precedence over url-derived value.""" + tool = ToolCreate(name="t", integration_type="REST", url="http://derived.com/path", base_url="http://explicit.com") + assert tool.base_url == "http://explicit.com" + + def test_non_rest_tool_skips_extraction(self): + """Non-REST integration types don't get URL extraction or default schema. + + MCP/A2A types are rejected outright by other validators, so we verify + by testing that a REST tool *does* get extraction (positive case) and + that the helper is gated on integration_type. + """ + # First-Party + from mcpgateway.schemas import _extract_rest_url_components + + values = {"integration_type": "MCP", "url": "http://example.com/path"} + # The helper is never called for non-REST, but verify it's a no-op + # when called directly without URL (simulating the guard in the validator). + non_rest_values: dict = {} + _extract_rest_url_components(non_rest_values) + assert "base_url" not in non_rest_values + + def test_no_url_no_base_url(self): + """REST tool with no url still gets default schema.""" + tool = ToolCreate(name="t", integration_type="REST", path_template="/test") + assert tool.input_schema == _DEFAULT_SCHEMA + + def test_output_schema_not_set_by_default(self): + """Validator does not touch output_schema.""" + tool = ToolCreate(name="t", integration_type="REST", base_url="http://example.com", path_template="/api") + assert tool.output_schema is None + + +class TestToolUpdateRESTDefaults: + """ToolUpdate validator: URL extraction and empty-schema normalisation.""" + + def test_url_extracts_components(self): + """Providing 'url' on update extracts base_url and path_template.""" + update = ToolUpdate(integration_type="REST", url="http://example.com/test") + assert update.base_url == "http://example.com" + assert update.path_template == "/test" + + def test_existing_schemas_preserved(self): + """Existing schemas are not overwritten by the validator.""" + schema = {"type": "object", "properties": {"existing": {"type": "string"}}} + update = ToolUpdate(integration_type="REST", input_schema=schema) + assert update.input_schema == schema + + def test_empty_dict_normalised(self): + """An explicitly empty {} input_schema is normalised to the typed default.""" + update = ToolUpdate(integration_type="REST", input_schema={}) + assert update.input_schema == _DEFAULT_SCHEMA + + def test_none_schema_left_alone(self): + """None input_schema is not changed (update may intentionally omit it).""" + update = ToolUpdate(integration_type="REST") + assert update.input_schema is None + + def test_non_empty_schema_not_normalised(self): + """A schema with actual properties is not touched.""" + schema = {"properties": {"x": {"type": "number"}}} + update = ToolUpdate(integration_type="REST", input_schema=schema) + assert update.input_schema == schema + + def test_non_rest_skips_extraction(self): + """Non-REST updates skip URL extraction entirely. + + MCP/A2A types are rejected by other validators, so we verify via + the helper function directly. + """ + # First-Party + from mcpgateway.schemas import _extract_rest_url_components + + values: dict = {} + _extract_rest_url_components(values) + assert "base_url" not in values + + def test_output_schema_not_set(self): + """Validator does not touch output_schema on update.""" + update = ToolUpdate(integration_type="REST", url="http://example.com/api") + assert update.output_schema is None