Skip to content

Commit

Permalink
add cumsum prim backward (#50565)
Browse files Browse the repository at this point in the history
* add cumsum prim backward

* skip aixs=None test case

* fix op generante eror

* fix static test error

* remove unused code

* fix static test error

* skip cpu float16 test case

* skip eager cpu cumsum float16 test case

* add cinn test

* reshape flatten out

* Disable cinn single test

* remove cinn test

* reformat todo

* add prim in cumsum op test

* remove old test

* fix typro

* fix typro

* fix typro

* pass axis=None test case

* remove forward prim test

* remove same name axis
  • Loading branch information
GGBond8488 committed Feb 28, 2023
1 parent 16a1b4a commit ca2b609
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 10 deletions.
25 changes: 25 additions & 0 deletions paddle/fluid/operators/cum_op.cc
Expand Up @@ -15,6 +15,9 @@ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"

Expand Down Expand Up @@ -100,6 +103,27 @@ class CumsumGradMaker : public framework::SingleGradOpMaker<T> {
}
};

class CumsumCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;

public:
void Apply() override {
paddle::experimental::Tensor x = this->GetSingleForwardInput("X");
paddle::experimental::Tensor out_grad = this->GetSingleOutputGrad("Out");
paddle::experimental::Tensor dx = this->GetSingleInputGrad("X");
auto* dx_ptr = this->GetOutputPtr(&dx);
std::string dx_name = this->GetOutputName(dx);
int axis = static_cast<int>(this->Attr<int>("axis"));
bool flatten = static_cast<bool>(this->Attr<bool>("flatten"));
bool exclusive = static_cast<bool>(this->Attr<bool>("exclusive"));
bool reverse = static_cast<bool>(this->Attr<bool>("reverse"));
VLOG(6) << "Runing add_grad composite func";
prim::cumsum_grad<prim::DescTensor>(
x, out_grad, axis, flatten, exclusive, reverse, dx_ptr);
this->RecoverOutputName(dx, dx_name);
}
};

class LogcumsumexpOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
Expand Down Expand Up @@ -182,6 +206,7 @@ DECLARE_INFER_SHAPE_FUNCTOR(logcumsumexp,
REGISTER_OPERATOR(cumsum,
ops::CumOp,
ops::CumsumOpMaker,
ops::CumsumCompositeGradOpMaker,
ops::CumsumGradMaker<paddle::framework::OpDesc>,
ops::CumsumGradMaker<paddle::imperative::OpBase>,
CumsumInferShapeFunctor);
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/prim/api/api.yaml
Expand Up @@ -25,3 +25,4 @@
- tile
- transpose
- pad
- cumsum
15 changes: 15 additions & 0 deletions paddle/fluid/prim/api/composite_backward/composite_backward_api.h
Expand Up @@ -414,5 +414,20 @@ void slice_grad(const Tensor& input,
}
}

template <typename T>
void cumsum_grad(const Tensor& x,
const Tensor& out_grad,
const Scalar& axis,
bool flatten,
bool exclusive,
bool reverse,
Tensor* x_grad) {
if (x_grad) {
auto grad = cumsum<T>(out_grad, axis, flatten, exclusive, !reverse);
grad = reshape<T>(grad, x.shape());
set_output<T>(grad, x_grad);
}
}

} // namespace prim
} // namespace paddle
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/legacy_backward.yaml
Expand Up @@ -313,6 +313,7 @@
kernel :
func : cumsum_grad
data_type: x
composite: cumsum_grad(x, out_grad, axis, flatten, exclusive, reverse, x_grad)

- backward_op : deformable_conv_grad
forward : deformable_conv(Tensor x, Tensor offset, Tensor filter, Tensor mask, int[] strides, int[] paddings, int[] dilations, int deformable_groups, int groups, int im2col_step) -> Tensor(out)
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Expand Up @@ -365,6 +365,17 @@
outputs :
out : Out

- op : cumsum
backward: cumsum_grad
inputs :
x : X
outputs :
out : Out
scalar:
axis:
data_type : int
tensor_name: AxisTensor

- op : data_norm
backward : data_norm_grad
extra :
Expand Down
59 changes: 49 additions & 10 deletions python/paddle/fluid/tests/unittests/test_cumsum_op.py
Expand Up @@ -115,6 +115,9 @@ def test_name(self):
class TestSumOp1(OpTest):
def setUp(self):
self.op_type = "cumsum"
self.prim_op_type = "prim"
self.python_api = paddle.cumsum
self.enable_cinn = False
self.attrs = {'axis': 2}
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")}
self.outputs = {'Out': self.inputs['X'].cumsum(axis=2)}
Expand All @@ -123,12 +126,15 @@ def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)


