-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fea/infer executor #13451
fea/infer executor #13451
Changes from all commits
ec0cddf
1f8f3e4
7be81ed
1f10465
9b24dc7
409239c
8e83d13
96cb893
3ec4687
f976a33
d9b065d
76cf652
9f24533
fd421d6
3f275e8
45dfdf9
b5a376f
76a5f7f
439afda
9f54699
5422d55
39c42cb
3647712
6b5bd44
435fa1f
fd69b67
828018d
ee75ad3
aee43f7
4bdffdb
f1b0fea
af5c86c
2001f9b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -56,9 +56,9 @@ else() | |
cc_test(mixed_vector_test SRCS mixed_vector_test.cc DEPS place memory device_context tensor) | ||
endif() | ||
if (NOT WIN32) | ||
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto recordio version) | ||
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto recordio version) | ||
else() | ||
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto version) | ||
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto version) | ||
endif (NOT WIN32) | ||
|
||
cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory) | ||
|
@@ -141,12 +141,15 @@ cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor) | |
|
||
cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glog) | ||
|
||
cc_library(naive_executor SRCS naive_executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass) | ||
|
||
if(WITH_DISTRIBUTE) | ||
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc cares grpc++_unsecure grpc_unsecure gpr graph_to_program_pass) | ||
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") | ||
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) | ||
else() | ||
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass) | ||
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass elementwise_add_op) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 可简化为 |
||
endif() | ||
|
||
if (NOT WIN32) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,9 +28,9 @@ cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph grap | |
pass_library(graph_to_program_pass base) | ||
pass_library(graph_viz_pass base) | ||
pass_library(fc_fuse_pass inference) | ||
if(WITH_MKLDNN) | ||
pass_library(conv_relu_mkldnn_fuse_pass inference) | ||
endif() | ||
if (WITH_MKLDNN) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个空格没必要,下同。 |
||
pass_library(conv_relu_mkldnn_fuse_pass inference) | ||
endif () | ||
pass_library(attention_lstm_fuse_pass inference) | ||
pass_library(infer_clean_graph_pass inference) | ||
pass_library(fc_lstm_fuse_pass inference) | ||
|
@@ -47,6 +47,6 @@ cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_r | |
cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass) | ||
cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector) | ||
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto) | ||
if(WITH_MKLDNN) | ||
cc_test(test_conv_relu_mkldnn_fuse_pass SRCS conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass) | ||
endif() | ||
if (WITH_MKLDNN) | ||
cc_test(test_conv_relu_mkldnn_fuse_pass SRCS conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass) | ||
endif () |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/framework/naive_executor.h" | ||
#include "paddle/fluid/framework/channel.h" | ||
#include "paddle/fluid/framework/feed_fetch_method.h" | ||
#include "paddle/fluid/framework/lod_rank_table.h" | ||
#include "paddle/fluid/framework/lod_tensor_array.h" | ||
#include "paddle/fluid/framework/op_registry.h" | ||
#include "paddle/fluid/framework/reader.h" | ||
#include "paddle/fluid/string/pretty_log.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
||
// These code can be shared with Executor. | ||
static void InitializeVariable(Variable *var, proto::VarType::Type var_type) { | ||
if (var_type == proto::VarType::LOD_TENSOR) { | ||
var->GetMutable<LoDTensor>(); | ||
} else if (var_type == proto::VarType::SELECTED_ROWS) { | ||
var->GetMutable<SelectedRows>(); | ||
} else if (var_type == proto::VarType::FEED_MINIBATCH) { | ||
var->GetMutable<FeedFetchList>(); | ||
} else if (var_type == proto::VarType::FETCH_LIST) { | ||
var->GetMutable<FeedFetchList>(); | ||
} else if (var_type == proto::VarType::STEP_SCOPES) { | ||
var->GetMutable<std::vector<framework::Scope>>(); | ||
} else if (var_type == proto::VarType::LOD_RANK_TABLE) { | ||
var->GetMutable<LoDRankTable>(); | ||
} else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) { | ||
var->GetMutable<LoDTensorArray>(); | ||
} else if (var_type == proto::VarType::PLACE_LIST) { | ||
var->GetMutable<platform::PlaceList>(); | ||
} else if (var_type == proto::VarType::READER) { | ||
var->GetMutable<ReaderHolder>(); | ||
} else if (var_type == proto::VarType::CHANNEL) { | ||
var->GetMutable<ChannelHolder>(); | ||
} else if (var_type == proto::VarType::RAW) { | ||
// GetMutable will be called in operator | ||
} else { | ||
PADDLE_THROW( | ||
"Variable type %d is not in " | ||
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, " | ||
"LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, RAW]", | ||
var_type); | ||
} | ||
} | ||
|
||
void NaiveExecutor::Prepare(Scope *parent_scope, | ||
const ProgramDesc &program_desc, int block_id, | ||
bool with_feed_fetch_ops) { | ||
if (!parent_scope) { | ||
scope_ = new framework::Scope; | ||
} else { | ||
scope_ = &parent_scope->NewScope(); | ||
} | ||
CreateVariables(program_desc, scope_, block_id); | ||
CreateOps(program_desc, block_id, with_feed_fetch_ops); | ||
} | ||
|
||
void NaiveExecutor::Run() { | ||
for (auto &op : ops_) { | ||
VLOG(4) << "run " << op->Type(); | ||
op->Run(*scope_, place_); | ||
} | ||
} | ||
|
||
void NaiveExecutor::CreateVariables(const ProgramDesc &desc, Scope *scope, | ||
int block_id) { | ||
PADDLE_ENFORCE(scope); | ||
auto &global_block = desc.Block(block_id); | ||
|
||
const Scope *ancestor_scope = scope; | ||
while (ancestor_scope->parent()) { | ||
ancestor_scope = ancestor_scope->parent(); | ||
} | ||
|
||
if (ancestor_scope != scope) { | ||
for (auto &var : global_block.AllVars()) { | ||
if (var->Name() == framework::kEmptyVarName) { | ||
continue; | ||
} | ||
// Create persistable vars in ancestor scope. | ||
if (var->Persistable()) { | ||
auto *ptr = const_cast<Scope *>(ancestor_scope)->Var(var->Name()); | ||
InitializeVariable(ptr, var->GetType()); | ||
VLOG(3) << "Create Variable " << var->Name() | ||
<< " global, which pointer is " << ptr; | ||
} else { // Create temporary variables in local scope. | ||
auto *ptr = scope->Var(var->Name()); | ||
InitializeVariable(ptr, var->GetType()); | ||
VLOG(3) << "Create Variable " << var->Name() | ||
<< " locally, which pointer is " << ptr; | ||
} | ||
} | ||
} else { | ||
for (auto &var : global_block.AllVars()) { | ||
auto *ptr = scope->Var(var->Name()); | ||
InitializeVariable(ptr, var->GetType()); | ||
VLOG(3) << "Create variable " << var->Name() << ", which pointer is " | ||
<< ptr; | ||
} | ||
} | ||
} | ||
|
||
void NaiveExecutor::CreateOps(const ProgramDesc &desc, int block_id, | ||
bool with_feed_fetch_ops) { | ||
for (const auto &op_desc : desc.Block(block_id).AllOps()) { | ||
if (!with_feed_fetch_ops && | ||
(op_desc->Type() == "feed" || op_desc->Type() == "fetch")) { | ||
string::PrettyLogEndl(string::Style::detail(), "--- skip [%s], %s -> %s", | ||
op_desc->Input("X")[0], op_desc->Type(), | ||
op_desc->Output("Out")[0]); | ||
continue; | ||
} | ||
ops_.emplace_back(OpRegistry::CreateOp(*op_desc)); | ||
} | ||
} | ||
|
||
LoDTensor *NaiveExecutor::FindTensor(const std::string &name) { | ||
PADDLE_ENFORCE(scope_, "Need to init scope first"); | ||
auto *var = scope_->FindVar(name); | ||
PADDLE_ENFORCE(var, "No variable [%s] in the scope"); | ||
auto *tensor = const_cast<LoDTensor *>(&var->Get<LoDTensor>()); | ||
return tensor; | ||
} | ||
|
||
void NaiveExecutor::CleanFeedFetchOps() { | ||
std::vector<std::unique_ptr<OperatorBase>> ops; | ||
for (auto &op : ops_) { | ||
if (op->Type() != "feed" && op->Type() != "fetch") { | ||
ops.emplace_back(std::move(op)); | ||
} | ||
} | ||
ops_.swap(ops); | ||
} | ||
|
||
} // namespace framework | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include "paddle/fluid/framework/operator.h" | ||
#include "paddle/fluid/framework/program_desc.h" | ||
#include "paddle/fluid/framework/scope.h" | ||
#include "paddle/fluid/platform/device_context.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
||
/* | ||
* Simple, intuitive and effective. Only single thread is supported, and | ||
* currently designed for inference. | ||
*/ | ||
class NaiveExecutor { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NaiveExecutor是否取名为InferenceExecutor更合理 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 不一定是 inference 的。 |
||
public: | ||
explicit NaiveExecutor(const platform::Place& place) : place_(place) {} | ||
|
||
// Create child scope. | ||
// Create variables. | ||
// @with_feed_fetch_ops: whether to work with the feed and fetch operators. | ||
void Prepare(Scope* parent_scope, const ProgramDesc& program_desc, | ||
int block_id, bool with_feed_fetch_ops); | ||
|
||
// Run all the operators. | ||
void Run(); | ||
|
||
// Get an tensor to operating directly, without the need for feed_ops. | ||
LoDTensor* FindTensor(const std::string& name); | ||
|
||
Scope* scope() { return scope_; } | ||
|
||
void CleanFeedFetchOps(); | ||
|
||
protected: | ||
void CreateVariables(const ProgramDesc& desc, Scope* scope, int block_id); | ||
|
||
void CreateOps(const ProgramDesc& desc, int block_id, | ||
bool with_feed_fetch_ops); | ||
|
||
private: | ||
const platform::Place place_; | ||
// Catch the required resource to avoid recreate. | ||
std::vector<std::unique_ptr<OperatorBase>> ops_; | ||
Scope* scope_; | ||
}; | ||
|
||
} // namespace framework | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/framework/naive_executor.h" | ||
#include <gtest/gtest.h> | ||
#include <algorithm> | ||
#include "paddle/fluid/framework/op_registry.h" | ||
#include "paddle/fluid/framework/program_desc.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
||
TEST(NaiveExecutor, Basic) { | ||
ProgramDesc program; | ||
auto* main_block = program.MutableBlock(0); | ||
auto* a = main_block->Var("a"); // input | ||
auto* b = main_block->Var("b"); // input | ||
auto* c = main_block->Var("c"); // input | ||
a->SetType(proto::VarType::LOD_TENSOR); | ||
b->SetType(proto::VarType::LOD_TENSOR); | ||
c->SetType(proto::VarType::LOD_TENSOR); | ||
|
||
auto* add = main_block->AppendOp(); | ||
add->SetType("elementwise_add"); | ||
add->SetInput("X", {"a"}); | ||
add->SetInput("Y", {"b"}); | ||
add->SetOutput("Out", {"c"}); | ||
|
||
auto place = platform::CPUPlace(); | ||
NaiveExecutor exe(place); | ||
exe.Prepare(nullptr, program, 0, false /*with feed fetch ops*/); | ||
auto* a_tensor = exe.FindTensor("a"); | ||
auto* b_tensor = exe.FindTensor("b"); | ||
auto* c_tensor = exe.FindTensor("c"); | ||
|
||
a_tensor->Resize({1, 4}); | ||
b_tensor->Resize({1, 4}); | ||
c_tensor->Resize({1, 4}); | ||
b_tensor->mutable_data<float>(place); | ||
a_tensor->mutable_data<float>(place); | ||
|
||
float a_arr[] = {0, 1, 2, 3}; | ||
float b_arr[] = {0.0, .1, .2, .3}; | ||
|
||
std::copy_n(a_arr, 4, a_tensor->mutable_data<float>(place)); | ||
std::copy_n(b_arr, 4, b_tensor->mutable_data<float>(place)); | ||
|
||
exe.Run(); | ||
|
||
auto* c_data = c_tensor->mutable_data<float>(place); | ||
for (int i = 0; i < 4; i++) { | ||
EXPECT_NEAR(c_data[i], 1.1 * i, 1e-3); | ||
} | ||
} | ||
|
||
} // namespace framework | ||
} // namespace paddle | ||
|
||
USE_OP(elementwise_add); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
144行应该放在if (NOT WITH_DISTRIBUTE),即151行后面吧。对分布式不影响。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
即使是 WITH_DISTRIBUTE的情况,也是需要跑inference的单测的,那个需要naive_executor。