Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PHI] transpose2_grad op migration #46139

Merged
merged 22 commits into from Oct 10, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
ebd80bf
op migrated, Copy(OneDNNContext, ...) added
paulinagacek Sep 16, 2022
490e6dc
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
paulinagacek Sep 16, 2022
addde63
mutable_data & op registration in fluid removed
paulinagacek Sep 19, 2022
35ec9b7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
paulinagacek Sep 19, 2022
28ffdb7
Merge branch 'transpose2_grad_op_migration' of https://github.com/Pau…
paulinagacek Sep 19, 2022
78f8af6
refactoring
paulinagacek Sep 20, 2022
8501e7f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
paulinagacek Sep 20, 2022
fbe2c5d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
paulinagacek Sep 20, 2022
bd49216
OneDNNGetDataType to uppercase
paulinagacek Sep 20, 2022
fa1369b
missing cpu check added, handler moved to .h file
paulinagacek Sep 21, 2022
d2603a1
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
paulinagacek Sep 21, 2022
2cbe8f4
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
paulinagacek Sep 22, 2022
d566a95
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
paulinagacek Sep 22, 2022
d33d261
name changed to transpose_grad
paulinagacek Sep 27, 2022
6d8de58
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
paulinagacek Sep 27, 2022
372b367
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
paulinagacek Sep 28, 2022
cf4eaf0
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
paulinagacek Sep 28, 2022
131bdc3
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
paulinagacek Sep 29, 2022
6a344c8
Copy changed back to TensorCopy
paulinagacek Sep 30, 2022
e70eb82
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
paulinagacek Sep 30, 2022
9470a6c
Resizing corrected, Copy(OneDNNContext) removed
paulinagacek Oct 3, 2022
14881c2
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
paulinagacek Oct 3, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 0 additions & 5 deletions paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc
Expand Up @@ -224,8 +224,3 @@ REGISTER_OP_KERNEL(transpose_grad,
MKLDNN,
::paddle::platform::CPUPlace,
ops::TransposeMKLDNNGradOpKernel<float>);

REGISTER_OP_KERNEL(transpose2_grad,
MKLDNN,
::paddle::platform::CPUPlace,
ops::TransposeMKLDNNGradOpKernel<float>);
8 changes: 8 additions & 0 deletions paddle/phi/core/tensor_utils.cc
Expand Up @@ -411,4 +411,12 @@ template void Copy(const CustomContext& dev_ctx,
bool blocking,
DenseTensor* dst);
#endif

#ifdef PADDLE_WITH_MKLDNN
template void Copy(const OneDNNContext& dev_ctx,
const DenseTensor& src,
Place dst_place,
bool blocking,
DenseTensor* dst);
#endif
} // namespace phi
135 changes: 135 additions & 0 deletions paddle/phi/kernels/onednn/transpose2_grad.cc
@@ -0,0 +1,135 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/transpose_grad_kernel.h"

namespace phi {
namespace funcs {

template <typename T>
class TransposeOneDNNHandler {
public:
TransposeOneDNNHandler(const OneDNNContext& dev_ctx,
std::vector<int64_t>& dims, // NOLINT
std::vector<int>& axis, // NOLINT
dnnl::engine engine)
: dev_ctx_(dev_ctx),
dims_(dims),
axis_(axis),
logical_axis_(dims.size(), 0),
engine_(engine) {}

std::shared_ptr<dnnl::memory> AcquireSrcMemory(
const phi::funcs::OneDNNMemoryFormat& fmt, void* ptr) {
// Make memory descriptor using input format, unless it
// cannot be trusted (nchw) then make up memory fmt manually
for (size_t i = 0; i < this->logical_axis_.size(); ++i) {
this->logical_axis_[i] = i;
}

auto src_md = fmt != phi::funcs::OneDNNMemoryFormat::nchw
? OneDNNMemDesc(dims_, oneDNNGetDataType<T>(), fmt)
: Axis2MemoryDesc(dims_, logical_axis_);
return std::make_shared<dnnl::memory>(src_md, engine_, ptr);
}

std::shared_ptr<dnnl::memory> AcquireDstMemory(DenseTensor* output,
phi::Place place) {
auto dst_md = Axis2MemoryDesc(dims_, axis_);
output->Resize(make_ddim(dims_));
auto dst_data = dev_ctx_.Alloc<T>(output);

return std::make_shared<dnnl::memory>(dst_md, engine_, dst_data);
}

std::shared_ptr<dnnl::reorder> AcquireTranspose(
std::shared_ptr<dnnl::memory> dst_memory_p,
std::shared_ptr<dnnl::memory> src_memory_p) {
return std::make_shared<dnnl::reorder>(*(src_memory_p), *(dst_memory_p));
}

protected:
dnnl::memory::desc Axis2MemoryDesc(std::vector<int64_t>& nchw_tz, // NOLINT
std::vector<int>& axis // NOLINT
) {
size_t ndims = axis.size();

std::vector<int64_t> strides(ndims);
unsigned int total_stride = 1;
for (int i = ndims - 1; i >= 0; --i) {
strides[axis[i]] = total_stride;
total_stride *= nchw_tz[axis[i]];
}
dnnl::memory::desc mem_d(
nchw_tz, phi::funcs::oneDNNGetDataType<T>(), strides);

return mem_d;
}

private:
const OneDNNContext& dev_ctx_;
std::vector<int64_t> dims_;
std::vector<int> axis_;
std::vector<int> logical_axis_;
dnnl::engine engine_;
};
} // namespace funcs

template <typename T, typename Context>
void TransposeGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const std::vector<int>& axis,
DenseTensor* x_grad) {
if (!x_grad) return;

const auto& onednn_engine = dev_ctx.GetEngine();
std::vector<int> reversed_axis(axis);
int ndims = axis.size();
if (ndims == 1) {
phi::Copy(dev_ctx, out_grad, out_grad.place(), false, x_grad);
x_grad->set_format(out_grad.format());
return;
}

for (size_t i = 0; i < axis.size(); i++) {
reversed_axis[axis[i]] = i;
}

const T* out_grad_data = out_grad.data<T>();
dev_ctx.template Alloc<T>(x_grad);
auto nchw_tz = phi::vectorize<int64_t>(out_grad.dims());

phi::funcs::TransposeOneDNNHandler<T> handler(
dev_ctx, nchw_tz, reversed_axis, onednn_engine);

auto transpose_src_memory_p = handler.AcquireSrcMemory(
out_grad.format(), phi::funcs::to_void_cast<T>(out_grad_data));
auto transpose_dst_memory_p =
handler.AcquireDstMemory(x_grad, dev_ctx.GetPlace());
auto transpose_p =
handler.AcquireTranspose(transpose_dst_memory_p, transpose_src_memory_p);

auto& astream = phi::OneDNNContext::tls().get_stream();
transpose_p->execute(
astream, *transpose_src_memory_p, *transpose_dst_memory_p);
astream.wait();
}

} // namespace phi

PD_REGISTER_KERNEL(
transpose2_grad, OneDNN, ALL_LAYOUT, phi::TransposeGradKernel, float) {}