Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add paddle.diag API, diag_v2 OP and CUDA kernel #26414

Merged
merged 2 commits into from
Aug 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions paddle/fluid/operators/diag_v2_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/diag_v2_op.h"
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"

namespace paddle {
namespace operators {

class DiagV2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "diag_v2");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "diag_v2");

auto x_dims = ctx->GetInputDim("X");
auto offset = ctx->Attrs().Get<int>("offset");

if (x_dims.size() == 1UL) {
int64_t size = x_dims[0] + std::abs(offset);
ctx->SetOutputDim("Out", {size, size});
} else if (x_dims.size() == 2UL) {
int64_t size;
if (offset >= 0) {
size = std::min(x_dims[0], x_dims[1] - offset);
} else {
size = std::min(x_dims[0] + offset, x_dims[1]);
}
ctx->SetOutputDim("Out", {size});
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The input tensor X's dimensions of DiagV2Op should be either 1 or "
"2, but received %d.",
x_dims.size()));
}
}
};

class DiagV2OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input tensor. Its shape is either 1-D or 2-D.");
AddOutput("Out", "The output tensor. A square matrix or a vector.");
AddAttr<int>("offset",
"The diagonal offset. A positive value represents "
"superdiagonal, 0 represents the main diagonal, and a "
"negative value represents subdiagonal.")
.SetDefault(0);
AddAttr<float>("padding_value",
"Use this value to fill the area outside the specified "
"diagonal band. Only takes effect when the input is a 1-D "
"Tensor. The default value is 0.")
.SetDefault(0.0f);
AddComment(R"DOC(
If ``x`` is a vector (1-D tensor), a 2-D square tensor whth the elements of ``x`` as the diagonal is returned.

If ``x`` is a matrix (2-D tensor), a 1-D tensor with the diagonal elements of ``x`` is returned.

The argument ``offset`` controls the diagonal offset:

If ``offset`` = 0, it is the main diagonal.

If ``offset`` > 0, it is superdiagonal.

If ``offset`` < 0, it is subdiagonal.
)DOC");
}
};

template <typename DeviceContext, typename T>
class DiagV2Kernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* x_data = X->data<T>();
auto x_dims = X->dims();
int offset = context.Attr<int>("offset");
auto* out = context.Output<framework::Tensor>("Out");
T* out_data = out->mutable_data<T>(context.GetPlace());
auto out_dims = out->dims();

int64_t i;
if (x_dims.size() == 1) {
float padding_value = context.Attr<float>("padding_value");
math::SetConstant<DeviceContext, T> set_padding_value;
auto& dev_ctx = context.template device_context<DeviceContext>();
set_padding_value(dev_ctx, out, static_cast<T>(padding_value));

auto x_length = x_dims[0];
const int& x_stride = ComputeStride(0, x_dims);

auto out_stride_0 = ComputeStride(0, out_dims);
auto out_stride_1 = ComputeStride(1, out_dims);
out_data +=
(offset >= 0 ? offset * out_stride_1 : -offset * out_stride_0);

for (i = 0; i < x_length; i++) {
out_data[i * (out_stride_0 + out_stride_1)] = x_data[i * x_stride];
}
} else {
auto out_length = out_dims[0];
const int& x_stride_0 = ComputeStride(0, x_dims);
const int& x_stride_1 = ComputeStride(1, x_dims);

auto out_stride_0 = ComputeStride(0, out_dims);
x_data += (offset >= 0 ? offset * x_stride_1 : -offset * x_stride_0);
for (i = 0; i < out_length; i++) {
out_data[i * out_stride_0] = x_data[i * (x_stride_0 + x_stride_1)];
}
}
}
};
} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OPERATOR(
diag_v2, ops::DiagV2Op, ops::DiagV2OpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
diag_v2, ops::DiagV2Kernel<paddle::platform::CPUDeviceContext, int>,
ops::DiagV2Kernel<paddle::platform::CPUDeviceContext, float>,
ops::DiagV2Kernel<paddle::platform::CPUDeviceContext, double>,
ops::DiagV2Kernel<paddle::platform::CPUDeviceContext, int64_t>);
122 changes: 122 additions & 0 deletions paddle/fluid/operators/diag_v2_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/diag_v2_op.h"

