Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[API/OP]Add a new API paddle.diagonal #33586

Merged
merged 19 commits into from Jun 22, 2021

Conversation

zhangbo9674
Copy link
Contributor

@zhangbo9674 zhangbo9674 commented Jun 16, 2021

PR types

New features

PR changes

APIs

Describe

This OP computes the diagonals of the input tensor x.

If x is 2D, returns the diagonal.
If x has larger dimensions, diagonals be taken from the 2D planes specified by axis1 and axis2.
By default, the 2D planes formed by the first and second axis of the input tensor x.

The argument offset determines where diagonals are taken from input tensor x:
- If offset = 0, it is the main diagonal.
- If offset > 0, it is above the main diagonal.
- If offset < 0, it is below the main diagonal.

Args:

x(Tensor): The input tensor x. Must be at least 2-dimensional. The input data type should be int32, int64, float32, float64.
offset(int, optional): Which diagonals in input tensor x will be taken. Default: 0 (main diagonals).
axis1(int, optional): The first axis with respect to take diagonal. Default: 0.
axis2(int, optional): The second axis with respect to take diagonal. Default: 1.
name (str, optional): Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Default: None.

if attr axis1 or axis2 is over the dims if imput Tensor, it will return a OutOfRange error.
if attr offset is over the range of [axis1, axis2], it will set the new dim as 0.

Returns:
Tensor: the output data type is the same as input data type.

Examples:

import paddle

x = paddle.rand([2,2,3],'float32')
print(x)
# Tensor(shape=[2, 2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
#        [[[0.45661032, 0.03751532, 0.90191704],
#          [0.43760979, 0.86177313, 0.65221709]],

#         [[0.17020577, 0.00259554, 0.28954273],
#          [0.51795638, 0.27325270, 0.18117726]]])

out1 = paddle.diagonal(x)
print(out1)
#Tensor(shape=[3, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
#       [[0.45661032, 0.51795638],
#        [0.03751532, 0.27325270],
#        [0.90191704, 0.18117726]])

out2 = paddle.diagonal(x, offset=0, axis1=2, axis2=1)
print(out2)
#Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
#       [[0.45661032, 0.86177313],
#        [0.17020577, 0.27325270]])

out3 = paddle.diagonal(x, offset=1, axis1=0, axis2=1)
print(out3)
#Tensor(shape=[3, 1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
#       [[0.43760979],
#        [0.86177313],
#        [0.65221709]])

out4 = paddle.diagonal(x, offset=0, axis1=1, axis2=2)
print(out4)
#Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
#       [[0.45661032, 0.86177313],
#        [0.17020577, 0.27325270]])

API doc:
图片
图片

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better use OP_INOUT_CHECK

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

platform::errors::NotFound("Output of DiagonalOp is not found."));

int offset_ = ctx->Attrs().Get<int>("offset");
int dim1 = ctx->Attrs().Get<int>("axis1");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe dim1->axis1 is better, to be consistent with op's attr.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

PADDLE_ENFORCE_NE(dim1_, dim2_,
platform::errors::InvalidArgument(
"The dimensions should not be identical "
"%ld vs %ld.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

%d for int

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

Comment on lines 65 to 68
auto dim1_size = out_dims[dim1_];
auto dim2_size = out_dims[dim2_];
out_dims.erase(out_dims.begin() + std::max(dim1_, dim2_));
out_dims.erase(out_dims.begin() + std::min(dim1_, dim2_));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add some comments

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

"(Tensor) The partial view of input with the its diagonal elements.");
AddAttr<int>(
"offset",
R"DOC((int, default 0), offset of the diagonal from the main diagonal. Can be both positive and negative. Defaults: 0.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Defaults -> Default

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

@@ -0,0 +1,362 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2019 -> 2021

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

Comment on lines 92 to 96
cudaMalloc(reinterpret_cast<void**>(&input_stride),
input_stride_size * sizeof(int64_t));
cudaMemcpy(reinterpret_cast<void*>(input_stride),
reinterpret_cast<void*>(host_input_stride),
input_stride_size * sizeof(int64_t), cudaMemcpyHostToDevice);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not use cudaMemcpy/cudaMalloc in op kernel directly. Firstly, try to use less memory as possible, i.e., do not malloc extra memory if can be avoid by refine the algorithm. Secondly, If temporary memory is exactly needed, use temp Tensor, and call tensor.mutable_data.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace cudaMalloc/cudaMemcpy by TensorFormVector in cudaKernel and cudaGradKernel.

@zhangbo9674 zhangbo9674 changed the title Add a new API paddle.diagonal [API/OP]Add a new API paddle.diagonal Jun 18, 2021
@@ -355,6 +356,8 @@
'shape',
'real',
'imag',
'digamma',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这行要删除,和362重复

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

# [[0.50427347, 0.78351408, 0.00833563],
# [0.36932808, 0.83527362, 0.49408615]]])

out = paddle.diagonal(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要补充以下几个示例:

  • axis1 = 1的示例
  • offset != 0,axis1 = 1 的示例
  • offset != 0、axis1 = 1、axis2 != 1 的示例

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

name (str, optional): Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Default: None.

Returns:
Tensor: the output data type is the same as input data type.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

少了,输入 Tensor 在指定二维平面的局部视图 这一句

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

if in_dygraph_mode():
return core.ops.diagonal(x, 'offset', offset, 'axis1', axis1, 'axis2', axis2)

if not in_dygraph_mode():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这行代码是否多余?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删除,谢谢!


"""
inputs = {'Input': [x]}
attrs = {'offset': offset, 'axis1': axis1, 'axis2': axis2}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这两个操作为什么放在动态图调用之前?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该段代码没有什么意义,已删除,谢谢!

- If offset < 0, it is below the main diagonal.

Args:
x(Tensor): The input tensor x. Must be at least 2-dimensional. The input data type should be int32, int64, float32, float64.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

应该考虑支持bool、fp16等更多类型

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已添加支持,谢谢!

- If offset < 0, it is below the main diagonal.

Args:
x(Tensor): The input tensor x. Must be at least 2-dimensional. The input data type should be int32, int64, float32, float64, float16, bool.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
x(Tensor): The input tensor x. Must be at least 2-dimensional. The input data type should be int32, int64, float32, float64, float16, bool.
x(Tensor): The input tensor x. Must be at least 2-dimensional. The input data type should be bool, int32, int64, float16, float32, float64.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

lanxianghit
lanxianghit previously approved these changes Jun 22, 2021
Copy link
Contributor

@lanxianghit lanxianghit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@TCChenlong TCChenlong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhiqiu zhiqiu merged commit ad10629 into PaddlePaddle:develop Jun 22, 2021
@zhangbo9674 zhangbo9674 deleted the dev/diagonal branch September 14, 2022 02:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants