Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions src/sagemaker/hyperpod/cli/inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pkgutil
import click
from typing import Callable, Optional, Mapping, Type
from .parser_utils import parse_complex_parameter


def load_schema_for_version(version: str, schema_pkg: str) -> dict:
Expand All @@ -22,14 +23,11 @@ def generate_click_command(
raise ValueError("You must pass a registry mapping version→Model")

def decorator(func: Callable) -> Callable:
# Parser for the single JSON‐dict env var flag
# Parser for dictionary parameters using the shared universal parser
def _parse_json_flag(ctx, param, value):
if value is None:
return None
try:
return json.loads(value)
except json.JSONDecodeError as e:
raise click.BadParameter(f"{param.name!r} must be valid JSON: {e}")
return parse_complex_parameter(ctx, param, value, expected_type=dict)

# 1) the wrapper click actually invokes
def wrapped_func(*args, **kwargs):
Expand All @@ -52,31 +50,31 @@ def wrapped_func(*args, **kwargs):
default=None,
help=(
"JSON object of environment variables, e.g. "
'\'{"VAR1":"foo","VAR2":"bar"}\''
'\'{"VAR1": "foo", "VAR2": "bar"}\''
),
metavar="JSON",
metavar="JSON"
)(wrapped_func)

wrapped_func = click.option(
"--dimensions",
callback=_parse_json_flag,
type=str,
default=None,
help=("JSON object of dimensions, e.g. " '\'{"VAR1":"foo","VAR2":"bar"}\''),
help=("JSON object of dimensions, e.g. " '\'{"VAR1": "foo", "VAR2": "bar"}\''),
metavar="JSON",
)(wrapped_func)

wrapped_func = click.option(
"--resources-limits",
callback=_parse_json_flag,
help='JSON object of resource limits, e.g. \'{"cpu":"2","memory":"4Gi"}\'',
help='JSON object of resource limits, e.g. \'{"cpu": "2", "memory": "4Gi"}\'',
metavar="JSON",
)(wrapped_func)

wrapped_func = click.option(
"--resources-requests",
callback=_parse_json_flag,
help='JSON object of resource requests, e.g. \'{"cpu":"1","memory":"2Gi"}\'',
help='JSON object of resource requests, e.g. \'{"cpu": "1", "memory": "2Gi"}\'',
metavar="JSON",
)(wrapped_func)

Expand Down
78 changes: 78 additions & 0 deletions src/sagemaker/hyperpod/cli/parser_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""
Shared parser utilities for complex CLI parameters
Provides universal parsing for lists, dicts, and other Python literals
"""

import ast
import click


def parse_complex_parameter(ctx, param, value, expected_type=None, allow_multiple=False):
"""
Universal parser for complex CLI parameters

Handles parsing of Python literal expressions including:
- Lists: '["item1", "item2"]'
- Dictionaries: '{"key": "value"}'
- Strings, numbers, booleans
- Multiple values for repeated flags

Args:
ctx: Click context object
param: Click parameter object
value: Input value(s) to parse - can be string or list of strings
expected_type: Expected Python type (dict, list, str, etc.) for validation
allow_multiple: Whether to handle multiple values (for repeated flags)

Returns:
Parsed Python object(s) - single object or list depending on allow_multiple

