diff --git a/README.md b/README.md index bd0a42d..495fad6 100644 --- a/README.md +++ b/README.md @@ -38,11 +38,14 @@ uv run mcp_proxy_for_aws/server.py docker build -t mcp-proxy-for-aws . ``` -## Configuration Parameters +## Configuration + +### Using Command Line Arguments | Parameter | Description | Default |Required | |----------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------|--- | -| `endpoint` | MCP endpoint URL (e.g., `https://your-service.us-east-1.amazonaws.com/mcp`) | N/A |Yes | +| `endpoint` | MCP endpoint URL (e.g., `https://your-service.us-east-1.amazonaws.com/mcp`) | N/A |Yes* | +| `--config` | Path to YAML configuration file | N/A |No | | --- | --- | --- |--- | | `--service` | AWS service name for SigV4 signing | Inferred from endpoint if not provided |No | | `--profile` | AWS profile for AWS credentials to use | Uses `AWS_PROFILE` environment variable if not set |No | @@ -55,6 +58,39 @@ docker build -t mcp-proxy-for-aws . | `--read-timeout` | Set desired read timeout in seconds | 120 |No | | `--write-timeout` | Set desired write timeout in seconds | 180 |No | +*Required unless using `--config` + +### Using Configuration File + +You can use a YAML configuration file instead of command line arguments: + +```bash +# Copy the example configuration +cp config.example.yaml config.yaml + +# Edit config.yaml with your settings +# Then run with the config file +uvx mcp-proxy-for-aws@latest --config config.yaml +``` + +Example `config.yaml`: + +```yaml +endpoint: "https://your-service.us-east-1.amazonaws.com/mcp" +service: "your-service" +profile: "default" +region: "us-east-1" +read_only: false +log_level: "INFO" +retries: 0 +timeout: 180.0 +connect_timeout: 60.0 +read_timeout: 120.0 +write_timeout: 180.0 +``` + +Command line arguments override configuration file settings. + ## Optional Environment Variables @@ -80,6 +116,8 @@ Add the following configuration to your MCP client config file (e.g., for Amazon ### Running from local - using uv +With command line arguments: + ``` { "mcpServers": { @@ -108,6 +146,28 @@ Add the following configuration to your MCP client config file (e.g., for Amazon } ``` +Or with a configuration file: + +``` +{ + "mcpServers": { + "": { + "disabled": false, + "type": "stdio", + "command": "uv", + "args": [ + "--directory", + "/path/to/mcp_proxy_for_aws", + "run", + "server.py", + "--config", + "/path/to/config.yaml" + ] + } + } +} +``` + ### Using Docker ``` diff --git a/config.example.yaml b/config.example.yaml new file mode 100644 index 0000000..0a350e3 --- /dev/null +++ b/config.example.yaml @@ -0,0 +1,38 @@ +# MCP Proxy for AWS Configuration Example +# Save this file as config.yaml and customize for your environment + +# Required: SigV4 MCP endpoint URL +endpoint: "https://your-service.us-east-1.amazonaws.com/mcp" + +# Optional: AWS service name for SigV4 signing +# If not provided, will be inferred from the endpoint +service: "your-service" + +# Optional: AWS profile to use +# If not provided, uses AWS_PROFILE environment variable +profile: "default" + +# Optional: AWS region to use +# If not provided, uses AWS_REGION environment variable or infers from endpoint +region: "us-east-1" + +# Optional: Disable tools which may require write permissions +# Default: false +read_only: false + +# Optional: Logging level +# Choices: DEBUG, INFO, WARNING, ERROR, CRITICAL +# Default: INFO +log_level: "INFO" + +# Optional: Number of retries when calling endpoint +# Range: 0-10, 0 disables retries +# Default: 0 +retries: 0 + +# Optional: Timeout settings (in seconds) +# Default values shown below +timeout: 180.0 +connect_timeout: 60.0 +read_timeout: 120.0 +write_timeout: 180.0 diff --git a/mcp_proxy_for_aws/cli.py b/mcp_proxy_for_aws/cli.py index db3a930..b99d5ca 100644 --- a/mcp_proxy_for_aws/cli.py +++ b/mcp_proxy_for_aws/cli.py @@ -39,14 +39,23 @@ def parse_args(): # Run with write permissions enabled mcp-proxy-for-aws --read-only + + # Run with configuration file + mcp-proxy-for-aws --config config.yaml """, ) parser.add_argument( 'endpoint', + nargs='?', help='SigV4 MCP endpoint URL', ) + parser.add_argument( + '--config', + help='Path to YAML configuration file', + ) + parser.add_argument( '--service', help='AWS service name for SigV4 signing (inferred from endpoint if not provided)', diff --git a/mcp_proxy_for_aws/config.py b/mcp_proxy_for_aws/config.py new file mode 100644 index 0000000..7d89376 --- /dev/null +++ b/mcp_proxy_for_aws/config.py @@ -0,0 +1,162 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration file loading for MCP Proxy for AWS.""" + +import logging +import os +import yaml +from pathlib import Path +from typing import Any, Dict, Optional + + +logger = logging.getLogger(__name__) + + +class Config: + """Configuration container for MCP Proxy.""" + + def __init__( + self, + endpoint: str, + service: Optional[str] = None, + profile: Optional[str] = None, + region: Optional[str] = None, + read_only: bool = False, + log_level: str = 'INFO', + retries: int = 0, + timeout: float = 180.0, + connect_timeout: float = 60.0, + read_timeout: float = 120.0, + write_timeout: float = 180.0, + ): + """Initialize configuration. + + Args: + endpoint: SigV4 MCP endpoint URL + service: AWS service name for SigV4 signing + profile: AWS profile to use + region: AWS region to use + read_only: Disable tools which may require write permissions + log_level: Logging level + retries: Number of retries when calling endpoint + timeout: Timeout when connecting to endpoint + connect_timeout: Connection timeout + read_timeout: Read timeout + write_timeout: Write timeout + """ + self.endpoint = endpoint + self.service = service + self.profile = profile + self.region = region + self.read_only = read_only + self.log_level = log_level + self.retries = retries + self.timeout = timeout + self.connect_timeout = connect_timeout + self.read_timeout = read_timeout + self.write_timeout = write_timeout + + +def load_config_file(config_path: str) -> Dict[str, Any]: + """Load configuration from YAML file. + + Args: + config_path: Path to YAML configuration file + + Returns: + Dictionary containing configuration values + + Raises: + FileNotFoundError: If config file doesn't exist + yaml.YAMLError: If config file is invalid YAML + ValueError: If config file has invalid structure + """ + path = Path(config_path).expanduser() + + if not path.exists(): + raise FileNotFoundError(f'Configuration file not found: {config_path}') + + logger.info('Loading configuration from: %s', config_path) + + with open(path, 'r') as f: + try: + config_data = yaml.safe_load(f) + except yaml.YAMLError as e: + raise yaml.YAMLError(f'Invalid YAML in configuration file: {e}') + + if not isinstance(config_data, dict): + raise ValueError('Configuration file must contain a YAML dictionary') + + return config_data + + +def merge_config(file_config: Optional[Dict[str, Any]], cli_args: Any) -> Config: + """Merge configuration from file and CLI arguments. + + CLI arguments take precedence over file configuration. + + Args: + file_config: Configuration loaded from file (or None) + cli_args: Parsed command-line arguments + + Returns: + Config object with merged configuration + """ + # Start with file config or empty dict + config_dict = file_config.copy() if file_config else {} + + # CLI args override file config (only if explicitly provided) + if hasattr(cli_args, 'endpoint') and cli_args.endpoint: + config_dict['endpoint'] = cli_args.endpoint + + if hasattr(cli_args, 'service') and cli_args.service: + config_dict['service'] = cli_args.service + + if hasattr(cli_args, 'profile') and cli_args.profile: + config_dict['profile'] = cli_args.profile + + if hasattr(cli_args, 'region') and cli_args.region: + config_dict['region'] = cli_args.region + + if hasattr(cli_args, 'read_only') and cli_args.read_only: + config_dict['read_only'] = cli_args.read_only + + if hasattr(cli_args, 'log_level') and cli_args.log_level: + config_dict['log_level'] = cli_args.log_level + + if hasattr(cli_args, 'retries') and cli_args.retries is not None: + config_dict['retries'] = cli_args.retries + + if hasattr(cli_args, 'timeout') and cli_args.timeout is not None: + config_dict['timeout'] = cli_args.timeout + + if hasattr(cli_args, 'connect_timeout') and cli_args.connect_timeout is not None: + config_dict['connect_timeout'] = cli_args.connect_timeout + + if hasattr(cli_args, 'read_timeout') and cli_args.read_timeout is not None: + config_dict['read_timeout'] = cli_args.read_timeout + + if hasattr(cli_args, 'write_timeout') and cli_args.write_timeout is not None: + config_dict['write_timeout'] = cli_args.write_timeout + + # Validate required fields + if 'endpoint' not in config_dict: + raise ValueError('endpoint is required (provide via CLI or config file)') + + # Apply environment variable defaults if not set + if 'profile' not in config_dict: + config_dict['profile'] = os.getenv('AWS_PROFILE') + + return Config(**config_dict) diff --git a/mcp_proxy_for_aws/server.py b/mcp_proxy_for_aws/server.py index 5eb51df..13cbba7 100644 --- a/mcp_proxy_for_aws/server.py +++ b/mcp_proxy_for_aws/server.py @@ -29,6 +29,7 @@ from fastmcp.server.middleware.logging import LoggingMiddleware from fastmcp.server.server import FastMCP from mcp_proxy_for_aws.cli import parse_args +from mcp_proxy_for_aws.config import load_config_file, merge_config from mcp_proxy_for_aws.logging_config import configure_logging from mcp_proxy_for_aws.middleware.tool_filter import ToolFilteringMiddleware from mcp_proxy_for_aws.utils import ( @@ -126,7 +127,23 @@ def add_logging_middleware(mcp: FastMCP, log_level: str) -> None: def main(): """Run the MCP server.""" - args = parse_args() + cli_args = parse_args() + + # Load config file if specified + file_config = None + if cli_args.config: + try: + file_config = load_config_file(cli_args.config) + except Exception as e: + print(f'Error loading configuration file: {e}') + raise + + # Merge file config and CLI args + try: + args = merge_config(file_config, cli_args) + except ValueError as e: + print(f'Configuration error: {e}') + raise # Configure logging configure_logging(args.log_level) diff --git a/pyproject.toml b/pyproject.toml index c5e7344..cdbf428 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ dependencies = [ "fastmcp>=2.11.2", "boto3>=1.34.0", "botocore>=1.34.0", + "pyyaml>=6.0.0", ] license = {text = "Apache-2.0"} license-files = ["LICENSE", "NOTICE" ] diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py new file mode 100644 index 0000000..b4bec94 --- /dev/null +++ b/tests/unit/test_config.py @@ -0,0 +1,261 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for configuration file loading.""" + +import argparse +import os +import pytest +import tempfile +import yaml +from mcp_proxy_for_aws.config import Config, load_config_file, merge_config +from pathlib import Path + + +class TestLoadConfigFile: + """Tests for load_config_file function.""" + + def test_load_valid_config(self): + """Test loading a valid YAML configuration file.""" + config_data = { + 'endpoint': 'https://example.com/mcp', + 'service': 'test-service', + 'profile': 'test-profile', + 'region': 'us-west-2', + 'read_only': True, + 'log_level': 'DEBUG', + 'retries': 3, + 'timeout': 200.0, + } + + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + yaml.dump(config_data, f) + config_path = f.name + + try: + loaded_config = load_config_file(config_path) + assert loaded_config == config_data + finally: + os.unlink(config_path) + + def test_load_nonexistent_file(self): + """Test loading a non-existent configuration file.""" + with pytest.raises(FileNotFoundError): + load_config_file('/nonexistent/config.yaml') + + def test_load_invalid_yaml(self): + """Test loading an invalid YAML file.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + f.write('invalid: yaml: content:\n - broken') + config_path = f.name + + try: + with pytest.raises(yaml.YAMLError): + load_config_file(config_path) + finally: + os.unlink(config_path) + + def test_load_non_dict_yaml(self): + """Test loading a YAML file that doesn't contain a dictionary.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + yaml.dump(['list', 'of', 'items'], f) + config_path = f.name + + try: + with pytest.raises(ValueError, match='must contain a YAML dictionary'): + load_config_file(config_path) + finally: + os.unlink(config_path) + + def test_load_with_tilde_expansion(self): + """Test loading a config file with ~ in path.""" + config_data = {'endpoint': 'https://example.com/mcp'} + + # Create temp file in home directory + home = Path.home() + config_path = home / '.mcp-proxy-test-config.yaml' + + try: + with open(config_path, 'w') as f: + yaml.dump(config_data, f) + + # Load using ~ notation + loaded_config = load_config_file(f'~/{config_path.name}') + assert loaded_config == config_data + finally: + if config_path.exists(): + config_path.unlink() + + +class TestMergeConfig: + """Tests for merge_config function.""" + + def test_merge_with_no_file_config(self): + """Test merging when no file config is provided.""" + cli_args = argparse.Namespace( + endpoint='https://cli.example.com/mcp', + service='cli-service', + profile='cli-profile', + region='us-east-1', + read_only=True, + log_level='INFO', + retries=2, + timeout=150.0, + connect_timeout=50.0, + read_timeout=100.0, + write_timeout=150.0, + ) + + config = merge_config(None, cli_args) + + assert config.endpoint == 'https://cli.example.com/mcp' + assert config.service == 'cli-service' + assert config.profile == 'cli-profile' + assert config.region == 'us-east-1' + assert config.read_only is True + assert config.log_level == 'INFO' + assert config.retries == 2 + + def test_merge_cli_overrides_file(self): + """Test that CLI arguments override file configuration.""" + file_config = { + 'endpoint': 'https://file.example.com/mcp', + 'service': 'file-service', + 'profile': 'file-profile', + 'region': 'us-west-2', + 'read_only': False, + 'log_level': 'WARNING', + 'retries': 1, + } + + cli_args = argparse.Namespace( + endpoint='https://cli.example.com/mcp', + service='cli-service', + profile=None, + region=None, + read_only=True, + log_level='DEBUG', + retries=None, + timeout=None, + connect_timeout=None, + read_timeout=None, + write_timeout=None, + ) + + config = merge_config(file_config, cli_args) + + # CLI overrides + assert config.endpoint == 'https://cli.example.com/mcp' + assert config.service == 'cli-service' + assert config.read_only is True + assert config.log_level == 'DEBUG' + + # File values preserved + assert config.profile == 'file-profile' + assert config.region == 'us-west-2' + assert config.retries == 1 + + def test_merge_with_environment_variables(self, monkeypatch): + """Test that environment variables are used as defaults.""" + monkeypatch.setenv('AWS_PROFILE', 'env-profile') + + file_config = { + 'endpoint': 'https://file.example.com/mcp', + } + + cli_args = argparse.Namespace( + endpoint=None, + service=None, + profile=None, + region=None, + read_only=False, + log_level='INFO', + retries=0, + timeout=180.0, + connect_timeout=60.0, + read_timeout=120.0, + write_timeout=180.0, + ) + + config = merge_config(file_config, cli_args) + + assert config.endpoint == 'https://file.example.com/mcp' + assert config.profile == 'env-profile' + + def test_merge_missing_required_endpoint(self): + """Test that missing endpoint raises ValueError.""" + cli_args = argparse.Namespace( + endpoint=None, + service=None, + profile=None, + region=None, + read_only=False, + log_level='INFO', + retries=0, + timeout=180.0, + connect_timeout=60.0, + read_timeout=120.0, + write_timeout=180.0, + ) + + with pytest.raises(ValueError, match='endpoint is required'): + merge_config({}, cli_args) + + +class TestConfig: + """Tests for Config class.""" + + def test_config_initialization(self): + """Test Config object initialization.""" + config = Config( + endpoint='https://example.com/mcp', + service='test-service', + profile='test-profile', + region='us-east-1', + read_only=True, + log_level='DEBUG', + retries=3, + timeout=200.0, + connect_timeout=70.0, + read_timeout=130.0, + write_timeout=200.0, + ) + + assert config.endpoint == 'https://example.com/mcp' + assert config.service == 'test-service' + assert config.profile == 'test-profile' + assert config.region == 'us-east-1' + assert config.read_only is True + assert config.log_level == 'DEBUG' + assert config.retries == 3 + assert config.timeout == 200.0 + assert config.connect_timeout == 70.0 + assert config.read_timeout == 130.0 + assert config.write_timeout == 200.0 + + def test_config_with_defaults(self): + """Test Config object with default values.""" + config = Config(endpoint='https://example.com/mcp') + + assert config.endpoint == 'https://example.com/mcp' + assert config.service is None + assert config.profile is None + assert config.region is None + assert config.read_only is False + assert config.log_level == 'INFO' + assert config.retries == 0 + assert config.timeout == 180.0 + assert config.connect_timeout == 60.0 + assert config.read_timeout == 120.0 + assert config.write_timeout == 180.0 diff --git a/uv.lock b/uv.lock index 4f46d8c..6d25aa6 100644 --- a/uv.lock +++ b/uv.lock @@ -1631,6 +1631,7 @@ dependencies = [ { name = "boto3" }, { name = "botocore" }, { name = "fastmcp" }, + { name = "pyyaml" }, ] [package.optional-dependencies] @@ -1661,6 +1662,7 @@ requires-dist = [ { name = "fastmcp", specifier = ">=2.11.2" }, { name = "isort", marker = "extra == 'dev'", specifier = ">=5.12.0" }, { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.3.0" }, + { name = "pyyaml", specifier = ">=6.0.0" }, ] provides-extras = ["dev"]