class TestSumOp2(OpTest):
def setUp(self):
self.op_type = "cumsum"
self.prim_op_type = "prim"
self.python_api = paddle.cumsum
self.enable_cinn = False
self.attrs = {'axis': -1, 'reverse': True}
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")}
self.outputs = {
Expand All @@ -141,12 +147,15 @@ def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)


class TestSumOp3(OpTest):
def setUp(self):
self.op_type = "cumsum"
self.prim_op_type = "prim"
self.python_api = paddle.cumsum
self.enable_cinn = False
self.attrs = {'axis': 1}
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")}
self.outputs = {'Out': self.inputs['X'].cumsum(axis=1)}
Expand All @@ -155,12 +164,15 @@ def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)


class TestSumOp4(OpTest):
def setUp(self):
self.op_type = "cumsum"
self.prim_op_type = "prim"
self.python_api = paddle.cumsum
self.enable_cinn = False
self.attrs = {'axis': 0}
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")}
self.outputs = {'Out': self.inputs['X'].cumsum(axis=0)}
Expand All @@ -175,6 +187,9 @@ def test_check_grad(self):
class TestSumOp5(OpTest):
def setUp(self):
self.op_type = "cumsum"
self.prim_op_type = "prim"
self.python_api = paddle.cumsum
self.enable_cinn = False
self.inputs = {'X': np.random.random((5, 20)).astype("float64")}
self.outputs = {'Out': self.inputs['X'].cumsum(axis=1)}

Expand All @@ -188,6 +203,9 @@ def test_check_grad(self):
class TestSumOp7(OpTest):
def setUp(self):
self.op_type = "cumsum"
self.prim_op_type = "prim"
self.python_api = paddle.cumsum
self.enable_cinn = False
self.inputs = {'X': np.random.random((100)).astype("float64")}
self.outputs = {'Out': self.inputs['X'].cumsum(axis=0)}

Expand Down Expand Up @@ -226,6 +244,9 @@ def test_main(self):
class TestSumOpExclusive1(OpTest):
def setUp(self):
self.op_type = "cumsum"
self.prim_op_type = "prim"
self.python_api = paddle.cumsum
self.enable_cinn = False
self.attrs = {'axis': 2, "exclusive": True}
a = np.random.random((4, 5, 20)).astype("float64")
self.inputs = {'X': a}
Expand All @@ -243,12 +264,15 @@ def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)


class TestSumOpExclusive2(OpTest):
def setUp(self):
self.op_type = "cumsum"
self.prim_op_type = "prim"
self.python_api = paddle.cumsum
self.enable_cinn = False
self.attrs = {'axis': 2, "exclusive": True}
a = np.random.random((1, 1, 100)).astype("float64")
self.inputs = {'X': a}
Expand All @@ -266,12 +290,15 @@ def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)


class TestSumOpExclusive3(OpTest):
def setUp(self):
self.op_type = "cumsum"
self.prim_op_type = "prim"
self.python_api = paddle.cumsum
self.enable_cinn = False
self.attrs = {'axis': 2, "exclusive": True}
a = np.random.random((4, 5, 20)).astype("float64")
self.inputs = {'X': a}
Expand All @@ -289,12 +316,15 @@ def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)


class TestSumOpExclusive4(OpTest):
def setUp(self):
self.op_type = "cumsum"
self.prim_op_type = "prim"
self.python_api = paddle.cumsum
self.enable_cinn = False
self.attrs = {'axis': 2, "exclusive": True}
a = np.random.random((1, 1, 100)).astype("float64")
self.inputs = {'X': a}
Expand All @@ -312,12 +342,15 @@ def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)


class TestSumOpExclusive5(OpTest):
def setUp(self):
self.op_type = "cumsum"
self.prim_op_type = "prim"
self.python_api = paddle.cumsum
self.enable_cinn = False
self.attrs = {'axis': 2, "exclusive": True}
a = np.random.random((4, 5, 40)).astype("float64")
self.inputs = {'X': a}
Expand All @@ -335,12 +368,15 @@ def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)


class TestSumOpExclusiveFP16(OpTest):
def setUp(self):
self.op_type = "cumsum"
self.prim_op_type = "prim"
self.python_api = paddle.cumsum
self.enable_cinn = False
self.attrs = {'axis': 2, "exclusive": True, "dtype": "float16"}
a = np.random.random((4, 5, 20)).astype("float64")
self.inputs = {'X': a}
Expand All @@ -358,12 +394,15 @@ def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)


class TestSumOpReverseExclusive(OpTest):
def setUp(self):
self.op_type = "cumsum"
self.prim_op_type = "prim"
self.python_api = paddle.cumsum
self.enable_cinn = False
self.attrs = {'axis': 2, 'reverse': True, "exclusive": True}
a = np.random.random((4, 5, 6)).astype("float64")
self.inputs = {'X': a}
Expand All @@ -382,7 +421,7 @@ def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)


class BadInputTest(unittest.TestCase):
Expand Down

0 comments on commit ca2b609

Please sign in to comment.