diff --git a/src/xpk/core/nodepool.py b/src/xpk/core/nodepool.py index 58ab038c4..cab159f15 100644 --- a/src/xpk/core/nodepool.py +++ b/src/xpk/core/nodepool.py @@ -14,6 +14,7 @@ limitations under the License. """ +from typing import List from ..utils.console import get_user_input, xpk_print from .capacity import ( AUTOPROVISIONING_CONFIG_VALUE, @@ -90,20 +91,26 @@ def run_gke_node_pool_create_command( xpk_print('Parsing capacity arguments failed!') return return_code - if system.accelerator_type == AcceleratorType['GPU']: - xpk_print( - f'Creating 1 node pool with {args.num_nodes} nodes of' - f' {system.device_type}\nUnderlyingly, we assume that means: {system}' - ) - desired_node_pool_names = [f'{args.cluster}-np-0'] - else: - xpk_print( - f'Creating {args.num_slices} node pool or pools of' - f' {system.device_type}\nUnderlyingly, we assume that means: {system}' - ) - desired_node_pool_names = [ - f'{args.cluster}-np-{slice_num}' for slice_num in range(args.num_slices) - ] + desired_node_pool_count = ( + 1 + if system.accelerator_type == AcceleratorType['GPU'] + else args.num_slices + ) + message = ( + ( + f'Creating 1 node pool with {args.num_nodes} nodes of' + f' {system.device_type}\nUnderlyingly, we assume that means: {system}' + ) + if system.accelerator_type == AcceleratorType['GPU'] + else ( + f'Creating {args.num_slices} node pool or pools of' + f' {system.device_type}\nUnderlyingly, we assume that means: {system}' + ) + ) + xpk_print(message) + desired_node_pool_names = get_desired_node_pool_names( + existing_node_pool_names, args.cluster, desired_node_pool_count + ) node_pools_to_remain = [] delete_commands = [] @@ -602,3 +609,21 @@ def get_nodepool_workload_metadata_mode( return 1, None return 0, nodepool_WI_mode.strip() + + +def get_desired_node_pool_names( + existing_node_pool_names: List[str], + cluster_name: str, + desired_node_pool_count: int, +) -> List[str]: + cluster_node_pools = [ + np + for np in existing_node_pool_names + if np.startswith(f'{cluster_name}-np-') + ] + result = set(cluster_node_pools[:desired_node_pool_count]) + i = 0 + while len(result) < desired_node_pool_count: + result.add(f'{cluster_name}-np-{i}') + i += 1 + return list(result) diff --git a/src/xpk/core/tests/unit/test_nodepool.py b/src/xpk/core/tests/unit/test_nodepool.py new file mode 100644 index 000000000..71cc540c3 --- /dev/null +++ b/src/xpk/core/tests/unit/test_nodepool.py @@ -0,0 +1,82 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from xpk.core.nodepool import get_desired_node_pool_names + +CLUSTER_NAME = "running-cucumber" + + +def node_pool_name(number: int) -> str: + return f"{CLUSTER_NAME}-np-{number}" + + +def test_compute_desired_node_pool_names_with_desired_larger_than_existing(): + result = get_desired_node_pool_names( + existing_node_pool_names=[node_pool_name(0)], + cluster_name=CLUSTER_NAME, + desired_node_pool_count=2, + ) + + expected_result = [node_pool_name(0), node_pool_name(1)] + assert set(result) == set(expected_result) + + +def test_compute_desired_node_pool_names_with_desired_smaller_than_existing(): + result = get_desired_node_pool_names( + existing_node_pool_names=[node_pool_name(0), node_pool_name(1)], + cluster_name=CLUSTER_NAME, + desired_node_pool_count=1, + ) + + expected_result = [node_pool_name(0)] + assert set(result) == set(expected_result) + + +def test_compute_desired_node_pool_names_with_consecutive_numbers_missing(): + result = get_desired_node_pool_names( + existing_node_pool_names=[node_pool_name(0), node_pool_name(3)], + cluster_name=CLUSTER_NAME, + desired_node_pool_count=3, + ) + + expected_result = [node_pool_name(0), node_pool_name(1), node_pool_name(3)] + assert set(result) == set(expected_result) + + +def test_compute_desired_node_pool_names_with_consecutive_numbers_missing_and_desired_equal_to_existing(): + result = get_desired_node_pool_names( + existing_node_pool_names=[node_pool_name(0), node_pool_name(3)], + cluster_name=CLUSTER_NAME, + desired_node_pool_count=2, + ) + + expected_result = [node_pool_name(0), node_pool_name(3)] + assert set(result) == set(expected_result) + + +def test_compute_desired_node_pool_names_with_unknown_node_pools(): + result = get_desired_node_pool_names( + existing_node_pool_names=[ + "unknown-node-pool", + node_pool_name(0), + node_pool_name(3), + ], + cluster_name=CLUSTER_NAME, + desired_node_pool_count=2, + ) + + expected_result = [node_pool_name(0), node_pool_name(3)] + assert set(result) == set(expected_result)