Skip to content

Commit

Permalink
merge dynamic op changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Superjomn committed Oct 10, 2017
1 parent dccf0cc commit b5e8a8a
Show file tree
Hide file tree
Showing 41 changed files with 1,818 additions and 63 deletions.
24 changes: 24 additions & 0 deletions paddle/framework/backward_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ class MulOpMaker : public OpProtoAndCheckerMaker {
AddInput("X", "A");
AddInput("Y", "B");
AddOutput("Out", "Out");
AddAttr<int>("x_num_col_dims", "").SetDefault(1).EqualGreaterThan(1);
AddAttr<int>("y_num_col_dims", "").SetDefault(1).EqualGreaterThan(1);
AddComment("Mul");
}
};
Expand Down Expand Up @@ -440,6 +442,28 @@ TEST(Backward, simple_single_op) {
std::vector<std::string>({f::GradVarName("b")}));
}

TEST(Backward, default_attribute) {
f::ProgramDesc *program_desc = GetNewProgramDesc();
f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc);
f::BlockDescBind *block = program.Block(0);
f::OpDescBind *op = block->AppendOp();
op->SetType("mul");
op->SetInput("X", {"x"});
op->SetInput("Y", {"y"});
op->SetOutput("Out", {"out"});

AppendBackward(program, {});

ASSERT_EQ(block->AllOps().size(), 2UL);
EXPECT_EQ(boost::get<int>(op->GetAttr("x_num_col_dims")), 1);
EXPECT_EQ(boost::get<int>(op->GetAttr("y_num_col_dims")), 1);

f::OpDescBind *grad_op = block->AllOps()[1];
ASSERT_EQ(grad_op->Type(), "mul_grad");
EXPECT_EQ(boost::get<int>(grad_op->GetAttr("x_num_col_dims")), 1);
EXPECT_EQ(boost::get<int>(grad_op->GetAttr("y_num_col_dims")), 1);
}

TEST(Backward, simple_mult_op) {
f::ProgramDesc *program_desc = GetNewProgramDesc();
f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc);
Expand Down
1 change: 1 addition & 0 deletions paddle/framework/block_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#pragma once

#include <deque>
#include <memory>
#include <unordered_map>
#include <vector>
#include "paddle/framework/op_desc.h"
Expand Down
1 change: 0 additions & 1 deletion paddle/framework/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ inline DataType ToDataType(std::type_index type) {
return DataType::INT32;
} else {
PADDLE_THROW("Not supported");
return static_cast<DataType>(-1);
}
}

Expand Down
1 change: 1 addition & 0 deletions paddle/framework/framework.proto
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

syntax = "proto2";
option optimize_for = LITE_RUNTIME;
package paddle.framework;

enum AttrType {
Expand Down
1 change: 1 addition & 0 deletions paddle/framework/op_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs,
inputs_ = inputs;
outputs_ = outputs;
attrs_ = attrs;
need_update_ = true;
}

OpDesc *OpDescBind::Proto() {
Expand Down
2 changes: 0 additions & 2 deletions paddle/framework/op_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ class OpDescBind {
void SetOutput(const std::string &param_name,
const std::vector<std::string> &args);

std::string DebugString() { return this->Proto()->DebugString(); }

bool HasAttr(const std::string &name) const {
return attrs_.find(name) != attrs_.end();
}
Expand Down
3 changes: 1 addition & 2 deletions paddle/framework/program_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */

#pragma once

#include <memory>
#include <vector>
#include "paddle/framework/framework.pb.h"
#include "paddle/platform/macros.h"
Expand All @@ -31,8 +32,6 @@ class ProgramDescBind {

BlockDescBind *Block(size_t idx) { return blocks_[idx].get(); }

std::string DebugString() { return Proto()->DebugString(); }

size_t Size() const { return blocks_.size(); }

ProgramDesc *Proto();
Expand Down
1 change: 1 addition & 0 deletions paddle/framework/type_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#pragma once
#include <functional>
#include <map>
#include <memory>
#include "paddle/platform/variant.h"

namespace paddle {
Expand Down
2 changes: 1 addition & 1 deletion paddle/math/tests/test_GpuProfiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,4 +162,4 @@ int main(int argc, char** argv) {
return RUN_ALL_TESTS();
}

#endif /* PADDLE_ONLY_CPU */
#endif
2 changes: 1 addition & 1 deletion paddle/memory/detail/buddy_allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ BuddyAllocator::PoolSet::iterator BuddyAllocator::RefillPool() {
max_chunk_size_ = platform::GpuMaxChunkSize();
}
}
#endif // PADDLE_ONLY_CPU
#endif

// Allocate a new maximum sized block
size_t index = 0;
Expand Down
2 changes: 1 addition & 1 deletion paddle/memory/detail/system_allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ void GPUAllocator::Free(void* p, size_t size, size_t index) {

bool GPUAllocator::UseGpu() const { return true; }

#endif // PADDLE_ONLY_CPU
#endif

} // namespace detail
} // namespace memory
Expand Down
2 changes: 1 addition & 1 deletion paddle/memory/detail/system_allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class GPUAllocator : public SystemAllocator {
size_t gpu_alloc_size_ = 0;
size_t fallback_alloc_size_ = 0;
};
#endif // PADDLE_ONLY_CPU
#endif

} // namespace detail
} // namespace memory
Expand Down
2 changes: 1 addition & 1 deletion paddle/memory/detail/system_allocator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,4 @@ TEST(GPUAllocator, Alloc) {
TestAllocator(a, 2048);
TestAllocator(a, 0);
}
#endif // PADDLE_ONLY_CPU
#endif
2 changes: 1 addition & 1 deletion paddle/memory/memcpy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ void Copy<platform::GPUPlace, platform::GPUPlace>(platform::GPUPlace dst_place,
platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToDevice);
}

