Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 28 additions & 13 deletions mcpgateway/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
from mcpgateway.services.tool_service import ToolError, ToolNotFoundError, ToolService
from mcpgateway.utils.create_jwt_token import get_jwt_token
from mcpgateway.utils.error_formatter import ErrorFormatter
from mcpgateway.utils.passthrough_headers import PassthroughHeadersError
from mcpgateway.utils.retry_manager import ResilientHttpClient
from mcpgateway.utils.verify_credentials import require_auth, require_basic_auth

Expand Down Expand Up @@ -176,14 +177,12 @@ async def wrapper(*args, request: Request = None, **kwargs):
@admin_router.get("/config/passthrough-headers", response_model=GlobalConfigRead)
@rate_limit(requests_per_minute=30) # Lower limit for config endpoints
async def get_global_passthrough_headers(
request: Request, # pylint: disable=unused-argument
db: Session = Depends(get_db),
_user: str = Depends(require_auth),
) -> GlobalConfigRead:
"""Get the global passthrough headers configuration.

Args:
request: HTTP request object
db: Database session
_user: Authenticated user

Expand All @@ -201,9 +200,11 @@ async def get_global_passthrough_headers(
True
"""
config = db.query(GlobalConfig).first()
if not config:
config = GlobalConfig()
return GlobalConfigRead(passthrough_headers=config.passthrough_headers)
if config:
passthrough_headers = config.passthrough_headers
else:
passthrough_headers = []
return GlobalConfigRead(passthrough_headers=passthrough_headers)


