From 77a47f8eea1b6978fcd29fc665891814411c69ff Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Sat, 2 Aug 2025 07:01:33 +0100 Subject: [PATCH 1/2] [release/2.7] [SWDEV-543214] Reland #2416 Fix warps runtime part 2 (#2442) https://github.com/ROCm/pytorch/pull/2421 didn't bring in all required changes to reland https://github.com/ROCm/pytorch/pull/2416 --- torch/_inductor/runtime/coordinate_descent_tuner.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/torch/_inductor/runtime/coordinate_descent_tuner.py b/torch/_inductor/runtime/coordinate_descent_tuner.py index 4c2af613a04ca..b27d1ffec265a 100644 --- a/torch/_inductor/runtime/coordinate_descent_tuner.py +++ b/torch/_inductor/runtime/coordinate_descent_tuner.py @@ -70,9 +70,14 @@ def get_config_max(self, prefix: str) -> int: return min(max_block, size_hint) if size_hint is not None else max_block def get_warpsmax(self): - # Currently, CUDA has a maximum of 1024 threads, so 32 is the max - # number of warps. - return 1024 // 32 + # CUDA/ROCm has a maximum of 1024 threads per block + from torch.cuda import current_device, get_device_properties, is_available + + warp_size = ( + get_device_properties(current_device()).warp_size if is_available() else 32 + ) + + return 1024 // warp_size def cache_benchmark_result(self, config, timing): self.cached_benchmark_results[triton_config_to_hashable(config)] = timing From 92bf6bc4a6415a3b9f67aaac231d0818cd9838ab Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Mon, 4 Aug 2025 14:57:52 +0100 Subject: [PATCH 2/2] Update coordinate_descent_tuner.py --- torch/_inductor/runtime/coordinate_descent_tuner.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/_inductor/runtime/coordinate_descent_tuner.py b/torch/_inductor/runtime/coordinate_descent_tuner.py index b27d1ffec265a..899cbb0fe417c 100644 --- a/torch/_inductor/runtime/coordinate_descent_tuner.py +++ b/torch/_inductor/runtime/coordinate_descent_tuner.py @@ -3,6 +3,7 @@ import itertools import logging from typing import Callable, Optional +from functools import lru_cache from .hints import TRITON_MAX_BLOCK from .runtime_utils import red_text, triton_config_to_hashable @@ -69,6 +70,7 @@ def get_config_max(self, prefix: str) -> int: size_hint = self.prefix_to_size_hint(prefix) return min(max_block, size_hint) if size_hint is not None else max_block + @lru_cache(maxsize=1) def get_warpsmax(self): # CUDA/ROCm has a maximum of 1024 threads per block from torch.cuda import current_device, get_device_properties, is_available