Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/xpk/commands/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
from ..utils.execution_context import is_dry_run
from ..utils.validation import validate_dependencies_list, SystemDependency, should_validate_dependencies
from . import cluster_gcluster
from .common import set_cluster_command
from .common import set_cluster_command, validate_sub_slicing_system
from jinja2 import Environment, FileSystemLoader
from ..utils.templates import TEMPLATE_PATH
import shutil
Expand Down Expand Up @@ -201,6 +201,11 @@ def cluster_adapt(args) -> None:
xpk_exit(0)


def _validate_cluster_create_args(args, system: SystemCharacteristics):
if FeatureFlags.SUB_SLICING_ENABLED and args.sub_slicing:
validate_sub_slicing_system(system)


def cluster_create(args) -> None:
"""Function around cluster creation.

Expand All @@ -213,12 +218,14 @@ def cluster_create(args) -> None:
SystemDependency.KJOB,
SystemDependency.GCLOUD,
])
system, return_code = get_system_characteristics(args)

system, return_code = get_system_characteristics(args)
if return_code > 0 or system is None:
xpk_print('Fetching system characteristics failed!')
xpk_exit(return_code)

_validate_cluster_create_args(args, system)

xpk_print(f'Starting cluster create for cluster {args.cluster}:', flush=True)
add_zone_and_project(args)

Expand Down
92 changes: 92 additions & 0 deletions src/xpk/commands/cluster_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""
Copyright 2025 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from argparse import Namespace
from dataclasses import dataclass
from unittest.mock import MagicMock
import pytest

from xpk.commands.cluster import _validate_cluster_create_args
from xpk.core.system_characteristics import SystemCharacteristics, UserFacingNameToSystemCharacteristics
from xpk.utils.feature_flags import FeatureFlags


@dataclass
class _Mocks:
common_print_mock: MagicMock
common_exit_mock: MagicMock


@pytest.fixture
def mock_common_print_and_exit(mocker):
common_print_mock = mocker.patch(
'xpk.commands.common.xpk_print',
return_value=None,
)
common_exit_mock = mocker.patch(
'xpk.commands.common.xpk_exit',
return_value=None,
)
return _Mocks(
common_print_mock=common_print_mock, common_exit_mock=common_exit_mock
)


DEFAULT_TEST_SYSTEM: SystemCharacteristics = (
UserFacingNameToSystemCharacteristics['l4-1']
)
SUB_SLICING_SYSTEM: SystemCharacteristics = (
UserFacingNameToSystemCharacteristics['v6e-4x4']
)


def test_validate_cluster_create_args_for_correct_args_pass(
mock_common_print_and_exit: _Mocks,
):
args = Namespace()

_validate_cluster_create_args(args, DEFAULT_TEST_SYSTEM)

assert mock_common_print_and_exit.common_print_mock.call_count == 0
assert mock_common_print_and_exit.common_exit_mock.call_count == 0


def test_validate_cluster_create_args_for_correct_sub_slicing_args_pass(
mock_common_print_and_exit: _Mocks,
):
FeatureFlags.SUB_SLICING_ENABLED = True
args = Namespace(sub_slicing=True)

_validate_cluster_create_args(args, SUB_SLICING_SYSTEM)

assert mock_common_print_and_exit.common_print_mock.call_count == 0
assert mock_common_print_and_exit.common_exit_mock.call_count == 0


def test_validate_cluster_create_args_for_not_supported_system_throws(
mock_common_print_and_exit: _Mocks,
):
FeatureFlags.SUB_SLICING_ENABLED = True
args = Namespace(sub_slicing=True)

_validate_cluster_create_args(args, DEFAULT_TEST_SYSTEM)

assert mock_common_print_and_exit.common_print_mock.call_count == 1
assert (
mock_common_print_and_exit.common_print_mock.call_args[0][0]
== 'Error: l4-1 does not support Sub-slicing.'
)
assert mock_common_print_and_exit.common_exit_mock.call_count == 1
6 changes: 6 additions & 0 deletions src/xpk/commands/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,9 @@ def is_TAS_possible(
system_characteristics.device_type != H100_MEGA_DEVICE_TYPE
or capacity_type == CapacityType.RESERVATION
)


def validate_sub_slicing_system(system: SystemCharacteristics):
if not system.supports_sub_slicing:
xpk_print(f'Error: {system.device_type} does not support Sub-slicing.')
xpk_exit(1)
Loading