diff --git a/src/common/types.h b/src/common/types.h index ca9e64cc60f..04b78947a6a 100644 --- a/src/common/types.h +++ b/src/common/types.h @@ -77,7 +77,8 @@ static const std::string G_OP_TYPE_BATCHNORM = "batch_norm"; static const std::string G_OP_TYPE_BOX_CODER = "box_coder"; static const std::string G_OP_TYPE_CONCAT = "concat"; static const std::string G_OP_TYPE_ELEMENTWISE_ADD = "elementwise_add"; -static const std::string G_OP_TYPE_FUSION_CONV_ADD_RELU = "FusionConvAddRelu"; +static const std::string G_OP_TYPE_FUSION_CONV_ADD_RELU = + "fusion_conv_add_relu"; static const std::string G_OP_TYPE_FC = "fc"; static const std::string G_OP_TYPE_LRN = "lrn"; static const std::string G_OP_TYPE_MUL = "mul"; @@ -92,6 +93,7 @@ static const std::string G_OP_TYPE_TRANSPOSE = "transpose"; static const std::string G_OP_TYPE_SPLIT = "split"; static const std::string G_OP_TYPE_FEED = "feed"; static const std::string G_OP_TYPE_FETCH = "fetch"; +static const std::string G_OP_TYPE_DEPTHWISE_CONV = "depthwise_conv2d"; static std::unordered_map< std::string, std::pair, std::vector>> diff --git a/src/framework/program/program-optimize/node.cpp b/src/framework/program/program-optimize/node.cpp index 31377222db8..c165b6568aa 100644 --- a/src/framework/program/program-optimize/node.cpp +++ b/src/framework/program/program-optimize/node.cpp @@ -45,6 +45,47 @@ bool Node::operator==(const Node &in) { return true; } +bool Node::CanSplit(std::unordered_set complex_compute_set) { + bool split = false; + CanSplit(&split, false, 0, &complex_compute_set, this); + return split; +} + +void Node::CanSplit(bool *split, bool spliting, int complex_count, + std::unordered_set *complex_compute_set, + Node *pre_node) { + if (spliting) { + if (complex_compute_set->find(this->type_) != complex_compute_set->end()) { + complex_count++; + } + } + + if (inputs_.size() > 1 && pre_node != inputs_.back()) { + return; + } + if (inputs_.size() > 1 && pre_node == inputs_.back()) { + if (complex_count > 1) { + *split = true; + return; + } + } + + // multi output, to check + if (outputs_.size() > 1) { + spliting = true; + complex_compute_set = 0; + } else { + if (spliting == true && inputs_.size() > 0) { + spliting = false; + } else { + } + } + + for (auto &output : outputs_) { + output->CanSplit(split, spliting, complex_count, complex_compute_set, this); + } +} + std::vector> Node::OpDescs(uint size) { std::vector> op_descs; OpDescs(size - 1, &op_descs); diff --git a/src/framework/program/program-optimize/node.h b/src/framework/program/program-optimize/node.h index da7e26a9ac0..8ef26f897d2 100644 --- a/src/framework/program/program-optimize/node.h +++ b/src/framework/program/program-optimize/node.h @@ -16,6 +16,7 @@ limitations under the License. */ #include #include +#include #include #include @@ -36,6 +37,7 @@ class Node : PaddleMobileObject { : op_desc_(op_desc), type_(op_desc->Type()) {} Node &operator>(std::shared_ptr node); bool operator==(const Node &in); + bool CanSplit(std::unordered_set complex_compute_set); std::string ToString() const; std::shared_ptr To(int size); uint Depth(uint begin = 0); @@ -49,6 +51,9 @@ class Node : PaddleMobileObject { void Description(); private: + void CanSplit(bool *split, bool spliting, int complex_count, + std::unordered_set *complex_compute_set, + Node *pre_node); void OpDescs(std::vector> *op_desc, Node *node, bool adding_thread, int thread_num); void OpDescs(uint size, diff --git a/src/framework/program/program-optimize/program_optimize.cpp b/src/framework/program/program-optimize/program_optimize.cpp index 8b0bf295262..d9c3c51c3c8 100644 --- a/src/framework/program/program-optimize/program_optimize.cpp +++ b/src/framework/program/program-optimize/program_optimize.cpp @@ -99,6 +99,8 @@ std::shared_ptr ProgramOptimize::FushionOptimize( // DLOG << "node: \n" << *begin_node; std::vector> op_descs; + // bool can_splite = begin_node->CanSplit({G_OP_TYPE_CONV, + // G_OP_TYPE_BATCHNORM, G_OP_TYPE_DEPTHWISE_CONV}); GenerateOps(&op_descs, begin_node.get()); block->ops_ = op_descs; } @@ -111,6 +113,25 @@ std::shared_ptr ProgramOptimize::FushionOptimize( return optimize_program; } +void ProgramOptimize::GenerateOps( + std::vector> *op_desc, Node *input_node, + Node *current_node) { + if (current_node->inputs_.size() > 1 && + input_node != current_node->inputs_.back()) { + return; + } else if (current_node->inputs_.size() > 1 && + input_node == current_node->inputs_.back()) { + op_desc->push_back(current_node->op_desc_); + } else { + op_desc->push_back(current_node->op_desc_); + } + + for (int i = 0; i < current_node->outputs_.size(); ++i) { + auto &output = current_node->outputs_[i]; + GenerateOps(op_desc, current_node, output.get()); + } +} + void ProgramOptimize::GenerateOps( std::vector> *op_desc, Node *input_node, Node *current_node, bool adding_thread, int thread_num, @@ -234,7 +255,11 @@ void ProgramOptimize::GenerateOps( // std::vector> *op_desc, // Node *input_node, Node *current_node, bool adding_thread, int // thread_num - this->GenerateOps(op_descs, begin_node, begin_node, false, -1, nullptr); + if (false) { + this->GenerateOps(op_descs, begin_node, begin_node, false, -1, nullptr); + } else { + this->GenerateOps(op_descs, begin_node, begin_node); + } } } // namespace framework diff --git a/src/framework/program/program-optimize/program_optimize.h b/src/framework/program/program-optimize/program_optimize.h index 32d8d1fa914..93943cf8395 100644 --- a/src/framework/program/program-optimize/program_optimize.h +++ b/src/framework/program/program-optimize/program_optimize.h @@ -33,9 +33,10 @@ class ProgramOptimize { private: int current_block_; std::vector> new_blocks_; - void GenerateOps(std::vector> *op_descs, Node *begin_node); + void GenerateOps(std::vector> *op_desc, + Node *input_node, Node *current_node); void GenerateOps(std::vector> *op_desc, Node *input_node, Node *current_node, bool adding_thread, int thread_num, std::shared_ptr new_block); diff --git a/src/io.cpp b/src/io.cpp index ac89106e498..8f6a07f2dd1 100644 --- a/src/io.cpp +++ b/src/io.cpp @@ -220,13 +220,17 @@ const framework::Program Loader::Load( } } } - originProgramDesc->Description("program: "); if (optimize) { framework::ProgramOptimize program_optimize; program.optimizeProgram = program_optimize.FushionOptimize(originProgramDesc); } + if (optimize) { + program.optimizeProgram->Description("optimize: "); + } else { + originProgramDesc->Description("program: "); + } paddle_mobile__framework__proto__program_desc__free_unpacked(c_program, NULL); return program; @@ -254,6 +258,7 @@ Executor::Executor(const framework::Program p, int batch_size, std::vector> ops = block_desc->Ops(); for (int j = 0; j < ops.size(); ++j) { std::shared_ptr op = ops[j]; + DLOG << "create op: " << op->Type(); auto op_base = framework::OpRegistry::CreateOp( op->Type(), op->GetInputs(), op->GetOutputs(), op->GetAttrMap(), program_.scope); diff --git a/src/operators/fusion_conv_add_relu_op.h b/src/operators/fusion_conv_add_relu_op.h index 1fa3399cf22..e93c910d2b3 100644 --- a/src/operators/fusion_conv_add_relu_op.h +++ b/src/operators/fusion_conv_add_relu_op.h @@ -28,11 +28,11 @@ class FushionConvAddReluOpMatcher : public framework::FusionOpMatcher { std::make_shared(G_OP_TYPE_RELU); } - void FolderNodes(framework::Node &node) { + void FolderNodes(framework::Node *node) { std::vector> origin_descs = - node.OpDescs(node_.Depth()); - node.Folder(node_.Depth(), Type(), - {{G_OP_TYPE_ELEMENTWISE_ADD, {"Y", "Z"}}}); + node->OpDescs(node_.Depth()); + node->Folder(node_.Depth(), Type(), + {{G_OP_TYPE_ELEMENTWISE_ADD, {"Y", "Z"}}}); } std::string Type() { return G_OP_TYPE_FUSION_CONV_ADD_RELU; } }; diff --git a/src/operators/fusion_fc_op.h b/src/operators/fusion_fc_op.h index fb49fa61b20..9019ef4d496 100644 --- a/src/operators/fusion_fc_op.h +++ b/src/operators/fusion_fc_op.h @@ -32,11 +32,11 @@ class FusionFcMatcher : public framework::FusionOpMatcher { node_ > std::make_shared(G_OP_TYPE_ELEMENTWISE_ADD); } - void FolderNodes(framework::Node &node) { + void FolderNodes(framework::Node *node) { vector> origin_descs = - node.OpDescs(node_.Depth()); - node.Folder(node_.Depth(), Type(), - {{G_OP_TYPE_ELEMENTWISE_ADD, {"Y", "Z"}}}); + node->OpDescs(node_.Depth()); + node->Folder(node_.Depth(), Type(), + {{G_OP_TYPE_ELEMENTWISE_ADD, {"Y", "Z"}}}); } std::string Type() { return G_OP_TYPE_FC; } diff --git a/test/net/test_googlenet.cpp b/test/net/test_googlenet.cpp index 0640af890cf..302cd3e726e 100644 --- a/test/net/test_googlenet.cpp +++ b/test/net/test_googlenet.cpp @@ -18,11 +18,12 @@ limitations under the License. */ int main() { paddle_mobile::Loader loader; + bool optimize = true; auto time1 = time(); - auto program = loader.Load(g_googlenet, false); + auto program = loader.Load(g_googlenet, optimize); auto time2 = time(); DLOG << "load cost :" << time_diff(time1, time2) << "ms\n"; - paddle_mobile::Executor executor(program, 1, false); + paddle_mobile::Executor executor(program, 1, optimize); std::vector input; std::vector dims{1, 3, 224, 224}; GetInput(g_test_image_1x3x224x224, &input, dims);