#endif // PADDLE_ONLY_CPU
#endif

} // namespace memory
} // namespace paddle
2 changes: 1 addition & 1 deletion paddle/memory/memcpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ template <typename DstPlace, typename SrcPlace>
void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num,
cudaStream_t stream);

#endif // PADDLE_ONLY_CPU
#endif

} // namespace memory
} // namespace paddle
2 changes: 1 addition & 1 deletion paddle/memory/memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ size_t Used<platform::GPUPlace>(platform::GPUPlace place) {
return GetGPUBuddyAllocator(place.device)->Used();
}

#endif // PADDLE_ONLY_CPU
#endif

} // namespace memory
} // namespace paddle
2 changes: 1 addition & 1 deletion paddle/memory/memory_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,4 @@ TEST(BuddyAllocator, GPUMultAlloc) {
}
}

#endif // PADDLE_ONLY_CPU
#endif
8 changes: 8 additions & 0 deletions paddle/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,20 @@ function(op_library TARGET)
set(pybind_flag 1)
endif()

# pool_op contains several operators
if ("${TARGET}" STREQUAL "pool_op")
set(pybind_flag 1)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(pool2d);\n")
endif()

# pool_with_index_op contains several operators
if ("${TARGET}" STREQUAL "pool_with_index_op")
set(pybind_flag 1)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(max_pool2d_with_index);\n")
endif()

# activation_op contains several operators
if ("${TARGET}" STREQUAL "activation_op")
set(pybind_flag 1)
Expand Down
24 changes: 24 additions & 0 deletions paddle/operators/activation_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,27 @@ class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
}
};

template <typename AttrType>
class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ELUOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(Tensor) The input of ELU operator, it shouldn't be empty. Input "
"is flattened and treated as a 1D array.");
AddOutput("Y",
"(Tensor) The output of ELU operator. It has the same shape as "
"the input.");
AddAttr<AttrType>(
"alpha", "(float, default 1.0) Alpha value in the elu formulation.")
.SetDefault(static_cast<AttrType>(1.));
AddComment(R"DOC(
ELU activation operator. It applies this element-wise computation on
the input: f(x) = max(0, x) + min(0, alpha * (exp(x) - 1)).
Check .. _Link: https://arxiv.org/abs/1511.07289 for more details.)DOC");
}
};

template <typename AttrType>
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
public:
Expand Down Expand Up @@ -289,6 +310,9 @@ REGISTER_OP(leaky_relu, ops::ActivationOp, ops::LeakyReluOpMaker<float>,
REGISTER_OP(soft_relu, ops::ActivationOp, ops::SoftReluOpMaker<float>,
soft_relu_grad, ops::ActivationOpGrad);

REGISTER_OP(elu, ops::ActivationOp, ops::ELUOpMaker<float>, elu_grad,
ops::ActivationOpGrad);

REGISTER_OP(relu6, ops::ActivationOp, ops::Relu6OpMaker<float>, relu6_grad,
ops::ActivationOpGrad);

Expand Down
66 changes: 48 additions & 18 deletions paddle/operators/activation_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,35 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor<T> {
}
};

template <typename T>
struct ELUFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}

template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) =
x.cwiseMax(static_cast<T>(0)) +
(alpha * (x.exp() - static_cast<T>(1))).cwiseMin(static_cast<T>(0));
}
};

template <typename T>
struct ELUGradFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) =
dy * (x > static_cast<T>(0)).template cast<T>() +
dy * (y + alpha) * (x < static_cast<T>(0)).template cast<T>();
}
};

template <typename T>
struct PowFunctor : public BaseActivationFunctor<T> {
float factor;
Expand Down Expand Up @@ -440,21 +469,22 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> {
} // namespace operators
} // namespace paddle

