diff --git a/mcpgateway/bootstrap_db.py b/mcpgateway/bootstrap_db.py index 8b68dbdc0..df2f7a097 100644 --- a/mcpgateway/bootstrap_db.py +++ b/mcpgateway/bootstrap_db.py @@ -78,7 +78,7 @@ async def bootstrap_admin_user() -> None: logger.info(f"Creating platform admin user: {settings.platform_admin_email}") admin_user = await auth_service.create_user( email=settings.platform_admin_email, - password=settings.platform_admin_password, + password=settings.platform_admin_password.get_secret_value(), full_name=settings.platform_admin_full_name, is_admin=True, ) diff --git a/mcpgateway/config.py b/mcpgateway/config.py index 84feacd52..87ba1c738 100644 --- a/mcpgateway/config.py +++ b/mcpgateway/config.py @@ -50,20 +50,16 @@ # Standard from functools import lru_cache from importlib.resources import files -import json +import json # TODO: consider typjson for type safety loading from configuration data. import logging import os from pathlib import Path import re import sys -from typing import Annotated, Any, ClassVar, Dict, List, Literal, Optional, Set, Union +from typing import Annotated, Any, ClassVar, Dict, List, Literal, NotRequired, Optional, Self, Set, TypedDict # Third-Party -from fastapi import HTTPException -import jq -from jsonpath_ng.ext import parse -from jsonpath_ng.jsonpath import JSONPath -from pydantic import Field, field_validator, HttpUrl, model_validator, PositiveInt, SecretStr +from pydantic import Field, field_validator, HttpUrl, model_validator, PositiveInt, SecretStr, ValidationInfo from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict # Only configure basic logging if no handlers exist yet @@ -156,10 +152,11 @@ class Settings(BaseSettings): port: PositiveInt = Field(default=4444, ge=1, le=65535) docs_allow_basic_auth: bool = False # Allow basic auth for docs database_url: str = "sqlite:///./mcp.db" - templates_dir: Path = Path("mcpgateway/templates") + # Absolute paths resolved at import-time (still override-able via env vars) - templates_dir: Path = files("mcpgateway") / "templates" - static_dir: Path = files("mcpgateway") / "static" + templates_dir: Path = Field(default_factory=lambda: Path(str(files("mcpgateway") / "templates"))) + static_dir: Path = Field(default_factory=lambda: Path(str(files("mcpgateway") / "static"))) + app_root_path: str = "" # Protocol @@ -167,9 +164,9 @@ class Settings(BaseSettings): # Authentication basic_auth_user: str = "admin" - basic_auth_password: str = "changeme" + basic_auth_password: SecretStr = Field(default=SecretStr("changeme")) jwt_algorithm: str = "HS256" - jwt_secret_key: SecretStr = Field(default="my-test-key") + jwt_secret_key: SecretStr = Field(default=SecretStr("my-test-key")) jwt_public_key_path: str = "" jwt_private_key_path: str = "" jwt_audience: str = "mcpgateway-api" @@ -184,27 +181,27 @@ class Settings(BaseSettings): sso_enabled: bool = Field(default=False, description="Enable Single Sign-On authentication") sso_github_enabled: bool = Field(default=False, description="Enable GitHub OAuth authentication") sso_github_client_id: Optional[str] = Field(default=None, description="GitHub OAuth client ID") - sso_github_client_secret: Optional[str] = Field(default=None, description="GitHub OAuth client secret") + sso_github_client_secret: Optional[SecretStr] = Field(default=None, description="GitHub OAuth client secret") sso_google_enabled: bool = Field(default=False, description="Enable Google OAuth authentication") sso_google_client_id: Optional[str] = Field(default=None, description="Google OAuth client ID") - sso_google_client_secret: Optional[str] = Field(default=None, description="Google OAuth client secret") + sso_google_client_secret: Optional[SecretStr] = Field(default=None, description="Google OAuth client secret") sso_ibm_verify_enabled: bool = Field(default=False, description="Enable IBM Security Verify OIDC authentication") sso_ibm_verify_client_id: Optional[str] = Field(default=None, description="IBM Security Verify client ID") - sso_ibm_verify_client_secret: Optional[str] = Field(default=None, description="IBM Security Verify client secret") + sso_ibm_verify_client_secret: Optional[SecretStr] = Field(default=None, description="IBM Security Verify client secret") sso_ibm_verify_issuer: Optional[str] = Field(default=None, description="IBM Security Verify OIDC issuer URL") sso_okta_enabled: bool = Field(default=False, description="Enable Okta OIDC authentication") sso_okta_client_id: Optional[str] = Field(default=None, description="Okta client ID") - sso_okta_client_secret: Optional[str] = Field(default=None, description="Okta client secret") + sso_okta_client_secret: Optional[SecretStr] = Field(default=None, description="Okta client secret") sso_okta_issuer: Optional[str] = Field(default=None, description="Okta issuer URL") sso_keycloak_enabled: bool = Field(default=False, description="Enable Keycloak OIDC authentication") sso_keycloak_base_url: Optional[str] = Field(default=None, description="Keycloak base URL (e.g., https://keycloak.example.com)") sso_keycloak_realm: str = Field(default="master", description="Keycloak realm name") sso_keycloak_client_id: Optional[str] = Field(default=None, description="Keycloak client ID") - sso_keycloak_client_secret: Optional[str] = Field(default=None, description="Keycloak client secret") + sso_keycloak_client_secret: Optional[SecretStr] = Field(default=None, description="Keycloak client secret") sso_keycloak_map_realm_roles: bool = Field(default=True, description="Map Keycloak realm roles to gateway teams") sso_keycloak_map_client_roles: bool = Field(default=False, description="Map Keycloak client roles to gateway RBAC") sso_keycloak_username_claim: str = Field(default="preferred_username", description="JWT claim for username") @@ -213,14 +210,14 @@ class Settings(BaseSettings): sso_entra_enabled: bool = Field(default=False, description="Enable Microsoft Entra ID OIDC authentication") sso_entra_client_id: Optional[str] = Field(default=None, description="Microsoft Entra ID client ID") - sso_entra_client_secret: Optional[str] = Field(default=None, description="Microsoft Entra ID client secret") + sso_entra_client_secret: Optional[SecretStr] = Field(default=None, description="Microsoft Entra ID client secret") sso_entra_tenant_id: Optional[str] = Field(default=None, description="Microsoft Entra ID tenant ID") sso_generic_enabled: bool = Field(default=False, description="Enable generic OIDC provider (Keycloak, Auth0, etc.)") sso_generic_provider_id: Optional[str] = Field(default=None, description="Provider ID (e.g., 'keycloak', 'auth0', 'authentik')") sso_generic_display_name: Optional[str] = Field(default=None, description="Display name shown on login page") sso_generic_client_id: Optional[str] = Field(default=None, description="Generic OIDC client ID") - sso_generic_client_secret: Optional[str] = Field(default=None, description="Generic OIDC client secret") + sso_generic_client_secret: Optional[SecretStr] = Field(default=None, description="Generic OIDC client secret") sso_generic_authorization_url: Optional[str] = Field(default=None, description="Authorization endpoint URL") sso_generic_token_url: Optional[str] = Field(default=None, description="Token endpoint URL") sso_generic_userinfo_url: Optional[str] = Field(default=None, description="Userinfo endpoint URL") @@ -290,7 +287,7 @@ class Settings(BaseSettings): # Email-Based Authentication email_auth_enabled: bool = Field(default=True, description="Enable email-based authentication") platform_admin_email: str = Field(default="admin@example.com", description="Platform administrator email address") - platform_admin_password: str = Field(default="changeme", description="Platform administrator password") + platform_admin_password: SecretStr = Field(default=SecretStr("changeme"), description="Platform administrator password") platform_admin_full_name: str = Field(default="Platform Administrator", description="Platform administrator full name") # Argon2id Password Hashing Configuration @@ -356,7 +353,7 @@ class Settings(BaseSettings): environment: Literal["development", "staging", "production"] = Field(default="development") # Domain configuration - app_domain: HttpUrl = Field(default="http://localhost:4444") + app_domain: HttpUrl = Field(default=HttpUrl("http://localhost:4444")) # Security settings secure_cookies: bool = Field(default=True) @@ -399,7 +396,7 @@ class Settings(BaseSettings): @field_validator("jwt_secret_key", "auth_encryption_secret") @classmethod - def validate_secrets(cls, v, info): + def validate_secrets(cls, v: Any, info: ValidationInfo) -> SecretStr: """ Validate that secret keys meet basic security requirements. @@ -418,7 +415,7 @@ def validate_secrets(cls, v, info): - The original value is returned as a `SecretStr` for safe handling. Args: - v (str | SecretStr): The secret value to validate. + v: The secret value to validate. info: Pydantic validation info object, used to get the field name. Returns: @@ -431,7 +428,7 @@ def validate_secrets(cls, v, info): if isinstance(v, SecretStr): value = v.get_secret_value() else: - value = v + value = str(v) # Check for default/weak secrets weak_secrets = ["my-test-key", "my-test-salt", "changeme", "secret", "password"] @@ -451,39 +448,46 @@ def validate_secrets(cls, v, info): @field_validator("basic_auth_password") @classmethod - def validate_admin_password(cls, v: str) -> str: + def validate_admin_password(cls, v: str | SecretStr) -> SecretStr: """Validate admin password meets security requirements. Args: v: The admin password value to validate. Returns: - str: The validated admin password value. + SecretStr: The validated admin password value, wrapped as SecretStr. """ - if v == "changeme": # nosec B105 - checking for default value + # Extract actual string value safely + if isinstance(v, SecretStr): + value = v.get_secret_value() + else: + value = v + + if value == "changeme": # nosec B105 - checking for default value logger.warning("🔓 SECURITY WARNING: Default admin password detected! Please change the BASIC_AUTH_PASSWORD immediately.") # Note: We can't access password_min_length here as it's not set yet during validation # Using default value of 8 to match the field default min_length = 8 # This matches the default in password_min_length field - if len(v) < min_length: - logger.warning(f"⚠️ SECURITY WARNING: Admin password should be at least {min_length} characters long. Current length: {len(v)}") + if len(value) < min_length: + logger.warning(f"⚠️ SECURITY WARNING: Admin password should be at least {min_length} characters long. Current length: {len(value)}") # Check password complexity - has_upper = any(c.isupper() for c in v) - has_lower = any(c.islower() for c in v) - has_digit = any(c.isdigit() for c in v) - has_special = bool(re.search(r'[!@#$%^&*(),.?":{}|<>]', v)) + has_upper = any(c.isupper() for c in value) + has_lower = any(c.islower() for c in value) + has_digit = any(c.isdigit() for c in value) + has_special = bool(re.search(r'[!@#$%^&*(),.?":{}|<>]', value)) complexity_score = sum([has_upper, has_lower, has_digit, has_special]) if complexity_score < 3: logger.warning("🔐 SECURITY WARNING: Admin password has low complexity. Should contain at least 3 of: uppercase, lowercase, digits, special characters") - return v + # Always return SecretStr to keep it secret-safe + return v if isinstance(v, SecretStr) else SecretStr(value) @field_validator("allowed_origins") @classmethod - def validate_cors_origins(cls, v: set) -> set: + def validate_cors_origins(cls, v: Any) -> set[str] | None: """Validate CORS allowed origins. Args: @@ -491,9 +495,14 @@ def validate_cors_origins(cls, v: set) -> set: Returns: set: The validated set of allowed origins. + + Raises: + ValueError: If allowed_origins is not a set or list of strings. """ - if not v: + if v is None: return v + if not isinstance(v, (set, list)): + raise ValueError("allowed_origins must be a set or list of strings") dangerous_origins = ["*", "null", ""] for origin in v: @@ -504,7 +513,7 @@ def validate_cors_origins(cls, v: set) -> set: if not origin.startswith(("http://", "https://")) and origin not in dangerous_origins: logger.warning(f"⚠️ SECURITY WARNING: Invalid origin format '{origin}'. Origins should start with http:// or https://") - return v + return set([str(origin) for origin in v]) @field_validator("database_url") @classmethod @@ -529,31 +538,27 @@ def validate_database_url(cls, v: str) -> str: return v @model_validator(mode="after") - @classmethod - def validate_security_combinations(cls, values): - """Validate security setting combinations. - - Args: - values: The Settings instance with all field values. + def validate_security_combinations(self) -> Self: + """Validate security setting combinations. Only logs warnings; no changes are made. Returns: - Settings: The validated Settings instance. + Itself. """ # Check for dangerous combinations - only log warnings, don't raise errors - if not values.auth_required and values.mcpgateway_ui_enabled: + if not self.auth_required and self.mcpgateway_ui_enabled: logger.warning("🔓 SECURITY WARNING: Admin UI is enabled without authentication. Consider setting AUTH_REQUIRED=true for production.") - if values.skip_ssl_verify and not values.dev_mode: + if self.skip_ssl_verify and not self.dev_mode: logger.warning("🔓 SECURITY WARNING: SSL verification is disabled in non-dev mode. This is a security risk! Set SKIP_SSL_VERIFY=false for production.") - if values.debug and not values.dev_mode: + if self.debug and not self.dev_mode: logger.warning("🐛 SECURITY WARNING: Debug mode is enabled in non-dev mode. This may leak sensitive information! Set DEBUG=false for production.") # Warn about federation without auth - if values.federation_enabled and not values.auth_required: + if self.federation_enabled and not self.auth_required: logger.warning("🌐 SECURITY WARNING: Federation is enabled without authentication. This may expose your gateway to unauthorized access.") - return values + return self def get_security_warnings(self) -> List[str]: """Get list of security warnings for current configuration. @@ -599,11 +604,23 @@ def get_security_warnings(self) -> List[str]: return warnings - def get_security_status(self) -> dict: + class SecurityStatus(TypedDict): + """TypedDict for comprehensive security status.""" + + secure_secrets: bool + auth_enabled: bool + ssl_verification: bool + debug_disabled: bool + cors_restricted: bool + ui_protected: bool + warnings: List[str] + security_score: int + + def get_security_status(self) -> SecurityStatus: """Get comprehensive security status. Returns: - dict: Dictionary containing security status information including score and warnings. + SecurityStatus: Dictionary containing security status information including score and warnings. """ # Compute a security score: 100 minus 10 for each warning @@ -628,7 +645,7 @@ def get_security_status(self) -> dict: @field_validator("allowed_origins", mode="before") @classmethod - def _parse_allowed_origins(cls, v): + def _parse_allowed_origins(cls, v: Any) -> Set[str]: """Parse allowed origins from environment variable or config value. Handles multiple input formats for the allowed_origins field: @@ -728,7 +745,7 @@ def validate_log_level(cls, v: str) -> str: @field_validator("federation_peers", mode="before") @classmethod - def _parse_federation_peers(cls, v): + def _parse_federation_peers(cls, v: Any) -> List[str]: """Parse federation peer URLs from environment variable or config value. Handles multiple input formats for the federation_peers field: @@ -770,7 +787,7 @@ def _parse_federation_peers(cls, v): peers = json.loads(v) except json.JSONDecodeError: peers = [s.strip() for s in v.split(",") if s.strip()] - return peers + return peers # type: ignore[no-any-return] # Convert other iterables to list return list(v) @@ -784,7 +801,7 @@ def _parse_federation_peers(cls, v): @field_validator("sso_issuers", mode="before") @classmethod - def parse_issuers(cls, v): + def parse_issuers(cls, v: Any) -> set[str]: """ Parse and validate the SSO issuers configuration value. @@ -793,7 +810,7 @@ def parse_issuers(cls, v): provide issuers as JSON while still supporting direct list assignment in code. Args: - v (str | list): The input value for SSO issuers, either a JSON array string + v: The input value for SSO issuers, either a JSON array string or a Python list. Returns: @@ -806,10 +823,10 @@ def parse_issuers(cls, v): # Accept either a JSON array string or actual list if isinstance(v, str): try: - return json.loads(v) + return json.loads(v) # type: ignore[no-any-return] except json.JSONDecodeError: raise ValueError(f"SSO_ISSUERS must be a JSON array of URLs, got: {v!r}") - return v + return v # type: ignore[no-any-return] # Resources resource_cache_size: int = 1000 @@ -977,7 +994,7 @@ def custom_well_known_files(self) -> Dict[str, str]: @field_validator("well_known_security_txt_enabled", mode="after") @classmethod - def _auto_enable_security_txt(cls, v, info): + def _auto_enable_security_txt(cls, v: Any, info: ValidationInfo) -> bool: """Auto-enable security.txt if content is provided. Args: @@ -989,7 +1006,7 @@ def _auto_enable_security_txt(cls, v, info): """ if info.data and "well_known_security_txt" in info.data: return bool(info.data["well_known_security_txt"].strip()) - return v + return bool(v) # ------------------------------- # Flexible list parsing for envs @@ -1002,7 +1019,7 @@ def _auto_enable_security_txt(cls, v, info): mode="before", ) @classmethod - def _parse_list_from_env(cls, v): # type: ignore[override] + def _parse_list_from_env(cls, v: None | str | list[str]) -> list[str]: """Parse list fields from environment values. Accepts either JSON arrays (e.g. '["a","b"]') or comma-separated @@ -1013,6 +1030,9 @@ def _parse_list_from_env(cls, v): # type: ignore[override] Returns: list: Parsed list of values. + + Raises: + ValueError: If the value type is invalid for list field parsing. """ if v is None: return [] @@ -1030,7 +1050,7 @@ def _parse_list_from_env(cls, v): # type: ignore[override] logger.warning("Invalid JSON list in env for list field; falling back to CSV parsing") # CSV fallback return [item.strip() for item in s.split(",") if item.strip()] - return v + raise ValueError("Invalid type for list field") @property def api_key(self) -> str: @@ -1049,7 +1069,7 @@ def api_key(self) -> str: >>> settings.api_key 'user123:pass456' """ - return f"{self.basic_auth_user}:{self.basic_auth_password}" + return f"{self.basic_auth_user}:{self.basic_auth_password.get_secret_value()}" @property def supports_http(self) -> bool: @@ -1111,13 +1131,22 @@ def supports_sse(self) -> bool: """ return self.transport_type in ["sse", "all"] + class DatabaseSettings(TypedDict): + """TypedDict for SQLAlchemy database settings.""" + + pool_size: int + max_overflow: int + pool_timeout: int + pool_recycle: int + connect_args: dict[str, Any] # TODO: more specific type if needed + @property - def database_settings(self) -> dict: + def database_settings(self) -> DatabaseSettings: """ Get SQLAlchemy database settings. Returns: - dict: Dictionary containing SQLAlchemy database configuration options. + DatabaseSettings: Dictionary containing SQLAlchemy database configuration options. Examples: >>> from mcpgateway.config import Settings @@ -1133,12 +1162,20 @@ def database_settings(self) -> dict: "connect_args": {"check_same_thread": False} if self.database_url.startswith("sqlite") else {}, } + class CORSSettings(TypedDict): + """TypedDict for CORS settings.""" + + allow_origins: NotRequired[List[str]] + allow_credentials: NotRequired[bool] + allow_methods: NotRequired[List[str]] + allow_headers: NotRequired[List[str]] + @property - def cors_settings(self) -> dict: + def cors_settings(self) -> CORSSettings: """Get CORS settings. Returns: - dict: Dictionary containing CORS configuration options. + CORSSettings: Dictionary containing CORS configuration options. Examples: >>> s = Settings(cors_enabled=True, allowed_origins={'http://localhost'}) @@ -1291,7 +1328,7 @@ def validate_database(self) -> None: # Base URL for pagination links (defaults to request URL) pagination_base_url: Optional[str] = Field(default=None, description="Base URL for pagination links") - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: """Initialize Settings with environment variable parsing. Args: @@ -1359,7 +1396,7 @@ def __init__(self, **kwargs): # Masking value for all sensitive data masked_auth_value: str = "*****" - def log_summary(self): + def log_summary(self) -> None: """ Log a summary of the application settings. @@ -1374,125 +1411,6 @@ def log_summary(self): logger.info(f"Application settings summary: {summary}") -def extract_using_jq(data, jq_filter=""): - """ - Extracts data from a given input (string, dict, or list) using a jq filter string. - - Args: - data (str, dict, list): The input JSON data. Can be a string, dict, or list. - jq_filter (str): The jq filter string to extract the desired data. - - Returns: - The result of applying the jq filter to the input data. - - Examples: - >>> extract_using_jq('{"a": 1, "b": 2}', '.a') - [1] - >>> extract_using_jq({'a': 1, 'b': 2}, '.b') - [2] - >>> extract_using_jq('[{"a": 1}, {"a": 2}]', '.[].a') - [1, 2] - >>> extract_using_jq('not a json', '.a') - ['Invalid JSON string provided.'] - >>> extract_using_jq({'a': 1}, '') - {'a': 1} - """ - if jq_filter == "": - return data - if isinstance(data, str): - # If the input is a string, parse it as JSON - try: - data = json.loads(data) - except json.JSONDecodeError: - return ["Invalid JSON string provided."] - - elif not isinstance(data, (dict, list)): - # If the input is not a string, dict, or list, raise an error - return ["Input data must be a JSON string, dictionary, or list."] - - # Apply the jq filter to the data - try: - # Pylint can't introspect C-extension modules, so it doesn't know that jq really does export an all() function. - # pylint: disable=c-extension-no-member - result = jq.all(jq_filter, data) # Use `jq.all` to get all matches (returns a list) - if result == [None]: - result = "Error applying jsonpath filter" - except Exception as e: - message = "Error applying jsonpath filter: " + str(e) - return message - - return result - - -def jsonpath_modifier(data: Any, jsonpath: str = "$[*]", mappings: Optional[Dict[str, str]] = None) -> Union[List, Dict]: - """ - Applies the given JSONPath expression and mappings to the data. - Only return data that is required by the user dynamically. - - Args: - data: The JSON data to query. - jsonpath: The JSONPath expression to apply. - mappings: Optional dictionary of mappings where keys are new field names - and values are JSONPath expressions. - - Returns: - Union[List, Dict]: A list (or mapped list) or a Dict of extracted data. - - Raises: - HTTPException: If there's an error parsing or executing the JSONPath expressions. - - Examples: - >>> jsonpath_modifier({'a': 1, 'b': 2}, '$.a') - [1] - >>> jsonpath_modifier([{'a': 1}, {'a': 2}], '$[*].a') - [1, 2] - >>> jsonpath_modifier({'a': {'b': 2}}, '$.a.b') - [2] - >>> jsonpath_modifier({'a': 1}, '$.b') - [] - """ - if not jsonpath: - jsonpath = "$[*]" - - try: - main_expr: JSONPath = parse(jsonpath) - except Exception as e: - raise HTTPException(status_code=400, detail=f"Invalid main JSONPath expression: {e}") - - try: - main_matches = main_expr.find(data) - except Exception as e: - raise HTTPException(status_code=400, detail=f"Error executing main JSONPath: {e}") - - results = [match.value for match in main_matches] - - if mappings: - mapped_results = [] - for item in results: - mapped_item = {} - for new_key, mapping_expr_str in mappings.items(): - try: - mapping_expr = parse(mapping_expr_str) - except Exception as e: - raise HTTPException(status_code=400, detail=f"Invalid mapping JSONPath for key '{new_key}': {e}") - try: - mapping_matches = mapping_expr.find(item) - except Exception as e: - raise HTTPException(status_code=400, detail=f"Error executing mapping JSONPath for key '{new_key}': {e}") - if not mapping_matches: - mapped_item[new_key] = None - elif len(mapping_matches) == 1: - mapped_item[new_key] = mapping_matches[0].value - else: - mapped_item[new_key] = [m.value for m in mapping_matches] - mapped_results.append(mapped_item) - results = mapped_results - - if len(results) == 1 and isinstance(results[0], dict): - return results[0] - return results - - @lru_cache() def get_settings() -> Settings: """Get cached settings instance. @@ -1521,7 +1439,7 @@ def get_settings() -> Settings: return cfg -def generate_settings_schema() -> dict: +def generate_settings_schema() -> dict[str, Any]: """ Return the JSON Schema describing the Settings model. @@ -1533,10 +1451,8 @@ def generate_settings_schema() -> dict: return Settings.model_json_schema(mode="validation") -# Create settings instance settings = get_settings() - if __name__ == "__main__": if "--schema" in sys.argv: schema = generate_settings_schema() diff --git a/mcpgateway/main.py b/mcpgateway/main.py index 58b215237..7422dcb97 100644 --- a/mcpgateway/main.py +++ b/mcpgateway/main.py @@ -45,6 +45,8 @@ from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates +from jsonpath_ng.ext import parse +from jsonpath_ng.jsonpath import JSONPath from pydantic import ValidationError from sqlalchemy import text from sqlalchemy.exc import IntegrityError @@ -61,7 +63,7 @@ from mcpgateway.auth import get_current_user from mcpgateway.bootstrap_db import main as bootstrap_db from mcpgateway.cache import ResourceCache, SessionRegistry -from mcpgateway.config import jsonpath_modifier, settings +from mcpgateway.config import settings from mcpgateway.db import refresh_slugs_on_startup, SessionLocal from mcpgateway.db import Tool as DbTool from mcpgateway.handlers.sampling import SamplingHandler @@ -102,9 +104,8 @@ from mcpgateway.services.completion_service import CompletionService from mcpgateway.services.export_service import ExportError, ExportService from mcpgateway.services.gateway_service import GatewayConnectionError, GatewayError, GatewayNameConflictError, GatewayNotFoundError, GatewayService, GatewayUrlConflictError -from mcpgateway.services.import_service import ConflictStrategy, ImportConflictError +from mcpgateway.services.import_service import ConflictStrategy, ImportConflictError, ImportService, ImportValidationError from mcpgateway.services.import_service import ImportError as ImportServiceError -from mcpgateway.services.import_service import ImportService, ImportValidationError from mcpgateway.services.logging_service import LoggingService from mcpgateway.services.prompt_service import PromptError, PromptNameConflictError, PromptNotFoundError, PromptService from mcpgateway.services.resource_service import ResourceError, ResourceNotFoundError, ResourceService, ResourceURIConflictError @@ -264,6 +265,75 @@ def get_user_email(user): resource_cache = ResourceCache(max_size=settings.resource_cache_size, ttl=settings.resource_cache_ttl) +def jsonpath_modifier(data: Any, jsonpath: str = "$[*]", mappings: Optional[Dict[str, str]] = None) -> Union[List, Dict]: + """ + Applies the given JSONPath expression and mappings to the data. + Only return data that is required by the user dynamically. + + Args: + data: The JSON data to query. + jsonpath: The JSONPath expression to apply. + mappings: Optional dictionary of mappings where keys are new field names + and values are JSONPath expressions. + + Returns: + Union[List, Dict]: A list (or mapped list) or a Dict of extracted data. + + Raises: + HTTPException: If there's an error parsing or executing the JSONPath expressions. + + Examples: + >>> jsonpath_modifier({'a': 1, 'b': 2}, '$.a') + [1] + >>> jsonpath_modifier([{'a': 1}, {'a': 2}], '$[*].a') + [1, 2] + >>> jsonpath_modifier({'a': {'b': 2}}, '$.a.b') + [2] + >>> jsonpath_modifier({'a': 1}, '$.b') + [] + """ + if not jsonpath: + jsonpath = "$[*]" + + try: + main_expr: JSONPath = parse(jsonpath) + except Exception as e: + raise HTTPException(status_code=400, detail=f"Invalid main JSONPath expression: {e}") + + try: + main_matches = main_expr.find(data) + except Exception as e: + raise HTTPException(status_code=400, detail=f"Error executing main JSONPath: {e}") + + results = [match.value for match in main_matches] + + if mappings: + mapped_results = [] + for item in results: + mapped_item = {} + for new_key, mapping_expr_str in mappings.items(): + try: + mapping_expr = parse(mapping_expr_str) + except Exception as e: + raise HTTPException(status_code=400, detail=f"Invalid mapping JSONPath for key '{new_key}': {e}") + try: + mapping_matches = mapping_expr.find(item) + except Exception as e: + raise HTTPException(status_code=400, detail=f"Error executing mapping JSONPath for key '{new_key}': {e}") + if not mapping_matches: + mapped_item[new_key] = None + elif len(mapping_matches) == 1: + mapped_item[new_key] = mapping_matches[0].value + else: + mapped_item[new_key] = [m.value for m in mapping_matches] + mapped_results.append(mapped_item) + results = mapped_results + + if len(results) == 1 and isinstance(results[0], dict): + return results[0] + return results + + #################### # Startup/Shutdown # #################### @@ -432,7 +502,7 @@ async def validate_security_configuration(): if settings.jwt_secret_key == "my-test-key" and not settings.dev_mode: # nosec B105 - checking for default value critical_issues.append("Using default JWT secret in non-dev mode. Set JWT_SECRET_KEY environment variable!") - if settings.basic_auth_password == "changeme" and settings.mcpgateway_ui_enabled: # nosec B105 - checking for default value + if settings.basic_auth_password.get_secret_value() == "changeme" and settings.mcpgateway_ui_enabled: # nosec B105 - checking for default value critical_issues.append("Admin UI enabled with default password. Set BASIC_AUTH_PASSWORD environment variable!") if not settings.auth_required and settings.federation_enabled and not settings.dev_mode: @@ -469,7 +539,7 @@ async def validate_security_configuration(): logger.info(" • Generate a strong JWT secret:") logger.info(" python3 -c 'import secrets; print(secrets.token_urlsafe(32))'") - if settings.basic_auth_password == "changeme": # nosec B105 - checking for default value + if settings.basic_auth_password.get_secret_value() == "changeme": # nosec B105 - checking for default value logger.info(" • Set a strong admin password in BASIC_AUTH_PASSWORD") if not settings.auth_required: @@ -1011,9 +1081,10 @@ def require_api_key(api_key: str) -> None: Examples: >>> from mcpgateway.config import settings + >>> from pydantic import SecretStr >>> settings.auth_required = True >>> settings.basic_auth_user = "admin" - >>> settings.basic_auth_password = "secret" + >>> settings.basic_auth_password = SecretStr("secret") >>> >>> # Valid API key >>> require_api_key("admin:secret") # Should not raise @@ -1026,7 +1097,7 @@ def require_api_key(api_key: str) -> None: 401 """ if settings.auth_required: - expected = f"{settings.basic_auth_user}:{settings.basic_auth_password}" + expected = f"{settings.basic_auth_user}:{settings.basic_auth_password.get_secret_value()}" if api_key != expected: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key") @@ -2623,8 +2694,10 @@ async def read_resource(resource_id: str, request: Request, db: Session = Depend # Ensure a plain JSON-serializable structure try: # First-Party - from mcpgateway.models import ResourceContent # pylint: disable=import-outside-toplevel - from mcpgateway.models import TextContent # pylint: disable=import-outside-toplevel + from mcpgateway.models import ( + ResourceContent, # pylint: disable=import-outside-toplevel + TextContent, # pylint: disable=import-outside-toplevel + ) # If already a ResourceContent, serialize directly if isinstance(content, ResourceContent): diff --git a/mcpgateway/scripts/validate_env.py b/mcpgateway/scripts/validate_env.py index d56efaeb4..5a37d9c22 100644 --- a/mcpgateway/scripts/validate_env.py +++ b/mcpgateway/scripts/validate_env.py @@ -53,7 +53,8 @@ def get_security_warnings(settings: Settings) -> list[str]: warnings.append(f"PORT: Out of allowed range (1-65535). Got: {settings.port}") # --- PLATFORM_ADMIN_PASSWORD --- - pw = settings.platform_admin_password + pw = settings.platform_admin_password.get_secret_value() if isinstance(settings.platform_admin_password, SecretStr) else settings.platform_admin_password + if not pw or pw.lower() in ("changeme", "admin", "password"): warnings.append("Default admin password detected! Please change PLATFORM_ADMIN_PASSWORD immediately.") min_length = settings.password_min_length @@ -64,7 +65,7 @@ def get_security_warnings(settings: Settings) -> list[str]: warnings.append("Admin password has low complexity. Should contain at least 3 of: uppercase, lowercase, digits, special characters") # --- BASIC_AUTH_PASSWORD --- - basic_pw = settings.basic_auth_password + basic_pw = settings.basic_auth_password.get_secret_value() if isinstance(settings.basic_auth_password, SecretStr) else settings.basic_auth_password if not basic_pw or basic_pw.lower() in ("changeme", "password"): warnings.append("Default BASIC_AUTH_PASSWORD detected! Please change it immediately.") min_length = settings.password_min_length diff --git a/mcpgateway/services/tool_service.py b/mcpgateway/services/tool_service.py index 998626760..179444be5 100644 --- a/mcpgateway/services/tool_service.py +++ b/mcpgateway/services/tool_service.py @@ -28,6 +28,7 @@ # Third-Party import httpx +import jq from mcp import ClientSession from mcp.client.sse import sse_client from mcp.client.streamable_http import streamablehttp_client @@ -62,14 +63,61 @@ from mcpgateway.utils.services_auth import decode_auth from mcpgateway.utils.sqlalchemy_modifier import json_contains_expr -# Local -from ..config import extract_using_jq - # Initialize logging service first logging_service = LoggingService() logger = logging_service.get_logger(__name__) +def extract_using_jq(data, jq_filter=""): + """ + Extracts data from a given input (string, dict, or list) using a jq filter string. + + Args: + data (str, dict, list): The input JSON data. Can be a string, dict, or list. + jq_filter (str): The jq filter string to extract the desired data. + + Returns: + The result of applying the jq filter to the input data. + + Examples: + >>> extract_using_jq('{"a": 1, "b": 2}', '.a') + [1] + >>> extract_using_jq({'a': 1, 'b': 2}, '.b') + [2] + >>> extract_using_jq('[{"a": 1}, {"a": 2}]', '.[].a') + [1, 2] + >>> extract_using_jq('not a json', '.a') + ['Invalid JSON string provided.'] + >>> extract_using_jq({'a': 1}, '') + {'a': 1} + """ + if jq_filter == "": + return data + if isinstance(data, str): + # If the input is a string, parse it as JSON + try: + data = json.loads(data) + except json.JSONDecodeError: + return ["Invalid JSON string provided."] + + elif not isinstance(data, (dict, list)): + # If the input is not a string, dict, or list, raise an error + return ["Input data must be a JSON string, dictionary, or list."] + + # Apply the jq filter to the data + try: + # Pylint can't introspect C-extension modules, so it doesn't know that jq really does export an all() function. + # pylint: disable=c-extension-no-member + result = jq.all(jq_filter, data) # Use `jq.all` to get all matches (returns a list) + if result == [None]: + result = "Error applying jsonpath filter" + except Exception as e: + message = "Error applying jsonpath filter: " + str(e) + return message + + return result + + class ToolError(Exception): """Base class for tool-related errors. diff --git a/mcpgateway/utils/sso_bootstrap.py b/mcpgateway/utils/sso_bootstrap.py index 2e91da2f4..ab1fcb537 100644 --- a/mcpgateway/utils/sso_bootstrap.py +++ b/mcpgateway/utils/sso_bootstrap.py @@ -117,7 +117,7 @@ def get_predefined_sso_providers() -> List[Dict]: "display_name": "GitHub", "provider_type": "oauth2", "client_id": settings.sso_github_client_id, - "client_secret": settings.sso_github_client_secret or "", + "client_secret": settings.sso_github_client_secret.get_secret_value() if settings.sso_github_client_secret else "", "authorization_url": "https://github.com/login/oauth/authorize", "token_url": "https://github.com/login/oauth/access_token", "userinfo_url": "https://api.github.com/user", @@ -137,7 +137,7 @@ def get_predefined_sso_providers() -> List[Dict]: "display_name": "Google", "provider_type": "oidc", "client_id": settings.sso_google_client_id, - "client_secret": settings.sso_google_client_secret or "", + "client_secret": settings.sso_google_client_secret.get_secret_value() if settings.sso_google_client_secret else "", "authorization_url": "https://accounts.google.com/o/oauth2/auth", "token_url": "https://oauth2.googleapis.com/token", "userinfo_url": "https://openidconnect.googleapis.com/v1/userinfo", @@ -159,7 +159,7 @@ def get_predefined_sso_providers() -> List[Dict]: "display_name": "IBM Security Verify", "provider_type": "oidc", "client_id": settings.sso_ibm_verify_client_id, - "client_secret": settings.sso_ibm_verify_client_secret or "", + "client_secret": settings.sso_ibm_verify_client_secret.get_secret_value() if settings.sso_ibm_verify_client_secret else "", "authorization_url": f"{base_url}/oidc/endpoint/default/authorize", "token_url": f"{base_url}/oidc/endpoint/default/token", "userinfo_url": f"{base_url}/oidc/endpoint/default/userinfo", @@ -181,7 +181,7 @@ def get_predefined_sso_providers() -> List[Dict]: "display_name": "Okta", "provider_type": "oidc", "client_id": settings.sso_okta_client_id, - "client_secret": settings.sso_okta_client_secret or "", + "client_secret": settings.sso_okta_client_secret.get_secret_value() if settings.sso_okta_client_secret else "", "authorization_url": f"{base_url}/oauth2/default/v1/authorize", "token_url": f"{base_url}/oauth2/default/v1/token", "userinfo_url": f"{base_url}/oauth2/default/v1/userinfo", @@ -204,7 +204,7 @@ def get_predefined_sso_providers() -> List[Dict]: "display_name": "Microsoft Entra ID", "provider_type": "oidc", "client_id": settings.sso_entra_client_id, - "client_secret": settings.sso_entra_client_secret or "", + "client_secret": settings.sso_entra_client_secret.get_secret_value() if settings.sso_entra_client_secret else "", "authorization_url": f"{base_url}/oauth2/v2.0/authorize", "token_url": f"{base_url}/oauth2/v2.0/token", "userinfo_url": "https://graph.microsoft.com/oidc/userinfo", @@ -232,7 +232,7 @@ def get_predefined_sso_providers() -> List[Dict]: "display_name": f"Keycloak ({settings.sso_keycloak_realm})", "provider_type": "oidc", "client_id": settings.sso_keycloak_client_id, - "client_secret": settings.sso_keycloak_client_secret or "", + "client_secret": settings.sso_keycloak_client_secret.get_secret_value() if settings.sso_keycloak_client_secret else "", "authorization_url": endpoints["authorization_url"], "token_url": endpoints["token_url"], "userinfo_url": endpoints["userinfo_url"], @@ -270,7 +270,7 @@ def get_predefined_sso_providers() -> List[Dict]: "display_name": display_name, "provider_type": "oidc", "client_id": settings.sso_generic_client_id, - "client_secret": settings.sso_generic_client_secret or "", + "client_secret": settings.sso_generic_client_secret.get_secret_value() if settings.sso_generic_client_secret else "", "authorization_url": settings.sso_generic_authorization_url, "token_url": settings.sso_generic_token_url, "userinfo_url": settings.sso_generic_userinfo_url, diff --git a/mcpgateway/utils/verify_credentials.py b/mcpgateway/utils/verify_credentials.py index 7a33f4c26..d4fa8dcad 100644 --- a/mcpgateway/utils/verify_credentials.py +++ b/mcpgateway/utils/verify_credentials.py @@ -10,6 +10,7 @@ headers and cookies. Examples: >>> from mcpgateway.utils import verify_credentials as vc + >>> from pydantic import SecretStr >>> class DummySettings: ... jwt_secret_key = 'secret' ... jwt_algorithm = 'HS256' @@ -17,7 +18,7 @@ ... jwt_issuer = 'mcpgateway' ... jwt_audience_verification = True ... basic_auth_user = 'user' - ... basic_auth_password = 'pass' + ... basic_auth_password = SecretStr('pass') ... auth_required = True ... require_token_expiration = False ... docs_allow_basic_auth = False @@ -153,6 +154,7 @@ async def verify_credentials(token: str) -> dict: Examples: >>> from mcpgateway.utils import verify_credentials as vc + >>> from pydantic import SecretStr >>> class DummySettings: ... jwt_secret_key = 'secret' ... jwt_algorithm = 'HS256' @@ -160,7 +162,7 @@ async def verify_credentials(token: str) -> dict: ... jwt_issuer = 'mcpgateway' ... jwt_audience_verification = True ... basic_auth_user = 'user' - ... basic_auth_password = 'pass' + ... basic_auth_password = SecretStr('pass') ... auth_required = True ... require_token_expiration = False ... docs_allow_basic_auth = False @@ -202,6 +204,7 @@ async def require_auth(request: Request, credentials: Optional[HTTPAuthorization Examples: >>> from mcpgateway.utils import verify_credentials as vc + >>> from pydantic import SecretStr >>> class DummySettings: ... jwt_secret_key = 'secret' ... jwt_algorithm = 'HS256' @@ -209,7 +212,7 @@ async def require_auth(request: Request, credentials: Optional[HTTPAuthorization ... jwt_issuer = 'mcpgateway' ... jwt_audience_verification = True ... basic_auth_user = 'user' - ... basic_auth_password = 'pass' + ... basic_auth_password = SecretStr('pass') ... auth_required = True ... mcp_client_auth_enabled = True ... trust_proxy_auth = False @@ -306,6 +309,7 @@ async def verify_basic_credentials(credentials: HTTPBasicCredentials) -> str: Examples: >>> from mcpgateway.utils import verify_credentials as vc + >>> from pydantic import SecretStr >>> class DummySettings: ... jwt_secret_key = 'secret' ... jwt_algorithm = 'HS256' @@ -313,7 +317,7 @@ async def verify_basic_credentials(credentials: HTTPBasicCredentials) -> str: ... jwt_issuer = 'mcpgateway' ... jwt_audience_verification = True ... basic_auth_user = 'user' - ... basic_auth_password = 'pass' + ... basic_auth_password = SecretStr('pass') ... auth_required = True ... docs_allow_basic_auth = False >>> vc.settings = DummySettings() @@ -330,7 +334,7 @@ async def verify_basic_credentials(credentials: HTTPBasicCredentials) -> str: error """ is_valid_user = credentials.username == settings.basic_auth_user - is_valid_pass = credentials.password == settings.basic_auth_password + is_valid_pass = credentials.password == settings.basic_auth_password.get_secret_value() if not (is_valid_user and is_valid_pass): raise HTTPException( @@ -359,6 +363,7 @@ async def require_basic_auth(credentials: HTTPBasicCredentials = Depends(basic_s Examples: >>> from mcpgateway.utils import verify_credentials as vc + >>> from pydantic import SecretStr >>> class DummySettings: ... jwt_secret_key = 'secret' ... jwt_algorithm = 'HS256' @@ -366,7 +371,7 @@ async def require_basic_auth(credentials: HTTPBasicCredentials = Depends(basic_s ... jwt_issuer = 'mcpgateway' ... jwt_audience_verification = True ... basic_auth_user = 'user' - ... basic_auth_password = 'pass' + ... basic_auth_password = SecretStr('pass') ... auth_required = True ... docs_allow_basic_auth = False >>> vc.settings = DummySettings() @@ -420,6 +425,7 @@ async def require_docs_basic_auth(auth_header: str) -> str: Examples: >>> from mcpgateway.utils import verify_credentials as vc + >>> from pydantic import SecretStr >>> class DummySettings: ... jwt_secret_key = 'secret' ... jwt_algorithm = 'HS256' @@ -427,7 +433,7 @@ async def require_docs_basic_auth(auth_header: str) -> str: ... jwt_issuer = 'mcpgateway' ... jwt_audience_verification = True ... basic_auth_user = 'user' - ... basic_auth_password = 'pass' + ... basic_auth_password = SecretStr('pass') ... auth_required = True ... require_token_expiration = False ... docs_allow_basic_auth = True @@ -633,6 +639,7 @@ async def require_auth_override( Examples: >>> from mcpgateway.utils import verify_credentials as vc + >>> from pydantic import SecretStr >>> class DummySettings: ... jwt_secret_key = 'secret' ... jwt_algorithm = 'HS256' @@ -640,7 +647,7 @@ async def require_auth_override( ... jwt_issuer = 'mcpgateway' ... jwt_audience_verification = True ... basic_auth_user = 'user' - ... basic_auth_password = 'pass' + ... basic_auth_password = SecretStr('pass') ... auth_required = True ... mcp_client_auth_enabled = True ... trust_proxy_auth = False diff --git a/tests/fuzz/fuzzers/fuzz_jsonpath.py b/tests/fuzz/fuzzers/fuzz_jsonpath.py index acc364a26..90ca4955c 100755 --- a/tests/fuzz/fuzzers/fuzz_jsonpath.py +++ b/tests/fuzz/fuzzers/fuzz_jsonpath.py @@ -24,7 +24,7 @@ from fastapi import HTTPException # First-Party - from mcpgateway.config import jsonpath_modifier + from mcpgateway.main import jsonpath_modifier except ImportError as e: print(f"Import error: {e}") sys.exit(1) diff --git a/tests/fuzz/test_jsonpath_fuzz.py b/tests/fuzz/test_jsonpath_fuzz.py index 507469e8a..38f696fbe 100644 --- a/tests/fuzz/test_jsonpath_fuzz.py +++ b/tests/fuzz/test_jsonpath_fuzz.py @@ -14,7 +14,7 @@ import pytest # First-Party -from mcpgateway.config import jsonpath_modifier +from mcpgateway.main import jsonpath_modifier class TestJSONPathFuzzing: diff --git a/tests/unit/mcpgateway/services/test_tool_service.py b/tests/unit/mcpgateway/services/test_tool_service.py index 61741e466..4c58a4d9e 100644 --- a/tests/unit/mcpgateway/services/test_tool_service.py +++ b/tests/unit/mcpgateway/services/test_tool_service.py @@ -24,6 +24,7 @@ from mcpgateway.plugins.framework import PluginManager from mcpgateway.schemas import AuthenticationValues, ToolCreate, ToolRead, ToolUpdate from mcpgateway.services.tool_service import ( + extract_using_jq, TextContent, ToolError, ToolInvocationError, @@ -1324,7 +1325,7 @@ async def test_invoke_tool_rest_post(self, tool_service, mock_tool, test_db): # Mock decode_auth to return empty dict when auth_value is None # Mock extract_using_jq to return the input unmodified when filter is empty - with patch("mcpgateway.services.tool_service.decode_auth", return_value={}), patch("mcpgateway.config.extract_using_jq", return_value={"result": "REST tool response"}): + with patch("mcpgateway.services.tool_service.decode_auth", return_value={}), patch("mcpgateway.services.tool_service.extract_using_jq", return_value={"result": "REST tool response"}): # Invoke tool result = await tool_service.invoke_tool(test_db, "test_tool", {"param": "value"}, request_headers=None) @@ -2046,7 +2047,7 @@ async def test_invoke_tool_rest_oauth_success(self, tool_service, mock_tool, tes # Mock metrics recording tool_service._record_tool_metric = AsyncMock() - with patch("mcpgateway.config.extract_using_jq", return_value={"result": "OAuth success"}): + with patch("mcpgateway.services.tool_service.extract_using_jq", return_value={"result": "OAuth success"}): # Invoke tool result = await tool_service.invoke_tool(test_db, "test_tool", {"param": "value"}, request_headers=None) @@ -2166,7 +2167,7 @@ def mock_passthrough(req_headers, tool_headers, db_session): with ( patch("mcpgateway.services.tool_service.decode_auth", return_value={}), patch("mcpgateway.services.tool_service.get_passthrough_headers", side_effect=mock_passthrough), - patch("mcpgateway.config.extract_using_jq", return_value={"result": "success with headers"}), + patch("mcpgateway.services.tool_service.extract_using_jq", return_value={"result": "success with headers"}), ): result = await tool_service.invoke_tool(test_db, "test_tool", {"param": "value"}, request_headers=request_headers) @@ -2259,7 +2260,7 @@ async def test_invoke_tool_with_plugin_post_invoke_success(self, tool_service, m with ( patch("mcpgateway.services.tool_service.decode_auth", return_value={}), - patch("mcpgateway.config.extract_using_jq", return_value={"result": "original response"}), + patch("mcpgateway.services.tool_service.extract_using_jq", return_value={"result": "original response"}), ): result = await tool_service.invoke_tool(test_db, "test_tool", {"param": "value"}, request_headers=None) @@ -2303,7 +2304,7 @@ async def test_invoke_tool_with_plugin_post_invoke_modified_payload(self, tool_s with ( patch("mcpgateway.services.tool_service.decode_auth", return_value={}), - patch("mcpgateway.config.extract_using_jq", return_value={"result": "original response"}), + patch("mcpgateway.services.tool_service.extract_using_jq", return_value={"result": "original response"}), ): result = await tool_service.invoke_tool(test_db, "test_tool", {"param": "value"}, request_headers=None) @@ -2347,7 +2348,7 @@ async def test_invoke_tool_with_plugin_post_invoke_invalid_modified_payload(self with ( patch("mcpgateway.services.tool_service.decode_auth", return_value={}), - patch("mcpgateway.config.extract_using_jq", return_value={"result": "original response"}), + patch("mcpgateway.services.tool_service.extract_using_jq", return_value={"result": "original response"}), ): result = await tool_service.invoke_tool(test_db, "test_tool", {"param": "value"}, request_headers=None) @@ -2393,7 +2394,7 @@ async def test_invoke_tool_with_plugin_post_invoke_error_fail_on_error(self, too with ( patch("mcpgateway.services.tool_service.decode_auth", return_value={}), - patch("mcpgateway.config.extract_using_jq", return_value={"result": "original response"}), + patch("mcpgateway.services.tool_service.extract_using_jq", return_value={"result": "original response"}), ): with pytest.raises(Exception) as exc_info: await tool_service.invoke_tool(test_db, "test_tool", {"param": "value"}, request_headers=None) @@ -2427,7 +2428,7 @@ async def test_invoke_tool_with_plugin_metadata_rest(self, tool_service, mock_to with ( patch("mcpgateway.services.tool_service.decode_auth", return_value={}), - patch("mcpgateway.config.extract_using_jq", return_value={"result": "original response"}), + patch("mcpgateway.services.tool_service.extract_using_jq", return_value={"result": "original response"}), ): result = await tool_service.invoke_tool(test_db, "test_tool", {"param": "value"}, request_headers=None) @@ -2484,3 +2485,27 @@ async def test_invoke_tool_with_plugin_metadata_sse(self, tool_service, mock_too await tool_service.invoke_tool(test_db, "test_tool", {"param": "value"}, request_headers=None) await tool_service._plugin_manager.shutdown() + + +# --------------------------------------------------------------------------- # +# extract_using_jq # +# --------------------------------------------------------------------------- # +def test_extract_using_jq_happy_path(): + data = {"a": 123} + + with patch("mcpgateway.services.tool_service.jq.all", return_value=[123]) as mock_jq: + out = extract_using_jq(data, ".a") + mock_jq.assert_called_once_with(".a", data) + assert out == [123] + + +def test_extract_using_jq_short_circuits_and_errors(): + # Empty filter returns data unmodified + orig = {"x": "y"} + assert extract_using_jq(orig) is orig + + # Non-JSON string + assert extract_using_jq("this isn't json", ".foo") == ["Invalid JSON string provided."] + + # Unsupported input type + assert extract_using_jq(42, ".foo") == ["Input data must be a JSON string, dictionary, or list."] diff --git a/tests/unit/mcpgateway/test_bootstrap_db.py b/tests/unit/mcpgateway/test_bootstrap_db.py index 9526f307d..e0c76b8ba 100644 --- a/tests/unit/mcpgateway/test_bootstrap_db.py +++ b/tests/unit/mcpgateway/test_bootstrap_db.py @@ -12,6 +12,7 @@ from unittest.mock import AsyncMock, MagicMock, Mock, patch # Third-Party +from pydantic import SecretStr import pytest # First-Party @@ -30,7 +31,7 @@ def mock_settings(): settings = Mock() settings.email_auth_enabled = True settings.platform_admin_email = "admin@example.com" - settings.platform_admin_password = "secure_password" + settings.platform_admin_password = SecretStr("secure_password") settings.platform_admin_full_name = "Platform Admin" settings.auto_create_personal_teams = True settings.database_url = "sqlite:///:memory:" @@ -124,20 +125,25 @@ async def test_bootstrap_admin_user_success(self, mock_settings, mock_db_session mock_email_auth_service.get_user_by_email.return_value = None mock_email_auth_service.create_user.return_value = mock_admin_user - with patch("mcpgateway.bootstrap_db.settings", mock_settings): - with patch("mcpgateway.bootstrap_db.SessionLocal", return_value=mock_db_session): - with patch("mcpgateway.services.email_auth_service.EmailAuthService", return_value=mock_email_auth_service): - with patch("mcpgateway.db.utc_now") as mock_utc_now: - mock_utc_now.return_value = "2024-01-01T00:00:00Z" - with patch("mcpgateway.bootstrap_db.logger") as mock_logger: - await bootstrap_admin_user() - - mock_email_auth_service.create_user.assert_called_once_with( - email=mock_settings.platform_admin_email, password=mock_settings.platform_admin_password, full_name=mock_settings.platform_admin_full_name, is_admin=True - ) - assert mock_admin_user.email_verified_at == "2024-01-01T00:00:00Z" - assert mock_db_session.commit.call_count == 2 - mock_logger.info.assert_any_call(f"Platform admin user created successfully: {mock_settings.platform_admin_email}") + with ( + patch("mcpgateway.bootstrap_db.settings", mock_settings), + patch("mcpgateway.bootstrap_db.SessionLocal", return_value=mock_db_session), + patch("mcpgateway.services.email_auth_service.EmailAuthService", return_value=mock_email_auth_service), + patch("mcpgateway.db.utc_now") as mock_utc_now, + patch("mcpgateway.bootstrap_db.logger") as mock_logger, + ): + mock_utc_now.return_value = "2024-01-01T00:00:00Z" + await bootstrap_admin_user() + + mock_email_auth_service.create_user.assert_called_once_with( + email=mock_settings.platform_admin_email, + password=mock_settings.platform_admin_password.get_secret_value(), + full_name=mock_settings.platform_admin_full_name, + is_admin=True, + ) + assert mock_admin_user.email_verified_at == "2024-01-01T00:00:00Z" + assert mock_db_session.commit.call_count == 2 + mock_logger.info.assert_any_call(f"Platform admin user created successfully: {mock_settings.platform_admin_email}") @pytest.mark.asyncio async def test_bootstrap_admin_user_with_personal_team(self, mock_settings, mock_db_session, mock_email_auth_service, mock_admin_user): diff --git a/tests/unit/mcpgateway/test_config.py b/tests/unit/mcpgateway/test_config.py index 9bba7b989..8521e63bc 100644 --- a/tests/unit/mcpgateway/test_config.py +++ b/tests/unit/mcpgateway/test_config.py @@ -11,20 +11,17 @@ # Standard import os from pathlib import Path -from typing import Any, Dict, List from unittest.mock import MagicMock, patch -# Third-Party -from fastapi import HTTPException +from pydantic import SecretStr +# Third-Party # Third-party import pytest # First-Party from mcpgateway.config import ( - extract_using_jq, get_settings, - jsonpath_modifier, Settings, ) @@ -55,7 +52,7 @@ def test_parse_federation_peers_json_and_csv(): # --------------------------------------------------------------------------- # # database / CORS helpers # # --------------------------------------------------------------------------- # -def test_database_settings_sqlite_and_non_sqlite(tmp_path: Path): +def test_database_settings_sqlite_and_non_sqlite(tmp_path: Path) -> None: """connect_args differs for sqlite vs everything else.""" # sqlite -> check_same_thread flag present db_file = tmp_path / "foo" / "bar.db" @@ -68,7 +65,7 @@ def test_database_settings_sqlite_and_non_sqlite(tmp_path: Path): assert s_pg.database_settings["connect_args"] == {} -def test_validate_database_creates_missing_parent(tmp_path: Path): +def test_validate_database_creates_missing_parent(tmp_path: Path) -> None: db_file = tmp_path / "newdir" / "db.sqlite" url = f"sqlite:///{db_file}" s = Settings(database_url=url, _env_file=None) @@ -103,65 +100,6 @@ def test_cors_settings_branches(): assert result == {} # Empty dict when disabled -# --------------------------------------------------------------------------- # -# extract_using_jq # -# --------------------------------------------------------------------------- # -def test_extract_using_jq_happy_path(): - data = {"a": 123} - - with patch("mcpgateway.config.jq.all", return_value=[123]) as mock_jq: - out = extract_using_jq(data, ".a") - mock_jq.assert_called_once_with(".a", data) - assert out == [123] - - -def test_extract_using_jq_short_circuits_and_errors(): - # Empty filter returns data unmodified - orig = {"x": "y"} - assert extract_using_jq(orig) is orig - - # Non-JSON string - assert extract_using_jq("this isn't json", ".foo") == ["Invalid JSON string provided."] - - # Unsupported input type - assert extract_using_jq(42, ".foo") == ["Input data must be a JSON string, dictionary, or list."] - - -# --------------------------------------------------------------------------- # -# jsonpath_modifier # -# --------------------------------------------------------------------------- # -@pytest.fixture(scope="module") -def sample_people() -> List[Dict[str, Any]]: - return [ - {"name": "Ada", "id": 1}, - {"name": "Bob", "id": 2}, - ] - - -def test_jsonpath_modifier_basic_match(sample_people): - # Pull out names directly - names = jsonpath_modifier(sample_people, "$[*].name") - assert names == ["Ada", "Bob"] - - # Same query but with a mapping - mapped = jsonpath_modifier(sample_people, "$[*]", mappings={"n": "$.name"}) - assert mapped == [{"n": "Ada"}, {"n": "Bob"}] - - -def test_jsonpath_modifier_single_dict_collapse(): - person = {"name": "Zoe", "id": 10} - out = jsonpath_modifier(person, "$") - assert out == person # single-item dict collapses to dict, not list - - -def test_jsonpath_modifier_invalid_expressions(sample_people): - with pytest.raises(HTTPException): - jsonpath_modifier(sample_people, "$[") # invalid main expr - - with pytest.raises(HTTPException): - jsonpath_modifier(sample_people, "$[*]", mappings={"bad": "$["}) # invalid mapping expr - - # --------------------------------------------------------------------------- # # get_settings LRU cache # # --------------------------------------------------------------------------- # @@ -200,7 +138,7 @@ def test_settings_default_values(): assert settings.port == 4444 assert settings.database_url == "sqlite:///./mcp.db" assert settings.basic_auth_user == "admin" - assert settings.basic_auth_password == "changeme" + assert settings.basic_auth_password == SecretStr("changeme") assert settings.auth_required is True assert settings.jwt_secret_key.get_secret_value() == "x" * 32 assert settings.auth_encryption_secret.get_secret_value() == "dummy-secret" diff --git a/tests/unit/mcpgateway/test_coverage_push.py b/tests/unit/mcpgateway/test_coverage_push.py index 582a80d0b..f3bc7135b 100644 --- a/tests/unit/mcpgateway/test_coverage_push.py +++ b/tests/unit/mcpgateway/test_coverage_push.py @@ -13,6 +13,7 @@ # Third-Party from fastapi import HTTPException from fastapi.testclient import TestClient +from pydantic import SecretStr import pytest # First-Party @@ -36,14 +37,14 @@ def test_require_api_key_scenarios(): with patch("mcpgateway.main.settings") as mock_settings: mock_settings.auth_required = True mock_settings.basic_auth_user = "admin" - mock_settings.basic_auth_password = "secret" + mock_settings.basic_auth_password = SecretStr("secret") require_api_key("admin:secret") # Should not raise # Test with auth enabled and incorrect key with patch("mcpgateway.main.settings") as mock_settings: mock_settings.auth_required = True mock_settings.basic_auth_user = "admin" - mock_settings.basic_auth_password = "secret" + mock_settings.basic_auth_password = SecretStr("secret") with pytest.raises(HTTPException): require_api_key("wrong:key") diff --git a/tests/unit/mcpgateway/test_main.py b/tests/unit/mcpgateway/test_main.py index e3a3bff07..cc2ed736c 100644 --- a/tests/unit/mcpgateway/test_main.py +++ b/tests/unit/mcpgateway/test_main.py @@ -17,6 +17,7 @@ from unittest.mock import ANY, MagicMock, patch # Third-Party +from fastapi import HTTPException from fastapi.testclient import TestClient import jwt import pytest @@ -691,6 +692,7 @@ def test_read_resource_endpoint(self, mock_read_resource, test_client, auth_head """Test reading resource content.""" # Clear the resource cache to avoid stale/cached values from mcpgateway import main as mcpgateway_main + mcpgateway_main.resource_cache.clear() mock_read_resource.return_value = ResourceContent( @@ -1536,3 +1538,47 @@ def test_redoc_with_auth(self, test_client, auth_headers): """Test GET /redoc with authentication returns 200 or redirect.""" response = test_client.get("/redoc", headers=auth_headers) assert response.status_code == 200 + + +# --------------------------------------------------------------------------- # +# jsonpath_modifier # +# --------------------------------------------------------------------------- # +@pytest.fixture(scope="module") +def sample_people(): + return [ + {"name": "Ada", "id": 1}, + {"name": "Bob", "id": 2}, + ] + + +def test_jsonpath_modifier_basic_match(sample_people): + # First-Party + from mcpgateway.main import jsonpath_modifier + + # Pull out names directly + names = jsonpath_modifier(sample_people, "$[*].name") + assert names == ["Ada", "Bob"] + + # Same query but with a mapping + mapped = jsonpath_modifier(sample_people, "$[*]", mappings={"n": "$.name"}) + assert mapped == [{"n": "Ada"}, {"n": "Bob"}] + + +def test_jsonpath_modifier_single_dict_collapse(): + # First-Party + from mcpgateway.main import jsonpath_modifier + + person = {"name": "Zoe", "id": 10} + out = jsonpath_modifier(person, "$") + assert out == person # single-item dict collapses to dict, not list + + +def test_jsonpath_modifier_invalid_expressions(sample_people): + # First-Party + from mcpgateway.main import jsonpath_modifier + + with pytest.raises(HTTPException): + jsonpath_modifier(sample_people, "$[") # invalid main expr + + with pytest.raises(HTTPException): + jsonpath_modifier(sample_people, "$[*]", mappings={"bad": "$["}) # invalid mapping expr diff --git a/tests/unit/mcpgateway/test_validate_env.py b/tests/unit/mcpgateway/test_validate_env.py index 662ffaf0b..704a2d313 100644 --- a/tests/unit/mcpgateway/test_validate_env.py +++ b/tests/unit/mcpgateway/test_validate_env.py @@ -1,20 +1,21 @@ # -*- coding: utf-8 -*- # File: tests/unit/mcpgateway/test_validate_env.py -from pathlib import Path -import pytest import logging import os +from pathlib import Path from unittest.mock import patch -# Suppress mcpgateway.config logs during tests -logging.getLogger("mcpgateway.config").setLevel(logging.ERROR) +import pytest # Import the validate_env script directly from mcpgateway.scripts import validate_env as ve +# Suppress mcpgateway.config logs during tests +logging.getLogger("mcpgateway.config").setLevel(logging.ERROR) + @pytest.fixture -def valid_env(tmp_path: Path): +def valid_env(tmp_path: Path) -> Path: envfile = tmp_path / ".env" envfile.write_text( "APP_DOMAIN=http://localhost:8000\n" @@ -30,14 +31,14 @@ def valid_env(tmp_path: Path): @pytest.fixture -def invalid_env(tmp_path: Path): +def invalid_env(tmp_path: Path) -> Path: envfile = tmp_path / ".env" # Invalid URL + wrong log level + invalid port envfile.write_text("APP_DOMAIN=not-a-url\nPORT=-1\nLOG_LEVEL=wronglevel\n") return envfile -def test_validate_env_success_direct(valid_env: Path): +def test_validate_env_success_direct(valid_env: Path) -> None: """ Test a valid .env. Warnings will be printed but do NOT fail the test. """ @@ -57,7 +58,7 @@ def test_validate_env_success_direct(valid_env: Path): assert code == 0 -def test_validate_env_failure_direct(invalid_env: Path): +def test_validate_env_failure_direct(invalid_env: Path) -> None: """ Test an invalid .env. Should fail due to ValidationError. """ diff --git a/tests/unit/mcpgateway/utils/test_verify_credentials.py b/tests/unit/mcpgateway/utils/test_verify_credentials.py index 860beb42b..dabf49f63 100644 --- a/tests/unit/mcpgateway/utils/test_verify_credentials.py +++ b/tests/unit/mcpgateway/utils/test_verify_credentials.py @@ -34,6 +34,7 @@ from fastapi.security import HTTPAuthorizationCredentials, HTTPBasicCredentials from fastapi.testclient import TestClient import jwt +from pydantic import SecretStr import pytest # First-Party @@ -157,7 +158,7 @@ async def test_require_auth_missing_token(monkeypatch): @pytest.mark.asyncio async def test_verify_basic_credentials_success(monkeypatch): monkeypatch.setattr(vc.settings, "basic_auth_user", "alice", raising=False) - monkeypatch.setattr(vc.settings, "basic_auth_password", "secret", raising=False) + monkeypatch.setattr(vc.settings, "basic_auth_password", SecretStr("secret"), raising=False) creds = HTTPBasicCredentials(username="alice", password="secret") assert await vc.verify_basic_credentials(creds) == "alice" @@ -166,7 +167,7 @@ async def test_verify_basic_credentials_success(monkeypatch): @pytest.mark.asyncio async def test_verify_basic_credentials_failure(monkeypatch): monkeypatch.setattr(vc.settings, "basic_auth_user", "alice", raising=False) - monkeypatch.setattr(vc.settings, "basic_auth_password", "secret", raising=False) + monkeypatch.setattr(vc.settings, "basic_auth_password", SecretStr("secret"), raising=False) creds = HTTPBasicCredentials(username="bob", password="wrong") with pytest.raises(HTTPException) as exc: @@ -237,7 +238,7 @@ async def test_require_auth_override_basic_auth_enabled_success(monkeypatch): monkeypatch.setattr(vc.settings, "docs_allow_basic_auth", True, raising=False) monkeypatch.setattr(vc.settings, "auth_required", True, raising=False) monkeypatch.setattr(vc.settings, "basic_auth_user", "alice", raising=False) - monkeypatch.setattr(vc.settings, "basic_auth_password", "secret", raising=False) + monkeypatch.setattr(vc.settings, "basic_auth_password", SecretStr("secret"), raising=False) basic_auth_header = f"Basic {base64.b64encode('alice:secret'.encode()).decode()}" result = await vc.require_auth_override(auth_header=basic_auth_header) assert result == vc.settings.basic_auth_user @@ -249,7 +250,7 @@ async def test_require_auth_override_basic_auth_enabled_failure(monkeypatch): monkeypatch.setattr(vc.settings, "docs_allow_basic_auth", True, raising=False) monkeypatch.setattr(vc.settings, "auth_required", True, raising=False) monkeypatch.setattr(vc.settings, "basic_auth_user", "alice", raising=False) - monkeypatch.setattr(vc.settings, "basic_auth_password", "secret", raising=False) + monkeypatch.setattr(vc.settings, "basic_auth_password", SecretStr("secret"), raising=False) # case1. format is wrong header = "Basic fakeAuth" @@ -313,7 +314,7 @@ async def test_docs_both_auth_methods_work_simultaneously(monkeypatch): monkeypatch.setattr(vc.settings, "auth_required", True, raising=False) monkeypatch.setattr(vc.settings, "docs_allow_basic_auth", True, raising=False) monkeypatch.setattr(vc.settings, "basic_auth_user", "admin", raising=False) - monkeypatch.setattr(vc.settings, "basic_auth_password", "secret", raising=False) + monkeypatch.setattr(vc.settings, "basic_auth_password", SecretStr("secret"), raising=False) monkeypatch.setattr(vc.settings, "jwt_secret_key", SECRET, raising=False) monkeypatch.setattr(vc.settings, "jwt_algorithm", ALGO, raising=False) monkeypatch.setattr(vc.settings, "jwt_audience", "mcpgateway-api", raising=False) @@ -335,7 +336,7 @@ async def test_docs_invalid_basic_auth_fails(monkeypatch): monkeypatch.setattr(vc.settings, "auth_required", True, raising=False) monkeypatch.setattr(vc.settings, "docs_allow_basic_auth", True, raising=False) monkeypatch.setattr(vc.settings, "basic_auth_user", "admin", raising=False) - monkeypatch.setattr(vc.settings, "basic_auth_password", "correct", raising=False) + monkeypatch.setattr(vc.settings, "basic_auth_password", SecretStr("correct"), raising=False) # Send wrong Basic Auth wrong_basic = f"Basic {base64.b64encode(b'admin:wrong').decode()}" with pytest.raises(HTTPException) as exc: @@ -349,7 +350,7 @@ async def test_integration_docs_endpoint_both_auth_methods(test_client, monkeypa """Integration test: /docs accepts both auth methods when enabled.""" monkeypatch.setattr("mcpgateway.config.settings.docs_allow_basic_auth", True) monkeypatch.setattr("mcpgateway.config.settings.basic_auth_user", "admin") - monkeypatch.setattr("mcpgateway.config.settings.basic_auth_password", "changeme") + monkeypatch.setattr("mcpgateway.config.settings.basic_auth_password", SecretStr("changeme")) monkeypatch.setattr("mcpgateway.config.settings.jwt_secret_key", SECRET) monkeypatch.setattr("mcpgateway.config.settings.jwt_algorithm", ALGO) monkeypatch.setattr("mcpgateway.config.settings.jwt_audience", "mcpgateway-api")