diff --git a/src/xpk/commands/workload.py b/src/xpk/commands/workload.py index d7bb86446..59a4b7c29 100644 --- a/src/xpk/commands/workload.py +++ b/src/xpk/commands/workload.py @@ -32,7 +32,7 @@ get_main_container_docker_image, get_user_workload_container, ) -from ..core.kueue_manager import has_sub_slicing_enabled, get_installed_kueue_version, SUB_SLICING_TOPOLOGIES, LOCAL_QUEUE_NAME +from ..core.kueue_manager import has_sub_slicing_enabled, get_installed_kueue_version, LOCAL_QUEUE_NAME from ..core.docker_resources import get_volumes, parse_env_config from ..core.gcloud_context import add_zone_and_project from ..core.monitoring import get_gke_outlier_dashboard @@ -78,6 +78,7 @@ get_storages_to_mount, ) from ..core.system_characteristics import ( + SUB_SLICING_TOPOLOGIES, AcceleratorType, get_system_characteristics, compute_vms_per_slice, diff --git a/src/xpk/core/kueue_manager.py b/src/xpk/core/kueue_manager.py index 5dfb2f89c..ef7738e39 100644 --- a/src/xpk/core/kueue_manager.py +++ b/src/xpk/core/kueue_manager.py @@ -21,16 +21,17 @@ import json from jinja2 import Environment, FileSystemLoader +from ..utils.topology import get_slice_topology_level, get_topology_product, is_topology_contained from ..utils.execution_context import is_dry_run from ..utils.kueue import is_queued_cluster from kubernetes.utils import parse_quantity - from .capacity import B200_DEVICE_TYPE, H100_MEGA_DEVICE_TYPE, H200_DEVICE_TYPE from .scheduling import ( create_accelerator_label, create_machine_label, ) from .system_characteristics import ( + SUB_SLICING_TOPOLOGIES, AcceleratorTypeToAcceleratorCharacteristics, SystemCharacteristics, ) @@ -56,7 +57,6 @@ KUEUE_SUB_SLICING_TOPOLOGY_JINJA_FILE = "kueue_sub_slicing_topology.yaml.j2" MEMORY_SIZE_PER_VM = 1.2 MIN_MEMORY_LIMIT_SIZE = 4096 -SUB_SLICING_TOPOLOGIES = ["2x2", "2x4", "4x4", "4x8", "8x8", "8x16", "16x16"] @dataclass(frozen=True) @@ -413,12 +413,25 @@ def __get_topology_name_and_yaml( ).render(), ) elif configure_sub_slicing: + sorted_topologies = sorted( + SUB_SLICING_TOPOLOGIES, key=get_topology_product, reverse=True + ) + levels = [ + get_slice_topology_level(topology) + for topology in sorted_topologies + if is_topology_contained( + contained=topology, container=system.topology + ) + ] + levels.append("kubernetes.io/hostname") + return _NameAndYaml( name=SUB_SLICE_TOPOLOGY_NAME, yaml=self.template_env.get_template( KUEUE_SUB_SLICING_TOPOLOGY_JINJA_FILE ).render({ "sub_slice_topology_name": SUB_SLICE_TOPOLOGY_NAME, + "levels": levels, }), ) else: diff --git a/src/xpk/core/kueue_manager_test.py b/src/xpk/core/kueue_manager_test.py index 5de53bd0c..edcfcb886 100644 --- a/src/xpk/core/kueue_manager_test.py +++ b/src/xpk/core/kueue_manager_test.py @@ -22,7 +22,7 @@ from unittest.mock import MagicMock, patch from xpk.core.kueue_manager import KueueConfig, KueueManager, has_sub_slicing_enabled -from xpk.core.system_characteristics import AcceleratorType, SystemCharacteristics +from xpk.core.system_characteristics import AcceleratorType, SystemCharacteristics, UserFacingNameToSystemCharacteristics from xpk.core.testing.commands_tester import CommandsTester from packaging.version import Version @@ -435,6 +435,7 @@ def test_configure_generates_correct_manifest_with_sub_slicing( kueue_config = dataclasses.replace( KUEUE_CONFIG, configure_sub_slicing=True, + system=UserFacingNameToSystemCharacteristics["v6e-8x8"], ) kueue_manager.install_or_upgrade(kueue_config) @@ -447,6 +448,15 @@ def test_configure_generates_correct_manifest_with_sub_slicing( assert resource_flavor["spec"]["topologyName"] == "sub-slice-topology" topology = _first(doc for doc in manifest_docs if doc["kind"] == "Topology") assert topology["metadata"]["name"] == "sub-slice-topology" + expected_levels = [ + "cloud.google.com/gke-tpu-slice-8x8-id", + "cloud.google.com/gke-tpu-slice-4x8-id", + "cloud.google.com/gke-tpu-slice-4x4-id", + "cloud.google.com/gke-tpu-slice-2x4-id", + "kubernetes.io/hostname", + ] + actual_levels = [level["nodeLabel"] for level in topology["spec"]["levels"]] + assert actual_levels == expected_levels @patch("xpk.core.kueue_manager.write_tmp_file") diff --git a/src/xpk/core/scheduling.py b/src/xpk/core/scheduling.py index d87871d48..399609865 100644 --- a/src/xpk/core/scheduling.py +++ b/src/xpk/core/scheduling.py @@ -14,6 +14,7 @@ limitations under the License. """ +from ..utils.topology import get_slice_topology_level from ..utils.console import xpk_print from ..utils.topology import is_topology_valid from ..utils.execution_context import is_dry_run @@ -300,7 +301,7 @@ def create_sub_slicing_annotations(sub_slicing_topology: str) -> list[str]: return [ ( 'kueue.x-k8s.io/podset-required-topology:' - f' "google.com/gke-tpu-slice-{sub_slicing_topology}-id"' + f' "{get_slice_topology_level(sub_slicing_topology)}"' ), f'cloud.google.com/gke-tpu-slice-topology: {sub_slicing_topology}', ] diff --git a/src/xpk/core/scheduling_test.py b/src/xpk/core/scheduling_test.py index eb129df21..3af4a21e0 100644 --- a/src/xpk/core/scheduling_test.py +++ b/src/xpk/core/scheduling_test.py @@ -19,16 +19,14 @@ def test_create_sub_slicing_annotations_returns_valid_annotations(): - subslicing_topology = '2x2' - - result = create_sub_slicing_annotations(subslicing_topology) + result = create_sub_slicing_annotations(sub_slicing_topology='2x4') assert result == [ ( 'kueue.x-k8s.io/podset-required-topology:' - ' "google.com/gke-tpu-slice-2x2-id"' + ' "cloud.google.com/gke-tpu-slice-2x4-id"' ), - 'cloud.google.com/gke-tpu-slice-topology: 2x2', + 'cloud.google.com/gke-tpu-slice-topology: 2x4', ] diff --git a/src/xpk/core/system_characteristics.py b/src/xpk/core/system_characteristics.py index 4d3e5aa72..06e2996b3 100644 --- a/src/xpk/core/system_characteristics.py +++ b/src/xpk/core/system_characteristics.py @@ -18,6 +18,8 @@ from ..utils.topology import get_topology_product from enum import Enum +SUB_SLICING_TOPOLOGIES = ['2x4', '4x4', '4x8', '8x8', '8x16', '16x16'] + class AcceleratorType(Enum): TPU = 1 @@ -495,17 +497,19 @@ def compute_vms_per_slice(topology: str) -> int: tensorcores_per_chip=1, gke_accelerator='tpu-v6e-slice', machine_type='ct6e-standard-4t', - supports_sub_slicing=True, + supports_sub_slicing=False, supported_topologies=[ '2x2', - '2x4', - '4x4', - '4x8', - '8x8', - '8x16', - '16x16', ], ), + **get_tpu_system_characteristics_map( + prefix='v6e', + tensorcores_per_chip=1, + gke_accelerator='tpu-v6e-slice', + machine_type='ct6e-standard-4t', + supports_sub_slicing=True, + supported_topologies=SUB_SLICING_TOPOLOGIES, + ), **get_tpu_system_characteristics_map( prefix='v5p', tensorcores_per_chip=2, diff --git a/src/xpk/parser/workload.py b/src/xpk/parser/workload.py index 7e73a600e..ab6e1d0b1 100644 --- a/src/xpk/parser/workload.py +++ b/src/xpk/parser/workload.py @@ -25,8 +25,7 @@ from .common import add_shared_arguments from .validators import directory_path_type, name_type from ..utils.feature_flags import FeatureFlags -from ..core.kueue_manager import SUB_SLICING_TOPOLOGIES -from ..core.system_characteristics import get_system_characteristics_keys_by_accelerator_type, AcceleratorType +from ..core.system_characteristics import get_system_characteristics_keys_by_accelerator_type, AcceleratorType, SUB_SLICING_TOPOLOGIES def set_workload_parsers(workload_parser: ArgumentParser): diff --git a/src/xpk/parser/workload_test.py b/src/xpk/parser/workload_test.py index 870273303..513b4055b 100644 --- a/src/xpk/parser/workload_test.py +++ b/src/xpk/parser/workload_test.py @@ -74,9 +74,9 @@ def test_workload_create_sub_slicing_topology_can_be_set(): "--workload", "test", "--tpu-type", - "tpu7x-2", + "tpu7x-8", "--sub-slicing-topology", - "2x2", + "2x4", ]) - assert args.sub_slicing_topology is "2x2" + assert args.sub_slicing_topology is "2x4" diff --git a/src/xpk/templates/kueue_sub_slicing_topology.yaml.j2 b/src/xpk/templates/kueue_sub_slicing_topology.yaml.j2 index 87e874e57..54eb795cc 100644 --- a/src/xpk/templates/kueue_sub_slicing_topology.yaml.j2 +++ b/src/xpk/templates/kueue_sub_slicing_topology.yaml.j2 @@ -4,11 +4,6 @@ metadata: name: {{ sub_slice_topology_name }} spec: levels: - - nodeLabel: "cloud.google.com/gke-tpu-slice-16x16-id" - - nodeLabel: "cloud.google.com/gke-tpu-slice-8x16-id" - - nodeLabel: "cloud.google.com/gke-tpu-slice-8x8-id" - - nodeLabel: "cloud.google.com/gke-tpu-slice-4x8-id" - - nodeLabel: "cloud.google.com/gke-tpu-slice-4x4-id" - - nodeLabel: "cloud.google.com/gke-tpu-slice-2x4-id" - - nodeLabel: "cloud.google.com/gke-tpu-slice-2x2-id" - - nodeLabel: "kubernetes.io/hostname" + {% for level in levels %} + - nodeLabel: "{{level}}" + {% endfor %} diff --git a/src/xpk/utils/topology.py b/src/xpk/utils/topology.py index f65d1722e..178f94858 100644 --- a/src/xpk/utils/topology.py +++ b/src/xpk/utils/topology.py @@ -44,3 +44,7 @@ def is_topology_contained(contained: str, container: str) -> bool: contained <= container for contained, container in zip(contained_parsed, container_parsed) ) + + +def get_slice_topology_level(topology: str) -> str: + return f"cloud.google.com/gke-tpu-slice-{topology}-id"