Skip to content

Commit

Permalink
compilation optimization for kron_grad_kernel (PaddlePaddle#57822)
Browse files Browse the repository at this point in the history
  • Loading branch information
tianhaodongbd authored and Frida-a committed Oct 14, 2023
1 parent 3ff45a6 commit b20fa93
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions paddle/phi/kernels/impl/kron_grad_kernel_impl.h
Expand Up @@ -15,6 +15,7 @@
#pragma once

#include "paddle/phi/kernels/impl/kron_kernel_impl.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"

namespace phi {

Expand Down Expand Up @@ -234,12 +235,12 @@ struct KronGradOpFunctor {
#if defined(__NVCC__) || defined(__HIPCC__)
auto stream = dev_ctx.stream(); // it is a cuda device_context
if (dx) {
funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dev_ctx, dout_x, dx, kps::IdentityFunctor<T>(), {1});
phi::SumKernel<T, Context>(
dev_ctx, dout_x, {1}, dout_x.dtype(), false, dx);
}
if (dy) {
funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dev_ctx, dout_y, dy, kps::IdentityFunctor<T>(), {1});
phi::SumKernel<T, Context>(
dev_ctx, dout_y, {1}, dout_y.dtype(), false, dy);
}
#else
auto *place = dev_ctx.eigen_device();
Expand Down

0 comments on commit b20fa93

Please sign in to comment.