diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 30225a4a99d99..3b7cbcd98927b 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -58,6 +58,8 @@ class MulOpMaker : public OpProtoAndCheckerMaker { AddInput("X", "A"); AddInput("Y", "B"); AddOutput("Out", "Out"); + AddAttr("x_num_col_dims", "").SetDefault(1).EqualGreaterThan(1); + AddAttr("y_num_col_dims", "").SetDefault(1).EqualGreaterThan(1); AddComment("Mul"); } }; @@ -440,6 +442,28 @@ TEST(Backward, simple_single_op) { std::vector({f::GradVarName("b")})); } +TEST(Backward, default_attribute) { + f::ProgramDesc *program_desc = GetNewProgramDesc(); + f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc); + f::BlockDescBind *block = program.Block(0); + f::OpDescBind *op = block->AppendOp(); + op->SetType("mul"); + op->SetInput("X", {"x"}); + op->SetInput("Y", {"y"}); + op->SetOutput("Out", {"out"}); + + AppendBackward(program, {}); + + ASSERT_EQ(block->AllOps().size(), 2UL); + EXPECT_EQ(boost::get(op->GetAttr("x_num_col_dims")), 1); + EXPECT_EQ(boost::get(op->GetAttr("y_num_col_dims")), 1); + + f::OpDescBind *grad_op = block->AllOps()[1]; + ASSERT_EQ(grad_op->Type(), "mul_grad"); + EXPECT_EQ(boost::get(grad_op->GetAttr("x_num_col_dims")), 1); + EXPECT_EQ(boost::get(grad_op->GetAttr("y_num_col_dims")), 1); +} + TEST(Backward, simple_mult_op) { f::ProgramDesc *program_desc = GetNewProgramDesc(); f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc); diff --git a/paddle/framework/block_desc.h b/paddle/framework/block_desc.h index 2de270f60ec2a..3437e89923da8 100644 --- a/paddle/framework/block_desc.h +++ b/paddle/framework/block_desc.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include +#include #include #include #include "paddle/framework/op_desc.h" diff --git a/paddle/framework/data_type.h b/paddle/framework/data_type.h index 55e3931f870d6..649899d42572c 100644 --- a/paddle/framework/data_type.h +++ b/paddle/framework/data_type.h @@ -28,7 +28,6 @@ inline DataType ToDataType(std::type_index type) { return DataType::INT32; } else { PADDLE_THROW("Not supported"); - return static_cast(-1); } } diff --git a/paddle/framework/framework.proto b/paddle/framework/framework.proto index ac2827e54773f..b7a63f9ba10b7 100644 --- a/paddle/framework/framework.proto +++ b/paddle/framework/framework.proto @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ syntax = "proto2"; +option optimize_for = LITE_RUNTIME; package paddle.framework; enum AttrType { diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index 02aa74a8420a5..c2e796b7c1b6e 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -25,6 +25,7 @@ OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs, inputs_ = inputs; outputs_ = outputs; attrs_ = attrs; + need_update_ = true; } OpDesc *OpDescBind::Proto() { diff --git a/paddle/framework/op_desc.h b/paddle/framework/op_desc.h index b729029412661..d0c314771c04d 100644 --- a/paddle/framework/op_desc.h +++ b/paddle/framework/op_desc.h @@ -52,8 +52,6 @@ class OpDescBind { void SetOutput(const std::string ¶m_name, const std::vector &args); - std::string DebugString() { return this->Proto()->DebugString(); } - bool HasAttr(const std::string &name) const { return attrs_.find(name) != attrs_.end(); } diff --git a/paddle/framework/program_desc.h b/paddle/framework/program_desc.h index 9b34a06aeff94..f29b1c54e7160 100644 --- a/paddle/framework/program_desc.h +++ b/paddle/framework/program_desc.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include "paddle/framework/framework.pb.h" #include "paddle/platform/macros.h" @@ -31,8 +32,6 @@ class ProgramDescBind { BlockDescBind *Block(size_t idx) { return blocks_[idx].get(); } - std::string DebugString() { return Proto()->DebugString(); } - size_t Size() const { return blocks_.size(); } ProgramDesc *Proto(); diff --git a/paddle/framework/type_defs.h b/paddle/framework/type_defs.h index a5b94722136ec..6f65a942ba2a4 100644 --- a/paddle/framework/type_defs.h +++ b/paddle/framework/type_defs.h @@ -15,6 +15,7 @@ #pragma once #include #include +#include #include "paddle/platform/variant.h" namespace paddle { diff --git a/paddle/math/tests/test_GpuProfiler.cpp b/paddle/math/tests/test_GpuProfiler.cpp index 9402bd3ec48fb..d9f146f0d1f63 100644 --- a/paddle/math/tests/test_GpuProfiler.cpp +++ b/paddle/math/tests/test_GpuProfiler.cpp @@ -162,4 +162,4 @@ int main(int argc, char** argv) { return RUN_ALL_TESTS(); } -#endif /* PADDLE_ONLY_CPU */ +#endif diff --git a/paddle/memory/detail/buddy_allocator.cc b/paddle/memory/detail/buddy_allocator.cc index fdc5ed19dc297..e212f7737a409 100644 --- a/paddle/memory/detail/buddy_allocator.cc +++ b/paddle/memory/detail/buddy_allocator.cc @@ -182,7 +182,7 @@ BuddyAllocator::PoolSet::iterator BuddyAllocator::RefillPool() { max_chunk_size_ = platform::GpuMaxChunkSize(); } } -#endif // PADDLE_ONLY_CPU +#endif // Allocate a new maximum sized block size_t index = 0; diff --git a/paddle/memory/detail/system_allocator.cc b/paddle/memory/detail/system_allocator.cc index 6c9a46dd09c15..33166d9ce23a4 100644 --- a/paddle/memory/detail/system_allocator.cc +++ b/paddle/memory/detail/system_allocator.cc @@ -134,7 +134,7 @@ void GPUAllocator::Free(void* p, size_t size, size_t index) { bool GPUAllocator::UseGpu() const { return true; } -#endif // PADDLE_ONLY_CPU +#endif } // namespace detail } // namespace memory diff --git a/paddle/memory/detail/system_allocator.h b/paddle/memory/detail/system_allocator.h index ee9b012f91a96..552cab4f96ff2 100644 --- a/paddle/memory/detail/system_allocator.h +++ b/paddle/memory/detail/system_allocator.h @@ -51,7 +51,7 @@ class GPUAllocator : public SystemAllocator { size_t gpu_alloc_size_ = 0; size_t fallback_alloc_size_ = 0; }; -#endif // PADDLE_ONLY_CPU +#endif } // namespace detail } // namespace memory diff --git a/paddle/memory/detail/system_allocator_test.cc b/paddle/memory/detail/system_allocator_test.cc index cd563844e7fa2..6a8558937bf0c 100644 --- a/paddle/memory/detail/system_allocator_test.cc +++ b/paddle/memory/detail/system_allocator_test.cc @@ -62,4 +62,4 @@ TEST(GPUAllocator, Alloc) { TestAllocator(a, 2048); TestAllocator(a, 0); } -#endif // PADDLE_ONLY_CPU +#endif diff --git a/paddle/memory/memcpy.cc b/paddle/memory/memcpy.cc index 790420a8ab41b..1df88a6da9fb0 100644 --- a/paddle/memory/memcpy.cc +++ b/paddle/memory/memcpy.cc @@ -89,7 +89,7 @@ void Copy(platform::GPUPlace dst_place, platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToDevice); } -#endif // PADDLE_ONLY_CPU +#endif } // namespace memory } // namespace paddle diff --git a/paddle/memory/memcpy.h b/paddle/memory/memcpy.h index 0bccee58c3a22..9b36182c2b619 100644 --- a/paddle/memory/memcpy.h +++ b/paddle/memory/memcpy.h @@ -53,7 +53,7 @@ template void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num, cudaStream_t stream); -#endif // PADDLE_ONLY_CPU +#endif } // namespace memory } // namespace paddle diff --git a/paddle/memory/memory.cc b/paddle/memory/memory.cc index 30ce8a82e16ed..5087c02385f7f 100644 --- a/paddle/memory/memory.cc +++ b/paddle/memory/memory.cc @@ -111,7 +111,7 @@ size_t Used(platform::GPUPlace place) { return GetGPUBuddyAllocator(place.device)->Used(); } -#endif // PADDLE_ONLY_CPU +#endif } // namespace memory } // namespace paddle diff --git a/paddle/memory/memory_test.cc b/paddle/memory/memory_test.cc index 0d402038a06f4..2444931e26774 100644 --- a/paddle/memory/memory_test.cc +++ b/paddle/memory/memory_test.cc @@ -135,4 +135,4 @@ TEST(BuddyAllocator, GPUMultAlloc) { } } -#endif // PADDLE_ONLY_CPU +#endif diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 4adb5a088f810..7dae8fe2f99f9 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -55,12 +55,20 @@ function(op_library TARGET) set(pybind_flag 1) endif() + # pool_op contains several operators if ("${TARGET}" STREQUAL "pool_op") set(pybind_flag 1) # It's enough to just adding one operator to pybind file(APPEND ${pybind_file} "USE_OP(pool2d);\n") endif() + # pool_with_index_op contains several operators + if ("${TARGET}" STREQUAL "pool_with_index_op") + set(pybind_flag 1) + # It's enough to just adding one operator to pybind + file(APPEND ${pybind_file} "USE_OP(max_pool2d_with_index);\n") + endif() + # activation_op contains several operators if ("${TARGET}" STREQUAL "activation_op") set(pybind_flag 1) diff --git a/paddle/operators/activation_op.cc b/paddle/operators/activation_op.cc index 5e5df49b0788c..92db62907924d 100644 --- a/paddle/operators/activation_op.cc +++ b/paddle/operators/activation_op.cc @@ -201,6 +201,27 @@ class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker { } }; +template +class ELUOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ELUOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(Tensor) The input of ELU operator, it shouldn't be empty. Input " + "is flattened and treated as a 1D array."); + AddOutput("Y", + "(Tensor) The output of ELU operator. It has the same shape as " + "the input."); + AddAttr( + "alpha", "(float, default 1.0) Alpha value in the elu formulation.") + .SetDefault(static_cast(1.)); + AddComment(R"DOC( + ELU activation operator. It applies this element-wise computation on + the input: f(x) = max(0, x) + min(0, alpha * (exp(x) - 1)). + Check .. _Link: https://arxiv.org/abs/1511.07289 for more details.)DOC"); + } +}; + template class Relu6OpMaker : public framework::OpProtoAndCheckerMaker { public: @@ -289,6 +310,9 @@ REGISTER_OP(leaky_relu, ops::ActivationOp, ops::LeakyReluOpMaker, REGISTER_OP(soft_relu, ops::ActivationOp, ops::SoftReluOpMaker, soft_relu_grad, ops::ActivationOpGrad); +REGISTER_OP(elu, ops::ActivationOp, ops::ELUOpMaker, elu_grad, + ops::ActivationOpGrad); + REGISTER_OP(relu6, ops::ActivationOp, ops::Relu6OpMaker, relu6_grad, ops::ActivationOpGrad); diff --git a/paddle/operators/activation_op.h b/paddle/operators/activation_op.h index f127468125c26..123f0c4dbca65 100644 --- a/paddle/operators/activation_op.h +++ b/paddle/operators/activation_op.h @@ -384,6 +384,35 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor { } }; +template +struct ELUFunctor : public BaseActivationFunctor { + float alpha; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"alpha", &alpha}}; + } + + template + void operator()(Device d, X x, Y y) const { + y.device(d) = + x.cwiseMax(static_cast(0)) + + (alpha * (x.exp() - static_cast(1))).cwiseMin(static_cast(0)); + } +}; + +template +struct ELUGradFunctor : public BaseActivationFunctor { + float alpha; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"alpha", &alpha}}; + } + template + void operator()(Device d, X x, Y y, dY dy, dX dx) const { + dx.device(d) = + dy * (x > static_cast(0)).template cast() + + dy * (y + alpha) * (x < static_cast(0)).template cast(); + } +}; + template struct PowFunctor : public BaseActivationFunctor { float factor; @@ -440,21 +469,22 @@ struct STanhGradFunctor : public BaseActivationFunctor { } // namespace operators } // namespace paddle -#define FOR_EACH_KERNEL_FUNCTOR(__macro) \ - __macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \ - __macro(exp, ExpFunctor, ExpGradFunctor); \ - __macro(relu, ReluFunctor, ReluGradFunctor); \ - __macro(tanh, TanhFunctor, TanhGradFunctor); \ - __macro(sqrt, SqrtFunctor, SqrtGradFunctor); \ - __macro(abs, AbsFunctor, AbsGradFunctor); \ - __macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \ - __macro(log, LogFunctor, LogGradFunctor); \ - __macro(square, SquareFunctor, SquareGradFunctor); \ - __macro(brelu, BReluFunctor, BReluGradFunctor); \ - __macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor); \ - __macro(pow, PowFunctor, PowGradFunctor); \ - __macro(stanh, STanhFunctor, STanhGradFunctor); \ - __macro(softsign, SoftsignFunctor, SoftsignGradFunctor); \ - __macro(relu6, Relu6Functor, Relu6GradFunctor); \ - __macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \ - __macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor) +#define FOR_EACH_KERNEL_FUNCTOR(__macro) \ + __macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \ + __macro(exp, ExpFunctor, ExpGradFunctor); \ + __macro(relu, ReluFunctor, ReluGradFunctor); \ + __macro(tanh, TanhFunctor, TanhGradFunctor); \ + __macro(sqrt, SqrtFunctor, SqrtGradFunctor); \ + __macro(abs, AbsFunctor, AbsGradFunctor); \ + __macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \ + __macro(log, LogFunctor, LogGradFunctor); \ + __macro(square, SquareFunctor, SquareGradFunctor); \ + __macro(brelu, BReluFunctor, BReluGradFunctor); \ + __macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor); \ + __macro(pow, PowFunctor, PowGradFunctor); \ + __macro(stanh, STanhFunctor, STanhGradFunctor); \ + __macro(softsign, SoftsignFunctor, SoftsignGradFunctor); \ + __macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \ + __macro(relu6, Relu6Functor, Relu6GradFunctor); \ + __macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \ + __macro(elu, ELUFunctor, ELUGradFunctor) diff --git a/paddle/operators/dynamic_recurrent_op.cc b/paddle/operators/dynamic_recurrent_op.cc index 909992b450788..aa98aa51b3d7d 100644 --- a/paddle/operators/dynamic_recurrent_op.cc +++ b/paddle/operators/dynamic_recurrent_op.cc @@ -53,7 +53,7 @@ class DynamicRecurrentOpProtoAndCheckerMaker "names of pre-memories"); AddAttr>(name.memories, "names of memories"); - AddComment("This is a recurrent group operator."); + AddComment("This is a RNN operator for varience-length sequences."); } }; @@ -65,9 +65,10 @@ void DynamicRecurrentOp::Run(const Scope& scope, WriteStepInputs(); InitStates(); + // call stepnet in all the time steps for (size_t step = 0; step < cache_.num_steps; step++) { - // call stepnet - stepnet_->Run(scope, dev_ctx); + auto& step_scope = cache_.GetScope(step); + stepnet_->Run(step_scope, dev_ctx); } WriteStepOutputs(); @@ -96,10 +97,10 @@ void DynamicRecurrentOp::SplitInputs() const { } void DynamicRecurrentOp::WriteStepInputs() const { - const auto& inlinks = cache_.inlinks; - for (auto& item : inlinks) { + for (auto& item : cache_.inlinks) { auto ta_it = step_inputs_.find(item.first); - PADDLE_ENFORCE(ta_it != step_inputs_.end(), ""); + PADDLE_ENFORCE(ta_it != step_inputs_.end(), + "step_inputs_ not compatible with memory set"); TensorArray& ta = step_inputs_[item.first]; for (size_t step = 0; step < ta.size(); step++) { auto tensor = ta.Read(step); @@ -178,8 +179,8 @@ void DynamicRecurrentOp::InitStates() const { const auto& dims = boot_state.dims(); for (size_t step = 0; step < cache_.num_steps; step++) { - // link pre-state to boot_state auto& cur_scope = cache_.GetScope(step); + // link pre-state to boot_state // init state and pre-state auto* pre_state = cur_scope.FindVar(memory.pre_var); PADDLE_ENFORCE_NOT_NULL(pre_state); @@ -194,9 +195,9 @@ void DynamicRecurrentOp::InitStates() const { states_[memory.var].WriteShared(step, state->Get()); // link previous scope's state to the pre-states in current scope if (step == 0) { - auto* cur_state_tensor = pre_state->GetMutable(); - cur_state_tensor->Resize(boot_state.dims()); - cur_state_tensor->ShareDataWith(boot_state); + auto* pre_state_tensor = pre_state->GetMutable(); + pre_state_tensor->Resize(boot_state.dims()); + pre_state_tensor->ShareDataWith(boot_state); } else { auto& pre_scope = cache_.GetScope(step - 1); auto* state_pre = pre_scope.FindVar(memory.var); diff --git a/paddle/operators/dynamic_recurrent_op_test.cc b/paddle/operators/dynamic_recurrent_op_test.cc index 5f06c14ef5be1..f7778415b7c0f 100644 --- a/paddle/operators/dynamic_recurrent_op_test.cc +++ b/paddle/operators/dynamic_recurrent_op_test.cc @@ -55,7 +55,7 @@ class DynamicRecurrentOpTestHelper : public ::testing::Test { framework::OpDesc CreateOpDesc() { // create op paddle::framework::OpDesc op_desc; - op_desc.set_type("dynamic_recurrent_op"); + op_desc.set_type("dynamic_recurrent"); OpDescNewVar(argname.inlinks, {"in0"}, op_desc.add_inputs()); OpDescNewVar(argname.boot_memories, {"boot_mem"}, op_desc.add_inputs()); diff --git a/paddle/operators/fill_constant_op.cc b/paddle/operators/fill_constant_op.cc new file mode 100644 index 0000000000000..65d03d5fa4426 --- /dev/null +++ b/paddle/operators/fill_constant_op.cc @@ -0,0 +1,68 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/fill_constant_op.h" + +namespace paddle { +namespace operators { + +class FillConstantOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of FillConstantOp should not be null."); + auto &shape = ctx->Attrs().Get>("shape"); + std::vector shape_int64(shape.size(), 0); + std::transform(shape.begin(), shape.end(), shape_int64.begin(), + [](int a) { return static_cast(a); }); + auto dims = framework::make_ddim(shape_int64); + ctx->SetOutputDim("Out", dims); + } + + framework::DataType IndicateDataType( + const framework::ExecutionContext &ctx) const override { + return static_cast(ctx.Attr("dataType")); + } +}; + +class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker { + public: + FillConstantOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : framework::OpProtoAndCheckerMaker(proto, op_checker) { + AddAttr("dataType", + "(int, default 5 (FP32)) " + "Output data type") + .SetDefault(framework::DataType::FP32); + AddAttr>("shape", "(vector) The shape of the output"); + AddAttr("value", "(float, default 0) The value to be filled") + .SetDefault(0.0f); + AddOutput("Out", + "(Tensor) Tensor of specified shape will be filled " + "with the specified value"); + AddComment(R"DOC(Fill up a variable with specified constant value.)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(fill_constant, ops::FillConstantOp, + ops::FillConstantOpMaker); +REGISTER_OP_CPU_KERNEL( + fill_constant, + ops::FillConstantOpKernel); diff --git a/paddle/operators/fill_constant_op.cu b/paddle/operators/fill_constant_op.cu new file mode 100644 index 0000000000000..eef8fcbd7f65a --- /dev/null +++ b/paddle/operators/fill_constant_op.cu @@ -0,0 +1,22 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/framework/op_registry.h" +#include "paddle/operators/fill_constant_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL( + fill_constant, + ops::FillConstantOpKernel); diff --git a/paddle/operators/fill_constant_op.h b/paddle/operators/fill_constant_op.h new file mode 100644 index 0000000000000..53b8b548eca6d --- /dev/null +++ b/paddle/operators/fill_constant_op.h @@ -0,0 +1,37 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class FillConstantOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + auto value = ctx.Attr("value"); + + auto out_eigen = framework::EigenVector::Flatten(*out); + auto place = ctx.GetEigenDevice(); + out_eigen.device(place) = out_eigen.constant(static_cast(value)); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/interp_op.cc b/paddle/operators/interp_op.cc new file mode 100644 index 0000000000000..d02b01c3f3a1b --- /dev/null +++ b/paddle/operators/interp_op.cc @@ -0,0 +1,113 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/net_op.h" + +namespace paddle { +namespace operators { + +class InterpOp : public NetOp { + public: + InterpOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : NetOp(type, inputs, outputs, attrs) { + PADDLE_ENFORCE_NE(Input("X"), framework::kEmptyVarName, + "Input(X) of InterpOp should not be null."); + PADDLE_ENFORCE_NE(Input("Y"), framework::kEmptyVarName, + "Input(Y) of InterpOp should not be null."); + PADDLE_ENFORCE_NE(Input("W"), framework::kEmptyVarName, + "Input(W) of InterpOp should not be null."); + PADDLE_ENFORCE_NE(Output("SubOut"), framework::kEmptyVarName, + "Output(SubOut) of InterpOp should not be null."); + PADDLE_ENFORCE_NE(Output("MulOut"), framework::kEmptyVarName, + "Output(MulOut) of InterpOp should not be null."); + PADDLE_ENFORCE_NE(Output("Out"), framework::kEmptyVarName, + "Output(Out) of InterpOp should not be null."); + + // SubOut = X - Y + auto x = Input("X"); + auto y = Input("Y"); + auto sub_out = Output("SubOut"); + AppendOp(framework::OpRegistry::CreateOp( + "elementwise_sub", {{"X", {x}}, {"Y", {y}}}, {{"Out", {sub_out}}}, {})); + + // MulOut = SubOut * W = (X - Y) * W + auto w = Input("W"); + auto mul_out = Output("MulOut"); + AppendOp(framework::OpRegistry::CreateOp( + "elementwise_mul", {{"X", {sub_out}}, {"Y", {w}}}, {{"Out", {mul_out}}}, + {{"axis", 0}})); + + // Out = MulOut + Y = (X - Y) * W + Y = X * W + Y * (1 - W) + AppendOp(framework::OpRegistry::CreateOp("elementwise_add", + {{"X", {mul_out}}, {"Y", {y}}}, + {{"Out", {Output("Out")}}}, {})); + + CompleteAddOp(false); + } +}; + +class InterpOpMaker : public framework::OpProtoAndCheckerMaker { + public: + InterpOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(Tensor), 2-D Matrix of shape [batch_size, data_dim]" + "containing data samples, the first input of interp_op"); + AddInput("Y", + "(Tensor), 2-D Matrix of shape `[batch_size, data_dim]`" + "containing data samples, the second input of interp_op"); + AddInput("W", + "(Tensor), 1-D Vector of shape [batch_size]," + "the interpolated values in the half-open interval [0.0, 1.0)"); + AddOutput("SubOut", + "(Tensor), the intermediate subtraction outputs, saving X - Y.") + .AsIntermediate(); + AddOutput("MulOut", + "(Tensor), the intermediate multiplication outputs," + "saving the elementwise multiplication of (X - Y) and W.") + .AsIntermediate(); + AddOutput("Out", + "(Tensor), the output of interp_op, same shape with X," + "returns the first-dimensional piecewise linear interpolant " + "between X and Y"); + AddComment(R"DOC( + Linear Interpolation with two inputs, used in NEURAL TURING MACHINE. + + Equation: + Out.row[i] = X.row[i] * W[i] + Y.row[i] * (1 - W[i]) + = (X.row[i] - Y.row[i]) * W[i] + Y.row[i] + + Example: + X = [[1,2],[3,4]], + Y = [[2,1],[4,3]], + W = [0.3, 0.4] + + Then, Out = [[1.7,1.3],[3.6,3.4]] + + where 1.7 = 1*0.3+2*(1-0.3), + 1.3 = 2*0.3+1*(1-0.3), + 3.6 = 3*0.4+4*(1-0.4), + 3.4 = 4*0.4+3*(1-0.4) +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(interp, ops::InterpOp, ops::InterpOpMaker); diff --git a/paddle/operators/math/pooling.cc b/paddle/operators/math/pooling.cc index 3b706529d8f1e..50cfb88bb5700 100644 --- a/paddle/operators/math/pooling.cc +++ b/paddle/operators/math/pooling.cc @@ -18,6 +18,11 @@ namespace paddle { namespace operators { namespace math { +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ template class Pool2dFunctor { public: @@ -73,6 +78,11 @@ class Pool2dFunctor { } }; +/* +* All tensors are in NCHW format. +* Ksize, strides, paddings are two elements. These two elements represent height +* and width, respectively. +*/ template class Pool2dGradFunctor { public: @@ -135,6 +145,11 @@ class Pool2dGradFunctor { } }; +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ template class MaxPool2dGradFunctor { public: @@ -197,7 +212,7 @@ class MaxPool2dGradFunctor { }; template class MaxPool2dGradFunctor; -// template class MaxPool2dGradFunctor; +template class MaxPool2dGradFunctor; template class Pool2dFunctor, float>; @@ -216,6 +231,11 @@ template class Pool2dGradFunctor< template class Pool2dGradFunctor< platform::CPUPlace, paddle::operators::math::AvgPoolGrad, double>; +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ template class Pool3dFunctor { public: @@ -286,6 +306,11 @@ class Pool3dFunctor { } }; +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ template class Pool3dGradFunctor { public: @@ -364,6 +389,11 @@ class Pool3dGradFunctor { } }; +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ template class MaxPool3dGradFunctor { public: @@ -440,7 +470,7 @@ class MaxPool3dGradFunctor { }; template class MaxPool3dGradFunctor; -// template class MaxPool3dGradFunctor; +template class MaxPool3dGradFunctor; template class Pool3dFunctor, float>; @@ -458,6 +488,253 @@ template class Pool3dGradFunctor< platform::CPUPlace, paddle::operators::math::MaxPoolGrad, double>; template class Pool3dGradFunctor< platform::CPUPlace, paddle::operators::math::AvgPoolGrad, double>; + +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ +template +class MaxPool2dWithIndexFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input.dims()[0]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output.dims()[1]; + const int output_height = output.dims()[2]; + const int output_width = output.dims()[3]; + const int ksize_height = ksize[0]; + const int ksize_width = ksize[1]; + const int stride_height = strides[0]; + const int stride_width = strides[1]; + const int padding_height = paddings[0]; + const int padding_width = paddings[1]; + const int input_stride = input_height * input_width; + const int output_stride = output_height * output_width; + + const T* input_data = input.data(); + T* output_data = output.mutable_data(context.GetPlace()); + T* mask_data = mask.mutable_data(context.GetPlace()); + + for (int i = 0; i < batch_size; i++) { + for (int c = 0; c < output_channels; ++c) { + for (int ph = 0; ph < output_height; ++ph) { + int hstart = ph * stride_height - padding_height; + int hend = std::min(hstart + ksize_height, input_height); + hstart = std::max(hstart, 0); + for (int pw = 0; pw < output_width; ++pw) { + int wstart = pw * stride_width - padding_width; + int wend = std::min(wstart + ksize_width, input_width); + wstart = std::max(wstart, 0); + + T ele = static_cast(-FLT_MAX); + int index = -1; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + if (ele < input_data[h * input_width + w]) { + ele = input_data[h * input_width + w]; + index = h * input_width + w; + } + } + } + output_data[ph * output_width + pw] = ele; + mask_data[ph * output_width + pw] = index; + } + } + // offset + input_data += input_stride; + output_data += output_stride; + mask_data += output_stride; + } + } + } +}; + +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ +template +class MaxPool2dWithIndexGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& input_grad, + const framework::Tensor& output_grad, + const framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input_grad.dims()[0]; + const int input_height = input_grad.dims()[2]; + const int input_width = input_grad.dims()[3]; + const int output_channels = output_grad.dims()[1]; + const int output_height = output_grad.dims()[2]; + const int output_width = output_grad.dims()[3]; + const int input_stride = input_height * input_width; + const int output_stride = output_height * output_width; + + const T* mask_data = mask.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + + for (int n = 0; n < batch_size; ++n) { + for (int c = 0; c < output_channels; ++c) { + for (int ph = 0; ph < output_height; ++ph) { + for (int pw = 0; pw < output_width; ++pw) { + const int output_idx = ph * output_width + pw; + const int input_idx = static_cast(mask_data[output_idx]); + input_grad_data[input_idx] += output_grad_data[output_idx]; + } + } + // offset + input_grad_data += input_stride; + output_grad_data += output_stride; + mask_data += output_stride; + } + } + } +}; + +template class MaxPool2dWithIndexFunctor; +template class MaxPool2dWithIndexGradFunctor; +template class MaxPool2dWithIndexFunctor; +template class MaxPool2dWithIndexGradFunctor; + +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ +template +class MaxPool3dWithIndexFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input.dims()[0]; + const int input_depth = input.dims()[2]; + const int input_height = input.dims()[3]; + const int input_width = input.dims()[4]; + const int output_channels = output.dims()[1]; + const int output_depth = output.dims()[2]; + const int output_height = output.dims()[3]; + const int output_width = output.dims()[4]; + const int ksize_depth = ksize[0]; + const int ksize_height = ksize[1]; + const int ksize_width = ksize[2]; + const int stride_depth = strides[0]; + const int stride_height = strides[1]; + const int stride_width = strides[2]; + const int padding_depth = paddings[0]; + const int padding_height = paddings[1]; + const int padding_width = paddings[2]; + const int input_stride = input_depth * input_height * input_width; + const int output_stride = output_depth * output_height * output_width; + + const T* input_data = input.data(); + T* output_data = output.mutable_data(context.GetPlace()); + T* mask_data = mask.mutable_data(context.GetPlace()); + + for (int i = 0; i < batch_size; i++) { + for (int c = 0; c < output_channels; ++c) { + for (int pd = 0; pd < output_depth; ++pd) { + int dstart = pd * stride_depth - padding_depth; + int dend = std::min(dstart + ksize_depth, input_depth); + dstart = std::max(dstart, 0); + for (int ph = 0; ph < output_height; ++ph) { + int hstart = ph * stride_height - padding_height; + int hend = std::min(hstart + ksize_height, input_height); + hstart = std::max(hstart, 0); + for (int pw = 0; pw < output_width; ++pw) { + int wstart = pw * stride_width - padding_width; + int wend = std::min(wstart + ksize_width, input_width); + wstart = std::max(wstart, 0); + + int output_idx = (pd * output_height + ph) * output_width + pw; + T ele = static_cast(-FLT_MAX); + int index = -1; + for (int d = dstart; d < dend; ++d) { + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int input_idx = (d * input_height + h) * input_width + w; + if (ele < input_data[input_idx]) { + index = input_idx; + ele = input_data[input_idx]; + } + } + } + } + output_data[output_idx] = ele; + mask_data[output_idx] = index; + } + } + } + // offset + input_data += input_stride; + output_data += output_stride; + mask_data += output_stride; + } + } + } +}; + +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ +template +class MaxPool3dWithIndexGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& input_grad, + const framework::Tensor& output_grad, + const framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input_grad.dims()[0]; + const int input_depth = input_grad.dims()[2]; + const int input_height = input_grad.dims()[3]; + const int input_width = input_grad.dims()[4]; + const int output_channels = output_grad.dims()[1]; + const int output_depth = output_grad.dims()[2]; + const int output_height = output_grad.dims()[3]; + const int output_width = output_grad.dims()[4]; + const int input_stride = input_depth * input_height * input_width; + const int output_stride = output_depth * output_height * output_width; + + const T* mask_data = mask.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + + for (int n = 0; n < batch_size; ++n) { + for (int c = 0; c < output_channels; ++c) { + for (int pd = 0; pd < output_depth; ++pd) { + for (int ph = 0; ph < output_height; ++ph) { + for (int pw = 0; pw < output_width; ++pw) { + const int output_idx = + (pd * output_height + ph) * output_width + pw; + const int input_idx = static_cast(mask_data[output_idx]); + input_grad_data[input_idx] += output_grad_data[output_idx]; + } + } + } + // offset + input_grad_data += input_stride; + output_grad_data += output_stride; + mask_data += output_stride; + } + } + } +}; + +template class MaxPool3dWithIndexFunctor; +template class MaxPool3dWithIndexGradFunctor; +template class MaxPool3dWithIndexFunctor; +template class MaxPool3dWithIndexGradFunctor; } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/pooling.cu b/paddle/operators/math/pooling.cu index 8aeccd1f8e885..736327f4b7b9e 100644 --- a/paddle/operators/math/pooling.cu +++ b/paddle/operators/math/pooling.cu @@ -144,11 +144,16 @@ __global__ void KernelMaxPool2DGrad( if (maxIndex != -1) { // atomic add - atomicAdd(input_grad + maxIndex, output_grad[index]); + platform::CudaAtomicAdd(input_grad + maxIndex, output_grad[index]); } } } +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ template class Pool2dFunctor { public: @@ -190,6 +195,11 @@ class Pool2dFunctor { } }; +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ template class Pool2dGradFunctor { public: @@ -234,6 +244,11 @@ class Pool2dGradFunctor { } }; +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ template class MaxPool2dGradFunctor { public: @@ -278,9 +293,7 @@ class MaxPool2dGradFunctor { }; template class MaxPool2dGradFunctor; -// template class MaxPool2dGradFunctor; // The -// 64-bit floating-point version of atomicAdd() is only supported by devices of -// compute capability 6.x and higher. +template class MaxPool2dGradFunctor; template class Pool2dFunctor, float>; @@ -453,11 +466,16 @@ __global__ void KernelMaxPool3DGrad( } if (maxIdx != -1) { // atomic add - atomicAdd(input_grad + maxIdx, output_grad[index]); + platform::CudaAtomicAdd(input_grad + maxIdx, output_grad[index]); } } } +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ template class Pool3dFunctor { public: @@ -506,6 +524,11 @@ class Pool3dFunctor { } }; +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ template class Pool3dGradFunctor { public: @@ -558,6 +581,11 @@ class Pool3dGradFunctor { } }; +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ template class MaxPool3dGradFunctor { public: @@ -609,9 +637,7 @@ class MaxPool3dGradFunctor { }; template class MaxPool3dGradFunctor; -// template class MaxPool3dGradFunctor; // The -// 64-bit floating-point version of atomicAdd() is only supported by devices of -// compute capability 6.x and higher. +template class MaxPool3dGradFunctor; template class Pool3dFunctor, float>; @@ -630,6 +656,404 @@ template class Pool3dGradFunctor< template class Pool3dGradFunctor< platform::GPUPlace, paddle::operators::math::AvgPoolGrad, double>; +template +__global__ void KernelMaxPool2dWithIdx( + const int nthreads, const T* input_data, T* output_data, T* mask_data, + const int channels, const int input_height, const int input_width, + const int output_height, const int output_width, const int ksize_height, + const int ksize_width, const int stride_height, const int stride_width, + const int padding_height, const int padding_width) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; + index += blockDim.x * gridDim.x) { + int pw = index % output_width; + int ph = (index / output_width) % output_height; + int c = (index / output_width / output_height) % channels; + int batch_idx = index / output_width / output_height / channels; + + int hstart = ph * stride_height - padding_height; + int hend = min(hstart + ksize_height, input_height); + hstart = max(hstart, 0); + + int wstart = pw * stride_width - padding_width; + int wend = min(wstart + ksize_width, input_width); + wstart = max(wstart, 0); + + input_data += (batch_idx * channels + c) * input_height * input_width; + T ele = -FLT_MAX; + int max_index = -1; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int input_index = h * input_width + w; + if (ele < input_data[input_index]) { + max_index = input_index; + ele = input_data[input_index]; + } + } + } + output_data[index] = ele; + mask_data[index] = max_index; + } +} + +template +__global__ void KernelMaxPool2DWithIdxGrad( + const int nthreads, T* input_grad, const T* output_grad, const T* mask_data, + const int channels, const int input_height, const int input_width, + const int output_height, const int output_width, const int ksize_height, + const int ksize_width, const int stride_height, const int stride_width, + const int padding_height, const int padding_width) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; + index += blockDim.x * gridDim.x) { + int w_offset = index % input_width; + int h_offset = (index / input_width) % input_height; + int c_offset = (index / input_width / input_height) % channels; + int batch_idx = index / input_width / input_height / channels; + + int ph_start = + (h_offset + padding_height < ksize_height) + ? 0 + : (h_offset + padding_height - ksize_height) / stride_height + 1; + int pw_start = + (w_offset + padding_width < ksize_width) + ? 0 + : (w_offset + padding_width - ksize_width) / stride_width + 1; + int ph_end = + min((h_offset + padding_height) / stride_height + 1, output_height); + int pw_end = + min((w_offset + padding_width) / stride_width + 1, output_width); + + T gradient = 0; + int input_current_featuremap_idx = h_offset * input_width + w_offset; + int output_idx = + (batch_idx * channels + c_offset) * output_height * output_width; + + mask_data += output_idx; + output_grad += output_idx; + for (int ph = ph_start; ph < ph_end; ++ph) { + for (int pw = pw_start; pw < pw_end; ++pw) { + if (mask_data[ph * output_width + pw] == input_current_featuremap_idx) + gradient += output_grad[ph * output_width + pw]; + } + } + input_grad[index] = gradient; + } +} + +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ +template +class MaxPool2dWithIndexFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input.dims()[0]; + const int input_channels = input.dims()[1]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output.dims()[1]; + const int output_height = output.dims()[2]; + const int output_width = output.dims()[3]; + const int ksize_height = ksize[0]; + const int ksize_width = ksize[1]; + const int stride_height = strides[0]; + const int stride_width = strides[1]; + const int padding_height = paddings[0]; + const int padding_width = paddings[1]; + + const T* input_data = input.data(); + T* output_data = output.mutable_data(context.GetPlace()); + T* mask_data = mask.mutable_data(context.GetPlace()); + + int nthreads = batch_size * output_channels * output_height * output_width; + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelMaxPool2dWithIdx< + T><<(context) + .stream()>>>(nthreads, input_data, output_data, mask_data, + input_channels, input_height, input_width, + output_height, output_width, ksize_height, + ksize_width, stride_height, stride_width, + padding_height, padding_width); + } +}; + +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ +template +class MaxPool2dWithIndexGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& input_grad, + const framework::Tensor& output_grad, + const framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input_grad.dims()[0]; + const int input_channels = input_grad.dims()[1]; + const int input_height = input_grad.dims()[2]; + const int input_width = input_grad.dims()[3]; + const int output_height = output_grad.dims()[2]; + const int output_width = output_grad.dims()[3]; + const int ksize_height = ksize[0]; + const int ksize_width = ksize[1]; + const int stride_height = strides[0]; + const int stride_width = strides[1]; + const int padding_height = paddings[0]; + const int padding_width = paddings[1]; + + const T* mask_data = mask.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + + int nthreads = batch_size * input_channels * input_height * input_width; + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelMaxPool2DWithIdxGrad< + T><<(context) + .stream()>>>(nthreads, input_grad_data, output_grad_data, + mask_data, input_channels, input_height, + input_width, output_height, output_width, + ksize_height, ksize_width, stride_height, + stride_width, padding_height, padding_width); + } +}; + +template class MaxPool2dWithIndexFunctor; +template class MaxPool2dWithIndexGradFunctor; +template class MaxPool2dWithIndexFunctor; +template class MaxPool2dWithIndexGradFunctor; + +template +__global__ void KernelMaxPool3DWithIdx( + const int nthreads, const T* input_data, T* output_data, T* mask_data, + const int channels, const int input_depth, const int input_height, + const int input_width, const int output_depth, const int output_height, + const int output_width, const int ksize_depth, const int ksize_height, + const int ksize_width, const int stride_depth, const int stride_height, + const int stride_width, const int padding_depth, const int padding_height, + const int padding_width) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; + index += blockDim.x * gridDim.x) { + int pw = index % output_width; + int ph = (index / output_width) % output_height; + int pd = (index / output_width / output_height) % output_depth; + int c = (index / output_width / output_height / output_depth) % channels; + int batch_idx = + index / output_width / output_height / output_depth / channels; + + int dstart = pd * stride_depth - padding_depth; + int hstart = ph * stride_height - padding_height; + int wstart = pw * stride_width - padding_width; + int dend = min(dstart + ksize_depth, input_depth); + int hend = min(hstart + ksize_height, input_height); + int wend = min(wstart + ksize_width, input_width); + dstart = max(dstart, 0); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + + T ele = -FLT_MAX; + int max_index = -1; + input_data += + (batch_idx * channels + c) * input_depth * input_height * input_width; + + for (int d = dstart; d < dend; ++d) { + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + if (ele < input_data[(d * input_height + h) * input_width + w]) { + max_index = (d * input_height + h) * input_width + w; + ele = input_data[max_index]; + } + } + } + } + output_data[index] = ele; + mask_data[index] = max_index; + } +} + +template +__global__ void KernelMaxPool3DWithIdxGrad( + const int nthreads, T* input_grad, const T* output_grad, const T* mask, + const int channels, const int input_depth, const int input_height, + const int input_width, const int output_depth, const int output_height, + const int output_width, const int ksize_depth, const int ksize_height, + const int ksize_width, const int stride_depth, const int stride_height, + const int stride_width, const int padding_depth, const int padding_height, + const int padding_width) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; + index += blockDim.x * gridDim.x) { + int w_offset = index % input_width; + int h_offset = (index / input_width) % input_height; + int d_offset = (index / input_width / input_height) % input_depth; + int c_offset = + (index / input_width / input_height / input_depth) % channels; + int batch_idx = index / input_width / input_height / input_depth / channels; + + int pd_start = + (d_offset + padding_depth < ksize_depth) + ? 0 + : (d_offset + padding_depth - ksize_depth) / stride_depth + 1; + int ph_start = + (h_offset + padding_height < ksize_height) + ? 0 + : (h_offset + padding_height - ksize_height) / stride_height + 1; + int pw_start = + (w_offset + padding_width < ksize_width) + ? 0 + : (w_offset + padding_width - ksize_width) / stride_width + 1; + int pd_end = + min((d_offset + padding_depth) / stride_depth + 1, output_depth); + int ph_end = + min((h_offset + padding_height) / stride_height + 1, output_height); + int pw_end = + min((w_offset + padding_width) / stride_width + 1, output_width); + + T gradient = 0; + int input_current_feature_map_idx = + (d_offset * input_height + h_offset) * input_width + w_offset; + int output_idx = (batch_idx * channels + c_offset) * output_depth * + output_height * output_width; + mask += output_idx; + output_grad += output_idx; + + for (int pd = pd_start; pd < pd_end; ++pd) { + for (int ph = ph_start; ph < ph_end; ++ph) { + for (int pw = pw_start; pw < pw_end; ++pw) { + if (mask[(pd * output_height + ph) * output_width + pw] == + input_current_feature_map_idx) + gradient += + output_grad[(pd * output_height + ph) * output_width + pw]; + } + } + } + input_grad[index] = gradient; + } +} + +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ +template +class MaxPool3dWithIndexFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input.dims()[0]; + const int input_channels = input.dims()[1]; + const int input_depth = input.dims()[2]; + const int input_height = input.dims()[3]; + const int input_width = input.dims()[4]; + const int output_channels = output.dims()[1]; + const int output_depth = output.dims()[2]; + const int output_height = output.dims()[3]; + const int output_width = output.dims()[4]; + const int ksize_depth = ksize[0]; + const int ksize_height = ksize[1]; + const int ksize_width = ksize[2]; + const int stride_depth = strides[0]; + const int stride_height = strides[1]; + const int stride_width = strides[2]; + const int padding_depth = paddings[0]; + const int padding_height = paddings[1]; + const int padding_width = paddings[2]; + + const T* input_data = input.data(); + T* output_data = output.mutable_data(context.GetPlace()); + T* mask_data = mask.mutable_data(context.GetPlace()); + + int nthreads = batch_size * output_channels * output_depth * output_height * + output_width; + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelMaxPool3DWithIdx< + T><<(context) + .stream()>>>( + nthreads, input_data, output_data, mask_data, input_channels, + input_depth, input_height, input_width, output_depth, output_height, + output_width, ksize_depth, ksize_height, ksize_width, stride_depth, + stride_height, stride_width, padding_depth, padding_height, + padding_width); + } +}; + +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ +template +class MaxPool3dWithIndexGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& input_grad, + const framework::Tensor& output_grad, + const framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input_grad.dims()[0]; + const int input_channels = input_grad.dims()[1]; + const int input_depth = input_grad.dims()[2]; + const int input_height = input_grad.dims()[3]; + const int input_width = input_grad.dims()[4]; + const int output_depth = output_grad.dims()[2]; + const int output_height = output_grad.dims()[3]; + const int output_width = output_grad.dims()[4]; + const int ksize_depth = ksize[0]; + const int ksize_height = ksize[1]; + const int ksize_width = ksize[2]; + const int stride_depth = strides[0]; + const int stride_height = strides[1]; + const int stride_width = strides[2]; + const int padding_depth = paddings[0]; + const int padding_height = paddings[1]; + const int padding_width = paddings[2]; + + const T* output_grad_data = output_grad.data(); + const T* mask_data = mask.data(); + T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + + int nthreads = + batch_size * input_channels * input_depth * input_height * input_width; + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelMaxPool3DWithIdxGrad< + T><<(context) + .stream()>>>( + nthreads, input_grad_data, output_grad_data, mask_data, input_channels, + input_depth, input_height, input_width, output_depth, output_height, + output_width, ksize_depth, ksize_height, ksize_width, stride_depth, + stride_height, stride_width, padding_depth, padding_height, + padding_width); + } +}; + +template class MaxPool3dWithIndexFunctor; +template class MaxPool3dWithIndexGradFunctor; +template class MaxPool3dWithIndexFunctor; +template class MaxPool3dWithIndexGradFunctor; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/pooling.h b/paddle/operators/math/pooling.h index d214c689235ad..c50c57b5c52cd 100644 --- a/paddle/operators/math/pooling.h +++ b/paddle/operators/math/pooling.h @@ -21,15 +21,27 @@ limitations under the License. */ namespace paddle { namespace operators { namespace math { -////////////////////// -#define FLT_MAX __FLT_MAX__ // +#define FLT_MAX \ + __FLT_MAX__ // It might need to be placed in another file, but I'm still + // wondering where to put it. + +/* + * \brief Extracting simple operations from pooling. + * Both MaxPool and AvgPool need "initial", "compute" and "finalize" + * operation. + * MaxPool initializes temp variable to the negative maximum to find the + * maximum value in the pooling field. + * AvgPool initializes temp variable to the zero to accumulate all values + * in pool pooling, and finally takes the average. + * MaxPoolGrad and AvgPoolGrad are gradient operations respectively. + */ template class MaxPool { public: DEVICE inline T initial() { return static_cast(-FLT_MAX); } DEVICE inline void compute(T& y, const T& x) { y = y > x ? y : x; } - DEVICE inline void finalize(T& y, const T& poo_size) {} + DEVICE inline void finalize(T& y, const T& pool_field) {} }; template @@ -37,8 +49,9 @@ class AvgPool { public: DEVICE inline T initial() { return static_cast(0); } DEVICE inline void compute(T& y, const T& x) { y += x; } - DEVICE inline void finalize(T& y, const T& poo_size) { y /= poo_size; } + DEVICE inline void finalize(T& y, const T& pool_field) { y /= pool_field; } }; + template class MaxPoolGrad { public: @@ -57,6 +70,20 @@ class AvgPoolGrad { } }; +/* + * \brief Getting pooling results, and calculating gradient. + * + * In pool2d, all tensors are in NCHW format. Where N is batch size, C is the + * number of channels, H and W is the height and width of feature. + * In pool3d, all tensors are in NCDHW format. Where N is batch size, C is the + * number of channels, D, H and W is the depth, height and width of feature. + * + * In max pooling, it is possible that the pooling region has multiple maximum + * elements. In this case, we should compute the gradient of the first maximum + * element. + * This is different from average pooling. So we rewrite the max_pool_grad: + * MaxPool2dGradFunctor, MaxPool3dGradFunctor. + */ template class Pool2dFunctor { public: @@ -117,6 +144,51 @@ class MaxPool3dGradFunctor { std::vector& strides, std::vector& paddings); }; +/* + * \brief Getting max pooling results and corresponding max index, and + * calculating gradient. + * In up-sampling-pooling, it is necessary to know max element index. + * In pool2d, all tensors are in NCHW format. In pool3d, all tensors are in + * NCDHW format. + */ +template +class MaxPool2dWithIndexFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings); +}; + +template +class MaxPool2dWithIndexGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& input_grad, + const framework::Tensor& output_grad, + const framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings); +}; + +template +class MaxPool3dWithIndexFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings); +}; + +template +class MaxPool3dWithIndexGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& input_grad, + const framework::Tensor& output_grad, + const framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings); +}; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/net_op.h b/paddle/operators/net_op.h index 2388b094d2285..ebeb262d9621f 100644 --- a/paddle/operators/net_op.h +++ b/paddle/operators/net_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include "paddle/framework/framework.pb.h" #include "paddle/framework/op_registry.h" diff --git a/paddle/operators/pool_with_index_op.cc b/paddle/operators/pool_with_index_op.cc new file mode 100644 index 0000000000000..7b6afcfd1f7e3 --- /dev/null +++ b/paddle/operators/pool_with_index_op.cc @@ -0,0 +1,228 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/pool_with_index_op.h" + +namespace paddle { +namespace operators { + +inline int OutputSizeMaxPool(int input_size, int filter_size, int padding, + int stride) { + int output_size = (input_size - filter_size + 2 * padding) / stride + 1; + return output_size; +} + +class MaxPoolWithIndexOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "X(Input) of Pooling should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Out(Output) of Pooling should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Mask"), + "Mask(Output) of Pooling should not be null."); + + auto in_x_dims = ctx->GetInputDim("X"); + + std::vector ksize = ctx->Attrs().Get>("ksize"); + std::vector strides = ctx->Attrs().Get>("strides"); + std::vector paddings = ctx->Attrs().Get>("paddings"); + + PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5, + "Pooling intput should be 4-D or 5-D"); + + if (ctx->Attrs().Get("globalPooling")) { + ksize.resize(static_cast(in_x_dims.size()) - 2); + for (size_t i = 0; i < ksize.size(); ++i) + ksize[i] = static_cast(in_x_dims[i + 2]); + } + + PADDLE_ENFORCE(in_x_dims.size() - ksize.size() == 2U, + "Intput size and pooling size should be consistent."); + PADDLE_ENFORCE_EQ(ksize.size(), strides.size(), + "Strides size and pooling size should be the same."); + PADDLE_ENFORCE_EQ(ksize.size(), paddings.size(), + "Paddings size and pooling size should be the same."); + + std::vector output_shape({in_x_dims[0], in_x_dims[1]}); + for (size_t i = 0; i < ksize.size(); ++i) { + output_shape.push_back(OutputSizeMaxPool(in_x_dims[i + 2], ksize[i], + paddings[i], strides[i])); + } + ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); + ctx->SetOutputDim("Mask", framework::make_ddim(output_shape)); + } +}; + +class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Input(X@GRAD) should not be null."); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } +}; + +class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { + public: + MaxPool2dWithIndexOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "X", + "The input tensor of pooling operator. " + "The format of input tensor is NCHW. Where N is batch size, C is the " + "number of channels, H and W is the height and width of image."); + AddOutput("Out", + "The output tensor of pooling operator." + "The format of output tensor is also NCHW." + "Where N is batch size, C is " + "the number of channels, H and W is the height and " + "width of image."); + AddOutput("Mask", + "The Mask tensor of pooling operator." + "The format of output tensor is also NCHW." + "Where N is batch size, C is the number of channels, H and W " + "is the height and width of image." + "The value in it is the index in current feature map"); + + AddAttr>( + "ksize", + "The pooling size(height, width) of pooling operator." + "If globalPooling = true, ksize is ignored and need not be " + "specified."); // TODO(Chengduo): Add checker. (Currently, + // TypedAttrChecker don't support vector type.) + AddAttr( + "globalPooling", + "Whether to use the globalPooling." + "Bool constant equal to false or true." + "Default false." + "If globalPooling = true, ksize is ignored and need not be specified.") + .SetDefault(false); + AddAttr>("strides", + "Strides(height, width) of pooling operator." + "Default {1,1}.") + .SetDefault({1, 1}); // TODO(Chengduo): Add checker. (Currently, + // TypedAttrChecker don't support vector type.) + AddAttr>("paddings", + "Paddings(height, width) of pooling operator." + "Default {0,0}.") + .SetDefault({0, 0}); // TODO(Chengduo): Add checker. (Currently, + // TypedAttrChecker don't support vector type.) + + AddComment(R"DOC( +The maxPooling2d with index operation calculates the output and the mask +based on the input and ksize, strides, paddings parameters. Input(X) and +output(Out, Mask) are in NCHW format. Where N is batch size, C is the +number of channels, H and W is the height and width of feature. +Parameters(ksize, strides, paddings) are two elements. +These two elements represent height and width, respectively. +)DOC"); + } +}; + +class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { + public: + MaxPool3dWithIndexOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "X", + "The input tensor of pooling operator. " + "The format of input tensor is NCDHW. Where N is batch size, C is " + "the number of channels, D, H and W is the depth, height and width of " + "image."); + AddOutput("Out", + "The output tensor of pooling operator." + "The format of output tensor is also NCDHW." + "Where N is batch size, C is " + "the number of channels, D, H and W is the depth, height and " + "width of image."); + AddOutput("Mask", + "The Mask tensor of pooling operator." + "The format of output tensor is also NCDHW." + "Where N is batch size, C is the number of channels, D, H and W " + "is the depth, height and width of image." + "The value in it is the index in current feature map"); + + AddAttr>( + "ksize", + "The pooling size(depth, height, width) of pooling operator." + "If globalPooling = true, ksize is ignored and need not be " + "specified."); // TODO(Chengduo): Add checker. (Currently, + // TypedAttrChecker don't support vector type.) + AddAttr( + "globalPooling", + "Whether to use the globalPooling." + "Bool constant equal to false or true." + "Default false." + "If globalPooling = true, ksize is ignored and need not be specified.") + .SetDefault(false); + AddAttr>( + "strides", + "Strides(depth, height, width) of pooling operator." + "Default {1,1,1}.") + .SetDefault({1, 1, 1}); // TODO(Chengduo): Add checker. (Currently, + // TypedAttrChecker don't support vector type.) + AddAttr>( + "paddings", + "Paddings(depth, height, width) of pooling operator." + "Default {0,0,0}.") + .SetDefault({0, 0, 0}); // TODO(Chengduo): Add checker. (Currently, + // TypedAttrChecker don't support vector type.) + + AddComment(R"DOC( +The maxpooling3d with index operation calculates the output and the mask +based on the input and ksize, strides, paddings parameters. +Input(X) and output(Out, Mask) are in NCDHW format. Where N is batch +size, C is the number of channels, D, H and W is the depth, height and +width of feature. Parameters(ksize, strides, paddings) are three elements. +These three elements represent depth, height and width, respectively. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP(max_pool2d_with_index, ops::MaxPoolWithIndexOp, + ops::MaxPool2dWithIndexOpMaker, max_pool2d_with_index_grad, + ops::MaxPoolWithIndexOpGrad); + +REGISTER_OP_CPU_KERNEL( + max_pool2d_with_index, + ops::MaxPoolWithIndexKernel); +REGISTER_OP_CPU_KERNEL( + max_pool2d_with_index_grad, + ops::MaxPoolWithIndexGradKernel) + +REGISTER_OP(max_pool3d_with_index, ops::MaxPoolWithIndexOp, + ops::MaxPool3dWithIndexOpMaker, max_pool3d_with_index_grad, + ops::MaxPoolWithIndexOpGrad); + +REGISTER_OP_CPU_KERNEL( + max_pool3d_with_index, + ops::MaxPoolWithIndexKernel); +REGISTER_OP_CPU_KERNEL( + max_pool3d_with_index_grad, + ops::MaxPoolWithIndexGradKernel) diff --git a/paddle/operators/pool_with_index_op.cu b/paddle/operators/pool_with_index_op.cu new file mode 100644 index 0000000000000..287657d4b1c57 --- /dev/null +++ b/paddle/operators/pool_with_index_op.cu @@ -0,0 +1,31 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/pool_with_index_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_GPU_KERNEL( + max_pool2d_with_index, + ops::MaxPoolWithIndexKernel); +REGISTER_OP_GPU_KERNEL( + max_pool2d_with_index_grad, + ops::MaxPoolWithIndexGradKernel) + +REGISTER_OP_GPU_KERNEL( + max_pool3d_with_index, + ops::MaxPoolWithIndexKernel); +REGISTER_OP_GPU_KERNEL( + max_pool3d_with_index_grad, + ops::MaxPoolWithIndexGradKernel) diff --git a/paddle/operators/pool_with_index_op.h b/paddle/operators/pool_with_index_op.h new file mode 100644 index 0000000000000..01b961ca8295f --- /dev/null +++ b/paddle/operators/pool_with_index_op.h @@ -0,0 +1,103 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/pooling.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class MaxPoolWithIndexKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* in_x = context.Input("X"); + Tensor* out = context.Output("Out"); + Tensor* mask = context.Output("Mask"); + + std::vector ksize = context.Attr>("ksize"); + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + if (context.Attr("globalPooling")) { + for (size_t i = 0; i < ksize.size(); ++i) { + ksize[i] = static_cast(in_x->dims()[i + 2]); + } + } + + switch (ksize.size()) { + case 2: { + paddle::operators::math::MaxPool2dWithIndexFunctor + pool2d_forward; + pool2d_forward(context.device_context(), *in_x, *out, *mask, ksize, + strides, paddings); + } break; + case 3: { + paddle::operators::math::MaxPool3dWithIndexFunctor + pool3d_forward; + pool3d_forward(context.device_context(), *in_x, *out, *mask, ksize, + strides, paddings); + } break; + } + } +}; + +template +class MaxPoolWithIndexGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* mask = context.Input("Mask"); + const Tensor* out_grad = + context.Input(framework::GradVarName("Out")); + Tensor* in_x_grad = context.Output(framework::GradVarName("X")); + + std::vector ksize = context.Attr>("ksize"); + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + if (context.Attr("globalPooling")) { + for (size_t i = 0; i < ksize.size(); ++i) { + ksize[i] = static_cast(in_x_grad->dims()[i + 2]); + } + } + + if (in_x_grad) { + in_x_grad->mutable_data(context.GetPlace()); + auto temp = framework::EigenVector::Flatten(*in_x_grad); + temp.device(context.GetEigenDevice()) = + temp.constant(static_cast(0)); + + switch (ksize.size()) { + case 2: { + paddle::operators::math::MaxPool2dWithIndexGradFunctor + pool2d_backward; + pool2d_backward(context.device_context(), *in_x_grad, *out_grad, + *mask, ksize, strides, paddings); + } break; + case 3: { + paddle::operators::math::MaxPool3dWithIndexGradFunctor + pool3d_backward; + pool3d_backward(context.device_context(), *in_x_grad, *out_grad, + *mask, ksize, strides, paddings); + } break; + } + } + } +}; +} // namespace operators +} // namespace paddle diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index a9b6b799036a4..36450e9268913 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -136,7 +136,7 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; } cudaStream_t CUDADeviceContext::stream() const { return stream_; } -#endif // PADDLE_ONLY_CPU +#endif } // namespace platform } // namespace paddle diff --git a/paddle/platform/enforce.h b/paddle/platform/enforce.h index 15d8446cd8dce..cd906c3fa9375 100644 --- a/paddle/platform/enforce.h +++ b/paddle/platform/enforce.h @@ -41,7 +41,7 @@ limitations under the License. */ #include #include -#endif // PADDLE_ONLY_CPU +#endif namespace paddle { namespace platform { diff --git a/paddle/platform/gpu_info.h b/paddle/platform/gpu_info.h index fb33db07bd54d..37665b97d764f 100644 --- a/paddle/platform/gpu_info.h +++ b/paddle/platform/gpu_info.h @@ -63,4 +63,4 @@ void GpuMemcpyPeer(void *dst, int dst_device, const void *src, int src_device, } // namespace platform } // namespace paddle -#endif // PADDLE_ONLY_CPU +#endif diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 218821b35bb69..47bd7bc3bb3f0 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -117,7 +117,6 @@ void BindProgramDesc(py::module &m) { .def("append_block", &ProgramDescBind::AppendBlock, py::return_value_policy::reference) .def("block", &ProgramDescBind::Block, py::return_value_policy::reference) - .def("__str__", &ProgramDescBind::DebugString) .def("num_blocks", &ProgramDescBind::Size); } @@ -191,8 +190,6 @@ void BindOpDesc(py::module &m) { .def("output", &OpDescBind::Output) .def("output_names", &OpDescBind::OutputNames) .def("set_output", &OpDescBind::SetOutput) - .def("__str__", &OpDescBind::DebugString) - .def("__repr__", &OpDescBind::DebugString) .def("has_attr", &OpDescBind::HasAttr) .def("attr_type", &OpDescBind::GetAttrType) .def("attr_names", &OpDescBind::AttrNames) diff --git a/python/paddle/v2/framework/tests/test_activation_op.py b/python/paddle/v2/framework/tests/test_activation_op.py index 8b76decaecdcb..4528ed555d6bd 100644 --- a/python/paddle/v2/framework/tests/test_activation_op.py +++ b/python/paddle/v2/framework/tests/test_activation_op.py @@ -181,6 +181,26 @@ def test_check_grad(self): self.check_grad(['X'], 'Y', max_relative_error=0.02) +class TestELU(OpTest): + def setUp(self): + self.op_type = "elu" + x = np.random.uniform(-3, 3, [4, 4]).astype("float32") + alpha = 1. + # Note: unlike other Relu extensions, point 0 on standard ELU function (i.e. alpha = 1) + # is differentiable, so we can skip modifications like x[np.abs(x) < 0.005] = 0.02 here + self.inputs = {'X': x} + self.attrs = {'alpha': alpha} + self.outputs = { + 'Y': np.maximum(0, x) + np.minimum(0, alpha * (np.exp(x) - 1)) + } + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.02) + + class TestReciprocal(OpTest): def setUp(self): self.op_type = "reciprocal" diff --git a/python/paddle/v2/framework/tests/test_fill_constant_op.py b/python/paddle/v2/framework/tests/test_fill_constant_op.py new file mode 100644 index 0000000000000..dff7b615aa378 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_fill_constant_op.py @@ -0,0 +1,35 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestFillConstantOp1(OpTest): + def setUp(self): + '''Test fill_constant op with specified value + ''' + self.op_type = "fill_constant" + + self.inputs = {} + self.attrs = {'shape': [123, 92], 'value': 3.8} + self.outputs = {'Out': np.full((123, 92), 3.8)} + + def test_check_output(self): + self.check_output() + + +class TestFillConstantOp2(OpTest): + def setUp(self): + '''Test fill_constant op with default value + ''' + self.op_type = "fill_constant" + + self.inputs = {} + self.attrs = {'shape': [123, 92]} + self.outputs = {'Out': np.full((123, 92), 0.0)} + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_interp_op.py b/python/paddle/v2/framework/tests/test_interp_op.py new file mode 100644 index 0000000000000..066569b96c961 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_interp_op.py @@ -0,0 +1,28 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestInterpOp(OpTest): + def setUp(self): + self.op_type = "interp" + x = np.random.random((2, 3)).astype("float32") + y = np.random.random((2, 3)).astype("float32") + w = np.random.random(2).astype("float32") + + sub_out = x - y + mul_out = sub_out * w.reshape(2, 1) + out = mul_out + y + + self.inputs = {'X': x, 'Y': y, 'W': w} + self.outputs = {'Out': out, 'SubOut': sub_out, 'MulOut': mul_out} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out') + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_pool_max_op.py b/python/paddle/v2/framework/tests/test_pool_max_op.py new file mode 100644 index 0000000000000..f0f8aa6089c74 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_pool_max_op.py @@ -0,0 +1,212 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def max_pool3D_forward_naive(x, + ksize, + strides, + paddings=[0, 0, 0], + global_pool=0): + + N, C, D, H, W = x.shape + if global_pool == 1: + ksize = [D, H, W] + D_out = (D - ksize[0] + 2 * paddings[0]) / strides[0] + 1 + H_out = (H - ksize[1] + 2 * paddings[1]) / strides[1] + 1 + W_out = (W - ksize[2] + 2 * paddings[2]) / strides[2] + 1 + out = np.zeros((N, C, D_out, H_out, W_out)) + mask = np.zeros((N, C, D_out, H_out, W_out)) + for k in xrange(D_out): + d_start = np.max((k * strides[0] - paddings[0], 0)) + d_end = np.min((k * strides[0] + ksize[0] - paddings[0], D)) + for i in xrange(H_out): + h_start = np.max((i * strides[0] - paddings[0], 0)) + h_end = np.min((i * strides[0] + ksize[0] - paddings[0], H)) + for j in xrange(W_out): + w_start = np.max((j * strides[1] - paddings[1], 0)) + w_end = np.min((j * strides[1] + ksize[1] - paddings[1], W)) + x_masked = x[:, :, d_start:d_end, h_start:h_end, w_start:w_end] + + out[:, :, k, i, j] = np.max(x_masked, axis=(2, 3, 4)) + + for n in xrange(N): + for c in xrange(C): + arr = x_masked[n, c, :, :, :] + index = np.where(arr == np.max(arr)) + sub_deep = index[0][0] + sub_row = index[1][0] + sub_col = index[2][0] + index = ((d_start + sub_deep) * H + + (h_start + sub_row)) * W + w_start + sub_col + mask[n, c, k, i, j] = index + + return out, mask + + +def max_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): + + N, C, H, W = x.shape + if global_pool == 1: + ksize = [H, W] + H_out = (H - ksize[0] + 2 * paddings[0]) / strides[0] + 1 + W_out = (W - ksize[1] + 2 * paddings[1]) / strides[1] + 1 + out = np.zeros((N, C, H_out, W_out)) + mask = np.zeros((N, C, H_out, W_out)) + for i in xrange(H_out): + for j in xrange(W_out): + r_start = np.max((i * strides[0] - paddings[0], 0)) + r_end = np.min((i * strides[0] + ksize[0] - paddings[0], H)) + c_start = np.max((j * strides[1] - paddings[1], 0)) + c_end = np.min((j * strides[1] + ksize[1] - paddings[1], W)) + x_masked = x[:, :, r_start:r_end, c_start:c_end] + + out[:, :, i, j] = np.max(x_masked, axis=(2, 3)) + + for n in xrange(N): + for c in xrange(C): + arr = x_masked[n, c, :, :] + index = np.where(arr == np.max(arr)) + sub_row = index[0][0] + sub_col = index[1][0] + index = (r_start + sub_row) * W + c_start + sub_col + mask[n, c, i, j] = index + + return out, mask + + +class TestMaxPoolWithIndex_Op(OpTest): + def setUp(self): + self.initTestCase() + input = np.random.random(self.shape).astype("float32") + output, mask = self.pool_forward_naive(input, self.ksize, self.strides, + self.paddings, self.global_pool) + + self.attrs = { + 'strides': self.strides, + 'paddings': self.paddings, + 'ksize': self.ksize, + 'globalPooling': self.global_pool, + } + + self.inputs = {'X': input} + self.outputs = {'Out': output, "Mask": mask} + + def test_check_output(self): + self.check_output() + + # def test_check_grad(self): + # self.check_grad(set(['X']), ['Out'], max_relative_error=0.07) + + def initTestCase(self): + self.global_pool = True + self.index = "max_pool3d_with_index" + self.op_type = "%s" % self.index + self.pool_forward_naive = max_pool3D_forward_naive + self.shape = [2, 3, 5, 5, 5] + self.ksize = [3, 3, 3] + self.strides = [1, 1, 1] + self.paddings = [1, 1, 1] + + +class TestCase1(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = True + self.op_type = "max_pool3d_with_index" + self.pool_forward_naive = max_pool3D_forward_naive + self.shape = [2, 3, 5, 5, 5] + self.ksize = [3, 3, 3] + self.strides = [1, 1, 1] + self.paddings = [1, 1, 1] + + +class TestCase2(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = False + self.op_type = "max_pool3d_with_index" + self.pool_forward_naive = max_pool3D_forward_naive + self.shape = [2, 3, 7, 7, 7] + self.ksize = [3, 3, 3] + self.strides = [1, 1, 1] + self.paddings = [1, 1, 1] + + +class TestCase3(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = False + self.op_type = "max_pool3d_with_index" + self.pool_forward_naive = max_pool3D_forward_naive + self.shape = [2, 3, 7, 7, 7] + self.ksize = [3, 3, 3] + self.strides = [2, 2, 2] + self.paddings = [0, 0, 0] + + +class TestCase4(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = True + self.op_type = "max_pool3d_with_index" + self.pool_forward_naive = max_pool3D_forward_naive + self.shape = [2, 3, 5, 5, 5] + self.ksize = [3, 3, 3] + self.strides = [1, 1, 1] + self.paddings = [1, 1, 1] + + +class TestCase5(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = True + self.op_type = "max_pool3d_with_index" + self.pool_forward_naive = max_pool3D_forward_naive + self.shape = [2, 3, 5, 5, 5] + self.ksize = [3, 3, 3] + self.strides = [2, 2, 2] + self.paddings = [0, 0, 0] + + +class TestCase6(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = False + self.op_type = "max_pool2d_with_index" + self.pool_forward_naive = max_pool2D_forward_naive + self.shape = [2, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [1, 1] + + +class TestCase7(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = False + self.op_type = "max_pool2d_with_index" + self.pool_forward_naive = max_pool2D_forward_naive + self.shape = [2, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [2, 2] + self.paddings = [0, 0] + + +class TestCase8(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = True + self.op_type = "max_pool2d_with_index" + self.pool_forward_naive = max_pool2D_forward_naive + self.shape = [2, 3, 5, 5] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [1, 1] + + +class TestCase9(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = True + self.op_type = "max_pool2d_with_index" + self.pool_forward_naive = max_pool2D_forward_naive + self.shape = [2, 3, 5, 5] + self.ksize = [3, 3] + self.strides = [2, 2] + self.paddings = [0, 0] + + +if __name__ == '__main__': + unittest.main()