-
Notifications
You must be signed in to change notification settings - Fork 662
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat multi input sharing graph, save and load compiled graph #9754
Changes from 9 commits
72cab02
9b968f1
ee42376
e68d89a
6c50cef
ec7a542
78a41da
10f0c9e
843b429
9dc411d
b5db0e7
3ba6e12
936b0fa
4e69b77
8bc4f0e
00b5a37
539770f
d1fcd7d
f4038f8
0fba176
9921dc2
97d4994
2f9f28f
5c48aa4
d5caa28
6be6b90
edb1061
21fc9ba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,7 +14,12 @@ See the License for the specific language governing permissions and | |
limitations under the License. | ||
*/ | ||
#include "oneflow/core/framework/nn_graph.h" | ||
#include <cstdint> | ||
#include <memory> | ||
#include <string> | ||
#include <vector> | ||
#include "oneflow/core/common/buffer_manager.h" | ||
#include "oneflow/core/common/hash_container.h" | ||
#include "oneflow/core/common/maybe.h" | ||
#include "oneflow/core/common/scalar.h" | ||
#include "oneflow/core/common/cost_util.h" | ||
|
@@ -26,6 +31,7 @@ limitations under the License. | |
#include "oneflow/core/framework/instructions_builder.h" | ||
#include "oneflow/core/framework/nd_sbp.h" | ||
#include "oneflow/core/framework/scope_util.h" | ||
#include "oneflow/core/framework/tensor.h" | ||
#include "oneflow/core/framework/tensor_name_scope.h" | ||
#include "oneflow/core/functional/functional.h" | ||
#include "oneflow/core/graph/op_graph.h" | ||
|
@@ -305,7 +311,7 @@ Maybe<void> NNGraph::RegisterNewVariableOpInJobPass() { | |
} | ||
|
||
Maybe<void> NNGraph::DeleteOutdatedVariableInVariableTensorMgr() { | ||
std::set<std::string> variable_names = [&]() -> Maybe<std::set<std::string>> { | ||
const auto& var_get_func = [&]() -> Maybe<std::set<std::string>> { | ||
std::set<std::string> variable_names_; | ||
OpGraph op_graph(job_); | ||
JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> { | ||
|
@@ -314,8 +320,8 @@ Maybe<void> NNGraph::DeleteOutdatedVariableInVariableTensorMgr() { | |
return Maybe<void>::Ok(); | ||
})); | ||
return variable_names_; | ||
}() | ||
.GetOrThrow(); | ||
}; | ||
std::set<std::string> variable_names = *JUST(var_get_func()); | ||
|
||
auto mgr = Singleton<VariableTensorMgr>::Get(); | ||
for (auto& name : mgr->DumpNames()) { | ||
|
@@ -324,28 +330,83 @@ Maybe<void> NNGraph::DeleteOutdatedVariableInVariableTensorMgr() { | |
return Maybe<void>::Ok(); | ||
} | ||
|
||
Maybe<void> NNGraph::CompileAndInitRuntime() { | ||
Maybe<void> NNGraph::AlignStatesAfterLogicalGraphCompile() { | ||
auto compile_tc = std::make_unique<CostCounter<std::chrono::seconds>>(true, true); | ||
CHECK_OR_RETURN(!runtime_inited_) | ||
<< Error::RuntimeError() << "nn.Graph runtime is already initialized"; | ||
JUST(RegisterFreeEagerTensorsToVariableOpNames()); | ||
JUST(RegisterNewVariableOpInJobPass()); | ||
JUST(DeleteOutdatedVariableInVariableTensorMgr()); | ||
|
||
// NOTE(chengcheng): TensorNameScope need to be cleared after current graph is built. | ||
one::TensorNameScope::Global()->Clear(); | ||
// Clear all backward pass scope | ||
ClearAllBackwardPassScope(); | ||
compile_tc->Count("[GraphCompile]" + name_ + " AlignStates", 0); | ||
return Maybe<void>::Ok(); | ||
} | ||
|
||
// NOTE(chengcheng): Singleton<JobDesc> need be clear before GlobalJobDescScope construct. | ||
if (Singleton<JobDesc>::Get() != nullptr) { Singleton<JobDesc>::Delete(); } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这行逻辑挪到哪里了呢 |
||
Maybe<void> NNGraph::CompleteLogicalGraphForRuntime() { | ||
auto compile_tc = std::make_unique<CostCounter<std::chrono::seconds>>(true, true); | ||
// A global variable to get graph configurations. | ||
auto current_graph_config = std::make_unique<GlobalJobDescScope>(job_.job_conf(), job_id()); | ||
// NOTE(chengcheng): do job compeleter for each rank. | ||
JUST(JobCompleter::Complete(&job_)); | ||
compile_tc->Count("[GraphCompile]" + name_ + " CompleteJob", 0); | ||
return Maybe<void>::Ok(); | ||
} | ||
|
||
auto scope = std::make_unique<GlobalJobDescScope>(job_.job_conf(), job_id_); | ||
Maybe<void> NNGraph::BuildWithNewInputFromSharedGraph( | ||
const std::vector<std::string>& shared_inputs_op_names, | ||
const std::vector<std::shared_ptr<one::Tensor>>& new_input_tensors, | ||
const std::vector<std::string>& shared_op_names, const std::string& new_serialized_job) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shared_op_names 这个参数有什么用呢, new_serialized_job 里是不是都有 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
shared_op_names 是从 build 那里直接产生的原始逻辑图得到的,new_serialized_job 里面已经是优化后的图了。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
那如何保证 shared_op_names 和 new_serialized_job 两者相同呢,可能 new_serialized_job 没有 shared_op_names 里的 op 了 |
||
CHECK_EQ_OR_RETURN(shared_inputs_op_names.size(), new_input_tensors.size()); // NOLINE | ||
auto compile_tc = std::make_unique<CostCounter<std::chrono::seconds>>(true, true); | ||
// Register inputs. | ||
JUST(RegisterInputOpNamesAndTensors(shared_inputs_op_names, new_input_tensors)); | ||
|
||
// Generate new input tensor getter. | ||
HashMap<std::string, std::shared_ptr<one::Tensor>> input_name2tensor; | ||
for (int64_t idx = 0; idx < shared_inputs_op_names.size(); ++idx) { | ||
input_name2tensor.emplace(shared_inputs_op_names[idx], new_input_tensors[idx]); | ||
} | ||
const auto& InputTensor4Name = | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个是不是应该放在 RegisterInputOpNamesAndTensors 里 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
InputTensor4Name 只有下面用了,上面的RegisterInputOpNamesAndTensors不依赖这个查找,所以就写成了用完就释放的形式 |
||
[&input_name2tensor](const std::string& op_name) -> Maybe<std::shared_ptr<one::Tensor>> { | ||
auto iter = input_name2tensor.find(op_name); | ||
CHECK_OR_RETURN(iter != input_name2tensor.end()) | ||
<< "Can't find input tensor of " << op_name << "."; | ||
return iter->second; | ||
}; | ||
|
||
// Generate new OperatorConf getter. | ||
Job new_build_job; | ||
CHECK_OR_RETURN(new_build_job.ParseFromString(new_serialized_job)) | ||
<< "nn.Graph " << name_ << " parse job proto of new build graph failed."; | ||
CHECK_EQ_OR_RETURN(new_build_job.net().op_size(), shared_op_names.size()) | ||
<< "nn.Graph " << name_ << " new_build_job op size and shared_op_names size are not equal."; | ||
HashMap<std::string, const OperatorConf*> shared_op_name2_new_op; | ||
for (int64_t op_idx = 0; op_idx < shared_op_names.size(); ++op_idx) { | ||
// Assume that the new graph and the shared graph from nn.Graph.build have the same op order. | ||
const auto& op = new_build_job.mutable_net()->mutable_op()->at(op_idx); | ||
shared_op_name2_new_op.emplace(shared_op_names[op_idx], &op); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 其实可以把 : new_build_job 直接传给: CompleteSharedGraphForNewInput 吧 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
这个 map 其实是从:shared op name 到new build job 中的 op。 中间用 op 顺序做了下对应, shared op name 都 op order 到 new build job 中的 op。 以给后面修改 shared graph op attr 做准备。所以只传递 new_build_job 还不行。 这个我改下名字,然后注释下。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
这里我理解是 new_build_job 是不包含 op 顺序导致的? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. new_build_job 是新的 build 函数产生的临时 job,后来改了下名字,它作为新 graph attr 的词典存在。 所以要额外传递 op name 信息来维护新老 op 的对应关系 |
||
} | ||
const auto& NewOp4SharedOpName = | ||
[&shared_op_name2_new_op](const std::string& shared_op_name) -> Maybe<const OperatorConf*> { | ||
auto iter = shared_op_name2_new_op.find(shared_op_name); | ||
CHECK_OR_RETURN(iter != shared_op_name2_new_op.end()) | ||
<< "Can't find new operator conf of " << shared_op_name << "."; | ||
return iter->second; | ||
}; | ||
|
||
// A global variable to get graph configurations. | ||
auto current_graph_config = std::make_unique<GlobalJobDescScope>(job_.job_conf(), job_id()); | ||
// NOTE(chengcheng): do job compeleter for each rank. | ||
JUST(JobCompleter().Complete(&job_)); | ||
JUST(JobCompleter::CompleteSharedGraphForNewInput(&job_, InputTensor4Name, NewOp4SharedOpName)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里的 job_ 是不是一个空的,从接口命名上应该 update 或者 rewrite job_ ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. job_ 不是空的,是 copy 过来的一个优化后的 job; There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个 copy 是哪里发生的 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 另外一个 comment 里面回复了,是 python graph 在 初始化 c nn graph 时传递进来的。 |
||
compile_tc->Count("[GraphCompile]" + name_ + " CompleteJob", 0); | ||
return Maybe<void>::Ok(); | ||
} | ||
|
||
Maybe<void> NNGraph::CompilePlanForRuntime() { | ||
auto compile_tc = std::make_unique<CostCounter<std::chrono::seconds>>(true, true); | ||
// A global variable to get graph configurations. | ||
auto current_graph_config = std::make_unique<GlobalJobDescScope>(job_.job_conf(), job_id()); | ||
if (GlobalProcessCtx::IsThisProcessMaster()) { | ||
// TODO(chengcheng): new memory reused by chunk | ||
Compiler().Compile(&job_, &plan_); | ||
|
@@ -389,7 +450,14 @@ Maybe<void> NNGraph::CompileAndInitRuntime() { | |
compile_tc->Count("[GraphCompile]" + name_ + " SyncPlan", 0, true); | ||
// NOTE(chengcheng): recovery op_attr | ||
PlanUtil::PopulateOpAttribute(&plan_, plan_.job_id2op_attribute_ref_table()); | ||
return Maybe<void>::Ok(); | ||
} | ||
|
||
Maybe<void> NNGraph::InitRuntime() { | ||
CHECK_OR_RETURN(!runtime_inited_) | ||
<< Error::RuntimeError() << "nn.Graph runtime is already initialized"; | ||
|
||
auto compile_tc = std::make_unique<CostCounter<std::chrono::seconds>>(true, true); | ||
NewRuntimeBuffers(); | ||
|
||
JUST(GetVariableRealBlobAfterSyncPlan()); | ||
|
@@ -402,10 +470,19 @@ Maybe<void> NNGraph::CompileAndInitRuntime() { | |
runtime_.reset(new Runtime(plan_, variable_op_name2eager_blob_object_)); | ||
compile_tc->Count("[GraphCompile]" + name_ + " InitRuntime", 0, true); | ||
JUST(LogProgress("[GraphCompile]" + name_ + " Done", true)); | ||
|
||
runtime_inited_ = true; | ||
return Maybe<void>::Ok(); | ||
} | ||
|
||
Maybe<void> NNGraph::CompileAndInitRuntime() { | ||
JUST(AlignStatesAfterLogicalGraphCompile()); | ||
JUST(CompleteLogicalGraphForRuntime()); | ||
JUST(CompilePlanForRuntime()); | ||
JUST(InitRuntime()); | ||
return Maybe<void>::Ok(); | ||
} | ||
|
||
Maybe<void> NNGraph::GetVariableRealBlobAfterSyncPlan() { | ||
CHECK_OR_RETURN(variable_op_name2eager_blob_object_.empty()) | ||
<< Error::RuntimeError() << kOfBugIssueUploadPrompt; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -72,6 +72,7 @@ bool IsInterfaceOpConf(const OperatorConf& op_conf) { | |
} | ||
|
||
GlobalJobDescScope::GlobalJobDescScope(const JobConfigProto& job_conf, int64_t job_id) { | ||
if (Singleton<JobDesc>::Get() != nullptr) { Singleton<JobDesc>::Delete(); } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 哦哦看到挪到这里了 |
||
Singleton<JobDesc>::New(job_conf, job_id); | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
因为本文件外部要使用这些名字,所以放到了函数外部