Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 55 additions & 89 deletions src/xpk/core/kueue_manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
limitations under the License.
"""

from typing import Generator, TypeVar
import unittest
import yaml
from unittest.mock import MagicMock, patch
Expand Down Expand Up @@ -339,18 +340,16 @@ def test_resource_update_for_large_cluster(self, mock_run_retry):
patch_call[0][0],
)

@patch("xpk.core.kueue_manager.KueueManager._KueueManager__apply_manifest")
@patch("xpk.core.kueue_manager.KueueManager._KueueManager__install")
@patch(
"xpk.core.kueue_manager.KueueManager._KueueManager__update_kueue_resources_if_necessary"
)
def test_configure_generates_correct_manifest(
self, mock_update_resources, mock_install, mock_apply_manifest
self, mock_update_resources, mock_install
):
"""Test that __configure generates the correct manifest content for TPUs."""
mock_install.return_value = 0
mock_update_resources.return_value = 0
mock_apply_manifest.return_value = 0
kueue_config = KueueConfig(
system=self.mock_system_chars,
total_chips=8,
Expand All @@ -360,51 +359,29 @@ def test_configure_generates_correct_manifest(
num_slices=2,
)

with patch.object(
self.kueue_manager, "_KueueManager__get_installed_kueue_version"
) as mock_get_version:
mock_get_version.return_value = (1, None) # Trigger install
with (
patch.object(self.kueue_manager, "_KueueManager__install_kueue_crs"),
patch.object(
self.kueue_manager, "_KueueManager__wait_for_kueue_available"
),
):
self.kueue_manager.install_or_upgrade(kueue_config)

mock_apply_manifest.assert_called_once()
rendered_manifest = mock_apply_manifest.call_args[0][0]
rendered_manifest = self._trigger_installation(kueue_config)

self.assertNotIn("kind: Topology", rendered_manifest)
manifest_docs = list(yaml.safe_load_all(rendered_manifest))
cluster_queue = next(
(doc for doc in manifest_docs if doc["kind"] == "ClusterQueue"), None
cluster_queue = _first(
doc for doc in manifest_docs if doc["kind"] == "ClusterQueue"
)
self.assertIsNotNone(cluster_queue)
self.assertEqual(
cluster_queue["spec"]["resourceGroups"][0]["flavors"][0]["name"],
"2xv5p-8",
)
resources = cluster_queue["spec"]["resourceGroups"][0]["flavors"][0][
"resources"
]
tpu_resource = next(
(r for r in resources if r["name"] == "google.com/tpu"), None
)
cpu_resource = next((r for r in resources if r["name"] == "cpu"), None)
memory_resource = next(
(r for r in resources if r["name"] == "memory"), None
)
self.assertIsNotNone(tpu_resource)
tpu_resource = _first(r for r in resources if r["name"] == "google.com/tpu")
cpu_resource = _first(r for r in resources if r["name"] == "cpu")
memory_resource = _first(r for r in resources if r["name"] == "memory")
self.assertEqual(tpu_resource["nominalQuota"], 8)
self.assertIsNotNone(cpu_resource)
self.assertEqual(cpu_resource["nominalQuota"], 100)
self.assertIsNotNone(memory_resource)
self.assertEqual(memory_resource["nominalQuota"], "100Gi")
resource_flavor = next(
(doc for doc in manifest_docs if doc["kind"] == "ResourceFlavor"), None
resource_flavor = _first(
doc for doc in manifest_docs if doc["kind"] == "ResourceFlavor"
)
self.assertIsNotNone(resource_flavor)
self.assertEqual(
resource_flavor["spec"]["nodeLabels"][
"cloud.google.com/gke-tpu-accelerator"
Expand All @@ -418,18 +395,16 @@ def test_configure_generates_correct_manifest(
"2x2x1",
)

@patch("xpk.core.kueue_manager.KueueManager._KueueManager__apply_manifest")
@patch("xpk.core.kueue_manager.KueueManager._KueueManager__install")
@patch(
"xpk.core.kueue_manager.KueueManager._KueueManager__update_kueue_resources_if_necessary"
)
def test_configure_generates_manifest_with_admission_checks_for_flex_single_slice(
self, mock_update_resources, mock_install, mock_apply_manifest
self, mock_update_resources, mock_install
):
"""Test that __configure generates the correct manifest with admission checks."""
mock_install.return_value = 0
mock_update_resources.return_value = 0
mock_apply_manifest.return_value = 0
kueue_config = KueueConfig(
system=self.mock_system_chars,
total_chips=8,
Expand All @@ -440,45 +415,29 @@ def test_configure_generates_manifest_with_admission_checks_for_flex_single_slic
flex=True,
)

with patch.object(
self.kueue_manager, "_KueueManager__get_installed_kueue_version"
) as mock_get_version:
mock_get_version.return_value = (1, None) # Trigger install
with (
patch.object(self.kueue_manager, "_KueueManager__install_kueue_crs"),
patch.object(
self.kueue_manager, "_KueueManager__wait_for_kueue_available"
),
):
self.kueue_manager.install_or_upgrade(kueue_config)

mock_apply_manifest.assert_called_once()
rendered_manifest = mock_apply_manifest.call_args[0][0]
rendered_manifest = self._trigger_installation(kueue_config)

self.assertNotIn("kind: Topology", rendered_manifest)
manifest_docs = list(yaml.safe_load_all(rendered_manifest))
cluster_queue = next(
(doc for doc in manifest_docs if doc["kind"] == "ClusterQueue"), None
cluster_queue = _first(
doc for doc in manifest_docs if doc["kind"] == "ClusterQueue"
)
self.assertIsNotNone(cluster_queue)
self.assertEqual(
cluster_queue["spec"]["resourceGroups"][0]["flavors"][0]["name"],
"1xv5p-8",
)
self.assertEqual(cluster_queue["spec"]["admissionChecks"][0], "dws-prov")

@patch("xpk.core.kueue_manager.KueueManager._KueueManager__apply_manifest")
@patch("xpk.core.kueue_manager.KueueManager._KueueManager__install")
@patch(
"xpk.core.kueue_manager.KueueManager._KueueManager__update_kueue_resources_if_necessary"
)
def test_configure_generates_correct_manifest_with_topology(
self, mock_update_resources, mock_install, mock_apply_manifest
self, mock_update_resources, mock_install
):
"""Test that __configure generates correct manifest for GPUs."""
mock_install.return_value = 0
mock_update_resources.return_value = 0
mock_apply_manifest.return_value = 0
kueue_config = KueueConfig(
system=self.mock_system_chars_gpu,
total_chips=16,
Expand All @@ -487,39 +446,30 @@ def test_configure_generates_correct_manifest_with_topology(
num_slices=2,
)

with patch.object(
self.kueue_manager, "_KueueManager__get_installed_kueue_version"
) as mock_get_version:
mock_get_version.return_value = (1, None) # Trigger install
self.kueue_manager.install_or_upgrade(kueue_config)
rendered_manifest = self._trigger_installation(kueue_config)

mock_apply_manifest.assert_called_once()
rendered_manifest = mock_apply_manifest.call_args[0][0]
self.assertIn("kind: Topology", rendered_manifest)
manifest_docs = list(yaml.safe_load_all(rendered_manifest))
resource_flavor = next(
(doc for doc in manifest_docs if doc["kind"] == "ResourceFlavor"), None
resource_flavor = _first(
doc for doc in manifest_docs if doc["kind"] == "ResourceFlavor"
)
self.assertIsNotNone(resource_flavor)
self.assertEqual(
resource_flavor["spec"]["nodeLabels"][
"cloud.google.com/gke-accelerator"
],
"h100-mega-80gb-8",
)

@patch("xpk.core.kueue_manager.KueueManager._KueueManager__apply_manifest")
@patch("xpk.core.kueue_manager.KueueManager._KueueManager__install")
@patch(
"xpk.core.kueue_manager.KueueManager._KueueManager__update_kueue_resources_if_necessary"
)
def test_configure_generates_correct_manifest_with_pathways(
self, mock_update_resources, mock_install, mock_apply_manifest
self, mock_update_resources, mock_install
):
"""Test that __configure generates the correct manifest with pathways enabled."""
mock_install.return_value = 0
mock_update_resources.return_value = 0
mock_apply_manifest.return_value = 0
kueue_config = KueueConfig(
system=self.mock_system_chars,
total_chips=8,
Expand All @@ -529,37 +479,25 @@ def test_configure_generates_correct_manifest_with_pathways(
num_slices=2,
)

with patch.object(
self.kueue_manager, "_KueueManager__get_installed_kueue_version"
) as mock_get_version:
mock_get_version.return_value = (1, None) # Trigger install
self.kueue_manager.install_or_upgrade(kueue_config)

mock_apply_manifest.assert_called_once()
rendered_manifest = mock_apply_manifest.call_args[0][0]
rendered_manifest = self._trigger_installation(kueue_config)
manifest_docs = list(yaml.safe_load_all(rendered_manifest))

# Check for the new "cpu-user" ResourceFlavor
cpu_user_flavor = next(
(
doc
for doc in manifest_docs
if doc["kind"] == "ResourceFlavor"
and doc["metadata"]["name"] == "cpu-user"
),
None,
cpu_user_flavor = _first(
doc
for doc in manifest_docs
if doc["kind"] == "ResourceFlavor"
and doc["metadata"]["name"] == "cpu-user"
)
self.assertIsNotNone(cpu_user_flavor)
self.assertEqual(
cpu_user_flavor["spec"]["nodeLabels"]["cloud.google.com/gke-nodepool"],
"cpu-np",
)

# Check that the ClusterQueue has the new resource group for pathways
cluster_queue = next(
(doc for doc in manifest_docs if doc["kind"] == "ClusterQueue"), None
cluster_queue = _first(
doc for doc in manifest_docs if doc["kind"] == "ClusterQueue"
)
self.assertIsNotNone(cluster_queue)
self.assertEqual(len(cluster_queue["spec"]["resourceGroups"]), 2)
pathways_rg = cluster_queue["spec"]["resourceGroups"][1]
self.assertEqual(pathways_rg["coveredResources"], ["cpu", "memory"])
Expand All @@ -571,6 +509,34 @@ def test_configure_generates_correct_manifest_with_pathways(
pathways_rg["flavors"][0]["resources"][1]["nominalQuota"], "2000G"
)

def _trigger_installation(self, kueue_config: KueueConfig) -> str:
"""Calls Kueue installation and returns the rendered manifest."""
with (
patch.object(
self.kueue_manager, "_KueueManager__get_installed_kueue_version"
) as mock_get_version,
patch.object(
self.kueue_manager, "_KueueManager__apply_manifest"
) as mock_apply_manifest,
):
mock_apply_manifest.return_value = 0
mock_get_version.return_value = (1, None)
self.kueue_manager.install_or_upgrade(kueue_config)

mock_apply_manifest.assert_called_once()
manifest = mock_apply_manifest.call_args[0][0]
assert isinstance(manifest, str)
return manifest


T = TypeVar("T")


def _first(generator: Generator[T, None, None]) -> T:
result = next(generator, None)
assert result is not None
return result


if __name__ == "__main__":
unittest.main()
Loading