diff --git a/src/sagemaker/hyperpod/cli/inference_utils.py b/src/sagemaker/hyperpod/cli/inference_utils.py index 4fd76193..79c7966b 100644 --- a/src/sagemaker/hyperpod/cli/inference_utils.py +++ b/src/sagemaker/hyperpod/cli/inference_utils.py @@ -2,6 +2,7 @@ import pkgutil import click from typing import Callable, Optional, Mapping, Type +from .parser_utils import parse_complex_parameter def load_schema_for_version(version: str, schema_pkg: str) -> dict: @@ -22,14 +23,11 @@ def generate_click_command( raise ValueError("You must pass a registry mapping version→Model") def decorator(func: Callable) -> Callable: - # Parser for the single JSON‐dict env var flag + # Parser for dictionary parameters using the shared universal parser def _parse_json_flag(ctx, param, value): if value is None: return None - try: - return json.loads(value) - except json.JSONDecodeError as e: - raise click.BadParameter(f"{param.name!r} must be valid JSON: {e}") + return parse_complex_parameter(ctx, param, value, expected_type=dict) # 1) the wrapper click actually invokes def wrapped_func(*args, **kwargs): @@ -52,9 +50,9 @@ def wrapped_func(*args, **kwargs): default=None, help=( "JSON object of environment variables, e.g. " - '\'{"VAR1":"foo","VAR2":"bar"}\'' + '\'{"VAR1": "foo", "VAR2": "bar"}\'' ), - metavar="JSON", + metavar="JSON" )(wrapped_func) wrapped_func = click.option( @@ -62,21 +60,21 @@ def wrapped_func(*args, **kwargs): callback=_parse_json_flag, type=str, default=None, - help=("JSON object of dimensions, e.g. " '\'{"VAR1":"foo","VAR2":"bar"}\''), + help=("JSON object of dimensions, e.g. " '\'{"VAR1": "foo", "VAR2": "bar"}\''), metavar="JSON", )(wrapped_func) wrapped_func = click.option( "--resources-limits", callback=_parse_json_flag, - help='JSON object of resource limits, e.g. \'{"cpu":"2","memory":"4Gi"}\'', + help='JSON object of resource limits, e.g. \'{"cpu": "2", "memory": "4Gi"}\'', metavar="JSON", )(wrapped_func) wrapped_func = click.option( "--resources-requests", callback=_parse_json_flag, - help='JSON object of resource requests, e.g. \'{"cpu":"1","memory":"2Gi"}\'', + help='JSON object of resource requests, e.g. \'{"cpu": "1", "memory": "2Gi"}\'', metavar="JSON", )(wrapped_func) diff --git a/src/sagemaker/hyperpod/cli/parser_utils.py b/src/sagemaker/hyperpod/cli/parser_utils.py new file mode 100644 index 00000000..7bfdfc84 --- /dev/null +++ b/src/sagemaker/hyperpod/cli/parser_utils.py @@ -0,0 +1,78 @@ +""" +Shared parser utilities for complex CLI parameters +Provides universal parsing for lists, dicts, and other Python literals +""" + +import ast +import click + + +def parse_complex_parameter(ctx, param, value, expected_type=None, allow_multiple=False): + """ + Universal parser for complex CLI parameters + + Handles parsing of Python literal expressions including: + - Lists: '["item1", "item2"]' + - Dictionaries: '{"key": "value"}' + - Strings, numbers, booleans + - Multiple values for repeated flags + + Args: + ctx: Click context object + param: Click parameter object + value: Input value(s) to parse - can be string or list of strings + expected_type: Expected Python type (dict, list, str, etc.) for validation + allow_multiple: Whether to handle multiple values (for repeated flags) + + Returns: + Parsed Python object(s) - single object or list depending on allow_multiple + + Raises: + click.BadParameter: If parsing fails or type validation fails + """ + if value is None: + return None + + # Handle multiple values (like --volume used multiple times) + if allow_multiple: + if not isinstance(value, (list, tuple)): + value = [value] + + results = [] + for i, v in enumerate(value): + try: + parsed = ast.literal_eval(v) + + # Type validation for individual items + if expected_type and not isinstance(parsed, expected_type): + raise click.BadParameter( + f"{param.name} item {i+1} must be {expected_type.__name__}, " + f"got {type(parsed).__name__}: {v}" + ) + + results.append(parsed) + except (ValueError, SyntaxError) as e: + raise click.BadParameter( + f"Invalid format for {param.name} item {i+1}: {v}. " + f"Expected Python literal (dict, list, string, etc.). Error: {e}" + ) + + return results + + # Handle single value + try: + parsed = ast.literal_eval(value) + + # Type validation + if expected_type and not isinstance(parsed, expected_type): + raise click.BadParameter( + f"{param.name} must be {expected_type.__name__}, " + f"got {type(parsed).__name__}: {value}" + ) + + return parsed + except (ValueError, SyntaxError) as e: + raise click.BadParameter( + f"Invalid format for {param.name}: {value}. " + f"Expected Python literal (dict, list, string, etc.). Error: {e}" + ) diff --git a/src/sagemaker/hyperpod/cli/training_utils.py b/src/sagemaker/hyperpod/cli/training_utils.py index a08bb735..59bf8601 100644 --- a/src/sagemaker/hyperpod/cli/training_utils.py +++ b/src/sagemaker/hyperpod/cli/training_utils.py @@ -3,6 +3,7 @@ import click from typing import Callable, Optional, Mapping, Type, Dict, Any from pydantic import ValidationError +from .parser_utils import parse_complex_parameter def load_schema_for_version( @@ -41,42 +42,19 @@ def generate_click_command( raise ValueError("You must pass a registry mapping version→Model") def decorator(func: Callable) -> Callable: - # Parser for the single JSON‐dict env var flag - def _parse_json_flag(ctx, param, value): - if value is None: - return None - try: - return json.loads(value) - except json.JSONDecodeError as e: - raise click.BadParameter(f"{param.name!r} must be valid JSON: {e}") - # Parser for list flags + # Specific parser functions for different parameter types def _parse_list_flag(ctx, param, value): - if value is None: - return None - # Remove brackets and split by comma - value = value.strip("[]") - return [item.strip() for item in value.split(",") if item.strip()] - + """Parse list parameters like --command and --args""" + return parse_complex_parameter(ctx, param, value, list) + + def _parse_dict_flag(ctx, param, value): + """Parse dictionary parameters like --environment and --label_selector""" + return parse_complex_parameter(ctx, param, value, dict) + def _parse_volume_param(ctx, param, value): - """Parse volume parameters from command line format to dictionary format.""" - volumes = [] - for i, v in enumerate(value): - try: - # Split by comma and then by equals, with validation - parts = {} - for item in v.split(','): - if '=' not in item: - raise click.UsageError(f"Invalid volume format in volume {i+1}: '{item}' should be key=value") - key, val = item.split('=', 1) # Split only on first '=' to handle values with '=' - parts[key.strip()] = val.strip() - - volumes.append(parts) - except Exception as e: - raise click.UsageError(f"Error parsing volume {i+1}: {str(e)}") - - # Note: Detailed validation will be handled by schema validation - return volumes + """Parse volume parameters (multiple dictionaries)""" + return parse_complex_parameter(ctx, param, value, dict, allow_multiple=True) # 1) the wrapper click will call def wrapped_func(*args, **kwargs): @@ -111,37 +89,37 @@ def wrapped_func(*args, **kwargs): wrapped_func = click.option( "--environment", - callback=_parse_json_flag, + callback=_parse_dict_flag, type=str, default=None, help=( - "JSON object of environment variables, e.g. " - '\'{"VAR1":"foo","VAR2":"bar"}\'' + "Dictionary of environment variables, e.g. " + '\'{"VAR1": "foo", "VAR2": "bar"}\'' ), - metavar="JSON", + metavar="DICT", )(wrapped_func) wrapped_func = click.option( "--label_selector", - callback=_parse_json_flag, - help='JSON object of resource limits, e.g. \'{"cpu":"2","memory":"4Gi"}\'', - metavar="JSON", + callback=_parse_dict_flag, + help='Dictionary of resource limits, e.g. \'{"cpu": "2", "memory": "4Gi"}\'', + metavar="DICT", )(wrapped_func) wrapped_func = click.option( "--volume", multiple=True, callback=_parse_volume_param, - help="List of volume configurations. \ - Command structure: --volume name=,type=,mount_path=, \ - For hostPath: --volume name=model-data,type=hostPath,mount_path=/data,path=/data \ - For persistentVolumeClaim: --volume name=training-output,type=pvc,mount_path=/mnt/output,claim_name=training-output-pvc,read_only=false \ - If multiple --volume flag if multiple volumes are needed.", + help='Volume configurations as dictionaries. \ + Example: --volume \'{"name": "vol1", "type": "hostPath", "mountPath": "/data", "hostPath": "/host"}\' \ + For hostPath: --volume \'{"name": "model-data", "type": "hostPath", "mountPath": "/data", "hostPath": "/data"}\' \ + For persistentVolumeClaim: --volume \'{"name": "training-output", "type": "pvc", "mountPath": "/mnt/output", "claimName": "training-output-pvc", "readOnly": false}\' \ + Use multiple --volume flags if multiple volumes are needed.', )(wrapped_func) # Add list options list_params = { - "command": "List of command arguments", - "args": "List of script arguments, e.g. '[--batch-size, 32, --learning-rate, 0.001]'", + "command": 'List of command arguments, e.g. \'["python", "train.py"]\'', + "args": 'List of script arguments, e.g. \'["--batch-size", "32", "--learning-rate", "0.001"]\'', } for param_name, help_text in list_params.items(): diff --git a/src/sagemaker/hyperpod/inference/jumpstart_public_hub_visualization_utils.py b/src/sagemaker/hyperpod/inference/jumpstart_public_hub_visualization_utils.py index a3c1d63b..b6f6b967 100644 --- a/src/sagemaker/hyperpod/inference/jumpstart_public_hub_visualization_utils.py +++ b/src/sagemaker/hyperpod/inference/jumpstart_public_hub_visualization_utils.py @@ -296,4 +296,4 @@ def _style_dataframe(df): def _get_table_layout(data_length): """Get appropriate table layout based on data size.""" - return {} if data_length > 10 else {"topStart": None, "topEnd": "search"} + return {} if data_length > 10 else {"topStart": None, "topEnd": "search"} \ No newline at end of file diff --git a/test/unit_tests/cli/test_inference_utils.py b/test/unit_tests/cli/test_inference_utils.py index 94db7dd9..98861b6e 100644 --- a/test/unit_tests/cli/test_inference_utils.py +++ b/test/unit_tests/cli/test_inference_utils.py @@ -82,10 +82,10 @@ def cmd(namespace, version, domain): out = json.loads(res_ok.output) assert out == {'env': {'a':1}, 'dimensions': {'b':2}, 'limits': {'c':3}, 'reqs': {'d':4}} - # invalid JSON produces click error + # invalid format produces click error res_err = self.runner.invoke(cmd, ['--env', 'notjson']) assert res_err.exit_code == 2 - assert 'must be valid JSON' in res_err.output + assert 'Invalid format for' in res_err.output and 'Expected Python literal' in res_err.output @patch('sagemaker.hyperpod.cli.inference_utils.load_schema_for_version') def test_type_mapping_and_defaults(self, mock_load_schema): diff --git a/test/unit_tests/cli/test_training_utils.py b/test/unit_tests/cli/test_training_utils.py index 683280b4..9c7a4442 100644 --- a/test/unit_tests/cli/test_training_utils.py +++ b/test/unit_tests/cli/test_training_utils.py @@ -89,10 +89,10 @@ def cmd(version, debug, config): 'label_selector': {'key': 'value'} } - # Test invalid JSON input - result = self.runner.invoke(cmd, ['--environment', 'invalid']) + # Test invalid Python literal input + result = self.runner.invoke(cmd, ['--environment', 'invalid_python_syntax']) assert result.exit_code == 2 - assert 'must be valid JSON' in result.output + assert 'Invalid format' in result.output @patch('sagemaker.hyperpod.cli.training_utils.pkgutil.get_data') def test_list_parameters(self, mock_get_data): @@ -124,10 +124,10 @@ def cmd(version, debug, config): 'args': config.args })) - # Test list input + # Test list input - using consistent single quotes outside, double quotes inside result = self.runner.invoke(cmd, [ - '--command', '[python, train.py]', - '--args', '[--epochs, 10]' + '--command', '["python", "train.py"]', + '--args', '["--epochs", "10"]' ]) assert result.exit_code == 0 output = json.loads(result.output) @@ -256,7 +256,7 @@ def cmd(version, debug, config): # Test single hostPath volume result = self.runner.invoke(cmd, [ - '--volume', 'name=model-data,type=hostPath,mount_path=/data,path=/host/data' + '--volume', '{"name": "model-data", "type": "hostPath", "mount_path": "/data", "path": "/host/data"}' ]) assert result.exit_code == 0 output = json.loads(result.output) @@ -270,7 +270,7 @@ def cmd(version, debug, config): # Test single PVC volume result = self.runner.invoke(cmd, [ - '--volume', 'name=training-output,type=pvc,mount_path=/output,claim_name=my-pvc,read_only=false' + '--volume', '{"name": "training-output", "type": "pvc", "mount_path": "/output", "claim_name": "my-pvc", "read_only": "false"}' ]) assert result.exit_code == 0 output = json.loads(result.output) @@ -285,8 +285,8 @@ def cmd(version, debug, config): # Test multiple volumes result = self.runner.invoke(cmd, [ - '--volume', 'name=model-data,type=hostPath,mount_path=/data,path=/host/data', - '--volume', 'name=training-output,type=pvc,mount_path=/output,claim_name=my-pvc,read_only=true' + '--volume', '{"name": "model-data", "type": "hostPath", "mount_path": "/data", "path": "/host/data"}', + '--volume', '{"name": "training-output", "type": "pvc", "mount_path": "/output", "claim_name": "my-pvc", "read_only": "true"}' ]) assert result.exit_code == 0 output = json.loads(result.output) @@ -372,7 +372,7 @@ def cmd(version, debug, config): result = self.runner.invoke(cmd, [ '--job-name', 'test-job', '--image', 'test-image', - '--volume', 'name=model-data,type=hostPath,mount_path=/data,path=/host/data' + '--volume', '{"name": "model-data", "type": "hostPath", "mount_path": "/data", "path": "/host/data"}' ]) assert result.exit_code == 0 output = json.loads(result.output) @@ -383,7 +383,7 @@ def cmd(version, debug, config): result = self.runner.invoke(cmd, [ '--job-name', 'test-job', '--image', 'test-image', - '--volume', 'name=training-output,type=pvc,mount_path=/output,claim_name=my-pvc,read_only=true' + '--volume', '{"name": "training-output", "type": "pvc", "mount_path": "/output", "claim_name": "my-pvc", "read_only": "true"}' ]) assert result.exit_code == 0 output = json.loads(result.output) @@ -394,7 +394,7 @@ def cmd(version, debug, config): @patch('sagemaker.hyperpod.cli.training_utils.pkgutil.get_data') def test_volume_flag_parsing_errors(self, mock_get_data): - """Test volume flag parsing error handling""" + """Test volume flag parsing error handling with new format""" schema = { 'properties': { 'volume': { @@ -421,23 +421,23 @@ def to_domain(self): def cmd(version, debug, config): click.echo("success") - # Test invalid format (missing equals sign) + # Test invalid Python literal (old key=value format) result = self.runner.invoke(cmd, [ - '--volume', 'name=model-data,type=hostPath,mount_path,path=/host/data' + '--volume', 'name=model-data,type=hostPath,mount_path=/data' ]) assert result.exit_code == 2 - assert "should be key=value" in result.output + assert "Invalid format" in result.output - # Test empty volume parameter + # Test invalid JSON syntax result = self.runner.invoke(cmd, [ - '--volume', '' + '--volume', '{"name": "test", invalid}' ]) assert result.exit_code == 2 - assert "Error parsing volume" in result.output + assert "Invalid format" in result.output @patch('sagemaker.hyperpod.cli.training_utils.pkgutil.get_data') - def test_volume_flag_with_equals_in_value(self, mock_get_data): - """Test volume flag parsing with equals signs in values""" + def test_volume_flag_with_special_characters(self, mock_get_data): + """Test volume flag parsing with special characters in new format""" schema = { 'properties': { 'volume': { @@ -466,9 +466,9 @@ def cmd(version, debug, config): 'volume': config.volume if hasattr(config, 'volume') else None })) - # Test volume with equals sign in path value + # Test volume with special characters in path result = self.runner.invoke(cmd, [ - '--volume', 'name=model-data,type=hostPath,mount_path=/data,path=/host/data=special' + '--volume', '{"name": "model-data", "type": "hostPath", "mount_path": "/data", "path": "/host/data=special"}' ]) assert result.exit_code == 0 output = json.loads(result.output) @@ -478,4 +478,4 @@ def cmd(version, debug, config): 'mount_path': '/data', 'path': '/host/data=special' }] - assert output['volume'] == expected_volume \ No newline at end of file + assert output['volume'] == expected_volume