diff --git a/src/xpk/commands/cluster.py b/src/xpk/commands/cluster.py index f21ba929a..e033f9ac1 100644 --- a/src/xpk/commands/cluster.py +++ b/src/xpk/commands/cluster.py @@ -1217,6 +1217,7 @@ def install_kueue(args, system: SystemCharacteristics, autoprovisioning_config): total_chips=total_chips, autoprovisioning_enabled=autoprovisioning_enabled, num_slices=args.num_slices, + flex=args.flex, memory_limit=args.memory_limit, cpu_limit=args.cpu_limit, is_pathways_cluster=args.enable_pathways, diff --git a/src/xpk/core/capacity.py b/src/xpk/core/capacity.py index 6d2b07f55..b14028cf5 100644 --- a/src/xpk/core/capacity.py +++ b/src/xpk/core/capacity.py @@ -17,6 +17,7 @@ import enum from ..utils.console import xpk_print, xpk_exit +from ..utils.kueue import is_queued_cluster from .commands import run_command_with_updates, run_command_for_value AUTOPROVISIONING_CONFIG_VALUE = 'AUTOPROVISION' @@ -199,7 +200,7 @@ def get_capacity_arguments_from_capacity_type( ' --location-policy=ANY --reservation-affinity=none' f' --no-enable-autorepair --max-nodes={max_nodes}' ) - if args.num_slices <= 1: + if is_queued_cluster(args.num_slices): capacity_args += ' --enable-queued-provisioning' case CapacityType.RESERVATION: capacity_args = ( diff --git a/src/xpk/core/kueue_manager.py b/src/xpk/core/kueue_manager.py index 298a3aefe..31810a3a8 100644 --- a/src/xpk/core/kueue_manager.py +++ b/src/xpk/core/kueue_manager.py @@ -15,11 +15,13 @@ """ import math +import textwrap from dataclasses import dataclass from typing import Optional, List, Dict, Any import json from jinja2 import Environment, FileSystemLoader from ..utils.execution_context import is_dry_run +from ..utils.kueue import is_queued_cluster from .capacity import B200_DEVICE_TYPE, H100_MEGA_DEVICE_TYPE, H200_DEVICE_TYPE from .scheduling import ( @@ -311,11 +313,11 @@ def __build_template_context( }], }) - if flex: - admission_checks = """ + if flex and is_queued_cluster(num_slices): + admission_checks = textwrap.dedent(""" admissionChecks: - dws-prov - """ + """) else: admission_checks = "" diff --git a/src/xpk/core/kueue_manager_test.py b/src/xpk/core/kueue_manager_test.py index 67ee820d6..5fa1e5003 100644 --- a/src/xpk/core/kueue_manager_test.py +++ b/src/xpk/core/kueue_manager_test.py @@ -416,6 +416,86 @@ def test_configure_generates_correct_manifest( "2x2x1", ) + @patch("xpk.core.kueue_manager.KueueManager._KueueManager__apply_manifest") + @patch("xpk.core.kueue_manager.KueueManager._KueueManager__install") + @patch( + "xpk.core.kueue_manager.KueueManager._KueueManager__update_kueue_resources_if_necessary" + ) + def test_configure_generates_correct_manifest_with_admission_checks( + self, mock_update_resources, mock_install, mock_apply_manifest + ): + """Test that __configure generates the correct manifest content for TPUs.""" + mock_install.return_value = 0 + mock_update_resources.return_value = 0 + mock_apply_manifest.return_value = 0 + kueue_config = KueueConfig( + system=self.mock_system_chars, + total_chips=8, + cpu_limit=100, + memory_limit="100Gi", + autoprovisioning_enabled=False, + num_slices=1, + flex=True, + ) + + with patch.object( + self.kueue_manager, "_KueueManager__get_installed_kueue_version" + ) as mock_get_version: + mock_get_version.return_value = (1, None) # Trigger install + with ( + patch.object(self.kueue_manager, "_KueueManager__install_kueue_crs"), + patch.object( + self.kueue_manager, "_KueueManager__wait_for_kueue_available" + ), + ): + self.kueue_manager.install_or_upgrade(kueue_config) + + mock_apply_manifest.assert_called_once() + rendered_manifest = mock_apply_manifest.call_args[0][0] + + self.assertNotIn("kind: Topology", rendered_manifest) + manifest_docs = list(yaml.safe_load_all(rendered_manifest)) + cluster_queue = next( + (doc for doc in manifest_docs if doc["kind"] == "ClusterQueue"), None + ) + self.assertIsNotNone(cluster_queue) + self.assertEqual( + cluster_queue["spec"]["resourceGroups"][0]["flavors"][0]["name"], + "1xv5p-8", + ) + resources = cluster_queue["spec"]["resourceGroups"][0]["flavors"][0][ + "resources" + ] + tpu_resource = next( + (r for r in resources if r["name"] == "google.com/tpu"), None + ) + cpu_resource = next((r for r in resources if r["name"] == "cpu"), None) + memory_resource = next( + (r for r in resources if r["name"] == "memory"), None + ) + self.assertIsNotNone(tpu_resource) + self.assertEqual(tpu_resource["nominalQuota"], 8) + self.assertIsNotNone(cpu_resource) + self.assertEqual(cpu_resource["nominalQuota"], 100) + self.assertIsNotNone(memory_resource) + self.assertEqual(memory_resource["nominalQuota"], "100Gi") + resource_flavor = next( + (doc for doc in manifest_docs if doc["kind"] == "ResourceFlavor"), None + ) + self.assertIsNotNone(resource_flavor) + self.assertEqual( + resource_flavor["spec"]["nodeLabels"][ + "cloud.google.com/gke-tpu-accelerator" + ], + "test-accelerator", + ) + self.assertEqual( + resource_flavor["spec"]["nodeLabels"][ + "cloud.google.com/gke-tpu-topology" + ], + "2x2x1", + ) + @patch("xpk.core.kueue_manager.KueueManager._KueueManager__apply_manifest") @patch("xpk.core.kueue_manager.KueueManager._KueueManager__install") @patch( diff --git a/src/xpk/utils/kueue.py b/src/xpk/utils/kueue.py new file mode 100644 index 000000000..3b9175c51 --- /dev/null +++ b/src/xpk/utils/kueue.py @@ -0,0 +1,20 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + + +def is_queued_cluster(num_slices: int) -> bool: + """Determines if admission checks should be enabled and cluster queued.""" + return num_slices <= 1