Skip to content

Commit

Permalink
fix conv3d backward (#42502)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangkaihuo authored May 6, 2022
1 parent d73eb38 commit 503569a
Showing 1 changed file with 8 additions and 27 deletions.
35 changes: 8 additions & 27 deletions paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ limitations under the License. */
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/sparse/scatter.cu.h"
#include "paddle/phi/kernels/funcs/scatter.cu.h"
#include "paddle/phi/kernels/sparse/convolution_grad_kernel.h"
#include "paddle/phi/kernels/sparse/gpu/convolution.cu.h"

Expand Down Expand Up @@ -203,38 +203,19 @@ void Conv3dGradGPUKernel(const GPUContext& dev_ctx,
}

// 4. scatter
// x_grad->ResizeAndAllocate(x.non_zero_elements().dims());
DenseTensorMeta index_meta(DataType::INT32, {rulebook_len}, DataLayout::NCHW);
DenseTensor out_index = phi::Empty(dev_ctx, std::move(index_meta));
DenseTensor unique_key = phi::Empty(
dev_ctx,
DenseTensorMeta(paddle::experimental::CppTypeToDataType<IntT>::Type(),
{rulebook_len},
DataLayout::NCHW));
DenseTensor unique_value = phi::Empty(dev_ctx, std::move(index_meta));

SortedAndUniqueIndex<GPUContext, IntT>(dev_ctx,
rulebook_ptr + rulebook_len,
rulebook_len,
&out_index,
&unique_key,
&unique_value);

config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, rulebook_len * in_channels, 1);

phi::funcs::sparse::ScatterKernel<T><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(
phi::funcs::ScatterCUDAKernel<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(
d_x_features_ptr,
unique_value.data<int>(),
out_index.data<int>(),
x.nnz(),
rulebook_ptr + rulebook_len,
x_grad_values_ptr,
rulebook_len,
in_channels,
x_grad_values_ptr,
subm);
false);
}

template <typename T, typename Context>
Expand Down

0 comments on commit 503569a

Please sign in to comment.