diff --git a/paddle/operators/seq_expand_op.cc b/paddle/operators/seq_expand_op.cc new file mode 100644 index 0000000000000..08fda9b445642 --- /dev/null +++ b/paddle/operators/seq_expand_op.cc @@ -0,0 +1,153 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/seq_expand_op.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class SeqExpandOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X")); + PADDLE_ENFORCE(ctx->HasOutput("Out")); + PADDLE_ENFORCE(ctx->HasInput("Y")); + framework::DDim out_dim; + out_dim = ctx->GetInputDim("Y"); + ctx->ShareLoD("Y", "Out"); + ctx->SetOutputDim("Out", out_dim); + } +}; + +class SeqExpandOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SeqExpandOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(Tensor or LoDTensor) The input(X) of this operator can be a " + "LoDTensor or a base Tensor."); + AddInput("Y", + "(LoDTensor)The reference input(Y) of seq_expand op." + "It must be a LoDTensor with k-level(k>0)." + "The input(X) will be expanded according to LOD of input(Y)." + "The element numbers of last level in input(Y) " + "must be equal to dims[0] of input(X)."); + AddOutput("Out", + "(LodTensor)The output of seq_expand op." + "The lod of output will be as same as input(Y)'s lod."); + AddComment(R"DOC( +Expand input(X) according to LOD of input(Y). + +Case 1: + +Given 2-level a LoDTensor input(X) + X.lod = [[0, 2, 3], + [0, 1, 3, 4]] + X.data = [a, b, c, d] + X.dims = [4, 1] +and input(Y) + Y.lod = [[0, 2, 4], + [0, 3, 6, 7, 8]] +with condition len(Y.lod[-1]) -1 == X.dims[0] +then we get 2-level LoDTensor + Out.lod = [[0, 2, 4], + [0, 3, 6, 7, 8]] + Out.data = [a, a, a, b, b, b, c, d] + Out.dims = [8, 1] + +Case 2: + +Given a 0-level LoDTensor input(X) + X.data = [a, b, c] + X.lod = NULL + X.dims = [3, 1] +and input(Y) + Y.lod = [[0, 2, 3, 6]] +with condition len(Y.lod[-1]) -1 == X.dims[0] +then we get 1-level LoDTensor + Out.lod = [[0, 2, 3, 6]] + Out.data = [a, a, b, c, c, c] + Out.dims = [6, 1] + +Case 3: + +Given a 0-level LoDTensor input(X) + X.data = [[a, b], [c, d], [e, f]] + X.lod = NULL + X.dims = [3, 2] +and input(Y) + Y.lod = [[0, 2, 3, 6]] +with condition len(Y.lod[-1]) -1 == X.dims[0] +then we get 1-level LoDTensor + Out.lod = [[0, 2, 3, 6]] + Out.data = [[a,b], [a,b] [c,d], [e, f], [e, f], [e, f]] + Out.dims = [6, 2] + +Case 4: + +Given 2-level a LoDTensor input(X) + X.lod = [[0, 2, 3], + [0, 1, 3, 4]] + X.data = [a, b, c, d] + X.dims = [4, 1] +and input(Y) + Y.lod = [[0, 2, 4], + [0, 3, 6, 6, 8]] +with condition len(Y.lod[-1]) -1 == X.dims[0] +then we get 2-level LoDTensor + Out.lod = [[0, 2, 4], + [0, 3, 6, 6, 8]] + Out.data = [a, a, a, b, b, b, d, d] + Out.dims = [8, 1] + + +)DOC"); + } +}; + +class SeqExpandOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X")); + PADDLE_ENFORCE(ctx->HasInput("Out")); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "The input(Out@GRAD) should not be null"); + auto x_dims = ctx->GetInputDim("X"); + auto x_grad_name = framework::GradVarName("X"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(seq_expand, ops::SeqExpandOp, ops::SeqExpandOpMaker, + seq_expand_grad, ops::SeqExpandOpGrad); +REGISTER_OP_CPU_KERNEL(seq_expand, + ops::SeqExpandKernel); +REGISTER_OP_CPU_KERNEL( + seq_expand_grad, + ops::SeqExpandGradKernel); diff --git a/paddle/operators/seq_expand_op.cu b/paddle/operators/seq_expand_op.cu new file mode 100644 index 0000000000000..f1e4b82a76e62 --- /dev/null +++ b/paddle/operators/seq_expand_op.cu @@ -0,0 +1,23 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/seq_expand_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(seq_expand, + ops::SeqExpandKernel); +REGISTER_OP_GPU_KERNEL( + seq_expand_grad, + ops::SeqExpandGradKernel); diff --git a/paddle/operators/seq_expand_op.h b/paddle/operators/seq_expand_op.h new file mode 100644 index 0000000000000..aa91e0f9296a7 --- /dev/null +++ b/paddle/operators/seq_expand_op.h @@ -0,0 +1,100 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/framework/op_registry.h" +#include "paddle/memory/memcpy.h" +#include "unsupported/Eigen/CXX11/Tensor" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; + +template +class SeqExpandKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* out = context.Output("Out"); + const T* x_data = x->data(); + auto x_dims = x->dims(); + auto* y = context.Input("Y"); + PADDLE_ENFORCE_EQ(x_dims[0], y->lod().back().size() - 1, + "The size of last lod level in Input(Y)" + "must be equal to dims[0] of Input(X)."); + out->set_lod(y->lod()); + auto place = context.GetEigenDevice(); + size_t element_len = framework::product(x_dims) / x_dims[0]; + T* out_data = out->mutable_data(context.GetPlace()); + auto out_starts = out->lod().back(); + + for (size_t i = 0; i < out_starts.size() - 1; i++) { + int scale = out_starts[i + 1] - out_starts[i]; + Eigen::TensorMap< + Eigen::Tensor> + x_t(x_data, 1, element_len); + Eigen::TensorMap> + out_t(out_data, scale, element_len); + Eigen::array cast({scale, 1}); + out_t.device(place) = x_t.broadcast(cast); + x_data += element_len; + out_data += element_len * scale; + } + } +}; + +/* + *Given Grad(Out) + * + * Grad(Out).lod = [[0, 2], + * [0, 3, 6]] + * Grad(Out).data = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] + * Then + * Grad(X).data = [(0.1 + 0.2 + 0.3), (0.4 + 0.5 + 0.6)] + * = [0.6, 1.5] + * Grad(X).lod = Input(X).lod + * + * */ +template +class SeqExpandGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* d_out = context.Input(framework::GradVarName("Out")); + auto* x = context.Input("X"); + auto* out = context.Input("Out"); + auto* d_x = context.Output(framework::GradVarName("X")); + auto out_last_level = out->lod().back(); + d_x->set_lod(x->lod()); + const T* d_out_data = d_out->data(); + T* d_x_data = d_x->mutable_data(context.GetPlace()); + size_t element_len = d_out->numel() / d_out->dims()[0]; + for (size_t i = 0; i < out_last_level.size() - 1; ++i) { + size_t repeat = out_last_level[i + 1] - out_last_level[i]; + Eigen::TensorMap< + Eigen::Tensor> + d_out_t(d_out_data, static_cast(repeat), element_len); + Eigen::TensorMap> + d_x_t(d_x_data, static_cast(element_len)); + auto place = context.GetEigenDevice(); + d_x_t.device(place) = d_out_t.sum(Eigen::array({{0}})); + d_out_data += (repeat * element_len); + d_x_data += element_len; + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/sequence_concat_op.cc b/paddle/operators/sequence_concat_op.cc index 1fce96cdfe20f..46f73e3c27983 100644 --- a/paddle/operators/sequence_concat_op.cc +++ b/paddle/operators/sequence_concat_op.cc @@ -68,12 +68,12 @@ class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker { "The level should be less than the level number of inputs.") .SetDefault(0); AddComment(R"DOC( - The sequence_concat operator concatenates multiple LoDTensors. - It only supports sequence (LoD Tensor with level number is 1) + The sequence_concat operator concatenates multiple LoDTensors. + It only supports sequence (LoD Tensor with level number is 1) or a nested sequence (LoD tensor with level number is 2) as its input. - Case1: If the axis is other than 0(here, axis is 1 and level is 1), - each input should have the same LoD information and the LoD + each input should have the same LoD information and the LoD information of the output keeps the same as the input. LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4) @@ -81,7 +81,7 @@ class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker { LoD(Out) = {{0,2,4}, {0,1,2,3,4}}; Dims(Out) = (4,7,4) - Case2: - If the axis is 0(here, leve is 0), the inputs are concatenated along + If the axis is 0(here, leve is 0), the inputs are concatenated along time steps, the LoD information of the output need to re-compute. LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4) @@ -94,7 +94,7 @@ class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker { LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4) LoD(x1) = {{0,3,5}, {0,1,3,4,5}}; Dims(x1) = (5,3,4) LoD(Out) = {{0,5,9}, {0,2,5,7,9}}; Dims(Out) = (9,3,4) - + NOTE: The levels of all the inputs should be the same. )DOC"); } diff --git a/python/paddle/v2/framework/tests/test_seq_expand.py b/python/paddle/v2/framework/tests/test_seq_expand.py new file mode 100644 index 0000000000000..ff17edd04bfd3 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_seq_expand.py @@ -0,0 +1,63 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestSeqExpand(OpTest): + def set_data(self): + x_data = np.random.uniform(0.1, 1, [3, 1]).astype('float32') + y_data = np.random.uniform(0.1, 1, [8, 1]).astype('float32') + y_lod = [[0, 1, 4, 8]] + self.inputs = {'X': x_data, 'Y': (y_data, y_lod)} + + def compute(self): + x = self.inputs['X'] + x_data, x_lod = x if type(x) == tuple else (x, None) + n = 1 + x_data.shape[0] if not x_lod else len(x_lod[0]) + y_data, y_lod = self.inputs['Y'] + repeats = [((y_lod[-1][i + 1] - y_lod[-1][i])) + for i in range(len(y_lod[-1]) - 1)] + out = x_data.repeat(repeats, axis=0) + self.outputs = {'Out': out} + + def setUp(self): + self.op_type = 'seq_expand' + self.set_data() + self.compute() + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +class TestSeqExpandCase1(TestSeqExpand): + def set_data(self): + x_data = np.random.uniform(0.1, 1, [5, 1]).astype('float32') + x_lod = [[0, 2, 5]] + y_data = np.random.uniform(0.1, 1, [13, 1]).astype('float32') + y_lod = [[0, 2, 5], [0, 2, 4, 7, 10, 13]] + self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} + + +class TestSeqExpandCase2(TestSeqExpand): + def set_data(self): + x_data = np.random.uniform(0.1, 1, [1, 2, 2]).astype('float32') + x_lod = [[0, 1]] + y_data = np.random.uniform(0.1, 1, [2, 2, 2]).astype('float32') + y_lod = [[0, 2]] + self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} + + +class TestSeqExpandCase3(TestSeqExpand): + def set_data(self): + x_data = np.random.uniform(0.1, 1, [4, 1]).astype('float32') + x_lod = [[0, 1, 2, 3, 4]] + y_data = np.random.uniform(0.1, 1, [6, 1]).astype('float32') + y_lod = [[0, 2, 4, 4, 6]] + self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} + + +if __name__ == '__main__': + unittest.main()