From cac9f630044c840e53b29f791fb4cdf0d028d8cd Mon Sep 17 00:00:00 2001 From: Madhav Kandukuri Date: Thu, 30 Oct 2025 07:19:25 +0530 Subject: [PATCH 01/14] Removes redundant base64 Signed-off-by: Madhav Kandukuri --- mcpgateway/utils/oauth_encryption.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mcpgateway/utils/oauth_encryption.py b/mcpgateway/utils/oauth_encryption.py index bd58a49a6..a6a1959f7 100644 --- a/mcpgateway/utils/oauth_encryption.py +++ b/mcpgateway/utils/oauth_encryption.py @@ -83,7 +83,7 @@ def encrypt_secret(self, plaintext: str) -> str: try: fernet = self._get_fernet() encrypted = fernet.encrypt(plaintext.encode()) - return base64.urlsafe_b64encode(encrypted).decode() + return encrypted.decode() except Exception as e: logger.error(f"Failed to encrypt OAuth secret: {e}") raise @@ -99,8 +99,7 @@ def decrypt_secret(self, encrypted_text: str) -> Optional[str]: """ try: fernet = self._get_fernet() - encrypted_bytes = base64.urlsafe_b64decode(encrypted_text.encode()) - decrypted = fernet.decrypt(encrypted_bytes) + decrypted = fernet.decrypt(encrypted_text.encode()) return decrypted.decode() except Exception as e: logger.error(f"Failed to decrypt OAuth secret: {e}") From f53fcc5e4fbc62850aff0f22082a683fda4b906a Mon Sep 17 00:00:00 2001 From: Madhav Kandukuri Date: Thu, 30 Oct 2025 07:45:08 +0530 Subject: [PATCH 02/14] Replace PBKDF2HMAC with Argon2Id encryption Signed-off-by: Madhav Kandukuri --- mcpgateway/utils/oauth_encryption.py | 71 +++++++++++++++++----------- 1 file changed, 43 insertions(+), 28 deletions(-) diff --git a/mcpgateway/utils/oauth_encryption.py b/mcpgateway/utils/oauth_encryption.py index a6a1959f7..de5295828 100644 --- a/mcpgateway/utils/oauth_encryption.py +++ b/mcpgateway/utils/oauth_encryption.py @@ -12,13 +12,17 @@ # Standard import base64 +import json import logging +import os from typing import Optional +# First-Party +from mcpgateway.config import settings + # Third-Party from cryptography.fernet import Fernet -from cryptography.hazmat.primitives import hashes -from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC +from argon2.low_level import hash_secret_raw, Type from pydantic import SecretStr logger = logging.getLogger(__name__) @@ -41,32 +45,31 @@ class OAuthEncryption: False """ - def __init__(self, encryption_secret: SecretStr): + def __init__(self, encryption_secret: SecretStr, time_cost: Optional[int] = None, memory_cost: Optional[int] = None, parallelism: Optional[int] = None, hash_len: int = 32, salt_len: int = 16): """Initialize the encryption handler. Args: encryption_secret: Secret key for encryption/decryption """ self.encryption_secret = encryption_secret.get_secret_value().encode() - self._fernet = None - - def _get_fernet(self) -> Fernet: - """Get or create Fernet instance for encryption. - - Returns: - Fernet instance for encryption/decryption - """ - if self._fernet is None: - # Derive a key from the encryption secret using PBKDF2 - kdf = PBKDF2HMAC( - algorithm=hashes.SHA256(), - length=32, - salt=b"mcp_gateway_oauth", # Fixed salt for consistency - iterations=100000, - ) - key = base64.urlsafe_b64encode(kdf.derive(self.encryption_secret)) - self._fernet = Fernet(key) - return self._fernet + self.time_cost = time_cost or getattr(settings, "argon2id_time_cost", 3) + self.memory_cost = memory_cost or getattr(settings, "argon2id_memory_cost", 65536) + self.parallelism = parallelism or getattr(settings, "argon2id_parallelism", 1) + self.hash_len = hash_len + self.salt_len = salt_len + + + def derive_key_argon2id(self, passphrase: bytes, salt: bytes, time_cost: int, memory_cost: int, parallelism: int) -> bytes: + raw = hash_secret_raw( + secret=passphrase, + salt=salt, + time_cost=time_cost, + memory_cost=memory_cost, # KiB + parallelism=parallelism, + hash_len=self.hash_len, + type=Type.ID, + ) + return base64.urlsafe_b64encode(raw) def encrypt_secret(self, plaintext: str) -> str: """Encrypt a plaintext secret. @@ -81,25 +84,37 @@ def encrypt_secret(self, plaintext: str) -> str: Exception: If encryption fails """ try: - fernet = self._get_fernet() + salt = os.urandom(16) + key = self.derive_key_argon2id(self.encryption_secret, salt, self.time_cost, self.memory_cost, self.parallelism) + fernet = Fernet(key) encrypted = fernet.encrypt(plaintext.encode()) - return encrypted.decode() + return json.dumps({ + "kdf": "argon2id", + "t": self.time_cost, + "m": self.memory_cost, + "p": self.parallelism, + "salt": base64.b64encode(salt).decode(), + "token": encrypted.decode(), + }) except Exception as e: logger.error(f"Failed to encrypt OAuth secret: {e}") raise - def decrypt_secret(self, encrypted_text: str) -> Optional[str]: + def decrypt_secret(self, bundle_json: str) -> Optional[str]: """Decrypt an encrypted secret. Args: - encrypted_text: Base64-encoded encrypted string + bundle_json: str: JSON string containing encryption metadata and token Returns: Decrypted secret string, or None if decryption fails """ try: - fernet = self._get_fernet() - decrypted = fernet.decrypt(encrypted_text.encode()) + b = json.loads(bundle_json) + salt = base64.b64decode(b["salt"]) + key = self.derive_key_argon2id(self.encryption_secret, salt, time_cost=b["t"], memory_cost=b["m"], parallelism=b["p"]) + fernet = Fernet(key) + decrypted = fernet.decrypt(b["token"].encode()) return decrypted.decode() except Exception as e: logger.error(f"Failed to decrypt OAuth secret: {e}") From 90e4a156833893f9fd1955dc794ede51bf4d72c2 Mon Sep 17 00:00:00 2001 From: Madhav Kandukuri Date: Thu, 30 Oct 2025 10:14:43 +0530 Subject: [PATCH 03/14] Use Argon2id for key generation in fernet encryption Signed-off-by: Madhav Kandukuri --- mcpgateway/admin.py | 18 +-- mcpgateway/routers/oauth_router.py | 4 +- mcpgateway/services/dcr_service.py | 8 +- mcpgateway/services/oauth_manager.py | 10 +- mcpgateway/services/token_storage_service.py | 4 +- ...uth_encryption.py => fernet_encryption.py} | 58 ++++---- .../mcpgateway/services/test_dcr_service.py | 8 +- tests/unit/mcpgateway/test_admin.py | 6 +- tests/unit/mcpgateway/test_oauth_manager.py | 131 +++++++----------- 9 files changed, 110 insertions(+), 137 deletions(-) rename mcpgateway/utils/{oauth_encryption.py => fernet_encryption.py} (75%) diff --git a/mcpgateway/admin.py b/mcpgateway/admin.py index 5d859a2cb..f88e98a51 100644 --- a/mcpgateway/admin.py +++ b/mcpgateway/admin.py @@ -112,8 +112,8 @@ from mcpgateway.services.tool_service import ToolError, ToolNameConflictError, ToolNotFoundError, ToolService from mcpgateway.utils.create_jwt_token import create_jwt_token, get_jwt_token from mcpgateway.utils.error_formatter import ErrorFormatter +from mcpgateway.utils.fernet_encryption import get_fernet_encryption from mcpgateway.utils.metadata_capture import MetadataCapture -from mcpgateway.utils.oauth_encryption import get_oauth_encryption from mcpgateway.utils.pagination import generate_pagination_links from mcpgateway.utils.passthrough_headers import PassthroughHeadersError from mcpgateway.utils.retry_manager import ResilientHttpClient @@ -6194,7 +6194,7 @@ async def admin_add_gateway(request: Request, db: Session = Depends(get_db), use oauth_config = json.loads(oauth_config_json) # Encrypt the client secret if present if oauth_config and "client_secret" in oauth_config: - encryption = get_oauth_encryption(settings.auth_encryption_secret) + encryption = get_fernet_encryption(settings.auth_encryption_secret) oauth_config["client_secret"] = encryption.encrypt_secret(oauth_config["client_secret"]) except (json.JSONDecodeError, ValueError) as e: LOGGER.error(f"Failed to parse OAuth config: {e}") @@ -6231,7 +6231,7 @@ async def admin_add_gateway(request: Request, db: Session = Depends(get_db), use oauth_config["client_id"] = oauth_client_id if oauth_client_secret: # Encrypt the client secret - encryption = get_oauth_encryption(settings.auth_encryption_secret) + encryption = get_fernet_encryption(settings.auth_encryption_secret) oauth_config["client_secret"] = encryption.encrypt_secret(oauth_client_secret) # Add username and password for password grant type @@ -6503,7 +6503,7 @@ async def admin_edit_gateway( oauth_config = json.loads(oauth_config_json) # Encrypt the client secret if present and not empty if oauth_config and "client_secret" in oauth_config and oauth_config["client_secret"]: - encryption = get_oauth_encryption(settings.auth_encryption_secret) + encryption = get_fernet_encryption(settings.auth_encryption_secret) oauth_config["client_secret"] = encryption.encrypt_secret(oauth_config["client_secret"]) except (json.JSONDecodeError, ValueError) as e: LOGGER.error(f"Failed to parse OAuth config: {e}") @@ -6540,7 +6540,7 @@ async def admin_edit_gateway( oauth_config["client_id"] = oauth_client_id if oauth_client_secret: # Encrypt the client secret - encryption = get_oauth_encryption(settings.auth_encryption_secret) + encryption = get_fernet_encryption(settings.auth_encryption_secret) oauth_config["client_secret"] = encryption.encrypt_secret(oauth_client_secret) # Add username and password for password grant type @@ -9571,7 +9571,7 @@ async def admin_add_a2a_agent( oauth_config = json.loads(oauth_config_json) # Encrypt the client secret if present if oauth_config and "client_secret" in oauth_config: - encryption = get_oauth_encryption(settings.auth_encryption_secret) + encryption = get_fernet_encryption(settings.auth_encryption_secret) oauth_config["client_secret"] = encryption.encrypt_secret(oauth_config["client_secret"]) except (json.JSONDecodeError, ValueError) as e: LOGGER.error(f"Failed to parse OAuth config: {e}") @@ -9608,7 +9608,7 @@ async def admin_add_a2a_agent( oauth_config["client_id"] = oauth_client_id if oauth_client_secret: # Encrypt the client secret - encryption = get_oauth_encryption(settings.auth_encryption_secret) + encryption = get_fernet_encryption(settings.auth_encryption_secret) oauth_config["client_secret"] = encryption.encrypt_secret(oauth_client_secret) # Add username and password for password grant type @@ -9890,7 +9890,7 @@ async def admin_edit_a2a_agent( oauth_config = json.loads(oauth_config_json) # Encrypt the client secret if present and not empty if oauth_config and "client_secret" in oauth_config and oauth_config["client_secret"]: - encryption = get_oauth_encryption(settings.auth_encryption_secret) + encryption = get_fernet_encryption(settings.auth_encryption_secret) oauth_config["client_secret"] = encryption.encrypt_secret(oauth_config["client_secret"]) except (json.JSONDecodeError, ValueError) as e: LOGGER.error(f"Failed to parse OAuth config: {e}") @@ -9927,7 +9927,7 @@ async def admin_edit_a2a_agent( oauth_config["client_id"] = oauth_client_id if oauth_client_secret: # Encrypt the client secret - encryption = get_oauth_encryption(settings.auth_encryption_secret) + encryption = get_fernet_encryption(settings.auth_encryption_secret) oauth_config["client_secret"] = encryption.encrypt_secret(oauth_client_secret) # Add username and password for password grant type diff --git a/mcpgateway/routers/oauth_router.py b/mcpgateway/routers/oauth_router.py index e52ceb093..83b4c0cb7 100644 --- a/mcpgateway/routers/oauth_router.py +++ b/mcpgateway/routers/oauth_router.py @@ -113,9 +113,9 @@ async def initiate_oauth_flow( decrypted_secret = None if registered_client.client_secret_encrypted: # First-Party - from mcpgateway.utils.oauth_encryption import get_oauth_encryption + from mcpgateway.utils.fernet_encryption import get_fernet_encryption - encryption = get_oauth_encryption(settings.auth_encryption_secret) + encryption = get_fernet_encryption(settings.auth_encryption_secret) decrypted_secret = encryption.decrypt_secret(registered_client.client_secret_encrypted) # Update oauth_config with registered credentials diff --git a/mcpgateway/services/dcr_service.py b/mcpgateway/services/dcr_service.py index 88573b126..e300d3539 100644 --- a/mcpgateway/services/dcr_service.py +++ b/mcpgateway/services/dcr_service.py @@ -25,7 +25,7 @@ # First-Party from mcpgateway.config import get_settings from mcpgateway.db import RegisteredOAuthClient -from mcpgateway.utils.oauth_encryption import get_oauth_encryption +from mcpgateway.utils.fernet_encryption import get_fernet_encryption logger = logging.getLogger(__name__) @@ -168,7 +168,7 @@ async def register_client(self, gateway_id: str, gateway_name: str, issuer: str, raise DcrError(f"Failed to register client with {issuer}: {e}") # Encrypt secrets - encryption = get_oauth_encryption(self.settings.auth_encryption_secret) + encryption = get_fernet_encryption(self.settings.auth_encryption_secret) client_secret = registration_response.get("client_secret") client_secret_encrypted = encryption.encrypt_secret(client_secret) if client_secret else None @@ -260,7 +260,7 @@ async def update_client_registration(self, client_record: RegisteredOAuthClient, raise DcrError("Cannot update client: no registration_access_token available") # Decrypt registration access token - encryption = get_oauth_encryption(self.settings.auth_encryption_secret) + encryption = get_fernet_encryption(self.settings.auth_encryption_secret) registration_access_token = encryption.decrypt_secret(client_record.registration_access_token_encrypted) # Build update request @@ -313,7 +313,7 @@ async def delete_client_registration(self, client_record: RegisteredOAuthClient, return True # Consider it deleted locally # Decrypt registration access token - encryption = get_oauth_encryption(self.settings.auth_encryption_secret) + encryption = get_fernet_encryption(self.settings.auth_encryption_secret) registration_access_token = encryption.decrypt_secret(client_record.registration_access_token_encrypted) # Send delete request diff --git a/mcpgateway/services/oauth_manager.py b/mcpgateway/services/oauth_manager.py index 1e755af00..304bec339 100644 --- a/mcpgateway/services/oauth_manager.py +++ b/mcpgateway/services/oauth_manager.py @@ -28,7 +28,7 @@ # First-Party from mcpgateway.config import get_settings -from mcpgateway.utils.oauth_encryption import get_oauth_encryption +from mcpgateway.utils.fernet_encryption import get_fernet_encryption logger = logging.getLogger(__name__) @@ -222,7 +222,7 @@ async def _client_credentials_flow(self, credentials: Dict[str, Any]) -> str: if len(client_secret) > 50: # Simple heuristic: encrypted secrets are longer try: settings = get_settings() - encryption = get_oauth_encryption(settings.auth_encryption_secret) + encryption = get_fernet_encryption(settings.auth_encryption_secret) decrypted_secret = encryption.decrypt_secret(client_secret) if decrypted_secret: client_secret = decrypted_secret @@ -313,7 +313,7 @@ async def _password_flow(self, credentials: Dict[str, Any]) -> str: if client_secret and len(client_secret) > 50: # Simple heuristic: encrypted secrets are longer try: settings = get_settings() - encryption = get_oauth_encryption(settings.auth_encryption_secret) + encryption = get_fernet_encryption(settings.auth_encryption_secret) decrypted_secret = encryption.decrypt_secret(client_secret) if decrypted_secret: client_secret = decrypted_secret @@ -430,7 +430,7 @@ async def exchange_code_for_token(self, credentials: Dict[str, Any], code: str, if client_secret and len(client_secret) > 50: # Simple heuristic: encrypted secrets are longer try: settings = get_settings() - encryption = get_oauth_encryption(settings.auth_encryption_secret) + encryption = get_fernet_encryption(settings.auth_encryption_secret) decrypted_secret = encryption.decrypt_secret(client_secret) if decrypted_secret: client_secret = decrypted_secret @@ -1007,7 +1007,7 @@ async def _exchange_code_for_tokens(self, credentials: Dict[str, Any], code: str if client_secret and len(client_secret) > 50: # Simple heuristic: encrypted secrets are longer try: settings = get_settings() - encryption = get_oauth_encryption(settings.auth_encryption_secret) + encryption = get_fernet_encryption(settings.auth_encryption_secret) decrypted_secret = encryption.decrypt_secret(client_secret) if decrypted_secret: client_secret = decrypted_secret diff --git a/mcpgateway/services/token_storage_service.py b/mcpgateway/services/token_storage_service.py index da441c7b7..8c85de7df 100644 --- a/mcpgateway/services/token_storage_service.py +++ b/mcpgateway/services/token_storage_service.py @@ -23,7 +23,7 @@ from mcpgateway.config import get_settings from mcpgateway.db import OAuthToken from mcpgateway.services.oauth_manager import OAuthError -from mcpgateway.utils.oauth_encryption import get_oauth_encryption +from mcpgateway.utils.fernet_encryption import get_fernet_encryption logger = logging.getLogger(__name__) @@ -68,7 +68,7 @@ def __init__(self, db: Session): self.db = db try: settings = get_settings() - self.encryption = get_oauth_encryption(settings.auth_encryption_secret) + self.encryption = get_fernet_encryption(settings.auth_encryption_secret) except (ImportError, AttributeError): logger.warning("OAuth encryption not available, using plain text storage") self.encryption = None diff --git a/mcpgateway/utils/oauth_encryption.py b/mcpgateway/utils/fernet_encryption.py similarity index 75% rename from mcpgateway/utils/oauth_encryption.py rename to mcpgateway/utils/fernet_encryption.py index de5295828..470607c99 100644 --- a/mcpgateway/utils/oauth_encryption.py +++ b/mcpgateway/utils/fernet_encryption.py @@ -4,9 +4,9 @@ SPDX-License-Identifier: Apache-2.0 Authors: Mihai Criveti -OAuth Encryption Utilities. +Fernet Encryption Utilities. -This module provides encryption and decryption functions for OAuth client secrets +This module provides encryption and decryption functions for client secrets using the AUTH_ENCRYPTION_SECRET from configuration. """ @@ -17,23 +17,23 @@ import os from typing import Optional -# First-Party -from mcpgateway.config import settings - # Third-Party -from cryptography.fernet import Fernet from argon2.low_level import hash_secret_raw, Type +from cryptography.fernet import Fernet from pydantic import SecretStr +# First-Party +from mcpgateway.config import settings + logger = logging.getLogger(__name__) -class OAuthEncryption: - """Handles encryption and decryption of OAuth client secrets. +class FernetEncryption: + """Handles Fernet encryption and decryption of client secrets. Examples: Basic roundtrip: - >>> enc = OAuthEncryption(SecretStr('very-secret-key')) + >>> enc = FernetEncryption(SecretStr('very-secret-key')) >>> cipher = enc.encrypt_secret('hello') >>> isinstance(cipher, str) and enc.is_encrypted(cipher) True @@ -50,6 +50,11 @@ def __init__(self, encryption_secret: SecretStr, time_cost: Optional[int] = None Args: encryption_secret: Secret key for encryption/decryption + time_cost: Argon2id time cost parameter + memory_cost: Argon2id memory cost parameter (in KiB) + parallelism: Argon2id parallelism parameter + hash_len: Length of the derived key + salt_len: Length of the salt """ self.encryption_secret = encryption_secret.get_secret_value().encode() self.time_cost = time_cost or getattr(settings, "argon2id_time_cost", 3) @@ -57,7 +62,6 @@ def __init__(self, encryption_secret: SecretStr, time_cost: Optional[int] = None self.parallelism = parallelism or getattr(settings, "argon2id_parallelism", 1) self.hash_len = hash_len self.salt_len = salt_len - def derive_key_argon2id(self, passphrase: bytes, salt: bytes, time_cost: int, memory_cost: int, parallelism: int) -> bytes: raw = hash_secret_raw( @@ -88,16 +92,18 @@ def encrypt_secret(self, plaintext: str) -> str: key = self.derive_key_argon2id(self.encryption_secret, salt, self.time_cost, self.memory_cost, self.parallelism) fernet = Fernet(key) encrypted = fernet.encrypt(plaintext.encode()) - return json.dumps({ - "kdf": "argon2id", - "t": self.time_cost, - "m": self.memory_cost, - "p": self.parallelism, - "salt": base64.b64encode(salt).decode(), - "token": encrypted.decode(), - }) + return json.dumps( + { + "kdf": "argon2id", + "t": self.time_cost, + "m": self.memory_cost, + "p": self.parallelism, + "salt": base64.b64encode(salt).decode(), + "token": encrypted.decode(), + } + ) except Exception as e: - logger.error(f"Failed to encrypt OAuth secret: {e}") + logger.error(f"Failed to encrypt secret: {e}") raise def decrypt_secret(self, bundle_json: str) -> Optional[str]: @@ -117,7 +123,7 @@ def decrypt_secret(self, bundle_json: str) -> Optional[str]: decrypted = fernet.decrypt(b["token"].encode()) return decrypted.decode() except Exception as e: - logger.error(f"Failed to decrypt OAuth secret: {e}") + logger.error(f"Failed to decrypt secret: {e}") return None def is_encrypted(self, text: str) -> bool: @@ -138,18 +144,18 @@ def is_encrypted(self, text: str) -> bool: return False -def get_oauth_encryption(encryption_secret: SecretStr) -> OAuthEncryption: - """Get an OAuth encryption instance. +def get_fernet_encryption(encryption_secret: SecretStr) -> FernetEncryption: + """Get an Fernet encryption instance. Args: encryption_secret: Secret key for encryption/decryption Returns: - OAuthEncryption instance + FernetEncryption instance Examples: - >>> enc = get_oauth_encryption(SecretStr('k')) - >>> isinstance(enc, OAuthEncryption) + >>> enc = get_fernet_encryption(SecretStr('k')) + >>> isinstance(enc, FernetEncryption) True """ - return OAuthEncryption(encryption_secret) + return FernetEncryption(encryption_secret) diff --git a/tests/unit/mcpgateway/services/test_dcr_service.py b/tests/unit/mcpgateway/services/test_dcr_service.py index 9f493d103..379a0e558 100644 --- a/tests/unit/mcpgateway/services/test_dcr_service.py +++ b/tests/unit/mcpgateway/services/test_dcr_service.py @@ -429,7 +429,7 @@ class TestUpdateClientRegistration: @pytest.mark.asyncio async def test_update_client_registration_success(self, test_db): """Test successful client registration update.""" - from mcpgateway.utils.oauth_encryption import get_oauth_encryption + from mcpgateway.utils.fernet_encryption import get_fernet_encryption from mcpgateway.config import get_settings dcr_service = DcrService() @@ -442,7 +442,7 @@ async def test_update_client_registration_success(self, test_db): test_db.commit() # Encrypt the registration access token properly - encryption = get_oauth_encryption(get_settings().auth_encryption_secret) + encryption = get_fernet_encryption(get_settings().auth_encryption_secret) encrypted_token = encryption.encrypt_secret("registration-access-token") client_record = RegisteredOAuthClient( @@ -474,7 +474,7 @@ async def test_update_client_registration_success(self, test_db): @pytest.mark.asyncio async def test_update_client_registration_uses_access_token(self, test_db): """Test that update uses registration_access_token.""" - from mcpgateway.utils.oauth_encryption import get_oauth_encryption + from mcpgateway.utils.fernet_encryption import get_fernet_encryption from mcpgateway.config import get_settings dcr_service = DcrService() @@ -487,7 +487,7 @@ async def test_update_client_registration_uses_access_token(self, test_db): test_db.commit() # Encrypt the registration access token properly - encryption = get_oauth_encryption(get_settings().auth_encryption_secret) + encryption = get_fernet_encryption(get_settings().auth_encryption_secret) encrypted_token = encryption.encrypt_secret("registration-access-token") client_record = RegisteredOAuthClient( diff --git a/tests/unit/mcpgateway/test_admin.py b/tests/unit/mcpgateway/test_admin.py index 982206152..071cdab0d 100644 --- a/tests/unit/mcpgateway/test_admin.py +++ b/tests/unit/mcpgateway/test_admin.py @@ -2114,7 +2114,7 @@ async def test_admin_add_gateway_with_oauth_config(self, mock_register_gateway, mock_request.form = AsyncMock(return_value=form_data) # Mock OAuth encryption - with patch("mcpgateway.admin.get_oauth_encryption") as mock_get_encryption: + with patch("mcpgateway.admin.get_fernet_encryption") as mock_get_encryption: mock_encryption = MagicMock() mock_encryption.encrypt_secret.return_value = "encrypted-secret" mock_get_encryption.return_value = mock_encryption @@ -2175,7 +2175,7 @@ async def test_admin_edit_gateway_with_oauth_config(self, mock_update_gateway, m mock_request.form = AsyncMock(return_value=form_data) # Mock OAuth encryption - with patch("mcpgateway.admin.get_oauth_encryption") as mock_get_encryption: + with patch("mcpgateway.admin.get_fernet_encryption") as mock_get_encryption: mock_encryption = MagicMock() mock_encryption.encrypt_secret.return_value = "encrypted-edit-secret" mock_get_encryption.return_value = mock_encryption @@ -2204,7 +2204,7 @@ async def test_admin_edit_gateway_oauth_empty_client_secret(self, mock_update_ga mock_request.form = AsyncMock(return_value=form_data) # Mock OAuth encryption - should not be called for empty secret - with patch("mcpgateway.admin.get_oauth_encryption") as mock_get_encryption: + with patch("mcpgateway.admin.get_fernet_encryption") as mock_get_encryption: mock_encryption = MagicMock() mock_get_encryption.return_value = mock_encryption diff --git a/tests/unit/mcpgateway/test_oauth_manager.py b/tests/unit/mcpgateway/test_oauth_manager.py index 802227c2a..769b2d043 100644 --- a/tests/unit/mcpgateway/test_oauth_manager.py +++ b/tests/unit/mcpgateway/test_oauth_manager.py @@ -20,7 +20,7 @@ from mcpgateway.db import OAuthToken from mcpgateway.services.oauth_manager import OAuthError, OAuthManager from mcpgateway.services.token_storage_service import TokenStorageService -from mcpgateway.utils.oauth_encryption import OAuthEncryption +from mcpgateway.utils.fernet_encryption import FernetEncryption class TestOAuthManager: @@ -298,7 +298,7 @@ async def test_client_credentials_flow_with_encrypted_secret(self): mock_settings.auth_encryption_secret = "test_key" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.oauth_manager.get_oauth_encryption") as mock_get_encryption: + with patch("mcpgateway.services.oauth_manager.get_fernet_encryption") as mock_get_encryption: mock_encryption = Mock() mock_encryption.decrypt_secret.return_value = "decrypted_secret" mock_get_encryption.return_value = mock_encryption @@ -371,7 +371,7 @@ async def test_client_credentials_flow_decryption_returns_none(self): mock_settings.auth_encryption_secret = "test_key" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.oauth_manager.get_oauth_encryption") as mock_get_encryption: + with patch("mcpgateway.services.oauth_manager.get_fernet_encryption") as mock_get_encryption: mock_encryption = Mock() # Decryption returns None - line 108 mock_encryption.decrypt_secret.return_value = None @@ -890,7 +890,7 @@ async def test_exchange_code_for_tokens_decryption_returns_none(self): mock_settings.auth_encryption_secret = "test_key" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.oauth_manager.get_oauth_encryption") as mock_get_encryption: + with patch("mcpgateway.services.oauth_manager.get_fernet_encryption") as mock_get_encryption: mock_encryption = Mock() # Decryption returns None - lines 438-439 mock_encryption.decrypt_secret.return_value = None @@ -981,7 +981,7 @@ async def test_exchange_code_for_token_decryption_returns_none(self): mock_settings.auth_encryption_secret = "test_key" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.oauth_manager.get_oauth_encryption") as mock_get_encryption: + with patch("mcpgateway.services.oauth_manager.get_fernet_encryption") as mock_get_encryption: mock_encryption = Mock() # Decryption returns None - lines 216-217 mock_encryption.decrypt_secret.return_value = None @@ -1260,7 +1260,7 @@ async def test_exchange_code_for_tokens_decryption_success(self): mock_settings.auth_encryption_secret = "test_key" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.oauth_manager.get_oauth_encryption") as mock_get_encryption: + with patch("mcpgateway.services.oauth_manager.get_fernet_encryption") as mock_get_encryption: mock_encryption = Mock() # Decryption succeeds - lines 435-437 mock_encryption.decrypt_secret.return_value = "decrypted_secret" @@ -1301,7 +1301,7 @@ async def test_exchange_code_for_tokens_decryption_exception(self): mock_settings.auth_encryption_secret = "test_key" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.oauth_manager.get_oauth_encryption") as mock_get_encryption: + with patch("mcpgateway.services.oauth_manager.get_fernet_encryption") as mock_get_encryption: mock_encryption = Mock() # Decryption throws exception - lines 440-441 mock_encryption.decrypt_secret.side_effect = ValueError("Decryption failed") @@ -1454,7 +1454,7 @@ async def test_exchange_code_for_token_decryption_success(self): mock_settings.auth_encryption_secret = "test_key" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.oauth_manager.get_oauth_encryption") as mock_get_encryption: + with patch("mcpgateway.services.oauth_manager.get_fernet_encryption") as mock_get_encryption: mock_encryption = Mock() # Decryption succeeds - lines 213-215 mock_encryption.decrypt_secret.return_value = "decrypted_secret" @@ -1495,7 +1495,7 @@ async def test_exchange_code_for_token_decryption_exception(self): mock_settings.auth_encryption_secret = "test_key" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.oauth_manager.get_oauth_encryption") as mock_get_encryption: + with patch("mcpgateway.services.oauth_manager.get_fernet_encryption") as mock_get_encryption: mock_encryption = Mock() # Decryption throws exception - lines 218-219 mock_encryption.decrypt_secret.side_effect = ValueError("Decryption failed") @@ -1661,7 +1661,7 @@ def test_init_with_encryption(self): mock_settings.auth_encryption_secret = "test_secret_key" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.token_storage_service.get_oauth_encryption") as mock_get_enc: + with patch("mcpgateway.services.token_storage_service.get_fernet_encryption") as mock_get_enc: mock_encryption = Mock() mock_get_enc.return_value = mock_encryption @@ -1709,7 +1709,7 @@ async def test_store_tokens_new_record_with_encryption(self): mock_settings.auth_encryption_secret = "test_secret" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.token_storage_service.get_oauth_encryption") as mock_get_enc: + with patch("mcpgateway.services.token_storage_service.get_fernet_encryption") as mock_get_enc: mock_get_enc.return_value = mock_encryption service = TokenStorageService(mock_db) @@ -1810,7 +1810,7 @@ async def test_store_tokens_update_existing_record(self): mock_settings.auth_encryption_secret = "test_secret" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.token_storage_service.get_oauth_encryption") as mock_get_enc: + with patch("mcpgateway.services.token_storage_service.get_fernet_encryption") as mock_get_enc: mock_get_enc.return_value = mock_encryption service = TokenStorageService(mock_db) @@ -1854,7 +1854,7 @@ async def test_store_tokens_without_refresh_token(self): mock_settings.auth_encryption_secret = "test_secret" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.token_storage_service.get_oauth_encryption") as mock_get_enc: + with patch("mcpgateway.services.token_storage_service.get_fernet_encryption") as mock_get_enc: mock_get_enc.return_value = mock_encryption service = TokenStorageService(mock_db) @@ -1911,7 +1911,7 @@ async def test_get_valid_token_success_with_encryption(self): mock_settings.auth_encryption_secret = "test_secret" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.token_storage_service.get_oauth_encryption") as mock_get_enc: + with patch("mcpgateway.services.token_storage_service.get_fernet_encryption") as mock_get_enc: mock_get_enc.return_value = mock_encryption service = TokenStorageService(mock_db) @@ -2442,29 +2442,17 @@ async def test_cleanup_expired_tokens_exception(self): mock_db.rollback.assert_called_once() -class TestOAuthEncryption: - """Test cases for OAuthEncryption class.""" +class TestFernetEncryption: + """Test cases for FernetEncryption class.""" def test_init(self): - """Test OAuthEncryption initialization.""" - encryption = OAuthEncryption(SecretStr("test_secret_key")) + """Test FernetEncryption initialization.""" + encryption = FernetEncryption(SecretStr("test_secret_key")) assert encryption.encryption_secret == b"test_secret_key" - assert encryption._fernet is None - - def test_get_fernet_creates_instance(self): - """Test _get_fernet creates Fernet instance on first call.""" - encryption = OAuthEncryption(SecretStr("test_secret_key")) - - fernet1 = encryption._get_fernet() - fernet2 = encryption._get_fernet() - - # Should return same instance (cached) - assert fernet1 is fernet2 - assert encryption._fernet is not None def test_encrypt_secret_success(self): """Test successful secret encryption.""" - encryption = OAuthEncryption(SecretStr("test_secret_key")) + encryption = FernetEncryption(SecretStr("test_secret_key")) plaintext = "my_secret_token_123" encrypted = encryption.encrypt_secret(plaintext) @@ -2479,8 +2467,8 @@ def test_encrypt_secret_success(self): def test_encrypt_secret_different_keys_different_output(self): """Test that different keys produce different encrypted output.""" - encryption1 = OAuthEncryption(SecretStr("key1")) - encryption2 = OAuthEncryption(SecretStr("key2")) + encryption1 = FernetEncryption(SecretStr("key1")) + encryption2 = FernetEncryption(SecretStr("key2")) plaintext = "same_secret" encrypted1 = encryption1.encrypt_secret(plaintext) @@ -2491,7 +2479,7 @@ def test_encrypt_secret_different_keys_different_output(self): def test_encrypt_secret_same_key_different_output(self): """Test that same key produces different encrypted output due to nonce.""" - encryption = OAuthEncryption(SecretStr("test_key")) + encryption = FernetEncryption(SecretStr("test_key")) plaintext = "same_secret" encrypted1 = encryption.encrypt_secret(plaintext) @@ -2506,7 +2494,7 @@ def test_encrypt_secret_same_key_different_output(self): def test_encrypt_secret_empty_string(self): """Test encrypting empty string.""" - encryption = OAuthEncryption(SecretStr("test_key")) + encryption = FernetEncryption(SecretStr("test_key")) encrypted = encryption.encrypt_secret("") decrypted = encryption.decrypt_secret(encrypted) @@ -2515,7 +2503,7 @@ def test_encrypt_secret_empty_string(self): def test_encrypt_secret_unicode_characters(self): """Test encrypting string with unicode characters.""" - encryption = OAuthEncryption(SecretStr("test_key")) + encryption = FernetEncryption(SecretStr("test_key")) plaintext = "šŸ” secret with Ć©mojis and spĆ©ciĆ l chars Ʊ" encrypted = encryption.encrypt_secret(plaintext) @@ -2525,20 +2513,15 @@ def test_encrypt_secret_unicode_characters(self): def test_encrypt_secret_exception_handling(self): """Test exception handling in encrypt_secret.""" - encryption = OAuthEncryption(SecretStr("test_key")) - - # Mock the Fernet instance to raise an exception - with patch.object(encryption, "_get_fernet") as mock_get_fernet: - mock_fernet = Mock() - mock_fernet.encrypt.side_effect = Exception("Encryption failed") - mock_get_fernet.return_value = mock_fernet + encryption = FernetEncryption(SecretStr("test_key")) + with patch.object(encryption, "derive_key_argon2id", side_effect=Exception("Encryption failed")): with pytest.raises(Exception, match="Encryption failed"): encryption.encrypt_secret("test") def test_decrypt_secret_success(self): """Test successful secret decryption.""" - encryption = OAuthEncryption(SecretStr("test_secret_key")) + encryption = FernetEncryption(SecretStr("test_secret_key")) plaintext = "original_secret" # First encrypt @@ -2551,7 +2534,7 @@ def test_decrypt_secret_success(self): def test_decrypt_secret_invalid_data(self): """Test decryption with invalid encrypted data.""" - encryption = OAuthEncryption(SecretStr("test_key")) + encryption = FernetEncryption(SecretStr("test_key")) result = encryption.decrypt_secret("invalid_encrypted_data") @@ -2559,8 +2542,8 @@ def test_decrypt_secret_invalid_data(self): def test_decrypt_secret_wrong_key(self): """Test decryption with wrong key.""" - encryption1 = OAuthEncryption(SecretStr("key1")) - encryption2 = OAuthEncryption(SecretStr("key2")) + encryption1 = FernetEncryption(SecretStr("key1")) + encryption2 = FernetEncryption(SecretStr("key2")) # Encrypt with one key encrypted = encryption1.encrypt_secret("secret") @@ -2572,7 +2555,7 @@ def test_decrypt_secret_wrong_key(self): def test_decrypt_secret_corrupted_data(self): """Test decryption with corrupted base64 data.""" - encryption = OAuthEncryption(SecretStr("test_key")) + encryption = FernetEncryption(SecretStr("test_key")) # Create valid encrypted data then corrupt it encrypted = encryption.encrypt_secret("test") @@ -2584,7 +2567,7 @@ def test_decrypt_secret_corrupted_data(self): def test_decrypt_secret_malformed_base64(self): """Test decryption with malformed base64.""" - encryption = OAuthEncryption(SecretStr("test_key")) + encryption = FernetEncryption(SecretStr("test_key")) result = encryption.decrypt_secret("not_valid_base64!@#") @@ -2592,7 +2575,7 @@ def test_decrypt_secret_malformed_base64(self): def test_decrypt_secret_empty_string(self): """Test decryption with empty string.""" - encryption = OAuthEncryption(SecretStr("test_key")) + encryption = FernetEncryption(SecretStr("test_key")) result = encryption.decrypt_secret("") @@ -2600,7 +2583,7 @@ def test_decrypt_secret_empty_string(self): def test_is_encrypted_valid_encrypted_data(self): """Test is_encrypted with valid encrypted data.""" - encryption = OAuthEncryption(SecretStr("test_key")) + encryption = FernetEncryption(SecretStr("test_key")) encrypted = encryption.encrypt_secret("test_data") @@ -2608,14 +2591,14 @@ def test_is_encrypted_valid_encrypted_data(self): def test_is_encrypted_plain_text(self): """Test is_encrypted with plain text.""" - encryption = OAuthEncryption(SecretStr("test_key")) + encryption = FernetEncryption(SecretStr("test_key")) assert encryption.is_encrypted("plain_text_secret") is False assert encryption.is_encrypted("another_plain_string") is False def test_is_encrypted_short_data(self): """Test is_encrypted with short data.""" - encryption = OAuthEncryption(SecretStr("test_key")) + encryption = FernetEncryption(SecretStr("test_key")) # Fernet encrypted data should be at least 32 bytes short_data = "dGVzdA==" # "test" in base64 (only 4 bytes when decoded) @@ -2624,7 +2607,7 @@ def test_is_encrypted_short_data(self): def test_is_encrypted_valid_base64_but_not_encrypted(self): """Test is_encrypted with valid base64 that's not encrypted data.""" - encryption = OAuthEncryption(SecretStr("test_key")) + encryption = FernetEncryption(SecretStr("test_key")) # Create base64 data that's long enough but not encrypted # Standard @@ -2641,32 +2624,32 @@ def test_is_encrypted_valid_base64_but_not_encrypted(self): def test_is_encrypted_invalid_base64(self): """Test is_encrypted with invalid base64.""" - encryption = OAuthEncryption(SecretStr("test_key")) + encryption = FernetEncryption(SecretStr("test_key")) assert encryption.is_encrypted("not_base64!@#$%") is False def test_is_encrypted_exception_handling(self): """Test exception handling in is_encrypted.""" - encryption = OAuthEncryption(SecretStr("test_key")) + encryption = FernetEncryption(SecretStr("test_key")) # Test with None (should handle gracefully) with patch("base64.urlsafe_b64decode", side_effect=Exception("Base64 error")): result = encryption.is_encrypted("any_string") assert result is False - def test_get_oauth_encryption_function(self): - """Test the get_oauth_encryption utility function.""" + def test_get_fernet_encryption_function(self): + """Test the get_fernet_encryption utility function.""" # First-Party - from mcpgateway.utils.oauth_encryption import get_oauth_encryption + from mcpgateway.utils.fernet_encryption import get_fernet_encryption - encryption = get_oauth_encryption(SecretStr("test_secret")) + encryption = get_fernet_encryption(SecretStr("test_secret")) - assert isinstance(encryption, OAuthEncryption) + assert isinstance(encryption, FernetEncryption) assert encryption.encryption_secret == b"test_secret" def test_encryption_roundtrip_multiple_values(self): """Test encryption/decryption roundtrip with multiple values.""" - encryption = OAuthEncryption(SecretStr("test_key")) + encryption = FernetEncryption(SecretStr("test_key")) test_values = [ "simple_token", @@ -2687,8 +2670,8 @@ def test_encryption_roundtrip_multiple_values(self): def test_encryption_key_derivation_consistency(self): """Test that key derivation is consistent across instances.""" # Create two instances with same key - encryption1 = OAuthEncryption(SecretStr("same_key")) - encryption2 = OAuthEncryption(SecretStr("same_key")) + encryption1 = FernetEncryption(SecretStr("same_key")) + encryption2 = FernetEncryption(SecretStr("same_key")) # Encrypt with first instance plaintext = "test_consistency" @@ -2702,7 +2685,7 @@ def test_encryption_key_derivation_consistency(self): def test_encryption_with_long_key(self): """Test encryption with very long key.""" long_key = SecretStr("a" * 1000) # Very long key - encryption = OAuthEncryption(long_key) + encryption = FernetEncryption(long_key) encrypted = encryption.encrypt_secret("test_data") decrypted = encryption.decrypt_secret(encrypted) @@ -2712,25 +2695,9 @@ def test_encryption_with_long_key(self): def test_encryption_with_special_char_key(self): """Test encryption with key containing special characters.""" special_key = SecretStr("key_with_special_chars!@#$%^&*()_+-={}[]|\\:;\"'<>?,./") - encryption = OAuthEncryption(special_key) + encryption = FernetEncryption(special_key) encrypted = encryption.encrypt_secret("test_data") decrypted = encryption.decrypt_secret(encrypted) assert decrypted == "test_data" - - def test_fernet_instance_caching(self): - """Test that Fernet instance is properly cached.""" - encryption = OAuthEncryption(SecretStr("test_key")) - - # First call should create instance - assert encryption._fernet is None - fernet1 = encryption._get_fernet() - assert encryption._fernet is not None - - # Subsequent calls should return cached instance - fernet2 = encryption._get_fernet() - fernet3 = encryption._get_fernet() - - assert fernet1 is fernet2 - assert fernet2 is fernet3 From 13e01d3bf6d308bbbd815ad7229f0c7ac8eb6244 Mon Sep 17 00:00:00 2001 From: Madhav Kandukuri Date: Thu, 30 Oct 2025 11:28:01 +0530 Subject: [PATCH 04/14] Add docstring Signed-off-by: Madhav Kandukuri --- mcpgateway/utils/fernet_encryption.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/mcpgateway/utils/fernet_encryption.py b/mcpgateway/utils/fernet_encryption.py index 470607c99..f139e054a 100644 --- a/mcpgateway/utils/fernet_encryption.py +++ b/mcpgateway/utils/fernet_encryption.py @@ -64,6 +64,18 @@ def __init__(self, encryption_secret: SecretStr, time_cost: Optional[int] = None self.salt_len = salt_len def derive_key_argon2id(self, passphrase: bytes, salt: bytes, time_cost: int, memory_cost: int, parallelism: int) -> bytes: + """Derive a key from a passphrase using Argon2id. + + Args: + passphrase: The passphrase to derive the key from + salt: The salt to use in key derivation + time_cost: Argon2id time cost parameter + memory_cost: Argon2id memory cost parameter (in KiB) + parallelism: Argon2id parallelism parameter + + Returns: + The derived key + """ raw = hash_secret_raw( secret=passphrase, salt=salt, From 7c3cccc8a123629d61f157147de681138b081a8e Mon Sep 17 00:00:00 2001 From: Madhav Kandukuri Date: Thu, 30 Oct 2025 16:09:51 +0530 Subject: [PATCH 05/14] Make sso_service use fernet_encryption utl Signed-off-by: Madhav Kandukuri --- ...3320c56_use_argon2id_for_encryption_key.py | 90 +++++++++++++++++++ mcpgateway/services/sso_service.py | 51 ++--------- 2 files changed, 99 insertions(+), 42 deletions(-) create mode 100644 mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py diff --git a/mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py b/mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py new file mode 100644 index 000000000..1675ca7c0 --- /dev/null +++ b/mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py @@ -0,0 +1,90 @@ +"""Use Argon2id for encryption key + +Revision ID: a706a3320c56 +Revises: 3c89a45f32e5 +Create Date: 2025-10-30 15:31:25.115536 + +""" +import base64 +import json +import os +from typing import Sequence, Union + +from mcpgateway.config import settings + +from alembic import op +import sqlalchemy as sa +from cryptography.fernet import Fernet +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC +from argon2.low_level import hash_secret_raw, Type + +def reencrypt_with_argon2id(encrypted_text: str) -> str: + """Re-encrypts an existing encrypted text using Argon2id KDF. + + Args: + encrypted_text: The original encrypted text using PBKDF2HMAC. + + Returns: + A JSON string containing the Argon2id re-encrypted token and parameters. + """ + encryption_secret = settings.auth_encryption_secret.get_secret_value().encode() + original_kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=32, + salt=b"mcp_gateway_oauth", # Fixed salt for consistency + iterations=100000, + ) + original_key = base64.urlsafe_b64encode(original_kdf.derive(encryption_secret)) + original_fernet = Fernet(original_key) + original_encrypted_bytes = base64.urlsafe_b64decode(encrypted_text.encode()) + original_decrypted_bytes = original_fernet.decrypt(original_encrypted_bytes) + + time_cost = getattr(settings, "argon2id_time_cost", 3) + memory_cost = getattr(settings, "argon2id_memory_cost", 65536) + parallelism = getattr(settings, "argon2id_parallelism", 1) + hash_len = 32 + + salt = os.urandom(16) + argon2id_raw = hash_secret_raw( + secret=encryption_secret, + salt=salt, + time_cost=time_cost, + memory_cost=memory_cost, # KiB + parallelism=parallelism, + hash_len=hash_len, + type=Type.ID, + ) + argon2id_key = base64.urlsafe_b64encode(argon2id_raw) + argon2id_fernet = Fernet(argon2id_key) + argon2id_encrypted_bytes = argon2id_fernet.encrypt(original_decrypted_bytes) + return json.dumps( + { + "kdf": "argon2id", + "t": time_cost, + "m": memory_cost, + "p": parallelism, + "salt": base64.b64encode(salt).decode(), + "token": argon2id_encrypted_bytes.decode(), + } + ) + +# revision identifiers, used by Alembic. +revision: str = 'a706a3320c56' +down_revision: Union[str, Sequence[str], None] = '3c89a45f32e5' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + pass + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + pass + # ### end Alembic commands ### diff --git a/mcpgateway/services/sso_service.py b/mcpgateway/services/sso_service.py index 032b22bf7..d9400f4c8 100644 --- a/mcpgateway/services/sso_service.py +++ b/mcpgateway/services/sso_service.py @@ -22,11 +22,7 @@ import urllib.parse # Third-Party -from cryptography.fernet import Fernet -from cryptography.hazmat.primitives import hashes -from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC import httpx -from pydantic import SecretStr from sqlalchemy import and_, select from sqlalchemy.orm import Session @@ -35,6 +31,7 @@ from mcpgateway.db import PendingUserApproval, SSOAuthSession, SSOProvider, utc_now from mcpgateway.services.email_auth_service import EmailAuthService from mcpgateway.utils.create_jwt_token import create_jwt_token +from mcpgateway.utils.fernet_encryption import get_fernet_encryption # Logger logger = logging.getLogger(__name__) @@ -64,39 +61,7 @@ def __init__(self, db: Session): """ self.db = db self.auth_service = EmailAuthService(db) - self._encryption_key = self._get_or_create_encryption_key() - - def _get_or_create_encryption_key(self) -> bytes: - """Get or create encryption key for client secrets. - - Returns: - Encryption key bytes - """ - # Use the same encryption secret as the auth service - key = settings.auth_encryption_secret - - if not key: - # Generate a new key - in production, this should be persisted - key = Fernet.generate_key() - # Derive a proper Fernet key from the secret - - # Unwrap SecretStr if necessary - if isinstance(key, SecretStr): - key = key.get_secret_value() - - # Convert string to bytes - if isinstance(key, str): - key = key.encode("utf-8") - - # Derive a 32-byte key using PBKDF2 - kdf = PBKDF2HMAC( - algorithm=hashes.SHA256(), - length=32, - salt=b"sso_salt", # Static salt for consistency - iterations=100000, - ) - derived_key = base64.urlsafe_b64encode(kdf.derive(key)) - return derived_key + self._encryption = get_fernet_encryption(settings.auth_encryption_secret) def _encrypt_secret(self, secret: str) -> str: """Encrypt a client secret for secure storage. @@ -107,10 +72,9 @@ def _encrypt_secret(self, secret: str) -> str: Returns: Encrypted secret string """ - fernet = Fernet(self._encryption_key) - return fernet.encrypt(secret.encode()).decode() + return self._encryption.encrypt_secret(secret) - def _decrypt_secret(self, encrypted_secret: str) -> str: + def _decrypt_secret(self, encrypted_secret: str) -> Optional[str]: """Decrypt a client secret for use. Args: @@ -119,8 +83,11 @@ def _decrypt_secret(self, encrypted_secret: str) -> str: Returns: Plain text client secret """ - fernet = Fernet(self._encryption_key) - return fernet.decrypt(encrypted_secret.encode()).decode() + decrypted: str | None = self._encryption.decrypt_secret(encrypted_secret) + if decrypted: + return decrypted + + return None def list_enabled_providers(self) -> List[SSOProvider]: """Get list of enabled SSO providers. From e0829ca7c308b38d52b3fce9c7f1873b8384064f Mon Sep 17 00:00:00 2001 From: Madhav Kandukuri Date: Thu, 30 Oct 2025 16:24:22 +0530 Subject: [PATCH 06/14] wip migration script Signed-off-by: Madhav Kandukuri --- ...3320c56_use_argon2id_for_encryption_key.py | 326 +++++++++++++++++- 1 file changed, 312 insertions(+), 14 deletions(-) diff --git a/mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py b/mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py index 1675ca7c0..7e471bf17 100644 --- a/mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py +++ b/mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py @@ -7,8 +7,9 @@ """ import base64 import json +import logging import os -from typing import Sequence, Union +from typing import Sequence, Union, Optional from mcpgateway.config import settings @@ -19,6 +20,14 @@ from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC from argon2.low_level import hash_secret_raw, Type +logger = logging.getLogger(__name__) + +# revision identifiers, used by Alembic. +revision: str = 'a706a3320c56' +down_revision: Union[str, Sequence[str], None] = '3c89a45f32e5' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + def reencrypt_with_argon2id(encrypted_text: str) -> str: """Re-encrypts an existing encrypted text using Argon2id KDF. @@ -69,22 +78,311 @@ def reencrypt_with_argon2id(encrypted_text: str) -> str: } ) -# revision identifiers, used by Alembic. -revision: str = 'a706a3320c56' -down_revision: Union[str, Sequence[str], None] = '3c89a45f32e5' -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None +def reencrypt_with_pbkdf2hmac(argon2id_bundle: str) -> Optional[str]: + """Re-encrypts an Argon2id encrypted bundle back to PBKDF2HMAC. + + Args: + argon2id_bundle: The JSON string containing Argon2id encrypted data. + + Returns: + A PBKDF2HMAC re-encrypted token. + """ + try: + argon2id_data = json.loads(argon2id_bundle) + if argon2id_data.get("kdf") != "argon2id": + raise ValueError("Not an Argon2id bundle") + + encryption_secret = settings.auth_encryption_secret.get_secret_value().encode() + salt = base64.b64decode(argon2id_data["salt"]) + time_cost = argon2id_data["t"] + memory_cost = argon2id_data["m"] + parallelism = argon2id_data["p"] + argon2id_raw = hash_secret_raw( + secret=encryption_secret, + salt=salt, + time_cost=time_cost, + memory_cost=memory_cost, # KiB + parallelism=parallelism, + hash_len=32, + type=Type.ID, + ) + argon2id_key = base64.urlsafe_b64encode(argon2id_raw) + argon2id_fernet = Fernet(argon2id_key) + argon2id_encrypted_bytes = argon2id_data["token"].encode() + decrypted_bytes = argon2id_fernet.decrypt(argon2id_encrypted_bytes) + + original_kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=32, + salt=b"mcp_gateway_oauth", # Fixed salt for consistency + iterations=100000, + ) + original_key = base64.urlsafe_b64encode(original_kdf.derive(encryption_secret)) + original_fernet = Fernet(original_key) + original_encrypted_bytes = original_fernet.encrypt(decrypted_bytes) + return base64.urlsafe_b64encode(original_encrypted_bytes).decode() + except Exception as e: + raise ValueError("Invalid Argon2id bundle") from e + +def _looks_argon2_bundle(val: Optional[str]) -> bool: + if not val: + return False + # Fast path: Fernet tokens usually start with 'gAAAAA'; Argon2 bundle is JSON + if val and val[:1] in ('{', '['): + try: + obj = json.loads(val) + return isinstance(obj, dict) and obj.get("kdf") == "argon2id" + except Exception: + return False + return False + +def _looks_legacy_pbkdf2_token(val: Optional[str]) -> bool: + """Heuristic for legacy PBKDF2 format (base64-wrapped Fernet token string, not JSON).""" + if not val or not isinstance(val, str): + return False + # Legacy column stored base64(urlsafe) of the Fernet token (which is itself base64 bytes), + # so it's NOT JSON and usually not starting with '{' + return not val.startswith("{") + +def _upgrade_value(old: Optional[str]) -> Optional[str]: + """PBKDF2 -> Argon2id bundle, when needed.""" + if not old: + return None + if _looks_argon2_bundle(old): + return None # already migrated + if not _looks_legacy_pbkdf2_token(old): + return None # unknown format; skip + try: + return reencrypt_with_argon2id(old) + except Exception as e: + logger.warning("Upgrade skip (cannot re-encrypt PBKDF2 value): %s", e) + return None + + +def _downgrade_value(old: Optional[str]) -> Optional[str]: + """Argon2id bundle -> PBKDF2 legacy, when needed.""" + if not old: + return None + if not _looks_argon2_bundle(old): + return None # not an Argon2 bundle + try: + return reencrypt_with_pbkdf2hmac(old) + except Exception as e: + logger.warning("Downgrade skip (cannot re-encrypt Argon2 bundle): %s", e) + return None + + +def _upgrade_json_client_secret(bind, table: str) -> None: + rows = bind.execute(text(f""" + SELECT id, oauth_config + FROM {table} + WHERE oauth_config IS NOT NULL + """)).mappings().all() + + for r in rows: + rid = r["id"] + cfg_raw = r["oauth_config"] + try: + cfg = cfg_raw if isinstance(cfg_raw, dict) else json.loads(cfg_raw) + except Exception: + logger.warning("%s.id=%s: oauth_config not JSON, skipping", table, rid) + continue + + secret = cfg.get("client_secret") + new_secret = _upgrade_value(secret) + if new_secret: + cfg["client_secret"] = new_secret + bind.execute( + text(f"UPDATE {table} SET oauth_config = :cfg WHERE id = :id"), + {"cfg": json.dumps(cfg), "id": rid}, + ) + + +def _downgrade_json_client_secret(bind, table: str) -> None: + rows = bind.execute(text(f""" + SELECT id, oauth_config + FROM {table} + WHERE oauth_config IS NOT NULL + """)).mappings().all() + + for r in rows: + rid = r["id"] + cfg_raw = r["oauth_config"] + try: + cfg = cfg_raw if isinstance(cfg_raw, dict) else json.loads(cfg_raw) + except Exception: + logger.warning("%s.id=%s: oauth_config not JSON, skipping", table, rid) + continue + + secret = cfg.get("client_secret") + new_secret = _downgrade_value(secret) + if new_secret: + cfg["client_secret"] = new_secret + bind.execute( + text(f"UPDATE {table} SET oauth_config = :cfg WHERE id = :id"), + {"cfg": json.dumps(cfg), "id": rid}, + ) def upgrade() -> None: - """Upgrade schema.""" - # ### commands auto generated by Alembic - please adjust! ### - pass - # ### end Alembic commands ### + bind = op.get_bind() + + # JSON: gateways.oauth_config.client_secret + _upgrade_json_client_secret(bind, "gateways") + + # JSON: a2a_agents.oauth_config.client_secret + _upgrade_json_client_secret(bind, "a2a_agents") + + # oauth_tokens: access_token, refresh_token + rows = bind.execute(text(""" + SELECT id, access_token, refresh_token + FROM oauth_tokens + WHERE (access_token IS NOT NULL OR refresh_token IS NOT NULL) + """)).mappings().all() + + for r in rows: + tid = r["id"] + at = r["access_token"] + rt = r["refresh_token"] + nat = _upgrade_value(at) + nrt = _upgrade_value(rt) + if nat or nrt: + bind.execute( + text(""" + UPDATE oauth_tokens + SET access_token = COALESCE(:nat, access_token), + refresh_token = COALESCE(:nrt, refresh_token) + WHERE id = :id + """), + {"nat": nat, "nrt": nrt, "id": tid}, + ) + + # registered_oauth_clients: client_secret_encrypted, registration_access_token_encrypted + rows = bind.execute(text(""" + SELECT id, client_secret_encrypted, registration_access_token_encrypted + FROM registered_oauth_clients + WHERE client_secret_encrypted IS NOT NULL + OR registration_access_token_encrypted IS NOT NULL + """)).mappings().all() + + for r in rows: + rid = r["id"] + cs = r["client_secret_encrypted"] + rat = r["registration_access_token_encrypted"] + ncs = _upgrade_value(cs) + nrat = _upgrade_value(rat) + if ncs or nrat: + bind.execute( + text(""" + UPDATE registered_oauth_clients + SET client_secret_encrypted = COALESCE(:ncs, client_secret_encrypted), + registration_access_token_encrypted = COALESCE(:nrat, registration_access_token_encrypted) + WHERE id = :id + """), + {"ncs": ncs, "nrat": nrat, "id": rid}, + ) + + # sso_providers: client_secret_encrypted + rows = bind.execute(text(""" + SELECT id, client_secret_encrypted + FROM sso_providers + WHERE client_secret_encrypted IS NOT NULL + """)).mappings().all() + + for r in rows: + sid = r["id"] + cs = r["client_secret_encrypted"] + ncs = _upgrade_value(cs) + if ncs: + bind.execute( + text(""" + UPDATE sso_providers + SET client_secret_encrypted = :ncs + WHERE id = :id + """), + {"ncs": ncs, "id": sid}, + ) + + logger.info("Upgrade complete: PBKDF2 -> Argon2id bundle re-encryption.") def downgrade() -> None: - """Downgrade schema.""" - # ### commands auto generated by Alembic - please adjust! ### - pass - # ### end Alembic commands ### + bind = op.get_bind() + + # JSON: gateways.oauth_config.client_secret + _downgrade_json_client_secret(bind, "gateways") + + # JSON: a2a_agents.oauth_config.client_secret + _downgrade_json_client_secret(bind, "a2a_agents") + + # oauth_tokens: access_token, refresh_token + rows = bind.execute(text(""" + SELECT id, access_token, refresh_token + FROM oauth_tokens + WHERE (access_token IS NOT NULL OR refresh_token IS NOT NULL) + """)).mappings().all() + + for r in rows: + tid = r["id"] + at = r["access_token"] + rt = r["refresh_token"] + nat = _downgrade_value(at) + nrt = _downgrade_value(rt) + if nat or nrt: + bind.execute( + text(""" + UPDATE oauth_tokens + SET access_token = COALESCE(:nat, access_token), + refresh_token = COALESCE(:nrt, refresh_token) + WHERE id = :id + """), + {"nat": nat, "nrt": nrt, "id": tid}, + ) + + # registered_oauth_clients: client_secret_encrypted, registration_access_token_encrypted + rows = bind.execute(text(""" + SELECT id, client_secret_encrypted, registration_access_token_encrypted + FROM registered_oauth_clients + WHERE client_secret_encrypted IS NOT NULL + OR registration_access_token_encrypted IS NOT NULL + """)).mappings().all() + + for r in rows: + rid = r["id"] + cs = r["client_secret_encrypted"] + rat = r["registration_access_token_encrypted"] + ncs = _downgrade_value(cs) + nrat = _downgrade_value(rat) + if ncs or nrat: + bind.execute( + text(""" + UPDATE registered_oauth_clients + SET client_secret_encrypted = COALESCE(:ncs, client_secret_encrypted), + registration_access_token_encrypted = COALESCE(:nrat, registration_access_token_encrypted) + WHERE id = :id + """), + {"ncs": ncs, "nrat": nrat, "id": rid}, + ) + + # sso_providers: client_secret_encrypted + rows = bind.execute(text(""" + SELECT id, client_secret_encrypted + FROM sso_providers + WHERE client_secret_encrypted IS NOT NULL + """)).mappings().all() + + for r in rows: + sid = r["id"] + cs = r["client_secret_encrypted"] + ncs = _downgrade_value(cs) + if ncs: + bind.execute( + text(""" + UPDATE sso_providers + SET client_secret_encrypted = :ncs + WHERE id = :id + """), + {"ncs": ncs, "id": sid}, + ) + + logger.info("Downgrade complete: Argon2id bundle -> PBKDF2 legacy re-encryption.") \ No newline at end of file From 5bcbd4ee06bdcdbbb2280142e4384148293a2fe7 Mon Sep 17 00:00:00 2001 From: Madhav Kandukuri Date: Thu, 30 Oct 2025 16:30:11 +0530 Subject: [PATCH 07/14] Fix import in alembic script Signed-off-by: Madhav Kandukuri --- .../versions/a706a3320c56_use_argon2id_for_encryption_key.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py b/mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py index 7e471bf17..a610a34cc 100644 --- a/mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py +++ b/mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py @@ -14,7 +14,7 @@ from mcpgateway.config import settings from alembic import op -import sqlalchemy as sa +import sqlalchemy as text from cryptography.fernet import Fernet from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC From afc5c6da999491e6ac3d75d3609594f2532bd820 Mon Sep 17 00:00:00 2001 From: Madhav Kandukuri Date: Thu, 30 Oct 2025 18:01:03 +0530 Subject: [PATCH 08/14] Linting fixes Signed-off-by: Madhav Kandukuri --- ...3320c56_use_argon2id_for_encryption_key.py | 201 ++++++++++++++---- mcpgateway/services/sso_service.py | 2 +- 2 files changed, 160 insertions(+), 43 deletions(-) diff --git a/mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py b/mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py index a610a34cc..5fcafe185 100644 --- a/mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py +++ b/mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py @@ -5,32 +5,37 @@ Create Date: 2025-10-30 15:31:25.115536 """ + +# Standard import base64 import json import logging import os -from typing import Sequence, Union, Optional - -from mcpgateway.config import settings +from typing import Optional, Sequence, Union +# Third-Party from alembic import op -import sqlalchemy as text +from argon2.low_level import hash_secret_raw, Type from cryptography.fernet import Fernet from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC -from argon2.low_level import hash_secret_raw, Type +from sqlalchemy import text + +# First-Party +from mcpgateway.config import settings logger = logging.getLogger(__name__) # revision identifiers, used by Alembic. -revision: str = 'a706a3320c56' -down_revision: Union[str, Sequence[str], None] = '3c89a45f32e5' +revision: str = "a706a3320c56" +down_revision: Union[str, Sequence[str], None] = "3c89a45f32e5" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None + def reencrypt_with_argon2id(encrypted_text: str) -> str: """Re-encrypts an existing encrypted text using Argon2id KDF. - + Args: encrypted_text: The original encrypted text using PBKDF2HMAC. @@ -87,12 +92,15 @@ def reencrypt_with_pbkdf2hmac(argon2id_bundle: str) -> Optional[str]: Returns: A PBKDF2HMAC re-encrypted token. + + Raises: + ValueError: If the input is not a valid Argon2id bundle. """ try: argon2id_data = json.loads(argon2id_bundle) if argon2id_data.get("kdf") != "argon2id": raise ValueError("Not an Argon2id bundle") - + encryption_secret = settings.auth_encryption_secret.get_secret_value().encode() salt = base64.b64decode(argon2id_data["salt"]) time_cost = argon2id_data["t"] @@ -125,11 +133,20 @@ def reencrypt_with_pbkdf2hmac(argon2id_bundle: str) -> Optional[str]: except Exception as e: raise ValueError("Invalid Argon2id bundle") from e + def _looks_argon2_bundle(val: Optional[str]) -> bool: + """Heuristic for Argon2id bundle format (JSON with kdf=argon2id). + + Args: + val: The encrypted value. + + Returns: + True if it looks like an Argon2id encrypted token. + """ if not val: return False # Fast path: Fernet tokens usually start with 'gAAAAA'; Argon2 bundle is JSON - if val and val[:1] in ('{', '['): + if val and val[:1] in ("{", "["): try: obj = json.loads(val) return isinstance(obj, dict) and obj.get("kdf") == "argon2id" @@ -137,16 +154,32 @@ def _looks_argon2_bundle(val: Optional[str]) -> bool: return False return False + def _looks_legacy_pbkdf2_token(val: Optional[str]) -> bool: - """Heuristic for legacy PBKDF2 format (base64-wrapped Fernet token string, not JSON).""" + """Heuristic for legacy PBKDF2 format (base64-wrapped Fernet token string, not JSON). + + Args: + val: The encrypted value. + + Returns: + True if it looks like a legacy PBKDF2 encrypted token. + """ if not val or not isinstance(val, str): return False # Legacy column stored base64(urlsafe) of the Fernet token (which is itself base64 bytes), # so it's NOT JSON and usually not starting with '{' return not val.startswith("{") + def _upgrade_value(old: Optional[str]) -> Optional[str]: - """PBKDF2 -> Argon2id bundle, when needed.""" + """PBKDF2 -> Argon2id bundle, when needed. + + Args: + old: The existing encrypted value. + + Returns: + The re-encrypted value using Argon2id, or None if no change is needed. + """ if not old: return None if _looks_argon2_bundle(old): @@ -161,7 +194,14 @@ def _upgrade_value(old: Optional[str]) -> Optional[str]: def _downgrade_value(old: Optional[str]) -> Optional[str]: - """Argon2id bundle -> PBKDF2 legacy, when needed.""" + """Argon2id bundle -> PBKDF2 legacy, when needed. + + Args: + old: The existing encrypted value. + + Returns: + The re-encrypted value using PBKDF2HMAC, or None if no change is needed. + """ if not old: return None if not _looks_argon2_bundle(old): @@ -174,11 +214,19 @@ def _downgrade_value(old: Optional[str]) -> Optional[str]: def _upgrade_json_client_secret(bind, table: str) -> None: - rows = bind.execute(text(f""" + rows = ( + bind.execute( + text( + f""" SELECT id, oauth_config FROM {table} WHERE oauth_config IS NOT NULL - """)).mappings().all() + """ + ) + ) + .mappings() + .all() + ) for r in rows: rid = r["id"] @@ -200,11 +248,19 @@ def _upgrade_json_client_secret(bind, table: str) -> None: def _downgrade_json_client_secret(bind, table: str) -> None: - rows = bind.execute(text(f""" + rows = ( + bind.execute( + text( + f""" SELECT id, oauth_config FROM {table} WHERE oauth_config IS NOT NULL - """)).mappings().all() + """ + ) + ) + .mappings() + .all() + ) for r in rows: rid = r["id"] @@ -224,6 +280,7 @@ def _downgrade_json_client_secret(bind, table: str) -> None: {"cfg": json.dumps(cfg), "id": rid}, ) + def upgrade() -> None: bind = op.get_bind() @@ -234,11 +291,19 @@ def upgrade() -> None: _upgrade_json_client_secret(bind, "a2a_agents") # oauth_tokens: access_token, refresh_token - rows = bind.execute(text(""" + rows = ( + bind.execute( + text( + """ SELECT id, access_token, refresh_token FROM oauth_tokens WHERE (access_token IS NOT NULL OR refresh_token IS NOT NULL) - """)).mappings().all() + """ + ) + ) + .mappings() + .all() + ) for r in rows: tid = r["id"] @@ -248,22 +313,32 @@ def upgrade() -> None: nrt = _upgrade_value(rt) if nat or nrt: bind.execute( - text(""" + text( + """ UPDATE oauth_tokens SET access_token = COALESCE(:nat, access_token), refresh_token = COALESCE(:nrt, refresh_token) WHERE id = :id - """), + """ + ), {"nat": nat, "nrt": nrt, "id": tid}, ) # registered_oauth_clients: client_secret_encrypted, registration_access_token_encrypted - rows = bind.execute(text(""" + rows = ( + bind.execute( + text( + """ SELECT id, client_secret_encrypted, registration_access_token_encrypted FROM registered_oauth_clients WHERE client_secret_encrypted IS NOT NULL OR registration_access_token_encrypted IS NOT NULL - """)).mappings().all() + """ + ) + ) + .mappings() + .all() + ) for r in rows: rid = r["id"] @@ -273,21 +348,31 @@ def upgrade() -> None: nrat = _upgrade_value(rat) if ncs or nrat: bind.execute( - text(""" + text( + """ UPDATE registered_oauth_clients SET client_secret_encrypted = COALESCE(:ncs, client_secret_encrypted), registration_access_token_encrypted = COALESCE(:nrat, registration_access_token_encrypted) WHERE id = :id - """), + """ + ), {"ncs": ncs, "nrat": nrat, "id": rid}, ) # sso_providers: client_secret_encrypted - rows = bind.execute(text(""" + rows = ( + bind.execute( + text( + """ SELECT id, client_secret_encrypted FROM sso_providers WHERE client_secret_encrypted IS NOT NULL - """)).mappings().all() + """ + ) + ) + .mappings() + .all() + ) for r in rows: sid = r["id"] @@ -295,11 +380,13 @@ def upgrade() -> None: ncs = _upgrade_value(cs) if ncs: bind.execute( - text(""" + text( + """ UPDATE sso_providers SET client_secret_encrypted = :ncs WHERE id = :id - """), + """ + ), {"ncs": ncs, "id": sid}, ) @@ -316,11 +403,19 @@ def downgrade() -> None: _downgrade_json_client_secret(bind, "a2a_agents") # oauth_tokens: access_token, refresh_token - rows = bind.execute(text(""" + rows = ( + bind.execute( + text( + """ SELECT id, access_token, refresh_token FROM oauth_tokens WHERE (access_token IS NOT NULL OR refresh_token IS NOT NULL) - """)).mappings().all() + """ + ) + ) + .mappings() + .all() + ) for r in rows: tid = r["id"] @@ -330,22 +425,32 @@ def downgrade() -> None: nrt = _downgrade_value(rt) if nat or nrt: bind.execute( - text(""" + text( + """ UPDATE oauth_tokens SET access_token = COALESCE(:nat, access_token), refresh_token = COALESCE(:nrt, refresh_token) WHERE id = :id - """), + """ + ), {"nat": nat, "nrt": nrt, "id": tid}, ) # registered_oauth_clients: client_secret_encrypted, registration_access_token_encrypted - rows = bind.execute(text(""" + rows = ( + bind.execute( + text( + """ SELECT id, client_secret_encrypted, registration_access_token_encrypted FROM registered_oauth_clients WHERE client_secret_encrypted IS NOT NULL OR registration_access_token_encrypted IS NOT NULL - """)).mappings().all() + """ + ) + ) + .mappings() + .all() + ) for r in rows: rid = r["id"] @@ -355,21 +460,31 @@ def downgrade() -> None: nrat = _downgrade_value(rat) if ncs or nrat: bind.execute( - text(""" + text( + """ UPDATE registered_oauth_clients SET client_secret_encrypted = COALESCE(:ncs, client_secret_encrypted), registration_access_token_encrypted = COALESCE(:nrat, registration_access_token_encrypted) WHERE id = :id - """), + """ + ), {"ncs": ncs, "nrat": nrat, "id": rid}, ) # sso_providers: client_secret_encrypted - rows = bind.execute(text(""" + rows = ( + bind.execute( + text( + """ SELECT id, client_secret_encrypted FROM sso_providers WHERE client_secret_encrypted IS NOT NULL - """)).mappings().all() + """ + ) + ) + .mappings() + .all() + ) for r in rows: sid = r["id"] @@ -377,12 +492,14 @@ def downgrade() -> None: ncs = _downgrade_value(cs) if ncs: bind.execute( - text(""" + text( + """ UPDATE sso_providers SET client_secret_encrypted = :ncs WHERE id = :id - """), + """ + ), {"ncs": ncs, "id": sid}, ) - logger.info("Downgrade complete: Argon2id bundle -> PBKDF2 legacy re-encryption.") \ No newline at end of file + logger.info("Downgrade complete: Argon2id bundle -> PBKDF2 legacy re-encryption.") diff --git a/mcpgateway/services/sso_service.py b/mcpgateway/services/sso_service.py index d9400f4c8..89149032f 100644 --- a/mcpgateway/services/sso_service.py +++ b/mcpgateway/services/sso_service.py @@ -86,7 +86,7 @@ def _decrypt_secret(self, encrypted_secret: str) -> Optional[str]: decrypted: str | None = self._encryption.decrypt_secret(encrypted_secret) if decrypted: return decrypted - + return None def list_enabled_providers(self) -> List[SSOProvider]: From 91d9d39bc150b30c5e067e8735b123dd1e7d41c1 Mon Sep 17 00:00:00 2001 From: Madhav Kandukuri Date: Thu, 30 Oct 2025 19:09:33 +0530 Subject: [PATCH 09/14] Move encryption from util to service Signed-off-by: Madhav Kandukuri --- mcpgateway/admin.py | 18 +-- ...3320c56_use_argon2id_for_encryption_key.py | 127 ++++++++---------- mcpgateway/routers/oauth_router.py | 4 +- mcpgateway/services/dcr_service.py | 8 +- .../encryption_service.py} | 26 ++-- mcpgateway/services/oauth_manager.py | 10 +- mcpgateway/services/sso_service.py | 4 +- mcpgateway/services/token_storage_service.py | 4 +- .../mcpgateway/services/test_dcr_service.py | 8 +- tests/unit/mcpgateway/test_admin.py | 6 +- tests/unit/mcpgateway/test_oauth_manager.py | 96 ++++++------- 11 files changed, 151 insertions(+), 160 deletions(-) rename mcpgateway/{utils/fernet_encryption.py => services/encryption_service.py} (88%) diff --git a/mcpgateway/admin.py b/mcpgateway/admin.py index f88e98a51..490d59b73 100644 --- a/mcpgateway/admin.py +++ b/mcpgateway/admin.py @@ -95,6 +95,7 @@ ) from mcpgateway.services.a2a_service import A2AAgentError, A2AAgentNameConflictError, A2AAgentNotFoundError, A2AAgentService from mcpgateway.services.catalog_service import catalog_service +from mcpgateway.services.encryption_service import get_encryption_service from mcpgateway.services.export_service import ExportError, ExportService from mcpgateway.services.gateway_service import GatewayConnectionError, GatewayNameConflictError, GatewayNotFoundError, GatewayService, GatewayUrlConflictError from mcpgateway.services.import_service import ConflictStrategy @@ -112,7 +113,6 @@ from mcpgateway.services.tool_service import ToolError, ToolNameConflictError, ToolNotFoundError, ToolService from mcpgateway.utils.create_jwt_token import create_jwt_token, get_jwt_token from mcpgateway.utils.error_formatter import ErrorFormatter -from mcpgateway.utils.fernet_encryption import get_fernet_encryption from mcpgateway.utils.metadata_capture import MetadataCapture from mcpgateway.utils.pagination import generate_pagination_links from mcpgateway.utils.passthrough_headers import PassthroughHeadersError @@ -6194,7 +6194,7 @@ async def admin_add_gateway(request: Request, db: Session = Depends(get_db), use oauth_config = json.loads(oauth_config_json) # Encrypt the client secret if present if oauth_config and "client_secret" in oauth_config: - encryption = get_fernet_encryption(settings.auth_encryption_secret) + encryption = get_encryption_service(settings.auth_encryption_secret) oauth_config["client_secret"] = encryption.encrypt_secret(oauth_config["client_secret"]) except (json.JSONDecodeError, ValueError) as e: LOGGER.error(f"Failed to parse OAuth config: {e}") @@ -6231,7 +6231,7 @@ async def admin_add_gateway(request: Request, db: Session = Depends(get_db), use oauth_config["client_id"] = oauth_client_id if oauth_client_secret: # Encrypt the client secret - encryption = get_fernet_encryption(settings.auth_encryption_secret) + encryption = get_encryption_service(settings.auth_encryption_secret) oauth_config["client_secret"] = encryption.encrypt_secret(oauth_client_secret) # Add username and password for password grant type @@ -6503,7 +6503,7 @@ async def admin_edit_gateway( oauth_config = json.loads(oauth_config_json) # Encrypt the client secret if present and not empty if oauth_config and "client_secret" in oauth_config and oauth_config["client_secret"]: - encryption = get_fernet_encryption(settings.auth_encryption_secret) + encryption = get_encryption_service(settings.auth_encryption_secret) oauth_config["client_secret"] = encryption.encrypt_secret(oauth_config["client_secret"]) except (json.JSONDecodeError, ValueError) as e: LOGGER.error(f"Failed to parse OAuth config: {e}") @@ -6540,7 +6540,7 @@ async def admin_edit_gateway( oauth_config["client_id"] = oauth_client_id if oauth_client_secret: # Encrypt the client secret - encryption = get_fernet_encryption(settings.auth_encryption_secret) + encryption = get_encryption_service(settings.auth_encryption_secret) oauth_config["client_secret"] = encryption.encrypt_secret(oauth_client_secret) # Add username and password for password grant type @@ -9571,7 +9571,7 @@ async def admin_add_a2a_agent( oauth_config = json.loads(oauth_config_json) # Encrypt the client secret if present if oauth_config and "client_secret" in oauth_config: - encryption = get_fernet_encryption(settings.auth_encryption_secret) + encryption = get_encryption_service(settings.auth_encryption_secret) oauth_config["client_secret"] = encryption.encrypt_secret(oauth_config["client_secret"]) except (json.JSONDecodeError, ValueError) as e: LOGGER.error(f"Failed to parse OAuth config: {e}") @@ -9608,7 +9608,7 @@ async def admin_add_a2a_agent( oauth_config["client_id"] = oauth_client_id if oauth_client_secret: # Encrypt the client secret - encryption = get_fernet_encryption(settings.auth_encryption_secret) + encryption = get_encryption_service(settings.auth_encryption_secret) oauth_config["client_secret"] = encryption.encrypt_secret(oauth_client_secret) # Add username and password for password grant type @@ -9890,7 +9890,7 @@ async def admin_edit_a2a_agent( oauth_config = json.loads(oauth_config_json) # Encrypt the client secret if present and not empty if oauth_config and "client_secret" in oauth_config and oauth_config["client_secret"]: - encryption = get_fernet_encryption(settings.auth_encryption_secret) + encryption = get_encryption_service(settings.auth_encryption_secret) oauth_config["client_secret"] = encryption.encrypt_secret(oauth_config["client_secret"]) except (json.JSONDecodeError, ValueError) as e: LOGGER.error(f"Failed to parse OAuth config: {e}") @@ -9927,7 +9927,7 @@ async def admin_edit_a2a_agent( oauth_config["client_id"] = oauth_client_id if oauth_client_secret: # Encrypt the client secret - encryption = get_fernet_encryption(settings.auth_encryption_secret) + encryption = get_encryption_service(settings.auth_encryption_secret) oauth_config["client_secret"] = encryption.encrypt_secret(oauth_client_secret) # Add username and password for password grant type diff --git a/mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py b/mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py index 5fcafe185..02d9a8952 100644 --- a/mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py +++ b/mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py @@ -133,6 +133,14 @@ def reencrypt_with_pbkdf2hmac(argon2id_bundle: str) -> Optional[str]: except Exception as e: raise ValueError("Invalid Argon2id bundle") from e +def _reflect(conn): + md = sa.MetaData() + gateways = sa.Table("gateways", md, autoload_with=conn) + a2a_agents = sa.Table("a2a_agents", md, autoload_with=conn) + return {"gateways": gateways, "a2a_agents": a2a_agents} + +def _is_json(col): + return isinstance(col.type, sa.JSON) def _looks_argon2_bundle(val: Optional[str]) -> bool: """Heuristic for Argon2id bundle format (JSON with kdf=argon2id). @@ -213,82 +221,65 @@ def _downgrade_value(old: Optional[str]) -> Optional[str]: return None -def _upgrade_json_client_secret(bind, table: str) -> None: - rows = ( - bind.execute( - text( - f""" - SELECT id, oauth_config - FROM {table} - WHERE oauth_config IS NOT NULL - """ - ) - ) - .mappings() - .all() - ) - - for r in rows: - rid = r["id"] - cfg_raw = r["oauth_config"] - try: - cfg = cfg_raw if isinstance(cfg_raw, dict) else json.loads(cfg_raw) - except Exception: - logger.warning("%s.id=%s: oauth_config not JSON, skipping", table, rid) - continue - - secret = cfg.get("client_secret") - new_secret = _upgrade_value(secret) - if new_secret: - cfg["client_secret"] = new_secret - bind.execute( - text(f"UPDATE {table} SET oauth_config = :cfg WHERE id = :id"), - {"cfg": json.dumps(cfg), "id": rid}, - ) - - -def _downgrade_json_client_secret(bind, table: str) -> None: - rows = ( - bind.execute( - text( - f""" - SELECT id, oauth_config - FROM {table} - WHERE oauth_config IS NOT NULL - """ - ) - ) - .mappings() - .all() - ) - - for r in rows: - rid = r["id"] - cfg_raw = r["oauth_config"] - try: - cfg = cfg_raw if isinstance(cfg_raw, dict) else json.loads(cfg_raw) - except Exception: - logger.warning("%s.id=%s: oauth_config not JSON, skipping", table, rid) - continue - - secret = cfg.get("client_secret") - new_secret = _downgrade_value(secret) - if new_secret: - cfg["client_secret"] = new_secret - bind.execute( - text(f"UPDATE {table} SET oauth_config = :cfg WHERE id = :id"), - {"cfg": json.dumps(cfg), "id": rid}, - ) +def _upgrade_json_client_secret(conn, table): + t = table + sel = sa.select(t.c.id, t.c.oauth_config).where(t.c.oauth_config.isnot(None)) + for row in conn.execute(sel).mappings(): + rid = row["id"] + cfg = row["oauth_config"] + if isinstance(cfg, str): + try: + cfg = json.loads(cfg) + except json.JSONDecodeError as e: + logger.warning("Skipping %s.id=%s: invalid JSON (%s)", table, rid, e) + continue + if not isinstance(cfg, dict): continue + + old = cfg.get("client_secret") + new = _upgrade_value(old) # your helper + if not new: continue + + cfg["client_secret"] = new + value = cfg if _is_json(t.c.oauth_config) else json.dumps(cfg) + upd = sa.update(t).where(t.c.id == rid).values(oauth_config=value) + conn.execute(upd) + + +def _downgrade_json_client_secret(conn, table): + t = table + sel = sa.select(t.c.id, t.c.oauth_config).where(t.c.oauth_config.isnot(None)) + for row in conn.execute(sel).mappings(): + rid = row["id"] + cfg = row["oauth_config"] + if isinstance(cfg, str): + try: + cfg = json.loads(cfg) + except json.JSONDecodeError as e: + logger.warning("Skipping %s.id=%s: invalid JSON (%s)", table, rid, e) + continue + if not isinstance(cfg, dict): continue + + old = cfg.get("client_secret") + new = _downgrade_value(old) # your helper + if not new: continue + + cfg["client_secret"] = new + value = cfg if _is_json(t.c.oauth_config) else json.dumps(cfg) + upd = sa.update(t).where(t.c.id == rid).values(oauth_config=value) + conn.execute(upd) def upgrade() -> None: bind = op.get_bind() + conn = op.get_bind() + t = _reflect(conn) + # JSON: gateways.oauth_config.client_secret - _upgrade_json_client_secret(bind, "gateways") + _upgrade_json_client_secret(conn, t["gateways"]) # JSON: a2a_agents.oauth_config.client_secret - _upgrade_json_client_secret(bind, "a2a_agents") + _upgrade_json_client_secret(conn, t["a2a_agents"]) # oauth_tokens: access_token, refresh_token rows = ( diff --git a/mcpgateway/routers/oauth_router.py b/mcpgateway/routers/oauth_router.py index 83b4c0cb7..f8a1c17d3 100644 --- a/mcpgateway/routers/oauth_router.py +++ b/mcpgateway/routers/oauth_router.py @@ -113,9 +113,9 @@ async def initiate_oauth_flow( decrypted_secret = None if registered_client.client_secret_encrypted: # First-Party - from mcpgateway.utils.fernet_encryption import get_fernet_encryption + from mcpgateway.services.encryption_service import get_encryption_service - encryption = get_fernet_encryption(settings.auth_encryption_secret) + encryption = get_encryption_service(settings.auth_encryption_secret) decrypted_secret = encryption.decrypt_secret(registered_client.client_secret_encrypted) # Update oauth_config with registered credentials diff --git a/mcpgateway/services/dcr_service.py b/mcpgateway/services/dcr_service.py index e300d3539..3e95d1a52 100644 --- a/mcpgateway/services/dcr_service.py +++ b/mcpgateway/services/dcr_service.py @@ -25,7 +25,7 @@ # First-Party from mcpgateway.config import get_settings from mcpgateway.db import RegisteredOAuthClient -from mcpgateway.utils.fernet_encryption import get_fernet_encryption +from mcpgateway.services.encryption_service import get_encryption_service logger = logging.getLogger(__name__) @@ -168,7 +168,7 @@ async def register_client(self, gateway_id: str, gateway_name: str, issuer: str, raise DcrError(f"Failed to register client with {issuer}: {e}") # Encrypt secrets - encryption = get_fernet_encryption(self.settings.auth_encryption_secret) + encryption = get_encryption_service(self.settings.auth_encryption_secret) client_secret = registration_response.get("client_secret") client_secret_encrypted = encryption.encrypt_secret(client_secret) if client_secret else None @@ -260,7 +260,7 @@ async def update_client_registration(self, client_record: RegisteredOAuthClient, raise DcrError("Cannot update client: no registration_access_token available") # Decrypt registration access token - encryption = get_fernet_encryption(self.settings.auth_encryption_secret) + encryption = get_encryption_service(self.settings.auth_encryption_secret) registration_access_token = encryption.decrypt_secret(client_record.registration_access_token_encrypted) # Build update request @@ -313,7 +313,7 @@ async def delete_client_registration(self, client_record: RegisteredOAuthClient, return True # Consider it deleted locally # Decrypt registration access token - encryption = get_fernet_encryption(self.settings.auth_encryption_secret) + encryption = get_encryption_service(self.settings.auth_encryption_secret) registration_access_token = encryption.decrypt_secret(client_record.registration_access_token_encrypted) # Send delete request diff --git a/mcpgateway/utils/fernet_encryption.py b/mcpgateway/services/encryption_service.py similarity index 88% rename from mcpgateway/utils/fernet_encryption.py rename to mcpgateway/services/encryption_service.py index f139e054a..369a13644 100644 --- a/mcpgateway/utils/fernet_encryption.py +++ b/mcpgateway/services/encryption_service.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- -"""Location: ./mcpgateway/utils/oauth_encryption.py +"""Location: ./mcpgateway/services/encryption_service.py Copyright 2025 SPDX-License-Identifier: Apache-2.0 -Authors: Mihai Criveti +Authors: Mihai Criveti, Madhav Kandukuri -Fernet Encryption Utilities. +Encryption Service. -This module provides encryption and decryption functions for client secrets +This service provides encryption and decryption functions for client secrets using the AUTH_ENCRYPTION_SECRET from configuration. """ @@ -28,12 +28,12 @@ logger = logging.getLogger(__name__) -class FernetEncryption: - """Handles Fernet encryption and decryption of client secrets. +class EncryptionService: + """Handles encryption and decryption of client secrets. Examples: Basic roundtrip: - >>> enc = FernetEncryption(SecretStr('very-secret-key')) + >>> enc = EncryptionService(SecretStr('very-secret-key')) >>> cipher = enc.encrypt_secret('hello') >>> isinstance(cipher, str) and enc.is_encrypted(cipher) True @@ -156,18 +156,18 @@ def is_encrypted(self, text: str) -> bool: return False -def get_fernet_encryption(encryption_secret: SecretStr) -> FernetEncryption: - """Get an Fernet encryption instance. +def get_encryption_service(encryption_secret: SecretStr) -> EncryptionService: + """Get an EncryptionService instance. Args: encryption_secret: Secret key for encryption/decryption Returns: - FernetEncryption instance + EncryptionService instance Examples: - >>> enc = get_fernet_encryption(SecretStr('k')) - >>> isinstance(enc, FernetEncryption) + >>> enc = get_encryption_service(SecretStr('k')) + >>> isinstance(enc, EncryptionService) True """ - return FernetEncryption(encryption_secret) + return EncryptionService(encryption_secret) diff --git a/mcpgateway/services/oauth_manager.py b/mcpgateway/services/oauth_manager.py index 304bec339..610ed8213 100644 --- a/mcpgateway/services/oauth_manager.py +++ b/mcpgateway/services/oauth_manager.py @@ -28,7 +28,7 @@ # First-Party from mcpgateway.config import get_settings -from mcpgateway.utils.fernet_encryption import get_fernet_encryption +from mcpgateway.services.encryption_service import get_encryption_service logger = logging.getLogger(__name__) @@ -222,7 +222,7 @@ async def _client_credentials_flow(self, credentials: Dict[str, Any]) -> str: if len(client_secret) > 50: # Simple heuristic: encrypted secrets are longer try: settings = get_settings() - encryption = get_fernet_encryption(settings.auth_encryption_secret) + encryption = get_encryption_service(settings.auth_encryption_secret) decrypted_secret = encryption.decrypt_secret(client_secret) if decrypted_secret: client_secret = decrypted_secret @@ -313,7 +313,7 @@ async def _password_flow(self, credentials: Dict[str, Any]) -> str: if client_secret and len(client_secret) > 50: # Simple heuristic: encrypted secrets are longer try: settings = get_settings() - encryption = get_fernet_encryption(settings.auth_encryption_secret) + encryption = get_encryption_service(settings.auth_encryption_secret) decrypted_secret = encryption.decrypt_secret(client_secret) if decrypted_secret: client_secret = decrypted_secret @@ -430,7 +430,7 @@ async def exchange_code_for_token(self, credentials: Dict[str, Any], code: str, if client_secret and len(client_secret) > 50: # Simple heuristic: encrypted secrets are longer try: settings = get_settings() - encryption = get_fernet_encryption(settings.auth_encryption_secret) + encryption = get_encryption_service(settings.auth_encryption_secret) decrypted_secret = encryption.decrypt_secret(client_secret) if decrypted_secret: client_secret = decrypted_secret @@ -1007,7 +1007,7 @@ async def _exchange_code_for_tokens(self, credentials: Dict[str, Any], code: str if client_secret and len(client_secret) > 50: # Simple heuristic: encrypted secrets are longer try: settings = get_settings() - encryption = get_fernet_encryption(settings.auth_encryption_secret) + encryption = get_encryption_service(settings.auth_encryption_secret) decrypted_secret = encryption.decrypt_secret(client_secret) if decrypted_secret: client_secret = decrypted_secret diff --git a/mcpgateway/services/sso_service.py b/mcpgateway/services/sso_service.py index 89149032f..fb3bbdc78 100644 --- a/mcpgateway/services/sso_service.py +++ b/mcpgateway/services/sso_service.py @@ -30,8 +30,8 @@ from mcpgateway.config import settings from mcpgateway.db import PendingUserApproval, SSOAuthSession, SSOProvider, utc_now from mcpgateway.services.email_auth_service import EmailAuthService +from mcpgateway.services.encryption_service import get_encryption_service from mcpgateway.utils.create_jwt_token import create_jwt_token -from mcpgateway.utils.fernet_encryption import get_fernet_encryption # Logger logger = logging.getLogger(__name__) @@ -61,7 +61,7 @@ def __init__(self, db: Session): """ self.db = db self.auth_service = EmailAuthService(db) - self._encryption = get_fernet_encryption(settings.auth_encryption_secret) + self._encryption = get_encryption_service(settings.auth_encryption_secret) def _encrypt_secret(self, secret: str) -> str: """Encrypt a client secret for secure storage. diff --git a/mcpgateway/services/token_storage_service.py b/mcpgateway/services/token_storage_service.py index 8c85de7df..ef6b18e66 100644 --- a/mcpgateway/services/token_storage_service.py +++ b/mcpgateway/services/token_storage_service.py @@ -22,8 +22,8 @@ # First-Party from mcpgateway.config import get_settings from mcpgateway.db import OAuthToken +from mcpgateway.services.encryption_service import get_encryption_service from mcpgateway.services.oauth_manager import OAuthError -from mcpgateway.utils.fernet_encryption import get_fernet_encryption logger = logging.getLogger(__name__) @@ -68,7 +68,7 @@ def __init__(self, db: Session): self.db = db try: settings = get_settings() - self.encryption = get_fernet_encryption(settings.auth_encryption_secret) + self.encryption = get_encryption_service(settings.auth_encryption_secret) except (ImportError, AttributeError): logger.warning("OAuth encryption not available, using plain text storage") self.encryption = None diff --git a/tests/unit/mcpgateway/services/test_dcr_service.py b/tests/unit/mcpgateway/services/test_dcr_service.py index 379a0e558..a9b0a6987 100644 --- a/tests/unit/mcpgateway/services/test_dcr_service.py +++ b/tests/unit/mcpgateway/services/test_dcr_service.py @@ -429,7 +429,7 @@ class TestUpdateClientRegistration: @pytest.mark.asyncio async def test_update_client_registration_success(self, test_db): """Test successful client registration update.""" - from mcpgateway.utils.fernet_encryption import get_fernet_encryption + from mcpgateway.services.encryption_service import get_encryption_service from mcpgateway.config import get_settings dcr_service = DcrService() @@ -442,7 +442,7 @@ async def test_update_client_registration_success(self, test_db): test_db.commit() # Encrypt the registration access token properly - encryption = get_fernet_encryption(get_settings().auth_encryption_secret) + encryption = get_encryption_service(get_settings().auth_encryption_secret) encrypted_token = encryption.encrypt_secret("registration-access-token") client_record = RegisteredOAuthClient( @@ -474,7 +474,7 @@ async def test_update_client_registration_success(self, test_db): @pytest.mark.asyncio async def test_update_client_registration_uses_access_token(self, test_db): """Test that update uses registration_access_token.""" - from mcpgateway.utils.fernet_encryption import get_fernet_encryption + from mcpgateway.services.encryption_service import get_encryption_service from mcpgateway.config import get_settings dcr_service = DcrService() @@ -487,7 +487,7 @@ async def test_update_client_registration_uses_access_token(self, test_db): test_db.commit() # Encrypt the registration access token properly - encryption = get_fernet_encryption(get_settings().auth_encryption_secret) + encryption = get_encryption_service(get_settings().auth_encryption_secret) encrypted_token = encryption.encrypt_secret("registration-access-token") client_record = RegisteredOAuthClient( diff --git a/tests/unit/mcpgateway/test_admin.py b/tests/unit/mcpgateway/test_admin.py index 071cdab0d..568c35477 100644 --- a/tests/unit/mcpgateway/test_admin.py +++ b/tests/unit/mcpgateway/test_admin.py @@ -2114,7 +2114,7 @@ async def test_admin_add_gateway_with_oauth_config(self, mock_register_gateway, mock_request.form = AsyncMock(return_value=form_data) # Mock OAuth encryption - with patch("mcpgateway.admin.get_fernet_encryption") as mock_get_encryption: + with patch("mcpgateway.admin.get_encryption_service") as mock_get_encryption: mock_encryption = MagicMock() mock_encryption.encrypt_secret.return_value = "encrypted-secret" mock_get_encryption.return_value = mock_encryption @@ -2175,7 +2175,7 @@ async def test_admin_edit_gateway_with_oauth_config(self, mock_update_gateway, m mock_request.form = AsyncMock(return_value=form_data) # Mock OAuth encryption - with patch("mcpgateway.admin.get_fernet_encryption") as mock_get_encryption: + with patch("mcpgateway.admin.get_encryption_service") as mock_get_encryption: mock_encryption = MagicMock() mock_encryption.encrypt_secret.return_value = "encrypted-edit-secret" mock_get_encryption.return_value = mock_encryption @@ -2204,7 +2204,7 @@ async def test_admin_edit_gateway_oauth_empty_client_secret(self, mock_update_ga mock_request.form = AsyncMock(return_value=form_data) # Mock OAuth encryption - should not be called for empty secret - with patch("mcpgateway.admin.get_fernet_encryption") as mock_get_encryption: + with patch("mcpgateway.admin.get_encryption_service") as mock_get_encryption: mock_encryption = MagicMock() mock_get_encryption.return_value = mock_encryption diff --git a/tests/unit/mcpgateway/test_oauth_manager.py b/tests/unit/mcpgateway/test_oauth_manager.py index 769b2d043..3431119a6 100644 --- a/tests/unit/mcpgateway/test_oauth_manager.py +++ b/tests/unit/mcpgateway/test_oauth_manager.py @@ -20,7 +20,7 @@ from mcpgateway.db import OAuthToken from mcpgateway.services.oauth_manager import OAuthError, OAuthManager from mcpgateway.services.token_storage_service import TokenStorageService -from mcpgateway.utils.fernet_encryption import FernetEncryption +from mcpgateway.services.encryption_service import EncryptionService class TestOAuthManager: @@ -298,7 +298,7 @@ async def test_client_credentials_flow_with_encrypted_secret(self): mock_settings.auth_encryption_secret = "test_key" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.oauth_manager.get_fernet_encryption") as mock_get_encryption: + with patch("mcpgateway.services.oauth_manager.get_encryption_service") as mock_get_encryption: mock_encryption = Mock() mock_encryption.decrypt_secret.return_value = "decrypted_secret" mock_get_encryption.return_value = mock_encryption @@ -371,7 +371,7 @@ async def test_client_credentials_flow_decryption_returns_none(self): mock_settings.auth_encryption_secret = "test_key" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.oauth_manager.get_fernet_encryption") as mock_get_encryption: + with patch("mcpgateway.services.oauth_manager.get_encryption_service") as mock_get_encryption: mock_encryption = Mock() # Decryption returns None - line 108 mock_encryption.decrypt_secret.return_value = None @@ -890,7 +890,7 @@ async def test_exchange_code_for_tokens_decryption_returns_none(self): mock_settings.auth_encryption_secret = "test_key" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.oauth_manager.get_fernet_encryption") as mock_get_encryption: + with patch("mcpgateway.services.oauth_manager.get_encryption_service") as mock_get_encryption: mock_encryption = Mock() # Decryption returns None - lines 438-439 mock_encryption.decrypt_secret.return_value = None @@ -981,7 +981,7 @@ async def test_exchange_code_for_token_decryption_returns_none(self): mock_settings.auth_encryption_secret = "test_key" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.oauth_manager.get_fernet_encryption") as mock_get_encryption: + with patch("mcpgateway.services.oauth_manager.get_encryption_service") as mock_get_encryption: mock_encryption = Mock() # Decryption returns None - lines 216-217 mock_encryption.decrypt_secret.return_value = None @@ -1260,7 +1260,7 @@ async def test_exchange_code_for_tokens_decryption_success(self): mock_settings.auth_encryption_secret = "test_key" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.oauth_manager.get_fernet_encryption") as mock_get_encryption: + with patch("mcpgateway.services.oauth_manager.get_encryption_service") as mock_get_encryption: mock_encryption = Mock() # Decryption succeeds - lines 435-437 mock_encryption.decrypt_secret.return_value = "decrypted_secret" @@ -1301,7 +1301,7 @@ async def test_exchange_code_for_tokens_decryption_exception(self): mock_settings.auth_encryption_secret = "test_key" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.oauth_manager.get_fernet_encryption") as mock_get_encryption: + with patch("mcpgateway.services.oauth_manager.get_encryption_service") as mock_get_encryption: mock_encryption = Mock() # Decryption throws exception - lines 440-441 mock_encryption.decrypt_secret.side_effect = ValueError("Decryption failed") @@ -1454,7 +1454,7 @@ async def test_exchange_code_for_token_decryption_success(self): mock_settings.auth_encryption_secret = "test_key" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.oauth_manager.get_fernet_encryption") as mock_get_encryption: + with patch("mcpgateway.services.oauth_manager.get_encryption_service") as mock_get_encryption: mock_encryption = Mock() # Decryption succeeds - lines 213-215 mock_encryption.decrypt_secret.return_value = "decrypted_secret" @@ -1495,7 +1495,7 @@ async def test_exchange_code_for_token_decryption_exception(self): mock_settings.auth_encryption_secret = "test_key" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.oauth_manager.get_fernet_encryption") as mock_get_encryption: + with patch("mcpgateway.services.oauth_manager.get_encryption_service") as mock_get_encryption: mock_encryption = Mock() # Decryption throws exception - lines 218-219 mock_encryption.decrypt_secret.side_effect = ValueError("Decryption failed") @@ -1661,7 +1661,7 @@ def test_init_with_encryption(self): mock_settings.auth_encryption_secret = "test_secret_key" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.token_storage_service.get_fernet_encryption") as mock_get_enc: + with patch("mcpgateway.services.token_storage_service.get_encryption_service") as mock_get_enc: mock_encryption = Mock() mock_get_enc.return_value = mock_encryption @@ -1709,7 +1709,7 @@ async def test_store_tokens_new_record_with_encryption(self): mock_settings.auth_encryption_secret = "test_secret" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.token_storage_service.get_fernet_encryption") as mock_get_enc: + with patch("mcpgateway.services.token_storage_service.get_encryption_service") as mock_get_enc: mock_get_enc.return_value = mock_encryption service = TokenStorageService(mock_db) @@ -1810,7 +1810,7 @@ async def test_store_tokens_update_existing_record(self): mock_settings.auth_encryption_secret = "test_secret" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.token_storage_service.get_fernet_encryption") as mock_get_enc: + with patch("mcpgateway.services.token_storage_service.get_encryption_service") as mock_get_enc: mock_get_enc.return_value = mock_encryption service = TokenStorageService(mock_db) @@ -1854,7 +1854,7 @@ async def test_store_tokens_without_refresh_token(self): mock_settings.auth_encryption_secret = "test_secret" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.token_storage_service.get_fernet_encryption") as mock_get_enc: + with patch("mcpgateway.services.token_storage_service.get_encryption_service") as mock_get_enc: mock_get_enc.return_value = mock_encryption service = TokenStorageService(mock_db) @@ -1911,7 +1911,7 @@ async def test_get_valid_token_success_with_encryption(self): mock_settings.auth_encryption_secret = "test_secret" mock_get_settings.return_value = mock_settings - with patch("mcpgateway.services.token_storage_service.get_fernet_encryption") as mock_get_enc: + with patch("mcpgateway.services.token_storage_service.get_encryption_service") as mock_get_enc: mock_get_enc.return_value = mock_encryption service = TokenStorageService(mock_db) @@ -2442,17 +2442,17 @@ async def test_cleanup_expired_tokens_exception(self): mock_db.rollback.assert_called_once() -class TestFernetEncryption: - """Test cases for FernetEncryption class.""" +class TestEncryptionService: + """Test cases for EncryptionService class.""" def test_init(self): - """Test FernetEncryption initialization.""" - encryption = FernetEncryption(SecretStr("test_secret_key")) + """Test EncryptionService initialization.""" + encryption = EncryptionService(SecretStr("test_secret_key")) assert encryption.encryption_secret == b"test_secret_key" def test_encrypt_secret_success(self): """Test successful secret encryption.""" - encryption = FernetEncryption(SecretStr("test_secret_key")) + encryption = EncryptionService(SecretStr("test_secret_key")) plaintext = "my_secret_token_123" encrypted = encryption.encrypt_secret(plaintext) @@ -2467,8 +2467,8 @@ def test_encrypt_secret_success(self): def test_encrypt_secret_different_keys_different_output(self): """Test that different keys produce different encrypted output.""" - encryption1 = FernetEncryption(SecretStr("key1")) - encryption2 = FernetEncryption(SecretStr("key2")) + encryption1 = EncryptionService(SecretStr("key1")) + encryption2 = EncryptionService(SecretStr("key2")) plaintext = "same_secret" encrypted1 = encryption1.encrypt_secret(plaintext) @@ -2479,7 +2479,7 @@ def test_encrypt_secret_different_keys_different_output(self): def test_encrypt_secret_same_key_different_output(self): """Test that same key produces different encrypted output due to nonce.""" - encryption = FernetEncryption(SecretStr("test_key")) + encryption = EncryptionService(SecretStr("test_key")) plaintext = "same_secret" encrypted1 = encryption.encrypt_secret(plaintext) @@ -2494,7 +2494,7 @@ def test_encrypt_secret_same_key_different_output(self): def test_encrypt_secret_empty_string(self): """Test encrypting empty string.""" - encryption = FernetEncryption(SecretStr("test_key")) + encryption = EncryptionService(SecretStr("test_key")) encrypted = encryption.encrypt_secret("") decrypted = encryption.decrypt_secret(encrypted) @@ -2503,7 +2503,7 @@ def test_encrypt_secret_empty_string(self): def test_encrypt_secret_unicode_characters(self): """Test encrypting string with unicode characters.""" - encryption = FernetEncryption(SecretStr("test_key")) + encryption = EncryptionService(SecretStr("test_key")) plaintext = "šŸ” secret with Ć©mojis and spĆ©ciĆ l chars Ʊ" encrypted = encryption.encrypt_secret(plaintext) @@ -2513,7 +2513,7 @@ def test_encrypt_secret_unicode_characters(self): def test_encrypt_secret_exception_handling(self): """Test exception handling in encrypt_secret.""" - encryption = FernetEncryption(SecretStr("test_key")) + encryption = EncryptionService(SecretStr("test_key")) with patch.object(encryption, "derive_key_argon2id", side_effect=Exception("Encryption failed")): with pytest.raises(Exception, match="Encryption failed"): @@ -2521,7 +2521,7 @@ def test_encrypt_secret_exception_handling(self): def test_decrypt_secret_success(self): """Test successful secret decryption.""" - encryption = FernetEncryption(SecretStr("test_secret_key")) + encryption = EncryptionService(SecretStr("test_secret_key")) plaintext = "original_secret" # First encrypt @@ -2534,7 +2534,7 @@ def test_decrypt_secret_success(self): def test_decrypt_secret_invalid_data(self): """Test decryption with invalid encrypted data.""" - encryption = FernetEncryption(SecretStr("test_key")) + encryption = EncryptionService(SecretStr("test_key")) result = encryption.decrypt_secret("invalid_encrypted_data") @@ -2542,8 +2542,8 @@ def test_decrypt_secret_invalid_data(self): def test_decrypt_secret_wrong_key(self): """Test decryption with wrong key.""" - encryption1 = FernetEncryption(SecretStr("key1")) - encryption2 = FernetEncryption(SecretStr("key2")) + encryption1 = EncryptionService(SecretStr("key1")) + encryption2 = EncryptionService(SecretStr("key2")) # Encrypt with one key encrypted = encryption1.encrypt_secret("secret") @@ -2555,7 +2555,7 @@ def test_decrypt_secret_wrong_key(self): def test_decrypt_secret_corrupted_data(self): """Test decryption with corrupted base64 data.""" - encryption = FernetEncryption(SecretStr("test_key")) + encryption = EncryptionService(SecretStr("test_key")) # Create valid encrypted data then corrupt it encrypted = encryption.encrypt_secret("test") @@ -2567,7 +2567,7 @@ def test_decrypt_secret_corrupted_data(self): def test_decrypt_secret_malformed_base64(self): """Test decryption with malformed base64.""" - encryption = FernetEncryption(SecretStr("test_key")) + encryption = EncryptionService(SecretStr("test_key")) result = encryption.decrypt_secret("not_valid_base64!@#") @@ -2575,7 +2575,7 @@ def test_decrypt_secret_malformed_base64(self): def test_decrypt_secret_empty_string(self): """Test decryption with empty string.""" - encryption = FernetEncryption(SecretStr("test_key")) + encryption = EncryptionService(SecretStr("test_key")) result = encryption.decrypt_secret("") @@ -2583,7 +2583,7 @@ def test_decrypt_secret_empty_string(self): def test_is_encrypted_valid_encrypted_data(self): """Test is_encrypted with valid encrypted data.""" - encryption = FernetEncryption(SecretStr("test_key")) + encryption = EncryptionService(SecretStr("test_key")) encrypted = encryption.encrypt_secret("test_data") @@ -2591,14 +2591,14 @@ def test_is_encrypted_valid_encrypted_data(self): def test_is_encrypted_plain_text(self): """Test is_encrypted with plain text.""" - encryption = FernetEncryption(SecretStr("test_key")) + encryption = EncryptionService(SecretStr("test_key")) assert encryption.is_encrypted("plain_text_secret") is False assert encryption.is_encrypted("another_plain_string") is False def test_is_encrypted_short_data(self): """Test is_encrypted with short data.""" - encryption = FernetEncryption(SecretStr("test_key")) + encryption = EncryptionService(SecretStr("test_key")) # Fernet encrypted data should be at least 32 bytes short_data = "dGVzdA==" # "test" in base64 (only 4 bytes when decoded) @@ -2607,7 +2607,7 @@ def test_is_encrypted_short_data(self): def test_is_encrypted_valid_base64_but_not_encrypted(self): """Test is_encrypted with valid base64 that's not encrypted data.""" - encryption = FernetEncryption(SecretStr("test_key")) + encryption = EncryptionService(SecretStr("test_key")) # Create base64 data that's long enough but not encrypted # Standard @@ -2624,32 +2624,32 @@ def test_is_encrypted_valid_base64_but_not_encrypted(self): def test_is_encrypted_invalid_base64(self): """Test is_encrypted with invalid base64.""" - encryption = FernetEncryption(SecretStr("test_key")) + encryption = EncryptionService(SecretStr("test_key")) assert encryption.is_encrypted("not_base64!@#$%") is False def test_is_encrypted_exception_handling(self): """Test exception handling in is_encrypted.""" - encryption = FernetEncryption(SecretStr("test_key")) + encryption = EncryptionService(SecretStr("test_key")) # Test with None (should handle gracefully) with patch("base64.urlsafe_b64decode", side_effect=Exception("Base64 error")): result = encryption.is_encrypted("any_string") assert result is False - def test_get_fernet_encryption_function(self): - """Test the get_fernet_encryption utility function.""" + def test_get_encryption_service_function(self): + """Test the get_encryption_service utility function.""" # First-Party - from mcpgateway.utils.fernet_encryption import get_fernet_encryption + from mcpgateway.services.encryption_service import get_encryption_service - encryption = get_fernet_encryption(SecretStr("test_secret")) + encryption = get_encryption_service(SecretStr("test_secret")) - assert isinstance(encryption, FernetEncryption) + assert isinstance(encryption, EncryptionService) assert encryption.encryption_secret == b"test_secret" def test_encryption_roundtrip_multiple_values(self): """Test encryption/decryption roundtrip with multiple values.""" - encryption = FernetEncryption(SecretStr("test_key")) + encryption = EncryptionService(SecretStr("test_key")) test_values = [ "simple_token", @@ -2670,8 +2670,8 @@ def test_encryption_roundtrip_multiple_values(self): def test_encryption_key_derivation_consistency(self): """Test that key derivation is consistent across instances.""" # Create two instances with same key - encryption1 = FernetEncryption(SecretStr("same_key")) - encryption2 = FernetEncryption(SecretStr("same_key")) + encryption1 = EncryptionService(SecretStr("same_key")) + encryption2 = EncryptionService(SecretStr("same_key")) # Encrypt with first instance plaintext = "test_consistency" @@ -2685,7 +2685,7 @@ def test_encryption_key_derivation_consistency(self): def test_encryption_with_long_key(self): """Test encryption with very long key.""" long_key = SecretStr("a" * 1000) # Very long key - encryption = FernetEncryption(long_key) + encryption = EncryptionService(long_key) encrypted = encryption.encrypt_secret("test_data") decrypted = encryption.decrypt_secret(encrypted) @@ -2695,7 +2695,7 @@ def test_encryption_with_long_key(self): def test_encryption_with_special_char_key(self): """Test encryption with key containing special characters.""" special_key = SecretStr("key_with_special_chars!@#$%^&*()_+-={}[]|\\:;\"'<>?,./") - encryption = FernetEncryption(special_key) + encryption = EncryptionService(special_key) encrypted = encryption.encrypt_secret("test_data") decrypted = encryption.decrypt_secret(encrypted) From ed9d3212d3305cb7d76cdb28c32ec919c72532a8 Mon Sep 17 00:00:00 2001 From: Madhav Kandukuri Date: Thu, 30 Oct 2025 19:14:17 +0530 Subject: [PATCH 10/14] Add missing docstrings Signed-off-by: Madhav Kandukuri --- ...3320c56_use_argon2id_for_encryption_key.py | 30 +++++++++++++++++++ mcpgateway/services/metrics.py | 1 + 2 files changed, 31 insertions(+) diff --git a/mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py b/mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py index 02d9a8952..e4c453a1a 100644 --- a/mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py +++ b/mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py @@ -134,12 +134,28 @@ def reencrypt_with_pbkdf2hmac(argon2id_bundle: str) -> Optional[str]: raise ValueError("Invalid Argon2id bundle") from e def _reflect(conn): + """Reflect relevant tables. + + Args: + conn: The database connection. + + Returns: + A dict of reflected tables. + """ md = sa.MetaData() gateways = sa.Table("gateways", md, autoload_with=conn) a2a_agents = sa.Table("a2a_agents", md, autoload_with=conn) return {"gateways": gateways, "a2a_agents": a2a_agents} def _is_json(col): + """Check if a column is of JSON type. + + Args: + col: The column to check. + + Returns: + True if the column is of JSON type. + """ return isinstance(col.type, sa.JSON) def _looks_argon2_bundle(val: Optional[str]) -> bool: @@ -222,6 +238,12 @@ def _downgrade_value(old: Optional[str]) -> Optional[str]: def _upgrade_json_client_secret(conn, table): + """Upgrade JSON client_secret fields in the given table. + + Args: + conn: The database connection. + table: The table to upgrade. + """ t = table sel = sa.select(t.c.id, t.c.oauth_config).where(t.c.oauth_config.isnot(None)) for row in conn.execute(sel).mappings(): @@ -246,6 +268,12 @@ def _upgrade_json_client_secret(conn, table): def _downgrade_json_client_secret(conn, table): + """Downgrade JSON client_secret fields in the given table. + + Args: + conn: The database connection. + table: The table to downgrade. + """ t = table sel = sa.select(t.c.id, t.c.oauth_config).where(t.c.oauth_config.isnot(None)) for row in conn.execute(sel).mappings(): @@ -270,6 +298,7 @@ def _downgrade_json_client_secret(conn, table): def upgrade() -> None: + """Use Argon2id KDF for encryption key re-encryption.""" bind = op.get_bind() conn = op.get_bind() @@ -385,6 +414,7 @@ def upgrade() -> None: def downgrade() -> None: + """Revert to PBKDF2HMAC KDF for encryption key re-encryption.""" bind = op.get_bind() # JSON: gateways.oauth_config.client_secret diff --git a/mcpgateway/services/metrics.py b/mcpgateway/services/metrics.py index f339d0b25..8d649b012 100644 --- a/mcpgateway/services/metrics.py +++ b/mcpgateway/services/metrics.py @@ -118,4 +118,5 @@ def setup_metrics(app): @app.get("/metrics/prometheus") async def metrics_disabled(): + """Returns metrics response when metrics collection is disabled.""" return Response(content='{"error": "Metrics collection is disabled"}', media_type="application/json", status_code=status.HTTP_503_SERVICE_UNAVAILABLE) From a2f0ca4ca7d9d5616bc135a689ac56969998f4fa Mon Sep 17 00:00:00 2001 From: Madhav Kandukuri Date: Thu, 30 Oct 2025 19:36:00 +0530 Subject: [PATCH 11/14] flake8 fixes Signed-off-by: Madhav Kandukuri --- ...06a3320c56_use_argon2id_for_encryption_key.py | 16 ++++++++++++---- mcpgateway/services/metrics.py | 6 +++++- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py b/mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py index e4c453a1a..39ce84036 100644 --- a/mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py +++ b/mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py @@ -20,6 +20,7 @@ from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC from sqlalchemy import text +import sqlalchemy as sa # First-Party from mcpgateway.config import settings @@ -133,6 +134,7 @@ def reencrypt_with_pbkdf2hmac(argon2id_bundle: str) -> Optional[str]: except Exception as e: raise ValueError("Invalid Argon2id bundle") from e + def _reflect(conn): """Reflect relevant tables. @@ -147,6 +149,7 @@ def _reflect(conn): a2a_agents = sa.Table("a2a_agents", md, autoload_with=conn) return {"gateways": gateways, "a2a_agents": a2a_agents} + def _is_json(col): """Check if a column is of JSON type. @@ -158,6 +161,7 @@ def _is_json(col): """ return isinstance(col.type, sa.JSON) + def _looks_argon2_bundle(val: Optional[str]) -> bool: """Heuristic for Argon2id bundle format (JSON with kdf=argon2id). @@ -255,11 +259,13 @@ def _upgrade_json_client_secret(conn, table): except json.JSONDecodeError as e: logger.warning("Skipping %s.id=%s: invalid JSON (%s)", table, rid, e) continue - if not isinstance(cfg, dict): continue + if not isinstance(cfg, dict): + continue old = cfg.get("client_secret") new = _upgrade_value(old) # your helper - if not new: continue + if not new: + continue cfg["client_secret"] = new value = cfg if _is_json(t.c.oauth_config) else json.dumps(cfg) @@ -285,11 +291,13 @@ def _downgrade_json_client_secret(conn, table): except json.JSONDecodeError as e: logger.warning("Skipping %s.id=%s: invalid JSON (%s)", table, rid, e) continue - if not isinstance(cfg, dict): continue + if not isinstance(cfg, dict): + continue old = cfg.get("client_secret") new = _downgrade_value(old) # your helper - if not new: continue + if not new: + continue cfg["client_secret"] = new value = cfg if _is_json(t.c.oauth_config) else json.dumps(cfg) diff --git a/mcpgateway/services/metrics.py b/mcpgateway/services/metrics.py index 8d649b012..b3276ba23 100644 --- a/mcpgateway/services/metrics.py +++ b/mcpgateway/services/metrics.py @@ -118,5 +118,9 @@ def setup_metrics(app): @app.get("/metrics/prometheus") async def metrics_disabled(): - """Returns metrics response when metrics collection is disabled.""" + """Returns metrics response when metrics collection is disabled. + + Returns: + Response: HTTP 503 response indicating metrics are disabled. + """ return Response(content='{"error": "Metrics collection is disabled"}', media_type="application/json", status_code=status.HTTP_503_SERVICE_UNAVAILABLE) From 17c8ca6c8a6839f260e259a24178143ce5dbaa40 Mon Sep 17 00:00:00 2001 From: Madhav Kandukuri Date: Thu, 30 Oct 2025 22:29:08 +0530 Subject: [PATCH 12/14] Handle str inputs for encryption_secret Signed-off-by: Madhav Kandukuri --- Containerfile.lite | 2 +- Makefile | 2 +- mcpgateway/services/encryption_service.py | 18 +++++++++++++----- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/Containerfile.lite b/Containerfile.lite index 94b8fa0f0..92f50fc48 100644 --- a/Containerfile.lite +++ b/Containerfile.lite @@ -76,7 +76,7 @@ SHELL ["/bin/bash", "-euo", "pipefail", "-c"] ARG PYTHON_VERSION ARG ROOTFS_PATH -ARG TARGETPLATFORM +ARG TARGETPLATFORM=linux/amd64 ARG GRPC_PYTHON_BUILD_SYSTEM_OPENSSL='False' # ---------------------------------------------------------------------------- diff --git a/Makefile b/Makefile index d8ced2b75..3331906c8 100644 --- a/Makefile +++ b/Makefile @@ -2658,7 +2658,7 @@ docker-dev: @$(MAKE) container-build CONTAINER_RUNTIME=docker CONTAINER_FILE=Containerfile docker: - @$(MAKE) container-build CONTAINER_RUNTIME=docker CONTAINER_FILE=Containerfile + @$(MAKE) container-build CONTAINER_RUNTIME=docker CONTAINER_FILE=Containerfile.lite docker-prod: @DOCKER_CONTENT_TRUST=1 $(MAKE) container-build CONTAINER_RUNTIME=docker CONTAINER_FILE=Containerfile.lite diff --git a/mcpgateway/services/encryption_service.py b/mcpgateway/services/encryption_service.py index 369a13644..715d9a174 100644 --- a/mcpgateway/services/encryption_service.py +++ b/mcpgateway/services/encryption_service.py @@ -15,7 +15,7 @@ import json import logging import os -from typing import Optional +from typing import Optional, Union # Third-Party from argon2.low_level import hash_secret_raw, Type @@ -45,7 +45,7 @@ class EncryptionService: False """ - def __init__(self, encryption_secret: SecretStr, time_cost: Optional[int] = None, memory_cost: Optional[int] = None, parallelism: Optional[int] = None, hash_len: int = 32, salt_len: int = 16): + def __init__(self, encryption_secret: Union[SecretStr, str], time_cost: Optional[int] = None, memory_cost: Optional[int] = None, parallelism: Optional[int] = None, hash_len: int = 32, salt_len: int = 16): """Initialize the encryption handler. Args: @@ -56,7 +56,12 @@ def __init__(self, encryption_secret: SecretStr, time_cost: Optional[int] = None hash_len: Length of the derived key salt_len: Length of the salt """ - self.encryption_secret = encryption_secret.get_secret_value().encode() + # Handle both SecretStr and plain string for backwards compatibility + if isinstance(encryption_secret, SecretStr): + self.encryption_secret = encryption_secret.get_secret_value().encode() + else: + # If a plain string is passed, use it directly (for testing/legacy code) + self.encryption_secret = str(encryption_secret).encode() self.time_cost = time_cost or getattr(settings, "argon2id_time_cost", 3) self.memory_cost = memory_cost or getattr(settings, "argon2id_memory_cost", 65536) self.parallelism = parallelism or getattr(settings, "argon2id_parallelism", 1) @@ -156,11 +161,11 @@ def is_encrypted(self, text: str) -> bool: return False -def get_encryption_service(encryption_secret: SecretStr) -> EncryptionService: +def get_encryption_service(encryption_secret: Union[SecretStr, str]) -> EncryptionService: """Get an EncryptionService instance. Args: - encryption_secret: Secret key for encryption/decryption + encryption_secret: Secret key for encryption/decryption (SecretStr or plain string) Returns: EncryptionService instance @@ -169,5 +174,8 @@ def get_encryption_service(encryption_secret: SecretStr) -> EncryptionService: >>> enc = get_encryption_service(SecretStr('k')) >>> isinstance(enc, EncryptionService) True + >>> enc2 = get_encryption_service('plain-key') + >>> isinstance(enc2, EncryptionService) + True """ return EncryptionService(encryption_secret) From 277ceb116ea675ddc10b6d34050f91a3ad3d9a3c Mon Sep 17 00:00:00 2001 From: Madhav Kandukuri Date: Fri, 31 Oct 2025 13:34:56 +0530 Subject: [PATCH 13/14] Update alembic down revision number Signed-off-by: Madhav Kandukuri --- .../versions/a706a3320c56_use_argon2id_for_encryption_key.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py b/mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py index 39ce84036..67463b19a 100644 --- a/mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py +++ b/mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py @@ -1,7 +1,7 @@ """Use Argon2id for encryption key Revision ID: a706a3320c56 -Revises: 3c89a45f32e5 +Revises: h2b3c4d5e6f7 Create Date: 2025-10-30 15:31:25.115536 """ @@ -29,7 +29,7 @@ # revision identifiers, used by Alembic. revision: str = "a706a3320c56" -down_revision: Union[str, Sequence[str], None] = "3c89a45f32e5" +down_revision: Union[str, Sequence[str], None] = "h2b3c4d5e6f7" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None From 8fccacf9e4a9fdcd5aaabfc53ed98037da3e37f8 Mon Sep 17 00:00:00 2001 From: Mihai Criveti Date: Wed, 5 Nov 2025 00:33:18 +0000 Subject: [PATCH 14/14] Fix bandit --- ...3320c56_use_argon2id_for_encryption_key.py | 7 +++--- mcpgateway/services/encryption_service.py | 24 +++++++++++++++++-- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py b/mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py index 67463b19a..09f3075a4 100644 --- a/mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py +++ b/mcpgateway/alembic/versions/a706a3320c56_use_argon2id_for_encryption_key.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- """Use Argon2id for encryption key Revision ID: a706a3320c56 @@ -19,8 +20,8 @@ from cryptography.fernet import Fernet from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC -from sqlalchemy import text import sqlalchemy as sa +from sqlalchemy import text # First-Party from mcpgateway.config import settings @@ -263,7 +264,7 @@ def _upgrade_json_client_secret(conn, table): continue old = cfg.get("client_secret") - new = _upgrade_value(old) # your helper + new = _upgrade_value(old) # your helper if not new: continue @@ -295,7 +296,7 @@ def _downgrade_json_client_secret(conn, table): continue old = cfg.get("client_secret") - new = _downgrade_value(old) # your helper + new = _downgrade_value(old) # your helper if not new: continue diff --git a/mcpgateway/services/encryption_service.py b/mcpgateway/services/encryption_service.py index 715d9a174..b5c926851 100644 --- a/mcpgateway/services/encryption_service.py +++ b/mcpgateway/services/encryption_service.py @@ -45,7 +45,9 @@ class EncryptionService: False """ - def __init__(self, encryption_secret: Union[SecretStr, str], time_cost: Optional[int] = None, memory_cost: Optional[int] = None, parallelism: Optional[int] = None, hash_len: int = 32, salt_len: int = 16): + def __init__( + self, encryption_secret: Union[SecretStr, str], time_cost: Optional[int] = None, memory_cost: Optional[int] = None, parallelism: Optional[int] = None, hash_len: int = 32, salt_len: int = 16 + ): """Initialize the encryption handler. Args: @@ -151,9 +153,27 @@ def is_encrypted(self, text: str) -> bool: Returns: True if the string appears to be encrypted + + Note: + Supports both legacy PBKDF2 (base64-wrapped Fernet) and new Argon2id + (JSON bundle) formats. Checks JSON format first, then falls back to + base64 check for legacy format. """ + if not text: + return False + + # Check for new Argon2id JSON bundle format + if text.startswith("{"): + try: + obj = json.loads(text) + if isinstance(obj, dict) and obj.get("kdf") == "argon2id": + return True + except (json.JSONDecodeError, ValueError, KeyError): + # Not valid JSON or missing expected structure - continue to legacy check + pass + + # Check for legacy PBKDF2 base64-wrapped Fernet format try: - # Try to decode as base64 and check if it looks like encrypted data decoded = base64.urlsafe_b64decode(text.encode()) # Encrypted data should be at least 32 bytes (Fernet minimum) return len(decoded) >= 32