diff --git a/src/xpk/parser/workload.py b/src/xpk/parser/workload.py index 3ecb3d1a8..04de73f17 100644 --- a/src/xpk/parser/workload.py +++ b/src/xpk/parser/workload.py @@ -24,6 +24,7 @@ from ..core.docker_image import DEFAULT_DOCKER_IMAGE, DEFAULT_SCRIPT_DIR from .common import add_shared_arguments from .validators import directory_path_type, name_type +from ..utils.feature_flags import FeatureFlags def set_workload_parsers(workload_parser: ArgumentParser): @@ -658,6 +659,13 @@ def add_shared_workload_create_optional_arguments(args_parsers): ' the workload.' ), ) + if FeatureFlags.SUB_SLICING_ENABLED: + custom_parser.add_argument( + '--sub-slicing-topology', + type=str, + help='Sub-slicing topology to use.', + required=False, + ) def add_shared_workload_create_env_arguments(args_parsers): diff --git a/src/xpk/parser/workload_test.py b/src/xpk/parser/workload_test.py new file mode 100644 index 000000000..7e0d53d19 --- /dev/null +++ b/src/xpk/parser/workload_test.py @@ -0,0 +1,82 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import argparse +from xpk.parser.workload import set_workload_create_parser +from ..utils.feature_flags import FeatureFlags +import pytest + + +@pytest.fixture(autouse=True) +def with_sub_slicing_enabled(): + FeatureFlags.SUB_SLICING_ENABLED = True + + +def test_workload_create_sub_slicing_topology_is_hidden_with_flag_off(): + FeatureFlags.SUB_SLICING_ENABLED = False + parser = argparse.ArgumentParser() + + set_workload_create_parser(parser) + help_str = parser.format_help() + + assert "--sub-slicing" not in help_str + + +def test_workload_create_sub_slicing_topology_is_shown_with_flag_on(): + parser = argparse.ArgumentParser() + + set_workload_create_parser(parser) + help_str = parser.format_help() + + assert "--sub-slicing" in help_str + + +def test_workload_create_sub_slicing_topology_is_none_by_default(): + parser = argparse.ArgumentParser() + + set_workload_create_parser(parser) + args = parser.parse_args([ + "--cluster", + "test-cluster", + "--command", + "python3", + "--workload", + "test", + "--tpu-type", + "test-tpu", + ]) + + assert args.sub_slicing_topology is None + + +def test_workload_create_sub_slicing_topology_can_be_set(): + parser = argparse.ArgumentParser() + + set_workload_create_parser(parser) + args = parser.parse_args([ + "--cluster", + "test-cluster", + "--command", + "python3", + "--workload", + "test", + "--tpu-type", + "test-tpu", + "--sub-slicing-topology", + "2x2", + ]) + + assert args.sub_slicing_topology is "2x2"