diff --git a/src/xpk/core/system_characteristics.py b/src/xpk/core/system_characteristics.py index 400dc632a..4d3e5aa72 100644 --- a/src/xpk/core/system_characteristics.py +++ b/src/xpk/core/system_characteristics.py @@ -166,12 +166,16 @@ def get_tpu_system_characteristics_map( supported_topologies: list[str], supports_sub_slicing: bool, tpu_type_requires_workload_policy: bool = False, + default_topologies: set[str] | None = None, ) -> dict[str, SystemCharacteristics]: system_characteristics_map = {} + if default_topologies is None: + default_topologies = set() for topology in supported_topologies: chips_per_vm = compute_chips_per_vm(topology) vms_per_slice = compute_vms_per_slice(topology) num_tensorcores = compute_num_tensorcores(tensorcores_per_chip, topology) + device_type = f'{prefix}-{num_tensorcores}' system = SystemCharacteristics( topology=topology, vms_per_slice=vms_per_slice, @@ -179,13 +183,17 @@ def get_tpu_system_characteristics_map( gce_machine_type=machine_type, chips_per_vm=chips_per_vm, accelerator_type=AcceleratorType.TPU, - device_type=f'{prefix}-{num_tensorcores}', + device_type=device_type, requires_workload_policy=tpu_type_requires_workload_policy and vms_per_slice > 1, supports_sub_slicing=supports_sub_slicing, ) system_characteristics_map[f'{prefix}-{topology}'] = system - system_characteristics_map[f'{prefix}-{num_tensorcores}'] = system + if ( + topology in default_topologies + or device_type not in system_characteristics_map + ): + system_characteristics_map[device_type] = system return system_characteristics_map @@ -373,6 +381,106 @@ def compute_vms_per_slice(topology: str) -> int: tpu_type_requires_workload_policy=True, supports_sub_slicing=False, supported_topologies=generate_tpu_topologies(max_cubes=144), + default_topologies=set([ + '12x12x12', + '12x12x16', + '12x12x20', + '12x12x24', + '12x12x28', + '12x12x36', + '12x16x16', + '12x16x20', + '12x16x24', + '12x16x28', + '12x20x20', + '12x20x24', + '12x24x24', + '16x16x16', + '16x16x20', + '16x16x24', + '16x16x32', + '16x20x28', + '16x24x24', + '2x2x1', + '2x2x2', + '2x2x4', + '2x4x4', + '4x12x116', + '4x12x12', + '4x12x124', + '4x12x20', + '4x12x28', + '4x12x44', + '4x12x52', + '4x12x68', + '4x12x76', + '4x12x92', + '4x20x20', + '4x20x28', + '4x20x44', + '4x20x52', + '4x20x68', + '4x20x76', + '4x28x28', + '4x28x44', + '4x28x52', + '4x4x116', + '4x4x12', + '4x4x124', + '4x4x148', + '4x4x164', + '4x4x172', + '4x4x188', + '4x4x20', + '4x4x212', + '4x4x236', + '4x4x244', + '4x4x28', + '4x4x4', + '4x4x44', + '4x4x52', + '4x4x68', + '4x4x76', + '4x4x8', + '4x4x92', + '4x8x116', + '4x8x12', + '4x8x124', + '4x8x148', + '4x8x164', + '4x8x172', + '4x8x188', + '4x8x20', + '4x8x28', + '4x8x44', + '4x8x52', + '4x8x68', + '4x8x76', + '4x8x8', + '4x8x92', + '8x12x12', + '8x12x16', + '8x12x20', + '8x12x28', + '8x12x44', + '8x12x52', + '8x16x16', + '8x16x20', + '8x16x28', + '8x16x44', + '8x20x20', + '8x20x28', + '8x8x12', + '8x8x16', + '8x8x20', + '8x8x28', + '8x8x44', + '8x8x52', + '8x8x68', + '8x8x76', + '8x8x8', + '8x8x92', + ]), ), **get_tpu_system_characteristics_map( prefix='v6e', @@ -405,6 +513,104 @@ def compute_vms_per_slice(topology: str) -> int: machine_type='ct5p-hightpu-4t', supports_sub_slicing=False, supported_topologies=generate_tpu_topologies(max_cubes=140), + default_topologies=set([ + '2x2x1', + '2x2x2', + '2x2x4', + '2x4x4', + '4x4x4', + '4x4x8', + '4x4x12', + '4x8x8', + '4x4x20', + '4x8x12', + '4x4x28', + '8x8x8', + '4x12x12', + '4x8x20', + '4x4x44', + '8x8x12', + '4x4x52', + '4x8x28', + '4x12x20', + '8x8x16', + '4x4x68', + '8x12x12', + '4x4x76', + '8x8x20', + '4x12x28', + '4x8x44', + '4x4x92', + '8x12x16', + '4x20x20', + '4x8x52', + '12x12x12', + '8x8x28', + '4x4x116', + '8x12x20', + '4x4x124', + '8x16x16', + '4x12x44', + '4x8x68', + '4x20x28', + '12x12x16', + '4x4x148', + '4x8x76', + '4x12x52', + '8x16x20', + '4x4x164', + '8x12x28', + '4x4x172', + '8x8x44', + '12x12x20', + '4x8x92', + '4x4x188', + '12x16x16', + '4x28x28', + '8x20x20', + '4x12x68', + '8x8x52', + '4x4x212', + '12x12x24', + '4x20x44', + '8x16x28', + '4x12x76', + '4x8x116', + '4x4x236', + '12x16x20', + '4x4x244', + '4x8x124', + '12x12x28', + '16x16x16', + '4x20x52', + '8x12x44', + '8x8x68', + '4x12x92', + '8x20x28', + '12x16x24', + '4x8x148', + '12x20x20', + '8x8x76', + '4x28x44', + '8x12x52', + '16x16x20', + '12x12x36', + '4x8x164', + '12x16x28', + '4x20x68', + '4x8x172', + '4x12x116', + '8x16x44', + '12x20x24', + '4x28x52', + '8x8x92', + '4x12x124', + '4x8x188', + '4x20x76', + '16x16x24', + '12x24x24', + '16x20x28', + ]), ), **get_tpu_system_characteristics_map( prefix='v5litepod', @@ -423,6 +629,19 @@ def compute_vms_per_slice(topology: str) -> int: supported_topologies=generate_tpu_topologies( max_cubes=64, enforce_nondecreasing=False ), + default_topologies=set([ + '2x2x1', + '2x2x2', + '2x2x4', + '2x4x4', + '4x4x4', + '4x4x8', + '4x8x8', + '8x8x8', + '8x8x12', + '8x8x16', + '8x16x16', + ]), ), # CPU system characteristics. # Note that chips_per_vm is actually the number of vCPUs in that CPU. diff --git a/src/xpk/core/system_characteristics_test.py b/src/xpk/core/system_characteristics_test.py index 981bf28ba..befe283d1 100644 --- a/src/xpk/core/system_characteristics_test.py +++ b/src/xpk/core/system_characteristics_test.py @@ -101,6 +101,21 @@ def test_get_tpu_system_characteristics_map_returns_correct_values_for_2x2x2_top } +def test_get_tpu_system_characteristics_map_prefers_default_topologies(): + result = get_tpu_system_characteristics_map( + prefix="test", + tensorcores_per_chip=2, + gke_accelerator="test", + machine_type="test", + supported_topologies=["4x4x4", "4x4x32", "4x8x16", "8x8x8"], + supports_sub_slicing=False, + default_topologies=set(["4x8x16"]), + ) + + assert result["test-128"].topology == "4x4x4" + assert result["test-1024"].topology == "4x8x16" + + def test_generate_tpu_topologies_returns_correct_number_of_values_for_TPU_platforms(): v4 = generate_tpu_topologies(max_cubes=64, enforce_nondecreasing=False) v5p = generate_tpu_topologies(max_cubes=140)