From d414e2748d76e23a3092659c8b9bb2d32b066462 Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Sat, 2 Aug 2025 07:01:33 +0100 Subject: [PATCH 1/2] [release/2.7] [SWDEV-543214] Reland #2416 Fix warps runtime part 2 (#2442) https://github.com/ROCm/pytorch/pull/2421 didn't bring in all required changes to reland https://github.com/ROCm/pytorch/pull/2416 --- torch/_inductor/runtime/coordinate_descent_tuner.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/torch/_inductor/runtime/coordinate_descent_tuner.py b/torch/_inductor/runtime/coordinate_descent_tuner.py index 62a2abcea8d2d..7ce7dcbbffbbc 100644 --- a/torch/_inductor/runtime/coordinate_descent_tuner.py +++ b/torch/_inductor/runtime/coordinate_descent_tuner.py @@ -60,9 +60,14 @@ def get_config_max(self, prefix: str) -> int: return min(max_block, size_hint) if size_hint is not None else max_block def get_warpsmax(self): - # Currently, CUDA has a maximum of 1024 threads, so 32 is the max - # number of warps. - return 1024 // 32 + # CUDA/ROCm has a maximum of 1024 threads per block + from torch.cuda import current_device, get_device_properties, is_available + + warp_size = ( + get_device_properties(current_device()).warp_size if is_available() else 32 + ) + + return 1024 // warp_size def cache_benchmark_result(self, config, timing): self.cached_benchmark_results[triton_config_to_hashable(config)] = timing From 535d4e0c18ce9c3dbf71668e812b9f133c90d66f Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Mon, 4 Aug 2025 14:57:15 +0100 Subject: [PATCH 2/2] Update coordinate_descent_tuner.py --- torch/_inductor/runtime/coordinate_descent_tuner.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/_inductor/runtime/coordinate_descent_tuner.py b/torch/_inductor/runtime/coordinate_descent_tuner.py index 7ce7dcbbffbbc..ac308d6e24e0c 100644 --- a/torch/_inductor/runtime/coordinate_descent_tuner.py +++ b/torch/_inductor/runtime/coordinate_descent_tuner.py @@ -3,6 +3,7 @@ import itertools import logging from typing import Callable, Optional +from functools import lru_cache from .hints import TRITON_MAX_BLOCK from .runtime_utils import red_text, triton_config_to_hashable @@ -59,6 +60,7 @@ def get_config_max(self, prefix: str) -> int: size_hint = self.size_hints.get(prefix) if self.size_hints is not None else None return min(max_block, size_hint) if size_hint is not None else max_block + @lru_cache(maxsize=1) def get_warpsmax(self): # CUDA/ROCm has a maximum of 1024 threads per block from torch.cuda import current_device, get_device_properties, is_available