diff --git a/aten/src/ATen/native/cuda/Normalization.cuh b/aten/src/ATen/native/cuda/Normalization.cuh index 088aa517aa23a..5e6e59784ef39 100644 --- a/aten/src/ATen/native/cuda/Normalization.cuh +++ b/aten/src/ATen/native/cuda/Normalization.cuh @@ -23,7 +23,7 @@ namespace at::native { // The maximum number of threads in a block #if defined(USE_ROCM) -constexpr int MAX_BLOCK_SIZE = 256; +constexpr int MAX_BLOCK_SIZE = 1024; #else constexpr int MAX_BLOCK_SIZE = 512; #endif @@ -33,7 +33,7 @@ constexpr unsigned MAX_GRID_SIZE = 65535u; // Number of threads in a block given an input size up to MAX_BLOCK_SIZE static int getNumThreads(int nElem) { #if defined(USE_ROCM) - int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE }; + int threadSizes[5] = { 64, 128, 256, 512, MAX_BLOCK_SIZE }; #else int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE }; #endif