diff --git a/src/xpk/core/system_characteristics.py b/src/xpk/core/system_characteristics.py index b16d02eea..400dc632a 100644 --- a/src/xpk/core/system_characteristics.py +++ b/src/xpk/core/system_characteristics.py @@ -131,6 +131,33 @@ def get_system_characteristics_by_device_type( return None, 1 +def generate_tpu_topologies( + max_cubes: int, enforce_nondecreasing: bool = True +) -> list[str]: + """Generates a list of unique TPU topologies formatted as strings "AxBxC". + + The list will contain all triplets (A, B, C) such that: + - A, B and C are integers in range 4..256 (including 4 and 256) + - A, B and C are divisible by 4 + - (A/4) * (B/4) * (C/4) <= max_cubes + - if enforce_nondecreasing: A <= B <= C + Additionally, the list will also contain the following triplets: + 2x2x1, 2x2x2, 2x2x4, 2x4x4 + + Args: + max_cubes: maximum number of cubes supported by a TPU platform + enforce_nondecreasing: whether to enforce A <= B <= C or not + """ + topologies = ['2x2x1', '2x2x2', '2x2x4', '2x4x4'] + MAX = 256 + for x in range(4, MAX + 1, 4): + for y in range(x if enforce_nondecreasing else 4, MAX + 1, 4): + for z in range(y if enforce_nondecreasing else 4, MAX + 1, 4): + if (x // 4) * (y // 4) * (z // 4) <= max_cubes: + topologies.append(f'{x}x{y}x{z}') + return topologies + + def get_tpu_system_characteristics_map( prefix: str, tensorcores_per_chip: int, @@ -345,106 +372,7 @@ def compute_vms_per_slice(topology: str) -> int: machine_type='tpu7x-standard-4t', tpu_type_requires_workload_policy=True, supports_sub_slicing=False, - supported_topologies=[ - '12x12x12', - '12x12x16', - '12x12x20', - '12x12x24', - '12x12x28', - '12x12x36', - '12x16x16', - '12x16x20', - '12x16x24', - '12x16x28', - '12x20x20', - '12x20x24', - '12x24x24', - '16x16x16', - '16x16x20', - '16x16x24', - '16x16x32', - '16x20x28', - '16x24x24', - '2x2x1', - '2x2x2', - '2x2x4', - '2x4x4', - '4x12x116', - '4x12x12', - '4x12x124', - '4x12x20', - '4x12x28', - '4x12x44', - '4x12x52', - '4x12x68', - '4x12x76', - '4x12x92', - '4x20x20', - '4x20x28', - '4x20x44', - '4x20x52', - '4x20x68', - '4x20x76', - '4x28x28', - '4x28x44', - '4x28x52', - '4x4x116', - '4x4x12', - '4x4x124', - '4x4x148', - '4x4x164', - '4x4x172', - '4x4x188', - '4x4x20', - '4x4x212', - '4x4x236', - '4x4x244', - '4x4x28', - '4x4x4', - '4x4x44', - '4x4x52', - '4x4x68', - '4x4x76', - '4x4x8', - '4x4x92', - '4x8x116', - '4x8x12', - '4x8x124', - '4x8x148', - '4x8x164', - '4x8x172', - '4x8x188', - '4x8x20', - '4x8x28', - '4x8x44', - '4x8x52', - '4x8x68', - '4x8x76', - '4x8x8', - '4x8x92', - '8x12x12', - '8x12x16', - '8x12x20', - '8x12x28', - '8x12x44', - '8x12x52', - '8x16x16', - '8x16x20', - '8x16x28', - '8x16x44', - '8x20x20', - '8x20x28', - '8x8x12', - '8x8x16', - '8x8x20', - '8x8x28', - '8x8x44', - '8x8x52', - '8x8x68', - '8x8x76', - '8x8x8', - '8x8x92', - ], + supported_topologies=generate_tpu_topologies(max_cubes=144), ), **get_tpu_system_characteristics_map( prefix='v6e', @@ -476,104 +404,7 @@ def compute_vms_per_slice(topology: str) -> int: gke_accelerator='tpu-v5p-slice', machine_type='ct5p-hightpu-4t', supports_sub_slicing=False, - supported_topologies=[ - '2x2x1', - '2x2x2', - '2x2x4', - '2x4x4', - '4x4x4', - '4x4x8', - '4x4x12', - '4x8x8', - '4x4x20', - '4x8x12', - '4x4x28', - '8x8x8', - '4x12x12', - '4x8x20', - '4x4x44', - '8x8x12', - '4x4x52', - '4x8x28', - '4x12x20', - '8x8x16', - '4x4x68', - '8x12x12', - '4x4x76', - '8x8x20', - '4x12x28', - '4x8x44', - '4x4x92', - '8x12x16', - '4x20x20', - '4x8x52', - '12x12x12', - '8x8x28', - '4x4x116', - '8x12x20', - '4x4x124', - '8x16x16', - '4x12x44', - '4x8x68', - '4x20x28', - '12x12x16', - '4x4x148', - '4x8x76', - '4x12x52', - '8x16x20', - '4x4x164', - '8x12x28', - '4x4x172', - '8x8x44', - '12x12x20', - '4x8x92', - '4x4x188', - '12x16x16', - '4x28x28', - '8x20x20', - '4x12x68', - '8x8x52', - '4x4x212', - '12x12x24', - '4x20x44', - '8x16x28', - '4x12x76', - '4x8x116', - '4x4x236', - '12x16x20', - '4x4x244', - '4x8x124', - '12x12x28', - '16x16x16', - '4x20x52', - '8x12x44', - '8x8x68', - '4x12x92', - '8x20x28', - '12x16x24', - '4x8x148', - '12x20x20', - '8x8x76', - '4x28x44', - '8x12x52', - '16x16x20', - '12x12x36', - '4x8x164', - '12x16x28', - '4x20x68', - '4x8x172', - '4x12x116', - '8x16x44', - '12x20x24', - '4x28x52', - '8x8x92', - '4x12x124', - '4x8x188', - '4x20x76', - '16x16x24', - '12x24x24', - '16x20x28', - ], + supported_topologies=generate_tpu_topologies(max_cubes=140), ), **get_tpu_system_characteristics_map( prefix='v5litepod', @@ -589,19 +420,9 @@ def compute_vms_per_slice(topology: str) -> int: gke_accelerator='tpu-v4-podslice', machine_type='ct4p-hightpu-4t', supports_sub_slicing=False, - supported_topologies=[ - '2x2x1', - '2x2x2', - '2x2x4', - '2x4x4', - '4x4x4', - '4x4x8', - '4x8x8', - '8x8x8', - '8x8x12', - '8x8x16', - '8x16x16', - ], + supported_topologies=generate_tpu_topologies( + max_cubes=64, enforce_nondecreasing=False + ), ), # CPU system characteristics. # Note that chips_per_vm is actually the number of vCPUs in that CPU. diff --git a/src/xpk/core/system_characteristics_test.py b/src/xpk/core/system_characteristics_test.py index f884f7a01..981bf28ba 100644 --- a/src/xpk/core/system_characteristics_test.py +++ b/src/xpk/core/system_characteristics_test.py @@ -14,7 +14,7 @@ limitations under the License. """ -from .system_characteristics import get_tpu_system_characteristics_map, SystemCharacteristics, AcceleratorType +from .system_characteristics import get_tpu_system_characteristics_map, generate_tpu_topologies, SystemCharacteristics, AcceleratorType def test_get_tpu_system_characteristics_map_returns_correct_values_for_1x1_topology(): @@ -99,3 +99,35 @@ def test_get_tpu_system_characteristics_map_returns_correct_values_for_2x2x2_top "test-16": expected_system_characteristics, "test-2x2x2": expected_system_characteristics, } + + +def test_generate_tpu_topologies_returns_correct_number_of_values_for_TPU_platforms(): + v4 = generate_tpu_topologies(max_cubes=64, enforce_nondecreasing=False) + v5p = generate_tpu_topologies(max_cubes=140) + tpu7x = generate_tpu_topologies(max_cubes=144) + + assert len(v4) == 800 + assert len(v5p) == 414 + assert len(tpu7x) == 432 + + +def test_generate_tpu_topologies_respects_constraints(): + ordered_6_cubes = generate_tpu_topologies( + max_cubes=6, enforce_nondecreasing=True + ) + non_ordered_6_cubes = generate_tpu_topologies( + max_cubes=6, enforce_nondecreasing=False + ) + + assert "8x4x4" not in ordered_6_cubes + assert "8x4x4" in non_ordered_6_cubes + assert "4x8x12" in ordered_6_cubes # exactly 6 cubes + assert "4x8x12" in non_ordered_6_cubes # exactly 6 cubes + assert "4x8x16" not in ordered_6_cubes # too many cubes (8) + assert "4x8x16" not in non_ordered_6_cubes # too many cubes (8) + + +def test_generate_tpu_topologies_contains_sub_cube_slices(): + one_cube = generate_tpu_topologies(max_cubes=1) + + assert one_cube == ["2x2x1", "2x2x2", "2x2x4", "2x4x4", "4x4x4"]