@admin_router.put("/config/passthrough-headers", response_model=GlobalConfigRead)
Expand All @@ -222,6 +223,9 @@ async def update_global_passthrough_headers(
db: Database session
_user: Authenticated user

Raises:
HTTPException: If there is a conflict or validation error

Returns:
GlobalConfigRead: The updated configuration

Expand All @@ -235,14 +239,25 @@ async def update_global_passthrough_headers(
>>> inspect.iscoroutinefunction(update_global_passthrough_headers)
True
"""
config = db.query(GlobalConfig).first()
if not config:
config = GlobalConfig(passthrough_headers=config_update.passthrough_headers)
db.add(config)
else:
config.passthrough_headers = config_update.passthrough_headers
db.commit()
return GlobalConfigRead(passthrough_headers=config.passthrough_headers)
try:
config = db.query(GlobalConfig).first()
if not config:
config = GlobalConfig(passthrough_headers=config_update.passthrough_headers)
db.add(config)
else:
config.passthrough_headers = config_update.passthrough_headers
db.commit()
return GlobalConfigRead(passthrough_headers=config.passthrough_headers)
except Exception as e:
if isinstance(e, IntegrityError):
db.rollback()
raise HTTPException(status_code=409, detail="Passthrough headers conflict")
if isinstance(e, ValidationError):
db.rollback()
raise HTTPException(status_code=422, detail="Invalid passthrough headers format")
if isinstance(e, PassthroughHeadersError):
db.rollback()
raise HTTPException(status_code=500, detail=str(e))


@admin_router.get("/servers", response_model=List[ServerRead])
Expand Down
9 changes: 9 additions & 0 deletions mcpgateway/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
from mcpgateway.transports.streamablehttp_transport import SessionManagerWrapper, streamable_http_auth
from mcpgateway.utils.db_isready import wait_for_db_ready
from mcpgateway.utils.error_formatter import ErrorFormatter
from mcpgateway.utils.passthrough_headers import set_global_passthrough_headers
from mcpgateway.utils.redis_isready import wait_for_redis_ready
from mcpgateway.utils.retry_manager import ResilientHttpClient
from mcpgateway.utils.verify_credentials import require_auth, require_auth_override, verify_jwt_token
Expand Down Expand Up @@ -191,6 +192,14 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
await plugin_manager.initialize()
logger.info(f"Plugin manager initialized with {plugin_manager.plugin_count} plugins")

if settings.enable_header_passthrough:
db_gen = get_db()
db = next(db_gen) # pylint: disable=stop-iteration-return
try:
await set_global_passthrough_headers(db)
finally:
db.close()

await tool_service.initialize()
await resource_service.initialize()
await prompt_service.initialize()
Expand Down
112 changes: 106 additions & 6 deletions mcpgateway/utils/passthrough_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,18 @@
MAX_HEADER_VALUE_LENGTH = 4096


class PassthroughHeadersError(Exception):
"""Base class for passthrough headers-related errors.

Examples:
>>> error = PassthroughHeadersError("Test error")
>>> str(error)
'Test error'
>>> isinstance(error, Exception)
True
"""


def sanitize_header_value(value: str, max_length: int = MAX_HEADER_VALUE_LENGTH) -> str:
"""Sanitize header value for security.

Expand Down Expand Up @@ -131,12 +143,15 @@ def get_passthrough_headers(request_headers: Dict[str, str], base_headers: Dict[

Examples:
Feature disabled by default (secure by default):
>>> from unittest.mock import Mock
>>> mock_db = Mock()
>>> request_headers = {"x-tenant-id": "should-be-ignored"}
>>> base_headers = {"Content-Type": "application/json"}
>>> result = get_passthrough_headers(request_headers, base_headers, mock_db)
>>> result
>>> from unittest.mock import Mock, patch
>>> with patch(__name__ + ".settings") as mock_settings:
... mock_settings.enable_header_passthrough = False
... mock_settings.default_passthrough_headers = ["X-Tenant-Id"]
... mock_db = Mock()
... mock_db.query.return_value.first.return_value = None
... request_headers = {"x-tenant-id": "should-be-ignored"}
... base_headers = {"Content-Type": "application/json"}
... get_passthrough_headers(request_headers, base_headers, mock_db)
{'Content-Type': 'application/json'}

See comprehensive unit tests in tests/unit/mcpgateway/utils/test_passthrough_headers*.py
Expand Down Expand Up @@ -213,3 +228,88 @@ def get_passthrough_headers(request_headers: Dict[str, str], base_headers: Dict[

logger.debug(f"Final passthrough headers: {list(passthrough_headers.keys())}")
return passthrough_headers


async def set_global_passthrough_headers(db: Session) -> None:
"""Set global passthrough headers in the database if not already configured.

This function checks if the global passthrough headers are already set in the
GlobalConfig table. If not, it initializes them with the default headers from
settings.default_passthrough_headers.

Args:
db (Session): SQLAlchemy database session for querying and updating GlobalConfig.

Raises:
PassthroughHeadersError: If unable to update passthrough headers in the database.

Examples:
Successful insert of default headers:
>>> import pytest
>>> from unittest.mock import Mock, patch
>>> @pytest.mark.asyncio
... @patch("mcpgateway.utils.passthrough_headers.settings")
... async def test_default_headers(mock_settings):
... mock_settings.enable_header_passthrough = True
... mock_settings.default_passthrough_headers = ["X-Tenant-Id", "X-Trace-Id"]
... mock_db = Mock()
... mock_db.query.return_value.first.return_value = None
... await set_global_passthrough_headers(mock_db)
... mock_db.add.assert_called_once()
... mock_db.commit.assert_called_once()

Database write failure:
>>> import pytest
>>> from unittest.mock import Mock, patch
>>> from mcpgateway.utils.passthrough_headers import PassthroughHeadersError
>>> @pytest.mark.asyncio
... @patch("mcpgateway.utils.passthrough_headers.settings")
... async def test_db_write_failure(mock_settings):
... mock_settings.enable_header_passthrough = True
... mock_db = Mock()
... mock_db.query.return_value.first.return_value = None
... mock_db.commit.side_effect = Exception("DB write failed")
... with pytest.raises(PassthroughHeadersError):
... await set_global_passthrough_headers(mock_db)
... mock_db.rollback.assert_called_once()

Config already exists (no DB write):
>>> import pytest
>>> from unittest.mock import Mock, patch
>>> from mcpgateway.models import GlobalConfig
>>> @pytest.mark.asyncio
... @patch("mcpgateway.utils.passthrough_headers.settings")
... async def test_existing_config(mock_settings):
... mock_settings.enable_header_passthrough = True
... mock_db = Mock()
... existing = Mock(spec=GlobalConfig)
... existing.passthrough_headers = ["X-Tenant-ID", "Authorization"]
... mock_db.query.return_value.first.return_value = existing
... await set_global_passthrough_headers(mock_db)
... mock_db.add.assert_not_called()
... mock_db.commit.assert_not_called()
... assert existing.passthrough_headers == ["X-Tenant-ID", "Authorization"]

Note:
This function is typically called during application startup to ensure
global configuration is in place before any gateway operations.
"""
global_config = db.query(GlobalConfig).first()

if not global_config:
config_headers = settings.default_passthrough_headers
if config_headers:
allowed_headers = []
for header_name in config_headers:
# Validate header name
if not validate_header_name(header_name):
logger.warning(f"Invalid header name '{header_name}' - skipping (must match pattern: {HEADER_NAME_REGEX.pattern})")
continue

allowed_headers.append(header_name)
try:
db.add(GlobalConfig(passthrough_headers=allowed_headers))
db.commit()
except Exception as e:
db.rollback()
raise PassthroughHeadersError(f"Failed to update passthrough headers: {str(e)}")
64 changes: 62 additions & 2 deletions tests/unit/mcpgateway/utils/test_passthrough_headers_fixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
# Standard
import logging
from unittest.mock import Mock, patch
import pytest

# First-Party
from mcpgateway.db import Gateway as DbGateway
from mcpgateway.db import GlobalConfig
from mcpgateway.utils.passthrough_headers import get_passthrough_headers
from mcpgateway.utils.passthrough_headers import get_passthrough_headers, set_global_passthrough_headers, PassthroughHeadersError


class TestPassthroughHeaders:
Expand Down Expand Up @@ -121,8 +122,11 @@ def test_authorization_conflict_bearer_auth(self, mock_settings, caplog):
# Check warning was logged
assert any("Skipping Authorization header passthrough due to bearer auth" in record.message for record in caplog.records)

def test_feature_disabled_by_default(self):
@patch("mcpgateway.utils.passthrough_headers.settings")
def test_feature_disabled_by_default(self, mock_settings):
"""Test that feature is disabled by default."""
mock_settings.enable_header_passthrough = False

mock_db = Mock()
request_headers = {"x-tenant-id": "test"}
base_headers = {"Content-Type": "application/json"}
Expand Down Expand Up @@ -154,3 +158,59 @@ def test_case_insensitive_header_matching(self, mock_settings):
# Headers should preserve config case in output keys
expected = {"X-Tenant-ID": "mixed-case-value", "Authorization": "bearer lowercase-header"}
assert result == expected

@pytest.mark.asyncio
@patch("mcpgateway.utils.passthrough_headers.settings")
async def test_set_global_passthrough_headers_default(self, mock_settings):
mock_settings.enable_header_passthrough = True
mock_settings.default_passthrough_headers = ["X-Tenant-Id", "X-Trace-Id"]

mock_db = Mock()
mock_db.query.return_value.first.return_value = None # Simulate no config in DB

# Act
await set_global_passthrough_headers(mock_db)

# Assert
mock_db.add.assert_called_once()
added_config = mock_db.add.call_args[0][0]
assert added_config.passthrough_headers == ["X-Tenant-Id", "X-Trace-Id"]

mock_db.commit.assert_called_once()


@pytest.mark.asyncio
@patch("mcpgateway.utils.passthrough_headers.settings")
async def test_set_global_passthrough_headers_invalid_config(self, mock_settings):
"""Should raise PassthroughHeadersError when config is invalid."""
mock_settings.enable_header_passthrough = True

mock_db = Mock()
mock_db.query.return_value.first.return_value = None
mock_db.commit.side_effect = Exception("DB write failed")

with pytest.raises(PassthroughHeadersError) as exc_info:
await set_global_passthrough_headers(mock_db)

assert "DB write failed" in str(exc_info.value) or str(exc_info.value)
mock_db.rollback.assert_called_once()

@pytest.mark.asyncio
@patch("mcpgateway.utils.passthrough_headers.settings")
async def test_set_global_passthrough_headers_existing_config(self, mock_settings):
"""Should raise PassthroughHeadersError when config is invalid."""
mock_settings.enable_header_passthrough = True

mock_db = Mock()
mock_global_config = Mock(spec=GlobalConfig)
mock_global_config.passthrough_headers = ["X-Tenant-ID", "Authorization"]
mock_db.query.return_value.first.return_value = mock_global_config

await set_global_passthrough_headers(mock_db)

mock_db.add.assert_not_called()
mock_db.commit.assert_not_called()

# Ensure existing config is not modified
assert mock_global_config.passthrough_headers == ["X-Tenant-ID", "Authorization"]
mock_db.rollback.assert_not_called()
Loading