Skip to content

Commit

Permalink
Merge pull request #332 from codeWorm2015/develop
Browse files Browse the repository at this point in the history
fix #331 fix fc crash
  • Loading branch information
Eclipsess committed May 31, 2018
2 parents 7da97bc + ddd8e46 commit 2a345a2
Show file tree
Hide file tree
Showing 9 changed files with 94 additions and 14 deletions.
4 changes: 3 additions & 1 deletion src/common/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ static const std::string G_OP_TYPE_BATCHNORM = "batch_norm";
static const std::string G_OP_TYPE_BOX_CODER = "box_coder";
static const std::string G_OP_TYPE_CONCAT = "concat";
static const std::string G_OP_TYPE_ELEMENTWISE_ADD = "elementwise_add";
static const std::string G_OP_TYPE_FUSION_CONV_ADD_RELU = "FusionConvAddRelu";
static const std::string G_OP_TYPE_FUSION_CONV_ADD_RELU =
"fusion_conv_add_relu";
static const std::string G_OP_TYPE_FC = "fc";
static const std::string G_OP_TYPE_LRN = "lrn";
static const std::string G_OP_TYPE_MUL = "mul";
Expand All @@ -92,6 +93,7 @@ static const std::string G_OP_TYPE_TRANSPOSE = "transpose";
static const std::string G_OP_TYPE_SPLIT = "split";
static const std::string G_OP_TYPE_FEED = "feed";
static const std::string G_OP_TYPE_FETCH = "fetch";
static const std::string G_OP_TYPE_DEPTHWISE_CONV = "depthwise_conv2d";

static std::unordered_map<
std::string, std::pair<std::vector<std::string>, std::vector<std::string>>>
Expand Down
41 changes: 41 additions & 0 deletions src/framework/program/program-optimize/node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,47 @@ bool Node::operator==(const Node &in) {
return true;
}

bool Node::CanSplit(std::unordered_set<std::string> complex_compute_set) {
bool split = false;
CanSplit(&split, false, 0, &complex_compute_set, this);
return split;
}

void Node::CanSplit(bool *split, bool spliting, int complex_count,
std::unordered_set<std::string> *complex_compute_set,
Node *pre_node) {
if (spliting) {
if (complex_compute_set->find(this->type_) != complex_compute_set->end()) {
complex_count++;
}
}

if (inputs_.size() > 1 && pre_node != inputs_.back()) {
return;
}
if (inputs_.size() > 1 && pre_node == inputs_.back()) {
if (complex_count > 1) {
*split = true;
return;
}
}

// multi output, to check
if (outputs_.size() > 1) {
spliting = true;
complex_compute_set = 0;
} else {
if (spliting == true && inputs_.size() > 0) {
spliting = false;
} else {
}
}

for (auto &output : outputs_) {
output->CanSplit(split, spliting, complex_count, complex_compute_set, this);
}
}