#define FOR_EACH_KERNEL_FUNCTOR(__macro) \
__macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \
__macro(exp, ExpFunctor, ExpGradFunctor); \
__macro(relu, ReluFunctor, ReluGradFunctor); \
__macro(tanh, TanhFunctor, TanhGradFunctor); \
__macro(sqrt, SqrtFunctor, SqrtGradFunctor); \
__macro(abs, AbsFunctor, AbsGradFunctor); \
__macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \
__macro(log, LogFunctor, LogGradFunctor); \
__macro(square, SquareFunctor, SquareGradFunctor); \
__macro(brelu, BReluFunctor, BReluGradFunctor); \
__macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor); \
__macro(pow, PowFunctor, PowGradFunctor); \
__macro(stanh, STanhFunctor, STanhGradFunctor); \
__macro(softsign, SoftsignFunctor, SoftsignGradFunctor); \
__macro(relu6, Relu6Functor, Relu6GradFunctor); \
__macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \
__macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor)
#define FOR_EACH_KERNEL_FUNCTOR(__macro) \
__macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \
__macro(exp, ExpFunctor, ExpGradFunctor); \
__macro(relu, ReluFunctor, ReluGradFunctor); \
__macro(tanh, TanhFunctor, TanhGradFunctor); \
__macro(sqrt, SqrtFunctor, SqrtGradFunctor); \
__macro(abs, AbsFunctor, AbsGradFunctor); \
__macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \
__macro(log, LogFunctor, LogGradFunctor); \
__macro(square, SquareFunctor, SquareGradFunctor); \
__macro(brelu, BReluFunctor, BReluGradFunctor); \
__macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor); \
__macro(pow, PowFunctor, PowGradFunctor); \
__macro(stanh, STanhFunctor, STanhGradFunctor); \
__macro(softsign, SoftsignFunctor, SoftsignGradFunctor); \
__macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \
__macro(relu6, Relu6Functor, Relu6GradFunctor); \
__macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \
__macro(elu, ELUFunctor, ELUGradFunctor)
21 changes: 11 additions & 10 deletions paddle/operators/dynamic_recurrent_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class DynamicRecurrentOpProtoAndCheckerMaker
"names of pre-memories");
AddAttr<std::vector<std::string>>(name.memories, "names of memories");

AddComment("This is a recurrent group operator.");
AddComment("This is a RNN operator for varience-length sequences.");
}
};

Expand All @@ -65,9 +65,10 @@ void DynamicRecurrentOp::Run(const Scope& scope,
WriteStepInputs();
InitStates();

// call stepnet in all the time steps
for (size_t step = 0; step < cache_.num_steps; step++) {
// call stepnet
stepnet_->Run(scope, dev_ctx);
auto& step_scope = cache_.GetScope(step);
stepnet_->Run(step_scope, dev_ctx);
}

WriteStepOutputs();
Expand Down Expand Up @@ -96,10 +97,10 @@ void DynamicRecurrentOp::SplitInputs() const {
}

void DynamicRecurrentOp::WriteStepInputs() const {
const auto& inlinks = cache_.inlinks;
for (auto& item : inlinks) {
for (auto& item : cache_.inlinks) {
auto ta_it = step_inputs_.find(item.first);
PADDLE_ENFORCE(ta_it != step_inputs_.end(), "");
PADDLE_ENFORCE(ta_it != step_inputs_.end(),
"step_inputs_ not compatible with memory set");
TensorArray& ta = step_inputs_[item.first];
for (size_t step = 0; step < ta.size(); step++) {
auto tensor = ta.Read(step);
Expand Down Expand Up @@ -178,8 +179,8 @@ void DynamicRecurrentOp::InitStates() const {
const auto& dims = boot_state.dims();

for (size_t step = 0; step < cache_.num_steps; step++) {
// link pre-state to boot_state
auto& cur_scope = cache_.GetScope(step);
// link pre-state to boot_state
// init state and pre-state
auto* pre_state = cur_scope.FindVar(memory.pre_var);
PADDLE_ENFORCE_NOT_NULL(pre_state);
Expand All @@ -194,9 +195,9 @@ void DynamicRecurrentOp::InitStates() const {
states_[memory.var].WriteShared(step, state->Get<LoDTensor>());
// link previous scope's state to the pre-states in current scope
if (step == 0) {
auto* cur_state_tensor = pre_state->GetMutable<LoDTensor>();
cur_state_tensor->Resize(boot_state.dims());
cur_state_tensor->ShareDataWith<value_type>(boot_state);
auto* pre_state_tensor = pre_state->GetMutable<LoDTensor>();
pre_state_tensor->Resize(boot_state.dims());
pre_state_tensor->ShareDataWith<value_type>(boot_state);
} else {
auto& pre_scope = cache_.GetScope(step - 1);
auto* state_pre = pre_scope.FindVar(memory.var);
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/dynamic_recurrent_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class DynamicRecurrentOpTestHelper : public ::testing::Test {
framework::OpDesc CreateOpDesc() {
// create op
paddle::framework::OpDesc op_desc;
op_desc.set_type("dynamic_recurrent_op");
op_desc.set_type("dynamic_recurrent");

OpDescNewVar(argname.inlinks, {"in0"}, op_desc.add_inputs());
OpDescNewVar(argname.boot_memories, {"boot_mem"}, op_desc.add_inputs());
Expand Down
Loading

0 comments on commit b5e8a8a

Please sign in to comment.