diff --git a/torch/_inductor/runtime/coordinate_descent_tuner.py b/torch/_inductor/runtime/coordinate_descent_tuner.py index 6626c88a1e0d..f58f4da06113 100644 --- a/torch/_inductor/runtime/coordinate_descent_tuner.py +++ b/torch/_inductor/runtime/coordinate_descent_tuner.py @@ -63,9 +63,14 @@ def get_config_max(self, prefix: str) -> int: @lru_cache(maxsize=1) def get_warpsmax(self): - # Currently, CUDA has a maximum of 1024 threads, so 32 is the max - # number of warps. - return 1024 // 32 + # CUDA/ROCm has a maximum of 1024 threads per block + from torch.cuda import current_device, get_device_properties, is_available + + warp_size = ( + get_device_properties(current_device()).warp_size if is_available() else 32 + ) + + return 1024 // warp_size def cache_benchmark_result(self, config, timing): self.cached_benchmark_results[triton_config_to_hashable(config)] = timing