diff --git a/cmake/external/mkldnn.cmake b/cmake/external/mkldnn.cmake index 25c6b4ef52d3f..9686df0021900 100644 --- a/cmake/external/mkldnn.cmake +++ b/cmake/external/mkldnn.cmake @@ -51,7 +51,7 @@ ExternalProject_Add( ${EXTERNAL_PROJECT_LOG_ARGS} DEPENDS ${MKLDNN_DEPENDS} GIT_REPOSITORY "https://github.com/01org/mkl-dnn.git" - GIT_TAG "v0.9" + GIT_TAG "v0.10" PREFIX ${MKLDNN_SOURCES_DIR} UPDATE_COMMAND "" CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${MKLDNN_INSTALL_DIR} diff --git a/cmake/external/mklml.cmake b/cmake/external/mklml.cmake index e9fd3d4bedc98..74f3279831357 100644 --- a/cmake/external/mklml.cmake +++ b/cmake/external/mklml.cmake @@ -28,7 +28,7 @@ INCLUDE(ExternalProject) SET(MKLML_PROJECT "extern_mklml") SET(MKLML_VER "mklml_lnx_2018.0.20170720") -SET(MKLML_URL "https://github.com/01org/mkl-dnn/releases/download/v0.9/${MKLML_VER}.tgz") +SET(MKLML_URL "https://github.com/01org/mkl-dnn/releases/download/v0.10/${MKLML_VER}.tgz") SET(MKLML_SOURCE_DIR "${THIRD_PARTY_PATH}/mklml") SET(MKLML_DOWNLOAD_DIR "${MKLML_SOURCE_DIR}/src/${MKLML_PROJECT}") SET(MKLML_DST_DIR "mklml") @@ -54,7 +54,8 @@ ExternalProject_Add( ${EXTERNAL_PROJECT_LOG_ARGS} PREFIX ${MKLML_SOURCE_DIR} DOWNLOAD_DIR ${MKLML_DOWNLOAD_DIR} - DOWNLOAD_COMMAND wget --no-check-certificate -qO- ${MKLML_URL} | tar xz -C ${MKLML_DOWNLOAD_DIR} + DOWNLOAD_COMMAND wget --no-check-certificate ${MKLML_URL} -c -q -O ${MKLML_VER}.tgz + && tar zxf ${MKLML_VER}.tgz DOWNLOAD_NO_PROGRESS 1 UPDATE_COMMAND "" CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${MKLML_INSTALL_ROOT} diff --git a/doc/about/index_cn.md b/doc/about/index_cn.md deleted file mode 100644 index 3bf030004d4de..0000000000000 --- a/doc/about/index_cn.md +++ /dev/null @@ -1,11 +0,0 @@ -关于PaddlePaddle -================ - -PaddlePaddle是一个最早由百度科学家和工程师共同研发的并行分布式深度学习平台,兼备易用性、高效性、灵活性和可扩展性,目前已被百度内部多个产品线广泛使用。 -PaddlePaddle目前已经开放源码, 但是远未完善,我们希望能在这个基础上不断的改进、扩展和延伸。 -同时我们希望广大开发者积极提供反馈和贡献源代码,建立一个活跃的开源社区。 - -致谢 --------- - -在此,特别感谢PaddlePaddle的[所有贡献者](https://github.com/PaddlePaddle/Paddle/graphs/contributors)。 diff --git a/doc/about/index_en.rst b/doc/about/index_en.rst deleted file mode 100644 index 065c430cdea80..0000000000000 --- a/doc/about/index_en.rst +++ /dev/null @@ -1,14 +0,0 @@ -ABOUT -======= - -PaddlPaddle is an easy-to-use, efficient, flexible and scalable deep learning platform, -which is originally developed by Baidu scientists and engineers for the purpose of applying deep learning to many products at Baidu. - -PaddlePaddle is now open source but far from complete, which is intended to be built upon, improved, scaled, and extended. -We hope to build an active open source community both by providing feedback and by actively contributing to the source code. - - -Credits --------- - -We owe many thanks to `all contributors and developers `_ of PaddlePaddle! diff --git a/doc/api/v2/config/layer.rst b/doc/api/v2/config/layer.rst index 2273c8e8698c1..1329b77bb44f5 100644 --- a/doc/api/v2/config/layer.rst +++ b/doc/api/v2/config/layer.rst @@ -419,9 +419,14 @@ multi_binary_label_cross_entropy_cost .. autoclass:: paddle.v2.layer.multi_binary_label_cross_entropy_cost :noindex: -huber_cost ----------- -.. autoclass:: paddle.v2.layer.huber_cost +huber_regression_cost +------------------------- +.. autoclass:: paddle.v2.layer.huber_regression_cost + :noindex: + +huber_classification_cost +------------------------- +.. autoclass:: paddle.v2.layer.huber_classification_cost :noindex: lambda_cost diff --git a/doc/getstarted/build_and_install/index_cn.rst b/doc/getstarted/build_and_install/index_cn.rst index a24df6c518fad..dd9923697ab85 100644 --- a/doc/getstarted/build_and_install/index_cn.rst +++ b/doc/getstarted/build_and_install/index_cn.rst @@ -6,14 +6,12 @@ 安装流程 ++++++++ -PaddlePaddle提供数个预编译的二进制来进行安装,包括Docker镜像,ubuntu的deb安装包等。我们推荐使用Docker镜像来部署环境,同时欢迎贡献更多的安装包。 +PaddlePaddle提供Docker镜像来部署环境。 .. toctree:: :maxdepth: 1 docker_install_cn.rst - ubuntu_install_cn.rst - 编译流程 diff --git a/doc/getstarted/build_and_install/index_en.rst b/doc/getstarted/build_and_install/index_en.rst index 1bfd4f75c0b9b..8a53588e0439d 100644 --- a/doc/getstarted/build_and_install/index_en.rst +++ b/doc/getstarted/build_and_install/index_en.rst @@ -8,14 +8,13 @@ Install PaddlePaddle :maxdepth: 1 docker_install_en.rst - ubuntu_install_en.rst Build from Source ----------------- .. warning:: - Please use :code:`deb` package or :code:`docker` image to install paddle. The building guide is used for hacking or contributing PaddlePaddle source code. + Please use :code:`docker` image to install paddle. The building guide is used for hacking or contributing PaddlePaddle source code. .. toctree:: :maxdepth: 1 diff --git a/doc/getstarted/build_and_install/ubuntu_install_cn.rst b/doc/getstarted/build_and_install/ubuntu_install_cn.rst deleted file mode 100644 index 9e39ccb00f5d5..0000000000000 --- a/doc/getstarted/build_and_install/ubuntu_install_cn.rst +++ /dev/null @@ -1,71 +0,0 @@ -Ubuntu部署PaddlePaddle -=================================== - -PaddlePaddle提供了ubuntu 14.04 deb安装包。 - -安装 ------- - -安装包的下载地址是\: https://github.com/PaddlePaddle/Paddle/releases - -它包含四个版本\: - -* cpu版本: 支持主流x86处理器平台, 使用了avx指令集。 - -* cpu-noavx版本:支持主流x86处理器平台,没有使用avx指令集。 - -* gpu版本:支持主流x86处理器平台,支持nvidia cuda平台,使用了avx指令集。 - -* gpu-noavx版本:支持主流x86处理器平台,支持nvidia cuda平台,没有使用avx指令集。 - -下载完相关安装包后,执行: - -.. code-block:: shell - - sudo apt-get install gdebi - gdebi paddle-*-cpu.deb - -或者: - -.. code-block:: shell - - dpkg -i paddle-*-cpu.deb - apt-get install -f - - -在 :code:`dpkg -i` 的时候如果报一些依赖未找到的错误是正常的, -在 :code:`apt-get install -f` 里会继续安装 PaddlePaddle。 - -安装完成后,可以使用命令 :code:`paddle version` 查看安装后的paddle 版本: - -.. code-block:: shell - - PaddlePaddle 0.8.0b1, compiled with - with_avx: ON - with_gpu: OFF - with_double: OFF - with_python: ON - with_rdma: OFF - with_timer: OFF - with_predict_sdk: - - -可能遇到的问题 --------------- - -libcudart.so/libcudnn.so找不到 -++++++++++++++++++++++++++++++ - -安装完成后,运行 :code:`paddle train` 报错\: - -.. code-block:: shell - - 0831 12:36:04.151525 1085 hl_dso_loader.cc:70] Check failed: nullptr != *dso_handle For Gpu version of PaddlePaddle, it couldn't find CUDA library: libcudart.so Please make sure you already specify its path.Note: for training data on Cpu using Gpu version of PaddlePaddle,you must specify libcudart.so via LD_LIBRARY_PATH. - -原因是未设置cuda运行时环境变量。 如果使用GPU版本的PaddlePaddle,请安装CUDA 7.5 和CUDNN 5到本地环境中,并设置: - -.. code-block:: shell - - export LD_LIBRARY_PATH=/usr/local/cuda/lib64:/usr/local/cuda/lib:$LD_LIBRARY_PATH - export PATH=/usr/local/cuda/bin:$PATH - diff --git a/doc/getstarted/build_and_install/ubuntu_install_en.rst b/doc/getstarted/build_and_install/ubuntu_install_en.rst deleted file mode 100644 index ea8042085bf45..0000000000000 --- a/doc/getstarted/build_and_install/ubuntu_install_en.rst +++ /dev/null @@ -1,25 +0,0 @@ -Debian Package installation guide -================================= - -PaddlePaddle supports :code:`deb` pacakge. The installation of this :code:`deb` package is tested in ubuntu 14.04, but it should be support other debian based linux, too. - -There are four versions of debian package, :code:`cpu`, :code:`gpu`, :code:`cpu-noavx`, :code:`gpu-noavx`. And :code:`noavx` version is used to support CPU which does not contain :code:`AVX` instructions. The download url of :code:`deb` package is \: https://github.com/baidu/Paddle/releases/ - - -After downloading PaddlePaddle deb packages, you can use :code:`gdebi` install. - -.. code-block:: bash - - gdebi paddle-*.deb - -If :code:`gdebi` is not installed, you can use :code:`sudo apt-get install gdebi` to install it. - -Or you can use following commands to install PaddlePaddle. - -.. code-block:: bash - - dpkg -i paddle-*.deb - apt-get install -f - -And if you use GPU version deb package, you need to install CUDA toolkit and cuDNN, and set related environment variables(such as LD_LIBRARY_PATH) first. It is normal when `dpkg -i` get errors. `apt-get install -f` will continue install paddle, and install dependences. - diff --git a/doc/howto/dev/new_op_cn.md b/doc/howto/dev/new_op_cn.md index ebd2cf3ff0456..7f8da2da5a0d4 100644 --- a/doc/howto/dev/new_op_cn.md +++ b/doc/howto/dev/new_op_cn.md @@ -5,12 +5,13 @@ - [定义ProtoMaker类](#定义ProtoMaker类) - [定义Operator类](#定义Operator类) - [定义OpKernel类](#定义OpKernel类) - - [注册类](#注册类) + - [注册Operator](#注册Operator) - [编译](#编译) - [绑定Python](#绑定Python) - [实现单元测试](#实现单元测试) - [前向Operator单测](#前向Operator单测) - [反向Operator单测](#反向Operator单测) + - [编译和执行](#编译和执行) ## 概念简介 @@ -22,19 +23,17 @@ - `framework::OperatorWithKernel`:继承自OperatorBase,Op有计算函数,称作有Kernel。 - `class OpProtoAndCheckerMaker`:描述该Op的输入、输出、属性、注释,主要用于Python API接口生成 -依据是否包含kernel,将Op分为两种:包含Kernel的Op和不包含kernel的Op,前者Op的定义继承自`OperatorBase`,后者继承自`OperatorWithKernel`。本教程主要介绍带Kernel的Op如何写,简单总结如下: +依据是否包含kernel,将Op分为两种:包含Kernel的Op和不包含kernel的Op,前者Op的定义继承自`OperatorBase`,后者继承自`OperatorWithKernel`。本教程主要介绍带Kernel的Op如何写,简单总结Op需要包含的内容如下: -Forward Op需要包含: - - - OpProtoMake定义 - - Op定义 - - Kernel实现 + + 内容 | 定义位置 +-------------- | :---------------------- +OpProtoMake定义 | `.cc`文件,Backward Op不需要定义OpProtoMake +Op定义 | `.cc`文件 +Kernel实现 | CPU、GPU共享Kernel在`.h`文件,否则,CPU可以在`.cc`文件,GPU可在`.cu`文件。 +注册Op | Op注册在`.cc`文件;Kernel注册CPU在`.cc`文件,GPU在`.cu`文件 + -与之对应的Backward Op包含: - - - Op定义 - - Kernel实现 - 下面以矩阵乘操作,即[MulOp](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/mul_op.cc)为例来介绍如何写带Kernel的Operator。 @@ -137,8 +136,9 @@ MulOp(const std::string &type, const framework::VariableNameMap &inputs, ``` 还需要重写`InferShape`接口。`InferShape`为const函数,不能修改Op的成员变量,参数为`const framework::InferShapeContext &ctx`,通过该参数可获取到输入输出以及属性。它的功能是: - - 1). 做检查, 尽早报错:检查输入数据维度、类型等是否合法 - - 2). 设置输出Tensor的形状 + + - 1). 做检查, 尽早报错:检查输入数据维度、类型等是否合法。 + - 2). 设置输出Tensor的形状。 通常`OpProtoMaker`和`Op`类的定义写在`.cc`文件中,和要讲到的注册函数一起放在`.cc`中 @@ -172,7 +172,7 @@ class MulKernel : public framework::OpKernel { 到此前向Op实现完成,需要在`.cc`文件中注册该op和kernel。反向Op类的定义和Kernel定义与前向Op类似,这里不再重复。但注意,反向Op没有`ProtoMaker`。 -### 4. 注册类 +### 4. 注册Operator 在`.cc`文件中注册前向、反向Op类,注册CPU Kernel。 @@ -297,4 +297,28 @@ class TestMulOp(unittest.TestCase): - 调用`create_op("mul")`创建反向Op对应的前向Op。 - 定义输入`inputs`。 - 调用`compare_grad`函数对比CPU、GPU计算结果。 - - 调用`check_grad`检查梯度稳定性。 + - 调用`check_grad`检查梯度稳定性,这里采用数值法检测梯度正确性。 + - 第一个参数`op` : 前向op。 + - 第二个参数`inputs` : 输入词典,词典的Key和`ProtoMaker`定义保持一致。 + - 第三个参数`set(["X", "Y"])` : 指定对输入变量`X`、`Y`做梯度检测。 + - 第四个参数`"Out"` : 指定前向网络最终的输出目标变量`Out` + + +### 编译和执行 + +单测完成之后,在[`python/paddle/v2/framework/tests/CMakeLists.txt`](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/framework/tests/CMakeLists.txt)里添加编译: + +``` +py_test(test_mul_op SRCS test_mul_op.py) +``` + +编译时需要打开`WITH_TESTING`, 即 `cmake paddle_dir -DWITH_TESTING=ON`,编译成功之后执行单测命令为: + +``` +make test ARGS="-R test_mul_op -V" +``` +或者: + +``` +ctest -R test_mul_op +``` diff --git a/doc/index_en.rst b/doc/index_en.rst index 168c7667c61da..64684b8b9b27e 100644 --- a/doc/index_en.rst +++ b/doc/index_en.rst @@ -7,4 +7,3 @@ PaddlePaddle Documentation getstarted/index_en.rst howto/index_en.rst api/index_en.rst - about/index_en.rst diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index bfda18724cc8e..6b4c612cd8d92 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -124,6 +124,9 @@ static std::unique_ptr BackwardRecursive( std::list insert_position; for (auto& dup_output_op : dup_output_ops) { const std::string& name = dup_output_op.first; + // duplicate @Empty@ don't need to be added + if (name == kEmptyVarName) continue; + auto& dup_op = dup_output_op.second; // no duplicate output if (dup_op.size() == 1) continue; @@ -209,7 +212,7 @@ std::unique_ptr Backward( const OperatorBase& forwardOp, const std::unordered_set& no_grad_vars) { std::unordered_set no_grad_names; - no_grad_names.reserve(no_grad_vars.size()); + no_grad_names.reserve(no_grad_vars.size() + 1); no_grad_names.insert(std::string(kEmptyVarName) + kGradVarSuffix); diff --git a/paddle/framework/backward.md b/paddle/framework/backward.md index c8fa3fefe5632..8aa6728a95bc4 100644 --- a/paddle/framework/backward.md +++ b/paddle/framework/backward.md @@ -1,23 +1,53 @@ -## Operator/expression 's Backward +# Operator/expression 's Backward -### Motivation +## Motivation -In Neural Network, the backpropagation algorithm follows the chain rule, so we need to compound the fundmental gradient operators/expressions together with chain rule . Every forward network need a backward network to construct the full computation lineage, the operator/ expression's Backward feature will generate the backward pass respect to forward pass. +In Neural Network, the backpropagation algorithm follows the chain rule, so we need to compound the fundmental gradient operators/expressions together with chain rule . Every forward network need a backward network to construct the full computation graph, the operator/expression's backward pass will be generated respect to forward pass. + +## Backward Operator Registry -### Implement : gradient operator registry +A backward network is built up with several backward operators. Backward operators take forward operators' inputs, outputs and output gradients and then calculate its input gradients. -| | forward operator | backward operator | -| ---------------------- | ---------------- | -------------------------------- | -| **Operator::inputs_** | Inputs | Inputs, Outputs, OutputGradients | -| **Operator::outputs_** | Outputs | InputGradients | +| | forward operator | backward operator +| ---------------------- | ---------------- |------------------------- | +| **Operator::inputs_** | Inputs | Inputs, Outputs, OutputGradients | +| **Operator::outputs_** | Outputs | InputGradients | -Inputs/Outputs means the input/output of the operator, InputGradients/OutputGradients is the gradient respect to forward opeartor. Forward operator and Backward operator are isomorphic, save their corresponding needs into member attribute. + In most cases, there is a one-to-one correspondence between forward and backward operators. These correspondences are recorded by a global hash map(`OpInfoMap`). To follow the philosophy of minimum core and make operators pluggable, the registry mechanism is introduced. -We use a global hash map record the gradient operators available, follow the philosophy of minimum core, make operator pluggable unit. Each gradient is an operator and it needs to regist itself. +For example, we have got a `mul_op`, and we can register it's information and corresponding backward operator by the following macro: -grad_op_builder(fengjiayi) +```cpp +REGISTER_OP(mul, MulOp, MulOpMaker, mul_grad, MulOpGrad); +``` -### Implement : Backward network +`mul` is the operator's type. `MulOp` and `MulOpMaker` are the operator class and the operator maker class respectively. + +`mul_grad` is the type of backward operator, and `MulOpGrad` is its class name. + +## Backward Opeartor Creating + +Given a certain forward operator, we can get its corresponding backward opeartor by calling: + +```cpp +OperatorBase* bwd_op = BuildGradOp(const OperatorBase* fwd_op); +``` + +The function `BuildGradOp` will sequentially execute following processes: + +1. Get the `type_` of given forward operator, and then get the corresponding backward operator's type by looking up the `OpInfoMap`. + +2. Build two maps named `inputs` and `outputs` to temporary storage backward operator's inputs and outputs. Copy forward operator's `inputs_` and `outputs_` to map `inputs`, except these are not necessary for gradient computing. + +3. Add forward inputs' gradient variables into map `output`, adding forward outputs' gradient variables into map `input`. + +4. Building backward operator with `inputs`, `outputs` and forward operator's attributes. + +## Backward Network Building + +A backward network is a series of backward operators. The main idea of building a backward network is creating backward operators in the inverted sequence and put them together. + +In our design, the network itself is also a kind of operator. So the operators contained by a big network may be some small network. given a forward network, it generates the backward network. We only care about the Gradients—`OutputGradients`,`InputGradients`. diff --git a/paddle/function/CMakeLists.txt b/paddle/function/CMakeLists.txt index c572a9d433bc1..f43f15e5cacb7 100644 --- a/paddle/function/CMakeLists.txt +++ b/paddle/function/CMakeLists.txt @@ -21,6 +21,8 @@ if(USE_NNPACK) endif() endif() +list(APPEND cpp_files neon/NeonDepthwiseConv.cpp) + add_library(paddle_function STATIC ${cpp_files} ${cu_objs}) add_dependencies(paddle_function ${external_project_dependencies}) add_dependencies(paddle_function paddle_proto) @@ -42,11 +44,11 @@ if(WITH_GPU) add_simple_unittest(RowConvOpTest) add_simple_unittest(BlockExpandOpTest) add_simple_unittest(CropOpTest) - add_simple_unittest(DepthwiseConvOpTest) endif() add_simple_unittest(Im2ColTest) add_simple_unittest(GemmConvOpTest) +add_simple_unittest(DepthwiseConvOpTest) endif() add_style_check_target(paddle_function ${h_files}) diff --git a/paddle/function/DepthwiseConvOpTest.cpp b/paddle/function/DepthwiseConvOpTest.cpp index f44ae0c342e95..d8e8c889d5c23 100644 --- a/paddle/function/DepthwiseConvOpTest.cpp +++ b/paddle/function/DepthwiseConvOpTest.cpp @@ -34,4 +34,13 @@ TEST(DepthwiseConv, BackwardFilter) { } #endif +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + +TEST(DepthwiseConv, Forward) { + DepthwiseConvolution( + "GemmConv-CPU", "NeonDepthwiseConv-CPU", forward); +} + +#endif + } // namespace paddle diff --git a/paddle/function/Im2Col.h b/paddle/function/Im2Col.h index 48e2e32f9256f..9b91e223a6a28 100644 --- a/paddle/function/Im2Col.h +++ b/paddle/function/Im2Col.h @@ -16,6 +16,7 @@ limitations under the License. */ #include "TensorShape.h" #include "TensorType.h" +#include "neon/neon_util.h" namespace paddle { @@ -93,4 +94,95 @@ class Col2ImFunctor { int paddingWidth); }; +template +struct Padding { + static void run(const T* src, + T* dest, + int channels, + int inputHeight, + int inputWidth, + int paddingHeight, + int paddingWidth) { + const int destWidth = inputWidth + 2 * paddingWidth; + for (int c = 0; c < channels; c++) { + if (paddingHeight > 0) { + memset(dest, 0, destWidth * paddingHeight * sizeof(T)); + dest += destWidth * paddingHeight; + } + + for (int i = 0; i < inputHeight; i++) { + // padding head + for (int j = 0; j < paddingWidth; j++) { + *dest++ = T(0); + } + + memcpy(dest, src, inputWidth * sizeof(T)); + dest += inputWidth; + src += inputWidth; + + // padding tail + for (int j = 0; j < paddingWidth; j++) { + *dest++ = T(0); + } + } + + if (paddingHeight > 0) { + memset(dest, 0, destWidth * paddingHeight * sizeof(T)); + dest += destWidth * paddingHeight; + } + } + } +}; + +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +template <> +struct Padding { + static void run(const float* src, + float* dest, + int channels, + int inputHeight, + int inputWidth, + int paddingHeight, + int paddingWidth) { + const int destWidth = inputWidth + 2 * paddingWidth; + for (int c = 0; c < channels; c++) { + if (paddingHeight > 0) { + memset(dest, 0, destWidth * paddingHeight * sizeof(float)); + dest += destWidth * paddingHeight; + } + + for (int i = 0; i < inputHeight; i++) { + // padding head + for (int j = 0; j < paddingWidth; j++) { + *dest++ = float(0); + } + + int step = inputWidth >> 2; + int remain = inputWidth & 3; + for (int s = 0; s < step; s++) { + float32x4_t s0 = vld1q_f32(src); + vst1q_f32(dest, s0); + src += 4; + dest += 4; + } + for (int r = 0; r < remain; r++) { + *dest++ = *src++; + } + + // padding tail + for (int j = 0; j < paddingWidth; j++) { + *dest++ = float(0); + } + } + + if (paddingHeight > 0) { + memset(dest, 0, destWidth * paddingHeight * sizeof(float)); + dest += destWidth * paddingHeight; + } + } + } +}; + +#endif + } // namespace paddle diff --git a/paddle/function/neon/NeonDepthwiseConv.cpp b/paddle/function/neon/NeonDepthwiseConv.cpp new file mode 100644 index 0000000000000..f09e98587d168 --- /dev/null +++ b/paddle/function/neon/NeonDepthwiseConv.cpp @@ -0,0 +1,577 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "neon_util.h" +#include "paddle/function/ConvOp.h" +#include "paddle/function/Im2Col.h" + +namespace paddle { + +namespace neon { + +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + +template +struct DepthwiseConvKernel {}; + +inline float32_t conv3x3(float32x4_t r0, + float32x4_t r1, + float32x4_t r2, + float32x4_t k0, + float32x4_t k1, + float32x4_t k2) { + float32x4_t tmp; + tmp = vmulq_f32(r0, k0); + tmp = vmlaq_f32(tmp, r1, k1); + tmp = vmlaq_f32(tmp, r2, k2); + return vaddvq_f32(tmp); +} + +inline float32_t conv4x4(float32x4_t r0, + float32x4_t r1, + float32x4_t r2, + float32x4_t r3, + float32x4_t k0, + float32x4_t k1, + float32x4_t k2, + float32x4_t k3) { + float32x4_t tmp; + tmp = vmulq_f32(r0, k0); + tmp = vmlaq_f32(tmp, r1, k1); + tmp = vmlaq_f32(tmp, r2, k2); + tmp = vmlaq_f32(tmp, r3, k3); + return vaddvq_f32(tmp); +} + +/** + * Each step calculates four elements of the output. + * First step: + * R0[0, 1, 2, 3...] * K[0][0] + * R0[1, 2, 3, 4...] * K[0][1] + * R0[2, 3, 4, 5...] * K[0][2] + * R1[0, 1, 2, 3...] * K[1][0] + * R1[1, 2, 3, 4...] * K[1][1] + * R1[2, 3, 4, 5...] * K[1][2] + * R2[0, 1, 2, 3...] * K[2][0] + * R2[1, 2, 3, 4...] * K[2][1] + * + R2[2, 3, 4, 5...] * K[2][2] + * ------------------------------ + * Output[0, 1, 2, 3] + */ +template <> +struct DepthwiseConvKernel<3, 1> { + static void run(const float* inputData, + const float* filterData, + int inputHeight, + int inputWidth, + int outputChannels, + int outputHeight, + int outputWidth, + int filterMultiplier, + float* outputData) { + const int steps = outputWidth >> 2; + const int remain = outputWidth & 3; + for (int c = 0; c < outputChannels; c++, filterData += 9) { + // Load the filters + float32x4_t k[3]; + k[0] = vld1q_f32(filterData); + k[1] = vld1q_f32(filterData + 3); + k[2] = vld1q_f32(filterData + 6); + k[0] = vsetq_lane_f32(0.f, k[0], 3); + k[1] = vsetq_lane_f32(0.f, k[1], 3); + k[2] = vsetq_lane_f32(0.f, k[2], 3); + + const float* r0 = + inputData + (c / filterMultiplier) * (inputHeight * inputWidth); + const float* r1 = r0 + inputWidth; + const float* r2 = r0 + inputWidth * 2; + float32x4_t input[3][3]; + for (int h = 0; h < outputHeight; h++) { + for (int s = 0; s < steps; s++) { + // Load the inputs + float32x4_t tmp; + input[0][0] = vld1q_f32(r0); + tmp = vld1q_f32(r0 + 4); + input[0][1] = vextq_f32(input[0][0], tmp, 1); + input[0][2] = vextq_f32(input[0][0], tmp, 2); + input[1][0] = vld1q_f32(r1); + tmp = vld1q_f32(r1 + 4); + input[1][1] = vextq_f32(input[1][0], tmp, 1); + input[1][2] = vextq_f32(input[1][0], tmp, 2); + input[2][0] = vld1q_f32(r2); + tmp = vld1q_f32(r2 + 4); + input[2][1] = vextq_f32(input[2][0], tmp, 1); + input[2][2] = vextq_f32(input[2][0], tmp, 2); + + float32x4_t tmp1 = vdupq_n_f32(0.f); + float32x4_t tmp2 = vdupq_n_f32(0.f); + tmp1 = vmlaq_laneq_f32(tmp1, input[0][0], k[0], 0); + tmp2 = vmlaq_laneq_f32(tmp2, input[0][1], k[0], 1); + tmp1 = vmlaq_laneq_f32(tmp1, input[0][2], k[0], 2); + tmp2 = vmlaq_laneq_f32(tmp2, input[1][0], k[1], 0); + tmp1 = vmlaq_laneq_f32(tmp1, input[1][1], k[1], 1); + tmp2 = vmlaq_laneq_f32(tmp2, input[1][2], k[1], 2); + tmp1 = vmlaq_laneq_f32(tmp1, input[2][0], k[2], 0); + tmp2 = vmlaq_laneq_f32(tmp2, input[2][1], k[2], 1); + tmp1 = vmlaq_laneq_f32(tmp1, input[2][2], k[2], 2); + tmp1 = vaddq_f32(tmp1, tmp2); + + vst1q_f32(outputData, tmp1); + r0 += 4; + r1 += 4; + r2 += 4; + outputData += 4; + } + + for (int r = 0; r < remain; r++) { + float32x4_t i0 = vld1q_f32(r0); + float32x4_t i1 = vld1q_f32(r1); + float32x4_t i2 = vld1q_f32(r2); + *outputData = conv3x3(i0, i1, i2, k[0], k[1], k[2]); + r0++; + r1++; + r2++; + outputData++; + } + + r0 += 2; + r1 += 2; + r2 += 2; + } + } + } +}; + +/** + * Each step calculates four elements of the output. + * First step: + * R0[0, 2, 4, 6...] * K[0][0] + * R0[1, 3, 5, 7...] * K[0][1] + * R0[2, 4, 6, 8...] * K[0][2] + * R1[0, 2, 4, 6...] * K[1][0] + * R1[1, 3, 5, 7...] * K[1][1] + * R1[2, 4, 6, 8...] * K[1][2] + * R2[0, 2, 4, 6...] * K[2][0] + * R2[1, 3, 5, 7...] * K[2][1] + * R2[2, 4, 6, 8...] * K[2][2] + * ------------------------------ + * Output[0, 1, 2, 3] + */ +template <> +struct DepthwiseConvKernel<3, 2> { + static void run(const float* inputData, + const float* filterData, + int inputHeight, + int inputWidth, + int outputChannels, + int outputHeight, + int outputWidth, + int filterMultiplier, + float* outputData) { + const int steps = outputWidth >> 2; + const int remain = outputWidth & 3; + for (int c = 0; c < outputChannels; c++, filterData += 9) { + // Load the filters + float32x4_t k[3]; + k[0] = vld1q_f32(filterData); + k[1] = vld1q_f32(filterData + 3); + k[2] = vld1q_f32(filterData + 6); + k[0] = vsetq_lane_f32(0.f, k[0], 3); + k[1] = vsetq_lane_f32(0.f, k[1], 3); + k[2] = vsetq_lane_f32(0.f, k[2], 3); + + const float* start = + inputData + (c / filterMultiplier) * (inputHeight * inputWidth); + float32x4_t input[3][3]; + for (int h = 0; h < outputHeight; h++) { + const float* r0 = start + 2 * h * inputWidth; + const float* r1 = start + (2 * h + 1) * inputWidth; + const float* r2 = start + (2 * h + 2) * inputWidth; + for (int s = 0; s < steps; s++) { + // Load the inputs + float32x4_t data1; + float32x4x2_t data2; + + data2 = vld2q_f32(r0); + input[0][0] = data2.val[0]; + input[0][1] = data2.val[1]; + data1 = vld1q_f32(r0 + 8); + input[0][2] = vextq_f32(data2.val[0], data1, 1); + + data2 = vld2q_f32(r1); + input[1][0] = data2.val[0]; + input[1][1] = data2.val[1]; + data1 = vld1q_f32(r1 + 8); + input[1][2] = vextq_f32(data2.val[0], data1, 1); + + data2 = vld2q_f32(r2); + input[2][0] = data2.val[0]; + input[2][1] = data2.val[1]; + data1 = vld1q_f32(r2 + 8); + input[2][2] = vextq_f32(data2.val[0], data1, 1); + + float32x4_t tmp1 = vdupq_n_f32(0.f); + float32x4_t tmp2 = vdupq_n_f32(0.f); + tmp1 = vmlaq_laneq_f32(tmp1, input[0][0], k[0], 0); + tmp2 = vmlaq_laneq_f32(tmp2, input[0][1], k[0], 1); + tmp1 = vmlaq_laneq_f32(tmp1, input[0][2], k[0], 2); + tmp2 = vmlaq_laneq_f32(tmp2, input[1][0], k[1], 0); + tmp1 = vmlaq_laneq_f32(tmp1, input[1][1], k[1], 1); + tmp2 = vmlaq_laneq_f32(tmp2, input[1][2], k[1], 2); + tmp1 = vmlaq_laneq_f32(tmp1, input[2][0], k[2], 0); + tmp2 = vmlaq_laneq_f32(tmp2, input[2][1], k[2], 1); + tmp1 = vmlaq_laneq_f32(tmp1, input[2][2], k[2], 2); + tmp1 = vaddq_f32(tmp1, tmp2); + + vst1q_f32(outputData, tmp1); + r0 += 8; + r1 += 8; + r2 += 8; + outputData += 4; + } + + for (int r = 0; r < remain; r++) { + float32x4_t i0 = vld1q_f32(r0); + float32x4_t i1 = vld1q_f32(r1); + float32x4_t i2 = vld1q_f32(r2); + *outputData = conv3x3(i0, i1, i2, k[0], k[1], k[2]); + r0 += 2; + r1 += 2; + r2 += 2; + outputData++; + } + } + } + } +}; + +/** + * Each step calculates four elements of the output. + */ +template <> +struct DepthwiseConvKernel<4, 1> { + static void run(const float* inputData, + const float* filterData, + int inputHeight, + int inputWidth, + int outputChannels, + int outputHeight, + int outputWidth, + int filterMultiplier, + float* outputData) { + const int steps = outputWidth >> 2; + const int remain = outputWidth & 3; + for (int c = 0; c < outputChannels; c++, filterData += 16) { + // Load the filters + float32x4_t k[4]; + k[0] = vld1q_f32(filterData); + k[1] = vld1q_f32(filterData + 4); + k[2] = vld1q_f32(filterData + 8); + k[3] = vld1q_f32(filterData + 12); + + const float* r0 = + inputData + (c / filterMultiplier) * (inputHeight * inputWidth); + const float* r1 = r0 + inputWidth; + const float* r2 = r0 + inputWidth * 2; + const float* r3 = r0 + inputWidth * 3; + float32x4_t input[4][4]; + for (int h = 0; h < outputHeight; h++) { + for (int s = 0; s < steps; s++) { + // Load the inputs + float32x4_t tmp; + input[0][0] = vld1q_f32(r0); + tmp = vld1q_f32(r0 + 4); + input[0][1] = vextq_f32(input[0][0], tmp, 1); + input[0][2] = vextq_f32(input[0][0], tmp, 2); + input[0][3] = vextq_f32(input[0][0], tmp, 3); + + input[1][0] = vld1q_f32(r1); + tmp = vld1q_f32(r1 + 4); + input[1][1] = vextq_f32(input[1][0], tmp, 1); + input[1][2] = vextq_f32(input[1][0], tmp, 2); + input[1][3] = vextq_f32(input[1][0], tmp, 3); + + input[2][0] = vld1q_f32(r2); + tmp = vld1q_f32(r2 + 4); + input[2][1] = vextq_f32(input[2][0], tmp, 1); + input[2][2] = vextq_f32(input[2][0], tmp, 2); + input[2][3] = vextq_f32(input[2][0], tmp, 3); + + input[3][0] = vld1q_f32(r3); + tmp = vld1q_f32(r3 + 4); + input[3][1] = vextq_f32(input[3][0], tmp, 1); + input[3][2] = vextq_f32(input[3][0], tmp, 2); + input[3][3] = vextq_f32(input[3][0], tmp, 3); + + float32x4_t tmp1 = vdupq_n_f32(0.f); + float32x4_t tmp2 = vdupq_n_f32(0.f); + tmp1 = vmlaq_laneq_f32(tmp1, input[0][0], k[0], 0); + tmp2 = vmlaq_laneq_f32(tmp2, input[0][1], k[0], 1); + tmp1 = vmlaq_laneq_f32(tmp1, input[0][2], k[0], 2); + tmp2 = vmlaq_laneq_f32(tmp2, input[0][3], k[0], 3); + tmp1 = vmlaq_laneq_f32(tmp1, input[1][0], k[1], 0); + tmp2 = vmlaq_laneq_f32(tmp2, input[1][1], k[1], 1); + tmp1 = vmlaq_laneq_f32(tmp1, input[1][2], k[1], 2); + tmp2 = vmlaq_laneq_f32(tmp2, input[1][3], k[1], 3); + tmp1 = vmlaq_laneq_f32(tmp1, input[2][0], k[2], 0); + tmp2 = vmlaq_laneq_f32(tmp2, input[2][1], k[2], 1); + tmp1 = vmlaq_laneq_f32(tmp1, input[2][2], k[2], 2); + tmp2 = vmlaq_laneq_f32(tmp2, input[2][3], k[2], 3); + tmp1 = vmlaq_laneq_f32(tmp1, input[3][0], k[3], 0); + tmp2 = vmlaq_laneq_f32(tmp2, input[3][1], k[3], 1); + tmp1 = vmlaq_laneq_f32(tmp1, input[3][2], k[3], 2); + tmp2 = vmlaq_laneq_f32(tmp2, input[3][3], k[3], 3); + tmp1 = vaddq_f32(tmp1, tmp2); + + vst1q_f32(outputData, tmp1); + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + outputData += 4; + } + + for (int r = 0; r < remain; r++) { + float32x4_t i0 = vld1q_f32(r0); + float32x4_t i1 = vld1q_f32(r1); + float32x4_t i2 = vld1q_f32(r2); + float32x4_t i3 = vld1q_f32(r3); + *outputData = conv4x4(i0, i1, i2, i3, k[0], k[1], k[2], k[3]); + r0++; + r1++; + r2++; + r3++; + outputData++; + } + + r0 += 3; + r1 += 3; + r2 += 3; + r3 += 3; + } + } + } +}; + +/** + * Each step calculates four elements of the output. + */ +template <> +struct DepthwiseConvKernel<4, 2> { + static void run(const float* inputData, + const float* filterData, + int inputHeight, + int inputWidth, + int outputChannels, + int outputHeight, + int outputWidth, + int filterMultiplier, + float* outputData) { + const int steps = outputWidth >> 2; + const int remain = outputWidth & 3; + for (int c = 0; c < outputChannels; c++, filterData += 16) { + // Load the filters + float32x4_t k[4]; + k[0] = vld1q_f32(filterData); + k[1] = vld1q_f32(filterData + 4); + k[2] = vld1q_f32(filterData + 8); + k[3] = vld1q_f32(filterData + 12); + + const float* start = + inputData + (c / filterMultiplier) * (inputHeight * inputWidth); + float32x4_t input[4][4]; + for (int h = 0; h < outputHeight; h++) { + const float* r0 = start + 2 * h * inputWidth; + const float* r1 = start + (2 * h + 1) * inputWidth; + const float* r2 = start + (2 * h + 2) * inputWidth; + const float* r3 = start + (2 * h + 3) * inputWidth; + for (int s = 0; s < steps; s++) { + // Load the inputs + float32x4x2_t data1; + float32x4x2_t data2; + + data1 = vld2q_f32(r0); + data2 = vld2q_f32(r0 + 8); + input[0][0] = data1.val[0]; + input[0][1] = data1.val[1]; + input[0][2] = vextq_f32(data1.val[0], data2.val[0], 1); + input[0][3] = vextq_f32(data1.val[1], data2.val[1], 1); + + data1 = vld2q_f32(r1); + data2 = vld2q_f32(r1 + 8); + input[1][0] = data1.val[0]; + input[1][1] = data1.val[1]; + input[1][2] = vextq_f32(data1.val[0], data2.val[0], 1); + input[1][3] = vextq_f32(data1.val[1], data2.val[1], 1); + + data1 = vld2q_f32(r2); + data2 = vld2q_f32(r2 + 8); + input[2][0] = data1.val[0]; + input[2][1] = data1.val[1]; + input[2][2] = vextq_f32(data1.val[0], data2.val[0], 1); + input[2][3] = vextq_f32(data1.val[1], data2.val[1], 1); + + data1 = vld2q_f32(r3); + data2 = vld2q_f32(r3 + 8); + input[3][0] = data1.val[0]; + input[3][1] = data1.val[1]; + input[3][2] = vextq_f32(data1.val[0], data2.val[0], 1); + input[3][3] = vextq_f32(data1.val[1], data2.val[1], 1); + + float32x4_t tmp1 = vdupq_n_f32(0.f); + float32x4_t tmp2 = vdupq_n_f32(0.f); + tmp1 = vmlaq_laneq_f32(tmp1, input[0][0], k[0], 0); + tmp2 = vmlaq_laneq_f32(tmp2, input[0][1], k[0], 1); + tmp1 = vmlaq_laneq_f32(tmp1, input[0][2], k[0], 2); + tmp2 = vmlaq_laneq_f32(tmp2, input[0][3], k[0], 3); + tmp1 = vmlaq_laneq_f32(tmp1, input[1][0], k[1], 0); + tmp2 = vmlaq_laneq_f32(tmp2, input[1][1], k[1], 1); + tmp1 = vmlaq_laneq_f32(tmp1, input[1][2], k[1], 2); + tmp2 = vmlaq_laneq_f32(tmp2, input[1][3], k[1], 3); + tmp1 = vmlaq_laneq_f32(tmp1, input[2][0], k[2], 0); + tmp2 = vmlaq_laneq_f32(tmp2, input[2][1], k[2], 1); + tmp1 = vmlaq_laneq_f32(tmp1, input[2][2], k[2], 2); + tmp2 = vmlaq_laneq_f32(tmp2, input[2][3], k[2], 3); + tmp1 = vmlaq_laneq_f32(tmp1, input[3][0], k[3], 0); + tmp2 = vmlaq_laneq_f32(tmp2, input[3][1], k[3], 1); + tmp1 = vmlaq_laneq_f32(tmp1, input[3][2], k[3], 2); + tmp2 = vmlaq_laneq_f32(tmp2, input[3][3], k[3], 3); + tmp1 = vaddq_f32(tmp1, tmp2); + + vst1q_f32(outputData, tmp1); + r0 += 8; + r1 += 8; + r2 += 8; + r3 += 8; + outputData += 4; + } + + for (int r = 0; r < remain; r++) { + float32x4_t i0 = vld1q_f32(r0); + float32x4_t i1 = vld1q_f32(r1); + float32x4_t i2 = vld1q_f32(r2); + float32x4_t i3 = vld1q_f32(r3); + *outputData = conv4x4(i0, i1, i2, i3, k[0], k[1], k[2], k[3]); + r0 += 2; + r1 += 2; + r2 += 2; + r3 += 2; + outputData++; + } + } + } + } +}; + +template +class NeonDepthwiseConvFunction : public ConvFunctionBase { +public: + void init(const FuncConfig& config) override { + ConvFunctionBase::init(config); + } + + void check(const BufferArgs& inputs, const BufferArgs& outputs) override { + const TensorShape& input = inputs[0].shape(); + const TensorShape& filter = inputs[1].shape(); + const TensorShape& output = outputs[0].shape(); + checkShape(input, filter, output); + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(numInputs_, inputs.size()); + CHECK_EQ(numOutputs_, outputs.size()); + check(inputs, outputs); + + const TensorShape& input = inputs[0].shape(); + const TensorShape& filter = inputs[1].shape(); + const TensorShape& output = outputs[0].shape(); + + size_t batchSize = input[0]; + size_t inputChannels = input[1]; + size_t inputHeight = input[2]; + size_t inputWidth = input[3]; + size_t filterHeight = getFilterHeight(filter); + size_t filterWidth = getFilterWidth(filter); + size_t outputChannels = output[1]; + size_t outputHeight = output[2]; + size_t outputWidth = output[3]; + size_t filterMultiplier = outputChannels / groups_; + CHECK_EQ(inputChannels, groups_); + + // only support strideH() == strideW() and filterHeight == filterWidth. + CHECK_EQ(strideH(), strideW()); + CHECK_EQ(filterHeight, filterWidth); + + float* inputData = inputs[0].data(); + float* filterData = inputs[1].data(); + float* outputData = outputs[0].data(); + + // padding the input + float* inputPadding = inputData; + if (paddingH() > 0 || paddingW() > 0) { + int newSize = batchSize * inputChannels * (inputHeight + 2 * paddingH()) * + (inputWidth + 2 * paddingW()); + resizeBuffer(newSize); + inputPadding = reinterpret_cast(memory_->getBuf()); + Padding::run(inputData, + inputPadding, + batchSize * inputChannels, + inputHeight, + inputWidth, + paddingH(), + paddingW()); + + // height and width of padding data + inputHeight += 2 * paddingH(); + inputWidth += 2 * paddingW(); + } + + std::function + DepthWiseConv; + + if (filterWidth == 3 && strideW() == 1) { + DepthWiseConv = DepthwiseConvKernel<3, 1>::run; + } else if (filterWidth == 3 && strideW() == 2) { + DepthWiseConv = DepthwiseConvKernel<3, 2>::run; + } else if (filterWidth == 4 && strideW() == 1) { + DepthWiseConv = DepthwiseConvKernel<4, 1>::run; + } else if (filterWidth == 4 && strideW() == 2) { + DepthWiseConv = DepthwiseConvKernel<4, 2>::run; + } else { + LOG(FATAL) << "Not supported"; + } + + for (size_t i = 0; i < batchSize; i++) { + DepthWiseConv(inputPadding, + filterData, + inputHeight, + inputWidth, + outputChannels, + outputHeight, + outputWidth, + filterMultiplier, + outputData); + inputPadding += inputChannels * inputHeight * inputWidth; + outputData += outputChannels * outputHeight * outputWidth; + } + } +}; + +REGISTER_TYPED_FUNC(NeonDepthwiseConv, CPU, NeonDepthwiseConvFunction); + +#endif + +} // namespace neon +} // namespace paddle diff --git a/paddle/function/neon/neon_util.h b/paddle/function/neon/neon_util.h new file mode 100644 index 0000000000000..56b3febe2d27b --- /dev/null +++ b/paddle/function/neon/neon_util.h @@ -0,0 +1,47 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + +#include + +namespace paddle { + +namespace neon { + +inline float32x4_t vld1q_f32_aligned(const float* p) { + return vld1q_f32( + (const float*)__builtin_assume_aligned(p, sizeof(float32x4_t))); +} + +#ifndef __aarch64__ +inline float32_t vaddvq_f32(float32x4_t a) { + float32x2_t v = vadd_f32(vget_high_f32(a), vget_low_f32(a)); + return vget_lane_f32(vpadd_f32(v, v), 0); +} + +inline float32x4_t vmlaq_laneq_f32(float32x4_t a, + float32x4_t b, + float32x4_t v, + const int lane) { + return vmlaq_n_f32(a, b, vgetq_lane_f32(v, lane)); +} +#endif + +} // namespace neon +} // namespace paddle + +#endif diff --git a/paddle/gserver/layers/CostLayer.cpp b/paddle/gserver/layers/CostLayer.cpp index 6bfdea3c6e3f7..ce071323ff585 100644 --- a/paddle/gserver/layers/CostLayer.cpp +++ b/paddle/gserver/layers/CostLayer.cpp @@ -572,13 +572,8 @@ void MultiBinaryLabelCrossEntropy::backwardImp(Matrix& output, } } -// -// Huber loss for robust 2-classes classification -// -REGISTER_LAYER(huber, HuberTwoClass); - -bool HuberTwoClass::init(const LayerMap& layerMap, - const ParameterMap& parameterMap) { +bool HuberCost::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { CostLayer::init(layerMap, parameterMap); if (useGpu_) { tmpCpuInput_.reserve(inputLayers_.size()); @@ -589,7 +584,7 @@ bool HuberTwoClass::init(const LayerMap& layerMap, return true; } -void HuberTwoClass::forwardImp(Matrix& output, Argument& label, Matrix& cost) { +void HuberCost::forwardImp(Matrix& output, Argument& label, Matrix& cost) { if (useGpu_) { for (size_t i = 0; i < inputLayers_.size(); i++) { tmpCpuInput_[i].resizeAndCopyFrom( @@ -597,13 +592,87 @@ void HuberTwoClass::forwardImp(Matrix& output, Argument& label, Matrix& cost) { } hl_stream_synchronize(HPPL_STREAM_DEFAULT); } - forwardImpIn(output, label, cost); } -void HuberTwoClass::forwardImpIn(Matrix& output, - Argument& label, - Matrix& target) { +// +// Huber loss for robust regression. +// +REGISTER_LAYER(huber_regression, HuberRegressionLoss); + +bool HuberRegressionLoss::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + HuberCost::init(layerMap, parameterMap); + delta_ = config_.delta(); + return true; +} + +void HuberRegressionLoss::forwardImp(Matrix& output, + Argument& label, + Matrix& target) { + HuberCost::forwardImp(output, label, target); + size_t numSamples = target.getHeight(); + size_t dim = output.getWidth(); + CHECK(label.value); + CHECK_EQ((*label.value).getHeight(), numSamples); + CHECK_EQ(output.getHeight(), numSamples); + CHECK_EQ(dim, (*label.value).getWidth()); + CHECK_EQ(target.getWidth(), (size_t)1); + + real* out = useGpu_ ? tmpCpuInput_[0].value->getData() : output.getData(); + real* lbl = + useGpu_ ? tmpCpuInput_[1].value->getData() : (*label.value).getData(); + std::vector cost(numSamples, 0); + for (size_t i = 0; i < numSamples; ++i) { + for (size_t j = 0; j < dim; ++j) { + int index = i * dim + j; + real a = std::abs(lbl[index] - out[index]); + if (a <= delta_) + cost[i] += a * a / 2; + else + cost[i] += delta_ * (a - delta_ / 2); + } + } + target.copyFrom(cost.data(), numSamples); +} + +void HuberRegressionLoss::backwardImp(Matrix& output, + Argument& label, + Matrix& outputG) { + size_t numSamples = output.getHeight(); + size_t dim = output.getWidth(); + real* out = useGpu_ ? tmpCpuInput_[0].value->getData() : output.getData(); + real* lbl = + useGpu_ ? tmpCpuInput_[1].value->getData() : (*label.value).getData(); + real* grad = useGpu_ ? tmpCpuInput_[0].grad->getData() : outputG.getData(); + for (size_t i = 0; i < numSamples; ++i) { + for (size_t j = 0; j < dim; ++j) { + int index = i * dim + j; + real a = lbl[index] - out[index]; + if (std::abs(a) <= delta_) + grad[index] += -a; + else + grad[index] += a > 0 ? -delta_ : delta_; + } + } + if (useGpu_) outputG.copyFrom(grad, numSamples * dim); +} + +// +// Huber loss for robust 2-classes classification +// +REGISTER_LAYER(huber_classification, HuberTwoClassification); + +bool HuberTwoClassification::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + return HuberCost::init(layerMap, parameterMap); +} + +void HuberTwoClassification::forwardImp(Matrix& output, + Argument& label, + Matrix& target) { + HuberCost::forwardImp(output, label, target); size_t numSamples = target.getHeight(); + CHECK(label.ids); CHECK_EQ((*label.ids).getSize(), numSamples); CHECK_EQ(output.getHeight(), numSamples); CHECK_EQ(output.getWidth(), (size_t)1); @@ -611,47 +680,35 @@ void HuberTwoClass::forwardImpIn(Matrix& output, real* out = useGpu_ ? tmpCpuInput_[0].value->getData() : output.getData(); int* lbl = useGpu_ ? tmpCpuInput_[1].ids->getData() : (*label.ids).getData(); - std::vector cost(numSamples); + std::vector cost(numSamples, 0); for (size_t i = 0; i < numSamples; ++i) { int y = 2 * lbl[i] - 1; - if (out[i] * y < -1) - cost[i] = -4 * out[i] * y; - else if (out[i] * y < 1) - cost[i] = (1 - out[i] * y) * (1 - out[i] * y); - else - cost[i] = 0; + real a = out[i] * y; + if (a < -1) + cost[i] = -4 * a; + else if (a < 1) + cost[i] = (1 - a) * (1 - a); } target.copyFrom(cost.data(), numSamples); } -void HuberTwoClass::backwardImp(Matrix& outputValue, - Argument& label, - Matrix& outputGrad) { - if (useGpu_) { - backwardImpIn( - *tmpCpuInput_[0].value, tmpCpuInput_[1], *tmpCpuInput_[0].grad); - outputGrad.copyFrom(*tmpCpuInput_[0].grad); - } else { - backwardImpIn(outputValue, label, outputGrad); - } -} - -void HuberTwoClass::backwardImpIn(Matrix& output, - Argument& label, - Matrix& outputG) { +void HuberTwoClassification::backwardImp(Matrix& output, + Argument& label, + Matrix& outputG) { size_t numSamples = output.getHeight(); - real* out = output.getData(); - real* grad = outputG.getData(); - int* lbl = (*label.ids).getData(); + real* out = useGpu_ ? tmpCpuInput_[0].value->getData() : output.getData(); + int* lbl = useGpu_ ? tmpCpuInput_[1].ids->getData() : (*label.ids).getData(); + real* grad = useGpu_ ? tmpCpuInput_[0].grad->getData() : outputG.getData(); for (size_t i = 0; i < numSamples; ++i) { int y = 2 * lbl[i] - 1; - if (y * out[i] < -1) + real a = out[i] * y; + if (a < -1) grad[i] += -4 * y; - else if (y * out[i] < 1) - grad[i] += -2 * (1 - y * out[i]) * y; + else if (a < 1) + grad[i] += -2 * (1 - a) * y; } + if (useGpu_) outputG.copyFrom(grad, numSamples); } - /** * This cost layer compute the sum of its input as loss. * \f[ diff --git a/paddle/gserver/layers/CostLayer.h b/paddle/gserver/layers/CostLayer.h index 14c0b33ec1a62..0f655b48eea05 100644 --- a/paddle/gserver/layers/CostLayer.h +++ b/paddle/gserver/layers/CostLayer.h @@ -304,37 +304,70 @@ class MultiBinaryLabelCrossEntropy : public CostLayer { Matrix& outputGrad) override; }; -/** - * Huber loss for robust 2-classes classification. - * - * For label={0, 1}, let y=2*label-1. Given output f, the loss is: - * \f[ - * Loss = - * \left\{\begin{matrix} - * 4 * y * f & \textit{if} \ \ y* f < -1 \\ - * (1 - y * f)^2 & \textit{if} \ \ -1 < y * f < 1 \\ - * 0 & \textit{otherwise} - * \end{matrix}\right. - * \f] +/* + * A base layer for HuberRegressionLoss and HuberTwoClassification. */ -class HuberTwoClass : public CostLayer { +class HuberCost : public CostLayer { +public: std::vector tmpCpuInput_; -public: - explicit HuberTwoClass(const LayerConfig& config) : CostLayer(config) {} + explicit HuberCost(const LayerConfig& config) : CostLayer(config) {} bool init(const LayerMap& layerMap, const ParameterMap& parameterMap) override; void forwardImp(Matrix& output, Argument& label, Matrix& cost) override; - void forwardImpIn(Matrix& output, Argument& label, Matrix& cost); + void backwardImp(Matrix& outputValue, + Argument& label, + Matrix& outputGrad) override {} +}; + +/** + * Huber loss for robust regression. + * + * Given output f(x), label y and delta, the loss is: + * Loss = 0.5 * (1 - y * f)^2, if abs(y - f) <= delta \\ + * Loss = delta * abs(y - f) - 0.5 * delta^2, otherwise + */ +class HuberRegressionLoss : public HuberCost { +public: + explicit HuberRegressionLoss(const LayerConfig& config) : HuberCost(config) {} + + bool init(const LayerMap& layerMap, + const ParameterMap& parameterMap) override; + + void forwardImp(Matrix& output, Argument& label, Matrix& cost) override; void backwardImp(Matrix& outputValue, Argument& label, Matrix& outputGrad) override; - void backwardImpIn(Matrix& outputValue, Argument& label, Matrix& outputGrad); +protected: + real delta_; +}; + +/** + * Huber loss for robust 2-classes classification. + * + * For label={0, 1}, let y=2*label-1. Given output f(x), the loss is: + * Loss = 4 * y * f, if y* f < -1 \\ + * Loss = (1 - y * f)^2, if -1 < y * f < 1 \\ + * Loss = 0, otherwise + */ +class HuberTwoClassification : public HuberCost { +public: + explicit HuberTwoClassification(const LayerConfig& config) + : HuberCost(config) {} + + bool init(const LayerMap& layerMap, + const ParameterMap& parameterMap) override; + + void forwardImp(Matrix& output, Argument& label, Matrix& cost) override; + + void backwardImp(Matrix& outputValue, + Argument& label, + Matrix& outputGrad) override; }; typedef std::shared_ptr CostLayerPtr; diff --git a/paddle/gserver/layers/ExpandConvLayer.cpp b/paddle/gserver/layers/ExpandConvLayer.cpp index 0ece2799318ea..20de475fc3f6b 100644 --- a/paddle/gserver/layers/ExpandConvLayer.cpp +++ b/paddle/gserver/layers/ExpandConvLayer.cpp @@ -29,6 +29,10 @@ namespace paddle { REGISTER_LAYER(exconv, ExpandConvLayer); REGISTER_LAYER(exconvt, ExpandConvLayer); +inline bool isDepthwiseConv(int channels, int groups) { + return channels == groups; +} + bool ExpandConvLayer::init(const LayerMap &layerMap, const ParameterMap ¶meterMap) { /* Initialize the basic convolutional parent class */ @@ -47,14 +51,27 @@ bool ExpandConvLayer::init(const LayerMap &layerMap, std::vector paddings = {(size_t)paddingY_[i], (size_t)padding_[i]}; std::vector strides = {(size_t)strideY_[i], (size_t)stride_[i]}; - if (useGpu_ && (size_t)groups_[i] == (size_t)channels_[i] && !isDeconv_) { + // Convolution Layer uses the GemmConv function by default. + convType = "GemmConv"; + convGradInputType = "GemmConvGradInput"; + convGradFilterType = "GemmConvGradFilter"; + + // If depth wise convolution and useGpu == true + if (useGpu_ && isDepthwiseConv(channels_[i], groups_[i]) && !isDeconv_) { convType = "DepthwiseConv"; convGradInputType = "DepthwiseConvGradInput"; convGradFilterType = "DepthwiseConvGradFilter"; - } else { - convType = "GemmConv"; - convGradInputType = "GemmConvGradInput"; - convGradFilterType = "GemmConvGradFilter"; + } + + // If depth wise convolution and useGpu == false and ARM-NEON + if (!useGpu_ && isDepthwiseConv(channels_[i], groups_[i]) && !isDeconv_) { +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + if ((filterSize_[i] == filterSizeY_[i]) && + (filterSize_[i] == 3 || filterSize_[i] == 4) && + (stride_[i] == strideY_[i]) && (stride_[i] == 1 || stride_[i] == 2)) { + convType = "NeonDepthwiseConv"; + } +#endif } if (FLAGS_use_nnpack && !isDeconv_) { diff --git a/paddle/gserver/layers/Layer.cpp b/paddle/gserver/layers/Layer.cpp index d5621412caee8..2bc20eee6c452 100644 --- a/paddle/gserver/layers/Layer.cpp +++ b/paddle/gserver/layers/Layer.cpp @@ -41,7 +41,7 @@ namespace paddle { Layer::Layer(const LayerConfig& config, bool useGpu) : config_(config), useGpu_(useGpu), - deviceId_(-1), + deviceId_(CPU_DEVICE), needSequenceInfo_(true) {} bool Layer::init(const LayerMap& layerMap, const ParameterMap& parameterMap) { diff --git a/paddle/gserver/layers/Layer.h b/paddle/gserver/layers/Layer.h index 0ed482889d0ce..edef36194aabd 100644 --- a/paddle/gserver/layers/Layer.h +++ b/paddle/gserver/layers/Layer.h @@ -59,7 +59,12 @@ class Layer { LayerConfig config_; /// whether to use GPU bool useGpu_; - /// Device Id. CPU is -1, and GPU is 0, 1, 2 ... + /// Paddle device ID, MKLDNN is -2, CPU is -1 + enum PADDLE_DEVICE_ID { + MKLDNN_DEVICE = -2, + CPU_DEVICE = -1, + }; + /// Device Id. MKLDNN is -2, CPU is -1, and GPU is 0, 1, 2 ... int deviceId_; /// Input layers std::vector inputLayers_; @@ -77,6 +82,7 @@ class Layer { Argument output_; /// Several outputs stored on different devices, used in 'parallel_nn' case, /// and record them by deviceId_. + /// Also used in 'use_mkldnn' case. std::vector outputOtherDevice_; /// If there are several outputs, map them by each name. std::map outputMap_; @@ -172,6 +178,13 @@ class Layer { return inputLayer.getOutput(deviceId_); } + /** + * Get the argument of input layer with deviceId. + */ + const Argument& getInput(size_t inputIndex, int deviceId) const { + return inputLayers_[inputIndex]->getOutput(deviceId); + } + /** * Get the forward-input value. */ @@ -186,6 +199,13 @@ class Layer { return inputLayer.getOutput(deviceId_).value; } + /** + * Get the forward-input value with deviceId. + */ + const MatrixPtr& getInputValue(int inputIndex, int deviceId) { + return inputLayers_[inputIndex]->getOutput(deviceId).value; + } + /** * Get the forward-input grad. */ @@ -200,6 +220,13 @@ class Layer { return inputLayer.getOutput(deviceId_).grad; } + /** + * Get the forward-input grad. + */ + const MatrixPtr& getInputGrad(int inputIndex, int deviceId) { + return inputLayers_[inputIndex]->getOutput(deviceId).grad; + } + /** * Get the forward-input label. */ diff --git a/paddle/gserver/layers/MKLDNNFcLayer.cpp b/paddle/gserver/layers/MKLDNNFcLayer.cpp index d201fac65e045..8318c8c519a4c 100644 --- a/paddle/gserver/layers/MKLDNNFcLayer.cpp +++ b/paddle/gserver/layers/MKLDNNFcLayer.cpp @@ -61,43 +61,42 @@ void MKLDNNFcLayer::convertWeightsFromPaddle() { return; } - // TODO(TJ): dst format should get from wgtVal_ - int dstFmt = PARAM_FORMAT_MKLDNN_OI; - int srcFmt = weight_->getParameterPtr()->getHeaderFormat(); - if (srcFmt == dstFmt) { - return; - } - - // The weight_ is transposed from initial paddle weight - MatrixPtr paddleWgt = Matrix::create( - weight_->getW()->getData(), iLayerSize_, oc_, false, false); - - // TODO(TJ): remove this print when do not need differ weights - std::ostringstream ostr; - paddleWgt->print(ostr); - VLOG(MKLDNN_ALL) << "Initial Weight from paddle: " << std::endl << ostr.str(); - - // The mkldnn weight is transposed from initial paddle matrix - MatrixPtr paddleWgtT; - paddleWgt->transpose(paddleWgtT, true); - weight_->getW()->copyFrom(*paddleWgtT); - weight_->getParameterPtr()->setHeaderFormat(dstFmt); + CHECK(wgtVal_) << "should have been initialized"; + bool hasNoSpatial_ = ih_ == 1 && iw_ == 1; + auto targetDim = wgtVal_->getDims(); + auto srcFmt = hasNoSpatial_ ? memory::format::io : memory::format::ihwo; + wgtVal_->reorderDataFrom(wgtVal_, srcFmt, targetDim); hasInitedWgt_ = true; } void MKLDNNFcLayer::convertWeightsToPaddle() { - MatrixPtr dnnWgt = weight_->getW(); - MatrixPtr paddleWgt; - dnnWgt->transpose(paddleWgt, true); - - // copy paddle weight and override on weight_ - MatrixPtr dnnWgtT = Matrix::create( - dnnWgt->getData(), dnnWgt->getWidth(), dnnWgt->getHeight(), false, false); - dnnWgtT->copyFrom(*paddleWgt); + CHECK(wgtVal_) << "should have been initialized"; + bool hasNoSpatial_ = ih_ == 1 && iw_ == 1; + auto targetDim = wgtVal_->getDims(); + auto dstFmt = hasNoSpatial_ ? memory::format::io : memory::format::ihwo; + wgtVal_->reorderDataTo(wgtVal_, dstFmt, targetDim); +} + +void MKLDNNFcLayer::convertOutputToOtherDevice() { + copyOutputInfoToOtherDevice(); + // find other cpu device and reorder output to cpu device + int cnt = 0; + for (size_t i = 0; i < outputOtherDevice_.size(); i++) { + if (outputOtherDevice_[i].deviceId == CPU_DEVICE) { + // fc cpu output value do not need convert + // just share point + outputOtherDevice_[i].value = output_.value; + ++cnt; + } + } + + if (cnt > 1) { + LOG(WARNING) << "should not have more than one CPU devie"; + } } void MKLDNNFcLayer::reshape() { - const Argument& input = getInput(0); + const Argument& input = getInput(0, getPrev(0)->getDeviceId()); int batchSize = input.getBatchSize(); if (bs_ == batchSize) { return; @@ -111,10 +110,6 @@ void MKLDNNFcLayer::reshape() { if (iw_ == 0) { iw_ = 1; } - hasSpatial_ = true; - if (ih_ == 1 && iw_ == 1) { - hasSpatial_ = false; - } CHECK_EQ(iLayerSize_, inputLayers_[0]->getSize()); ic_ = iLayerSize_ / (ih_ * iw_); CHECK_EQ(size_t(ic_ * ih_ * iw_), iLayerSize_) << "not divisible"; @@ -135,37 +130,53 @@ void MKLDNNFcLayer::reshape() { void MKLDNNFcLayer::resetFwd() { bool hasBias = biases_ && biases_->getW(); - real* iData = getInputValue(0)->getData(); - real* oData = getOutputValue()->getData(); - real* wData = weight_->getW()->getData(); - real* bData = hasBias ? biases_->getW()->getData() : NULL; - - // TODO(TJ): below create should be covered in MkldnnMatrix - // create memory desc - memory::desc iMD = hasSpatial_ ? createMD({bs_, ic_, ih_, iw_}, format::nchw) - : createMD({bs_, ic_}, format::nc); - memory::desc wMD = hasSpatial_ ? createMD({oc_, ic_, ih_, iw_}, format::oihw) - : createMD({oc_, ic_}, format::oi); - memory::desc bMD = bData != NULL ? createMD({oc_}, format::x) - : createMD({}, format::format_undef); - memory::desc oMD = createMD({bs_, oc_}, format::nc); - - // create memory primitive desc and memory self - inVal_.reset(new memory(memory::primitive_desc(iMD, engine_), iData)); - wgtVal_.reset(new memory(memory::primitive_desc(wMD, engine_), wData)); - outVal_.reset(new memory(memory::primitive_desc(oMD, engine_), oData)); + const MatrixPtr& wgt = weight_->getW(); + const MatrixPtr& bias = hasBias ? biases_->getW() : nullptr; + const MatrixPtr& out = output_.value; + + if (inputIsOnlyMKLDNN()) { + const MatrixPtr& in = getInputValue(0); + inVal_ = std::dynamic_pointer_cast(in); + CHECK(inVal_) << "Input should be MKLDNNMatrix"; + } else { + CHECK_EQ(getPrev(0)->getDeviceId(), CPU_DEVICE) << "Only support CPU yet"; + const MatrixPtr& in = getInputValue(0, CPU_DEVICE); + inVal_ = MKLDNNMatrix::create( + in, memory::dims{bs_, ic_, ih_, iw_}, format::nchw, engine_); + } + inVal_->downSpatial(); + wgtVal_ = MKLDNNMatrix::create( + wgt, memory::dims{oc_, ic_, ih_, iw_}, format::oihw, engine_); + wgtVal_->downSpatial(); + biasVal_ = + hasBias ? MKLDNNMatrix::create(bias, {oc_}, format::x, engine_) : nullptr; + outVal_ = MKLDNNMatrix::create(out, {bs_, oc_}, format::nc, engine_); + + // change original output value to mkldnn output value + output_.value = std::dynamic_pointer_cast(outVal_); + if (!outputIsOnlyMKLDNN()) { + convertOutputToOtherDevice(); + } + // create forward handle prop_kind pk = prop_kind::forward; - fc_fwd::desc fwdDesc = bData != NULL ? fc_fwd::desc(pk, iMD, wMD, bMD, oMD) - : fc_fwd::desc(pk, iMD, wMD, oMD); + fc_fwd::desc fwdDesc = hasBias ? fc_fwd::desc(pk, + inVal_->getMemoryDesc(), + wgtVal_->getMemoryDesc(), + biasVal_->getMemoryDesc(), + outVal_->getMemoryDesc()) + : fc_fwd::desc(pk, + inVal_->getMemoryDesc(), + wgtVal_->getMemoryDesc(), + outVal_->getMemoryDesc()); fc_fwd::primitive_desc fwdPD = fc_fwd::primitive_desc(fwdDesc, engine_); - - if (bData != NULL) { - biasVal_.reset(new memory(memory::primitive_desc(bMD, engine_), bData)); + if (hasBias) { fwd_.reset(new fc_fwd(fwdPD, *inVal_, *wgtVal_, *biasVal_, *outVal_)); } else { fwd_.reset(new fc_fwd(fwdPD, *inVal_, *wgtVal_, *outVal_)); } + printValueFormatFlow(); + pipelineFwd_.clear(); pipelineFwd_.push_back(*fwd_); } @@ -175,45 +186,46 @@ void MKLDNNFcLayer::resetBwd() { return; } needResetBwd_ = false; - bool hasBias = biases_ && biases_->getWGrad(); - real* iData = getInputValue(0)->getData(); - real* iDiff = getInputGrad(0) != nullptr ? getInputGrad(0)->getData() : NULL; - real* oDiff = getOutputGrad()->getData(); - real* wDiff = weight_->getWGrad()->getData(); - real* bDiff = hasBias ? biases_->getWGrad()->getData() : NULL; /// backward weight - // create memory desc for backward memory - memory::desc iMD = hasSpatial_ ? createMD({bs_, ic_, ih_, iw_}, format::nchw) - : createMD({bs_, ic_}, format::nc); - memory::desc wMD = hasSpatial_ ? createMD({oc_, ic_, ih_, iw_}, format::oihw) - : createMD({oc_, ic_}, format::oi); - memory::desc oMD = createMD({bs_, oc_}, format::nc); - memory::desc bMD = bDiff != NULL ? createMD({oc_}, format::x) - : createMD({}, format::format_undef); - - if (inVal_) { - // update data - inVal_->set_data_handle(iData); - } else { - inVal_.reset(new memory(memory::primitive_desc(iMD, engine_), iData)); - } - - // create memory primitive desc and memory self - wgtGrad_.reset(new memory(memory::primitive_desc(wMD, engine_), wDiff)); - outGrad_.reset(new memory(memory::primitive_desc(oMD, engine_), oDiff)); - - fc_fwd::desc fwdDesc = fc_fwd::desc(prop_kind::forward, iMD, wMD, oMD); + CHECK(inVal_) << "Should have input value"; + const MatrixPtr& wgt = weight_->getWGrad(); + const MatrixPtr& bias = hasBias ? biases_->getWGrad() : nullptr; + + // TODO(TJ): merge outgrad + int device = outputIsOnlyMKLDNN() ? MKLDNN_DEVICE : CPU_DEVICE; + // for MKLDNN device: + // can not directly cast outputgrad to mkldnnmatrix, + // since each layer can not write the inputgrad to mkldnn inputgrad. + // So just create from matrix with outputvalue format. + // for CPU device: + // fc do not need to convert from cpu device since output is always nc format + // only need create from cpu device + const MatrixPtr& out = getOutput(device).grad; + outGrad_ = MKLDNNMatrix::create(out, outVal_->getPrimitiveDesc()); + wgtGrad_ = MKLDNNMatrix::create(wgt, wgtVal_->getPrimitiveDesc()); + biasGrad_ = hasBias ? MKLDNNMatrix::create(bias, biasVal_->getPrimitiveDesc()) + : nullptr; + + // create memory primitive desc + fc_fwd::desc fwdDesc = fc_fwd::desc(prop_kind::forward, + inVal_->getMemoryDesc(), + wgtGrad_->getMemoryDesc(), + outGrad_->getMemoryDesc()); fc_fwd::primitive_desc fwdPD = fc_fwd::primitive_desc(fwdDesc, engine_); - fc_bwdWgt::desc bwdWgtDesc = bDiff != NULL - ? fc_bwdWgt::desc(iMD, wMD, bMD, oMD) - : fc_bwdWgt::desc(iMD, wMD, oMD); + fc_bwdWgt::desc bwdWgtDesc = hasBias + ? fc_bwdWgt::desc(inVal_->getMemoryDesc(), + wgtGrad_->getMemoryDesc(), + biasGrad_->getMemoryDesc(), + outGrad_->getMemoryDesc()) + : fc_bwdWgt::desc(inVal_->getMemoryDesc(), + wgtGrad_->getMemoryDesc(), + outGrad_->getMemoryDesc()); fc_bwdWgt::primitive_desc bwdWgtPD = fc_bwdWgt::primitive_desc(bwdWgtDesc, engine_, fwdPD); - if (bDiff != NULL) { - biasGrad_.reset(new memory(memory::primitive_desc(bMD, engine_), bDiff)); + if (hasBias) { bwdWgt_.reset( new fc_bwdWgt(bwdWgtPD, *inVal_, *outGrad_, *wgtGrad_, *biasGrad_)); } else { @@ -223,15 +235,26 @@ void MKLDNNFcLayer::resetBwd() { pipelineBwd_.push_back(*bwdWgt_); /// backward data - if (iDiff == NULL) { + device = inputIsOnlyMKLDNN() ? MKLDNN_DEVICE : CPU_DEVICE; + const MatrixPtr& in = getInputGrad(0, device); + if (in == nullptr) { return; } - fc_bwdData::desc bwdDataDesc = fc_bwdData::desc(iMD, wMD, oMD); + if (getInput(0, device).getAllCount() > 1) { + // TODO(TJ): use outputMaps_ ways when merge outgrad done + } else { + inGrad_ = MKLDNNMatrix::create(in, inVal_->getPrimitiveDesc()); + } + + fc_bwdData::desc bwdDataDesc = fc_bwdData::desc(inVal_->getMemoryDesc(), + wgtGrad_->getMemoryDesc(), + outGrad_->getMemoryDesc()); fc_bwdData::primitive_desc bwdDataPD = fc_bwdData::primitive_desc(bwdDataDesc, engine_, fwdPD); - inGrad_.reset(new memory(memory::primitive_desc(iMD, engine_), iDiff)); + CHECK(wgtVal_) << "Should have weight memory"; bwdData_.reset(new fc_bwdData(bwdDataPD, *outGrad_, *wgtVal_, *inGrad_)); + printGradFormatFlow(); pipelineBwd_.push_back(*bwdData_); } @@ -241,11 +264,7 @@ void MKLDNNFcLayer::forward(PassType passType) { { REGISTER_TIMER_INFO("mkldnn_FwdTimer", getName().c_str()); - - // update input data - // since it might be changed if this is after data layer - real* iData = getInputValue(0)->getData(); - inVal_->set_data_handle(iData); + syncInputValue(); // just submit forward pipeline stream_->submit(pipelineFwd_); @@ -267,10 +286,7 @@ void MKLDNNFcLayer::backward(const UpdateCallback& callback) { REGISTER_TIMER_INFO("mkldnn_bwdTimer", getName().c_str()); resetBwd(); - // update diff - real* oDiff = getOutputGrad()->getData(); - outGrad_->set_data_handle(oDiff); - + syncOutputGrad(); // just sumbmit backward pipeline stream_->submit(pipelineBwd_); } diff --git a/paddle/gserver/layers/MKLDNNFcLayer.h b/paddle/gserver/layers/MKLDNNFcLayer.h index 7954852a23f81..e138a6faf181c 100644 --- a/paddle/gserver/layers/MKLDNNFcLayer.h +++ b/paddle/gserver/layers/MKLDNNFcLayer.h @@ -32,16 +32,13 @@ class MKLDNNFcLayer : public MKLDNNLayer { // if has already init the weight bool hasInitedWgt_; - // if input layer has image size info (ih>1 && iw>1) - bool hasSpatial_; - // fc weight and bias std::unique_ptr weight_; std::unique_ptr biases_; public: explicit MKLDNNFcLayer(const LayerConfig& config) - : MKLDNNLayer(config), hasInitedWgt_(false), hasSpatial_(true) {} + : MKLDNNLayer(config), hasInitedWgt_(false) {} ~MKLDNNFcLayer() {} @@ -75,6 +72,8 @@ class MKLDNNFcLayer : public MKLDNNLayer { * only would be called when needed */ void resetBwd(); + + void convertOutputToOtherDevice() override; }; } // namespace paddle diff --git a/paddle/gserver/layers/MKLDNNLayer.h b/paddle/gserver/layers/MKLDNNLayer.h index 63e29f447eede..b983b833d510b 100644 --- a/paddle/gserver/layers/MKLDNNLayer.h +++ b/paddle/gserver/layers/MKLDNNLayer.h @@ -18,9 +18,9 @@ limitations under the License. */ #include "Layer.h" #include "MKLDNNBase.h" #include "mkldnn.hpp" +#include "paddle/math/MKLDNNMatrix.h" DECLARE_bool(use_mkldnn); -DECLARE_bool(use_mkldnn_wgt); namespace paddle { @@ -52,15 +52,15 @@ class MKLDNNLayer : public Layer { std::vector pipelineFwd_; std::vector pipelineBwd_; - // TODO(TJ): change below memory as MKLDNNMatrixPtr type - std::shared_ptr inVal_; - std::shared_ptr inGrad_; - std::shared_ptr outVal_; - std::shared_ptr outGrad_; - std::shared_ptr wgtVal_; - std::shared_ptr wgtGrad_; - std::shared_ptr biasVal_; - std::shared_ptr biasGrad_; + // MKLDNNMatrixPtr + MKLDNNMatrixPtr inVal_; + MKLDNNMatrixPtr inGrad_; + MKLDNNMatrixPtr outVal_; + MKLDNNMatrixPtr outGrad_; + MKLDNNMatrixPtr wgtVal_; + MKLDNNMatrixPtr wgtGrad_; + MKLDNNMatrixPtr biasVal_; + MKLDNNMatrixPtr biasGrad_; public: explicit MKLDNNLayer(const LayerConfig& config) @@ -83,17 +83,21 @@ class MKLDNNLayer : public Layer { virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap) { + CHECK(FLAGS_use_mkldnn) << "MkldnnLayers only support use_mkldnn." + << "Please set WITH_MKLDNN=ON " + << "and set use_mkldnn=True"; + CHECK(!useGpu_) << "Do not support GPU yet"; + + // set device id before Layer::init + setDevice(MKLDNN_DEVICE); + // change param device to MKLDNN device + setParamsDevice(MKLDNN_DEVICE, parameterMap); if (!Layer::init(layerMap, parameterMap)) { return false; } - CHECK(FLAGS_use_mkldnn) << "MkldnnLayers only support use_mkldnn." - << "Please set WITH_MKLDNN=ON " - << "and set use_mkldnn=True"; stream_.reset(new MKLDNNStream()); engine_ = CPUEngine::Instance().getEngine(); - - // TODO(TJ): deivecId return true; } @@ -109,6 +113,12 @@ class MKLDNNLayer : public Layer { */ virtual void convertWeightsToPaddle() {} + /** + * convert MKLDNN output to other device. + * only support CPU device yet + */ + virtual void convertOutputToOtherDevice() {} + /** * print info about sizes */ @@ -118,14 +128,124 @@ class MKLDNNLayer : public Layer { << ", oh: " << oh_ << ", ow: " << ow_; } - // TODO(TJ): move to MkldnnMatrix - // create memory desc - inline mkldnn::memory::desc createMD( - mkldnn::memory::dims dims, - mkldnn::memory::format fmt, - mkldnn::memory::data_type type = mkldnn::memory::data_type::f32) { - // TODO(TJ): isFmtSuppoted(fmt) - return mkldnn::memory::desc(dims, type, fmt); + /** + * Print the mkldnn memory format flow of value + */ + virtual void printValueFormatFlow() { + if (inVal_ && outVal_) { + VLOG(MKLDNN_FMTS) << "value format flow --- " << inVal_->getFormat() + << " >>> " << outVal_->getFormat(); + } + } + + /** + * Print the mkldnn memory format flow of grad + */ + virtual void printGradFormatFlow() { + if (inGrad_ && outGrad_) { + VLOG(MKLDNN_FMTS) << "grad format flow --- " << inGrad_->getFormat() + << " <<< " << outGrad_->getFormat(); + } + } + +protected: + /** + * copy image size and sequence info to other device + * @note: can not directly use Layer::copyOutputToOtherDevice since here only + * copy base info and do not copy data value + */ + void copyOutputInfoToOtherDevice() { + for (size_t i = 0; i < outputOtherDevice_.size(); i++) { + outputOtherDevice_[i].setFrameHeight(output_.getFrameHeight()); + outputOtherDevice_[i].setFrameWidth(output_.getFrameWidth()); + outputOtherDevice_[i].sequenceStartPositions = + output_.sequenceStartPositions; + outputOtherDevice_[i].subSequenceStartPositions = + output_.subSequenceStartPositions; + outputOtherDevice_[i].cpuSequenceDims = output_.cpuSequenceDims; + } + } + + /** + * If input only has MKLDNN device. + * Otherwise, only support the previous layer using CPU device. + */ + bool inputIsOnlyMKLDNN(int index = 0) { + int prevDevice = getPrev(index)->getDeviceId(); + if (prevDevice == MKLDNN_DEVICE) { + return true; + } else { + // do not support GPU yet + CHECK_EQ(prevDevice, CPU_DEVICE) << "Only support CPU yet"; + return false; + } + } + + /** + * If output only has MKLDNN device. + * Otherwise, other devices should only using CPU device. + */ + bool outputIsOnlyMKLDNN() { + for (size_t i = 0; i < outputOtherDevice_.size(); i++) { + CHECK_EQ(outputOtherDevice_[i].deviceId, CPU_DEVICE) + << "Only support other device is CPU yet"; + } + return outputOtherDevice_.size() == 0; + } + + /** + * Sync input value data + */ + void syncInputValue() { + if (inputIsOnlyMKLDNN()) { + return; + } + real* iData = getInputValue(0, CPU_DEVICE)->getData(); + // update input data + // since it might be changed if this is after data layer + inVal_->updateData(iData); + } + + /** + * Sync output grad data + */ + void syncOutputGrad() { + if (outputIsOnlyMKLDNN()) { + return; + } + + // update diff + real* oDiff = getOutput(CPU_DEVICE).grad->getData(); + outGrad_->updateData(oDiff); + } + + /** + * Set deviceId of this layer. + */ + void setDevice(int id) { deviceId_ = id; } + + /** + * Set deviceId of the params used in this layer. + */ + void setParamsDevice(int id, const ParameterMap& parameterMap) { + for (auto& inputConfig : config_.inputs()) { + if (inputConfig.has_input_parameter_name()) { + ParameterPtr parameter; + std::string name = inputConfig.input_parameter_name(); + CHECK(mapGet(name, parameterMap, ¶meter)) + << "Cannot find input parameter " << name << " for layer " + << getName(); + parameter->setDevice(id); + } + } + if (config_.has_bias_parameter_name()) { + ParameterPtr parameter; + std::string name = config_.bias_parameter_name(); + CHECK(mapGet(name, parameterMap, ¶meter)) + << "Cannot find bias parameter " << name << " for layer " + << getName(); + parameter->setDevice(id); + } } }; diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 9946f7666498e..93b6e3cc5bd7a 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -850,9 +850,27 @@ TEST(Layer, square_error_weighted) { } } +TEST(Layer, huber_regression_loss) { + TestConfig config; + config.layerConfig.set_type("huber_regression"); + config.biasSize = 0; + + config.inputDefs.push_back({INPUT_DATA, "layer_0", 10, 0}); + config.inputDefs.push_back({INPUT_DATA_TARGET, "layer_1", 10, 0}); + config.layerConfig.add_inputs(); + config.layerConfig.add_inputs(); + + for (auto useGpu : {false, true}) { + for (auto delta : {1, 3, 5}) { + config.layerConfig.set_delta(delta); + testLayerGrad(config, "huber_regression", 100, /* trans */ false, useGpu); + } + } +} + TEST(Layer, huber_two_class) { TestConfig config; - config.layerConfig.set_type("huber"); + config.layerConfig.set_type("huber_classification"); config.biasSize = 0; config.inputDefs.push_back({INPUT_DATA, "layer_0", 1, 0}); @@ -861,7 +879,7 @@ TEST(Layer, huber_two_class) { config.layerConfig.add_inputs(); for (auto useGpu : {false, true}) { - testLayerGrad(config, "huber", 100, /* trans */ false, useGpu); + testLayerGrad(config, "huber_two_class", 100, /* trans */ false, useGpu); } } diff --git a/paddle/math/Allocator.h b/paddle/math/Allocator.h index 666a8b8368e3e..94ef561f066a1 100644 --- a/paddle/math/Allocator.h +++ b/paddle/math/Allocator.h @@ -48,7 +48,13 @@ class CpuAllocator : public Allocator { */ virtual void* alloc(size_t size) { void* ptr; +#ifdef PADDLE_USE_MKLDNN + // refer to https://github.com/01org/mkl-dnn/blob/master/include/mkldnn.hpp + // memory alignment + CHECK_EQ(posix_memalign(&ptr, 4096ul, size), 0); +#else CHECK_EQ(posix_memalign(&ptr, 32ul, size), 0); +#endif CHECK(ptr) << "Fail to allocate CPU memory: size=" << size; return ptr; } diff --git a/paddle/math/CMakeLists.txt b/paddle/math/CMakeLists.txt index bf28092e82b77..68b5296228cd7 100644 --- a/paddle/math/CMakeLists.txt +++ b/paddle/math/CMakeLists.txt @@ -14,6 +14,17 @@ # file(GLOB MATH_HEADERS . *.h) file(GLOB MATH_SOURCES . *.cpp) + +if(NOT WITH_MKLDNN) + set(DNN_HEADER "${CMAKE_CURRENT_SOURCE_DIR}/MKLDNNMatrix.h") + set(DNN_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/MKLDNNMatrix.cpp") + list(REMOVE_ITEM MATH_HEADERS "${DNN_HEADER}") + list(REMOVE_ITEM MATH_SOURCES "${DNN_SOURCE}") + message(STATUS "Skip compiling with MKLDNNMatrix") +else() + message(STATUS "Compile with MKLDNNMatrix") +endif() + set(MATH_SOURCES "${PADDLE_SOURCE_DIR}/paddle/math/BaseMatrix.cu" "${PADDLE_SOURCE_DIR}/paddle/math/TrainingAlgorithmOp.cu" diff --git a/paddle/math/MKLDNNMatrix.cpp b/paddle/math/MKLDNNMatrix.cpp new file mode 100644 index 0000000000000..0a355e2644cce --- /dev/null +++ b/paddle/math/MKLDNNMatrix.cpp @@ -0,0 +1,144 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "MKLDNNMatrix.h" + +using namespace mkldnn; // NOLINT + +namespace paddle { + +MKLDNNMatrixPtr MKLDNNMatrix::create(MatrixPtr m, memory::primitive_desc pd) { + memory::desc md = pd.desc(); + size_t ndims = md.data.ndims; + int* dims = md.data.dims; + CHECK(ndims > 0) << "Input dims should not be empty"; + size_t cnts = 1; + for (size_t i = 0; i < ndims; ++i) { + cnts *= dims[i]; + } + + if (m == nullptr) { + size_t height = dims[0]; + size_t width = cnts / dims[0]; + m = Matrix::create(height, width, false, false); + } + + CHECK(m) << " Matrix should not be empty"; + CpuMatrixPtr cpuMatrix = std::dynamic_pointer_cast(m); + CHECK(cpuMatrix) << "Only support create from CPU matrix yet"; + + CHECK_EQ(cnts, m->getElementCnt()) << "Count size does not match"; + return std::make_shared( + m->getData(), m->getHeight(), m->getWidth(), pd); +} + +MKLDNNMatrixPtr MKLDNNMatrix::create(MatrixPtr m, + memory::dims dims, + memory::format fmt, + engine& eg, + mkldnn::memory::data_type dtype) { + return create(m, memory::primitive_desc(memory::desc(dims, dtype, fmt), eg)); +} + +void MKLDNNMatrix::reorderDataFrom(const MKLDNNMatrixPtr& m, + memory::format srcFmt, + memory::dims targetDim) { + memory::format dstFmt = getFormat(); + if (srcFmt == dstFmt) { + return; + } + CHECK_EQ(getElementCnt(), m->getElementCnt()) << "size should equal"; + reorderOnce(getData(), m->getData(), srcFmt, dstFmt, targetDim); +} + +void MKLDNNMatrix::reorderDataTo(const MKLDNNMatrixPtr& m, + memory::format dstFmt, + memory::dims targetDim) { + memory::format srcFmt = getFormat(); + if (srcFmt == dstFmt) { + return; + } + CHECK_EQ(getElementCnt(), m->getElementCnt()) << "size should equal"; + reorderOnce(getData(), m->getData(), srcFmt, dstFmt, targetDim); +} + +void MKLDNNMatrix::reorderOnce(void* srcData, + void* dstData, + memory::format srcFmt, + memory::format dstFmt, + memory::dims dm) { + CHECK(srcData); + CHECK(dstData); + MatrixPtr tmpSrc; + if (dstData == srcData) { + // inplace data + size_t sz = 1; + for (size_t i = 0; i < dm.size(); ++i) { + sz *= dm[i]; + } + tmpSrc = Matrix::create(sz, 1, false, false); + tmpSrc->copyFrom((real*)srcData, sz); + srcData = tmpSrc->getData(); + } + + auto dtype = this->getDtype(); + auto srcMD = memory::desc(dm, dtype, srcFmt); + auto dstMD = memory::desc(dm, dtype, dstFmt); + + auto eg = this->getEngine(); + auto src = memory(memory::primitive_desc(srcMD, eg), srcData); + auto dst = memory(memory::primitive_desc(dstMD, eg), dstData); + + auto r = reorder(src, dst); + stream(stream::kind::eager).submit({r}).wait(); +} + +void MKLDNNMatrix::downSpatial() { + int fmt = getFormat(); + if (!(fmt == memory::format::nchw || fmt == memory::format::oihw)) { + // only support nchw and oihw yet, later can support more like nhwc, ihwo + return; + } + + // TODO(TJ): change H(height) and W(width) if support nhwc or more + const int H = 2, W = 3; + memory::dims srcDims = getDims(); + if (srcDims[H] != 1 || srcDims[W] != 1) { + // can not down spatial + return; + } + + memory::dims dstDims = memory::dims{srcDims[0], srcDims[1]}; + memory::format dstFmt; + switch (fmt) { + case memory::format::nchw: + dstFmt = memory::format::nc; + break; + case memory::format::oihw: + dstFmt = memory::format::oi; + break; + default: + LOG(FATAL) << "unsupported format"; + } + memory::desc md = memory::desc(dstDims, getDtype(), dstFmt); + memory::primitive_desc pd = memory::primitive_desc(md, getEngine()); + mkldnn_primitive_t result; + mkldnn::error::wrap_c_api( + mkldnn_primitive_create(&result, pd.get(), nullptr, nullptr), + "could not create a memory primitive"); + reset(result); + set_data_handle(getData()); +} + +} // namespace paddle diff --git a/paddle/math/MKLDNNMatrix.h b/paddle/math/MKLDNNMatrix.h new file mode 100644 index 0000000000000..e50f698b49571 --- /dev/null +++ b/paddle/math/MKLDNNMatrix.h @@ -0,0 +1,148 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "Matrix.h" +#include "mkldnn.hpp" +#include "paddle/parameter/Parameter.h" + +namespace paddle { + +class MKLDNNMatrix; +typedef std::shared_ptr MKLDNNMatrixPtr; + +/** + * @brief MKLDNN Matrix. + * + */ +class MKLDNNMatrix : public CpuMatrix, public mkldnn::memory { +public: + MKLDNNMatrix(real* data, + size_t height, + size_t width, + mkldnn::memory::primitive_desc pd) + : CpuMatrix(data, height, width, false), mkldnn::memory(pd, data) {} + + ~MKLDNNMatrix() {} + + /** + * Create MKLDNNMatrix from a MatrixPtr and memory primitive_desc + */ + static MKLDNNMatrixPtr create(MatrixPtr m, mkldnn::memory::primitive_desc pd); + + /** + * Create MKLDNNMatrix from a MatrixPtr and memory details info + */ + static MKLDNNMatrixPtr create( + MatrixPtr m, + mkldnn::memory::dims dims, + mkldnn::memory::format fmt, + mkldnn::engine& eg, + mkldnn::memory::data_type dtype = mkldnn::memory::data_type::f32); + +public: + /** + * Reorder this MKLDNNMatrix from other format. + * Support inplace reorder. + * @note: this function would only reorder the data layout. + * will NOT change this original dim or format info + */ + void reorderDataFrom(const MKLDNNMatrixPtr& m, + memory::format srcFmt, + memory::dims targetDim); + + /** + * Reorder this MKLDNNMatrix to other format. + * Support inplace reorder. + * @note: this function would only reorder the data layout. + * will NOT change the dst dim or format info + */ + void reorderDataTo(const MKLDNNMatrixPtr& m, + memory::format dstFmt, + memory::dims targetDim); + + /** + * Dimensionality reduction. + * Change format "nchw --> nc" or "oihw --> oi" if the h and w are both 1 + */ + void downSpatial(); + + /** + * Update the memory data handle. + * Caution: This will not check the buffer size of the data, + * it should be coverd by user. + */ + void updateData(void* data) { set_data_handle(data); } + + /** + * Get primitive descriptor. + */ + mkldnn::memory::primitive_desc getPrimitiveDesc() { + return this->get_primitive_desc(); + } + + /** + * Get memory descriptor. + */ + mkldnn::memory::desc getMemoryDesc() { return getPrimitiveDesc().desc(); } + + /** + * Get dimensions. + */ + mkldnn::memory::dims getDims() { + mkldnn::memory::desc md = getMemoryDesc(); + const int* src = md.data.dims; + int ndims = md.data.ndims; + mkldnn::memory::dims dst; + dst.resize(ndims); + for (int i = 0; i < ndims; ++i) { + dst[i] = src[i]; + } + return dst; + } + + /** + * Get format. + */ + mkldnn::memory::format getFormat() { + return (mkldnn::memory::format)(getMemoryDesc().data.format); + } + + /** + * Get memory data type. + */ + mkldnn::memory::data_type getDtype() { + return (mkldnn::memory::data_type)(getMemoryDesc().data.data_type); + } + + /** + * Get engine. + */ + mkldnn::engine getEngine() { return getPrimitiveDesc().get_engine(); } + +protected: + /** + * Do reorder once. + * Can support inplace. + */ + void reorderOnce(void* srcData, + void* dstData, + memory::format srcFmt, + memory::format dstFmt, + memory::dims dm); +}; + +} // namespace paddle diff --git a/paddle/operators/net_op.cc b/paddle/operators/net_op.cc index 44d925f0b0cc5..78b5e27678423 100644 --- a/paddle/operators/net_op.cc +++ b/paddle/operators/net_op.cc @@ -31,10 +31,13 @@ void NetOp::CompleteAddOp(bool calc) { for (auto& op : ops_) { for (auto& ipt : op->Inputs()) { for (auto& var_name : ipt.second) { - if (!Contains(output_set, var_name)) { // Not other op's output - input_set.insert(var_name); - } else { + // If input variable has been in output set, then it will be + // added into intermediate_outputs_. Otherwise, it will be + // added into input set. + if (Contains(output_set, var_name)) { intermediate_outputs_.insert(var_name); + } else { + input_set.insert(var_name); } } } diff --git a/paddle/parameter/Parameter.h b/paddle/parameter/Parameter.h index 321f4275d8e68..04f12efaac15a 100644 --- a/paddle/parameter/Parameter.h +++ b/paddle/parameter/Parameter.h @@ -281,7 +281,11 @@ class Parameter { /** * @brief Set the format in header. */ - void setHeaderFormat(int32_t fmt) { headerFormat_ = fmt; } + void setHeaderFormat(int32_t fmt) { + CHECK(isHeaderFormatSupported(fmt)) << "Unsupported format version: " + << fmt; + headerFormat_ = fmt; + } /** * @brief Parameter Update Hook. diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto index 1ea1e05259652..1113d5aded1eb 100644 --- a/proto/ModelConfig.proto +++ b/proto/ModelConfig.proto @@ -499,6 +499,9 @@ message LayerConfig { optional int32 axis = 54 [ default = 2 ]; repeated uint32 offset = 55; repeated uint32 shape = 56; + + // for HuberRegressionLoss + optional double delta = 57 [ default = 1.0 ]; } message EvaluatorConfig { diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 752c686937f3b..0788e3994ebd9 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -2290,7 +2290,7 @@ def init(cls, name, inputs, device=None, coeff=1.): define_cost('SumOfSquaresCostLayer', 'square_error') define_cost('MultiBinaryLabelCrossEntropy', 'multi_binary_label_cross_entropy') define_cost('SoftBinaryClassCrossEntropy', 'soft_binary_class_cross_entropy') -define_cost('HuberTwoClass', 'huber') +define_cost('HuberTwoClassification', 'huber_classification') define_cost('SumCost', 'sum_cost') define_cost('SmoothL1Cost', 'smooth_l1') @@ -2352,6 +2352,17 @@ def __init__(self, name, inputs, NDCG_num=5, max_sort_size=-1, device=None): self.config.max_sort_size = max_sort_size +@config_layer('huber_regression') +class HuberRegressionLoss(LayerBase): + def __init__(self, name, inputs, delta=1., coeff=1., device=None): + super(HuberRegressionLoss, self).__init__( + name, 'huber_regression', 1, inputs=inputs, device=device) + config_assert( + len(self.inputs) == 2, 'HuberRegression must have 2 inputs') + self.config.delta = delta + self.config.coeff = coeff + + @config_layer('nce') class NCELayer(LayerBase): def __init__(self, diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index b9040f76a9ade..e73098910cb8a 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -111,7 +111,8 @@ 'sum_cost', 'rank_cost', 'lambda_cost', - 'huber_cost', + 'huber_regression_cost', + 'huber_classification_cost', 'block_expand_layer', 'maxout_layer', 'out_prod_layer', @@ -221,7 +222,8 @@ class LayerType(object): RANK_COST = 'rank-cost' LAMBDA_COST = 'lambda_cost' - HUBER = 'huber' + HUBER_REGRESSION = 'huber_regression' + HUBER_CLASSIFICATION = 'huber_classification' CROSS_ENTROPY = 'multi-class-cross-entropy' CROSS_ENTROPY_WITH_SELFNORM = 'multi_class_cross_entropy_with_selfnorm' CROSS_ENTROPY_OVER_BEAM = 'cross_entropy_over_beam' @@ -5650,16 +5652,77 @@ def sum_cost(input, name=None, layer_attr=None): @wrap_name_default() @layer_support() -def huber_cost(input, label, name=None, coeff=1.0, layer_attr=None): +def huber_regression_cost(input, + label, + name=None, + delta=1.0, + coeff=1.0, + layer_attr=None): + """ + In statistics, the Huber loss is a loss function used in robust regression, + that is less sensitive to outliers in data than the squared error loss. + Given a prediction f(x), a label y and :math:`\delta`, the loss function + is defined as: + + .. math: + loss = 0.5*\left ( y-f(x) \right )^2, \left | y-f(x) \right |\leq \delta + loss = \delta \left | y-f(x) \right |-0.5\delta ^2, otherwise + + The example usage is: + + .. code-block:: python + + cost = huber_regression_cost(input=input_layer, label=label_layer) + + :param input: The first input layer. + :type input: LayerOutput. + :param label: The input label. + :type input: LayerOutput. + :param name: The name of this layers. It is not necessary. + :type name: None|basestring. + :param delta: The difference between the observed and predicted values. + :type delta: float. + :param coeff: The coefficient affects the gradient in the backward. + :type coeff: float. + :param layer_attr: Extra Layer Attribute. + :type layer_attr: ExtraLayerAttribute + :return: LayerOutput object. + :rtype: LayerOutput. + """ + assert isinstance(input, LayerOutput) + Layer( + name=name, + type=LayerType.HUBER_REGRESSION, + inputs=[input.name, label.name], + delta=delta, + coeff=coeff, + **ExtraLayerAttribute.to_kwargs(layer_attr)) + return LayerOutput( + name, LayerType.HUBER_REGRESSION, parents=[input, label], size=1) + + +@wrap_name_default() +@layer_support() +def huber_classification_cost(input, + label, + name=None, + coeff=1.0, + layer_attr=None): """ - A loss layer for huber loss. + For classification purposes, a variant of the Huber loss called modified Huber + is sometimes used. Given a prediction f(x) (a real-valued classifier score) and + a true binary class label :math:`y\in \left \{-1, 1 \right \}`, the modified Huber + loss is defined as: + + .. math: + loss = \max \left ( 0, 1-yf(x) \right )^2, yf(x)\geq 1 + loss = -4yf(x), \text{otherwise} The example usage is: .. code-block:: python - cost = huber_cost(input=input_layer, - label=label_layer) + cost = huber_classification_cost(input=input_layer, label=label_layer) :param input: The first input layer. :type input: LayerOutput. @@ -5679,11 +5742,12 @@ def huber_cost(input, label, name=None, coeff=1.0, layer_attr=None): assert input.size == 1 Layer( name=name, - type=LayerType.HUBER, + type=LayerType.HUBER_CLASSIFICATION, inputs=[input.name, label.name], coeff=coeff, **ExtraLayerAttribute.to_kwargs(layer_attr)) - return LayerOutput(name, LayerType.HUBER, parents=[input, label], size=1) + return LayerOutput( + name, LayerType.HUBER_CLASSIFICATION, parents=[input, label], size=1) @wrap_name_default() diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_cost_layers.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_cost_layers.protostr index 05847344be60b..55ab464ddf88f 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_cost_layers.protostr +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_cost_layers.protostr @@ -167,6 +167,20 @@ layers { softmax_selfnorm_alpha: 0.1 coeff: 1.0 } +layers { + name: "__huber_regression_cost_0__" + type: "huber_regression" + size: 1 + active_type: "" + inputs { + input_layer_name: "input" + } + inputs { + input_layer_name: "labels" + } + coeff: 1.0 + delta: 1.0 +} layers { name: "huber_probs" type: "data" @@ -180,8 +194,8 @@ layers { active_type: "" } layers { - name: "__huber_cost_0__" - type: "huber" + name: "__huber_classification_cost_0__" + type: "huber_classification" size: 1 active_type: "" inputs { @@ -300,7 +314,8 @@ output_layer_names: "__rank_cost_0__" output_layer_names: "__lambda_cost_0__" output_layer_names: "__cross_entropy_0__" output_layer_names: "__cross_entropy_with_selfnorm_0__" -output_layer_names: "__huber_cost_0__" +output_layer_names: "__huber_regression_cost_0__" +output_layer_names: "__huber_classification_cost_0__" output_layer_names: "__multi_binary_label_cross_entropy_0__" output_layer_names: "__sum_cost_0__" output_layer_names: "__nce_layer_0__" @@ -324,9 +339,10 @@ sub_models { layer_names: "__lambda_cost_0__" layer_names: "__cross_entropy_0__" layer_names: "__cross_entropy_with_selfnorm_0__" + layer_names: "__huber_regression_cost_0__" layer_names: "huber_probs" layer_names: "huber_label" - layer_names: "__huber_cost_0__" + layer_names: "__huber_classification_cost_0__" layer_names: "__multi_binary_label_cross_entropy_0__" layer_names: "__sum_cost_0__" layer_names: "__nce_layer_0__" @@ -349,7 +365,8 @@ sub_models { output_layer_names: "__lambda_cost_0__" output_layer_names: "__cross_entropy_0__" output_layer_names: "__cross_entropy_with_selfnorm_0__" - output_layer_names: "__huber_cost_0__" + output_layer_names: "__huber_regression_cost_0__" + output_layer_names: "__huber_classification_cost_0__" output_layer_names: "__multi_binary_label_cross_entropy_0__" output_layer_names: "__sum_cost_0__" output_layer_names: "__nce_layer_0__" diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_cost_layers.py b/python/paddle/trainer_config_helpers/tests/configs/test_cost_layers.py index d2a3b702a1d7b..7ce375c708af7 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/test_cost_layers.py +++ b/python/paddle/trainer_config_helpers/tests/configs/test_cost_layers.py @@ -33,7 +33,9 @@ input=probs, label=xe_label), cross_entropy_with_selfnorm( input=probs, label=xe_label), - huber_cost( + huber_regression_cost( + input=seq_in, label=labels), + huber_classification_cost( input=data_layer( name='huber_probs', size=1), label=data_layer( diff --git a/python/paddle/v2/framework/tests/gradient_checker.py b/python/paddle/v2/framework/tests/gradient_checker.py index 9a7a7fbf5e63d..518f828bacd60 100644 --- a/python/paddle/v2/framework/tests/gradient_checker.py +++ b/python/paddle/v2/framework/tests/gradient_checker.py @@ -268,7 +268,7 @@ def check_grad(self, :param input_vars: numpy value of input variable. The following computation will use these variables. :param inputs_to_check: inputs var names that should check gradient. - :param output_name: output name that used to + :param output_name: the output variable name of forward network. :param max_relative_error: The relative tolerance parameter. :param no_grad_set: used when create backward ops :param only_cpu: only compute and check gradient on cpu kernel. diff --git a/python/paddle/v2/tests/test_layer.py b/python/paddle/v2/tests/test_layer.py index f2097e195f416..783a0ca85dc61 100644 --- a/python/paddle/v2/tests/test_layer.py +++ b/python/paddle/v2/tests/test_layer.py @@ -141,12 +141,13 @@ def test_cost_layer(self): cost8 = layer.rank_cost(left=score, right=score, label=score) cost9 = layer.lambda_cost(input=inference, score=score) cost10 = layer.sum_cost(input=inference) - cost11 = layer.huber_cost(input=score, label=label) + cost11 = layer.huber_regression_cost(input=score, label=label) + cost12 = layer.huber_classification_cost(input=score, label=label) print layer.parse_network([cost1, cost2]) print layer.parse_network([cost3, cost4]) print layer.parse_network([cost5, cost6]) - print layer.parse_network([cost7, cost8, cost9, cost10, cost11]) + print layer.parse_network([cost7, cost8, cost9, cost10, cost11, cost12]) crf = layer.crf(input=inference, label=label) crf_decoding = layer.crf_decoding(input=inference, size=3)