diff --git a/cmake/external/lite.cmake b/cmake/external/lite.cmake index d8eb7b762051f..ce3860f6406ea 100644 --- a/cmake/external/lite.cmake +++ b/cmake/external/lite.cmake @@ -43,7 +43,7 @@ if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR) ${LITE_PROJECT} ${EXTERNAL_PROJECT_LOG_ARGS} GIT_REPOSITORY "https://github.com/PaddlePaddle/Paddle-Lite.git" - GIT_TAG 947cda26637d46dc23f4e39d2b52e7d9a1fa6eef + GIT_TAG b30dc65b264f7bc3753ba862ff4e529ea2af6665 PREFIX ${LITE_SOURCES_DIR} UPDATE_COMMAND "" BUILD_COMMAND ${LITE_BUILD_COMMAND} diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 5f4f31abea489..8f8f5766147a3 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -64,7 +64,14 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper) -set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass buffer_shared_inplace_op_pass buffer_shared_cross_op_memory_reuse_pass) +set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto + sequential_execution_pass + modify_op_lock_and_record_event_pass + all_reduce_deps_pass + reference_count_pass + eager_deletion_pass + buffer_shared_inplace_op_pass + buffer_shared_cross_op_memory_reuse_pass) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS}) cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope @@ -91,23 +98,22 @@ cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executo DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context) cc_test(fused_broadcast_op_test SRCS fused_broadcast_op_handle_test.cc DEPS fused_broadcast_op_handle) +set(IR_PASS_DEPS graph_viz_pass multi_devices_graph_pass + multi_devices_graph_print_pass multi_devices_graph_check_pass + fuse_elewise_add_act_pass fuse_bn_act_pass + multi_batch_merge_pass + fuse_relu_depthwise_conv_pass + lock_free_optimize_pass + coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass + fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass + sync_batch_norm_pass runtime_context_cache_pass) +if(WITH_GPU) + set(IR_PASS_DEPS ${IR_PASS_DEPS} fusion_group_pass) +endif() if(WITH_NGRAPH) - set(NGRAPH_BS_DEPS ngraph) -else() - set(NGRAPH_BS_DEPS) + set(IR_PASS_DEPS ${IR_PASS_DEPS} ngraph) endif() - -cc_library(build_strategy SRCS build_strategy.cc DEPS - graph_viz_pass multi_devices_graph_pass - multi_devices_graph_print_pass multi_devices_graph_check_pass - fuse_elewise_add_act_pass fuse_bn_act_pass multi_batch_merge_pass - fuse_relu_depthwise_conv_pass - lock_free_optimize_pass - coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass - fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass - sync_batch_norm_pass runtime_context_cache_pass - pass_builder - ${NGRAPH_BS_DEPS}) +cc_library(build_strategy SRCS build_strategy.cc DEPS pass_builder ${IR_PASS_DEPS}) if (WITH_MKLDNN) target_link_libraries(build_strategy mkldnn_placement_pass) diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index ca6871cc3ee35..c524af5fcaa66 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -165,9 +165,12 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { void AppendOpFusePasses() { AppendPassWithCheck(strategy_.fuse_relu_depthwise_conv_, "fuse_relu_depthwise_conv_pass"); + AppendPassWithCheck(strategy_.fuse_bn_act_ops_, "fuse_bn_act_pass"); +#ifdef PADDLE_WITH_CUDA + AppendPassWithCheck(strategy_.enable_auto_fusion_, "fusion_group_pass"); +#endif AppendPassWithCheck(strategy_.fuse_elewise_add_act_ops_, "fuse_elewise_add_act_pass"); - AppendPassWithCheck(strategy_.fuse_bn_act_ops_, "fuse_bn_act_pass"); // for single card training, fuse_all_reduce_ops is unnecessary. // coalesce_grad_tensor_pass should be before of MultiDevPass. AppendPassWithCheck(strategy_.fuse_all_reduce_ops_, @@ -370,6 +373,12 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, "GPU, skipped."; continue; } + } else if (pass->Type() == "fusion_group_pass") { + pass->Set("use_gpu", new bool(use_cuda)); + if (!use_cuda) { + LOG(WARNING) << "fusion_group_pass is only supported on GPU, skipped."; + continue; + } } else if (pass->Type() == "fuse_bn_act_pass") { if (!use_cuda) { LOG(WARNING) << "fuse_bn_act_pass is only supported on " @@ -427,3 +436,6 @@ USE_PASS(mkldnn_placement_pass); #ifdef PADDLE_WITH_NGRAPH USE_PASS(ngraph_subgraph_pass); #endif +#ifdef PADDLE_WITH_CUDA +USE_PASS(fusion_group_pass); +#endif diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index 0b12f16727f1c..0e59969989868 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -86,8 +86,9 @@ struct BuildStrategy { // Operator fusion // TODO(dev-paddle): fuse_elewise_add_act_ops may cause some models have // cycle. - bool fuse_elewise_add_act_ops_{false}; bool fuse_bn_act_ops_{false}; + bool fuse_elewise_add_act_ops_{false}; + bool enable_auto_fusion_{false}; // Fuse_all_optimizer_ops and fuse_all_reduce_ops require that gradients // should not be sparse types boost::optional fuse_all_optimizer_ops_{false}; diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 73264140f9452..4a3d01e669b32 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -6,7 +6,7 @@ file(APPEND ${pass_file} "\#include \"paddle/fluid/framework/ir/pass.h\"\n") add_subdirectory(fuse_optimizer_ops_pass) add_subdirectory(memory_optimize_pass) add_subdirectory(multi_devices_graph_pass) -if(NOT APPLE AND NOT WIN32) +if(NOT APPLE AND NOT WIN32 AND WITH_GPU) add_subdirectory(fusion_group) endif() diff --git a/paddle/fluid/framework/ir/fusion_group/CMakeLists.txt b/paddle/fluid/framework/ir/fusion_group/CMakeLists.txt index 1887f425d12a0..fe2bd27524fbf 100644 --- a/paddle/fluid/framework/ir/fusion_group/CMakeLists.txt +++ b/paddle/fluid/framework/ir/fusion_group/CMakeLists.txt @@ -1,9 +1,11 @@ -cc_library(code_generator SRCS operation.cc code_generator.cc code_generator_helper.cc DEPS graph) +cc_library(code_generator + SRCS operation.cc code_generator.cc code_generator_helper.cc + DEPS graph subgraph_detector) if(WITH_GPU) cc_test(test_code_generator SRCS code_generator_tester.cc DEPS code_generator device_code lod_tensor graph_viz_pass) endif() cc_library(fusion_group_pass SRCS fusion_group_pass.cc elementwise_group_detector.cc - DEPS graph_pattern_detector pass code_generator) + DEPS subgraph_detector fuse_pass_base code_generator device_code) cc_test(test_fusion_group_pass SRCS fusion_group_pass_tester.cc DEPS fusion_group_pass graph_viz_pass) diff --git a/paddle/fluid/framework/ir/fusion_group/code_generator.cc b/paddle/fluid/framework/ir/fusion_group/code_generator.cc index c41a2ed835914..0f9ee83a41108 100644 --- a/paddle/fluid/framework/ir/fusion_group/code_generator.cc +++ b/paddle/fluid/framework/ir/fusion_group/code_generator.cc @@ -33,7 +33,7 @@ CodeGenerator::CodeGenerator() { std::string CodeGenerator::Generate(SubGraph* subgraph) { std::vector expressions = ConvertToExpressions(subgraph); - return Generate(subgraph->func_name, expressions); + return Generate(subgraph->GetFuncName(), expressions); } static bool HasInput(Node* n, std::string name) { diff --git a/paddle/fluid/framework/ir/fusion_group/code_generator_tester.cc b/paddle/fluid/framework/ir/fusion_group/code_generator_tester.cc index 9515237f96440..a5409cb9d6abf 100644 --- a/paddle/fluid/framework/ir/fusion_group/code_generator_tester.cc +++ b/paddle/fluid/framework/ir/fusion_group/code_generator_tester.cc @@ -227,7 +227,7 @@ std::vector TestMain( std::string code_str = code_generator.Generate(subgraph); VLOG(3) << code_str; - TestMainImpl(subgraph->func_name, code_str, cpu_tensors, n, input_ids, + TestMainImpl(subgraph->GetFuncName(), code_str, cpu_tensors, n, input_ids, output_ids); // Need to check the accuracy according to expressions. diff --git a/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc b/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc index fb6f7b8d74650..7aa22dc6d6756 100644 --- a/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc +++ b/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc @@ -13,8 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h" +#include +#include +#include #include "paddle/fluid/framework/ir/fusion_group/operation.h" -#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/subgraph_detector.h" namespace paddle { namespace framework { @@ -26,20 +29,22 @@ static std::unordered_set unary_op_types; static std::unordered_set& GetBinaryOpTypes() { if (binary_op_types.empty()) { - binary_op_types = OperationMap::Instance().Find(0, 2); + binary_op_types = + OperationMap::Instance().Find(/* type= */ 0, /* num_operands= */ 2); } return binary_op_types; } static std::unordered_set& GetUnaryOpTypes() { if (unary_op_types.empty()) { - unary_op_types = OperationMap::Instance().Find(0, 1); + unary_op_types = + OperationMap::Instance().Find(/* type= */ 0, /* num_operands= */ 1); } return unary_op_types; } static bool IsSpecifiedOp(const std::unordered_set& op_types, - Node* n) { + const Node* n) { if (n && n->IsOp() && n->Op() && n->outputs.size() > 0U) { auto iter = op_types.find(n->Op()->Type()); if (iter != op_types.end()) { @@ -49,114 +54,71 @@ static bool IsSpecifiedOp(const std::unordered_set& op_types, return false; } -static bool IsBinaryOp(Node* n) { - if (IsSpecifiedOp(GetBinaryOpTypes(), n) && n->inputs.size() == 2U) { - auto* x = n->inputs[0]; - auto* y = n->inputs[1]; +static bool IsGradOp(const Node* n) { + PADDLE_ENFORCE_EQ(n && n->IsOp() && n->Op(), true, + platform::errors::InvalidArgument( + "Expected node %p to be an operator node.", n)); + std::string suffix = "_grad"; + std::string op_type = n->Op()->Type(); + size_t pos = op_type.rfind(suffix); + return pos != std::string::npos && + pos == (op_type.length() - suffix.length()); +} - std::vector x_shape; - std::vector y_shape; - if (x && x->IsVar() && x->Var()) { - x_shape = x->Var()->GetShape(); - } - if (y && y->IsVar() && y->Var()) { - y_shape = y->Var()->GetShape(); - } - if (x_shape.size() == 0U || x_shape.size() != y_shape.size()) { +static bool IsEqual(const std::vector& l, + const std::vector& r) { + if (l.size() == 0U || r.size() == 0U || l.size() != r.size()) { + return false; + } + for (size_t i = 0; i < l.size(); ++i) { + if (l[i] != r[i]) { return false; } - for (size_t i = 0; i < x_shape.size(); ++i) { - if (x_shape[i] != y_shape[i]) { - return false; - } - } - return true; } - return false; + return true; } -static bool IsUnaryOp(Node* n) { return IsSpecifiedOp(GetUnaryOpTypes(), n); } +static bool IsBinaryOp(const Node* n) { + if (IsSpecifiedOp(GetBinaryOpTypes(), n)) { + if ((!IsGradOp(n) && n->inputs.size() != 2U) || n->inputs.size() == 0U) { + return false; + } -bool ElementwiseGroupDetector::IsElementwiseOp(Node* n) { - return IsBinaryOp(n) || IsUnaryOp(n); -} + // The shape of all inputs should be the same. + std::vector shape_0; + for (size_t i = 0; i < n->inputs.size(); ++i) { + auto* in_i = n->inputs[i]; + if (!(in_i && in_i->IsVar() && in_i->Var())) { + return false; + } -bool ElementwiseGroupDetector::IsInputOfElementwiseOp(Node* n, - std::string name) { - if (n && n->IsVar() && n->Var()) { - for (auto* op : n->outputs) { - if (IsElementwiseOp(op)) { - if (name.empty()) { - return true; - } else if (IsNthInput(n, op, name, 0)) { - return true; + std::vector shape_i = in_i->Var()->GetShape(); + if (i == 0U) { + shape_0 = shape_i; + } else { + if (!IsEqual(shape_0, shape_i)) { + return false; } } } + return true; } return false; } -bool ElementwiseGroupDetector::IsOutputOfElementwiseOp(Node* n) { - if (n && n->IsVar() && n->Var()) { - for (auto* op : n->inputs) { - if (IsElementwiseOp(op)) { - return true; - } - } - } - return false; +static bool IsUnaryOp(const Node* n) { + return IsSpecifiedOp(GetUnaryOpTypes(), n); } -int ElementwiseGroupDetector::Search(Node* n, std::vector except_nodes) { - std::unordered_set except_nodes_set; - for (size_t i = 0; i < except_nodes.size(); ++i) { - except_nodes_set.insert(except_nodes[i]); - } - - int num_operations = 0; - if (IsElementwiseOp(n)) { - subgraph_.Insert(n); - num_operations += 1; - for (auto* var : n->inputs) { - subgraph_.Insert(var); - if (except_nodes_set.find(var) == except_nodes_set.end()) { - num_operations += Search(var, {n}); - } - } - for (auto* var : n->outputs) { - subgraph_.Insert(var); - if (except_nodes_set.find(var) == except_nodes_set.end()) { - num_operations += Search(var, {n}); - } - } - } else if (n && n->IsVar() && n->Var()) { - for (auto* op : n->inputs) { - if (IsElementwiseOp(op) && - except_nodes_set.find(op) == except_nodes_set.end()) { - num_operations += Search(op, {n}); - } - } - for (auto* op : n->outputs) { - if (IsElementwiseOp(op) && - except_nodes_set.find(op) == except_nodes_set.end()) { - num_operations += Search(op, {n}); - } - } - } - return num_operations; +bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) { + return IsBinaryOp(n) || IsUnaryOp(n); } -int ElementwiseGroupDetector::operator()(Node* n) { - if (!IsOutputOfElementwiseOp(n) && IsInputOfElementwiseOp(n, "X")) { - name_ = n->Name(); - subgraph_.Insert(n); - num_operations_ = Search(n, n->inputs); - VLOG(4) << "Detect elementwise subgraph begin with " << name_ << ", " - << num_operations_ << " operations, " << GetSubgraph().GetNumNodes() - << " nodes"; - } - return num_operations_; +std::vector> ElementwiseGroupDetector::operator()( + Graph* graph) { + auto teller = [&](const Node* n) -> bool { return IsElementwiseOp(n); }; + + return SubgraphDetector(graph, teller)(); } } // namespace fusion_group diff --git a/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h b/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h index 49d472eaab870..ff4db720f5dea 100644 --- a/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h +++ b/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h @@ -14,10 +14,8 @@ limitations under the License. */ #pragma once -#include -#include #include -#include "paddle/fluid/framework/ir/fusion_group/subgraph.h" +#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/node.h" namespace paddle { @@ -27,21 +25,10 @@ namespace fusion_group { class ElementwiseGroupDetector { public: - int operator()(Node* n); - - SubGraph GetSubgraph() const { return subgraph_; } - - private: - bool IsElementwiseOp(Node* n); - bool IsInputOfElementwiseOp(Node* n, std::string name = ""); - bool IsOutputOfElementwiseOp(Node* n); - - int Search(Node* n, std::vector except_nodes = {}); + std::vector> operator()(Graph* graph); private: - std::string name_; - int num_operations_{0}; - SubGraph subgraph_; + bool IsElementwiseOp(const Node* n); }; } // namespace fusion_group diff --git a/paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc b/paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc index 4999acbf7daf2..1d9d4ab5d232d 100644 --- a/paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc +++ b/paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc @@ -13,57 +13,88 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/ir/fusion_group/fusion_group_pass.h" +#include +#include #include +#include "paddle/fluid/framework/ir/fusion_group/code_generator.h" #include "paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/pass_tester_helper.h" +#include "paddle/fluid/framework/op_proto_maker.h" +#include "paddle/fluid/platform/device_code.h" namespace paddle { namespace framework { namespace ir { void FusionGroupPass::ApplyImpl(ir::Graph* graph) const { - PADDLE_ENFORCE_NOT_NULL(graph); - - int num_elementwise_groups = DetectFusionGroup(graph, 0); - LOG(INFO) << "Detect " << num_elementwise_groups + FusePassBase::Init("fusion_group_pass", graph); + if (Get("use_gpu")) { + fusion_group::OperationMap::Init(); + int num_elementwise_groups = DetectFusionGroup(graph, 0); + VLOG(3) << "Detect " << num_elementwise_groups << " elementwise fusion groups."; + } } int FusionGroupPass::DetectFusionGroup(Graph* graph, int type) const { - std::vector subgraphs; - std::unordered_set all_nodes = graph->Nodes(); - for (Node* n : all_nodes) { - bool is_found = false; - for (auto& subgraph : subgraphs) { - if (subgraph.Has(n)) { - is_found = true; - break; - } - } - if (is_found) { - continue; + // TODO(liuyiqun): supported different places + platform::CUDAPlace place = platform::CUDAPlace(0); + int index = platform::DeviceCodePool::Init({place}).size(place); + + std::vector> subgraphs = + fusion_group::ElementwiseGroupDetector()(graph); + + int num_subgraphs = 0; + size_t min_subgraph_size = 2; + bool save_intermediate_out = true; + for (auto& vec : subgraphs) { + if (vec.size() >= min_subgraph_size) { + std::string func_name = "fused_elementwise_" + std::to_string(index++); + fusion_group::SubGraph subgraph( + type, func_name, save_intermediate_out, + std::unordered_set(vec.begin(), vec.end())); + VLOG(3) << "subgraph: {\n" + << DebugString(subgraph.SortedNodes()) << "}\n"; + + GenerateCode(&subgraph); + InsertFusionGroupOp(graph, &subgraph); + num_subgraphs++; } + } + return num_subgraphs; +} - fusion_group::SubGraph subgraph; - if (type == 0) { - fusion_group::ElementwiseGroupDetector detector; - int num_operations = detector(n); - if (num_operations >= 2) { - subgraph = detector.GetSubgraph(); - } - } +void FusionGroupPass::GenerateCode(fusion_group::SubGraph* subgraph) const { + fusion_group::CodeGenerator code_generator; + std::string code_str = code_generator.Generate(subgraph); + VLOG(3) << code_str; + + // TODO(liuyiqun): supported different places + platform::CUDAPlace place = platform::CUDAPlace(0); + std::unique_ptr device_code( + new platform::CUDADeviceCode(place, subgraph->GetFuncName(), code_str)); + device_code->Compile(); + + platform::DeviceCodePool& pool = platform::DeviceCodePool::Init({place}); + pool.Set(std::move(device_code)); +} - if (!subgraph.IsEmpty()) { - subgraphs.push_back(subgraph); +static int ExtractOpRole(fusion_group::SubGraph* subgraph) { + std::unordered_set op_roles; + std::string attr_name = OpProtoAndCheckerMaker::OpRoleAttrName(); + for (auto* n : subgraph->Nodes()) { + if (n && n->IsOp() && n->Op()) { + if (n->Op()->HasAttr(attr_name)) { + op_roles.insert(boost::get(n->Op()->GetAttr(attr_name))); + } } } - - // TODO(liuyiqun): check whether there are intersection between subgraphs - for (size_t i = 0; i < subgraphs.size(); ++i) { - InsertFusionGroupOp(graph, &subgraphs[i]); + if (op_roles.size() == 1U) { + return *(op_roles.begin()); + } else { + return static_cast(OpRole::kNotSpecified); } - return subgraphs.size(); } void FusionGroupPass::InsertFusionGroupOp( @@ -90,10 +121,12 @@ void FusionGroupPass::InsertFusionGroupOp( external_nodes.insert(n); } op_desc.SetOutput("Outs", output_names); - op_desc.SetAttr("type", subgraph->type); - op_desc.SetAttr("func_name", subgraph->func_name); + op_desc.SetAttr("type", subgraph->GetType()); + op_desc.SetAttr("func_name", subgraph->GetFuncName()); + op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), + ExtractOpRole(subgraph)); - auto fusion_group_node = graph->CreateOpNode(&op_desc); + Node* fusion_group_node = graph->CreateOpNode(&op_desc); for (auto* in : input_vars_of_subgraph) { IR_NODE_LINK_TO(in, fusion_group_node); } @@ -114,4 +147,5 @@ void FusionGroupPass::InsertFusionGroupOp( } // namespace framework } // namespace paddle -REGISTER_PASS(fusion_group_pass, paddle::framework::ir::FusionGroupPass); +REGISTER_PASS(fusion_group_pass, paddle::framework::ir::FusionGroupPass) + .RequirePassAttr("use_gpu"); diff --git a/paddle/fluid/framework/ir/fusion_group/fusion_group_pass.h b/paddle/fluid/framework/ir/fusion_group/fusion_group_pass.h index 8bdddf8877c06..72c7250e7205e 100644 --- a/paddle/fluid/framework/ir/fusion_group/fusion_group_pass.h +++ b/paddle/fluid/framework/ir/fusion_group/fusion_group_pass.h @@ -16,19 +16,20 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fusion_group/subgraph.h" -#include "paddle/fluid/framework/ir/pass.h" namespace paddle { namespace framework { namespace ir { -class FusionGroupPass : public Pass { +class FusionGroupPass : public FusePassBase { protected: void ApplyImpl(Graph* graph) const override; private: int DetectFusionGroup(Graph* graph, int type = 0) const; + void GenerateCode(fusion_group::SubGraph* subgraph) const; void InsertFusionGroupOp(Graph* graph, fusion_group::SubGraph* subgraph) const; diff --git a/paddle/fluid/framework/ir/fusion_group/fusion_group_pass_tester.cc b/paddle/fluid/framework/ir/fusion_group/fusion_group_pass_tester.cc index 172ec0c0ee84d..2446716019cda 100644 --- a/paddle/fluid/framework/ir/fusion_group/fusion_group_pass_tester.cc +++ b/paddle/fluid/framework/ir/fusion_group/fusion_group_pass_tester.cc @@ -138,19 +138,15 @@ int TestMain(std::unique_ptr graph, std::string prefix) { } TEST(FusionGroupPass, elementwise_list) { - fusion_group::OperationMap::Init(); - - std::unique_ptr graph = BuildElementwiseListGraph(false); + std::unique_ptr graph = BuildElementwiseListGraph(true); int num_fusion_group_ops = TestMain(std::move(graph), "elementwise_list"); - EXPECT_EQ(num_fusion_group_ops, 1); + EXPECT_EQ(num_fusion_group_ops, 2); } TEST(FusionGroupPass, elementwise_tree) { - fusion_group::OperationMap::Init(); - - std::unique_ptr graph = BuildElementwiseTreeGraph(false); + std::unique_ptr graph = BuildElementwiseTreeGraph(true); int num_fusion_group_ops = TestMain(std::move(graph), "elementwise_tree"); - EXPECT_EQ(num_fusion_group_ops, 2); + EXPECT_EQ(num_fusion_group_ops, 4); } } // namespace ir diff --git a/paddle/fluid/framework/ir/fusion_group/subgraph.h b/paddle/fluid/framework/ir/fusion_group/subgraph.h index 1dd9caa10c98a..b9810882e1cc7 100644 --- a/paddle/fluid/framework/ir/fusion_group/subgraph.h +++ b/paddle/fluid/framework/ir/fusion_group/subgraph.h @@ -20,48 +20,59 @@ limitations under the License. */ #include #include "paddle/fluid/framework/ir/fusion_group/operation.h" #include "paddle/fluid/framework/ir/node.h" -#include "paddle/fluid/framework/ir/pass_tester_helper.h" +#include "paddle/fluid/framework/ir/subgraph_detector.h" namespace paddle { namespace framework { namespace ir { namespace fusion_group { -struct SubGraph { - int type{-1}; - std::string func_name; - bool save_intermediate_out{false}; - +class SubGraph { + public: SubGraph() = default; - SubGraph(int t, std::string f, bool s, const std::unordered_set& n) - : type(t), func_name(f), save_intermediate_out(s), nodes_set(n) {} + explicit SubGraph(int type) : type_(type) {} + SubGraph(int type, std::string func_name, bool save_intermediate_out, + const std::unordered_set& nodes_set) + : type_(type), + func_name_(func_name), + save_intermediate_out_(save_intermediate_out) { + for (auto* n : nodes_set) { + nodes_set_.insert(n); + if (n && n->IsOp() && n->Op()) { + // If the node is an op node, then add its input/output var nodes + // into the subgraph. + for (auto* in : n->inputs) { + nodes_set_.insert(in); + } + for (auto* out : n->outputs) { + nodes_set_.insert(out); + } + } + } + } - bool IsEmpty() { return nodes_set.empty(); } + bool IsEmpty() { return nodes_set_.empty(); } - const std::unordered_set& Nodes() const { return nodes_set; } + int GetType() const { return type_; } + void SetFuncName(std::string func_name) { func_name_ = func_name; } + std::string GetFuncName() const { return func_name_; } + + const std::unordered_set& Nodes() const { return nodes_set_; } const std::vector& SortedNodes() { - if (!is_sorted) { - Sort(); + if (!is_sorted_) { + TopologicalSort(); } - return sorted_nodes; + return sorted_nodes_; } - size_t GetNumNodes() { return nodes_set.size(); } + size_t GetNumNodes() { return nodes_set_.size(); } - bool Has(Node* n) { return nodes_set.find(n) != nodes_set.end(); } - - void Insert(Node* n) { - if (nodes_set.find(n) == nodes_set.end()) { - VLOG(5) << "Insert " << n->Name() << " to subgraph " << this; - nodes_set.insert(n); - is_sorted = false; - } - } + bool Has(Node* n) { return nodes_set_.find(n) != nodes_set_.end(); } int GetNumOperations() { int num_operations = 0; - for (auto* n : nodes_set) { + for (auto* n : nodes_set_) { if (n && n->IsOp() && n->Op()) { num_operations++; } @@ -96,203 +107,108 @@ struct SubGraph { std::vector GetOutputVarNodes() { // The order of output nodes should be consistant anywhere.. - std::vector output_vars; + std::vector output_vars_all; for (auto* n : SortedNodes()) { if (n && n->IsVar() && n->Var()) { - if (save_intermediate_out) { - // If the var_node is the output of some op_node in the subgraph, it - // is considered the output var node of the subgraph. - bool is_found = false; - for (auto* in : n->inputs) { - if (Has(in)) { - is_found = true; - } - } - if (is_found) { - output_vars.push_back(n); - } - } else { - // If one of the var_node's outputs is the input of some operator - // outside the subgraph, it is considered the output var node of the - // subgraph. - bool is_found = true; - if (n->outputs.size() == 0U) { - is_found = false; - } - for (auto* out : n->outputs) { - if (!Has(out)) { - is_found = false; - } - } - if (!is_found) { - output_vars.push_back(n); + // If the var_node is the output of some op_node in the subgraph, it + // is considered the output var node of the subgraph. + bool is_found = false; + for (auto* in : n->inputs) { + if (Has(in)) { + is_found = true; } } + if (is_found) { + output_vars_all.push_back(n); + } } } - return output_vars; - } - private: - int FindIndexInSortedNodes(Node* n) { - for (size_t i = 0; i < sorted_nodes.size(); ++i) { - if (n == sorted_nodes[i]) { - return static_cast(i); - } + if (save_intermediate_out_) { + return output_vars_all; } - return -1; - } - - void SortVarsBasedOnSortedOps() { - // Insert var nodes to sorted_nodes. - std::unordered_map sorted_vars; - for (auto* n : nodes_set) { - if (n && n->IsVar() && n->Var()) { - int from = 0; - int to = sorted_nodes.size(); - - for (auto* in : n->inputs) { - if (in && in->IsOp() && in->Op()) { - int index = FindIndexInSortedNodes(in); - // Insert after input op node - if (index >= 0) { - from = index + 1 > from ? index + 1 : from; - } - } - } - - for (auto* out : n->outputs) { - if (out && out->IsOp() && out->Op()) { - int index = FindIndexInSortedNodes(out); - // Insert before output op node - if (index >= 0) { - to = index < to ? index : to; - } - } - } - if (from > to) { - LOG(INFO) << "subgraph: {\n" << DebugString(Nodes()) << "}\n"; - LOG(INFO) << "sorted nodes: {\n" - << DebugString(sorted_nodes) << "}\n"; + std::vector output_vars_outside; + for (auto* n : output_vars_all) { + // If one of the var_node's outputs is the input of some operator + // outside the subgraph, it is considered the output var node of the + // subgraph. + bool is_found = true; + if (n->outputs.size() == 0U) { + is_found = false; + } + for (auto* out : n->outputs) { + if (!Has(out)) { + is_found = false; } - PADDLE_ENFORCE_LE(from, to, "Range [%d, %d] is invalid.", from, to); - sorted_nodes.insert(sorted_nodes.begin() + to, n); - sorted_vars[n->Name()] = n; + } + if (!is_found) { + output_vars_outside.push_back(n); } } + return output_vars_outside; } - std::vector SortedOps() { - Node* start_op_n = nullptr; - std::unordered_set ops; - for (auto* op_n : nodes_set) { - if (op_n && op_n->IsOp() && op_n->Op()) { - // Initialize ops to all ops in the subgraph. - ops.insert(op_n); + private: + void TopologicalSort() { + if (!is_sorted_) { + std::unordered_map> inputs_map; + std::unordered_map> outputs_map; + for (auto* n : nodes_set_) { + inputs_map[n] = n->inputs; + outputs_map[n] = n->outputs; + } - if (!start_op_n) { - // Find start op node whose inputs are produced outside the subgraph. - bool is_found = false; - for (auto* prev_op_n : GetPrevOpNodes(op_n)) { - if (Has(prev_op_n)) { - is_found = true; - break; + for (auto* n : nodes_set_) { + if (n && n->IsVar() && n->Var()) { + // Set the input of subgraph's input var node to null. + std::vector inputs; + for (auto* in : n->inputs) { + if (Has(in)) { + inputs.push_back(in); } } - if (!is_found) { - start_op_n = op_n; + // Set the output of subgraph's output var node to null. + std::vector outputs; + for (auto* out : n->outputs) { + if (Has(out)) { + outputs.push_back(out); + } } + n->inputs = inputs; + n->outputs = outputs; } } - } - - std::vector sorted_ops; - sorted_ops.push_back(start_op_n); - ops.erase(start_op_n); - while (ops.size() > 0U) { - std::unordered_set erased_ops; - for (auto* op_n : ops) { - bool found_connected_ops = false; - int from = 1; - int to = sorted_ops.size(); - std::unordered_set prev_op_nodes = GetPrevOpNodes(op_n); - std::unordered_set next_op_nodes = GetNextOpNodes(op_n); - for (int i = sorted_ops.size(); i >= 0; --i) { - if (prev_op_nodes.find(sorted_ops[i]) != prev_op_nodes.end()) { - // Insert after i (i + 1) - found_connected_ops = true; - from = (i + 1 > from) ? i + 1 : from; - } - if (next_op_nodes.find(sorted_ops[i]) != next_op_nodes.end()) { - // Insert before i - found_connected_ops = true; - to = (i < to) ? i : to; - } - } - if (found_connected_ops) { - if (from > to) { - LOG(INFO) << "subgraph: {\n" << DebugString(Nodes()) << "}\n"; - } - PADDLE_ENFORCE_LE(from, to, "Range [%d, %d] is invalid.", from, to); - sorted_ops.insert(sorted_ops.begin() + to, op_n); - erased_ops.insert(op_n); + // Collect the start points of the subgraph. + std::vector start_points; + for (auto* n : nodes_set_) { + if (n->inputs.empty()) { + start_points.push_back(n); } } - PADDLE_ENFORCE_GT(erased_ops.size(), 0U); - for (auto* op_n : erased_ops) { - ops.erase(op_n); + // Sort the subgraph. + NodesTSIterator x(start_points); + for (auto& n : iterator_range( + NodesTSIterator(start_points), NodesTSIterator())) { + sorted_nodes_.push_back(&n); } - } - return sorted_ops; - } - - std::unordered_set GetPrevOpNodes(Node* op_n) { - PADDLE_ENFORCE_EQ(op_n && op_n->IsOp() && op_n->Op(), true, - "Node %p is not a op node.", op_n); - - std::unordered_set prev_op_nodes; - for (auto* in_var : op_n->inputs) { - if (in_var && in_var->IsVar() && in_var->Var()) { - for (auto* prev_op_n : in_var->inputs) { - if (prev_op_n && prev_op_n->IsOp() && prev_op_n->Op()) { - prev_op_nodes.insert(prev_op_n); - } - } + // Reset the inputs, outputs. + for (auto* n : nodes_set_) { + n->inputs = inputs_map[n]; + n->outputs = outputs_map[n]; } } - return prev_op_nodes; - } - - std::unordered_set GetNextOpNodes(Node* op_n) { - PADDLE_ENFORCE_EQ(op_n && op_n->IsOp() && op_n->Op(), true, - "Node %p is not a op node.", op_n); - - std::unordered_set next_op_nodes; - for (auto* out_var : op_n->outputs) { - if (out_var && out_var->IsVar() && out_var->Var()) { - for (auto* next_op_n : out_var->outputs) { - if (next_op_n && next_op_n->IsOp() && next_op_n->Op()) { - next_op_nodes.insert(next_op_n); - } - } - } - } - return next_op_nodes; - } - - void Sort() { - if (!is_sorted) { - sorted_nodes = SortedOps(); - SortVarsBasedOnSortedOps(); - } - is_sorted = true; + is_sorted_ = true; } private: - std::unordered_set nodes_set; - bool is_sorted{false}; - std::vector sorted_nodes; + int type_{-1}; + std::string func_name_; + bool save_intermediate_out_{true}; + + std::unordered_set nodes_set_; + bool is_sorted_{false}; + std::vector sorted_nodes_; }; } // namespace fusion_group diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 42095c7bdc3e2..72dec87847b7b 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -167,7 +167,7 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) { } { - platform::RecordEvent record_event(Type() + "_op"); + platform::RecordEvent record_event(Type()); RunImpl(scope, place); } @@ -950,7 +950,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, std::vector transfered_inplace_vars; Scope* transfer_scope = nullptr; { - platform::RecordEvent record_event("prepare_data"); + platform::RecordEvent record_event("prepare_data_inner_op"); transfer_scope = PrepareData(scope, *kernel_type_, &transfered_inplace_vars, runtime_ctx); } @@ -963,7 +963,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, } if (!all_kernels_must_compute_runtime_shape_) { - platform::RecordEvent record_event("infer_shape"); + platform::RecordEvent record_event("infer_shape_inner_op"); RuntimeInferShapeContext infer_shape_ctx(*this, *runtime_ctx); this->InferShape(&infer_shape_ctx); } @@ -975,7 +975,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, // TODO(panyx0718): ExecutionContext should only depend on RuntimeContext // not Scope. Imperative mode only pass inputs and get outputs. { - platform::RecordEvent record_event("compute"); + platform::RecordEvent record_event("compute_inner_op"); (*kernel_func_)(ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx, kernel_configs)); } diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index c3590ef606006..353cad77b3b47 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -367,6 +367,9 @@ download_data(${LITE_MODEL_INSTALL_DIR} "mul_model_fp32.tgz") inference_analysis_test(lite_mul_model_test SRCS lite_mul_model_test.cc EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} ARGS --infer_model=${LITE_MODEL_INSTALL_DIR}) +inference_analysis_test(lite_resnet50_test SRCS lite_resnet50_test.cc + EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} + ARGS --infer_model=${RESNET50_MODEL_DIR}) inference_analysis_test(test_analyzer_capi SRCS analyzer_capi_tester.cc EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} paddle_fluid_c ARGS --infer_model=${RESNET50_MODEL_DIR}/model) diff --git a/paddle/fluid/inference/tests/api/lite_resnet50_test.cc b/paddle/fluid/inference/tests/api/lite_resnet50_test.cc new file mode 100644 index 0000000000000..e028c6e6410e7 --- /dev/null +++ b/paddle/fluid/inference/tests/api/lite_resnet50_test.cc @@ -0,0 +1,71 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include + +#include "paddle/fluid/inference/tests/api/tester_helper.h" + +namespace paddle { +namespace inference { + +TEST(AnalysisPredictor, use_gpu) { + std::string model_dir = FLAGS_infer_model + "/" + "model"; + AnalysisConfig config; + config.EnableUseGpu(100, 0); + config.SetModel(model_dir + "/model", model_dir + "/params"); + config.EnableLiteEngine(paddle::AnalysisConfig::Precision::kFloat32); + + std::vector inputs; + auto predictor = CreatePaddlePredictor(config); + const int batch = 1; + const int channel = 3; + const int height = 318; + const int width = 318; + const int input_num = batch * channel * height * width; + std::vector input(input_num, 1); + + PaddleTensor in; + in.shape = {1, 3, 318, 318}; + in.data = + PaddleBuf(static_cast(input.data()), input_num * sizeof(float)); + in.dtype = PaddleDType::FLOAT32; + inputs.emplace_back(in); + + std::vector outputs; + ASSERT_TRUE(predictor->Run(inputs, &outputs)); + + const std::vector truth_values = { + 127.780396, 738.16656, 1013.2264, -438.17206, 366.4022, 927.66187, + 736.2241, -633.68567, -329.92737, -430.15637, -633.0639, -146.54858, + -1324.2804, -1349.3661, -242.67671, 117.44864, -801.7251, -391.51495, + -404.8202, 454.16132, 515.48206, -133.03114, 69.293076, 590.09753, + -1434.6917, -1070.8903, 307.0744, 400.52573, -316.12177, -587.1265, + -161.05742, 800.3663, -96.47157, 748.708, 868.17645, -447.9403, + 112.73656, 1127.1992, 47.43518, 677.7219, 593.1881, -336.4011, + 551.3634, 397.82474, 78.39835, -715.4006, 405.96988, 404.25684, + 246.01978, -8.430191, 131.36617, -648.0528}; + + const size_t expected_size = 1; + EXPECT_EQ(outputs.size(), expected_size); + float* data_o = static_cast(outputs[0].data.data()); + for (size_t j = 0; j < outputs[0].data.length() / sizeof(float); j += 10) { + EXPECT_NEAR(data_o[j], truth_values[j / 10], 6e-3); + } +} + +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/operators/argsort_op.cc b/paddle/fluid/operators/argsort_op.cc index 1995b7ba048bb..33a1aafa0fd2e 100644 --- a/paddle/fluid/operators/argsort_op.cc +++ b/paddle/fluid/operators/argsort_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/argsort_op.h" +#include namespace paddle { namespace operators { @@ -21,7 +22,7 @@ class ArgsortOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of ArgsortOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), @@ -49,6 +50,24 @@ class ArgsortOp : public framework::OperatorWithKernel { } }; +class ArgsortGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*-->*/ framework::GradVarName("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); + } +}; + class ArgsortOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -83,16 +102,42 @@ Output(Indices) gives the sorted order along the given axis Attr(axis). } }; +template +class ArgsortGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new T()); + op->SetType("argsort_grad"); + op->SetInput("Indices", this->Output("Indices")); + op->SetInput("X", this->Input("X")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetAttrMap(this->Attrs()); + return op; + } +}; + +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ArgsortGradNoNeedBufferVarInference, "X"); + } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR( - argsort, ops::ArgsortOp, ops::ArgsortOpMaker, - paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); +REGISTER_OPERATOR(argsort, ops::ArgsortOp, ops::ArgsortOpMaker, + ops::ArgsortGradOpMaker, + ops::ArgsortGradOpMaker); +REGISTER_OPERATOR(argsort_grad, ops::ArgsortGradOp, + ops::ArgsortGradNoNeedBufferVarInference); REGISTER_OP_CPU_KERNEL(argsort, ops::ArgsortKernel, ops::ArgsortKernel, ops::ArgsortKernel, ops::ArgsortKernel); +REGISTER_OP_CPU_KERNEL( + argsort_grad, ops::ArgsortGradientKernel, + ops::ArgsortGradientKernel, + ops::ArgsortGradientKernel, + ops::ArgsortGradientKernel); diff --git a/paddle/fluid/operators/argsort_op.cu b/paddle/fluid/operators/argsort_op.cu index 0ea7e3dcb1486..006bf559195aa 100644 --- a/paddle/fluid/operators/argsort_op.cu +++ b/paddle/fluid/operators/argsort_op.cu @@ -58,6 +58,19 @@ static __global__ void FillIndex(T* indices, T num_rows, T num_cols) { } } +template +static __global__ void FillGrad(const T* dO, const IndType* indices, T* dX, + IndType num_rows, IndType num_cols) { + int col_id = threadIdx.x; + int row_id = blockIdx.x; + + for (IndType j = row_id; j < num_rows; j += gridDim.x) { + for (IndType i = col_id; i < num_cols; i += blockDim.x) { + dX[j * num_cols + indices[j * num_cols + i]] = dO[j * num_cols + i]; + } + } +} + // Sort by flag descending, True: descending. False: Ascending. // Default is false. template @@ -160,6 +173,35 @@ void ArgFullSort(const platform::CUDADeviceContext& ctx, const Tensor* input, temp_storage_bytes, cudaGetErrorString(err)); } +template +void ArgFullAssign(const platform::CUDADeviceContext& ctx, const Tensor* dO, + const Tensor* indices, Tensor* dX, const IndType num_rows, + const IndType num_cols) { + auto cu_stream = ctx.stream(); + + auto ComputeBlockSize = [](IndType col) { + if (col > 512) + return 1024; + else if (col > 256 && col <= 512) + return 512; + else if (col > 128 && col <= 256) + return 256; + else if (col > 64 && col <= 128) + return 128; + else + return 64; + }; + + int block_size = ComputeBlockSize(num_cols); + + int maxGridDimX = ctx.GetCUDAMaxGridDimSize().x; + // actually, int num_rows < max_grid_size + int grid_size = num_rows < maxGridDimX ? num_rows : maxGridDimX; + FillGrad<<>>( + dO->data(), indices->data(), dX->data(), num_rows, + num_cols); +} + template class ArgsortOpCUDAKernel : public framework::OpKernel { public: @@ -234,6 +276,81 @@ class ArgsortOpCUDAKernel : public framework::OpKernel { } }; +template +class ArgsortGradOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* indices = ctx.Input("Indices"); + auto* dX = ctx.Output(framework::GradVarName("X")); + auto* dO = ctx.Input(framework::GradVarName("Out")); + int axis = ctx.Attr("axis"); + + dX->mutable_data(ctx.GetPlace()); + auto dxt = framework::EigenVector::Flatten(*dX); + auto& place = *ctx.template device_context() + .eigen_device(); + dxt.device(place) = dxt.constant(static_cast(0)); + if (dO->numel() == 0) return; + + auto in_dims = indices->dims(); + axis = (axis < 0) ? (in_dims.size() + axis) : axis; + + int64_t numel = indices->numel(); + + // Special case for full sort, speedup ~190x. + if (axis == -1 || axis + 1 == in_dims.size()) { + const int64_t input_height = framework::product( + framework::slice_ddim(in_dims, 0, in_dims.size() - 1)); + const int64_t input_width = in_dims[in_dims.size() - 1]; + const auto& dev_ctx = ctx.cuda_device_context(); + ArgFullAssign(dev_ctx, dO, indices, dX, input_height, + input_width); + } else { + // if not full sort, do transpose first + std::vector trans; + for (int i = 0; i < axis; i++) { + trans.push_back(i); + } + trans.push_back(in_dims.size() - 1); + for (int i = axis + 1; i < in_dims.size() - 1; i++) { + trans.push_back(i); + } + trans.push_back(axis); + framework::DDim trans_dims(in_dims); + for (int i = 0; i < trans.size(); i++) { + trans_dims[i] = in_dims[trans[i]]; + } + + Tensor trans_dO; + trans_dO.mutable_data(trans_dims, ctx.GetPlace()); + Tensor trans_ind; + trans_ind.mutable_data(trans_dims, ctx.GetPlace()); + int ndims = trans.size(); + const auto& dev_ctx = ctx.cuda_device_context(); + // Do transpose + TransCompute(ndims, dev_ctx, *dO, + &trans_dO, trans); + TransCompute( + ndims, dev_ctx, *indices, &trans_ind, trans); + + const int64_t input_height = framework::product( + framework::slice_ddim(trans_dims, 0, trans_dims.size() - 1)); + const int64_t input_width = trans_dims[trans_dims.size() - 1]; + + Tensor tmp_out; + tmp_out.mutable_data(trans_dims, ctx.GetPlace()); + + ArgFullAssign(dev_ctx, &trans_dO, &trans_ind, &tmp_out, + input_height, input_width); + + // transpose back + TransCompute(ndims, dev_ctx, tmp_out, dX, + trans); + return; + } + } +}; + } // namespace operators } // namespace paddle @@ -243,3 +360,9 @@ REGISTER_OP_CUDA_KERNEL( paddle::operators::ArgsortOpCUDAKernel, paddle::operators::ArgsortOpCUDAKernel, paddle::operators::ArgsortOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL( + argsort_grad, paddle::operators::ArgsortGradOpCUDAKernel, + paddle::operators::ArgsortGradOpCUDAKernel, + paddle::operators::ArgsortGradOpCUDAKernel, + paddle::operators::ArgsortGradOpCUDAKernel, + paddle::operators::ArgsortGradOpCUDAKernel); diff --git a/paddle/fluid/operators/argsort_op.h b/paddle/fluid/operators/argsort_op.h index c48c4c14a83ec..fb353a8a2367b 100644 --- a/paddle/fluid/operators/argsort_op.h +++ b/paddle/fluid/operators/argsort_op.h @@ -68,6 +68,31 @@ static void FullSort(Type input_height, Type input_width, int input_dim, } } } + +template +static void FullAssign(Type input_height, Type input_width, int input_dim, + const framework::Tensor* input, + const framework::Tensor* indices, T* t_out) { +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (Type i = 0; i < input_height; ++i) { + if (input_dim == 1) { + auto e_input = EigenVector::Flatten(*input); + auto e_indices = EigenVector::Flatten(*indices); + for (Type j = 0; j < input_width; ++j) { + t_out[i * input_width + e_indices(j)] = e_input(e_indices(j)); + } + } else { + auto e_input = EigenMatrix::Reshape(*input, input_dim - 1); + auto e_indices = EigenMatrix::Reshape(*indices, input_dim - 1); + for (Type j = 0; j < input_width; ++j) { + t_out[i * input_width + e_indices(i, j)] = e_input(i, e_indices(i, j)); + } + } + } +} + template class ArgsortKernel : public framework::OpKernel { public: @@ -142,5 +167,77 @@ class ArgsortKernel : public framework::OpKernel { } }; +template +class ArgsortGradientKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* indices = ctx.Input("Indices"); + auto* dX = ctx.Output(framework::GradVarName("X")); + auto* dO = ctx.Input(framework::GradVarName("Out")); + int axis = ctx.Attr("axis"); + + auto in_dims = indices->dims(); + axis = (axis < 0) ? (in_dims.size() + axis) : axis; + + dX->mutable_data(ctx.GetPlace()); + auto dxt = framework::EigenVector::Flatten(*dX); + auto& place = *ctx.template device_context() + .eigen_device(); + dxt.device(place) = dxt.constant(static_cast(0)); + if (dO->numel() == 0) return; + + // Do full assign + if (axis == -1 || axis + 1 == in_dims.size()) { + const int64_t input_height = framework::product( + framework::slice_ddim(in_dims, 0, in_dims.size() - 1)); + const int64_t input_width = in_dims[in_dims.size() - 1]; + + FullAssign(input_height, input_width, in_dims.size(), dO, + indices, dX->data()); + } else { + // If not full assign do transpose + std::vector trans; + for (int i = 0; i < axis; i++) { + trans.push_back(i); + } + trans.push_back(in_dims.size() - 1); + for (int i = axis + 1; i < in_dims.size() - 1; i++) { + trans.push_back(i); + } + trans.push_back(axis); + framework::DDim trans_dims(in_dims); + for (size_t i = 0; i < trans.size(); i++) { + trans_dims[i] = in_dims[trans[i]]; + } + + Tensor trans_dO; + trans_dO.mutable_data(trans_dims, ctx.GetPlace()); + Tensor trans_ind; + trans_ind.mutable_data(trans_dims, ctx.GetPlace()); + int ndims = trans.size(); + auto& dev_ctx = ctx.template device_context(); + // Do transpose + TransCompute(ndims, dev_ctx, *dO, + &trans_dO, trans); + TransCompute( + ndims, dev_ctx, *indices, &trans_ind, trans); + + const int64_t input_height = framework::product( + framework::slice_ddim(trans_dims, 0, trans_dims.size() - 1)); + const int64_t input_width = trans_dims[trans_dims.size() - 1]; + + Tensor tmp_out; + T* t_out = tmp_out.mutable_data(trans_dims, ctx.GetPlace()); + + FullAssign(input_height, input_width, in_dims.size(), + &trans_dO, &trans_ind, t_out); + + // transpose back + TransCompute(ndims, dev_ctx, tmp_out, dX, + trans); + } + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/platform/profiler.cc b/paddle/fluid/platform/profiler.cc index a32b89bc4ce02..baa5c2743f9d0 100644 --- a/paddle/fluid/platform/profiler.cc +++ b/paddle/fluid/platform/profiler.cc @@ -372,12 +372,13 @@ void PrintProfiler(const std::vector> &events_table, std::vector> child_table; std::vector table; bool do_next = false; - std::string op_end_str = "_op"; + std::string op_end_str = "inner_op"; for (auto it = child_map.begin(); it != child_map.end(); it++) { if (it->first == event_item.name) { table.push_back(it->second); - do_next = it->second.name.rfind(op_end_str) == - (it->second.name.length() - op_end_str.length()); + if (!do_next) + do_next = !(it->second.name.rfind(op_end_str) == + (it->second.name.length() - op_end_str.length())); } } child_table.push_back(table); @@ -579,6 +580,7 @@ void ParseEvents(const std::vector> &events, std::vector event_items; std::vector main_event_items; std::unordered_map event_idx; + std::multimap sub_child_map; for (size_t j = 0; j < (*analyze_events)[i].size(); j++) { Event analyze_event = (*analyze_events)[i][j]; @@ -599,7 +601,7 @@ void ParseEvents(const std::vector> &events, (cname[fname.length()] == '/' && cname.rfind('/') == fname.length()); if (condition) { - child_map.insert( + sub_child_map.insert( std::pair(fname, event_items[k])); child_index[k] = 1; } @@ -618,9 +620,9 @@ void ParseEvents(const std::vector> &events, item.ave_time = item.total_time / item.calls; item.ratio = item.total_time / total; } - for (auto it = child_map.begin(); it != child_map.end(); it++) { + for (auto it = sub_child_map.begin(); it != sub_child_map.end(); it++) { it->second.ratio = it->second.total_time / total; - it->second.ave_time = it->second.ave_time / it->second.calls; + it->second.ave_time = it->second.total_time / it->second.calls; } // sort @@ -636,6 +638,11 @@ void ParseEvents(const std::vector> &events, << "\', which will be ignored in profiling report."; ++rit; } + + for (auto it = sub_child_map.begin(); it != sub_child_map.end(); it++) { + child_map.insert( + std::pair(it->first, it->second)); + } } // Print report diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 3762b5a419036..08c709dd854e6 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -2014,6 +2014,27 @@ All parameter, weight, gradient are variables in Paddle. build_strategy = fluid.BuildStrategy() build_strategy.fuse_bn_act_ops = True )DOC") + .def_property( + "enable_auto_fusion", + [](const BuildStrategy &self) { return self.enable_auto_fusion_; }, + [](BuildStrategy &self, bool b) { + PADDLE_ENFORCE_EQ(!self.IsFinalized(), true, + platform::errors::PreconditionNotMet( + "BuildStrategy is finlaized.")); + self.enable_auto_fusion_ = b; + }, + R"DOC((bool, optional): Whether to enable fusing subgraph to a + fusion_group. Now we only support fusing subgraph that composed + of elementwise-like operators, such as elementwise_add/mul + without broadcast and activations. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + build_strategy = fluid.BuildStrategy() + build_strategy.enable_auto_fusion = True + )DOC") .def_property( "fuse_relu_depthwise_conv", [](const BuildStrategy &self) { diff --git a/python/paddle/fluid/compiler.py b/python/paddle/fluid/compiler.py index 6f57f086e13b0..7400f45e05926 100644 --- a/python/paddle/fluid/compiler.py +++ b/python/paddle/fluid/compiler.py @@ -62,6 +62,28 @@ def _prune_feed_ops(program): program.global_block()._remove_op(index) +def _has_optimize_op(block): + for op in block.ops: + op_maker = core.op_proto_and_checker_maker + optimize = core.op_proto_and_checker_maker.OpRole.Optimize + if op_maker.kOpRoleVarAttrName() in op.attr_names and \ + int(op.all_attrs()[op_maker.kOpRoleAttrName()]) == int(optimize): + return True + return False + + +def _has_optimizer_in_control_flow(program): + if not program: + program = framework.default_main_program() + for op in program.global_block().ops: + if op.type == "conditional_block_grad": + sub_block = program.block(op._block_attr_id("sub_block")) + if _has_optimize_op(sub_block): + return True + + return False + + class CompiledProgram(object): """ The CompiledProgram is used to transform a program or graph for @@ -386,6 +408,16 @@ def _compile(self, scope, place): self._places = self._get_places(self._place, self._places) else: self._places = [self._place] + + # Todo(liym27):If optimizer is used in control flow, + # training on multi-places is not supported now, will + # be supported later. + if len(self._places) > 1 and \ + _has_optimizer_in_control_flow(self._program): + raise NotImplementedError( + "If optimizer is used in control flow, " + "training on multi-places is not supported now.") + self._executor = self._compile_data_parallel( use_cuda=isinstance(self._place, core.CUDAPlace), scope=self._scope, diff --git a/python/paddle/fluid/layer_helper_base.py b/python/paddle/fluid/layer_helper_base.py index 3140a84c09b7a..f6cf2a7d49c97 100644 --- a/python/paddle/fluid/layer_helper_base.py +++ b/python/paddle/fluid/layer_helper_base.py @@ -331,6 +331,14 @@ def create_parameter(self, if in_dygraph_mode(): # In dygraph mode, we want the returned parameter to be # initialized so that it can be used imperatively. + # check parameter name + is_used = unique_name.dygraph_parameter_name_checker(attr.name) + if is_used: + raise ValueError( + "parameter name [{}] have be been used. " + "In dygraph mode, the name of parameter can't be same." + "Please check the parameter attr value passed to self.create_parameter or " + "constructor of dygraph Layers".format(attr.name)) return self.main_program.global_block().create_parameter( dtype=dtype, shape=shape, diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 7e7beb494b83e..b56fcb35aed0d 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -189,6 +189,10 @@ if (APPLE OR WIN32) list(REMOVE_ITEM TEST_OPS test_dataset_dataloader) endif() +if(NOT WITH_GPU OR WIN32 OR APPLE) + list(REMOVE_ITEM TEST_OPS test_build_strategy_fusion_group_pass) +endif() + # Some ops need to check results when gc is enabled # Currently, only ops that register NoNeedBufferVarsInference need to do this test set(TEST_OPS_WITH_GC @@ -333,4 +337,5 @@ set_tests_properties(test_parallel_executor_test_while_train test_parallel_execu test_parallel_executor_feed_persistable_var test_parallel_executor_crf_auto_growth test_buffer_shared_memory_reuse_pass_and_fuse_optimization_op_pass test_data_norm_op test_imperative_using_non_zero_gpu test_fuse_bn_act_pass + test_optimizer_in_control_flow test_buffer_shared_memory_reuse_pass PROPERTIES LABELS "RUN_TYPE=DIST") diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_pool2d_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_pool2d_mkldnn_op.py index cae1239cbb40e..ee917b059b87c 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_pool2d_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_pool2d_mkldnn_op.py @@ -15,7 +15,8 @@ from __future__ import print_function import unittest -from paddle.fluid.tests.unittests.test_pool2d_op import * +import numpy as np +from paddle.fluid.tests.unittests.test_pool2d_op import TestPool2D_Op, TestCase1, TestCase2, TestCase3, TestCase4, TestCase5, avg_pool2D_forward_naive def create_test_mkldnn_use_ceil_class(parent): diff --git a/python/paddle/fluid/tests/unittests/test_argsort_op.py b/python/paddle/fluid/tests/unittests/test_argsort_op.py old mode 100755 new mode 100644 index 89ff5d7101a9a..44cd34879a69f --- a/python/paddle/fluid/tests/unittests/test_argsort_op.py +++ b/python/paddle/fluid/tests/unittests/test_argsort_op.py @@ -48,7 +48,7 @@ def init_axis(self): self.axis = -1 def init_datatype(self): - self.dtype = "float32" + self.dtype = "float64" def init_direction(self): self.descending = False @@ -56,6 +56,9 @@ def init_direction(self): def test_check_output(self): self.check_output() + def test_check_grad(self): + self.check_grad(['X'], 'Out') + class TestArgsortOpAxis0(TestArgsortOp): def init_axis(self): @@ -146,5 +149,18 @@ def init_direction(self): self.descending = True +class TestArgsortOpFP32Axis(TestArgsortOp): + def init_datatype(self): + self.dtype = "float32" + + +class TestArgsortOpFP32DescendingAxis(TestArgsortOp): + def init_datatype(self): + self.dtype = "float32" + + def init_direction(self): + self.descending = True + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_build_strategy_fusion_group_pass.py b/python/paddle/fluid/tests/unittests/test_build_strategy_fusion_group_pass.py new file mode 100644 index 0000000000000..4b41cfcd3136f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_build_strategy_fusion_group_pass.py @@ -0,0 +1,40 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle.fluid as fluid +import paddle.fluid.core as core +from test_eager_deletion_padding_rnn import RNNConfig, PaddingRNNTestBase + + +class FusionGroupPaddingRNNTest(PaddingRNNTestBase): + def set_customed_config(self): + # Enable fusion_group_pass + self.build_strategy.enable_auto_fusion = True + + # Use CUDA executor + if core.is_compiled_with_cuda(): + self.exe = fluid.Executor(fluid.CUDAPlace(0)) + + def test_train_enable_fusion_group(self): + rnn_model = "static" + config = RNNConfig("test", rnn_model) + with fluid.scope_guard(fluid.Scope()): + self.train(config, parallel=True, use_program_cache=False) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dygraph_mnist_fp16.py b/python/paddle/fluid/tests/unittests/test_dygraph_mnist_fp16.py index 0efc887384dba..7503a9172fc21 100644 --- a/python/paddle/fluid/tests/unittests/test_dygraph_mnist_fp16.py +++ b/python/paddle/fluid/tests/unittests/test_dygraph_mnist_fp16.py @@ -121,7 +121,7 @@ def test_mnist_fp16(self): if not fluid.is_compiled_with_cuda(): return x = np.random.randn(1, 3, 224, 224).astype("float16") - y = np.random.randn(1, 1).astype("int64") + y = np.random.randint(10, size=[1, 1], dtype="int64") with fluid.dygraph.guard(fluid.CUDAPlace(0)): model = MNIST(dtype="float16") x = fluid.dygraph.to_variable(x) diff --git a/python/paddle/fluid/tests/unittests/test_eager_deletion_padding_rnn.py b/python/paddle/fluid/tests/unittests/test_eager_deletion_padding_rnn.py index c0fd448d43397..92d1499843979 100644 --- a/python/paddle/fluid/tests/unittests/test_eager_deletion_padding_rnn.py +++ b/python/paddle/fluid/tests/unittests/test_eager_deletion_padding_rnn.py @@ -21,7 +21,6 @@ import paddle.fluid as fluid import paddle.fluid.core as core import paddle.fluid.layers as layers -import time import os from paddle.fluid import ParamAttr @@ -118,8 +117,7 @@ def lm_model(hidden_size, num_steps=20, init_scale=0.1, dropout=None, - rnn_model='static', - use_py_reader=False): + rnn_model='static'): def padding_rnn(input_embedding, len=3, init_hidden=None, init_cell=None): weight_1_arr = [] weight_2_arr = [] @@ -279,38 +277,9 @@ def encoder_static(input_embedding, len=3, init_hidden=None, gate_input = layers.elementwise_add(gate_input, bias) i, j, f, o = layers.split(gate_input, num_or_sections=4, dim=-1) - try: - from paddle.fluid.contrib.layers import fused_elemwise_activation - # fluid.contrib.layers.fused_elemwise_activation can do a fused - # operation, like: - # 1) x + sigmoid(y); x + tanh(y) - # 2) tanh(x + y) - # Now the unary operation supported in this fused op is limit, and - # we will extent this operation to support more unary operations and - # do this kind of fusion automitically in future version of paddle.fluid. - # layers.sigmoid(i) * layers.tanh(j) - tmp0 = fused_elemwise_activation( - x=layers.tanh(j), - y=i, - functor_list=['elementwise_mul', 'sigmoid'], - save_intermediate_out=False) - # pre_cell * layers.sigmoid(f) - tmp1 = fused_elemwise_activation( - x=pre_cell, - y=f, - functor_list=['elementwise_mul', 'sigmoid'], - save_intermediate_out=False) - c = tmp0 + tmp1 - # layers.tanh(c) * layers.sigmoid(o) - m = fused_elemwise_activation( - x=layers.tanh(c), - y=o, - functor_list=['elementwise_mul', 'sigmoid'], - save_intermediate_out=False) - except ImportError: - c = pre_cell * layers.sigmoid(f) + layers.sigmoid( - i) * layers.tanh(j) - m = layers.tanh(c) * layers.sigmoid(o) + c = pre_cell * layers.sigmoid(f) + layers.sigmoid( + i) * layers.tanh(j) + m = layers.tanh(c) * layers.sigmoid(o) hidden_array[k] = m cell_array[k] = c @@ -342,23 +311,16 @@ def encoder_static(input_embedding, len=3, init_hidden=None, return real_res, last_hidden, last_cell batch_size_each = batch_size - if use_py_reader: - feed_shapes = [[batch_size_each, num_steps, 1], - [batch_size_each * num_steps, 1]] - py_reader = fluid.layers.py_reader( - capacity=16, shapes=feed_shapes, dtypes=['int64', 'int64']) - x, y = fluid.layers.read_file(py_reader) - else: - x = layers.data( - name="x", - shape=[batch_size_each, num_steps, 1], - dtype='int64', - append_batch_size=False) - y = layers.data( - name="y", - shape=[batch_size_each * num_steps, 1], - dtype='int64', - append_batch_size=False) + x = layers.data( + name="x", + shape=[batch_size_each, num_steps, 1], + dtype='int64', + append_batch_size=False) + y = layers.data( + name="y", + shape=[batch_size_each * num_steps, 1], + dtype='int64', + append_batch_size=False) init_hidden = layers.data( name="init_hidden", @@ -472,10 +434,7 @@ def encoder_static(input_embedding, len=3, init_hidden=None, layers.assign(input=last_hidden, output=init_hidden) feeding_list = ['x', 'y', 'init_hidden', 'init_cell'] - if use_py_reader: - return loss, last_hidden, last_cell, feeding_list, py_reader - else: - return loss, last_hidden, last_cell, feeding_list + return loss, last_hidden, last_cell, feeding_list class PaddingRNNTestBase(unittest.TestCase): @@ -483,7 +442,24 @@ def setUp(self): self.reader = Reader() self.device_count = 1 - def prepare_program(self, config, parallel=True): + # Default exec_strategy + self.exec_strategy = fluid.ExecutionStrategy() + self.exec_strategy.num_threads = self.device_count + self.exec_strategy.num_iteration_per_drop_scope = 100 + + # Default build_strategy + self.build_strategy = fluid.BuildStrategy() + self.build_strategy.enable_inplace = True + self.build_strategy.memory_optimize = False + self.build_strategy.fuse_all_optimizer_ops = True + + # Default executor + self.exe = Executor(fluid.CPUPlace()) + + def set_customed_config(self): + pass + + def _prepare_program(self, config, parallel=True): self.main_program = fluid.Program() self.startup_program = fluid.Program() self.startup_program.random_seed = config.random_seed @@ -497,8 +473,7 @@ def prepare_program(self, config, parallel=True): num_steps=config.num_steps, init_scale=config.init_scale, dropout=config.dropout, - rnn_model=config.rnn_model, - use_py_reader=False) + rnn_model=config.rnn_model) self.loss, self.last_hidden, self.last_cell, self.feed_order = res_vars fluid.clip.set_gradient_clip( @@ -515,28 +490,19 @@ def prepare_program(self, config, parallel=True): optimizer = fluid.optimizer.SGD( learning_rate=self.learning_rate) optimizer.minimize(self.loss) - self.exe = Executor(fluid.CPUPlace()) + self.exe.run(self.startup_program) if parallel: - exec_strategy = fluid.ExecutionStrategy() - exec_strategy.num_threads = self.device_count - exec_strategy.num_iteration_per_drop_scope = 100 - - build_strategy = fluid.BuildStrategy() - build_strategy.enable_inplace = True - build_strategy.memory_optimize = False - build_strategy.fuse_all_optimizer_ops = True - self.train_program = fluid.compiler.CompiledProgram( self.main_program).with_data_parallel( loss_name=self.loss.name, - build_strategy=build_strategy, - exec_strategy=exec_strategy) + build_strategy=self.build_strategy, + exec_strategy=self.exec_strategy) else: self.train_program = self.main_program - def generate_init_data(self): + def _generate_init_data(self): init_hidden = np.zeros( (self.config.num_layers, self.config.batch_size, self.config.hidden_size), @@ -547,19 +513,19 @@ def generate_init_data(self): dtype='float32') return init_hidden, init_cell - def generate_new_lr(self, epoch_id=0, device_count=1): + def _generate_new_lr(self, epoch_id=0, device_count=1): new_lr = self.config.base_learning_rate * (self.config.lr_decay**max( epoch_id + 1 - self.config.epoch_start_decay, 0.0)) lr = np.ones((self.device_count), dtype='float32') * new_lr return lr - def prepare_input(self, - batch, - init_hidden=None, - init_cell=None, - epoch_id=0, - with_lr=True, - device_count=1): + def _prepare_input(self, + batch, + init_hidden=None, + init_cell=None, + epoch_id=0, + with_lr=True, + device_count=1): x, y = batch x = x.reshape((-1, self.config.num_steps, 1)) y = y.reshape((-1, 1)) @@ -572,19 +538,19 @@ def prepare_input(self, if init_cell is not None: res['init_cell'] = init_cell if with_lr: - res['learning_rate'] = self.generate_new_lr(epoch_id, device_count) + res['learning_rate'] = self._generate_new_lr(epoch_id, device_count) return res - def train_an_epoch(self, epoch_id, batch_times, use_program_cache=True): + def _train_an_epoch(self, epoch_id, use_program_cache=True): train_data_iter = self.reader.get_data_iter(self.config) total_loss = 0 iters = 0 - init_hidden, init_cell = self.generate_init_data() + init_hidden, init_cell = self._generate_init_data() ppl = np.zeros(shape=(0)) for batch_id, batch in enumerate(train_data_iter): - input_data_feed = self.prepare_input( + input_data_feed = self._prepare_input( batch, init_hidden=init_hidden, init_cell=init_cell, @@ -592,7 +558,6 @@ def train_an_epoch(self, epoch_id, batch_times, use_program_cache=True): with_lr=True, device_count=self.device_count) - batch_start_time = time.time() fetch_outs = self.exe.run(self.train_program, feed=input_data_feed, fetch_list=[ @@ -601,8 +566,6 @@ def train_an_epoch(self, epoch_id, batch_times, use_program_cache=True): self.last_cell.name ], use_program_cache=use_program_cache) - batch_time = time.time() - batch_start_time - batch_times.append(batch_time) cost_train = np.array(fetch_outs[0]) lr = np.array(fetch_outs[1]) @@ -617,17 +580,13 @@ def train_an_epoch(self, epoch_id, batch_times, use_program_cache=True): return ppl def train(self, config, parallel=True, use_program_cache=True): + self.set_customed_config() + self.config = config - self.prepare_program(config, parallel) - total_time = 0.0 + self._prepare_program(config, parallel) ppl = np.zeros(shape=(0, config.batch_size)) for epoch_id in range(config.max_epoch): - batch_times = [] - epoch_start_time = time.time() - train_ppl = self.train_an_epoch(epoch_id, batch_times, - use_program_cache) - epoch_time = time.time() - epoch_start_time - total_time += epoch_time + train_ppl = self._train_an_epoch(epoch_id, use_program_cache) ppl = np.append(ppl, train_ppl) return ppl diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index cde4264ff14c2..384acc2360af0 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -1550,6 +1550,15 @@ def test_accuracy(self): class TestBook(LayerTest): + def setUp(self): + self.only_static_set = set({"make_word_embedding"}) + self.not_compare_static_dygraph_set = set({ + "make_gaussian_random", "make_gaussian_random_batch_size_like", + "make_kldiv_loss", "make_prelu", + "make_sampled_softmax_with_cross_entropy", "make_sampling_id", + "make_uniform_random_batch_size_like" + }) + def test_all_layers(self): attrs = (getattr(self, name) for name in dir(self)) methods = filter(inspect.ismethod, attrs) @@ -1572,9 +1581,12 @@ def test_all_layers(self): feed=self._feed_dict, fetch_list=fetch_list, force_to_use_cpu=self._force_to_use_cpu) + else: assert method.__name__ in ('make_get_places') continue + if method.__name__ in self.only_static_set: + continue with self.dynamic_graph(self._force_to_use_cpu): dy_result = method() @@ -1582,7 +1594,9 @@ def test_all_layers(self): dy_result = dy_result[0] dy_result_value = dy_result.numpy() - self.assertTrue(np.array_equal(static_result[0], dy_result_value)) + if method.__name__ not in self.not_compare_static_dygraph_set: + self.assertTrue( + np.array_equal(static_result[0], dy_result_value)) def _get_np_data(self, shape, dtype, append_batch_size=True): np.random.seed(self.seed) diff --git a/python/paddle/fluid/tests/unittests/test_optimizer_in_control_flow.py b/python/paddle/fluid/tests/unittests/test_optimizer_in_control_flow.py index 63579ee80acc2..4b2914c223a08 100644 --- a/python/paddle/fluid/tests/unittests/test_optimizer_in_control_flow.py +++ b/python/paddle/fluid/tests/unittests/test_optimizer_in_control_flow.py @@ -22,6 +22,8 @@ import paddle.fluid.optimizer as optimizer from paddle.fluid.framework import Program, program_guard import paddle.fluid.core as core +import paddle.fluid.compiler as compiler +import os BATCH_SIZE = 1 INPUT_SIZE = 784 @@ -104,20 +106,20 @@ def fn_2(opt, avg_loss=None, pred=None, label=None): avg_loss = layers.case([(mod_two, lambda: fn_1(adam, avg_loss_1))], lambda: fn_2(sgd, avg_loss_2)) - place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() - exe = fluid.Executor(place) - exe.run(fluid.default_startup_program()) - - for epoch in range(EPOCH_NUM): - feed_image, feed_label = train_data[epoch] - fetch_list = [hidden, prediction, avg_loss] - feed = { - 'image': feed_image, - 'label': feed_label, - 'id': np.array([epoch]).astype('int32') - } - out = exe.run(main_program, feed=feed, fetch_list=fetch_list) - out_hidden, out_pred, loss = out + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(startup_program) + + for epoch in range(EPOCH_NUM): + feed_image, feed_label = train_data[epoch] + fetch_list = [hidden, prediction, avg_loss] + feed = { + 'image': feed_image, + 'label': feed_label, + 'id': np.array([epoch]).astype('int32') + } + out = exe.run(main_program, feed=feed, fetch_list=fetch_list) + out_hidden, out_pred, loss = out return out_hidden, out_pred, loss @@ -225,5 +227,58 @@ def test_optimzier_in_switch(self): loss_2)) +class TestMultiOptimizersMultiCardsError(unittest.TestCase): + def test_error(self): + startup_program = Program() + main_program = Program() + use_cuda = core.is_compiled_with_cuda() + with program_guard(main_program, startup_program): + + def fn_1(opt, avg_loss): + opt.minimize(avg_loss) + + def fn_2(opt, avg_loss): + opt.minimize(avg_loss) + + x = fluid.layers.data("X", [10], 'float32') + hidden = layers.fc(x, 5) + avg_loss = layers.mean(hidden) + + adam = optimizer.Adam(learning_rate=LR) + sgd = optimizer.SGD(learning_rate=LR) + + cond = layers.fill_constant([1], 'bool', True) + + layers.case([(cond, lambda: fn_1(adam, avg_loss))], + lambda: fn_2(sgd, avg_loss)) + + cpu_place = fluid.CPUPlace() + cuda_place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + + for place in [cpu_place, cuda_place]: + + exe = fluid.Executor(place) + exe.run(startup_program) + + np.random.seed(SEED) + os.environ['CPU_NUM'] = str(2) + pe_exe = fluid.ParallelExecutor( + use_cuda=use_cuda, + main_program=main_program, + loss_name=avg_loss.name) + num_devices = pe_exe.device_count + + def not_implemented_error(): + pe_exe.run(feed={ + 'X': np.random.random(size=[64, 10]).astype('float32'), + }, + fetch_list=[avg_loss.name]) + + if num_devices > 1: + self.assertRaises(NotImplementedError, not_implemented_error) + else: + not_implemented_error() + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/unique_name.py b/python/paddle/fluid/unique_name.py index 550364a22fe34..6129bd1c802e4 100644 --- a/python/paddle/fluid/unique_name.py +++ b/python/paddle/fluid/unique_name.py @@ -51,6 +51,33 @@ def __call__(self, key): return self.prefix + "_".join([key, str(tmp)]) +class DygraphParameterNameChecker(object): + """ + Check whether the name of parameter is used. + """ + + def __init__(self): + self._name_set = set() + + def __call__(self, name): + ''' + Check whether the name is used. If not used, insert into the _name_set. + + Args: + name(str): The name of parameter to check. + + Returns(bool): If the name is in name_set, return True; Otherwise, return False. + + ''' + if name in self._name_set: + return True + else: + self._name_set.add(name) + return False + + +dygraph_parameter_name_checker = DygraphParameterNameChecker() + generator = UniqueNameGenerator() @@ -101,7 +128,7 @@ def generate_with_ignorable_key(key): return generator(key) -def switch(new_generator=None): +def switch(new_generator=None, new_para_name_checker=None): """ Switch the namespace of in current context to a new namespace. Though :code:`switch()` and :code:`guard()` can both change namespace, @@ -112,9 +139,13 @@ def switch(new_generator=None): new_generator(UniqueNameGenerator, optional): A new UniqueNameGenerator, not required normally. Default is None, which means switch to a new anonymous namespace. + new_para_name_checker(DygraphParameterNameChecker, optional): A new DygraphParameterNameChecker, + not required normally. Default is None, which means switch to a new parameter name + checker. Returns: UniqueNameGenerator: The previous UniqueNameGenerator. + DygraphParameterNameChecker: The previous DygraphParameterNameChecker Examples: @@ -125,22 +156,29 @@ def switch(new_generator=None): name2 = fluid.unique_name.generate('fc') print(name1, name2) # fc_0, fc_1 - pre_generator = fluid.unique_name.switch() # switch to a new anonymous namespace. + pre_generator, pre_dygraph_name_checker = fluid.unique_name.switch() # switch to a new anonymous namespace. name2 = fluid.unique_name.generate('fc') print(name2) # fc_0 - fluid.unique_name.switch(pre_generator) # switch back to pre_generator. + fluid.unique_name.switch(pre_generator, pre_dygraph_name_checker) # switch back to pre_generator. name3 = fluid.unique_name.generate('fc') print(name3) # fc_2, since pre_generator has generated fc_0, fc_1. """ global generator - old = generator + old_generator = generator + global dygraph_parameter_name_checker + old_para_name_checker = dygraph_parameter_name_checker if new_generator is None: generator = UniqueNameGenerator() else: generator = new_generator - return old + + if new_para_name_checker is None: + dygraph_parameter_name_checker = DygraphParameterNameChecker() + else: + dygraph_parameter_name_checker = new_para_name_checker + return old_generator, old_para_name_checker @signature_safe_contextmanager @@ -180,6 +218,7 @@ def guard(new_generator=None): new_generator = UniqueNameGenerator(new_generator) elif isinstance(new_generator, six.binary_type): new_generator = UniqueNameGenerator(new_generator.decode()) - old = switch(new_generator) + + old_generator, old_para_name_checker = switch(new_generator) yield - switch(old) + switch(old_generator, old_para_name_checker)