namespace paddle {
namespace operators {

// Extract the diagonal of a matrix 'x' to a vector 'out'.
template <typename T>
__global__ void ExtractDiagonalKernel(T* out, const T* x, std::ptrdiff_t start,
std::ptrdiff_t size,
const std::ptrdiff_t sumStride,
const std::ptrdiff_t outStride) {
for (std::ptrdiff_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) {
const std::ptrdiff_t xOffset = start + sumStride * idx;
out[outStride * idx] = x[xOffset];
}
}

// Paste a vector 'x' to the diagonal of a matrix 'out'
template <typename T>
__global__ void PasteDiagonalKernel(T* out, const T* x, std::ptrdiff_t start,
std::ptrdiff_t x_length,
const std::ptrdiff_t sumStride,
const std::ptrdiff_t xStride) {
for (std::ptrdiff_t idx = blockIdx.x * blockDim.x + threadIdx.x;
idx < x_length; idx += gridDim.x * blockDim.x) {
const std::ptrdiff_t outOffset = start + sumStride * idx;
out[outOffset] = x[xStride * idx];
}
}

template <typename DeviceContext, typename T>
class DiagV2CUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* x_data = X->data<T>();
auto x_dims = X->dims();
int offset = context.Attr<int>("offset");
auto* out = context.Output<framework::Tensor>("Out");
T* out_data = out->mutable_data<T>(context.GetPlace());
auto out_dims = out->dims();
auto& dev_ctx = context.template device_context<DeviceContext>();

if (x_dims.size() == 1) {
float padding_value = context.Attr<float>("padding_value");
math::SetConstant<DeviceContext, T> set_padding_value;
set_padding_value(dev_ctx, out, static_cast<T>(padding_value));

auto x_length = x_dims[0];
auto size = (offset > 0) ? x_length + offset : x_length - offset;
const int& x_stride = ComputeStride(0, x_dims);
if (size > 0) {
const int block_num = std::min(static_cast<int>(size),
dev_ctx.GetMaxPhysicalThreadCount());
int size_ = static_cast<int>(size);
int block_num_ = static_cast<int>(block_num);
const int grid_num =
std::min(1024, (size_ + block_num_ - 1) / block_num_);
const auto& out_stride_0 = ComputeStride(0, out_dims);
const auto& out_stride_1 = ComputeStride(1, out_dims);
auto start =
(offset >= 0 ? offset * out_stride_1 : -offset * out_stride_0);

PasteDiagonalKernel<T><<<grid_num, block_num, 0, dev_ctx.stream()>>>(
out_data, x_data, start, x_length, out_stride_0 + out_stride_1,
x_stride);
}
} else {
const int& x_stride_0 = ComputeStride(0, x_dims);
const int& x_stride_1 = ComputeStride(1, x_dims);

int size;
if (offset > 0) {
size = std::min(x_dims[0], x_dims[1] - offset);
} else {
size = std::min(x_dims[0] + offset, x_dims[1]);
}

if (size > 0) {
const int block_num = std::min(static_cast<int>(size),
dev_ctx.GetMaxPhysicalThreadCount());
int size_ = static_cast<int>(size);
int block_num_ = static_cast<int>(block_num);
const int grid_num =
std::min(1024, (size_ + block_num_ - 1) / block_num_);
auto start = (offset >= 0 ? offset * x_stride_1 : -offset * x_stride_0);
const auto& out_stride_0 = ComputeStride(0, out_dims);

ExtractDiagonalKernel<T><<<grid_num, block_num, 0, dev_ctx.stream()>>>(
out_data, x_data, start, size, x_stride_0 + x_stride_1,
out_stride_0);
}
}
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
diag_v2, ops::DiagV2CUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::DiagV2CUDAKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::DiagV2CUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::DiagV2CUDAKernel<paddle::platform::CUDADeviceContext, double>);
34 changes: 34 additions & 0 deletions paddle/fluid/operators/diag_v2_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"

namespace paddle {
namespace operators {

using DDim = framework::DDim;

static inline int ComputeStride(int axis, DDim dims) {
int size = 1;
for (int i = axis + 1; i < dims.size(); i++) {
size *= dims[i];
}
return size;
}

} // namespace operators
} // namespace paddle
2 changes: 2 additions & 0 deletions python/paddle/fluid/layers/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .layer_function_generator import templatedoc
from . import utils
from ..data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype
from paddle.utils import deprecated
import numpy
import warnings

Expand Down Expand Up @@ -1537,6 +1538,7 @@ def zeros_like(x, out=None):
return out


@deprecated(since="2.0.0", update_to="paddle.diag")
def diag(diagonal):
"""
:alias_main: paddle.diag
Expand Down
Loading