diff --git a/aten/src/ATen/native/cuda/Indexing.cu b/aten/src/ATen/native/cuda/Indexing.cu index e49fffc2effc..df43bfac16f7 100644 --- a/aten/src/ATen/native/cuda/Indexing.cu +++ b/aten/src/ATen/native/cuda/Indexing.cu @@ -56,9 +56,10 @@ constexpr uint64_t getDefaultMaxThreadsPerBlock() { #endif } -#if 0 +#ifdef USE_ROCM +#define SKIP_SORTED_INDICES 32 template -__global__ void indexing_backward_kernel( +__global__ void indexing_backward_kernel_many_indices( const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight, int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) { using opmath_t = at::opmath_type; @@ -141,10 +142,7 @@ __global__ void indexing_backward_kernel( } } } -#endif -#ifdef USE_ROCM -#define SKIP_SORTED_INDICES 32 template __global__ void indexing_backward_kernel_stride_1( const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight, @@ -784,6 +782,38 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List= 200000) + AT_DISPATCH_V2( + expandedValue.scalar_type(), + "indexing_backward_many_indices", + AT_WRAP([&] { + indexing_backward_kernel_many_indices<<>>( + sorted_indices.const_data_ptr(), + orig_indices.const_data_ptr(), + expandedValue.const_data_ptr(), + src_.mutable_data_ptr(), + num_indices, + sliceSize, + strideBefore, + nElemBefore, + accumulate); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }), + AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), + // AT_EXPAND(AT_FLOAT8_TYPES), + // TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True + // should not be supported here, then reenable AT_FLOAT8_DTYPES + kFloat8_e4m3fn, + kFloat8_e5m2, + kFloat8_e4m3fnuz, + kFloat8_e5m2fnuz, + kComplexHalf, + kHalf, + kBool, + kBFloat16); + else +#endif AT_DISPATCH_V2( expandedValue.scalar_type(), "indexing_backward",