diff --git a/src/xpk/core/kueue_manager_test.py b/src/xpk/core/kueue_manager_test.py index 7fbce71f6..37f05c30e 100644 --- a/src/xpk/core/kueue_manager_test.py +++ b/src/xpk/core/kueue_manager_test.py @@ -14,6 +14,7 @@ limitations under the License. """ +from typing import Generator, TypeVar import unittest import yaml from unittest.mock import MagicMock, patch @@ -339,18 +340,16 @@ def test_resource_update_for_large_cluster(self, mock_run_retry): patch_call[0][0], ) - @patch("xpk.core.kueue_manager.KueueManager._KueueManager__apply_manifest") @patch("xpk.core.kueue_manager.KueueManager._KueueManager__install") @patch( "xpk.core.kueue_manager.KueueManager._KueueManager__update_kueue_resources_if_necessary" ) def test_configure_generates_correct_manifest( - self, mock_update_resources, mock_install, mock_apply_manifest + self, mock_update_resources, mock_install ): """Test that __configure generates the correct manifest content for TPUs.""" mock_install.return_value = 0 mock_update_resources.return_value = 0 - mock_apply_manifest.return_value = 0 kueue_config = KueueConfig( system=self.mock_system_chars, total_chips=8, @@ -360,27 +359,13 @@ def test_configure_generates_correct_manifest( num_slices=2, ) - with patch.object( - self.kueue_manager, "_KueueManager__get_installed_kueue_version" - ) as mock_get_version: - mock_get_version.return_value = (1, None) # Trigger install - with ( - patch.object(self.kueue_manager, "_KueueManager__install_kueue_crs"), - patch.object( - self.kueue_manager, "_KueueManager__wait_for_kueue_available" - ), - ): - self.kueue_manager.install_or_upgrade(kueue_config) - - mock_apply_manifest.assert_called_once() - rendered_manifest = mock_apply_manifest.call_args[0][0] + rendered_manifest = self._trigger_installation(kueue_config) self.assertNotIn("kind: Topology", rendered_manifest) manifest_docs = list(yaml.safe_load_all(rendered_manifest)) - cluster_queue = next( - (doc for doc in manifest_docs if doc["kind"] == "ClusterQueue"), None + cluster_queue = _first( + doc for doc in manifest_docs if doc["kind"] == "ClusterQueue" ) - self.assertIsNotNone(cluster_queue) self.assertEqual( cluster_queue["spec"]["resourceGroups"][0]["flavors"][0]["name"], "2xv5p-8", @@ -388,23 +373,15 @@ def test_configure_generates_correct_manifest( resources = cluster_queue["spec"]["resourceGroups"][0]["flavors"][0][ "resources" ] - tpu_resource = next( - (r for r in resources if r["name"] == "google.com/tpu"), None - ) - cpu_resource = next((r for r in resources if r["name"] == "cpu"), None) - memory_resource = next( - (r for r in resources if r["name"] == "memory"), None - ) - self.assertIsNotNone(tpu_resource) + tpu_resource = _first(r for r in resources if r["name"] == "google.com/tpu") + cpu_resource = _first(r for r in resources if r["name"] == "cpu") + memory_resource = _first(r for r in resources if r["name"] == "memory") self.assertEqual(tpu_resource["nominalQuota"], 8) - self.assertIsNotNone(cpu_resource) self.assertEqual(cpu_resource["nominalQuota"], 100) - self.assertIsNotNone(memory_resource) self.assertEqual(memory_resource["nominalQuota"], "100Gi") - resource_flavor = next( - (doc for doc in manifest_docs if doc["kind"] == "ResourceFlavor"), None + resource_flavor = _first( + doc for doc in manifest_docs if doc["kind"] == "ResourceFlavor" ) - self.assertIsNotNone(resource_flavor) self.assertEqual( resource_flavor["spec"]["nodeLabels"][ "cloud.google.com/gke-tpu-accelerator" @@ -418,18 +395,16 @@ def test_configure_generates_correct_manifest( "2x2x1", ) - @patch("xpk.core.kueue_manager.KueueManager._KueueManager__apply_manifest") @patch("xpk.core.kueue_manager.KueueManager._KueueManager__install") @patch( "xpk.core.kueue_manager.KueueManager._KueueManager__update_kueue_resources_if_necessary" ) def test_configure_generates_manifest_with_admission_checks_for_flex_single_slice( - self, mock_update_resources, mock_install, mock_apply_manifest + self, mock_update_resources, mock_install ): """Test that __configure generates the correct manifest with admission checks.""" mock_install.return_value = 0 mock_update_resources.return_value = 0 - mock_apply_manifest.return_value = 0 kueue_config = KueueConfig( system=self.mock_system_chars, total_chips=8, @@ -440,45 +415,29 @@ def test_configure_generates_manifest_with_admission_checks_for_flex_single_slic flex=True, ) - with patch.object( - self.kueue_manager, "_KueueManager__get_installed_kueue_version" - ) as mock_get_version: - mock_get_version.return_value = (1, None) # Trigger install - with ( - patch.object(self.kueue_manager, "_KueueManager__install_kueue_crs"), - patch.object( - self.kueue_manager, "_KueueManager__wait_for_kueue_available" - ), - ): - self.kueue_manager.install_or_upgrade(kueue_config) - - mock_apply_manifest.assert_called_once() - rendered_manifest = mock_apply_manifest.call_args[0][0] + rendered_manifest = self._trigger_installation(kueue_config) self.assertNotIn("kind: Topology", rendered_manifest) manifest_docs = list(yaml.safe_load_all(rendered_manifest)) - cluster_queue = next( - (doc for doc in manifest_docs if doc["kind"] == "ClusterQueue"), None + cluster_queue = _first( + doc for doc in manifest_docs if doc["kind"] == "ClusterQueue" ) - self.assertIsNotNone(cluster_queue) self.assertEqual( cluster_queue["spec"]["resourceGroups"][0]["flavors"][0]["name"], "1xv5p-8", ) self.assertEqual(cluster_queue["spec"]["admissionChecks"][0], "dws-prov") - @patch("xpk.core.kueue_manager.KueueManager._KueueManager__apply_manifest") @patch("xpk.core.kueue_manager.KueueManager._KueueManager__install") @patch( "xpk.core.kueue_manager.KueueManager._KueueManager__update_kueue_resources_if_necessary" ) def test_configure_generates_correct_manifest_with_topology( - self, mock_update_resources, mock_install, mock_apply_manifest + self, mock_update_resources, mock_install ): """Test that __configure generates correct manifest for GPUs.""" mock_install.return_value = 0 mock_update_resources.return_value = 0 - mock_apply_manifest.return_value = 0 kueue_config = KueueConfig( system=self.mock_system_chars_gpu, total_chips=16, @@ -487,20 +446,13 @@ def test_configure_generates_correct_manifest_with_topology( num_slices=2, ) - with patch.object( - self.kueue_manager, "_KueueManager__get_installed_kueue_version" - ) as mock_get_version: - mock_get_version.return_value = (1, None) # Trigger install - self.kueue_manager.install_or_upgrade(kueue_config) + rendered_manifest = self._trigger_installation(kueue_config) - mock_apply_manifest.assert_called_once() - rendered_manifest = mock_apply_manifest.call_args[0][0] self.assertIn("kind: Topology", rendered_manifest) manifest_docs = list(yaml.safe_load_all(rendered_manifest)) - resource_flavor = next( - (doc for doc in manifest_docs if doc["kind"] == "ResourceFlavor"), None + resource_flavor = _first( + doc for doc in manifest_docs if doc["kind"] == "ResourceFlavor" ) - self.assertIsNotNone(resource_flavor) self.assertEqual( resource_flavor["spec"]["nodeLabels"][ "cloud.google.com/gke-accelerator" @@ -508,18 +460,16 @@ def test_configure_generates_correct_manifest_with_topology( "h100-mega-80gb-8", ) - @patch("xpk.core.kueue_manager.KueueManager._KueueManager__apply_manifest") @patch("xpk.core.kueue_manager.KueueManager._KueueManager__install") @patch( "xpk.core.kueue_manager.KueueManager._KueueManager__update_kueue_resources_if_necessary" ) def test_configure_generates_correct_manifest_with_pathways( - self, mock_update_resources, mock_install, mock_apply_manifest + self, mock_update_resources, mock_install ): """Test that __configure generates the correct manifest with pathways enabled.""" mock_install.return_value = 0 mock_update_resources.return_value = 0 - mock_apply_manifest.return_value = 0 kueue_config = KueueConfig( system=self.mock_system_chars, total_chips=8, @@ -529,37 +479,25 @@ def test_configure_generates_correct_manifest_with_pathways( num_slices=2, ) - with patch.object( - self.kueue_manager, "_KueueManager__get_installed_kueue_version" - ) as mock_get_version: - mock_get_version.return_value = (1, None) # Trigger install - self.kueue_manager.install_or_upgrade(kueue_config) - - mock_apply_manifest.assert_called_once() - rendered_manifest = mock_apply_manifest.call_args[0][0] + rendered_manifest = self._trigger_installation(kueue_config) manifest_docs = list(yaml.safe_load_all(rendered_manifest)) # Check for the new "cpu-user" ResourceFlavor - cpu_user_flavor = next( - ( - doc - for doc in manifest_docs - if doc["kind"] == "ResourceFlavor" - and doc["metadata"]["name"] == "cpu-user" - ), - None, + cpu_user_flavor = _first( + doc + for doc in manifest_docs + if doc["kind"] == "ResourceFlavor" + and doc["metadata"]["name"] == "cpu-user" ) - self.assertIsNotNone(cpu_user_flavor) self.assertEqual( cpu_user_flavor["spec"]["nodeLabels"]["cloud.google.com/gke-nodepool"], "cpu-np", ) # Check that the ClusterQueue has the new resource group for pathways - cluster_queue = next( - (doc for doc in manifest_docs if doc["kind"] == "ClusterQueue"), None + cluster_queue = _first( + doc for doc in manifest_docs if doc["kind"] == "ClusterQueue" ) - self.assertIsNotNone(cluster_queue) self.assertEqual(len(cluster_queue["spec"]["resourceGroups"]), 2) pathways_rg = cluster_queue["spec"]["resourceGroups"][1] self.assertEqual(pathways_rg["coveredResources"], ["cpu", "memory"]) @@ -571,6 +509,34 @@ def test_configure_generates_correct_manifest_with_pathways( pathways_rg["flavors"][0]["resources"][1]["nominalQuota"], "2000G" ) + def _trigger_installation(self, kueue_config: KueueConfig) -> str: + """Calls Kueue installation and returns the rendered manifest.""" + with ( + patch.object( + self.kueue_manager, "_KueueManager__get_installed_kueue_version" + ) as mock_get_version, + patch.object( + self.kueue_manager, "_KueueManager__apply_manifest" + ) as mock_apply_manifest, + ): + mock_apply_manifest.return_value = 0 + mock_get_version.return_value = (1, None) + self.kueue_manager.install_or_upgrade(kueue_config) + + mock_apply_manifest.assert_called_once() + manifest = mock_apply_manifest.call_args[0][0] + assert isinstance(manifest, str) + return manifest + + +T = TypeVar("T") + + +def _first(generator: Generator[T, None, None]) -> T: + result = next(generator, None) + assert result is not None + return result + if __name__ == "__main__": unittest.main()