Raises:
click.BadParameter: If parsing fails or type validation fails
"""
if value is None:
return None

# Handle multiple values (like --volume used multiple times)
if allow_multiple:
if not isinstance(value, (list, tuple)):
value = [value]

results = []
for i, v in enumerate(value):
try:
parsed = ast.literal_eval(v)

# Type validation for individual items
if expected_type and not isinstance(parsed, expected_type):
raise click.BadParameter(
f"{param.name} item {i+1} must be {expected_type.__name__}, "
f"got {type(parsed).__name__}: {v}"
)

results.append(parsed)
except (ValueError, SyntaxError) as e:
raise click.BadParameter(
f"Invalid format for {param.name} item {i+1}: {v}. "
f"Expected Python literal (dict, list, string, etc.). Error: {e}"
)

return results

# Handle single value
try:
parsed = ast.literal_eval(value)

# Type validation
if expected_type and not isinstance(parsed, expected_type):
raise click.BadParameter(
f"{param.name} must be {expected_type.__name__}, "
f"got {type(parsed).__name__}: {value}"
)

return parsed
except (ValueError, SyntaxError) as e:
raise click.BadParameter(
f"Invalid format for {param.name}: {value}. "
f"Expected Python literal (dict, list, string, etc.). Error: {e}"
)
72 changes: 25 additions & 47 deletions src/sagemaker/hyperpod/cli/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import click
from typing import Callable, Optional, Mapping, Type, Dict, Any
from pydantic import ValidationError
from .parser_utils import parse_complex_parameter


def load_schema_for_version(
Expand Down Expand Up @@ -41,42 +42,19 @@ def generate_click_command(
raise ValueError("You must pass a registry mapping version→Model")

def decorator(func: Callable) -> Callable:
# Parser for the single JSON‐dict env var flag
def _parse_json_flag(ctx, param, value):
if value is None:
return None
try:
return json.loads(value)
except json.JSONDecodeError as e:
raise click.BadParameter(f"{param.name!r} must be valid JSON: {e}")

# Parser for list flags
# Specific parser functions for different parameter types
def _parse_list_flag(ctx, param, value):
if value is None:
return None
# Remove brackets and split by comma
value = value.strip("[]")
return [item.strip() for item in value.split(",") if item.strip()]

"""Parse list parameters like --command and --args"""
return parse_complex_parameter(ctx, param, value, list)

def _parse_dict_flag(ctx, param, value):
"""Parse dictionary parameters like --environment and --label_selector"""
return parse_complex_parameter(ctx, param, value, dict)

def _parse_volume_param(ctx, param, value):
"""Parse volume parameters from command line format to dictionary format."""
volumes = []
for i, v in enumerate(value):
try:
# Split by comma and then by equals, with validation
parts = {}
for item in v.split(','):
if '=' not in item:
raise click.UsageError(f"Invalid volume format in volume {i+1}: '{item}' should be key=value")
key, val = item.split('=', 1) # Split only on first '=' to handle values with '='
parts[key.strip()] = val.strip()

volumes.append(parts)
except Exception as e:
raise click.UsageError(f"Error parsing volume {i+1}: {str(e)}")

# Note: Detailed validation will be handled by schema validation
return volumes
"""Parse volume parameters (multiple dictionaries)"""
return parse_complex_parameter(ctx, param, value, dict, allow_multiple=True)

# 1) the wrapper click will call
def wrapped_func(*args, **kwargs):
Expand Down Expand Up @@ -111,37 +89,37 @@ def wrapped_func(*args, **kwargs):

wrapped_func = click.option(
"--environment",
callback=_parse_json_flag,
callback=_parse_dict_flag,
type=str,
default=None,
help=(
"JSON object of environment variables, e.g. "
'\'{"VAR1":"foo","VAR2":"bar"}\''
"Dictionary of environment variables, e.g. "
'\'{"VAR1": "foo", "VAR2": "bar"}\''
),
metavar="JSON",
metavar="DICT",
)(wrapped_func)
wrapped_func = click.option(
"--label_selector",
callback=_parse_json_flag,
help='JSON object of resource limits, e.g. \'{"cpu":"2","memory":"4Gi"}\'',
metavar="JSON",
callback=_parse_dict_flag,
help='Dictionary of resource limits, e.g. \'{"cpu": "2", "memory": "4Gi"}\'',
metavar="DICT",
)(wrapped_func)

wrapped_func = click.option(
"--volume",
multiple=True,
callback=_parse_volume_param,
help="List of volume configurations. \
Command structure: --volume name=<volume_name>,type=<volume_type>,mount_path=<mount_path>,<type-specific options> \
For hostPath: --volume name=model-data,type=hostPath,mount_path=/data,path=/data \
For persistentVolumeClaim: --volume name=training-output,type=pvc,mount_path=/mnt/output,claim_name=training-output-pvc,read_only=false \
If multiple --volume flag if multiple volumes are needed.",
help='Volume configurations as dictionaries. \
Example: --volume \'{"name": "vol1", "type": "hostPath", "mountPath": "/data", "hostPath": "/host"}\' \
For hostPath: --volume \'{"name": "model-data", "type": "hostPath", "mountPath": "/data", "hostPath": "/data"}\' \
For persistentVolumeClaim: --volume \'{"name": "training-output", "type": "pvc", "mountPath": "/mnt/output", "claimName": "training-output-pvc", "readOnly": false}\' \
Use multiple --volume flags if multiple volumes are needed.',
)(wrapped_func)

# Add list options
list_params = {
"command": "List of command arguments",
"args": "List of script arguments, e.g. '[--batch-size, 32, --learning-rate, 0.001]'",
"command": 'List of command arguments, e.g. \'["python", "train.py"]\'',
"args": 'List of script arguments, e.g. \'["--batch-size", "32", "--learning-rate", "0.001"]\'',
}

for param_name, help_text in list_params.items():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,4 +296,4 @@ def _style_dataframe(df):

def _get_table_layout(data_length):
"""Get appropriate table layout based on data size."""
return {} if data_length > 10 else {"topStart": None, "topEnd": "search"}
return {} if data_length > 10 else {"topStart": None, "topEnd": "search"}
4 changes: 2 additions & 2 deletions test/unit_tests/cli/test_inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@ def cmd(namespace, version, domain):
out = json.loads(res_ok.output)
assert out == {'env': {'a':1}, 'dimensions': {'b':2}, 'limits': {'c':3}, 'reqs': {'d':4}}

# invalid JSON produces click error
# invalid format produces click error
res_err = self.runner.invoke(cmd, ['--env', 'notjson'])
assert res_err.exit_code == 2
assert 'must be valid JSON' in res_err.output
assert 'Invalid format for' in res_err.output and 'Expected Python literal' in res_err.output

@patch('sagemaker.hyperpod.cli.inference_utils.load_schema_for_version')
def test_type_mapping_and_defaults(self, mock_load_schema):
Expand Down
Loading