std::vector<std::shared_ptr<framework::OpDesc>> Node::OpDescs(uint size) {
std::vector<std::shared_ptr<framework::OpDesc>> op_descs;
OpDescs(size - 1, &op_descs);
Expand Down
5 changes: 5 additions & 0 deletions src/framework/program/program-optimize/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License. */

#include <map>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>

Expand All @@ -36,6 +37,7 @@ class Node : PaddleMobileObject {
: op_desc_(op_desc), type_(op_desc->Type()) {}
Node &operator>(std::shared_ptr<Node> node);
bool operator==(const Node &in);
bool CanSplit(std::unordered_set<std::string> complex_compute_set);
std::string ToString() const;
std::shared_ptr<Node> To(int size);
uint Depth(uint begin = 0);
Expand All @@ -49,6 +51,9 @@ class Node : PaddleMobileObject {
void Description();

private:
void CanSplit(bool *split, bool spliting, int complex_count,
std::unordered_set<std::string> *complex_compute_set,
Node *pre_node);
void OpDescs(std::vector<std::shared_ptr<framework::OpDesc>> *op_desc,
Node *node, bool adding_thread, int thread_num);
void OpDescs(uint size,
Expand Down
27 changes: 26 additions & 1 deletion src/framework/program/program-optimize/program_optimize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ std::shared_ptr<ProgramDesc> ProgramOptimize::FushionOptimize(
// DLOG << "node: \n" << *begin_node;

std::vector<std::shared_ptr<framework::OpDesc>> op_descs;
// bool can_splite = begin_node->CanSplit({G_OP_TYPE_CONV,
// G_OP_TYPE_BATCHNORM, G_OP_TYPE_DEPTHWISE_CONV});
GenerateOps(&op_descs, begin_node.get());
block->ops_ = op_descs;
}
Expand All @@ -111,6 +113,25 @@ std::shared_ptr<ProgramDesc> ProgramOptimize::FushionOptimize(
return optimize_program;
}

void ProgramOptimize::GenerateOps(
std::vector<std::shared_ptr<framework::OpDesc>> *op_desc, Node *input_node,
Node *current_node) {
if (current_node->inputs_.size() > 1 &&
input_node != current_node->inputs_.back()) {
return;
} else if (current_node->inputs_.size() > 1 &&
input_node == current_node->inputs_.back()) {
op_desc->push_back(current_node->op_desc_);
} else {
op_desc->push_back(current_node->op_desc_);
}

for (int i = 0; i < current_node->outputs_.size(); ++i) {
auto &output = current_node->outputs_[i];
GenerateOps(op_desc, current_node, output.get());
}
}

void ProgramOptimize::GenerateOps(
std::vector<std::shared_ptr<framework::OpDesc>> *op_desc, Node *input_node,
Node *current_node, bool adding_thread, int thread_num,
Expand Down Expand Up @@ -234,7 +255,11 @@ void ProgramOptimize::GenerateOps(
// std::vector<std::shared_ptr<framework::OpDesc>> *op_desc,
// Node *input_node, Node *current_node, bool adding_thread, int
// thread_num
this->GenerateOps(op_descs, begin_node, begin_node, false, -1, nullptr);
if (false) {
this->GenerateOps(op_descs, begin_node, begin_node, false, -1, nullptr);
} else {
this->GenerateOps(op_descs, begin_node, begin_node);
}
}

} // namespace framework
Expand Down
3 changes: 2 additions & 1 deletion src/framework/program/program-optimize/program_optimize.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ class ProgramOptimize {
private:
int current_block_;
std::vector<std::shared_ptr<BlockDesc>> new_blocks_;

void GenerateOps(std::vector<std::shared_ptr<framework::OpDesc>> *op_descs,
Node *begin_node);
void GenerateOps(std::vector<std::shared_ptr<framework::OpDesc>> *op_desc,
Node *input_node, Node *current_node);
void GenerateOps(std::vector<std::shared_ptr<framework::OpDesc>> *op_desc,
Node *input_node, Node *current_node, bool adding_thread,
int thread_num, std::shared_ptr<BlockDesc> new_block);
Expand Down
7 changes: 6 additions & 1 deletion src/io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,13 +220,17 @@ const framework::Program<Dtype, P> Loader<Dtype, P>::Load(
}
}
}
originProgramDesc->Description("program: ");

if (optimize) {
framework::ProgramOptimize program_optimize;
program.optimizeProgram =
program_optimize.FushionOptimize(originProgramDesc);
}
if (optimize) {
program.optimizeProgram->Description("optimize: ");
} else {
originProgramDesc->Description("program: ");
}

paddle_mobile__framework__proto__program_desc__free_unpacked(c_program, NULL);
return program;
Expand Down Expand Up @@ -254,6 +258,7 @@ Executor<Dtype, P>::Executor(const framework::Program<Dtype> p, int batch_size,
std::vector<std::shared_ptr<framework::OpDesc>> ops = block_desc->Ops();
for (int j = 0; j < ops.size(); ++j) {
std::shared_ptr<framework::OpDesc> op = ops[j];
DLOG << "create op: " << op->Type();
auto op_base = framework::OpRegistry<Dtype>::CreateOp(
op->Type(), op->GetInputs(), op->GetOutputs(), op->GetAttrMap(),
program_.scope);
Expand Down
8 changes: 4 additions & 4 deletions src/operators/fusion_conv_add_relu_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ class FushionConvAddReluOpMatcher : public framework::FusionOpMatcher {
std::make_shared<framework::Node>(G_OP_TYPE_RELU);
}

void FolderNodes(framework::Node &node) {
void FolderNodes(framework::Node *node) {
std::vector<std::shared_ptr<framework::OpDesc>> origin_descs =
node.OpDescs(node_.Depth());
node.Folder(node_.Depth(), Type(),
{{G_OP_TYPE_ELEMENTWISE_ADD, {"Y", "Z"}}});
node->OpDescs(node_.Depth());
node->Folder(node_.Depth(), Type(),
{{G_OP_TYPE_ELEMENTWISE_ADD, {"Y", "Z"}}});
}
std::string Type() { return G_OP_TYPE_FUSION_CONV_ADD_RELU; }
};
Expand Down
8 changes: 4 additions & 4 deletions src/operators/fusion_fc_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ class FusionFcMatcher : public framework::FusionOpMatcher {
node_ > std::make_shared<framework::Node>(G_OP_TYPE_ELEMENTWISE_ADD);
}

void FolderNodes(framework::Node &node) {
void FolderNodes(framework::Node *node) {
vector<std::shared_ptr<framework::OpDesc>> origin_descs =
node.OpDescs(node_.Depth());
node.Folder(node_.Depth(), Type(),
{{G_OP_TYPE_ELEMENTWISE_ADD, {"Y", "Z"}}});
node->OpDescs(node_.Depth());
node->Folder(node_.Depth(), Type(),
{{G_OP_TYPE_ELEMENTWISE_ADD, {"Y", "Z"}}});
}

std::string Type() { return G_OP_TYPE_FC; }
Expand Down
5 changes: 3 additions & 2 deletions test/net/test_googlenet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ limitations under the License. */

int main() {
paddle_mobile::Loader<paddle_mobile::CPU> loader;
bool optimize = true;
auto time1 = time();
auto program = loader.Load(g_googlenet, false);
auto program = loader.Load(g_googlenet, optimize);
auto time2 = time();
DLOG << "load cost :" << time_diff(time1, time2) << "ms\n";
paddle_mobile::Executor<paddle_mobile::CPU> executor(program, 1, false);
paddle_mobile::Executor<paddle_mobile::CPU> executor(program, 1, optimize);
std::vector<float> input;
std::vector<int64_t> dims{1, 3, 224, 224};
GetInput<float>(g_test_image_1x3x224x224, &input, dims);
Expand Down

0 comments on commit 2a345a2

Please sign in to comment.