Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 35 additions & 5 deletions aten/src/ATen/native/cuda/Indexing.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,10 @@ constexpr uint64_t getDefaultMaxThreadsPerBlock() {
#endif
}

#if 0
#ifdef USE_ROCM
#define SKIP_SORTED_INDICES 32
template <typename scalar_t, int SZ>
__global__ void indexing_backward_kernel(
__global__ void indexing_backward_kernel_many_indices(
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) {
using opmath_t = at::opmath_type<scalar_t>;
Expand Down Expand Up @@ -141,10 +142,7 @@ __global__ void indexing_backward_kernel(
}
}
}
#endif

#ifdef USE_ROCM
#define SKIP_SORTED_INDICES 32
template <typename scalar_t>
__global__ void indexing_backward_kernel_stride_1(
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
Expand Down Expand Up @@ -784,6 +782,38 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
kBool,
kBFloat16);
} else {
#ifdef USE_ROCM
if (num_indices >= 200000)
AT_DISPATCH_V2(
expandedValue.scalar_type(),
"indexing_backward_many_indices",
AT_WRAP([&] {
indexing_backward_kernel_many_indices<scalar_t, UNROLL><<<new_grid, block, smem_dups_size, stream>>>(
sorted_indices.const_data_ptr<int64_t>(),
orig_indices.const_data_ptr<int64_t>(),
expandedValue.const_data_ptr<scalar_t>(),
src_.mutable_data_ptr<scalar_t>(),
num_indices,
sliceSize,
strideBefore,
nElemBefore,
accumulate);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}),
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
// AT_EXPAND(AT_FLOAT8_TYPES),
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
// should not be supported here, then reenable AT_FLOAT8_DTYPES
kFloat8_e4m3fn,
kFloat8_e5m2,
kFloat8_e4m3fnuz,
kFloat8_e5m2fnuz,
kComplexHalf,
kHalf,
kBool,
kBFloat16);
else
#endif
AT_DISPATCH_V2(
expandedValue.scalar_type(),
"indexing_backward",
Expand Down