From 611ee68b7888c8680b1c8ee967ad964d3c1e7f4c Mon Sep 17 00:00:00 2001 From: peterzhang2029 Date: Mon, 23 Oct 2017 17:33:23 +0800 Subject: [PATCH 1/7] add bilinear tensor product op --- .../operators/bilinear_tensor_product_op.cc | 153 +++++++++++++++ .../operators/bilinear_tensor_product_op.cu | 24 +++ paddle/operators/bilinear_tensor_product_op.h | 176 ++++++++++++++++++ .../tests/test_bilinear_tensor_product_op.py | 30 +++ 4 files changed, 383 insertions(+) create mode 100644 paddle/operators/bilinear_tensor_product_op.cc create mode 100644 paddle/operators/bilinear_tensor_product_op.cu create mode 100644 paddle/operators/bilinear_tensor_product_op.h create mode 100644 python/paddle/v2/framework/tests/test_bilinear_tensor_product_op.py diff --git a/paddle/operators/bilinear_tensor_product_op.cc b/paddle/operators/bilinear_tensor_product_op.cc new file mode 100644 index 0000000000000..64569e5fe77bb --- /dev/null +++ b/paddle/operators/bilinear_tensor_product_op.cc @@ -0,0 +1,153 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/bilinear_tensor_product_op.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class BilinearTensorProductOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Weight"), + "Input(Weight) should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null."); + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + auto weight_dims = ctx->GetInputDim("Weight"); + + PADDLE_ENFORCE_EQ(x_dims.size(), 1, "The input X must be a vector."); + PADDLE_ENFORCE_EQ(y_dims.size(), 1, "The input Y must be a vector."); + PADDLE_ENFORCE_EQ(weight_dims.size(), 3, + "The input Weight must be a 3D tensor."); + PADDLE_ENFORCE_GT(weight_dims[0], 0, + "The first dimension of Weight must be larger than 0."); + PADDLE_ENFORCE_GT(weight_dims[1], 0, + "The second dimension of Weight must be larger than 0."); + PADDLE_ENFORCE_GT(weight_dims[2], 0, + "The third dimension of Weight must be larger than 0."); + PADDLE_ENFORCE_EQ(x_dims[0], weight_dims[1], + "The dimension of X must be equal with the second " + "dimension of the Weight."); + PADDLE_ENFORCE_EQ(y_dims[0], weight_dims[2], + "The dimension of Y must be equal with the third " + "dimension of the Weight."); + + auto bias = Input("Bias"); + if (bias != framework::kEmptyVarName) { + auto bias_dims = ctx->GetInputDim("Bias"); + PADDLE_ENFORCE_EQ(bias_dims.size(), 1, + "The input Bias must be a vector."); + PADDLE_ENFORCE_EQ(bias_dims[0], weight_dims[0], + "The dimension of Bias must be equal with the first " + "dimension of the Weight."); + } + + ctx->SetOutputDim("Out", {weight_dims[0]}); + } +}; + +class BilinearTensorProductOpMaker : public framework::OpProtoAndCheckerMaker { + public: + BilinearTensorProductOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The first input of tensor op"); + AddInput("Y", "The second input of tensor op"); + AddInput("Weight", "The input weight of tensor op"); + AddInput("Bias", "The input bias of tensor op"); + AddOutput("Out", "The output of tensor op"); + AddComment(R"DOC( +Bilinear Tensor Product operator. +Given input X and Y, a 3D tensor weight, and bias. Each entry of the output is +computed by one slice i = 1, . . . , k of the tensor: Out_i = X*W_i*Y + Bias_i . + +The equation of this operator is: + + Out = \sum_{i} X*W_i*Y + Bias + +)DOC"); + } +}; + +class BilinearTensorProductOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null"); + PADDLE_ENFORCE(ctx->HasInput("Weight"), "Input(Weight) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input (Out@GRAD) should not be null"); + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + auto weight_dims = ctx->GetInputDim("Weight"); + auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); + + PADDLE_ENFORCE_EQ(out_dims.size(), 1, "The Out@GRAD must be a vector."); + PADDLE_ENFORCE_EQ( + weight_dims[0], out_dims[0], + "The dimension of Out@GRAD must be equal with the third dimension of " + "the Weight."); + + auto bias = Input("Bias"); + if (bias != framework::kEmptyVarName) { + auto bias_dims = ctx->GetInputDim("Bias"); + PADDLE_ENFORCE_EQ(bias_dims.size(), 1, "Input Bias must be a vector."); + PADDLE_ENFORCE_EQ( + bias_dims[0], out_dims[0], + "The dimension of Bias must be equal with the Out@GRAD "); + auto bias_grad_name = framework::GradVarName("Bias"); + if (ctx->HasOutput(bias_grad_name)) + ctx->SetOutputDim(bias_grad_name, bias_dims); + } + + auto x_grad_name = framework::GradVarName("X"); + auto y_grad_name = framework::GradVarName("Y"); + auto weight_grad_name = framework::GradVarName("Weight"); + + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + if (ctx->HasOutput(y_grad_name)) { + ctx->SetOutputDim(y_grad_name, y_dims); + } + if (ctx->HasOutput(weight_grad_name)) { + ctx->SetOutputDim(weight_grad_name, weight_dims); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(bilinear_tensor_product, ops::BilinearTensorProductOp, + ops::BilinearTensorProductOpMaker, bilinear_tensor_product_grad, + ops::BilinearTensorProductOpGrad); +REGISTER_OP_CPU_KERNEL( + bilinear_tensor_product, + ops::BilinearTensorProductKernel); +REGISTER_OP_CPU_KERNEL( + bilinear_tensor_product_grad, + ops::BilinearTensorProductGradKernel); diff --git a/paddle/operators/bilinear_tensor_product_op.cu b/paddle/operators/bilinear_tensor_product_op.cu new file mode 100644 index 0000000000000..a212460560e79 --- /dev/null +++ b/paddle/operators/bilinear_tensor_product_op.cu @@ -0,0 +1,24 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/bilinear_tensor_product_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL( + bilinear_tensor_product, + ops::BilinearTensorProductKernel); +REGISTER_OP_GPU_KERNEL( + bilinear_tensor_product_grad, + ops::BilinearTensorProductGradKernel); diff --git a/paddle/operators/bilinear_tensor_product_op.h b/paddle/operators/bilinear_tensor_product_op.h new file mode 100644 index 0000000000000..b816d6d7c210d --- /dev/null +++ b/paddle/operators/bilinear_tensor_product_op.h @@ -0,0 +1,176 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/platform/transform.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using platform::Transform; + +template +class BilinearTensorProductKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* weight = ctx.Input("Weight"); + auto* bias = ctx.Input("Bias"); + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + + auto weight_dims = weight->dims(); + Tensor left_mul_vec; + left_mul_vec.mutable_data(framework::make_ddim({weight_dims[2]}), + ctx.GetPlace()); + if (bias) { + out->CopyFrom(*bias, ctx.GetPlace(), ctx.device_context()); + } + for (int i = 0; i < weight_dims[0]; ++i) { + Tensor weight_mat = weight->Slice(i, i + 1).Resize( + framework::make_ddim({weight_dims[1], weight_dims[2]})); + math::gemm(ctx.device_context(), CblasNoTrans, CblasNoTrans, 1, + weight_dims[2], weight_dims[1], 1, x->data(), + weight_mat.data(), 0, left_mul_vec.data()); + if (bias) { + math::gemm(ctx.device_context(), CblasNoTrans, CblasNoTrans, + 1, 1, weight_dims[2], 1, left_mul_vec.data(), + y->data(), 1, &(out->data()[i])); + } else { + math::gemm(ctx.device_context(), CblasNoTrans, CblasNoTrans, + 1, 1, weight_dims[2], 1, left_mul_vec.data(), + y->data(), 0, &(out->data()[i])); + } + } + } +}; + +template +class ScaleFunctor { + public: + explicit ScaleFunctor(const T* scale) : scale_(scale) {} + + HOSTDEVICE T operator()(const T& x) const { return x * (*scale_); } + + private: + const T* scale_; +}; + +template +class BilinearTensorProductGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const Tensor* x = ctx.Input("X"); + const Tensor* y = ctx.Input("Y"); + const Tensor* weight = ctx.Input("Weight"); + Tensor* d_x = ctx.Output(framework::GradVarName("X")); + Tensor* d_y = ctx.Output(framework::GradVarName("Y")); + Tensor* d_weight = ctx.Output(framework::GradVarName("Weight")); + Tensor* d_bias = ctx.Output(framework::GradVarName("Bias")); + const Tensor* d_out = ctx.Input(framework::GradVarName("Out")); + auto* d_out_ptr = d_out->data(); + auto weight_dims = weight->dims(); + + // Get the first matrix of Weight. + Tensor weight_mat_0 = weight->Slice(0, 1).Resize( + framework::make_ddim({weight_dims[1], weight_dims[2]})); + + // Create the intermediate variable for gradient. + int numel_x = x->numel(); + int numel_y = y->numel(); + const T* x_ptr = x->data(); + const T* y_ptr = y->data(); + Tensor x_scale; + T* x_scale_ptr = x_scale.mutable_data( + framework::make_ddim({weight_dims[1]}), ctx.GetPlace()); + Tensor y_scale; + T* y_scale_ptr = y_scale.mutable_data( + framework::make_ddim({weight_dims[2]}), ctx.GetPlace()); + Transform trans; + + // Caculate the gradient of X according to the first matrix of Weight. + if (d_x) { + d_x->mutable_data(ctx.GetPlace()); + trans(ctx.device_context(), y_ptr, y_ptr + numel_y, y_scale_ptr, + ScaleFunctor(&d_out_ptr[0])); + math::gemm(ctx.device_context(), CblasNoTrans, CblasTrans, 1, + weight_dims[1], weight_dims[2], 1, y_scale.data(), + weight_mat_0.data(), 0, d_x->data()); + } + + // Caculate the gradient of Y according to the first matrix of Weight. + if (d_y) { + d_y->mutable_data(ctx.GetPlace()); + trans(ctx.device_context(), x_ptr, x_ptr + numel_x, x_scale_ptr, + ScaleFunctor(&d_out_ptr[0])); + math::gemm(ctx.device_context(), CblasTrans, CblasNoTrans, + weight_dims[2], 1, weight_dims[1], 1, + weight_mat_0.data(), x_scale.data(), 0, + d_y->data()); + } + + // Caculate the gradient of X and Y completly. + if (d_x || d_y) { + for (int i = 1; i < weight_dims[0]; ++i) { + Tensor weight_mat = weight->Slice(i, i + 1).Resize( + framework::make_ddim({weight_dims[1], weight_dims[2]})); + if (d_x) { + trans(ctx.device_context(), y_ptr, y_ptr + numel_y, y_scale_ptr, + ScaleFunctor(&d_out_ptr[i])); + math::gemm(ctx.device_context(), CblasNoTrans, CblasTrans, + 1, weight_dims[1], weight_dims[2], 1, + y_scale.data(), weight_mat.data(), 1, + d_x->data()); + } + if (d_y) { + trans(ctx.device_context(), x_ptr, x_ptr + numel_x, x_scale_ptr, + ScaleFunctor(&d_out_ptr[i])); + math::gemm(ctx.device_context(), CblasTrans, CblasNoTrans, + weight_dims[2], 1, weight_dims[1], 1, + weight_mat.data(), x_scale.data(), 1, + d_y->data()); + } + } + } + + // Caculate the gradient of Weight. + if (d_weight) { + d_weight->mutable_data(ctx.GetPlace()); + for (int i = 0; i < weight_dims[0]; ++i) { + Tensor d_weight_mat = d_weight->Slice(i, i + 1).Resize( + framework::make_ddim({weight_dims[1], weight_dims[2]})); + trans(ctx.device_context(), x_ptr, x_ptr + numel_x, x_scale_ptr, + ScaleFunctor(&d_out_ptr[i])); + math::gemm(ctx.device_context(), CblasTrans, CblasNoTrans, + weight_dims[1], weight_dims[2], 1, 1, + x_scale.data(), y->data(), 0, + d_weight_mat.data()); + } + } + + // Caculate the gradient of Bias. + if (d_bias) { + d_bias->mutable_data(ctx.GetPlace()); + d_bias->CopyFrom(*d_out, ctx.GetPlace(), ctx.device_context()); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_bilinear_tensor_product_op.py b/python/paddle/v2/framework/tests/test_bilinear_tensor_product_op.py new file mode 100644 index 0000000000000..10d90a9f0f9f8 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_bilinear_tensor_product_op.py @@ -0,0 +1,30 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestBilinearTensorProductOp(OpTest): + def setUp(self): + self.op_type = "bilinear_tensor_product" + self.inputs = { + 'X': np.random.random(3).astype("float32"), + 'Y': np.random.random(4).astype("float32"), + 'Weight': np.random.random((5, 3, 4)).astype("float32"), + 'Bias': np.random.random(5).astype("float32") + } + self.outputs = { + 'Out': np.matmul( + np.matmul(self.inputs['Weight'], self.inputs['Y']), + self.inputs['X']) + self.inputs['Bias'] + } + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y', 'Weight', 'Bias'], 'Out', max_relative_error=0.5) + + +if __name__ == "__main__": + unittest.main() From 3ae14242da3e32350790711b6339b07787a231ea Mon Sep 17 00:00:00 2001 From: peterzhang2029 Date: Tue, 7 Nov 2017 20:26:04 +0800 Subject: [PATCH 2/7] update for mini-batch --- .../operators/bilinear_tensor_product_op.cc | 78 +++++---- .../operators/bilinear_tensor_product_op.cu | 79 ++++++++- paddle/operators/bilinear_tensor_product_op.h | 165 ++++++++++-------- .../tests/test_bilinear_tensor_product_op.py | 81 +++++++-- 4 files changed, 279 insertions(+), 124 deletions(-) diff --git a/paddle/operators/bilinear_tensor_product_op.cc b/paddle/operators/bilinear_tensor_product_op.cc index 64569e5fe77bb..3bd2d40cd284a 100644 --- a/paddle/operators/bilinear_tensor_product_op.cc +++ b/paddle/operators/bilinear_tensor_product_op.cc @@ -34,8 +34,8 @@ class BilinearTensorProductOp : public framework::OperatorWithKernel { auto y_dims = ctx->GetInputDim("Y"); auto weight_dims = ctx->GetInputDim("Weight"); - PADDLE_ENFORCE_EQ(x_dims.size(), 1, "The input X must be a vector."); - PADDLE_ENFORCE_EQ(y_dims.size(), 1, "The input Y must be a vector."); + PADDLE_ENFORCE_EQ(x_dims.size(), 2, "The input X must be a 2D Tensor."); + PADDLE_ENFORCE_EQ(y_dims.size(), 2, "The input Y must be a 2D Tensor."); PADDLE_ENFORCE_EQ(weight_dims.size(), 3, "The input Weight must be a 3D tensor."); PADDLE_ENFORCE_GT(weight_dims[0], 0, @@ -44,24 +44,29 @@ class BilinearTensorProductOp : public framework::OperatorWithKernel { "The second dimension of Weight must be larger than 0."); PADDLE_ENFORCE_GT(weight_dims[2], 0, "The third dimension of Weight must be larger than 0."); - PADDLE_ENFORCE_EQ(x_dims[0], weight_dims[1], - "The dimension of X must be equal with the second " + PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0], + "The first dimension(batch_size) of X must be " + "equal with the first dimension of the Y."); + PADDLE_ENFORCE_EQ(x_dims[1], weight_dims[1], + "The second dimension of X must be equal with the second " "dimension of the Weight."); - PADDLE_ENFORCE_EQ(y_dims[0], weight_dims[2], - "The dimension of Y must be equal with the third " + PADDLE_ENFORCE_EQ(y_dims[1], weight_dims[2], + "The second dimension of Y must be equal with the third " "dimension of the Weight."); - auto bias = Input("Bias"); - if (bias != framework::kEmptyVarName) { + if (ctx->HasInput("Bias")) { auto bias_dims = ctx->GetInputDim("Bias"); - PADDLE_ENFORCE_EQ(bias_dims.size(), 1, - "The input Bias must be a vector."); - PADDLE_ENFORCE_EQ(bias_dims[0], weight_dims[0], - "The dimension of Bias must be equal with the first " - "dimension of the Weight."); + PADDLE_ENFORCE_EQ(bias_dims.size(), 2, + "The input Bias must have 2 dimensions."); + PADDLE_ENFORCE_EQ(bias_dims[0], 1, + "The first dimention of input Bias must be 1."); + PADDLE_ENFORCE_EQ(bias_dims[1], weight_dims[0], + "The second dimension of Bias must be equal with the " + "first dimension of the Weight."); } - ctx->SetOutputDim("Out", {weight_dims[0]}); + ctx->SetOutputDim("Out", {x_dims[0], weight_dims[0]}); + ctx->ShareLoD("X", /*->*/ "Out"); } }; @@ -70,19 +75,19 @@ class BilinearTensorProductOpMaker : public framework::OpProtoAndCheckerMaker { BilinearTensorProductOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "The first input of tensor op"); - AddInput("Y", "The second input of tensor op"); - AddInput("Weight", "The input weight of tensor op"); - AddInput("Bias", "The input bias of tensor op"); - AddOutput("Out", "The output of tensor op"); + AddInput("X", "The first input of BilinearTensorProduct op"); + AddInput("Y", "The second input of BilinearTensorProduct op"); + AddInput("Weight", "The input weight of BilinearTensorProduct op"); + AddInput("Bias", "The input bias of BilinearTensorProduct op") + .AsDispensable(); + AddOutput("Out", "The output of BilinearTensorProduct op"); AddComment(R"DOC( Bilinear Tensor Product operator. -Given input X and Y, a 3D tensor weight, and bias. Each entry of the output is -computed by one slice i = 1, . . . , k of the tensor: Out_i = X*W_i*Y + Bias_i . +Given input X and Y, a 3D tensor weight, and bias. Each column of the +output is computed by one slice i = 1, . . . , k of the tensor: -The equation of this operator is: - - Out = \sum_{i} X*W_i*Y + Bias + M = (X W_i) \cdot Y + Out_i = \sum_i {M_i} + Bias_i )DOC"); } @@ -104,19 +109,20 @@ class BilinearTensorProductOpGrad : public framework::OperatorWithKernel { auto weight_dims = ctx->GetInputDim("Weight"); auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); - PADDLE_ENFORCE_EQ(out_dims.size(), 1, "The Out@GRAD must be a vector."); + PADDLE_ENFORCE_EQ(out_dims.size(), 2, "The Out@GRAD must be a 2D Tensor."); PADDLE_ENFORCE_EQ( - weight_dims[0], out_dims[0], - "The dimension of Out@GRAD must be equal with the third dimension of " - "the Weight."); - - auto bias = Input("Bias"); - if (bias != framework::kEmptyVarName) { + x_dims[0], out_dims[0], + "The first dimension(batch_size) of Out@GRAD must be equal with " + "the first dimension of the X."); + PADDLE_ENFORCE_EQ(weight_dims[0], out_dims[1], + "The second dimension of Out@GRAD must be equal with " + "the third dimension of the Weight."); + + if (ctx->HasInput("Bias")) { auto bias_dims = ctx->GetInputDim("Bias"); - PADDLE_ENFORCE_EQ(bias_dims.size(), 1, "Input Bias must be a vector."); - PADDLE_ENFORCE_EQ( - bias_dims[0], out_dims[0], - "The dimension of Bias must be equal with the Out@GRAD "); + PADDLE_ENFORCE_EQ(bias_dims[1], out_dims[1], + "The second dimension of Bias must be equal with " + "the second dimension of the Out@GRAD."); auto bias_grad_name = framework::GradVarName("Bias"); if (ctx->HasOutput(bias_grad_name)) ctx->SetOutputDim(bias_grad_name, bias_dims); @@ -150,4 +156,4 @@ REGISTER_OP_CPU_KERNEL( ops::BilinearTensorProductKernel); REGISTER_OP_CPU_KERNEL( bilinear_tensor_product_grad, - ops::BilinearTensorProductGradKernel); + ops::BilinearTensorProductGradKernel); \ No newline at end of file diff --git a/paddle/operators/bilinear_tensor_product_op.cu b/paddle/operators/bilinear_tensor_product_op.cu index a212460560e79..1d65c17f8c1a6 100644 --- a/paddle/operators/bilinear_tensor_product_op.cu +++ b/paddle/operators/bilinear_tensor_product_op.cu @@ -15,10 +15,85 @@ #define EIGEN_USE_GPU #include "paddle/operators/bilinear_tensor_product_op.h" +namespace paddle { +namespace operators { + +template +class BilinearTensorProductCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* weight = ctx.Input("Weight"); + auto* bias = ctx.Input("Bias"); + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + + auto y_mat = EigenMatrix::From(*y); + auto batch_size = x->dims()[0]; + auto weight_dims = weight->dims(); + + auto place = ctx.GetEigenDevice(); + auto cpu_place = ctx.GetEigenDevice(); + + // Copy the output to cpu. + Tensor output_cpu; + output_cpu.CopyFrom(*out, platform::CPUPlace(), ctx.device_context()); + auto* output_cpu_ptr = output_cpu.data(); + auto output_cpu_mat = EigenMatrix::From(output_cpu); + + // Create the temporary variables. + Tensor left_mul; + left_mul.mutable_data(framework::make_ddim({batch_size, weight_dims[2]}), + ctx.GetPlace()); + auto left_mul_mat = EigenMatrix::From(left_mul); + Tensor output_col; + output_col.mutable_data(framework::make_ddim({batch_size}), + ctx.GetPlace()); + auto output_col_vec = EigenVector::From(output_col); + + for (size_t i = 0; i < weight_dims[0]; ++i) { + Tensor weight_mat = weight->Slice(i, i + 1).Resize( + framework::make_ddim({weight_dims[1], weight_dims[2]})); + math::gemm(ctx.device_context(), CblasNoTrans, CblasNoTrans, + batch_size, weight_dims[2], weight_dims[1], 1, + x->data(), weight_mat.data(), 0, + left_mul.data()); + output_col_vec.device(place) = + (left_mul_mat * y_mat).sum(Eigen::DSizes(1)); + + // Copy the output_col to cpu. + Tensor output_col_cpu; + output_col_cpu.CopyFrom(output_col, platform::CPUPlace(), + ctx.device_context()); + auto* output_col_ptr = output_col_cpu.data(); + + for (size_t j = 0; j < batch_size; ++j) { + output_cpu_ptr[i + j * weight_dims[0]] = output_col_ptr[j]; + } + } + + if (bias) { + // Copy the bias to cpu. + Tensor bias_cpu; + bias_cpu.CopyFrom(*bias, platform::CPUPlace(), ctx.device_context()); + auto bias_vec = EigenMatrix::From(bias_cpu); + Eigen::DSizes bcast(batch_size, 1); + output_cpu_mat.device(cpu_place) = + bias_vec.broadcast(bcast) + output_cpu_mat; + } + + // Copy the output to gpu. + out->CopyFrom(output_cpu, platform::GPUPlace(), ctx.device_context()); + } +}; +} // namespace operators +} // namespace paddle + namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL( bilinear_tensor_product, - ops::BilinearTensorProductKernel); + ops::BilinearTensorProductCUDAKernel); REGISTER_OP_GPU_KERNEL( bilinear_tensor_product_grad, - ops::BilinearTensorProductGradKernel); + ops::BilinearTensorProductGradKernel); \ No newline at end of file diff --git a/paddle/operators/bilinear_tensor_product_op.h b/paddle/operators/bilinear_tensor_product_op.h index b816d6d7c210d..238d1d7749694 100644 --- a/paddle/operators/bilinear_tensor_product_op.h +++ b/paddle/operators/bilinear_tensor_product_op.h @@ -14,15 +14,22 @@ #pragma once +#include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" #include "paddle/operators/math/math_function.h" -#include "paddle/platform/transform.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; -using platform::Transform; + +template +using EigenMatrix = framework::EigenMatrix; + +template +using EigenVector = framework::EigenVector; template class BilinearTensorProductKernel : public framework::OpKernel { @@ -35,43 +42,45 @@ class BilinearTensorProductKernel : public framework::OpKernel { auto* out = ctx.Output("Out"); out->mutable_data(ctx.GetPlace()); + auto y_mat = EigenMatrix::From(*y); + auto output_mat = EigenMatrix::From(*out); + + auto batch_size = x->dims()[0]; auto weight_dims = weight->dims(); - Tensor left_mul_vec; - left_mul_vec.mutable_data(framework::make_ddim({weight_dims[2]}), - ctx.GetPlace()); - if (bias) { - out->CopyFrom(*bias, ctx.GetPlace(), ctx.device_context()); - } - for (int i = 0; i < weight_dims[0]; ++i) { + auto place = ctx.GetEigenDevice(); + + // Create the temporary variables. + Tensor left_mul; + left_mul.mutable_data(framework::make_ddim({batch_size, weight_dims[2]}), + ctx.GetPlace()); + auto left_mul_mat = EigenMatrix::From(left_mul); + Tensor output_col; + output_col.mutable_data(framework::make_ddim({weight_dims[0]}), + ctx.GetPlace()); + auto output_col_vec = EigenVector::From(output_col); + + for (size_t i = 0; i < weight_dims[0]; ++i) { Tensor weight_mat = weight->Slice(i, i + 1).Resize( framework::make_ddim({weight_dims[1], weight_dims[2]})); - math::gemm(ctx.device_context(), CblasNoTrans, CblasNoTrans, 1, - weight_dims[2], weight_dims[1], 1, x->data(), - weight_mat.data(), 0, left_mul_vec.data()); - if (bias) { - math::gemm(ctx.device_context(), CblasNoTrans, CblasNoTrans, - 1, 1, weight_dims[2], 1, left_mul_vec.data(), - y->data(), 1, &(out->data()[i])); - } else { - math::gemm(ctx.device_context(), CblasNoTrans, CblasNoTrans, - 1, 1, weight_dims[2], 1, left_mul_vec.data(), - y->data(), 0, &(out->data()[i])); + math::gemm(ctx.device_context(), CblasNoTrans, CblasNoTrans, + batch_size, weight_dims[2], weight_dims[1], 1, + x->data(), weight_mat.data(), 0, + left_mul.data()); + output_col_vec = (left_mul_mat * y_mat).sum(Eigen::DSizes(1)); + for (size_t j = 0; j < batch_size; ++j) { + output_mat(j, i) = output_col_vec(j); } } + if (bias) { + auto bias_vec = EigenMatrix::From(*bias); + Eigen::DSizes bcast(batch_size, 1); + output_mat.device(place) = bias_vec.broadcast(bcast) + output_mat; + } else { + output_mat.device(place) = output_mat; + } } }; -template -class ScaleFunctor { - public: - explicit ScaleFunctor(const T* scale) : scale_(scale) {} - - HOSTDEVICE T operator()(const T& x) const { return x * (*scale_); } - - private: - const T* scale_; -}; - template class BilinearTensorProductGradKernel : public framework::OpKernel { public: @@ -84,66 +93,65 @@ class BilinearTensorProductGradKernel : public framework::OpKernel { Tensor* d_weight = ctx.Output(framework::GradVarName("Weight")); Tensor* d_bias = ctx.Output(framework::GradVarName("Bias")); const Tensor* d_out = ctx.Input(framework::GradVarName("Out")); - auto* d_out_ptr = d_out->data(); + + auto batch_size = x->dims()[0]; auto weight_dims = weight->dims(); - // Get the first matrix of Weight. - Tensor weight_mat_0 = weight->Slice(0, 1).Resize( - framework::make_ddim({weight_dims[1], weight_dims[2]})); + auto x_mat = EigenMatrix::From(*x); + auto y_mat = EigenMatrix::From(*y); + auto d_out_mat = EigenMatrix::From(*d_out); + auto place = ctx.GetEigenDevice(); - // Create the intermediate variable for gradient. - int numel_x = x->numel(); - int numel_y = y->numel(); - const T* x_ptr = x->data(); - const T* y_ptr = y->data(); + // Create the temporary variables for gradient. Tensor x_scale; - T* x_scale_ptr = x_scale.mutable_data( - framework::make_ddim({weight_dims[1]}), ctx.GetPlace()); + x_scale.mutable_data(framework::make_ddim({batch_size, weight_dims[1]}), + ctx.GetPlace()); + auto x_scale_mat = EigenMatrix::From(x_scale); Tensor y_scale; - T* y_scale_ptr = y_scale.mutable_data( - framework::make_ddim({weight_dims[2]}), ctx.GetPlace()); - Transform trans; + y_scale.mutable_data(framework::make_ddim({batch_size, weight_dims[2]}), + ctx.GetPlace()); + auto y_scale_mat = EigenMatrix::From(y_scale); + + math::SetConstant set_zero; - // Caculate the gradient of X according to the first matrix of Weight. + // Set X@Grad be zero at first. if (d_x) { d_x->mutable_data(ctx.GetPlace()); - trans(ctx.device_context(), y_ptr, y_ptr + numel_y, y_scale_ptr, - ScaleFunctor(&d_out_ptr[0])); - math::gemm(ctx.device_context(), CblasNoTrans, CblasTrans, 1, - weight_dims[1], weight_dims[2], 1, y_scale.data(), - weight_mat_0.data(), 0, d_x->data()); + set_zero(ctx.device_context(), d_x, static_cast(0)); } - // Caculate the gradient of Y according to the first matrix of Weight. + // Set Y@Grad be zero at first. if (d_y) { d_y->mutable_data(ctx.GetPlace()); - trans(ctx.device_context(), x_ptr, x_ptr + numel_x, x_scale_ptr, - ScaleFunctor(&d_out_ptr[0])); - math::gemm(ctx.device_context(), CblasTrans, CblasNoTrans, - weight_dims[2], 1, weight_dims[1], 1, - weight_mat_0.data(), x_scale.data(), 0, - d_y->data()); + set_zero(ctx.device_context(), d_y, static_cast(0)); } - // Caculate the gradient of X and Y completly. + // Caculate the X@Grad and Y@Grad. if (d_x || d_y) { - for (int i = 1; i < weight_dims[0]; ++i) { - Tensor weight_mat = weight->Slice(i, i + 1).Resize( + Eigen::DSizes bcast_for_x(1, weight_dims[2]); + Eigen::DSizes bcast_for_y(1, weight_dims[1]); + for (int i = 0; i < weight_dims[0]; ++i) { + Tensor weight_i = weight->Slice(i, i + 1).Resize( framework::make_ddim({weight_dims[1], weight_dims[2]})); + auto output_vec = d_out_mat.chip(i, 1); if (d_x) { - trans(ctx.device_context(), y_ptr, y_ptr + numel_y, y_scale_ptr, - ScaleFunctor(&d_out_ptr[i])); + y_scale_mat.device(place) = + output_vec.reshape(Eigen::DSizes(batch_size, 1)) + .broadcast(bcast_for_x) * + y_mat; math::gemm(ctx.device_context(), CblasNoTrans, CblasTrans, - 1, weight_dims[1], weight_dims[2], 1, - y_scale.data(), weight_mat.data(), 1, + batch_size, weight_dims[1], weight_dims[2], 1, + y_scale.data(), weight_i.data(), 1, d_x->data()); } if (d_y) { - trans(ctx.device_context(), x_ptr, x_ptr + numel_x, x_scale_ptr, - ScaleFunctor(&d_out_ptr[i])); - math::gemm(ctx.device_context(), CblasTrans, CblasNoTrans, - weight_dims[2], 1, weight_dims[1], 1, - weight_mat.data(), x_scale.data(), 1, + x_scale_mat.device(place) = + output_vec.reshape(Eigen::DSizes(batch_size, 1)) + .broadcast(bcast_for_y) * + x_mat; + math::gemm(ctx.device_context(), CblasNoTrans, CblasNoTrans, + batch_size, weight_dims[2], weight_dims[1], 1, + x_scale.data(), weight_i.data(), 1, d_y->data()); } } @@ -152,22 +160,27 @@ class BilinearTensorProductGradKernel : public framework::OpKernel { // Caculate the gradient of Weight. if (d_weight) { d_weight->mutable_data(ctx.GetPlace()); + Eigen::DSizes bcast_for_weight(1, weight_dims[1]); for (int i = 0; i < weight_dims[0]; ++i) { - Tensor d_weight_mat = d_weight->Slice(i, i + 1).Resize( + Tensor d_weight_i = d_weight->Slice(i, i + 1).Resize( framework::make_ddim({weight_dims[1], weight_dims[2]})); - trans(ctx.device_context(), x_ptr, x_ptr + numel_x, x_scale_ptr, - ScaleFunctor(&d_out_ptr[i])); + auto output_vec = d_out_mat.chip(i, 1); + x_scale_mat.device(place) = + output_vec.reshape(Eigen::DSizes(batch_size, 1)) + .broadcast(bcast_for_weight) * + x_mat; math::gemm(ctx.device_context(), CblasTrans, CblasNoTrans, - weight_dims[1], weight_dims[2], 1, 1, + weight_dims[1], weight_dims[2], batch_size, 1, x_scale.data(), y->data(), 0, - d_weight_mat.data()); + d_weight_i.data()); } } // Caculate the gradient of Bias. if (d_bias) { d_bias->mutable_data(ctx.GetPlace()); - d_bias->CopyFrom(*d_out, ctx.GetPlace(), ctx.device_context()); + auto d_bias_mat = EigenMatrix::From(*d_bias); + d_bias_mat.device(place) = d_out_mat.sum(Eigen::DSizes(0)); } } }; diff --git a/python/paddle/v2/framework/tests/test_bilinear_tensor_product_op.py b/python/paddle/v2/framework/tests/test_bilinear_tensor_product_op.py index 10d90a9f0f9f8..1c1f388098065 100644 --- a/python/paddle/v2/framework/tests/test_bilinear_tensor_product_op.py +++ b/python/paddle/v2/framework/tests/test_bilinear_tensor_product_op.py @@ -6,24 +6,85 @@ class TestBilinearTensorProductOp(OpTest): def setUp(self): self.op_type = "bilinear_tensor_product" + batch_size = 6 + size0 = 3 + size1 = 4 + size2 = 5 + a = np.random.random((batch_size, size0)).astype("float32") + b = np.random.random((batch_size, size1)).astype("float32") + w = np.random.random((size2, size0, size1)).astype("float32") + bias = np.random.random((1, size2)).astype("float32") + output = np.zeros((batch_size, size2)).astype("float32") + for i in range(size2): + w_i = w[i, :, :] + output[:, i] = np.sum(np.matmul(a, w_i) * b, axis=1) self.inputs = { - 'X': np.random.random(3).astype("float32"), - 'Y': np.random.random(4).astype("float32"), - 'Weight': np.random.random((5, 3, 4)).astype("float32"), - 'Bias': np.random.random(5).astype("float32") + 'X': a, + 'Y': b, + 'Weight': w, + 'Bias': bias, } - self.outputs = { - 'Out': np.matmul( - np.matmul(self.inputs['Weight'], self.inputs['Y']), - self.inputs['X']) + self.inputs['Bias'] + self.outputs = {'Out': output + bias} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y', 'Weight', 'Bias'], 'Out') + + +class TestBilinearTensorProductOp2(TestBilinearTensorProductOp): + def setUp(self): + self.op_type = "bilinear_tensor_product" + batch_size = 1 + size0 = 1 + size1 = 1 + size2 = 1 + a = np.random.random((batch_size, size0)).astype("float32") + b = np.random.random((batch_size, size1)).astype("float32") + w = np.random.random((size2, size0, size1)).astype("float32") + bias = np.random.random((1, size2)).astype("float32") + output = np.zeros((batch_size, size2)).astype("float32") + for i in range(size2): + w_i = w[i, :, :] + output[:, i] = np.sum(np.matmul(a, w_i) * b, axis=1) + self.inputs = { + 'X': a, + 'Y': b, + 'Weight': w, + 'Bias': bias, } + self.outputs = {'Out': output + bias} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y', 'Weight', 'Bias'], 'Out') + + +class TestBilinearTensorProductOp3(TestBilinearTensorProductOp): + def setUp(self): + self.op_type = "bilinear_tensor_product" + batch_size = 7 + size0 = 4 + size1 = 5 + size2 = 6 + a = np.random.random((batch_size, size0)).astype("float32") + b = np.random.random((batch_size, size1)).astype("float32") + w = np.random.random((size2, size0, size1)).astype("float32") + output = np.zeros((batch_size, size2)).astype("float32") + for i in range(size2): + w_i = w[i, :, :] + output[:, i] = np.sum(np.matmul(a, w_i) * b, axis=1) + self.inputs = {'X': a, 'Y': b, 'Weight': w} + self.outputs = {'Out': output} def test_check_output(self): self.check_output() def test_check_grad_normal(self): - self.check_grad( - ['X', 'Y', 'Weight', 'Bias'], 'Out', max_relative_error=0.5) + self.check_grad(['X', 'Y', 'Weight'], 'Out') if __name__ == "__main__": From 47269273ff15afc0156939de46f800a15def609c Mon Sep 17 00:00:00 2001 From: peterzhang2029 Date: Wed, 8 Nov 2017 14:53:21 +0800 Subject: [PATCH 3/7] refine memory transform --- .../operators/bilinear_tensor_product_op.cc | 64 +++++++------ .../operators/bilinear_tensor_product_op.cu | 95 ++----------------- paddle/operators/bilinear_tensor_product_op.h | 37 +++----- 3 files changed, 58 insertions(+), 138 deletions(-) diff --git a/paddle/operators/bilinear_tensor_product_op.cc b/paddle/operators/bilinear_tensor_product_op.cc index afb9678b64943..dc02e5811e833 100644 --- a/paddle/operators/bilinear_tensor_product_op.cc +++ b/paddle/operators/bilinear_tensor_product_op.cc @@ -34,34 +34,34 @@ class BilinearTensorProductOp : public framework::OperatorWithKernel { auto y_dims = ctx->GetInputDim("Y"); auto weight_dims = ctx->GetInputDim("Weight"); - PADDLE_ENFORCE_EQ(x_dims.size(), 2, "The input X must be a 2D Tensor."); - PADDLE_ENFORCE_EQ(y_dims.size(), 2, "The input Y must be a 2D Tensor."); - PADDLE_ENFORCE_EQ(weight_dims.size(), 3, + PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "The input X must be a 2D Tensor."); + PADDLE_ENFORCE_EQ(y_dims.size(), 2UL, "The input Y must be a 2D Tensor."); + PADDLE_ENFORCE_EQ(weight_dims.size(), 3UL, "The input Weight must be a 3D tensor."); - PADDLE_ENFORCE_GT(weight_dims[0], 0, - "The first dimension of Weight must be larger than 0."); - PADDLE_ENFORCE_GT(weight_dims[1], 0, - "The second dimension of Weight must be larger than 0."); - PADDLE_ENFORCE_GT(weight_dims[2], 0, - "The third dimension of Weight must be larger than 0."); + PADDLE_ENFORCE(weight_dims[0], + "The first dimension of Weight must be larger than 0."); + PADDLE_ENFORCE(weight_dims[1], + "The second dimension of Weight must be larger than 0."); + PADDLE_ENFORCE(weight_dims[2], + "The third dimension of Weight must be larger than 0."); PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0], "The first dimension(batch_size) of X must be " - "equal with the first dimension of the Y."); + "equal to the first dimension of the Y."); PADDLE_ENFORCE_EQ(x_dims[1], weight_dims[1], - "The second dimension of X must be equal with the second " + "The second dimension of X must be equal to the second " "dimension of the Weight."); PADDLE_ENFORCE_EQ(y_dims[1], weight_dims[2], - "The second dimension of Y must be equal with the third " + "The second dimension of Y must be equal to the third " "dimension of the Weight."); if (ctx->HasInput("Bias")) { auto bias_dims = ctx->GetInputDim("Bias"); - PADDLE_ENFORCE_EQ(bias_dims.size(), 2, + PADDLE_ENFORCE_EQ(bias_dims.size(), 2UL, "The input Bias must have 2 dimensions."); - PADDLE_ENFORCE_EQ(bias_dims[0], 1, + PADDLE_ENFORCE_EQ(bias_dims[0], 1UL, "The first dimention of input Bias must be 1."); PADDLE_ENFORCE_EQ(bias_dims[1], weight_dims[0], - "The second dimension of Bias must be equal with the " + "The second dimension of Bias must be equal to the " "first dimension of the Weight."); } @@ -75,12 +75,12 @@ class BilinearTensorProductOpMaker : public framework::OpProtoAndCheckerMaker { BilinearTensorProductOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "The first input of BilinearTensorProduct op"); - AddInput("Y", "The second input of BilinearTensorProduct op"); - AddInput("Weight", "The input weight of BilinearTensorProduct op"); - AddInput("Bias", "The input bias of BilinearTensorProduct op") + AddInput("X", "The first input of BilinearTensorProduct op."); + AddInput("Y", "The second input of BilinearTensorProduct op."); + AddInput("Weight", "The input weight of BilinearTensorProduct op."); + AddInput("Bias", "The input bias of BilinearTensorProduct op.") .AsDispensable(); - AddOutput("Out", "The output of BilinearTensorProduct op"); + AddOutput("Out", "The output of BilinearTensorProduct op."); AddComment(R"DOC( Bilinear Tensor Product operator. Given input X and Y, a 3D tensor weight, and bias. Each column of the @@ -99,30 +99,32 @@ class BilinearTensorProductOpGrad : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); - PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null"); - PADDLE_ENFORCE(ctx->HasInput("Weight"), "Input(Weight) should not be null"); + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Weight"), + "Input(Weight) should not be null."); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), - "Input (Out@GRAD) should not be null"); + "Input (Out@GRAD) should not be null."); auto x_dims = ctx->GetInputDim("X"); auto y_dims = ctx->GetInputDim("Y"); auto weight_dims = ctx->GetInputDim("Weight"); auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); - PADDLE_ENFORCE_EQ(out_dims.size(), 2, "The Out@GRAD must be a 2D Tensor."); + PADDLE_ENFORCE_EQ(out_dims.size(), 2UL, + "The Out@GRAD must be a 2D Tensor."); PADDLE_ENFORCE_EQ( x_dims[0], out_dims[0], - "The first dimension(batch_size) of Out@GRAD must be equal with " - "the first dimension of the X."); + "The first dimension(batch_size) of Out@GRAD must be equal to " + "the first dimension of the Input(X)."); PADDLE_ENFORCE_EQ(weight_dims[0], out_dims[1], - "The second dimension of Out@GRAD must be equal with " - "the third dimension of the Weight."); + "The second dimension of Out@GRAD must be equal to " + "the third dimension of the Input(Weight)."); if (ctx->HasInput("Bias")) { auto bias_dims = ctx->GetInputDim("Bias"); PADDLE_ENFORCE_EQ(bias_dims[1], out_dims[1], - "The second dimension of Bias must be equal with " - "the second dimension of the Out@GRAD."); + "The second dimension of Out@GRAD must be equal to " + "the second dimension of the Input(Bias)."); auto bias_grad_name = framework::GradVarName("Bias"); if (ctx->HasOutput(bias_grad_name)) ctx->SetOutputDim(bias_grad_name, bias_dims); diff --git a/paddle/operators/bilinear_tensor_product_op.cu b/paddle/operators/bilinear_tensor_product_op.cu index 1afdfe4b110d3..0f28a01c87e65 100644 --- a/paddle/operators/bilinear_tensor_product_op.cu +++ b/paddle/operators/bilinear_tensor_product_op.cu @@ -1,99 +1,24 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #define EIGEN_USE_GPU #include "paddle/operators/bilinear_tensor_product_op.h" -namespace paddle { -namespace operators { - -template -class BilinearTensorProductCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* weight = ctx.Input("Weight"); - auto* bias = ctx.Input("Bias"); - auto* out = ctx.Output("Out"); - out->mutable_data(ctx.GetPlace()); - - auto y_mat = EigenMatrix::From(*y); - auto batch_size = x->dims()[0]; - auto weight_dims = weight->dims(); - - auto place = ctx.GetEigenDevice(); - auto cpu_place = ctx.GetEigenDevice(); - - // Copy the output to cpu. - Tensor output_cpu; - output_cpu.CopyFrom(*out, platform::CPUPlace(), ctx.device_context()); - auto* output_cpu_ptr = output_cpu.data(); - auto output_cpu_mat = EigenMatrix::From(output_cpu); - - // Create the temporary variables. - Tensor left_mul; - left_mul.mutable_data(framework::make_ddim({batch_size, weight_dims[2]}), - ctx.GetPlace()); - auto left_mul_mat = EigenMatrix::From(left_mul); - Tensor output_col; - output_col.mutable_data(framework::make_ddim({batch_size}), - ctx.GetPlace()); - auto output_col_vec = EigenVector::From(output_col); - - for (size_t i = 0; i < weight_dims[0]; ++i) { - Tensor weight_mat = weight->Slice(i, i + 1).Resize( - framework::make_ddim({weight_dims[1], weight_dims[2]})); - math::gemm(ctx.device_context(), CblasNoTrans, CblasNoTrans, - batch_size, weight_dims[2], weight_dims[1], 1, - x->data(), weight_mat.data(), 0, - left_mul.data()); - output_col_vec.device(place) = - (left_mul_mat * y_mat).sum(Eigen::DSizes(1)); - - // Copy the output_col to cpu. - Tensor output_col_cpu; - output_col_cpu.CopyFrom(output_col, platform::CPUPlace(), - ctx.device_context()); - auto* output_col_ptr = output_col_cpu.data(); - - for (size_t j = 0; j < batch_size; ++j) { - output_cpu_ptr[i + j * weight_dims[0]] = output_col_ptr[j]; - } - } - - if (bias) { - // Copy the bias to cpu. - Tensor bias_cpu; - bias_cpu.CopyFrom(*bias, platform::CPUPlace(), ctx.device_context()); - auto bias_vec = EigenMatrix::From(bias_cpu); - Eigen::DSizes bcast(batch_size, 1); - output_cpu_mat.device(cpu_place) = - bias_vec.broadcast(bcast) + output_cpu_mat; - } - - // Copy the output to gpu. - out->CopyFrom(output_cpu, platform::GPUPlace(), ctx.device_context()); - } -}; -} // namespace operators -} // namespace paddle - namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL( bilinear_tensor_product, - ops::BilinearTensorProductCUDAKernel); + ops::BilinearTensorProductKernel); REGISTER_OP_GPU_KERNEL( bilinear_tensor_product_grad, ops::BilinearTensorProductGradKernel); diff --git a/paddle/operators/bilinear_tensor_product_op.h b/paddle/operators/bilinear_tensor_product_op.h index 238d1d7749694..6b40f77c4205a 100644 --- a/paddle/operators/bilinear_tensor_product_op.h +++ b/paddle/operators/bilinear_tensor_product_op.h @@ -1,16 +1,16 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - You may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #pragma once @@ -21,7 +21,7 @@ namespace paddle { namespace operators { -using Tensor = framework::Tensor; +using framework::Tensor; template @@ -49,34 +49,27 @@ class BilinearTensorProductKernel : public framework::OpKernel { auto weight_dims = weight->dims(); auto place = ctx.GetEigenDevice(); - // Create the temporary variables. + // Create the intermediate variables. Tensor left_mul; left_mul.mutable_data(framework::make_ddim({batch_size, weight_dims[2]}), ctx.GetPlace()); auto left_mul_mat = EigenMatrix::From(left_mul); - Tensor output_col; - output_col.mutable_data(framework::make_ddim({weight_dims[0]}), - ctx.GetPlace()); - auto output_col_vec = EigenVector::From(output_col); for (size_t i = 0; i < weight_dims[0]; ++i) { + auto output_col_vec = output_mat.chip(i, 1); Tensor weight_mat = weight->Slice(i, i + 1).Resize( framework::make_ddim({weight_dims[1], weight_dims[2]})); math::gemm(ctx.device_context(), CblasNoTrans, CblasNoTrans, batch_size, weight_dims[2], weight_dims[1], 1, x->data(), weight_mat.data(), 0, left_mul.data()); - output_col_vec = (left_mul_mat * y_mat).sum(Eigen::DSizes(1)); - for (size_t j = 0; j < batch_size; ++j) { - output_mat(j, i) = output_col_vec(j); - } + output_col_vec.device(place) = + (left_mul_mat * y_mat).sum(Eigen::DSizes(1)); } if (bias) { auto bias_vec = EigenMatrix::From(*bias); Eigen::DSizes bcast(batch_size, 1); output_mat.device(place) = bias_vec.broadcast(bcast) + output_mat; - } else { - output_mat.device(place) = output_mat; } } }; @@ -102,7 +95,7 @@ class BilinearTensorProductGradKernel : public framework::OpKernel { auto d_out_mat = EigenMatrix::From(*d_out); auto place = ctx.GetEigenDevice(); - // Create the temporary variables for gradient. + // Create the intermediate variables for gradient. Tensor x_scale; x_scale.mutable_data(framework::make_ddim({batch_size, weight_dims[1]}), ctx.GetPlace()); From 5cf8204171bbe11de9bff1eb6b6e59f2ad1a5263 Mon Sep 17 00:00:00 2001 From: peterzhang2029 Date: Thu, 9 Nov 2017 17:30:12 +0800 Subject: [PATCH 4/7] refine docString --- .../operators/bilinear_tensor_product_op.cc | 74 +++++++++---------- .../operators/bilinear_tensor_product_op.cu | 6 +- .../tests/test_bilinear_tensor_product_op.py | 54 -------------- 3 files changed, 40 insertions(+), 94 deletions(-) diff --git a/paddle/operators/bilinear_tensor_product_op.cc b/paddle/operators/bilinear_tensor_product_op.cc index dc02e5811e833..c65ba7eb262f3 100644 --- a/paddle/operators/bilinear_tensor_product_op.cc +++ b/paddle/operators/bilinear_tensor_product_op.cc @@ -34,35 +34,28 @@ class BilinearTensorProductOp : public framework::OperatorWithKernel { auto y_dims = ctx->GetInputDim("Y"); auto weight_dims = ctx->GetInputDim("Weight"); - PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "The input X must be a 2D Tensor."); - PADDLE_ENFORCE_EQ(y_dims.size(), 2UL, "The input Y must be a 2D Tensor."); + PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "The input(X) must be a 2D Tensor."); + PADDLE_ENFORCE_EQ(y_dims.size(), 2UL, "The input(Y) must be a 2D Tensor."); PADDLE_ENFORCE_EQ(weight_dims.size(), 3UL, - "The input Weight must be a 3D tensor."); - PADDLE_ENFORCE(weight_dims[0], - "The first dimension of Weight must be larger than 0."); - PADDLE_ENFORCE(weight_dims[1], - "The second dimension of Weight must be larger than 0."); - PADDLE_ENFORCE(weight_dims[2], - "The third dimension of Weight must be larger than 0."); + "The input(Weight) must be a 3D tensor."); PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0], - "The first dimension(batch_size) of X must be " - "equal to the first dimension of the Y."); + "The first dimension(batch_size) of input(X) must be " + "equal to the first dimension of the input(Y)."); PADDLE_ENFORCE_EQ(x_dims[1], weight_dims[1], - "The second dimension of X must be equal to the second " - "dimension of the Weight."); + "The second dimension of input(X) must be equal to " + "the second dimension of the input(Weight)."); PADDLE_ENFORCE_EQ(y_dims[1], weight_dims[2], - "The second dimension of Y must be equal to the third " - "dimension of the Weight."); + "The second dimension of input(Y) must be equal to " + "the third dimension of the input(Weight)."); if (ctx->HasInput("Bias")) { auto bias_dims = ctx->GetInputDim("Bias"); - PADDLE_ENFORCE_EQ(bias_dims.size(), 2UL, - "The input Bias must have 2 dimensions."); - PADDLE_ENFORCE_EQ(bias_dims[0], 1UL, - "The first dimention of input Bias must be 1."); + PADDLE_ENFORCE(bias_dims.size() == 2UL && bias_dims[0] == 1UL, + "The Input(Bias) must be a 2-D tensor with " + "the 2nd dimension fixed to 1 (a row vector)."); PADDLE_ENFORCE_EQ(bias_dims[1], weight_dims[0], - "The second dimension of Bias must be equal to the " - "first dimension of the Weight."); + "The second dimension of input(Bias) must be equal " + "to the first dimension of the input(Weight)."); } ctx->SetOutputDim("Out", {x_dims[0], weight_dims[0]}); @@ -75,12 +68,13 @@ class BilinearTensorProductOpMaker : public framework::OpProtoAndCheckerMaker { BilinearTensorProductOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "The first input of BilinearTensorProduct op."); - AddInput("Y", "The second input of BilinearTensorProduct op."); - AddInput("Weight", "The input weight of BilinearTensorProduct op."); - AddInput("Bias", "The input bias of BilinearTensorProduct op.") + AddInput("X", "The first input of bilinear_tensor_product operator."); + AddInput("Y", "The second input of bilinear_tensor_product operator."); + AddInput("Weight", + "The learnable parameters of bilinear_tensor_product operator."); + AddInput("Bias", "The learnable bias of bilinear_tensor_product operator.") .AsDispensable(); - AddOutput("Out", "The output of BilinearTensorProduct op."); + AddOutput("Out", "The output of bilinear_tensor_product operator."); AddComment(R"DOC( Bilinear Tensor Product operator. Given input X and Y, a 3D tensor weight, and bias. Each column of the @@ -104,27 +98,29 @@ class BilinearTensorProductOpGrad : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasInput("Weight"), "Input(Weight) should not be null."); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), - "Input (Out@GRAD) should not be null."); + "Input(Out@GRAD) should not be null."); auto x_dims = ctx->GetInputDim("X"); auto y_dims = ctx->GetInputDim("Y"); auto weight_dims = ctx->GetInputDim("Weight"); auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); PADDLE_ENFORCE_EQ(out_dims.size(), 2UL, - "The Out@GRAD must be a 2D Tensor."); + "The input(Out@GRAD) must be a 2D Tensor."); PADDLE_ENFORCE_EQ( x_dims[0], out_dims[0], - "The first dimension(batch_size) of Out@GRAD must be equal to " - "the first dimension of the Input(X)."); - PADDLE_ENFORCE_EQ(weight_dims[0], out_dims[1], - "The second dimension of Out@GRAD must be equal to " - "the third dimension of the Input(Weight)."); + "The first dimension(batch_size) of input(Out@GRAD) must be " + "equal to the first dimension of the Input(X)."); + PADDLE_ENFORCE_EQ( + weight_dims[0], out_dims[1], + "The second dimension of input(Out@GRAD) must be equal to " + "the third dimension of the Input(Weight)."); if (ctx->HasInput("Bias")) { auto bias_dims = ctx->GetInputDim("Bias"); - PADDLE_ENFORCE_EQ(bias_dims[1], out_dims[1], - "The second dimension of Out@GRAD must be equal to " - "the second dimension of the Input(Bias)."); + PADDLE_ENFORCE_EQ( + bias_dims[1], out_dims[1], + "The second dimension of input(Out@GRAD) must be equal to " + "the second dimension of the Input(Bias)."); auto bias_grad_name = framework::GradVarName("Bias"); if (ctx->HasOutput(bias_grad_name)) ctx->SetOutputDim(bias_grad_name, bias_dims); @@ -155,7 +151,9 @@ REGISTER_OP(bilinear_tensor_product, ops::BilinearTensorProductOp, ops::BilinearTensorProductOpGrad); REGISTER_OP_CPU_KERNEL( bilinear_tensor_product, - ops::BilinearTensorProductKernel); + ops::BilinearTensorProductKernel, + ops::BilinearTensorProductKernel); REGISTER_OP_CPU_KERNEL( bilinear_tensor_product_grad, - ops::BilinearTensorProductGradKernel); + ops::BilinearTensorProductGradKernel, + ops::BilinearTensorProductGradKernel); diff --git a/paddle/operators/bilinear_tensor_product_op.cu b/paddle/operators/bilinear_tensor_product_op.cu index 0f28a01c87e65..858d2668d0137 100644 --- a/paddle/operators/bilinear_tensor_product_op.cu +++ b/paddle/operators/bilinear_tensor_product_op.cu @@ -18,7 +18,9 @@ limitations under the License. */ namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL( bilinear_tensor_product, - ops::BilinearTensorProductKernel); + ops::BilinearTensorProductKernel, + ops::BilinearTensorProductKernel); REGISTER_OP_GPU_KERNEL( bilinear_tensor_product_grad, - ops::BilinearTensorProductGradKernel); + ops::BilinearTensorProductGradKernel, + ops::BilinearTensorProductGradKernel); diff --git a/python/paddle/v2/framework/tests/test_bilinear_tensor_product_op.py b/python/paddle/v2/framework/tests/test_bilinear_tensor_product_op.py index 1c1f388098065..080ca43b8269e 100644 --- a/python/paddle/v2/framework/tests/test_bilinear_tensor_product_op.py +++ b/python/paddle/v2/framework/tests/test_bilinear_tensor_product_op.py @@ -33,59 +33,5 @@ def test_check_grad_normal(self): self.check_grad(['X', 'Y', 'Weight', 'Bias'], 'Out') -class TestBilinearTensorProductOp2(TestBilinearTensorProductOp): - def setUp(self): - self.op_type = "bilinear_tensor_product" - batch_size = 1 - size0 = 1 - size1 = 1 - size2 = 1 - a = np.random.random((batch_size, size0)).astype("float32") - b = np.random.random((batch_size, size1)).astype("float32") - w = np.random.random((size2, size0, size1)).astype("float32") - bias = np.random.random((1, size2)).astype("float32") - output = np.zeros((batch_size, size2)).astype("float32") - for i in range(size2): - w_i = w[i, :, :] - output[:, i] = np.sum(np.matmul(a, w_i) * b, axis=1) - self.inputs = { - 'X': a, - 'Y': b, - 'Weight': w, - 'Bias': bias, - } - self.outputs = {'Out': output + bias} - - def test_check_output(self): - self.check_output() - - def test_check_grad_normal(self): - self.check_grad(['X', 'Y', 'Weight', 'Bias'], 'Out') - - -class TestBilinearTensorProductOp3(TestBilinearTensorProductOp): - def setUp(self): - self.op_type = "bilinear_tensor_product" - batch_size = 7 - size0 = 4 - size1 = 5 - size2 = 6 - a = np.random.random((batch_size, size0)).astype("float32") - b = np.random.random((batch_size, size1)).astype("float32") - w = np.random.random((size2, size0, size1)).astype("float32") - output = np.zeros((batch_size, size2)).astype("float32") - for i in range(size2): - w_i = w[i, :, :] - output[:, i] = np.sum(np.matmul(a, w_i) * b, axis=1) - self.inputs = {'X': a, 'Y': b, 'Weight': w} - self.outputs = {'Out': output} - - def test_check_output(self): - self.check_output() - - def test_check_grad_normal(self): - self.check_grad(['X', 'Y', 'Weight'], 'Out') - - if __name__ == "__main__": unittest.main() From 5f99ae908b5fac433df28cc806d5514a6054b26c Mon Sep 17 00:00:00 2001 From: peterzhang2029 Date: Mon, 13 Nov 2017 13:44:12 +0800 Subject: [PATCH 5/7] refine notation in bilinear_tensor_product_op.h --- paddle/operators/bilinear_tensor_product_op.h | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/paddle/operators/bilinear_tensor_product_op.h b/paddle/operators/bilinear_tensor_product_op.h index 6b40f77c4205a..29da5f4d2a658 100644 --- a/paddle/operators/bilinear_tensor_product_op.h +++ b/paddle/operators/bilinear_tensor_product_op.h @@ -27,10 +27,6 @@ template using EigenMatrix = framework::EigenMatrix; -template -using EigenVector = framework::EigenVector; - template class BilinearTensorProductKernel : public framework::OpKernel { public: @@ -49,7 +45,9 @@ class BilinearTensorProductKernel : public framework::OpKernel { auto weight_dims = weight->dims(); auto place = ctx.GetEigenDevice(); - // Create the intermediate variables. + // Create the intermediate variable to caculate the result of + // Input(X) multiplied by Input(Weight_i), the formula is: + // left_mul = X Weight_i. Tensor left_mul; left_mul.mutable_data(framework::make_ddim({batch_size, weight_dims[2]}), ctx.GetPlace()); @@ -95,11 +93,13 @@ class BilinearTensorProductGradKernel : public framework::OpKernel { auto d_out_mat = EigenMatrix::From(*d_out); auto place = ctx.GetEigenDevice(); - // Create the intermediate variables for gradient. + // Create the intermediate variable to caculate the Output(Y@Grad). Tensor x_scale; x_scale.mutable_data(framework::make_ddim({batch_size, weight_dims[1]}), ctx.GetPlace()); auto x_scale_mat = EigenMatrix::From(x_scale); + + // Create the intermediate variable to caculate the Output(X@Grad). Tensor y_scale; y_scale.mutable_data(framework::make_ddim({batch_size, weight_dims[2]}), ctx.GetPlace()); @@ -107,19 +107,19 @@ class BilinearTensorProductGradKernel : public framework::OpKernel { math::SetConstant set_zero; - // Set X@Grad be zero at first. + // Set Output(X@Grad) be zero. if (d_x) { d_x->mutable_data(ctx.GetPlace()); set_zero(ctx.device_context(), d_x, static_cast(0)); } - // Set Y@Grad be zero at first. + // Set Output(Y@Grad) be zero. if (d_y) { d_y->mutable_data(ctx.GetPlace()); set_zero(ctx.device_context(), d_y, static_cast(0)); } - // Caculate the X@Grad and Y@Grad. + // Caculate the Output(X@Grad) and Output(Y@Grad). if (d_x || d_y) { Eigen::DSizes bcast_for_x(1, weight_dims[2]); Eigen::DSizes bcast_for_y(1, weight_dims[1]); @@ -150,7 +150,7 @@ class BilinearTensorProductGradKernel : public framework::OpKernel { } } - // Caculate the gradient of Weight. + // Caculate the gradient of Input(Weight). if (d_weight) { d_weight->mutable_data(ctx.GetPlace()); Eigen::DSizes bcast_for_weight(1, weight_dims[1]); @@ -169,7 +169,7 @@ class BilinearTensorProductGradKernel : public framework::OpKernel { } } - // Caculate the gradient of Bias. + // Caculate the gradient of Input(Bias). if (d_bias) { d_bias->mutable_data(ctx.GetPlace()); auto d_bias_mat = EigenMatrix::From(*d_bias); From 0a6262d550c784548ee78719a46b748d89adc0bd Mon Sep 17 00:00:00 2001 From: peterzhang2029 Date: Mon, 13 Nov 2017 18:45:43 +0800 Subject: [PATCH 6/7] fix warning --- paddle/operators/bilinear_tensor_product_op.h | 52 ++++++++++--------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/paddle/operators/bilinear_tensor_product_op.h b/paddle/operators/bilinear_tensor_product_op.h index 29da5f4d2a658..984e7abdfb143 100644 --- a/paddle/operators/bilinear_tensor_product_op.h +++ b/paddle/operators/bilinear_tensor_product_op.h @@ -43,24 +43,26 @@ class BilinearTensorProductKernel : public framework::OpKernel { auto batch_size = x->dims()[0]; auto weight_dims = weight->dims(); + int Out_dim = weight_dims[0]; + int X_dim = weight_dims[1]; + int Y_dim = weight_dims[2]; auto place = ctx.GetEigenDevice(); // Create the intermediate variable to caculate the result of // Input(X) multiplied by Input(Weight_i), the formula is: // left_mul = X Weight_i. Tensor left_mul; - left_mul.mutable_data(framework::make_ddim({batch_size, weight_dims[2]}), + left_mul.mutable_data(framework::make_ddim({batch_size, Y_dim}), ctx.GetPlace()); auto left_mul_mat = EigenMatrix::From(left_mul); - for (size_t i = 0; i < weight_dims[0]; ++i) { + for (int i = 0; i < Out_dim; ++i) { auto output_col_vec = output_mat.chip(i, 1); - Tensor weight_mat = weight->Slice(i, i + 1).Resize( - framework::make_ddim({weight_dims[1], weight_dims[2]})); + Tensor weight_mat = + weight->Slice(i, i + 1).Resize(framework::make_ddim({X_dim, Y_dim})); math::gemm(ctx.device_context(), CblasNoTrans, CblasNoTrans, - batch_size, weight_dims[2], weight_dims[1], 1, - x->data(), weight_mat.data(), 0, - left_mul.data()); + batch_size, Y_dim, X_dim, 1, x->data(), + weight_mat.data(), 0, left_mul.data()); output_col_vec.device(place) = (left_mul_mat * y_mat).sum(Eigen::DSizes(1)); } @@ -87,6 +89,9 @@ class BilinearTensorProductGradKernel : public framework::OpKernel { auto batch_size = x->dims()[0]; auto weight_dims = weight->dims(); + int Out_dim = weight_dims[0]; + int X_dim = weight_dims[1]; + int Y_dim = weight_dims[2]; auto x_mat = EigenMatrix::From(*x); auto y_mat = EigenMatrix::From(*y); @@ -95,13 +100,13 @@ class BilinearTensorProductGradKernel : public framework::OpKernel { // Create the intermediate variable to caculate the Output(Y@Grad). Tensor x_scale; - x_scale.mutable_data(framework::make_ddim({batch_size, weight_dims[1]}), + x_scale.mutable_data(framework::make_ddim({batch_size, X_dim}), ctx.GetPlace()); auto x_scale_mat = EigenMatrix::From(x_scale); // Create the intermediate variable to caculate the Output(X@Grad). Tensor y_scale; - y_scale.mutable_data(framework::make_ddim({batch_size, weight_dims[2]}), + y_scale.mutable_data(framework::make_ddim({batch_size, Y_dim}), ctx.GetPlace()); auto y_scale_mat = EigenMatrix::From(y_scale); @@ -121,11 +126,11 @@ class BilinearTensorProductGradKernel : public framework::OpKernel { // Caculate the Output(X@Grad) and Output(Y@Grad). if (d_x || d_y) { - Eigen::DSizes bcast_for_x(1, weight_dims[2]); - Eigen::DSizes bcast_for_y(1, weight_dims[1]); - for (int i = 0; i < weight_dims[0]; ++i) { + Eigen::DSizes bcast_for_x(1, Y_dim); + Eigen::DSizes bcast_for_y(1, X_dim); + for (int i = 0; i < Out_dim; ++i) { Tensor weight_i = weight->Slice(i, i + 1).Resize( - framework::make_ddim({weight_dims[1], weight_dims[2]})); + framework::make_ddim({X_dim, Y_dim})); auto output_vec = d_out_mat.chip(i, 1); if (d_x) { y_scale_mat.device(place) = @@ -133,9 +138,8 @@ class BilinearTensorProductGradKernel : public framework::OpKernel { .broadcast(bcast_for_x) * y_mat; math::gemm(ctx.device_context(), CblasNoTrans, CblasTrans, - batch_size, weight_dims[1], weight_dims[2], 1, - y_scale.data(), weight_i.data(), 1, - d_x->data()); + batch_size, X_dim, Y_dim, 1, y_scale.data(), + weight_i.data(), 1, d_x->data()); } if (d_y) { x_scale_mat.device(place) = @@ -143,9 +147,8 @@ class BilinearTensorProductGradKernel : public framework::OpKernel { .broadcast(bcast_for_y) * x_mat; math::gemm(ctx.device_context(), CblasNoTrans, CblasNoTrans, - batch_size, weight_dims[2], weight_dims[1], 1, - x_scale.data(), weight_i.data(), 1, - d_y->data()); + batch_size, Y_dim, X_dim, 1, x_scale.data(), + weight_i.data(), 1, d_y->data()); } } } @@ -153,19 +156,18 @@ class BilinearTensorProductGradKernel : public framework::OpKernel { // Caculate the gradient of Input(Weight). if (d_weight) { d_weight->mutable_data(ctx.GetPlace()); - Eigen::DSizes bcast_for_weight(1, weight_dims[1]); - for (int i = 0; i < weight_dims[0]; ++i) { + Eigen::DSizes bcast_for_weight(1, X_dim); + for (int i = 0; i < Out_dim; ++i) { Tensor d_weight_i = d_weight->Slice(i, i + 1).Resize( - framework::make_ddim({weight_dims[1], weight_dims[2]})); + framework::make_ddim({X_dim, Y_dim})); auto output_vec = d_out_mat.chip(i, 1); x_scale_mat.device(place) = output_vec.reshape(Eigen::DSizes(batch_size, 1)) .broadcast(bcast_for_weight) * x_mat; math::gemm(ctx.device_context(), CblasTrans, CblasNoTrans, - weight_dims[1], weight_dims[2], batch_size, 1, - x_scale.data(), y->data(), 0, - d_weight_i.data()); + X_dim, Y_dim, batch_size, 1, x_scale.data(), + y->data(), 0, d_weight_i.data()); } } From c5d7107767a1a42f46e7d0bf42ef26279fd562db Mon Sep 17 00:00:00 2001 From: peterzhang2029 Date: Mon, 13 Nov 2017 20:17:38 +0800 Subject: [PATCH 7/7] refine var name --- paddle/operators/bilinear_tensor_product_op.h | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/paddle/operators/bilinear_tensor_product_op.h b/paddle/operators/bilinear_tensor_product_op.h index 984e7abdfb143..ffa4f43a32741 100644 --- a/paddle/operators/bilinear_tensor_product_op.h +++ b/paddle/operators/bilinear_tensor_product_op.h @@ -43,25 +43,25 @@ class BilinearTensorProductKernel : public framework::OpKernel { auto batch_size = x->dims()[0]; auto weight_dims = weight->dims(); - int Out_dim = weight_dims[0]; - int X_dim = weight_dims[1]; - int Y_dim = weight_dims[2]; + int out_dim = weight_dims[0]; + auto x_dim = weight_dims[1]; + auto y_dim = weight_dims[2]; auto place = ctx.GetEigenDevice(); // Create the intermediate variable to caculate the result of // Input(X) multiplied by Input(Weight_i), the formula is: // left_mul = X Weight_i. Tensor left_mul; - left_mul.mutable_data(framework::make_ddim({batch_size, Y_dim}), + left_mul.mutable_data(framework::make_ddim({batch_size, y_dim}), ctx.GetPlace()); auto left_mul_mat = EigenMatrix::From(left_mul); - for (int i = 0; i < Out_dim; ++i) { + for (int i = 0; i < out_dim; ++i) { auto output_col_vec = output_mat.chip(i, 1); Tensor weight_mat = - weight->Slice(i, i + 1).Resize(framework::make_ddim({X_dim, Y_dim})); + weight->Slice(i, i + 1).Resize(framework::make_ddim({x_dim, y_dim})); math::gemm(ctx.device_context(), CblasNoTrans, CblasNoTrans, - batch_size, Y_dim, X_dim, 1, x->data(), + batch_size, y_dim, x_dim, 1, x->data(), weight_mat.data(), 0, left_mul.data()); output_col_vec.device(place) = (left_mul_mat * y_mat).sum(Eigen::DSizes(1)); @@ -89,9 +89,9 @@ class BilinearTensorProductGradKernel : public framework::OpKernel { auto batch_size = x->dims()[0]; auto weight_dims = weight->dims(); - int Out_dim = weight_dims[0]; - int X_dim = weight_dims[1]; - int Y_dim = weight_dims[2]; + int out_dim = weight_dims[0]; + auto x_dim = weight_dims[1]; + auto y_dim = weight_dims[2]; auto x_mat = EigenMatrix::From(*x); auto y_mat = EigenMatrix::From(*y); @@ -100,13 +100,13 @@ class BilinearTensorProductGradKernel : public framework::OpKernel { // Create the intermediate variable to caculate the Output(Y@Grad). Tensor x_scale; - x_scale.mutable_data(framework::make_ddim({batch_size, X_dim}), + x_scale.mutable_data(framework::make_ddim({batch_size, x_dim}), ctx.GetPlace()); auto x_scale_mat = EigenMatrix::From(x_scale); // Create the intermediate variable to caculate the Output(X@Grad). Tensor y_scale; - y_scale.mutable_data(framework::make_ddim({batch_size, Y_dim}), + y_scale.mutable_data(framework::make_ddim({batch_size, y_dim}), ctx.GetPlace()); auto y_scale_mat = EigenMatrix::From(y_scale); @@ -126,11 +126,11 @@ class BilinearTensorProductGradKernel : public framework::OpKernel { // Caculate the Output(X@Grad) and Output(Y@Grad). if (d_x || d_y) { - Eigen::DSizes bcast_for_x(1, Y_dim); - Eigen::DSizes bcast_for_y(1, X_dim); - for (int i = 0; i < Out_dim; ++i) { + Eigen::DSizes bcast_for_x(1, y_dim); + Eigen::DSizes bcast_for_y(1, x_dim); + for (int i = 0; i < out_dim; ++i) { Tensor weight_i = weight->Slice(i, i + 1).Resize( - framework::make_ddim({X_dim, Y_dim})); + framework::make_ddim({x_dim, y_dim})); auto output_vec = d_out_mat.chip(i, 1); if (d_x) { y_scale_mat.device(place) = @@ -138,7 +138,7 @@ class BilinearTensorProductGradKernel : public framework::OpKernel { .broadcast(bcast_for_x) * y_mat; math::gemm(ctx.device_context(), CblasNoTrans, CblasTrans, - batch_size, X_dim, Y_dim, 1, y_scale.data(), + batch_size, x_dim, y_dim, 1, y_scale.data(), weight_i.data(), 1, d_x->data()); } if (d_y) { @@ -147,7 +147,7 @@ class BilinearTensorProductGradKernel : public framework::OpKernel { .broadcast(bcast_for_y) * x_mat; math::gemm(ctx.device_context(), CblasNoTrans, CblasNoTrans, - batch_size, Y_dim, X_dim, 1, x_scale.data(), + batch_size, y_dim, x_dim, 1, x_scale.data(), weight_i.data(), 1, d_y->data()); } } @@ -156,17 +156,17 @@ class BilinearTensorProductGradKernel : public framework::OpKernel { // Caculate the gradient of Input(Weight). if (d_weight) { d_weight->mutable_data(ctx.GetPlace()); - Eigen::DSizes bcast_for_weight(1, X_dim); - for (int i = 0; i < Out_dim; ++i) { + Eigen::DSizes bcast_for_weight(1, x_dim); + for (int i = 0; i < out_dim; ++i) { Tensor d_weight_i = d_weight->Slice(i, i + 1).Resize( - framework::make_ddim({X_dim, Y_dim})); + framework::make_ddim({x_dim, y_dim})); auto output_vec = d_out_mat.chip(i, 1); x_scale_mat.device(place) = output_vec.reshape(Eigen::DSizes(batch_size, 1)) .broadcast(bcast_for_weight) * x_mat; math::gemm(ctx.device_context(), CblasTrans, CblasNoTrans, - X_dim, Y_dim, batch_size, 1, x_scale.data(), + x_dim, y_dim, batch_size, 1, x_scale.data(), y->data(), 0, d_weight_i.data()); } }