Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PHI] transpose2_grad op migration #46139

Merged
merged 22 commits into from Oct 10, 2022
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
ebd80bf
op migrated, Copy(OneDNNContext, ...) added
paulinagacek Sep 16, 2022
490e6dc
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
paulinagacek Sep 16, 2022
addde63
mutable_data & op registration in fluid removed
paulinagacek Sep 19, 2022
35ec9b7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
paulinagacek Sep 19, 2022
28ffdb7
Merge branch 'transpose2_grad_op_migration' of https://github.com/Pau…
paulinagacek Sep 19, 2022
78f8af6
refactoring
paulinagacek Sep 20, 2022
8501e7f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
paulinagacek Sep 20, 2022
fbe2c5d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
paulinagacek Sep 20, 2022
bd49216
OneDNNGetDataType to uppercase
paulinagacek Sep 20, 2022
fa1369b
missing cpu check added, handler moved to .h file
paulinagacek Sep 21, 2022
d2603a1
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
paulinagacek Sep 21, 2022
2cbe8f4
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
paulinagacek Sep 22, 2022
d566a95
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
paulinagacek Sep 22, 2022
d33d261
name changed to transpose_grad
paulinagacek Sep 27, 2022
6d8de58
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
paulinagacek Sep 27, 2022
372b367
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
paulinagacek Sep 28, 2022
cf4eaf0
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
paulinagacek Sep 28, 2022
131bdc3
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
paulinagacek Sep 29, 2022
6a344c8
Copy changed back to TensorCopy
paulinagacek Sep 30, 2022
e70eb82
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
paulinagacek Sep 30, 2022
9470a6c
Resizing corrected, Copy(OneDNNContext) removed
paulinagacek Oct 3, 2022
14881c2
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
paulinagacek Oct 3, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 0 additions & 5 deletions paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc
Expand Up @@ -223,8 +223,3 @@ REGISTER_OP_KERNEL(transpose_grad,
MKLDNN,
::paddle::platform::CPUPlace,
ops::TransposeMKLDNNGradOpKernel<float>);

REGISTER_OP_KERNEL(transpose2_grad,
MKLDNN,
::paddle::platform::CPUPlace,
ops::TransposeMKLDNNGradOpKernel<float>);
66 changes: 66 additions & 0 deletions paddle/phi/backends/onednn/onednn_reuse.h
Expand Up @@ -1046,5 +1046,71 @@ class ClipOneDNNHandler
to_void_cast<T>(input_data));
}
};
template <typename T>
class TransposeOneDNNHandler {
public:
TransposeOneDNNHandler(const OneDNNContext& dev_ctx,
std::vector<int64_t>& dims, // NOLINT
std::vector<int>& axis, // NOLINT
dnnl::engine engine)
: dev_ctx_(dev_ctx),
dims_(dims),
axis_(axis),
logical_axis_(dims.size(), 0),
engine_(engine) {}

std::shared_ptr<dnnl::memory> AcquireSrcMemory(const OneDNNMemoryFormat& fmt,
void* ptr) {
// Make memory descriptor using input format, unless it
// cannot be trusted (nchw) then make up memory fmt manually
for (size_t i = 0; i < this->logical_axis_.size(); ++i) {
this->logical_axis_[i] = i;
}

auto src_md = fmt != OneDNNMemoryFormat::nchw
? OneDNNMemDesc(dims_, OneDNNGetDataType<T>(), fmt)
: Axis2MemoryDesc(dims_, logical_axis_);
return std::make_shared<dnnl::memory>(src_md, engine_, ptr);
}

std::shared_ptr<dnnl::memory> AcquireDstMemory(DenseTensor* output,
Place place) {
auto dst_md = Axis2MemoryDesc(dims_, axis_);
output->Resize(make_ddim(dims_));
auto dst_data = dev_ctx_.Alloc<T>(output);

return std::make_shared<dnnl::memory>(dst_md, engine_, dst_data);
}

std::shared_ptr<dnnl::reorder> AcquireTranspose(
std::shared_ptr<dnnl::memory> dst_memory_p,
std::shared_ptr<dnnl::memory> src_memory_p) {
return std::make_shared<dnnl::reorder>(*(src_memory_p), *(dst_memory_p));
}

protected:
dnnl::memory::desc Axis2MemoryDesc(std::vector<int64_t>& nchw_tz, // NOLINT
std::vector<int>& axis // NOLINT
) {
size_t ndims = axis.size();

std::vector<int64_t> strides(ndims);
unsigned int total_stride = 1;
for (int i = ndims - 1; i >= 0; --i) {
strides[axis[i]] = total_stride;
total_stride *= nchw_tz[axis[i]];
}
dnnl::memory::desc mem_d(nchw_tz, OneDNNGetDataType<T>(), strides);

return mem_d;
}

private:
const OneDNNContext& dev_ctx_;
std::vector<int64_t> dims_;
std::vector<int> axis_;
std::vector<int> logical_axis_;
dnnl::engine engine_;
};
} // namespace funcs
} // namespace phi
8 changes: 8 additions & 0 deletions paddle/phi/core/tensor_utils.cc
Expand Up @@ -411,4 +411,12 @@ template void Copy(const CustomContext& dev_ctx,
bool blocking,
DenseTensor* dst);
#endif

#ifdef PADDLE_WITH_MKLDNN
template void Copy(const OneDNNContext& dev_ctx,
const DenseTensor& src,
Place dst_place,
bool blocking,
DenseTensor* dst);
#endif
} // namespace phi
68 changes: 68 additions & 0 deletions paddle/phi/kernels/onednn/transpose_grad_kernel.cc
@@ -0,0 +1,68 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/transpose_grad_kernel.h"

#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {
template <typename T, typename Context>
void TransposeGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const std::vector<int>& axis,
DenseTensor* x_grad) {
PADDLE_ENFORCE_EQ(dev_ctx.GetPlace().GetType() == phi::AllocationType::CPU,
true,
errors::PreconditionNotMet(
"Operator DNNL TransposeGrad must use CPUPlace"));
if (!x_grad) return;

const auto& onednn_engine = dev_ctx.GetEngine();
std::vector<int> reversed_axis(axis);
int ndims = axis.size();
if (ndims == 1) {
Copy(dev_ctx, out_grad, out_grad.place(), false, x_grad);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found that phi::Copy and framework::TensorCopy behave inconsistently in the mkldnn scenario, you can still use framework::TensorCopy for the time being, the problem of phi::Copy I will solve

x_grad->set_format(out_grad.format());
return;
}

for (size_t i = 0; i < axis.size(); i++) {
reversed_axis[axis[i]] = i;
}

const T* out_grad_data = out_grad.data<T>();
dev_ctx.template Alloc<T>(x_grad);
auto nchw_tz = vectorize<int64_t>(out_grad.dims());

funcs::TransposeOneDNNHandler<T> handler(
dev_ctx, nchw_tz, reversed_axis, onednn_engine);

auto transpose_src_memory_p = handler.AcquireSrcMemory(
out_grad.format(), funcs::to_void_cast<T>(out_grad_data));
auto transpose_dst_memory_p =
handler.AcquireDstMemory(x_grad, dev_ctx.GetPlace());
auto transpose_p =
handler.AcquireTranspose(transpose_dst_memory_p, transpose_src_memory_p);

auto& astream = OneDNNContext::tls().get_stream();
transpose_p->execute(
astream, *transpose_src_memory_p, *transpose_dst_memory_p);
astream.wait();
}

} // namespace phi

PD_REGISTER_KERNEL(
transpose_grad, OneDNN, ALL_LAYOUT, phi::TransposeGradKernel, float) {}