diff --git a/src/xpk/core/system_characteristics.py b/src/xpk/core/system_characteristics.py index edb9e75f9..b81224fa6 100644 --- a/src/xpk/core/system_characteristics.py +++ b/src/xpk/core/system_characteristics.py @@ -135,10 +135,9 @@ def get_tpu_system_characteristics_map( ) -> dict[str, SystemCharacteristics]: system_characteristics_map = {} for topology in supported_topologies: - total_chips = get_topology_product(topology) - num_tensorcores = total_chips * tensorcores_per_chip - chips_per_vm = 1 if total_chips == 1 else 4 - vms_per_slice = total_chips // chips_per_vm + chips_per_vm = compute_chips_per_vm(topology) + vms_per_slice = compute_vms_per_slice(topology) + num_tensorcores = compute_num_tensorcores(tensorcores_per_chip, topology) system = SystemCharacteristics( topology=topology, vms_per_slice=vms_per_slice, @@ -156,6 +155,19 @@ def get_tpu_system_characteristics_map( return system_characteristics_map +def compute_chips_per_vm(topology: str) -> int: + return 1 if get_topology_product(topology) == 1 else 4 + + +def compute_num_tensorcores(tensorcores_per_chip: int, topology: str) -> int: + return get_topology_product(topology) * tensorcores_per_chip + + +def compute_vms_per_slice(topology: str) -> int: + chips_per_vm = compute_chips_per_vm(topology) + return get_topology_product(topology) // chips_per_vm + + ################### Subcommand Helper Functions ############################# """ !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! IF YOU MODIFY THE BELOW UserFacingNameToSystemCharacteristics MAP YOU SHOULD diff --git a/src/xpk/core/system_characteristics_test.py b/src/xpk/core/system_characteristics_test.py new file mode 100644 index 000000000..f0933f0e0 --- /dev/null +++ b/src/xpk/core/system_characteristics_test.py @@ -0,0 +1,73 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from .system_characteristics import get_tpu_system_characteristics_map, SystemCharacteristics + + +def test_get_tpu_system_characteristics_map_returns_correct_values_for_1x1_topology(): + result = get_tpu_system_characteristics_map( + prefix="test", + tensorcores_per_chip=1, + gke_accelerator="test", + machine_type="test", + supported_topologies=["1x1"], + supports_sub_slicing=False, + requires_workload_policy=True, + ) + + expected_system_characteristics = SystemCharacteristics( + topology="1x1", + vms_per_slice=1, + gke_accelerator="test", + gce_machine_type="test", + chips_per_vm=1, + accelerator_type=1, + device_type="test-1", + supports_sub_slicing=False, + requires_workload_policy=True, + ) + assert result == { + "test-1": expected_system_characteristics, + "test-1x1": expected_system_characteristics, + } + + +def test_get_tpu_system_characteristics_map_returns_correct_values_for_2x2_topology(): + result = get_tpu_system_characteristics_map( + prefix="test", + tensorcores_per_chip=2, + gke_accelerator="test", + machine_type="test", + supported_topologies=["2x2"], + supports_sub_slicing=False, + requires_workload_policy=True, + ) + + expected_system_characteristics = SystemCharacteristics( + topology="2x2", + vms_per_slice=1, + gke_accelerator="test", + gce_machine_type="test", + chips_per_vm=4, + accelerator_type=1, + device_type="test-8", + supports_sub_slicing=False, + requires_workload_policy=True, + ) + assert result == { + "test-8": expected_system_characteristics, + "test-2x2": expected_system_characteristics, + }