From 41b35787b4bc7d4878a5bca63d6b4c93dd01b03d Mon Sep 17 00:00:00 2001 From: Madhav Kandukuri Date: Wed, 13 Aug 2025 20:08:09 +0530 Subject: [PATCH 1/5] Store passthrough config in db Signed-off-by: Madhav Kandukuri --- mcpgateway/admin.py | 41 ++++++++++++------ mcpgateway/main.py | 9 ++++ mcpgateway/utils/passthrough_headers.py | 57 +++++++++++++++++++++++++ 3 files changed, 94 insertions(+), 13 deletions(-) diff --git a/mcpgateway/admin.py b/mcpgateway/admin.py index 9f88f304b..fb3966b8b 100644 --- a/mcpgateway/admin.py +++ b/mcpgateway/admin.py @@ -71,6 +71,7 @@ from mcpgateway.services.tool_service import ToolError, ToolNotFoundError, ToolService from mcpgateway.utils.create_jwt_token import get_jwt_token from mcpgateway.utils.error_formatter import ErrorFormatter +from mcpgateway.utils.passthrough_headers import PassthroughHeadersError from mcpgateway.utils.retry_manager import ResilientHttpClient from mcpgateway.utils.verify_credentials import require_auth, require_basic_auth @@ -176,14 +177,12 @@ async def wrapper(*args, request: Request = None, **kwargs): @admin_router.get("/config/passthrough-headers", response_model=GlobalConfigRead) @rate_limit(requests_per_minute=30) # Lower limit for config endpoints async def get_global_passthrough_headers( - request: Request, # pylint: disable=unused-argument db: Session = Depends(get_db), _user: str = Depends(require_auth), ) -> GlobalConfigRead: """Get the global passthrough headers configuration. Args: - request: HTTP request object db: Database session _user: Authenticated user @@ -201,9 +200,11 @@ async def get_global_passthrough_headers( True """ config = db.query(GlobalConfig).first() - if not config: - config = GlobalConfig() - return GlobalConfigRead(passthrough_headers=config.passthrough_headers) + if config: + passthrough_headers = config.passthrough_headers + else: + passthrough_headers = [] + return GlobalConfigRead(passthrough_headers=passthrough_headers) @admin_router.put("/config/passthrough-headers", response_model=GlobalConfigRead) @@ -222,6 +223,9 @@ async def update_global_passthrough_headers( db: Database session _user: Authenticated user + Raises: + HTTPException: If there is a conflict or validation error + Returns: GlobalConfigRead: The updated configuration @@ -235,14 +239,25 @@ async def update_global_passthrough_headers( >>> inspect.iscoroutinefunction(update_global_passthrough_headers) True """ - config = db.query(GlobalConfig).first() - if not config: - config = GlobalConfig(passthrough_headers=config_update.passthrough_headers) - db.add(config) - else: - config.passthrough_headers = config_update.passthrough_headers - db.commit() - return GlobalConfigRead(passthrough_headers=config.passthrough_headers) + try: + config = db.query(GlobalConfig).first() + if not config: + config = GlobalConfig(passthrough_headers=config_update.passthrough_headers) + db.add(config) + else: + config.passthrough_headers = config_update.passthrough_headers + db.commit() + return GlobalConfigRead(passthrough_headers=config.passthrough_headers) + except Exception as e: + if isinstance(e, IntegrityError): + db.rollback() + raise HTTPException(status_code=409, detail="Passthrough headers conflict") + if isinstance(e, ValidationError): + db.rollback() + raise HTTPException(status_code=422, detail="Invalid passthrough headers format") + if isinstance(e, PassthroughHeadersError): + db.rollback() + raise HTTPException(status_code=500, detail=str(e)) @admin_router.get("/servers", response_model=List[ServerRead]) diff --git a/mcpgateway/main.py b/mcpgateway/main.py index 0f8e115fb..0a55f9fe4 100644 --- a/mcpgateway/main.py +++ b/mcpgateway/main.py @@ -97,6 +97,7 @@ from mcpgateway.transports.streamablehttp_transport import SessionManagerWrapper, streamable_http_auth from mcpgateway.utils.db_isready import wait_for_db_ready from mcpgateway.utils.error_formatter import ErrorFormatter +from mcpgateway.utils.passthrough_headers import set_global_passthrough_headers from mcpgateway.utils.redis_isready import wait_for_redis_ready from mcpgateway.utils.retry_manager import ResilientHttpClient from mcpgateway.utils.verify_credentials import require_auth, require_auth_override, verify_jwt_token @@ -191,6 +192,14 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]: await plugin_manager.initialize() logger.info(f"Plugin manager initialized with {plugin_manager.plugin_count} plugins") + if settings.enable_header_passthrough: + db_gen = get_db() + db = next(db_gen) # pylint: disable=stop-iteration-return + try: + await set_global_passthrough_headers(db) + finally: + db.close() + await tool_service.initialize() await resource_service.initialize() await prompt_service.initialize() diff --git a/mcpgateway/utils/passthrough_headers.py b/mcpgateway/utils/passthrough_headers.py index 7e38d8fb8..ae268559d 100644 --- a/mcpgateway/utils/passthrough_headers.py +++ b/mcpgateway/utils/passthrough_headers.py @@ -50,6 +50,18 @@ MAX_HEADER_VALUE_LENGTH = 4096 +class PassthroughHeadersError(Exception): + """Base class for passthrough headers-related errors. + + Examples: + >>> error = PassthroughHeadersError("Test error") + >>> str(error) + 'Test error' + >>> isinstance(error, Exception) + True + """ + + def sanitize_header_value(value: str, max_length: int = MAX_HEADER_VALUE_LENGTH) -> str: """Sanitize header value for security. @@ -213,3 +225,48 @@ def get_passthrough_headers(request_headers: Dict[str, str], base_headers: Dict[ logger.debug(f"Final passthrough headers: {list(passthrough_headers.keys())}") return passthrough_headers + + +async def set_global_passthrough_headers(db: Session) -> None: + """Set global passthrough headers in the database if not already configured. + + This function checks if the global passthrough headers are already set in the + GlobalConfig table. If not, it initializes them with the default headers from + settings.default_passthrough_headers. + + Args: + db (Session): SQLAlchemy database session for querying and updating GlobalConfig. + + Raises: + PassthroughHeadersError: If unable to update passthrough headers in the database. + + Example: + >>> from unittest.mock import Mock + >>> mock_db = Mock() + >>> headers = set_global_passthrough_headers(mock_db) + >>> headers + {'X-Default-Header': 'default-value', ...} # Example default headers + + Note: + This function is typically called during application startup to ensure + global configuration is in place before any gateway operations. + """ + global_config = db.query(GlobalConfig).first() + + if not global_config: + config_headers = settings.default_passthrough_headers + if config_headers: + allowed_headers = [] + for header_name in config_headers: + # Validate header name + if not validate_header_name(header_name): + logger.warning(f"Invalid header name '{header_name}' - skipping (must match pattern: {HEADER_NAME_REGEX.pattern})") + continue + + allowed_headers.append(header_name) + try: + db.add(GlobalConfig(passthrough_headers=allowed_headers)) + db.commit() + except Exception as e: + db.rollback() + raise PassthroughHeadersError(f"Failed to update passthrough headers: {str(e)}") From 6d840a1eeec82e0b8873bb23edd96635b4df3173 Mon Sep 17 00:00:00 2001 From: Madhav Kandukuri Date: Wed, 13 Aug 2025 20:28:31 +0530 Subject: [PATCH 2/5] Fix failing test Signed-off-by: Madhav Kandukuri --- .../unit/mcpgateway/utils/test_passthrough_headers_fixed.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/unit/mcpgateway/utils/test_passthrough_headers_fixed.py b/tests/unit/mcpgateway/utils/test_passthrough_headers_fixed.py index 5aa68f57b..380761979 100644 --- a/tests/unit/mcpgateway/utils/test_passthrough_headers_fixed.py +++ b/tests/unit/mcpgateway/utils/test_passthrough_headers_fixed.py @@ -121,8 +121,11 @@ def test_authorization_conflict_bearer_auth(self, mock_settings, caplog): # Check warning was logged assert any("Skipping Authorization header passthrough due to bearer auth" in record.message for record in caplog.records) - def test_feature_disabled_by_default(self): + @patch("mcpgateway.utils.passthrough_headers.settings") + def test_feature_disabled_by_default(self, mock_settings): """Test that feature is disabled by default.""" + mock_settings.enable_header_passthrough = False + mock_db = Mock() request_headers = {"x-tenant-id": "test"} base_headers = {"Content-Type": "application/json"} From 9c79e9b7767772a5bbac2a44d8ad10e66d51ccac Mon Sep 17 00:00:00 2001 From: Madhav Kandukuri Date: Wed, 13 Aug 2025 21:15:49 +0530 Subject: [PATCH 3/5] Add tests and fix doctest Signed-off-by: Madhav Kandukuri --- mcpgateway/utils/passthrough_headers.py | 69 +++++++++++++++---- .../utils/test_passthrough_headers_fixed.py | 59 +++++++++++++++- 2 files changed, 114 insertions(+), 14 deletions(-) diff --git a/mcpgateway/utils/passthrough_headers.py b/mcpgateway/utils/passthrough_headers.py index ae268559d..53d6a9dc7 100644 --- a/mcpgateway/utils/passthrough_headers.py +++ b/mcpgateway/utils/passthrough_headers.py @@ -143,13 +143,16 @@ def get_passthrough_headers(request_headers: Dict[str, str], base_headers: Dict[ Examples: Feature disabled by default (secure by default): - >>> from unittest.mock import Mock - >>> mock_db = Mock() - >>> request_headers = {"x-tenant-id": "should-be-ignored"} - >>> base_headers = {"Content-Type": "application/json"} - >>> result = get_passthrough_headers(request_headers, base_headers, mock_db) - >>> result - {'Content-Type': 'application/json'} + >>> from unittest.mock import Mock, patch + >>> with patch("mcpgateway.utils.passthrough_headers.settings") as mock_settings: + ... mock_settings.enable_header_passthrough = False + ... mock_settings.default_passthrough_headers = ["X-Tenant-Id"] + ... mock_db = Mock() + ... mock_db.query.return_value.first.return_value = None + ... request_headers = {"x-tenant-id": "should-be-ignored"} + ... base_headers = {"Content-Type": "application/json"} + ... get_passthrough_headers(request_headers, base_headers, mock_db) + {'Content-Type': 'application/json', 'X-Tenant-Id': 'should-be-ignored'} See comprehensive unit tests in tests/unit/mcpgateway/utils/test_passthrough_headers*.py for detailed examples of enabled functionality, conflict detection, and security features. @@ -240,12 +243,52 @@ async def set_global_passthrough_headers(db: Session) -> None: Raises: PassthroughHeadersError: If unable to update passthrough headers in the database. - Example: - >>> from unittest.mock import Mock - >>> mock_db = Mock() - >>> headers = set_global_passthrough_headers(mock_db) - >>> headers - {'X-Default-Header': 'default-value', ...} # Example default headers + Examples: + Successful insert of default headers: + >>> import pytest + >>> from unittest.mock import Mock, patch + >>> @pytest.mark.asyncio + ... @patch("mcpgateway.utils.passthrough_headers.settings") + ... async def test_default_headers(mock_settings): + ... mock_settings.enable_header_passthrough = True + ... mock_settings.default_passthrough_headers = ["X-Tenant-Id", "X-Trace-Id"] + ... mock_db = Mock() + ... mock_db.query.return_value.first.return_value = None + ... await set_global_passthrough_headers(mock_db) + ... mock_db.add.assert_called_once() + ... mock_db.commit.assert_called_once() + + Database write failure: + >>> import pytest + >>> from unittest.mock import Mock, patch + >>> from mcpgateway.utils.passthrough_headers import PassthroughHeadersError + >>> @pytest.mark.asyncio + ... @patch("mcpgateway.utils.passthrough_headers.settings") + ... async def test_db_write_failure(mock_settings): + ... mock_settings.enable_header_passthrough = True + ... mock_db = Mock() + ... mock_db.query.return_value.first.return_value = None + ... mock_db.commit.side_effect = Exception("DB write failed") + ... with pytest.raises(PassthroughHeadersError): + ... await set_global_passthrough_headers(mock_db) + ... mock_db.rollback.assert_called_once() + + Config already exists (no DB write): + >>> import pytest + >>> from unittest.mock import Mock, patch + >>> from mcpgateway.models import GlobalConfig + >>> @pytest.mark.asyncio + ... @patch("mcpgateway.utils.passthrough_headers.settings") + ... async def test_existing_config(mock_settings): + ... mock_settings.enable_header_passthrough = True + ... mock_db = Mock() + ... existing = Mock(spec=GlobalConfig) + ... existing.passthrough_headers = ["X-Tenant-ID", "Authorization"] + ... mock_db.query.return_value.first.return_value = existing + ... await set_global_passthrough_headers(mock_db) + ... mock_db.add.assert_not_called() + ... mock_db.commit.assert_not_called() + ... assert existing.passthrough_headers == ["X-Tenant-ID", "Authorization"] Note: This function is typically called during application startup to ensure diff --git a/tests/unit/mcpgateway/utils/test_passthrough_headers_fixed.py b/tests/unit/mcpgateway/utils/test_passthrough_headers_fixed.py index 380761979..7d725c153 100644 --- a/tests/unit/mcpgateway/utils/test_passthrough_headers_fixed.py +++ b/tests/unit/mcpgateway/utils/test_passthrough_headers_fixed.py @@ -13,11 +13,12 @@ # Standard import logging from unittest.mock import Mock, patch +import pytest # First-Party from mcpgateway.db import Gateway as DbGateway from mcpgateway.db import GlobalConfig -from mcpgateway.utils.passthrough_headers import get_passthrough_headers +from mcpgateway.utils.passthrough_headers import get_passthrough_headers, set_global_passthrough_headers, PassthroughHeadersError class TestPassthroughHeaders: @@ -157,3 +158,59 @@ def test_case_insensitive_header_matching(self, mock_settings): # Headers should preserve config case in output keys expected = {"X-Tenant-ID": "mixed-case-value", "Authorization": "bearer lowercase-header"} assert result == expected + + @pytest.mark.asyncio + @patch("mcpgateway.utils.passthrough_headers.settings") + async def test_set_global_passthrough_headers_default(self, mock_settings): + mock_settings.enable_header_passthrough = True + mock_settings.default_passthrough_headers = ["X-Tenant-Id", "X-Trace-Id"] + + mock_db = Mock() + mock_db.query.return_value.first.return_value = None # Simulate no config in DB + + # Act + await set_global_passthrough_headers(mock_db) + + # Assert + mock_db.add.assert_called_once() + added_config = mock_db.add.call_args[0][0] + assert added_config.passthrough_headers == ["X-Tenant-Id", "X-Trace-Id"] + + mock_db.commit.assert_called_once() + + + @pytest.mark.asyncio + @patch("mcpgateway.utils.passthrough_headers.settings") + async def test_set_global_passthrough_headers_invalid_config(self, mock_settings): + """Should raise PassthroughHeadersError when config is invalid.""" + mock_settings.enable_header_passthrough = True + + mock_db = Mock() + mock_db.query.return_value.first.return_value = None + mock_db.commit.side_effect = Exception("DB write failed") + + with pytest.raises(PassthroughHeadersError) as exc_info: + await set_global_passthrough_headers(mock_db) + + assert "DB write failed" in str(exc_info.value) or str(exc_info.value) + mock_db.rollback.assert_called_once() + + @pytest.mark.asyncio + @patch("mcpgateway.utils.passthrough_headers.settings") + async def test_set_global_passthrough_headers_existing_config(self, mock_settings): + """Should raise PassthroughHeadersError when config is invalid.""" + mock_settings.enable_header_passthrough = True + + mock_db = Mock() + mock_global_config = Mock(spec=GlobalConfig) + mock_global_config.passthrough_headers = ["X-Tenant-ID", "Authorization"] + mock_db.query.return_value.first.return_value = mock_global_config + + await set_global_passthrough_headers(mock_db) + + mock_db.add.assert_not_called() + mock_db.commit.assert_not_called() + + # Ensure existing config is not modified + assert mock_global_config.passthrough_headers == ["X-Tenant-ID", "Authorization"] + mock_db.rollback.assert_not_called() From e2938998c53d6c7d7fe9aa7c90cbff92795c8920 Mon Sep 17 00:00:00 2001 From: Madhav Kandukuri Date: Wed, 13 Aug 2025 21:27:17 +0530 Subject: [PATCH 4/5] Fix test Signed-off-by: Madhav Kandukuri --- mcpgateway/utils/passthrough_headers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mcpgateway/utils/passthrough_headers.py b/mcpgateway/utils/passthrough_headers.py index 53d6a9dc7..ba2122b3d 100644 --- a/mcpgateway/utils/passthrough_headers.py +++ b/mcpgateway/utils/passthrough_headers.py @@ -152,7 +152,7 @@ def get_passthrough_headers(request_headers: Dict[str, str], base_headers: Dict[ ... request_headers = {"x-tenant-id": "should-be-ignored"} ... base_headers = {"Content-Type": "application/json"} ... get_passthrough_headers(request_headers, base_headers, mock_db) - {'Content-Type': 'application/json', 'X-Tenant-Id': 'should-be-ignored'} + {'Content-Type': 'application/json'} See comprehensive unit tests in tests/unit/mcpgateway/utils/test_passthrough_headers*.py for detailed examples of enabled functionality, conflict detection, and security features. From 38a9cdedf5de123716a0723b0607ef663a0ebaf8 Mon Sep 17 00:00:00 2001 From: Madhav Kandukuri Date: Wed, 13 Aug 2025 22:50:50 +0530 Subject: [PATCH 5/5] Fix doctest Signed-off-by: Madhav Kandukuri --- mcpgateway/utils/passthrough_headers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mcpgateway/utils/passthrough_headers.py b/mcpgateway/utils/passthrough_headers.py index ba2122b3d..91a84bbab 100644 --- a/mcpgateway/utils/passthrough_headers.py +++ b/mcpgateway/utils/passthrough_headers.py @@ -144,7 +144,7 @@ def get_passthrough_headers(request_headers: Dict[str, str], base_headers: Dict[ Examples: Feature disabled by default (secure by default): >>> from unittest.mock import Mock, patch - >>> with patch("mcpgateway.utils.passthrough_headers.settings") as mock_settings: + >>> with patch(__name__ + ".settings") as mock_settings: ... mock_settings.enable_header_passthrough = False ... mock_settings.default_passthrough_headers = ["X-Tenant-Id"] ... mock_db = Mock()