From 0ad12770c8a0a040e0bab3e8f8fa00abcf9e182a Mon Sep 17 00:00:00 2001 From: Jerry Mannil <65309407+jerrymannil@users.noreply.github.com> Date: Wed, 22 Oct 2025 06:53:31 -0700 Subject: [PATCH] [ROCm] [Normalization] Update block size (#2738) cherry-pick of https://github.com/pytorch/pytorch/commit/9f82535c5a8e0139f88bb64be9dc6b4a61be2947 Fixes #SWDEV-561122 --- aten/src/ATen/native/cuda/Normalization.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/cuda/Normalization.cuh b/aten/src/ATen/native/cuda/Normalization.cuh index 088aa517aa23a..5e6e59784ef39 100644 --- a/aten/src/ATen/native/cuda/Normalization.cuh +++ b/aten/src/ATen/native/cuda/Normalization.cuh @@ -23,7 +23,7 @@ namespace at::native { // The maximum number of threads in a block #if defined(USE_ROCM) -constexpr int MAX_BLOCK_SIZE = 256; +constexpr int MAX_BLOCK_SIZE = 1024; #else constexpr int MAX_BLOCK_SIZE = 512; #endif @@ -33,7 +33,7 @@ constexpr unsigned MAX_GRID_SIZE = 65535u; // Number of threads in a block given an input size up to MAX_BLOCK_SIZE static int getNumThreads(int nElem) { #if defined(USE_ROCM) - int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE }; + int threadSizes[5] = { 64, 128, 256, 512, MAX_BLOCK_SIZE }; #else int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE }; #endif