diff --git a/aten/src/ATen/native/cuda/Indexing.cu b/aten/src/ATen/native/cuda/Indexing.cu index 02feb55cb69d..e49fffc2effc 100644 --- a/aten/src/ATen/native/cuda/Indexing.cu +++ b/aten/src/ATen/native/cuda/Indexing.cu @@ -56,8 +56,7 @@ constexpr uint64_t getDefaultMaxThreadsPerBlock() { #endif } -#ifdef USE_ROCM -#define SKIP_SORTED_INDICES 32 +#if 0 template __global__ void indexing_backward_kernel( const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight, @@ -142,7 +141,10 @@ __global__ void indexing_backward_kernel( } } } +#endif +#ifdef USE_ROCM +#define SKIP_SORTED_INDICES 32 template __global__ void indexing_backward_kernel_stride_1( const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight, @@ -254,7 +256,8 @@ __global__ void indexing_backward_kernel_stride_1( } } } -#else +#endif + template __global__ void indexing_backward_kernel( const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight, @@ -333,6 +336,7 @@ __global__ void indexing_backward_kernel( } } +#ifndef USE_ROCM template __global__ void indexing_backward_kernel_stride_1( const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight, @@ -784,7 +788,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<<>>( + indexing_backward_kernel<<>>( sorted_indices.const_data_ptr(), orig_indices.const_data_ptr(), expandedValue.const_data_ptr(),