diff --git a/src/xpk/commands/cluster.py b/src/xpk/commands/cluster.py index 5ff9ce277..ef2715061 100644 --- a/src/xpk/commands/cluster.py +++ b/src/xpk/commands/cluster.py @@ -76,7 +76,7 @@ from ..utils.execution_context import is_dry_run from ..utils.validation import validate_dependencies_list, SystemDependency, should_validate_dependencies from . import cluster_gcluster -from .common import set_cluster_command +from .common import set_cluster_command, validate_sub_slicing_system from jinja2 import Environment, FileSystemLoader from ..utils.templates import TEMPLATE_PATH import shutil @@ -201,6 +201,11 @@ def cluster_adapt(args) -> None: xpk_exit(0) +def _validate_cluster_create_args(args, system: SystemCharacteristics): + if FeatureFlags.SUB_SLICING_ENABLED and args.sub_slicing: + validate_sub_slicing_system(system) + + def cluster_create(args) -> None: """Function around cluster creation. @@ -213,12 +218,14 @@ def cluster_create(args) -> None: SystemDependency.KJOB, SystemDependency.GCLOUD, ]) - system, return_code = get_system_characteristics(args) + system, return_code = get_system_characteristics(args) if return_code > 0 or system is None: xpk_print('Fetching system characteristics failed!') xpk_exit(return_code) + _validate_cluster_create_args(args, system) + xpk_print(f'Starting cluster create for cluster {args.cluster}:', flush=True) add_zone_and_project(args) diff --git a/src/xpk/commands/cluster_test.py b/src/xpk/commands/cluster_test.py new file mode 100644 index 000000000..77010d3d3 --- /dev/null +++ b/src/xpk/commands/cluster_test.py @@ -0,0 +1,92 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from argparse import Namespace +from dataclasses import dataclass +from unittest.mock import MagicMock +import pytest + +from xpk.commands.cluster import _validate_cluster_create_args +from xpk.core.system_characteristics import SystemCharacteristics, UserFacingNameToSystemCharacteristics +from xpk.utils.feature_flags import FeatureFlags + + +@dataclass +class _Mocks: + common_print_mock: MagicMock + common_exit_mock: MagicMock + + +@pytest.fixture +def mock_common_print_and_exit(mocker): + common_print_mock = mocker.patch( + 'xpk.commands.common.xpk_print', + return_value=None, + ) + common_exit_mock = mocker.patch( + 'xpk.commands.common.xpk_exit', + return_value=None, + ) + return _Mocks( + common_print_mock=common_print_mock, common_exit_mock=common_exit_mock + ) + + +DEFAULT_TEST_SYSTEM: SystemCharacteristics = ( + UserFacingNameToSystemCharacteristics['l4-1'] +) +SUB_SLICING_SYSTEM: SystemCharacteristics = ( + UserFacingNameToSystemCharacteristics['v6e-4x4'] +) + + +def test_validate_cluster_create_args_for_correct_args_pass( + mock_common_print_and_exit: _Mocks, +): + args = Namespace() + + _validate_cluster_create_args(args, DEFAULT_TEST_SYSTEM) + + assert mock_common_print_and_exit.common_print_mock.call_count == 0 + assert mock_common_print_and_exit.common_exit_mock.call_count == 0 + + +def test_validate_cluster_create_args_for_correct_sub_slicing_args_pass( + mock_common_print_and_exit: _Mocks, +): + FeatureFlags.SUB_SLICING_ENABLED = True + args = Namespace(sub_slicing=True) + + _validate_cluster_create_args(args, SUB_SLICING_SYSTEM) + + assert mock_common_print_and_exit.common_print_mock.call_count == 0 + assert mock_common_print_and_exit.common_exit_mock.call_count == 0 + + +def test_validate_cluster_create_args_for_not_supported_system_throws( + mock_common_print_and_exit: _Mocks, +): + FeatureFlags.SUB_SLICING_ENABLED = True + args = Namespace(sub_slicing=True) + + _validate_cluster_create_args(args, DEFAULT_TEST_SYSTEM) + + assert mock_common_print_and_exit.common_print_mock.call_count == 1 + assert ( + mock_common_print_and_exit.common_print_mock.call_args[0][0] + == 'Error: l4-1 does not support Sub-slicing.' + ) + assert mock_common_print_and_exit.common_exit_mock.call_count == 1 diff --git a/src/xpk/commands/common.py b/src/xpk/commands/common.py index ca2c1be69..e4113936c 100644 --- a/src/xpk/commands/common.py +++ b/src/xpk/commands/common.py @@ -67,3 +67,9 @@ def is_TAS_possible( system_characteristics.device_type != H100_MEGA_DEVICE_TYPE or capacity_type == CapacityType.RESERVATION ) + + +def validate_sub_slicing_system(system: SystemCharacteristics): + if not system.supports_sub_slicing: + xpk_print(f'Error: {system.device_type} does not support Sub-slicing.') + xpk_exit(1)