Skip to content

Commit

Permalink
remove inplace
Browse files Browse the repository at this point in the history
  • Loading branch information
AshburnLee committed Jan 12, 2022
1 parent 817d00d commit 6ab80a3
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions paddle/fluid/operators/elementwise/elementwise_max_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,20 @@ ElementwiseMaxGrad(const framework::ExecutionContext& ctx,
ctx.template device_context<platform::CUDADeviceContext>();
const auto place = ctx.GetPlace();
if (dx != nullptr && dy != nullptr) {
dx->mutable_data<T>(place);
if (dx->IsSharedBufferWith(*dout)) {
dx->clear();
dx->mutable_data<T>(x->dims(), place);
}
// dx->mutable_data<T>(place);
// if (dx->IsSharedBufferWith(*dout)) {
// dx->clear();
// dx->mutable_data<T>(x->dims(), place);
// }
std::vector<const framework::Tensor*> ins = {x, y, dout};
GetGradXAndYOut<ElementwiseType::kTernary, T>(
dev_ctx, place, axis, ins, dout, dx, dy, MaxGradXYFunctor<T, T>());
} else if (dx != nullptr && dy == nullptr) {
dx->mutable_data<T>(place);
if (dx->IsSharedBufferWith(*dout)) {
dx->clear();
dx->mutable_data<T>(x->dims(), place);
}
// dx->mutable_data<T>(place);
// if (dx->IsSharedBufferWith(*dout)) {
// dx->clear();
// dx->mutable_data<T>(x->dims(), place);
// }
std::vector<const framework::Tensor*> ins = {x, y, dout};
GetGradXOrYOut<ElementwiseType::kTernary, T>(
dev_ctx, place, axis, ins, dout, dx, MaxGradXFunctor<T>());
Expand Down

0 comments on commit 6ab80a3

Please sign in to comment.