From 901b041196f006cd1fc4775a87849e6e716b6c62 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 11 Oct 2017 23:09:45 +0800 Subject: [PATCH 01/12] Add seq_expand op 1. Add unitest 2. Add SeqExpandOpKernel --- paddle/operators/seq_expand_op.cc | 125 ++++++++++++++++++ paddle/operators/seq_expand_op.cu | 23 ++++ paddle/operators/seq_expand_op.h | 83 ++++++++++++ .../v2/framework/tests/test_seq_expand.py | 61 +++++++++ 4 files changed, 292 insertions(+) create mode 100644 paddle/operators/seq_expand_op.cc create mode 100644 paddle/operators/seq_expand_op.cu create mode 100644 paddle/operators/seq_expand_op.h create mode 100644 python/paddle/v2/framework/tests/test_seq_expand.py diff --git a/paddle/operators/seq_expand_op.cc b/paddle/operators/seq_expand_op.cc new file mode 100644 index 0000000000000..894ba3f6b70f5 --- /dev/null +++ b/paddle/operators/seq_expand_op.cc @@ -0,0 +1,125 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/seq_expand_op.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class SeqExpandOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SeqExpandOp should not be null."); + int repeat = ctx->Attrs().Get("repeat"); + DDim out_dim; + if (repeat == 0) { + PADDLE_ENFORCE( + ctx->HasInput("Y"), + "Input(Y) of SeqExpandOp should not be null while repeat == 0."); + out_dim = ctx->GetInputDim("Y"); + ctx->ShareLoD("Y", "Out"); + } else { + out_dim = ctx->GetInputDim("X"); + out_dim[0] = out_dim[0] * repeat; + ctx->SetOutputDim("Out", y_dim); + } + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of PadOp should not be null."); + ctx->SetOutputDim("Out", out_dim); + } +}; + +class SeqExpandOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SeqExpandOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + // TODO(wanghaoshuang): Add more comments + AddInput("X", "The input('X') of seq_expand op."); + AddInput("Y", "The reference input('Y') of seq_expand op."); + AddOutput("Out", "The output of seq_expand op."); + AddAttr("repeat", "repeat times").SetDefault(0); + AddComment(R"DOC( +As an example: + +Given: + +X = [1, 2 , 3] + +and + +repeat = 2 + + +then we get + +Out.data = [1, 1, 2, 2, 3, 3] +Out.lod = [[0, 2, 4, 6]] + +)DOC"); + } +}; + +class SeqExpandOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto x_dims = ctx->GetInputDim("X"); + auto x_grad_name = framework::GradVarName("X"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + } +}; + +class SeqExpandOpGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto* bind = new framework::OpDescBind(); + bind->SetInput("X", Input("X")); + bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + bind->SetOutput(framework::GradVarName("X"), InputGrad("X")); + bind->SetAttrMap(Attrs()); + bind->SetType("seq_expand_grad"); + return std::unique_ptr(bind); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(seq_expand, ops::SeqExpandOp, ops::SeqExpandOpMaker, + ops::SeqExpandOpGradMaker); +REGISTER_OPERATOR(seq_expand_grad, ops::SeqExpandOpGrad); +REGISTER_OP_CPU_KERNEL(seq_expand, + ops::SeqExpandKernel); +REGISTER_OP_CPU_KERNEL( + seq_expand_grad, + ops::SeqExpandGradKernel); diff --git a/paddle/operators/seq_expand_op.cu b/paddle/operators/seq_expand_op.cu new file mode 100644 index 0000000000000..f1e4b82a76e62 --- /dev/null +++ b/paddle/operators/seq_expand_op.cu @@ -0,0 +1,23 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/seq_expand_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(seq_expand, + ops::SeqExpandKernel); +REGISTER_OP_GPU_KERNEL( + seq_expand_grad, + ops::SeqExpandGradKernel); diff --git a/paddle/operators/seq_expand_op.h b/paddle/operators/seq_expand_op.h new file mode 100644 index 0000000000000..80076dc35fe60 --- /dev/null +++ b/paddle/operators/seq_expand_op.h @@ -0,0 +1,83 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "hl_cuda.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; +using LoD = paddle::framework::LoD; + +template +class SeqExpandKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* out = context.Output("Out"); + const T* x_data = x->data(); + T* out_data = out->mutable_data(context.GetPlace()); + size_t repeat = static_cast(context.Attr("repeat")); + + if (repeat != 0) { + if (x->lod().size() == 0) { + std::vector level0(x->dims()[0]); + for (size_t i = 0; i <= x->dims()[0]; i++) { + level0.push_back(i * repeat); + } + const LoD out_lod; + out_lod.push_back(level0); + out->set_lod(out_lod); + } + } + auto out_dim = out->dims(); + size_t element_len = framework::product(out_dim) / out_dim[0]; + std::vector cpy_map(out_dim[0]); + if (x->lod().size() == 0) { + auto lod = out->lod(); + for (int i = 0; i < lod.size() - 1; ++i) { + for (int j = lod[0][i]; i < lod[0][i + 1]; ++j) { + cpy_map[j] = i; + } + } + } + if (paddle::platform::CPUPlace() == Place) { + for (int i = 0; i < out_dim[0]; ++i) { + memcpy(out_data + element_len * i, x_data + element_len * cpy_map[i], + sizeof(T) * element_len); + } + } else { + for (int i = 0; i < out_dim[0]; ++i) { + hl_memcpy(out_data + element_len * i, x_data + element_len * cpy_map[i], + sizeof(T) * element_len); + } + } + } +}; + +template +class SeqExpandGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + // auto* d_out = context.Input(framework::GradVarName("Out")); + // auto* d_x = context.Output(framework::GradVarName("X")); + // d_x->mutable_data(context.GetPlace()); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_seq_expand.py b/python/paddle/v2/framework/tests/test_seq_expand.py new file mode 100644 index 0000000000000..4608d3c3bd6f3 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_seq_expand.py @@ -0,0 +1,61 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestSeqExpand(OpTest): + #class TestSeqExpand(): + def set_data(self): + self.op_type = 'seq_expand' + x = np.random.uniform(0.1, 1, [3, 2, 2]).astype('float32') + y = np.zeros((6, 2, 2)).astype('float32') + lod = [[0, 2, 3, 6]] + print "x = %s" % x + self.inputs = {'X': x, 'Y': (y, lod)} + self.repeat = None + + def compute(self): + x = self.inputs['X'] + cpy_map = {} + lod = [] + out_shape = [] + if self.repeat: + level0 = [] + for i in range(x.shape[0] + 1): + level0.append(i * self.repeat) + lod.append(level0) + + for i in x.shape: + out_shape.append(i) + out_shape[0] = out_shape[0] * self.repeat + else: + y, lod = self.inputs['Y'] + out_shape = y.shape + out = np.zeros(out_shape).astype('float32') + + start = 0 + + for i in range(len(lod[0]) - 1): + for j in range(lod[0][i], lod[0][i + 1]): + cpy_map[j] = i + print "cpy_map = %s" % cpy_map + for i in range(len(out)): + out[i] = x[cpy_map[i]] + + print "out = %s" % out + self.outputs = {'Out': (out, lod)} + + def setUp(self): + self.set_data() + self.compute() + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +if __name__ == '__main__': + unittest.main() +# TestSeqExpand().setUp() From acd1aaea49e749a8d402bd6f744f2ca5f3de6020 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Thu, 12 Oct 2017 00:21:41 +0800 Subject: [PATCH 02/12] fix issues --- paddle/operators/seq_expand_op.cc | 3 +-- paddle/operators/seq_expand_op.h | 10 +++++----- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/paddle/operators/seq_expand_op.cc b/paddle/operators/seq_expand_op.cc index 894ba3f6b70f5..63b17a10f59d0 100644 --- a/paddle/operators/seq_expand_op.cc +++ b/paddle/operators/seq_expand_op.cc @@ -28,7 +28,7 @@ class SeqExpandOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of SeqExpandOp should not be null."); int repeat = ctx->Attrs().Get("repeat"); - DDim out_dim; + framework::DDim out_dim; if (repeat == 0) { PADDLE_ENFORCE( ctx->HasInput("Y"), @@ -38,7 +38,6 @@ class SeqExpandOp : public framework::OperatorWithKernel { } else { out_dim = ctx->GetInputDim("X"); out_dim[0] = out_dim[0] * repeat; - ctx->SetOutputDim("Out", y_dim); } PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of PadOp should not be null."); diff --git a/paddle/operators/seq_expand_op.h b/paddle/operators/seq_expand_op.h index 80076dc35fe60..0c399fe196ef5 100644 --- a/paddle/operators/seq_expand_op.h +++ b/paddle/operators/seq_expand_op.h @@ -21,7 +21,6 @@ namespace paddle { namespace operators { using LoDTensor = framework::LoDTensor; -using LoD = paddle::framework::LoD; template class SeqExpandKernel : public framework::OpKernel { @@ -35,11 +34,11 @@ class SeqExpandKernel : public framework::OpKernel { if (repeat != 0) { if (x->lod().size() == 0) { - std::vector level0(x->dims()[0]); + std::vector level0; for (size_t i = 0; i <= x->dims()[0]; i++) { level0.push_back(i * repeat); } - const LoD out_lod; + framework::LoD out_lod; out_lod.push_back(level0); out->set_lod(out_lod); } @@ -55,14 +54,15 @@ class SeqExpandKernel : public framework::OpKernel { } } } - if (paddle::platform::CPUPlace() == Place) { + if (platform::is_cpu_place(context.GetPlace())) { for (int i = 0; i < out_dim[0]; ++i) { memcpy(out_data + element_len * i, x_data + element_len * cpy_map[i], sizeof(T) * element_len); } } else { for (int i = 0; i < out_dim[0]; ++i) { - hl_memcpy(out_data + element_len * i, x_data + element_len * cpy_map[i], + hl_memcpy(out_data + element_len * i, + const_cast(x_data) + element_len * cpy_map[i], sizeof(T) * element_len); } } From 23701ffaf07840013295bb2ec14a484e263cdab9 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 18 Oct 2017 11:32:55 +0800 Subject: [PATCH 03/12] Refine op --- paddle/operators/seq_expand_op.h | 119 +++++++++++----- python/paddle/v2/framework/tests/op_test.py | 4 +- .../v2/framework/tests/test_seq_expand.py | 128 +++++++++++++----- 3 files changed, 185 insertions(+), 66 deletions(-) diff --git a/paddle/operators/seq_expand_op.h b/paddle/operators/seq_expand_op.h index 0c399fe196ef5..cd1182c4f087f 100644 --- a/paddle/operators/seq_expand_op.h +++ b/paddle/operators/seq_expand_op.h @@ -14,14 +14,62 @@ #pragma once -#include "hl_cuda.h" #include "paddle/framework/op_registry.h" +#include "paddle/memory/memcpy.h" namespace paddle { namespace operators { using LoDTensor = framework::LoDTensor; +template +using vector = framework::Vector; + +vector repeat_lod(vector data, vector starts, + vector times, bool is_first) { + vector result; + result.push_back(data[0]); + size_t p = 0, start = 0, end = 0; + if (is_first == true) { + for (size_t i = 0; i < times.size(); ++i) { + result.push_back(data.back() + times[i] * (data[i + 1] - data[i])); + } + } else { + for (size_t i = 0; i < times.size(); ++i) { + while (starts[i] != data[p] && p < data.size()) { + ++p; + } + start = p; + while (starts[i + 1] != data[p] && p < data.size()) { + ++p; + } + end = p + 1; + for (size_t j = 0; j < times[i]; ++j) { + for (size_t index = start; index < end - 1; ++index) { + result.push_back(result.back() + data[index + 1] - data[index]); + } + } + } + } + return result; +} + +template +void repeat_data(const T* src, T* dst, size_t size, vector starts, + vector times, Place place) { + const T* src_p = src; + T* dst_p = dst; + size_t count = 0; + for (size_t i = 0; i < times.size(); ++i) { + count = size * (starts[i + 1] - starts[i]); + for (size_t j = 0; j < times[i]; ++j) { + memory::Copy(place, dst_p, place, src_p, sizeof(T) * count); + dst_p += count; + } + src_p += count; + } +} + template class SeqExpandKernel : public framework::OpKernel { public: @@ -29,43 +77,52 @@ class SeqExpandKernel : public framework::OpKernel { auto* x = context.Input("X"); auto* out = context.Output("Out"); const T* x_data = x->data(); - T* out_data = out->mutable_data(context.GetPlace()); - size_t repeat = static_cast(context.Attr("repeat")); + auto x_dims = x->dims(); + auto x_lod = x->lod(); - if (repeat != 0) { - if (x->lod().size() == 0) { - std::vector level0; - for (size_t i = 0; i <= x->dims()[0]; i++) { - level0.push_back(i * repeat); - } - framework::LoD out_lod; - out_lod.push_back(level0); - out->set_lod(out_lod); - } - } - auto out_dim = out->dims(); - size_t element_len = framework::product(out_dim) / out_dim[0]; - std::vector cpy_map(out_dim[0]); - if (x->lod().size() == 0) { - auto lod = out->lod(); - for (int i = 0; i < lod.size() - 1; ++i) { - for (int j = lod[0][i]; i < lod[0][i + 1]; ++j) { - cpy_map[j] = i; - } + if (x_lod.size() == 0) { + vector level; + for (int i = 0; i < x->dims()[0] + 1; ++i) { + level.push_back(i); } + x_lod.push_back(level); + } else { + x_lod.insert(x_lod.begin(), x_lod[0]); } - if (platform::is_cpu_place(context.GetPlace())) { - for (int i = 0; i < out_dim[0]; ++i) { - memcpy(out_data + element_len * i, x_data + element_len * cpy_map[i], - sizeof(T) * element_len); + + size_t repeat = static_cast(context.Attr("repeat")); + vector repeats; + if (repeat != 0) { + for (int i = 0; i < x_lod[0].size() - 1; ++i) { + repeats.push_back(repeat); } + std::vector dims = framework::vectorize(x->dims()); + dims[0] = dims[0] * repeat; + auto out_dims = framework::make_ddim(dims); + out->Resize(out_dims); } else { - for (int i = 0; i < out_dim[0]; ++i) { - hl_memcpy(out_data + element_len * i, - const_cast(x_data) + element_len * cpy_map[i], - sizeof(T) * element_len); + auto* y = context.Input("Y"); + auto y_lod = y->lod(); + for (int i = 0; i < y_lod[0].size() - 1; ++i) { + repeats.push_back((y_lod[0][i + 1] - y_lod[0][i]) / + (x_lod[0][i + 1] - x_lod[0][i])); } + out->Resize(x_dims); } + + framework::LoD out_lod; + auto level0 = repeat_lod(x_lod[0], x_lod[0], repeats, true); + out_lod.push_back(level0); + for (int i = 1; i < x_lod.size(); ++i) { + out_lod.push_back(repeat_lod(x_lod[i], x_lod[0], repeats, false)); + } + + size_t element_len = framework::product(x_dims) / x_dims[0]; + T* out_data = out->mutable_data(context.GetPlace()); + Place place = boost::get(context.GetPlace()); + repeat_data(x_data, out_data, element_len, x_lod[0], repeats, + place); + out->set_lod(out_lod); } }; diff --git a/python/paddle/v2/framework/tests/op_test.py b/python/paddle/v2/framework/tests/op_test.py index 81067f38bbf64..0b0de78caf9a0 100644 --- a/python/paddle/v2/framework/tests/op_test.py +++ b/python/paddle/v2/framework/tests/op_test.py @@ -246,7 +246,9 @@ def check_output_with_place(self, place, atol): else: actual = np.array(self.scope.find_var(out_name).get_tensor()) expect = self.outputs[out_name] - + print "out_name: %s" % out_name + print "actual: %s" % actual + print "expcept: %s" % expect self.assertTrue( np.allclose( actual, expect, atol=atol), diff --git a/python/paddle/v2/framework/tests/test_seq_expand.py b/python/paddle/v2/framework/tests/test_seq_expand.py index 4608d3c3bd6f3..854148a8f1fe9 100644 --- a/python/paddle/v2/framework/tests/test_seq_expand.py +++ b/python/paddle/v2/framework/tests/test_seq_expand.py @@ -3,59 +3,119 @@ from op_test import OpTest +def repeat(list, starts, times, is_first): + newlist = [list[0]] + if is_first: + for i, time in enumerate(times): + size = list[i + 1] - list[i] + newlist.append(newlist[-1] + size * time) + else: + for i, time in enumerate(times): + start = list.index(starts[i]) + end = list.index(starts[i + 1]) + 1 + for t in range(time): + for index in range(start, end - 1): + newlist.append(newlist[-1] + list[index + 1] - list[index]) + return newlist + + +def repeat_array(array, starts, times): + newlist = [] + for i, time in enumerate(times): + for t in range(time): + newlist.extend(array[starts[i]:starts[i + 1]]) + return newlist + + class TestSeqExpand(OpTest): - #class TestSeqExpand(): def set_data(self): self.op_type = 'seq_expand' x = np.random.uniform(0.1, 1, [3, 2, 2]).astype('float32') y = np.zeros((6, 2, 2)).astype('float32') - lod = [[0, 2, 3, 6]] - print "x = %s" % x - self.inputs = {'X': x, 'Y': (y, lod)} - self.repeat = None + y_lod = [[0, 2, 3, 6]] + self.inputs = {'X': (x, None), 'Y': (y, y_lod)} + self.repeat = 2 def compute(self): - x = self.inputs['X'] - cpy_map = {} - lod = [] - out_shape = [] + x_data, x_lod = self.inputs['X'] + print "x_data: %s" % x_data + print "x_lod: %s" % x_lod + if not x_lod: + x_lod = [[i for i in range(1 + x_data.shape[0])]] + else: + x_lod = [x_lod[0]] + x_lod if self.repeat: - level0 = [] - for i in range(x.shape[0] + 1): - level0.append(i * self.repeat) - lod.append(level0) - - for i in x.shape: - out_shape.append(i) - out_shape[0] = out_shape[0] * self.repeat + self.attrs = {'repeat': self.repeat} + repeats = (len(x_lod[0]) - 1) * [self.repeat] + # get out shape + # out_shape = np.copy(x_data.shape) + # out_shape[0] = out_shape[0] * self.repeat else: - y, lod = self.inputs['Y'] - out_shape = y.shape - out = np.zeros(out_shape).astype('float32') + y_data, y_lod = self.inputs['Y'] + print "y_lod: %s" % y_lod + #print "y_lod: %s" % y_lod + # get repeats + repeats = [((y_lod[0][i + 1] - y_lod[0][i]) / + (x_lod[0][i + 1] - x_lod[0][i])) + for i in range(len(y_lod[0]) - 1)] + # get out shape + # out_shape = y_data.shape + # get out lod - start = 0 - - for i in range(len(lod[0]) - 1): - for j in range(lod[0][i], lod[0][i + 1]): - cpy_map[j] = i - print "cpy_map = %s" % cpy_map - for i in range(len(out)): - out[i] = x[cpy_map[i]] - - print "out = %s" % out - self.outputs = {'Out': (out, lod)} + out_lod = [repeat(x_lod[0], x_lod[0], repeats, True)] + [ + repeat(lod, x_lod[0], repeats, False) for lod in x_lod[1:] + ] + # copy data + out = repeat_array(x_data.tolist(), x_lod[0], repeats) + self.outputs = {'Out': (out, out_lod)} + print "outputs: %s" % self.outputs def setUp(self): + self.op_type = 'seq_expand' self.set_data() self.compute() def test_check_output(self): self.check_output() - def test_check_grad(self): - self.check_grad(["X"], "Out") + +# def test_check_grad(self): +# self.check_grad(["X"], "Out") + + +class TestSeqExpandCase1(TestSeqExpand): + def set_data(self): + x_data = np.random.uniform(0.1, 1, [7, 1]).astype('float32') + x_lod = [[0, 5, 7], [0, 2, 5, 7]] + self.inputs = {'X': (x_data, x_lod)} + self.repeat = 2 + + +class TestSeqExpandCase2(TestSeqExpand): + def set_data(self): + x_data = np.random.uniform(0.1, 1, [4, 1]).astype('float32') + self.inputs = {'X': (x_data, None)} + self.repeat = 2 + + +class TestSeqExpandCase3(TestSeqExpand): + def set_data(self): + x_data = np.random.uniform(0.1, 1, [3, 1]).astype('float32') + y_data = np.random.uniform(0.1, 1, [8, 1]).astype('float32') + y_lod = [[0, 1, 4, 8]] + self.inputs = {'X': (x_data, None), 'Y': (y_data, y_lod)} + self.repeat = None + + +class TestSeqExpandCase4(TestSeqExpand): + def set_data(self): + x_data = np.random.uniform(0.1, 1, [5, 1]).astype('float32') + x_lod = [[0, 2, 5]] + y_data = np.random.uniform(0.1, 1, [13, 1]).astype('float32') + y_lod = [[0, 4, 13], [0, 2, 4, 7, 10, 13]] + self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} + self.repeat = None if __name__ == '__main__': unittest.main() -# TestSeqExpand().setUp() From 8de04be786fe21a72b9be91dab963f5d7520885b Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 18 Oct 2017 17:14:38 +0800 Subject: [PATCH 04/12] Fix unitest --- paddle/framework/lod_tensor.cc | 29 +++++++ paddle/framework/lod_tensor.h | 7 ++ paddle/operators/seq_expand_op.h | 79 +++++-------------- .../v2/framework/tests/test_seq_expand.py | 30 ++----- 4 files changed, 64 insertions(+), 81 deletions(-) diff --git a/paddle/framework/lod_tensor.cc b/paddle/framework/lod_tensor.cc index 5b7badf89c171..1247daafc5771 100644 --- a/paddle/framework/lod_tensor.cc +++ b/paddle/framework/lod_tensor.cc @@ -103,5 +103,34 @@ void LoDTensor::ShrinkInLevel(size_t level, size_t elem_begin, lod_ = new_lod; } +Vector repeat_lod(Vector data, Vector starts, + Vector times, bool is_first) { + Vector result; + result.push_back(data[0]); + size_t p = 0, start = 0, end = 0; + if (is_first == true) { + for (size_t i = 0; i < times.size(); ++i) { + result.push_back(data.back() + times[i] * (data[i + 1] - data[i])); + } + } else { + for (size_t i = 0; i < times.size(); ++i) { + while (starts[i] != data[p] && p < data.size()) { + ++p; + } + start = p; + while (starts[i + 1] != data[p] && p < data.size()) { + ++p; + } + end = p + 1; + for (size_t j = 0; j < times[i]; ++j) { + for (size_t index = start; index < end - 1; ++index) { + result.push_back(result.back() + data[index + 1] - data[index]); + } + } + } + } + return result; +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h index 4db36ee76609a..41c83a11649a9 100644 --- a/paddle/framework/lod_tensor.h +++ b/paddle/framework/lod_tensor.h @@ -15,6 +15,9 @@ #pragma once #include +#include "paddle/memory/memcpy.h" +#include "paddle/platform/device_context.h" +#include "paddle/platform/place.h" #ifdef PADDLE_WITH_CUDA #include #include @@ -122,5 +125,9 @@ class LoDTensor : public Tensor { private: LoD lod_; }; + +Vector repeat_lod(Vector data, Vector starts, + Vector times, bool is_first); + } // namespace framework } // namespace paddle diff --git a/paddle/operators/seq_expand_op.h b/paddle/operators/seq_expand_op.h index cd1182c4f087f..221393f909384 100644 --- a/paddle/operators/seq_expand_op.h +++ b/paddle/operators/seq_expand_op.h @@ -22,54 +22,6 @@ namespace operators { using LoDTensor = framework::LoDTensor; -template -using vector = framework::Vector; - -vector repeat_lod(vector data, vector starts, - vector times, bool is_first) { - vector result; - result.push_back(data[0]); - size_t p = 0, start = 0, end = 0; - if (is_first == true) { - for (size_t i = 0; i < times.size(); ++i) { - result.push_back(data.back() + times[i] * (data[i + 1] - data[i])); - } - } else { - for (size_t i = 0; i < times.size(); ++i) { - while (starts[i] != data[p] && p < data.size()) { - ++p; - } - start = p; - while (starts[i + 1] != data[p] && p < data.size()) { - ++p; - } - end = p + 1; - for (size_t j = 0; j < times[i]; ++j) { - for (size_t index = start; index < end - 1; ++index) { - result.push_back(result.back() + data[index + 1] - data[index]); - } - } - } - } - return result; -} - -template -void repeat_data(const T* src, T* dst, size_t size, vector starts, - vector times, Place place) { - const T* src_p = src; - T* dst_p = dst; - size_t count = 0; - for (size_t i = 0; i < times.size(); ++i) { - count = size * (starts[i + 1] - starts[i]); - for (size_t j = 0; j < times[i]; ++j) { - memory::Copy(place, dst_p, place, src_p, sizeof(T) * count); - dst_p += count; - } - src_p += count; - } -} - template class SeqExpandKernel : public framework::OpKernel { public: @@ -81,7 +33,7 @@ class SeqExpandKernel : public framework::OpKernel { auto x_lod = x->lod(); if (x_lod.size() == 0) { - vector level; + framework::Vector level; for (int i = 0; i < x->dims()[0] + 1; ++i) { level.push_back(i); } @@ -91,7 +43,7 @@ class SeqExpandKernel : public framework::OpKernel { } size_t repeat = static_cast(context.Attr("repeat")); - vector repeats; + framework::Vector repeats; if (repeat != 0) { for (int i = 0; i < x_lod[0].size() - 1; ++i) { repeats.push_back(repeat); @@ -107,21 +59,32 @@ class SeqExpandKernel : public framework::OpKernel { repeats.push_back((y_lod[0][i + 1] - y_lod[0][i]) / (x_lod[0][i + 1] - x_lod[0][i])); } - out->Resize(x_dims); + out->Resize(y->dims()); } framework::LoD out_lod; - auto level0 = repeat_lod(x_lod[0], x_lod[0], repeats, true); + auto level0 = framework::repeat_lod(x_lod[0], x_lod[0], repeats, true); out_lod.push_back(level0); for (int i = 1; i < x_lod.size(); ++i) { - out_lod.push_back(repeat_lod(x_lod[i], x_lod[0], repeats, false)); + out_lod.push_back( + framework::repeat_lod(x_lod[i], x_lod[0], repeats, false)); } size_t element_len = framework::product(x_dims) / x_dims[0]; T* out_data = out->mutable_data(context.GetPlace()); + + // copy data Place place = boost::get(context.GetPlace()); - repeat_data(x_data, out_data, element_len, x_lod[0], repeats, - place); + size_t count = 0; + for (size_t i = 0; i < repeats.size(); ++i) { + count = element_len * (x_lod[0][i + 1] - x_lod[0][i]); + for (size_t j = 0; j < repeats[i]; ++j) { + memory::Copy(place, out_data, place, x_data, sizeof(T) * count); + out_data += count; + } + x_data += count; + } + out->set_lod(out_lod); } }; @@ -130,9 +93,9 @@ template class SeqExpandGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - // auto* d_out = context.Input(framework::GradVarName("Out")); - // auto* d_x = context.Output(framework::GradVarName("X")); - // d_x->mutable_data(context.GetPlace()); + auto* d_out = context.Input(framework::GradVarName("Out")); + auto* d_x = context.Output(framework::GradVarName("X")); + d_x->mutable_data(context.GetPlace()); } }; diff --git a/python/paddle/v2/framework/tests/test_seq_expand.py b/python/paddle/v2/framework/tests/test_seq_expand.py index 854148a8f1fe9..2b9509413e30f 100644 --- a/python/paddle/v2/framework/tests/test_seq_expand.py +++ b/python/paddle/v2/framework/tests/test_seq_expand.py @@ -29,17 +29,13 @@ def repeat_array(array, starts, times): class TestSeqExpand(OpTest): def set_data(self): - self.op_type = 'seq_expand' - x = np.random.uniform(0.1, 1, [3, 2, 2]).astype('float32') - y = np.zeros((6, 2, 2)).astype('float32') - y_lod = [[0, 2, 3, 6]] - self.inputs = {'X': (x, None), 'Y': (y, y_lod)} + x_data = np.random.uniform(0.1, 1, [4, 1]).astype('float32') + self.inputs = {'X': x_data} self.repeat = 2 def compute(self): - x_data, x_lod = self.inputs['X'] - print "x_data: %s" % x_data - print "x_lod: %s" % x_lod + x = self.inputs['X'] + x_data, x_lod = x if type(x) == tuple else (x, None) if not x_lod: x_lod = [[i for i in range(1 + x_data.shape[0])]] else: @@ -47,28 +43,16 @@ def compute(self): if self.repeat: self.attrs = {'repeat': self.repeat} repeats = (len(x_lod[0]) - 1) * [self.repeat] - # get out shape - # out_shape = np.copy(x_data.shape) - # out_shape[0] = out_shape[0] * self.repeat else: y_data, y_lod = self.inputs['Y'] - print "y_lod: %s" % y_lod - #print "y_lod: %s" % y_lod - # get repeats repeats = [((y_lod[0][i + 1] - y_lod[0][i]) / (x_lod[0][i + 1] - x_lod[0][i])) for i in range(len(y_lod[0]) - 1)] - # get out shape - # out_shape = y_data.shape - # get out lod - out_lod = [repeat(x_lod[0], x_lod[0], repeats, True)] + [ repeat(lod, x_lod[0], repeats, False) for lod in x_lod[1:] ] - # copy data out = repeat_array(x_data.tolist(), x_lod[0], repeats) - self.outputs = {'Out': (out, out_lod)} - print "outputs: %s" % self.outputs + self.outputs = {'Out': out} def setUp(self): self.op_type = 'seq_expand' @@ -94,7 +78,7 @@ def set_data(self): class TestSeqExpandCase2(TestSeqExpand): def set_data(self): x_data = np.random.uniform(0.1, 1, [4, 1]).astype('float32') - self.inputs = {'X': (x_data, None)} + self.inputs = {'X': x_data} self.repeat = 2 @@ -103,7 +87,7 @@ def set_data(self): x_data = np.random.uniform(0.1, 1, [3, 1]).astype('float32') y_data = np.random.uniform(0.1, 1, [8, 1]).astype('float32') y_lod = [[0, 1, 4, 8]] - self.inputs = {'X': (x_data, None), 'Y': (y_data, y_lod)} + self.inputs = {'X': x_data, 'Y': (y_data, y_lod)} self.repeat = None From 31531ab581f7d726d410c2181ac79ed41a32b3ef Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Thu, 19 Oct 2017 01:18:20 +0800 Subject: [PATCH 05/12] Add backward kernel --- paddle/framework/lod_tensor.cc | 2 +- paddle/operators/seq_expand_op.cc | 30 +++++-------------- paddle/operators/seq_expand_op.h | 27 +++++++++++++++-- paddle/operators/sequence_concat_op.cc | 10 +++---- python/paddle/v2/framework/tests/op_test.py | 3 -- .../v2/framework/tests/test_seq_expand.py | 5 ++-- 6 files changed, 39 insertions(+), 38 deletions(-) diff --git a/paddle/framework/lod_tensor.cc b/paddle/framework/lod_tensor.cc index 1247daafc5771..e4a2f5765a10c 100644 --- a/paddle/framework/lod_tensor.cc +++ b/paddle/framework/lod_tensor.cc @@ -110,7 +110,7 @@ Vector repeat_lod(Vector data, Vector starts, size_t p = 0, start = 0, end = 0; if (is_first == true) { for (size_t i = 0; i < times.size(); ++i) { - result.push_back(data.back() + times[i] * (data[i + 1] - data[i])); + result.push_back(result.back() + times[i] * (data[i + 1] - data[i])); } } else { for (size_t i = 0; i < times.size(); ++i) { diff --git a/paddle/operators/seq_expand_op.cc b/paddle/operators/seq_expand_op.cc index 63b17a10f59d0..59d713548930e 100644 --- a/paddle/operators/seq_expand_op.cc +++ b/paddle/operators/seq_expand_op.cc @@ -60,7 +60,8 @@ As an example: Given: -X = [1, 2 , 3] +X.data = [1, 2 , 3, 4] +X.lod = [[0, 3, 4], [0, 1, 3, 4]] and @@ -69,8 +70,8 @@ repeat = 2 then we get -Out.data = [1, 1, 2, 2, 3, 3] -Out.lod = [[0, 2, 4, 6]] +Out.data = [1, 2, 3, 1, 2, 3, 4, 4] +Out.lod = [[0, 6, 8], [0, 3, 6, 7, 8], [0, 1, 3, 4, 6, 7, 8]] )DOC"); } @@ -83,6 +84,7 @@ class SeqExpandOpGrad : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) should not be null"); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), "Input(Out@GRAD) should not be null"); auto x_dims = ctx->GetInputDim("X"); @@ -93,30 +95,12 @@ class SeqExpandOpGrad : public framework::OperatorWithKernel { } }; -class SeqExpandOpGradMaker : public framework::SingleGradOpDescMaker { - public: - using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; - - protected: - std::unique_ptr Apply() const override { - auto* bind = new framework::OpDescBind(); - bind->SetInput("X", Input("X")); - bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); - bind->SetOutput(framework::GradVarName("X"), InputGrad("X")); - bind->SetAttrMap(Attrs()); - bind->SetType("seq_expand_grad"); - return std::unique_ptr(bind); - } -}; - } // namespace operators } // namespace paddle namespace ops = paddle::operators; - -REGISTER_OPERATOR(seq_expand, ops::SeqExpandOp, ops::SeqExpandOpMaker, - ops::SeqExpandOpGradMaker); -REGISTER_OPERATOR(seq_expand_grad, ops::SeqExpandOpGrad); +REGISTER_OP(seq_expand, ops::SeqExpandOp, ops::SeqExpandOpMaker, + seq_expand_grad, ops::SeqExpandOpGrad); REGISTER_OP_CPU_KERNEL(seq_expand, ops::SeqExpandKernel); REGISTER_OP_CPU_KERNEL( diff --git a/paddle/operators/seq_expand_op.h b/paddle/operators/seq_expand_op.h index 221393f909384..8b7bda54c05d1 100644 --- a/paddle/operators/seq_expand_op.h +++ b/paddle/operators/seq_expand_op.h @@ -16,6 +16,7 @@ #include "paddle/framework/op_registry.h" #include "paddle/memory/memcpy.h" +#include "unsupported/Eigen/CXX11/Tensor" namespace paddle { namespace operators { @@ -93,9 +94,29 @@ template class SeqExpandGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* d_out = context.Input(framework::GradVarName("Out")); - auto* d_x = context.Output(framework::GradVarName("X")); - d_x->mutable_data(context.GetPlace()); + auto* d_out = context.Input(framework::GradVarName("Out")); + auto* d_x = context.Output(framework::GradVarName("X")); + auto* x = context.Input("X"); + auto* out = context.Input("Out"); + auto out_lod = out->lod(); + d_x->set_lod(x->lod()); + const T* d_out_data = d_out->data(); + auto d_out_dims = d_out->dims(); + T* d_x_data = d_x->mutable_data(context.GetPlace()); + size_t element_len = framework::product(d_out_dims) / d_out_dims[0]; + for (size_t i = 0; i < out->NumElements(); ++i) { + size_t ele_count = out_lod[0][i + 1] - out_lod[0][i]; + size_t repeat = out->NumElements(0, i); + Eigen::TensorMap> d_out_t( + d_out_data, static_cast(repeat), + static_cast((ele_count * element_len) / repeat)); + Eigen::TensorMap> d_x_t( + d_x_data, static_cast((ele_count * element_len) / repeat)); + auto place = context.GetEigenDevice(); + d_x_t.device(place) = d_out_t.sum(Eigen::array({0})); + d_out_data += (ele_count * element_len); + d_x_data += ((ele_count * element_len) / repeat); + } } }; diff --git a/paddle/operators/sequence_concat_op.cc b/paddle/operators/sequence_concat_op.cc index 1fce96cdfe20f..46f73e3c27983 100644 --- a/paddle/operators/sequence_concat_op.cc +++ b/paddle/operators/sequence_concat_op.cc @@ -68,12 +68,12 @@ class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker { "The level should be less than the level number of inputs.") .SetDefault(0); AddComment(R"DOC( - The sequence_concat operator concatenates multiple LoDTensors. - It only supports sequence (LoD Tensor with level number is 1) + The sequence_concat operator concatenates multiple LoDTensors. + It only supports sequence (LoD Tensor with level number is 1) or a nested sequence (LoD tensor with level number is 2) as its input. - Case1: If the axis is other than 0(here, axis is 1 and level is 1), - each input should have the same LoD information and the LoD + each input should have the same LoD information and the LoD information of the output keeps the same as the input. LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4) @@ -81,7 +81,7 @@ class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker { LoD(Out) = {{0,2,4}, {0,1,2,3,4}}; Dims(Out) = (4,7,4) - Case2: - If the axis is 0(here, leve is 0), the inputs are concatenated along + If the axis is 0(here, leve is 0), the inputs are concatenated along time steps, the LoD information of the output need to re-compute. LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4) @@ -94,7 +94,7 @@ class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker { LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4) LoD(x1) = {{0,3,5}, {0,1,3,4,5}}; Dims(x1) = (5,3,4) LoD(Out) = {{0,5,9}, {0,2,5,7,9}}; Dims(Out) = (9,3,4) - + NOTE: The levels of all the inputs should be the same. )DOC"); } diff --git a/python/paddle/v2/framework/tests/op_test.py b/python/paddle/v2/framework/tests/op_test.py index 3ef8ec3164b52..a88e9f0bb8213 100644 --- a/python/paddle/v2/framework/tests/op_test.py +++ b/python/paddle/v2/framework/tests/op_test.py @@ -246,9 +246,6 @@ def check_output_with_place(self, place, atol): else: actual = np.array(self.scope.find_var(out_name).get_tensor()) expect = self.outputs[out_name] - print "out_name: %s" % out_name - print "actual: %s" % actual - print "expcept: %s" % expect self.assertTrue( np.allclose( actual, expect, atol=atol), diff --git a/python/paddle/v2/framework/tests/test_seq_expand.py b/python/paddle/v2/framework/tests/test_seq_expand.py index 2b9509413e30f..87e39d72bf5b4 100644 --- a/python/paddle/v2/framework/tests/test_seq_expand.py +++ b/python/paddle/v2/framework/tests/test_seq_expand.py @@ -62,9 +62,8 @@ def setUp(self): def test_check_output(self): self.check_output() - -# def test_check_grad(self): -# self.check_grad(["X"], "Out") + def test_check_grad(self): + self.check_grad(["X"], "Out") class TestSeqExpandCase1(TestSeqExpand): From a94b3dd9a7422fdc02795e73e3e5b4168b0fff45 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Thu, 19 Oct 2017 16:59:43 +0800 Subject: [PATCH 06/12] Refine comments and function name 1. Add more comments and exmples 2. Rename repeat_lod to expand_lod 3. Remove unused head file --- paddle/framework/lod_tensor.cc | 22 ++++----- paddle/framework/lod_tensor.h | 7 +-- paddle/operators/seq_expand_op.cc | 76 +++++++++++++++++++++++-------- paddle/operators/seq_expand_op.h | 18 ++++---- 4 files changed, 80 insertions(+), 43 deletions(-) diff --git a/paddle/framework/lod_tensor.cc b/paddle/framework/lod_tensor.cc index e4a2f5765a10c..49d9e56689246 100644 --- a/paddle/framework/lod_tensor.cc +++ b/paddle/framework/lod_tensor.cc @@ -103,28 +103,28 @@ void LoDTensor::ShrinkInLevel(size_t level, size_t elem_begin, lod_ = new_lod; } -Vector repeat_lod(Vector data, Vector starts, - Vector times, bool is_first) { +Vector expand_lod(Vector level, Vector starts, + Vector scales, bool repeat) { Vector result; - result.push_back(data[0]); + result.push_back(level[0]); size_t p = 0, start = 0, end = 0; - if (is_first == true) { - for (size_t i = 0; i < times.size(); ++i) { - result.push_back(result.back() + times[i] * (data[i + 1] - data[i])); + if (!repeat) { + for (size_t i = 0; i < scales.size(); ++i) { + result.push_back(result.back() + scales[i] * (level[i + 1] - level[i])); } } else { - for (size_t i = 0; i < times.size(); ++i) { - while (starts[i] != data[p] && p < data.size()) { + for (size_t i = 0; i < scales.size(); ++i) { + while (starts[i] != level[p] && p < level.size()) { ++p; } start = p; - while (starts[i + 1] != data[p] && p < data.size()) { + while (starts[i + 1] != level[p] && p < level.size()) { ++p; } end = p + 1; - for (size_t j = 0; j < times[i]; ++j) { + for (size_t j = 0; j < scales[i]; ++j) { for (size_t index = start; index < end - 1; ++index) { - result.push_back(result.back() + data[index + 1] - data[index]); + result.push_back(result.back() + level[index + 1] - level[index]); } } } diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h index 41c83a11649a9..c64ee94405b79 100644 --- a/paddle/framework/lod_tensor.h +++ b/paddle/framework/lod_tensor.h @@ -15,9 +15,6 @@ #pragma once #include -#include "paddle/memory/memcpy.h" -#include "paddle/platform/device_context.h" -#include "paddle/platform/place.h" #ifdef PADDLE_WITH_CUDA #include #include @@ -126,8 +123,8 @@ class LoDTensor : public Tensor { LoD lod_; }; -Vector repeat_lod(Vector data, Vector starts, - Vector times, bool is_first); +Vector expand_lod(Vector level, Vector starts, + Vector scales, bool repeat); } // namespace framework } // namespace paddle diff --git a/paddle/operators/seq_expand_op.cc b/paddle/operators/seq_expand_op.cc index 59d713548930e..b9633721e296c 100644 --- a/paddle/operators/seq_expand_op.cc +++ b/paddle/operators/seq_expand_op.cc @@ -50,28 +50,68 @@ class SeqExpandOpMaker : public framework::OpProtoAndCheckerMaker { SeqExpandOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - // TODO(wanghaoshuang): Add more comments - AddInput("X", "The input('X') of seq_expand op."); - AddInput("Y", "The reference input('Y') of seq_expand op."); - AddOutput("Out", "The output of seq_expand op."); - AddAttr("repeat", "repeat times").SetDefault(0); + AddInput( + "X", + "The input('X') of seq_expand op. It can be LoDTensor or base Tensor."); + AddInput( + "Y", + "The reference input('Y') of seq_expand op." + "It must be a LoDTensor with k-level(k>0)." + "This reference input is essential if 'repeat' attribute is not " + "configured." + "Input(X) will be expanded by LoD of input(Y) while repeat == 0."); + AddOutput("Out", + "The output of seq_expand op." + "The output is a (k+1)-level LoDTensor" + "while input(X) being k-level LoDTensor." + "(Given base tensor is 0-level LoDTensor.)"); + AddAttr("repeat", + "(type:int; default value: 0)" + "Repeatting times of each element while expanding input(X)." + "It works while input(Y) is not configured.") + .SetDefault(0); AddComment(R"DOC( -As an example: +Expand k-level LoDTensor to (k+1)-level LoDTensor +by lod of input(Y) or 'repeat' attribute. -Given: - -X.data = [1, 2 , 3, 4] -X.lod = [[0, 3, 4], [0, 1, 3, 4]] +Case 1: +Given a 2-level LoDTensor X: + X.data = [1, 2 , 3, 4] + X.lod = [[0, 3, 4], [0, 1, 3, 4]] and - -repeat = 2 - - -then we get - -Out.data = [1, 2, 3, 1, 2, 3, 4, 4] -Out.lod = [[0, 6, 8], [0, 3, 6, 7, 8], [0, 1, 3, 4, 6, 7, 8]] + repeat = 2 +then we get 3-level LoDTensor + Out.data = [1, 2, 3, 1, 2, 3, 4, 4] + Out.lod = [[0, 6, 8], + [0, 3, 6, 7, 8], + [0, 1, 3, 4, 6, 7, 8]] + +Case 2: + +Given 2-level a LoDTensor X + X.data = [1, 2, 3, 4] + X.lod = [[0, 3, 4], [0, 1, 3, 4]] +and + Y.lod = [[0, 6, 8], + [0, 3, 6, 7, 8], + [0,1,3,4,6,7,8]] +then we get 3-level LoDTensor + Out.data = [1, 2, 3, 1, 2, 3, 4, 4] + Out.lod = [[0, 6, 8], + [0, 3, 6, 7, 8], + [0, 1, 3, 4, 6, 7, 8]] + +Case 3: + +Given a 0-level LoDTensor X + X.data = [1, 2, 3, 4] + X.lod = NULL +and + repeat = 2 +then we get 1-level LoDTensor + Out.data = [1, 1, 2, 2, 3, 3, 4, 4] + Out.lod = [[0, 2, 4, 6, 8]] )DOC"); } diff --git a/paddle/operators/seq_expand_op.h b/paddle/operators/seq_expand_op.h index 8b7bda54c05d1..e990f125127e5 100644 --- a/paddle/operators/seq_expand_op.h +++ b/paddle/operators/seq_expand_op.h @@ -44,10 +44,10 @@ class SeqExpandKernel : public framework::OpKernel { } size_t repeat = static_cast(context.Attr("repeat")); - framework::Vector repeats; + framework::Vector scales; if (repeat != 0) { for (int i = 0; i < x_lod[0].size() - 1; ++i) { - repeats.push_back(repeat); + scales.push_back(repeat); } std::vector dims = framework::vectorize(x->dims()); dims[0] = dims[0] * repeat; @@ -57,18 +57,18 @@ class SeqExpandKernel : public framework::OpKernel { auto* y = context.Input("Y"); auto y_lod = y->lod(); for (int i = 0; i < y_lod[0].size() - 1; ++i) { - repeats.push_back((y_lod[0][i + 1] - y_lod[0][i]) / - (x_lod[0][i + 1] - x_lod[0][i])); + scales.push_back((y_lod[0][i + 1] - y_lod[0][i]) / + (x_lod[0][i + 1] - x_lod[0][i])); } out->Resize(y->dims()); } framework::LoD out_lod; - auto level0 = framework::repeat_lod(x_lod[0], x_lod[0], repeats, true); + auto level0 = framework::expand_lod(x_lod[0], x_lod[0], scales, false); out_lod.push_back(level0); for (int i = 1; i < x_lod.size(); ++i) { out_lod.push_back( - framework::repeat_lod(x_lod[i], x_lod[0], repeats, false)); + framework::expand_lod(x_lod[i], x_lod[0], scales, true)); } size_t element_len = framework::product(x_dims) / x_dims[0]; @@ -77,9 +77,9 @@ class SeqExpandKernel : public framework::OpKernel { // copy data Place place = boost::get(context.GetPlace()); size_t count = 0; - for (size_t i = 0; i < repeats.size(); ++i) { + for (size_t i = 0; i < scales.size(); ++i) { count = element_len * (x_lod[0][i + 1] - x_lod[0][i]); - for (size_t j = 0; j < repeats[i]; ++j) { + for (size_t j = 0; j < scales[i]; ++j) { memory::Copy(place, out_data, place, x_data, sizeof(T) * count); out_data += count; } @@ -95,9 +95,9 @@ class SeqExpandGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* d_out = context.Input(framework::GradVarName("Out")); - auto* d_x = context.Output(framework::GradVarName("X")); auto* x = context.Input("X"); auto* out = context.Input("Out"); + auto* d_x = context.Output(framework::GradVarName("X")); auto out_lod = out->lod(); d_x->set_lod(x->lod()); const T* d_out_data = d_out->data(); From 00ad7512cf21b35df7658011a2d5b680cd3d1f19 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Fri, 20 Oct 2017 15:23:48 +0800 Subject: [PATCH 07/12] Use stream while memory::Copy in GPU mode --- paddle/operators/seq_expand_op.cc | 2 +- paddle/operators/seq_expand_op.h | 38 ++++++++++++++++++++++++------- 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/paddle/operators/seq_expand_op.cc b/paddle/operators/seq_expand_op.cc index b9633721e296c..7add3d60f6f42 100644 --- a/paddle/operators/seq_expand_op.cc +++ b/paddle/operators/seq_expand_op.cc @@ -40,7 +40,7 @@ class SeqExpandOp : public framework::OperatorWithKernel { out_dim[0] = out_dim[0] * repeat; } PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of PadOp should not be null."); + "Output(Out) of SeqExpandOp should not be null."); ctx->SetOutputDim("Out", out_dim); } }; diff --git a/paddle/operators/seq_expand_op.h b/paddle/operators/seq_expand_op.h index e990f125127e5..d1dcc97920735 100644 --- a/paddle/operators/seq_expand_op.h +++ b/paddle/operators/seq_expand_op.h @@ -75,15 +75,37 @@ class SeqExpandKernel : public framework::OpKernel { T* out_data = out->mutable_data(context.GetPlace()); // copy data - Place place = boost::get(context.GetPlace()); + auto place = context.GetPlace(); size_t count = 0; - for (size_t i = 0; i < scales.size(); ++i) { - count = element_len * (x_lod[0][i + 1] - x_lod[0][i]); - for (size_t j = 0; j < scales[i]; ++j) { - memory::Copy(place, out_data, place, x_data, sizeof(T) * count); - out_data += count; + if (platform::is_cpu_place(place)) { + auto& cpu_place = boost::get(place); + for (size_t i = 0; i < scales.size(); ++i) { + count = element_len * (x_lod[0][i + 1] - x_lod[0][i]); + for (size_t j = 0; j < scales[i]; ++j) { + memory::Copy(cpu_place, out_data, cpu_place, x_data, + sizeof(T) * count); + out_data += count; + } + x_data += count; } - x_data += count; + } else { +#ifdef PADDLE_WITH_CUDA + auto& gpu_place = boost::get(place); + auto stream = reinterpret_cast( + context.device_context()) + .stream(); + for (size_t i = 0; i < scales.size(); ++i) { + count = element_len * (x_lod[0][i + 1] - x_lod[0][i]); + for (size_t j = 0; j < scales[i]; ++j) { + memory::Copy(gpu_place, out_data, gpu_place, x_data, + sizeof(T) * count, stream); + out_data += count; + } + x_data += count; + } +#else + PADDLE_THROW("Paddle is not compiled with GPU"); +#endif } out->set_lod(out_lod); @@ -113,7 +135,7 @@ class SeqExpandGradKernel : public framework::OpKernel { Eigen::TensorMap> d_x_t( d_x_data, static_cast((ele_count * element_len) / repeat)); auto place = context.GetEigenDevice(); - d_x_t.device(place) = d_out_t.sum(Eigen::array({0})); + d_x_t.device(place) = d_out_t.sum(Eigen::array({{0}})); d_out_data += (ele_count * element_len); d_x_data += ((ele_count * element_len) / repeat); } From d697b6a3497dc7d72f29f0696f23d2d38e349581 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Mon, 23 Oct 2017 14:17:15 +0800 Subject: [PATCH 08/12] Modified code using LoDTensor --- paddle/framework/lod_tensor.cc | 14 ++---- paddle/framework/lod_tensor.h | 2 +- paddle/operators/seq_expand_op.cc | 10 ++--- paddle/operators/seq_expand_op.h | 45 ++++++++++++------- python/paddle/v2/framework/tests/op_test.py | 2 + .../v2/framework/tests/test_seq_expand.py | 38 ++++++++++------ 6 files changed, 65 insertions(+), 46 deletions(-) diff --git a/paddle/framework/lod_tensor.cc b/paddle/framework/lod_tensor.cc index 49d9e56689246..6f1e1b870bcd2 100644 --- a/paddle/framework/lod_tensor.cc +++ b/paddle/framework/lod_tensor.cc @@ -103,25 +103,19 @@ void LoDTensor::ShrinkInLevel(size_t level, size_t elem_begin, lod_ = new_lod; } -Vector expand_lod(Vector level, Vector starts, +Vector expand_lod(Vector level, Vector indexes, Vector scales, bool repeat) { Vector result; result.push_back(level[0]); - size_t p = 0, start = 0, end = 0; + size_t start = 0, end = 0; if (!repeat) { for (size_t i = 0; i < scales.size(); ++i) { result.push_back(result.back() + scales[i] * (level[i + 1] - level[i])); } } else { for (size_t i = 0; i < scales.size(); ++i) { - while (starts[i] != level[p] && p < level.size()) { - ++p; - } - start = p; - while (starts[i + 1] != level[p] && p < level.size()) { - ++p; - } - end = p + 1; + start = indexes[i]; + end = indexes[i + 1]; for (size_t j = 0; j < scales[i]; ++j) { for (size_t index = start; index < end - 1; ++index) { result.push_back(result.back() + level[index + 1] - level[index]); diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h index af5e9f8abc1d0..4d1ec29f6001c 100644 --- a/paddle/framework/lod_tensor.h +++ b/paddle/framework/lod_tensor.h @@ -123,7 +123,7 @@ class LoDTensor : public Tensor { LoD lod_; }; -Vector expand_lod(Vector level, Vector starts, +Vector expand_lod(Vector level, Vector indexes, Vector scales, bool repeat); } // namespace framework diff --git a/paddle/operators/seq_expand_op.cc b/paddle/operators/seq_expand_op.cc index 7add3d60f6f42..d02a94d1645ea 100644 --- a/paddle/operators/seq_expand_op.cc +++ b/paddle/operators/seq_expand_op.cc @@ -77,15 +77,15 @@ by lod of input(Y) or 'repeat' attribute. Case 1: Given a 2-level LoDTensor X: - X.data = [1, 2 , 3, 4] + X.data = [a, b , c, d] X.lod = [[0, 3, 4], [0, 1, 3, 4]] and repeat = 2 then we get 3-level LoDTensor - Out.data = [1, 2, 3, 1, 2, 3, 4, 4] - Out.lod = [[0, 6, 8], - [0, 3, 6, 7, 8], - [0, 1, 3, 4, 6, 7, 8]] + Out.lod = [[0, 6, 8], + [0, 3, 6, 7, 8], + [0, 1, 3, 4, 6, 7, 8]] + Out.data = [a, b, c, a, b, c, d, d] Case 2: diff --git a/paddle/operators/seq_expand_op.h b/paddle/operators/seq_expand_op.h index d1dcc97920735..e31f60db495de 100644 --- a/paddle/operators/seq_expand_op.h +++ b/paddle/operators/seq_expand_op.h @@ -33,15 +33,12 @@ class SeqExpandKernel : public framework::OpKernel { auto x_dims = x->dims(); auto x_lod = x->lod(); - if (x_lod.size() == 0) { - framework::Vector level; - for (int i = 0; i < x->dims()[0] + 1; ++i) { - level.push_back(i); - } - x_lod.push_back(level); - } else { - x_lod.insert(x_lod.begin(), x_lod[0]); + framework::Vector level; + size_t num = (x_lod.size() == 0) ? (x->dims()[0] + 1) : x_lod[0].size(); + for (int i = 0; i < num; ++i) { + level.push_back(i); } + x_lod.push_back(level); size_t repeat = static_cast(context.Attr("repeat")); framework::Vector scales; @@ -56,19 +53,27 @@ class SeqExpandKernel : public framework::OpKernel { } else { auto* y = context.Input("Y"); auto y_lod = y->lod(); - for (int i = 0; i < y_lod[0].size() - 1; ++i) { - scales.push_back((y_lod[0][i + 1] - y_lod[0][i]) / - (x_lod[0][i + 1] - x_lod[0][i])); + auto y_abs_lod = y_lod.ToAbsOffset(); + auto x_abs_lod = x_lod.ToAbsOffset(); + for (int i = 0; i < y_abs_lod[0].size() - 1; ++i) { + scales.push_back((y_abs_lod[0][i + 1] - y_abs_lod[0][i]) / + (x_abs_lod[0][i + 1] - x_abs_lod[0][i])); } out->Resize(y->dims()); } + framework::Vector indexes; + for (int size_t i = 0; i < x_lod[0]; ++i) { + indexes[i] = x_lod[0]; + } framework::LoD out_lod; - auto level0 = framework::expand_lod(x_lod[0], x_lod[0], scales, false); + auto level0 = framework::expand_lod(indexes, x_lod[0], scales, false); out_lod.push_back(level0); for (int i = 1; i < x_lod.size(); ++i) { - out_lod.push_back( - framework::expand_lod(x_lod[i], x_lod[0], scales, true)); + for (int j = 0; j < indexes.size(); ++j) { + indexes[j] = x_lod[i - 1][indexes[j]]; + } + out_lod.push_back(framework::expand_lod(x_lod[i], indexes, scales, true)); } size_t element_len = framework::product(x_dims) / x_dims[0]; @@ -80,7 +85,7 @@ class SeqExpandKernel : public framework::OpKernel { if (platform::is_cpu_place(place)) { auto& cpu_place = boost::get(place); for (size_t i = 0; i < scales.size(); ++i) { - count = element_len * (x_lod[0][i + 1] - x_lod[0][i]); + count = element_len * (x_abs_lod[0][i + 1] - x_abs_lod[0][i]); for (size_t j = 0; j < scales[i]; ++j) { memory::Copy(cpu_place, out_data, cpu_place, x_data, sizeof(T) * count); @@ -95,7 +100,7 @@ class SeqExpandKernel : public framework::OpKernel { context.device_context()) .stream(); for (size_t i = 0; i < scales.size(); ++i) { - count = element_len * (x_lod[0][i + 1] - x_lod[0][i]); + count = element_len * (x_abs_lod[0][i + 1] - x_abs_lod[0][i]); for (size_t j = 0; j < scales[i]; ++j) { memory::Copy(gpu_place, out_data, gpu_place, x_data, sizeof(T) * count, stream); @@ -109,6 +114,11 @@ class SeqExpandKernel : public framework::OpKernel { } out->set_lod(out_lod); + for (size_t i = 0; i < lod.size; i++) { + for (size_t j = 0; j < lod[i].size(); j++) { + LOG(INFO) << "lod[" << i << "][" << j "] = " << lod[i][j]; + } + } } }; @@ -121,13 +131,14 @@ class SeqExpandGradKernel : public framework::OpKernel { auto* out = context.Input("Out"); auto* d_x = context.Output(framework::GradVarName("X")); auto out_lod = out->lod(); + auto out_abs_lod = out_lod.ToAbsOffset(); d_x->set_lod(x->lod()); const T* d_out_data = d_out->data(); auto d_out_dims = d_out->dims(); T* d_x_data = d_x->mutable_data(context.GetPlace()); size_t element_len = framework::product(d_out_dims) / d_out_dims[0]; for (size_t i = 0; i < out->NumElements(); ++i) { - size_t ele_count = out_lod[0][i + 1] - out_lod[0][i]; + size_t ele_count = out_abs_lod[0][i + 1] - out_abs_lod[0][i]; size_t repeat = out->NumElements(0, i); Eigen::TensorMap> d_out_t( d_out_data, static_cast(repeat), diff --git a/python/paddle/v2/framework/tests/op_test.py b/python/paddle/v2/framework/tests/op_test.py index a88e9f0bb8213..f3108d5108af9 100644 --- a/python/paddle/v2/framework/tests/op_test.py +++ b/python/paddle/v2/framework/tests/op_test.py @@ -246,6 +246,8 @@ def check_output_with_place(self, place, atol): else: actual = np.array(self.scope.find_var(out_name).get_tensor()) expect = self.outputs[out_name] + print "actual= %s" % actual + print "expect = %s" % expect self.assertTrue( np.allclose( actual, expect, atol=atol), diff --git a/python/paddle/v2/framework/tests/test_seq_expand.py b/python/paddle/v2/framework/tests/test_seq_expand.py index 87e39d72bf5b4..2910af6b78a80 100644 --- a/python/paddle/v2/framework/tests/test_seq_expand.py +++ b/python/paddle/v2/framework/tests/test_seq_expand.py @@ -27,7 +27,15 @@ def repeat_array(array, starts, times): return newlist +def toAbsOffset(lod): + for i in range(len(lod) - 2, -1, -1): + for j in range(len(lod[i])): + lod[i][j] = lod[i + 1][lod[i][j]] + return lod + + class TestSeqExpand(OpTest): + #class TestSeqExpand(): def set_data(self): x_data = np.random.uniform(0.1, 1, [4, 1]).astype('float32') self.inputs = {'X': x_data} @@ -35,23 +43,26 @@ def set_data(self): def compute(self): x = self.inputs['X'] + print "x= %s" % x x_data, x_lod = x if type(x) == tuple else (x, None) - if not x_lod: - x_lod = [[i for i in range(1 + x_data.shape[0])]] - else: - x_lod = [x_lod[0]] + x_lod + n = 1 + x_data.shape[0] if not x_lod else len(x_lod[0]) + x_lod = [[i for i in range(n)]] + x_lod + x_abs_lod = toAbsOffset(x_lod) if self.repeat: + print "repeat= %s" % self.repeat self.attrs = {'repeat': self.repeat} repeats = (len(x_lod[0]) - 1) * [self.repeat] else: y_data, y_lod = self.inputs['Y'] - repeats = [((y_lod[0][i + 1] - y_lod[0][i]) / - (x_lod[0][i + 1] - x_lod[0][i])) - for i in range(len(y_lod[0]) - 1)] - out_lod = [repeat(x_lod[0], x_lod[0], repeats, True)] + [ - repeat(lod, x_lod[0], repeats, False) for lod in x_lod[1:] - ] - out = repeat_array(x_data.tolist(), x_lod[0], repeats) + print "y_lod: %s" % y_lod + y_abs_lod = toAbsOffset(y_lod) + repeats = [((y_abs_lod[0][i + 1] - y_abs_lod[0][i]) / + (x_abs_lod[0][i + 1] - x_abs_lod[0][i])) + for i in range(len(y_abs_lod[0]) - 1)] + #out_lod = [repeat(x_lod[0], x_lod[0], repeats, True)] + [ + # repeat(lod, x_lod[0], repeats, False) for lod in x_lod[1:] + #] + out = repeat_array(x_data.tolist(), x_abs_lod[0], repeats) self.outputs = {'Out': out} def setUp(self): @@ -69,7 +80,7 @@ def test_check_grad(self): class TestSeqExpandCase1(TestSeqExpand): def set_data(self): x_data = np.random.uniform(0.1, 1, [7, 1]).astype('float32') - x_lod = [[0, 5, 7], [0, 2, 5, 7]] + x_lod = [[0, 2, 3], [0, 2, 5, 7]] self.inputs = {'X': (x_data, x_lod)} self.repeat = 2 @@ -95,10 +106,11 @@ def set_data(self): x_data = np.random.uniform(0.1, 1, [5, 1]).astype('float32') x_lod = [[0, 2, 5]] y_data = np.random.uniform(0.1, 1, [13, 1]).astype('float32') - y_lod = [[0, 4, 13], [0, 2, 4, 7, 10, 13]] + y_lod = [[0, 2, 5], [0, 2, 4, 7, 10, 13]] self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} self.repeat = None if __name__ == '__main__': unittest.main() +# TestSeqExpandCase4().setUp() From 296167446c35228c7e259677d82a3c85b896a7b5 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Tue, 24 Oct 2017 14:10:02 +0800 Subject: [PATCH 09/12] Rewrite sequence expand op --- paddle/framework/lod_tensor.cc | 23 ---- paddle/framework/lod_tensor.h | 3 - paddle/operators/seq_expand_op.cc | 109 +++++++-------- paddle/operators/seq_expand_op.h | 128 +++++------------- python/paddle/v2/framework/tests/op_test.py | 2 - .../v2/framework/tests/test_seq_expand.py | 96 +++---------- 6 files changed, 97 insertions(+), 264 deletions(-) diff --git a/paddle/framework/lod_tensor.cc b/paddle/framework/lod_tensor.cc index a7b2b5b1ec8c1..7c0ea0df78298 100644 --- a/paddle/framework/lod_tensor.cc +++ b/paddle/framework/lod_tensor.cc @@ -112,28 +112,5 @@ void LoDTensor::ShrinkInLevel(size_t level, size_t elem_begin, lod_ = new_lod; } -Vector expand_lod(Vector level, Vector indexes, - Vector scales, bool repeat) { - Vector result; - result.push_back(level[0]); - size_t start = 0, end = 0; - if (!repeat) { - for (size_t i = 0; i < scales.size(); ++i) { - result.push_back(result.back() + scales[i] * (level[i + 1] - level[i])); - } - } else { - for (size_t i = 0; i < scales.size(); ++i) { - start = indexes[i]; - end = indexes[i + 1]; - for (size_t j = 0; j < scales[i]; ++j) { - for (size_t index = start; index < end - 1; ++index) { - result.push_back(result.back() + level[index + 1] - level[index]); - } - } - } - } - return result; -} - } // namespace framework } // namespace paddle diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h index ec0b34878b01e..3895d3cb83bbe 100644 --- a/paddle/framework/lod_tensor.h +++ b/paddle/framework/lod_tensor.h @@ -136,8 +136,5 @@ class LoDTensor : public Tensor { LoD lod_; }; -Vector expand_lod(Vector level, Vector indexes, - Vector scales, bool repeat); - } // namespace framework } // namespace paddle diff --git a/paddle/operators/seq_expand_op.cc b/paddle/operators/seq_expand_op.cc index d02a94d1645ea..660e86e9ccca0 100644 --- a/paddle/operators/seq_expand_op.cc +++ b/paddle/operators/seq_expand_op.cc @@ -27,20 +27,14 @@ class SeqExpandOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of SeqExpandOp should not be null."); - int repeat = ctx->Attrs().Get("repeat"); - framework::DDim out_dim; - if (repeat == 0) { - PADDLE_ENFORCE( - ctx->HasInput("Y"), - "Input(Y) of SeqExpandOp should not be null while repeat == 0."); - out_dim = ctx->GetInputDim("Y"); - ctx->ShareLoD("Y", "Out"); - } else { - out_dim = ctx->GetInputDim("X"); - out_dim[0] = out_dim[0] * repeat; - } PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of SeqExpandOp should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("Y"), + "Input(Y) of SeqExpandOp should not be null while repeat == 0."); + framework::DDim out_dim; + out_dim = ctx->GetInputDim("Y"); + ctx->ShareLoD("Y", "Out"); ctx->SetOutputDim("Out", out_dim); } }; @@ -50,68 +44,63 @@ class SeqExpandOpMaker : public framework::OpProtoAndCheckerMaker { SeqExpandOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput( - "X", - "The input('X') of seq_expand op. It can be LoDTensor or base Tensor."); - AddInput( - "Y", - "The reference input('Y') of seq_expand op." - "It must be a LoDTensor with k-level(k>0)." - "This reference input is essential if 'repeat' attribute is not " - "configured." - "Input(X) will be expanded by LoD of input(Y) while repeat == 0."); + AddInput("X", + "(Tensor or LoDTensor) The input('X') of this operator can be a " + "LoDTensor or a base Tensor."); + AddInput("Y", + "(LoDTensor)The reference input('Y') of seq_expand op." + "It must be a LoDTensor with k-level(k>0)." + "Input(X) will be expanded according to LOD of input(Y)." + "The element numbers of last level in input('Y') " + "must be equal to dims[0] of input('X')."); AddOutput("Out", "The output of seq_expand op." - "The output is a (k+1)-level LoDTensor" - "while input(X) being k-level LoDTensor." - "(Given base tensor is 0-level LoDTensor.)"); - AddAttr("repeat", - "(type:int; default value: 0)" - "Repeatting times of each element while expanding input(X)." - "It works while input(Y) is not configured.") - .SetDefault(0); + "The lod of output will be as same as input(Y)'s lod."); AddComment(R"DOC( -Expand k-level LoDTensor to (k+1)-level LoDTensor -by lod of input(Y) or 'repeat' attribute. +Expand input(X) according to LOD of input(Y). Case 1: -Given a 2-level LoDTensor X: - X.data = [a, b , c, d] - X.lod = [[0, 3, 4], [0, 1, 3, 4]] -and - repeat = 2 -then we get 3-level LoDTensor - Out.lod = [[0, 6, 8], - [0, 3, 6, 7, 8], - [0, 1, 3, 4, 6, 7, 8]] - Out.data = [a, b, c, a, b, c, d, d] +Given 2-level a LoDTensor input(X) + X.lod = [[0, 2, 3], + [0, 1, 3, 4]] + X.data = [a, b, c, d] + X.dims = [4, 1] +and input(Y) + Y.lod = [[0, 2, 4], + [0, 3, 6, 7, 8]] +then we get 2-level LoDTensor + Out.lod = [[0, 2, 4], + [0, 3, 6, 7, 8]] + Out.data = [a, a, a, b, b, b, c, d] + Out.dims = [8, 1] Case 2: -Given 2-level a LoDTensor X - X.data = [1, 2, 3, 4] - X.lod = [[0, 3, 4], [0, 1, 3, 4]] -and - Y.lod = [[0, 6, 8], - [0, 3, 6, 7, 8], - [0,1,3,4,6,7,8]] -then we get 3-level LoDTensor - Out.data = [1, 2, 3, 1, 2, 3, 4, 4] - Out.lod = [[0, 6, 8], - [0, 3, 6, 7, 8], - [0, 1, 3, 4, 6, 7, 8]] +Given a 0-level LoDTensor input(X) + X.data = [a, b, c] + X.lod = NULL + X.dims = [3, 1] +and input(Y) + Y.lod = [[0, 2, 3, 6]] +then we get 1-level LoDTensor + Out.lod = [[0, 2, 3, 6]] + Out.data = [a, a, b, c, c, c] + Out.dims = [6, 1] Case 3: -Given a 0-level LoDTensor X - X.data = [1, 2, 3, 4] +Given a 0-level LoDTensor input(X) + X.data = [[a, b], [c, d], [e, f]] X.lod = NULL -and - repeat = 2 + X.dims = [3, 2] +and input(Y) + Y.lod = [[0, 2, 3, 6]] then we get 1-level LoDTensor - Out.data = [1, 1, 2, 2, 3, 3, 4, 4] - Out.lod = [[0, 2, 4, 6, 8]] + Out.lod = [[0, 2, 3, 6]] + Out.data = [[a,b], [a,b] [c,d], [e, f], [e, f], [e, f]] + Out.dims = [6, 2] + )DOC"); } diff --git a/paddle/operators/seq_expand_op.h b/paddle/operators/seq_expand_op.h index e31f60db495de..ad3f42116d598 100644 --- a/paddle/operators/seq_expand_op.h +++ b/paddle/operators/seq_expand_op.h @@ -31,93 +31,28 @@ class SeqExpandKernel : public framework::OpKernel { auto* out = context.Output("Out"); const T* x_data = x->data(); auto x_dims = x->dims(); - auto x_lod = x->lod(); - - framework::Vector level; - size_t num = (x_lod.size() == 0) ? (x->dims()[0] + 1) : x_lod[0].size(); - for (int i = 0; i < num; ++i) { - level.push_back(i); - } - x_lod.push_back(level); - - size_t repeat = static_cast(context.Attr("repeat")); - framework::Vector scales; - if (repeat != 0) { - for (int i = 0; i < x_lod[0].size() - 1; ++i) { - scales.push_back(repeat); - } - std::vector dims = framework::vectorize(x->dims()); - dims[0] = dims[0] * repeat; - auto out_dims = framework::make_ddim(dims); - out->Resize(out_dims); - } else { - auto* y = context.Input("Y"); - auto y_lod = y->lod(); - auto y_abs_lod = y_lod.ToAbsOffset(); - auto x_abs_lod = x_lod.ToAbsOffset(); - for (int i = 0; i < y_abs_lod[0].size() - 1; ++i) { - scales.push_back((y_abs_lod[0][i + 1] - y_abs_lod[0][i]) / - (x_abs_lod[0][i + 1] - x_abs_lod[0][i])); - } - out->Resize(y->dims()); - } - - framework::Vector indexes; - for (int size_t i = 0; i < x_lod[0]; ++i) { - indexes[i] = x_lod[0]; - } - framework::LoD out_lod; - auto level0 = framework::expand_lod(indexes, x_lod[0], scales, false); - out_lod.push_back(level0); - for (int i = 1; i < x_lod.size(); ++i) { - for (int j = 0; j < indexes.size(); ++j) { - indexes[j] = x_lod[i - 1][indexes[j]]; - } - out_lod.push_back(framework::expand_lod(x_lod[i], indexes, scales, true)); - } - + auto* y = context.Input("Y"); + PADDLE_ENFORCE_EQ(x_dims[0], y->lod().back().size() - 1, + "The size of last lod level in Input(Y)" + "must be equal to dims[0] of Input(X)."); + out->set_lod(y->lod()); + out->Resize(y->dims()); + auto place = context.GetEigenDevice(); size_t element_len = framework::product(x_dims) / x_dims[0]; T* out_data = out->mutable_data(context.GetPlace()); - - // copy data - auto place = context.GetPlace(); - size_t count = 0; - if (platform::is_cpu_place(place)) { - auto& cpu_place = boost::get(place); - for (size_t i = 0; i < scales.size(); ++i) { - count = element_len * (x_abs_lod[0][i + 1] - x_abs_lod[0][i]); - for (size_t j = 0; j < scales[i]; ++j) { - memory::Copy(cpu_place, out_data, cpu_place, x_data, - sizeof(T) * count); - out_data += count; - } - x_data += count; - } - } else { -#ifdef PADDLE_WITH_CUDA - auto& gpu_place = boost::get(place); - auto stream = reinterpret_cast( - context.device_context()) - .stream(); - for (size_t i = 0; i < scales.size(); ++i) { - count = element_len * (x_abs_lod[0][i + 1] - x_abs_lod[0][i]); - for (size_t j = 0; j < scales[i]; ++j) { - memory::Copy(gpu_place, out_data, gpu_place, x_data, - sizeof(T) * count, stream); - out_data += count; - } - x_data += count; - } -#else - PADDLE_THROW("Paddle is not compiled with GPU"); -#endif - } - - out->set_lod(out_lod); - for (size_t i = 0; i < lod.size; i++) { - for (size_t j = 0; j < lod[i].size(); j++) { - LOG(INFO) << "lod[" << i << "][" << j "] = " << lod[i][j]; - } + auto out_starts = out->lod().back(); + + for (size_t i = 0; i < out_starts.size() - 1; i++) { + int scale = out_starts[i + 1] - out_starts[i]; + Eigen::TensorMap< + Eigen::Tensor> + x_t(x_data, 1, element_len); + Eigen::TensorMap> + out_t(out_data, scale, element_len); + Eigen::array cast({scale, 1}); + out_t.device(place) = x_t.broadcast(cast); + x_data += element_len; + out_data += element_len * scale; } } }; @@ -130,25 +65,24 @@ class SeqExpandGradKernel : public framework::OpKernel { auto* x = context.Input("X"); auto* out = context.Input("Out"); auto* d_x = context.Output(framework::GradVarName("X")); - auto out_lod = out->lod(); - auto out_abs_lod = out_lod.ToAbsOffset(); + auto out_last_level = out->lod().back(); d_x->set_lod(x->lod()); const T* d_out_data = d_out->data(); auto d_out_dims = d_out->dims(); T* d_x_data = d_x->mutable_data(context.GetPlace()); size_t element_len = framework::product(d_out_dims) / d_out_dims[0]; - for (size_t i = 0; i < out->NumElements(); ++i) { - size_t ele_count = out_abs_lod[0][i + 1] - out_abs_lod[0][i]; - size_t repeat = out->NumElements(0, i); - Eigen::TensorMap> d_out_t( - d_out_data, static_cast(repeat), - static_cast((ele_count * element_len) / repeat)); - Eigen::TensorMap> d_x_t( - d_x_data, static_cast((ele_count * element_len) / repeat)); + + for (size_t i = 0; i < out_last_level.size() - 1; ++i) { + size_t repeat = out_last_level[i + 1] - out_last_level[i]; + Eigen::TensorMap< + Eigen::Tensor> + d_out_t(d_out_data, static_cast(repeat), element_len); + Eigen::TensorMap> + d_x_t(d_x_data, static_cast(element_len)); auto place = context.GetEigenDevice(); d_x_t.device(place) = d_out_t.sum(Eigen::array({{0}})); - d_out_data += (ele_count * element_len); - d_x_data += ((ele_count * element_len) / repeat); + d_out_data += (repeat * element_len); + d_x_data += element_len; } } }; diff --git a/python/paddle/v2/framework/tests/op_test.py b/python/paddle/v2/framework/tests/op_test.py index f3108d5108af9..a88e9f0bb8213 100644 --- a/python/paddle/v2/framework/tests/op_test.py +++ b/python/paddle/v2/framework/tests/op_test.py @@ -246,8 +246,6 @@ def check_output_with_place(self, place, atol): else: actual = np.array(self.scope.find_var(out_name).get_tensor()) expect = self.outputs[out_name] - print "actual= %s" % actual - print "expect = %s" % expect self.assertTrue( np.allclose( actual, expect, atol=atol), diff --git a/python/paddle/v2/framework/tests/test_seq_expand.py b/python/paddle/v2/framework/tests/test_seq_expand.py index 2910af6b78a80..901102802b129 100644 --- a/python/paddle/v2/framework/tests/test_seq_expand.py +++ b/python/paddle/v2/framework/tests/test_seq_expand.py @@ -3,66 +3,21 @@ from op_test import OpTest -def repeat(list, starts, times, is_first): - newlist = [list[0]] - if is_first: - for i, time in enumerate(times): - size = list[i + 1] - list[i] - newlist.append(newlist[-1] + size * time) - else: - for i, time in enumerate(times): - start = list.index(starts[i]) - end = list.index(starts[i + 1]) + 1 - for t in range(time): - for index in range(start, end - 1): - newlist.append(newlist[-1] + list[index + 1] - list[index]) - return newlist - - -def repeat_array(array, starts, times): - newlist = [] - for i, time in enumerate(times): - for t in range(time): - newlist.extend(array[starts[i]:starts[i + 1]]) - return newlist - - -def toAbsOffset(lod): - for i in range(len(lod) - 2, -1, -1): - for j in range(len(lod[i])): - lod[i][j] = lod[i + 1][lod[i][j]] - return lod - - class TestSeqExpand(OpTest): - #class TestSeqExpand(): def set_data(self): - x_data = np.random.uniform(0.1, 1, [4, 1]).astype('float32') - self.inputs = {'X': x_data} - self.repeat = 2 + x_data = np.random.uniform(0.1, 1, [3, 1]).astype('float32') + y_data = np.random.uniform(0.1, 1, [8, 1]).astype('float32') + y_lod = [[0, 1, 4, 8]] + self.inputs = {'X': x_data, 'Y': (y_data, y_lod)} def compute(self): x = self.inputs['X'] - print "x= %s" % x x_data, x_lod = x if type(x) == tuple else (x, None) n = 1 + x_data.shape[0] if not x_lod else len(x_lod[0]) - x_lod = [[i for i in range(n)]] + x_lod - x_abs_lod = toAbsOffset(x_lod) - if self.repeat: - print "repeat= %s" % self.repeat - self.attrs = {'repeat': self.repeat} - repeats = (len(x_lod[0]) - 1) * [self.repeat] - else: - y_data, y_lod = self.inputs['Y'] - print "y_lod: %s" % y_lod - y_abs_lod = toAbsOffset(y_lod) - repeats = [((y_abs_lod[0][i + 1] - y_abs_lod[0][i]) / - (x_abs_lod[0][i + 1] - x_abs_lod[0][i])) - for i in range(len(y_abs_lod[0]) - 1)] - #out_lod = [repeat(x_lod[0], x_lod[0], repeats, True)] + [ - # repeat(lod, x_lod[0], repeats, False) for lod in x_lod[1:] - #] - out = repeat_array(x_data.tolist(), x_abs_lod[0], repeats) + y_data, y_lod = self.inputs['Y'] + repeats = [((y_lod[-1][i + 1] - y_lod[-1][i])) + for i in range(len(y_lod[-1]) - 1)] + out = x_data.repeat(repeats, axis=0) self.outputs = {'Out': out} def setUp(self): @@ -78,39 +33,22 @@ def test_check_grad(self): class TestSeqExpandCase1(TestSeqExpand): - def set_data(self): - x_data = np.random.uniform(0.1, 1, [7, 1]).astype('float32') - x_lod = [[0, 2, 3], [0, 2, 5, 7]] - self.inputs = {'X': (x_data, x_lod)} - self.repeat = 2 - - -class TestSeqExpandCase2(TestSeqExpand): - def set_data(self): - x_data = np.random.uniform(0.1, 1, [4, 1]).astype('float32') - self.inputs = {'X': x_data} - self.repeat = 2 - - -class TestSeqExpandCase3(TestSeqExpand): - def set_data(self): - x_data = np.random.uniform(0.1, 1, [3, 1]).astype('float32') - y_data = np.random.uniform(0.1, 1, [8, 1]).astype('float32') - y_lod = [[0, 1, 4, 8]] - self.inputs = {'X': x_data, 'Y': (y_data, y_lod)} - self.repeat = None - - -class TestSeqExpandCase4(TestSeqExpand): def set_data(self): x_data = np.random.uniform(0.1, 1, [5, 1]).astype('float32') x_lod = [[0, 2, 5]] y_data = np.random.uniform(0.1, 1, [13, 1]).astype('float32') y_lod = [[0, 2, 5], [0, 2, 4, 7, 10, 13]] self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} - self.repeat = None + + +class TestSeqExpandCase2(TestSeqExpand): + def set_data(self): + x_data = np.random.uniform(0.1, 1, [1, 2, 2]).astype('float32') + x_lod = [[0, 1]] + y_data = np.random.uniform(0.1, 1, [2, 2, 2]).astype('float32') + y_lod = [[0, 2]] + self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} if __name__ == '__main__': unittest.main() -# TestSeqExpandCase4().setUp() From fab6f30ff62a14332903660a404f6b0d5f08be1c Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Mon, 30 Oct 2017 09:51:08 +0800 Subject: [PATCH 10/12] Add empty sequence case in unitest --- python/paddle/v2/framework/tests/test_seq_expand.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/paddle/v2/framework/tests/test_seq_expand.py b/python/paddle/v2/framework/tests/test_seq_expand.py index 901102802b129..ff17edd04bfd3 100644 --- a/python/paddle/v2/framework/tests/test_seq_expand.py +++ b/python/paddle/v2/framework/tests/test_seq_expand.py @@ -50,5 +50,14 @@ def set_data(self): self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} +class TestSeqExpandCase3(TestSeqExpand): + def set_data(self): + x_data = np.random.uniform(0.1, 1, [4, 1]).astype('float32') + x_lod = [[0, 1, 2, 3, 4]] + y_data = np.random.uniform(0.1, 1, [6, 1]).astype('float32') + y_lod = [[0, 2, 4, 4, 6]] + self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} + + if __name__ == '__main__': unittest.main() From 8d4e2d4cb37b190c16fbc35e2528f6caa536d53f Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Mon, 30 Oct 2017 11:46:47 +0800 Subject: [PATCH 11/12] 1. Add unitest for empty sequence case 2. Fix comments and paddle enforce check --- paddle/operators/seq_expand_op.cc | 32 ++++++++++++++++++++++++------- paddle/operators/seq_expand_op.h | 17 ++++++++++++---- 2 files changed, 38 insertions(+), 11 deletions(-) diff --git a/paddle/operators/seq_expand_op.cc b/paddle/operators/seq_expand_op.cc index 660e86e9ccca0..def5efa0e885d 100644 --- a/paddle/operators/seq_expand_op.cc +++ b/paddle/operators/seq_expand_op.cc @@ -25,10 +25,8 @@ class SeqExpandOp : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of SeqExpandOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of SeqExpandOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("X")); + PADDLE_ENFORCE(ctx->HasOutput("Out")); PADDLE_ENFORCE( ctx->HasInput("Y"), "Input(Y) of SeqExpandOp should not be null while repeat == 0."); @@ -54,7 +52,7 @@ class SeqExpandOpMaker : public framework::OpProtoAndCheckerMaker { "The element numbers of last level in input('Y') " "must be equal to dims[0] of input('X')."); AddOutput("Out", - "The output of seq_expand op." + "(LodTensor)The output of seq_expand op." "The lod of output will be as same as input(Y)'s lod."); AddComment(R"DOC( Expand input(X) according to LOD of input(Y). @@ -69,6 +67,7 @@ Given 2-level a LoDTensor input(X) and input(Y) Y.lod = [[0, 2, 4], [0, 3, 6, 7, 8]] +with condition len(Y.lod[-1]) -1 == X.dims[0] then we get 2-level LoDTensor Out.lod = [[0, 2, 4], [0, 3, 6, 7, 8]] @@ -83,6 +82,7 @@ Given a 0-level LoDTensor input(X) X.dims = [3, 1] and input(Y) Y.lod = [[0, 2, 3, 6]] +with condition len(Y.lod[-1]) -1 == X.dims[0] then we get 1-level LoDTensor Out.lod = [[0, 2, 3, 6]] Out.data = [a, a, b, c, c, c] @@ -96,11 +96,29 @@ Given a 0-level LoDTensor input(X) X.dims = [3, 2] and input(Y) Y.lod = [[0, 2, 3, 6]] +with condition len(Y.lod[-1]) -1 == X.dims[0] then we get 1-level LoDTensor Out.lod = [[0, 2, 3, 6]] Out.data = [[a,b], [a,b] [c,d], [e, f], [e, f], [e, f]] Out.dims = [6, 2] +Case 4: + +Given 2-level a LoDTensor input(X) + X.lod = [[0, 2, 3], + [0, 1, 3, 4]] + X.data = [a, b, c, d] + X.dims = [4, 1] +and input(Y) + Y.lod = [[0, 2, 4], + [0, 3, 6, 6, 8]] +with condition len(Y.lod[-1]) -1 == X.dims[0] +then we get 2-level LoDTensor + Out.lod = [[0, 2, 4], + [0, 3, 6, 6, 8]] + Out.data = [a, a, a, b, b, b, d, d] + Out.dims = [8, 1] + )DOC"); } @@ -112,8 +130,8 @@ class SeqExpandOpGrad : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); - PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) should not be null"); + PADDLE_ENFORCE(ctx->HasInput("X")); + PADDLE_ENFORCE(ctx->HasInput("Out")); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), "Input(Out@GRAD) should not be null"); auto x_dims = ctx->GetInputDim("X"); diff --git a/paddle/operators/seq_expand_op.h b/paddle/operators/seq_expand_op.h index ad3f42116d598..aa91e0f9296a7 100644 --- a/paddle/operators/seq_expand_op.h +++ b/paddle/operators/seq_expand_op.h @@ -36,7 +36,6 @@ class SeqExpandKernel : public framework::OpKernel { "The size of last lod level in Input(Y)" "must be equal to dims[0] of Input(X)."); out->set_lod(y->lod()); - out->Resize(y->dims()); auto place = context.GetEigenDevice(); size_t element_len = framework::product(x_dims) / x_dims[0]; T* out_data = out->mutable_data(context.GetPlace()); @@ -57,6 +56,18 @@ class SeqExpandKernel : public framework::OpKernel { } }; +/* + *Given Grad(Out) + * + * Grad(Out).lod = [[0, 2], + * [0, 3, 6]] + * Grad(Out).data = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] + * Then + * Grad(X).data = [(0.1 + 0.2 + 0.3), (0.4 + 0.5 + 0.6)] + * = [0.6, 1.5] + * Grad(X).lod = Input(X).lod + * + * */ template class SeqExpandGradKernel : public framework::OpKernel { public: @@ -68,10 +79,8 @@ class SeqExpandGradKernel : public framework::OpKernel { auto out_last_level = out->lod().back(); d_x->set_lod(x->lod()); const T* d_out_data = d_out->data(); - auto d_out_dims = d_out->dims(); T* d_x_data = d_x->mutable_data(context.GetPlace()); - size_t element_len = framework::product(d_out_dims) / d_out_dims[0]; - + size_t element_len = d_out->numel() / d_out->dims()[0]; for (size_t i = 0; i < out_last_level.size() - 1; ++i) { size_t repeat = out_last_level[i + 1] - out_last_level[i]; Eigen::TensorMap< From 84f471b42e7e8681c95453a01b0f7a1db0fd5125 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Mon, 30 Oct 2017 13:44:26 +0800 Subject: [PATCH 12/12] Fix comments --- paddle/operators/seq_expand_op.cc | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/paddle/operators/seq_expand_op.cc b/paddle/operators/seq_expand_op.cc index def5efa0e885d..08fda9b445642 100644 --- a/paddle/operators/seq_expand_op.cc +++ b/paddle/operators/seq_expand_op.cc @@ -27,9 +27,7 @@ class SeqExpandOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X")); PADDLE_ENFORCE(ctx->HasOutput("Out")); - PADDLE_ENFORCE( - ctx->HasInput("Y"), - "Input(Y) of SeqExpandOp should not be null while repeat == 0."); + PADDLE_ENFORCE(ctx->HasInput("Y")); framework::DDim out_dim; out_dim = ctx->GetInputDim("Y"); ctx->ShareLoD("Y", "Out"); @@ -43,14 +41,14 @@ class SeqExpandOpMaker : public framework::OpProtoAndCheckerMaker { framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", - "(Tensor or LoDTensor) The input('X') of this operator can be a " + "(Tensor or LoDTensor) The input(X) of this operator can be a " "LoDTensor or a base Tensor."); AddInput("Y", - "(LoDTensor)The reference input('Y') of seq_expand op." + "(LoDTensor)The reference input(Y) of seq_expand op." "It must be a LoDTensor with k-level(k>0)." - "Input(X) will be expanded according to LOD of input(Y)." - "The element numbers of last level in input('Y') " - "must be equal to dims[0] of input('X')."); + "The input(X) will be expanded according to LOD of input(Y)." + "The element numbers of last level in input(Y) " + "must be equal to dims[0] of input(X)."); AddOutput("Out", "(LodTensor)The output of seq_expand op." "The lod of output will be as same as input(Y)'s lod."); @@ -133,7 +131,7 @@ class SeqExpandOpGrad : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasInput("X")); PADDLE_ENFORCE(ctx->HasInput("Out")); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), - "Input(Out@GRAD) should not be null"); + "The input(Out@GRAD) should not be null"); auto x_dims = ctx->GetInputDim("X"); auto x_grad_name = framework::GradVarName("X"); if (ctx->HasOutput(x_grad_name)) {