Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions aten/src/ATen/native/cuda/Indexing.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ constexpr uint64_t getDefaultMaxThreadsPerBlock() {
#endif
}

#ifdef USE_ROCM
#define SKIP_SORTED_INDICES 32
#if 0
template <typename scalar_t, int SZ>
__global__ void indexing_backward_kernel(
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
Expand Down Expand Up @@ -142,7 +141,10 @@ __global__ void indexing_backward_kernel(
}
}
}
#endif

#ifdef USE_ROCM
#define SKIP_SORTED_INDICES 32
template <typename scalar_t>
__global__ void indexing_backward_kernel_stride_1(
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
Expand Down Expand Up @@ -254,7 +256,8 @@ __global__ void indexing_backward_kernel_stride_1(
}
}
}
#else
#endif

template <typename scalar_t, int SZ>
__global__ void indexing_backward_kernel(
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
Expand Down Expand Up @@ -333,6 +336,7 @@ __global__ void indexing_backward_kernel(
}
}

#ifndef USE_ROCM
template <typename scalar_t>
__global__ void indexing_backward_kernel_stride_1(
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
Expand Down Expand Up @@ -784,7 +788,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
expandedValue.scalar_type(),
"indexing_backward",
AT_WRAP([&] {
indexing_backward_kernel<scalar_t, UNROLL><<<KERNEL_GRID, block, KERNEL_SMEM, stream>>>(
indexing_backward_kernel<scalar_t, UNROLL><<<grid, block, 0, stream>>>(
sorted_indices.const_data_ptr<int64_t>(),
orig_indices.const_data_ptr<int64_t>(),
expandedValue.const_data_ptr<scalar_t>(),
Expand Down