-
Notifications
You must be signed in to change notification settings - Fork 667
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat multi input sharing graph, save and load compiled graph #9754
Changes from all commits
72cab02
9b968f1
ee42376
e68d89a
6c50cef
ec7a542
78a41da
10f0c9e
843b429
9dc411d
b5db0e7
3ba6e12
936b0fa
4e69b77
8bc4f0e
00b5a37
539770f
d1fcd7d
f4038f8
0fba176
9921dc2
97d4994
2f9f28f
5c48aa4
d5caa28
6be6b90
edb1061
21fc9ba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,45 +44,47 @@ class BufferMgr final { | |
}; | ||
|
||
static const std::string kBufferNameGlobalWaitJobId = "GlobalWaitJobId"; | ||
static const std::string kCallbackNotifierBufferNamePrefix = "CallbackNotifier-"; | ||
static const std::string kInputCriticalSectionWaitBufferNamePrefix = "InputCriticalSectionWait-"; | ||
static const std::string kInputCriticalSectionCallbackBufferNamePrefix = | ||
"InputCriticalSectionCallback-"; | ||
static const std::string kOutputCriticalSectionWaitBufferNamePrefix = "OutputCriticalSectionWait-"; | ||
static const std::string kOutputCriticalSectionCallbackBufferNamePrefix = | ||
"OutputCriticalSectionCallback-"; | ||
static const std::string kInputBufferNamePrefix = "Input-"; | ||
static const std::string kOutputBufferNamePrefix = "Output-"; | ||
static const std::string kSourceTickBufferNamePrefix = "SourceTick-"; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 因为本文件外部要使用这些名字,所以放到了函数外部 |
||
|
||
inline std::string GetCallbackNotifierBufferName(const std::string& job_name) { | ||
static const std::string prefix = "CallbackNotifier-"; | ||
return prefix + job_name; | ||
return kCallbackNotifierBufferNamePrefix + job_name; | ||
} | ||
|
||
inline std::string GetInputCriticalSectionWaitBufferName(const std::string& job_name) { | ||
static const std::string prefix = "InputCriticalSectionWait-"; | ||
return prefix + job_name; | ||
return kInputCriticalSectionWaitBufferNamePrefix + job_name; | ||
} | ||
|
||
inline std::string GetInputCriticalSectionCallbackBufferName(const std::string& job_name) { | ||
static const std::string prefix = "InputCriticalSectionCallback-"; | ||
return prefix + job_name; | ||
return kInputCriticalSectionCallbackBufferNamePrefix + job_name; | ||
} | ||
|
||
inline std::string GetOutputCriticalSectionWaitBufferName(const std::string& job_name) { | ||
static const std::string prefix = "OutputCriticalSectionWait-"; | ||
return prefix + job_name; | ||
return kOutputCriticalSectionWaitBufferNamePrefix + job_name; | ||
} | ||
|
||
inline std::string GetOutputCriticalSectionCallbackBufferName(const std::string& job_name) { | ||
static const std::string prefix = "OutputCriticalSectionCallback-"; | ||
return prefix + job_name; | ||
return kOutputCriticalSectionCallbackBufferNamePrefix + job_name; | ||
} | ||
|
||
inline std::string GetInputBufferName(const std::string& job_name, const std::string& op_name) { | ||
static const std::string prefix = "Input-"; | ||
return prefix + job_name + "-" + op_name; | ||
return kInputBufferNamePrefix + job_name + "-" + op_name; | ||
} | ||
|
||
inline std::string GetOutputBufferName(const std::string& job_name, const std::string& op_name) { | ||
static const std::string prefix = "Output-"; | ||
return prefix + job_name + "-" + op_name; | ||
return kOutputBufferNamePrefix + job_name + "-" + op_name; | ||
} | ||
|
||
inline std::string GetSourceTickBufferName(const std::string& job_name) { | ||
static const std::string prefix = "SourceTick-"; | ||
return prefix + job_name; | ||
return kSourceTickBufferNamePrefix + job_name; | ||
} | ||
|
||
} // namespace oneflow | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ limitations under the License. | |
*/ | ||
#include "oneflow/core/framework/nn_graph.h" | ||
#include "oneflow/core/common/buffer_manager.h" | ||
#include "oneflow/core/common/hash_container.h" | ||
#include "oneflow/core/common/maybe.h" | ||
#include "oneflow/core/common/scalar.h" | ||
#include "oneflow/core/common/cost_util.h" | ||
|
@@ -26,6 +27,7 @@ limitations under the License. | |
#include "oneflow/core/framework/instructions_builder.h" | ||
#include "oneflow/core/framework/nd_sbp.h" | ||
#include "oneflow/core/framework/scope_util.h" | ||
#include "oneflow/core/framework/tensor.h" | ||
#include "oneflow/core/framework/tensor_name_scope.h" | ||
#include "oneflow/core/functional/functional.h" | ||
#include "oneflow/core/graph/op_graph.h" | ||
|
@@ -305,7 +307,7 @@ Maybe<void> NNGraph::RegisterNewVariableOpInJobPass() { | |
} | ||
|
||
Maybe<void> NNGraph::DeleteOutdatedVariableInVariableTensorMgr() { | ||
std::set<std::string> variable_names = [&]() -> Maybe<std::set<std::string>> { | ||
const auto& var_get_func = [&]() -> Maybe<std::set<std::string>> { | ||
std::set<std::string> variable_names_; | ||
OpGraph op_graph(job_); | ||
JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> { | ||
|
@@ -314,8 +316,8 @@ Maybe<void> NNGraph::DeleteOutdatedVariableInVariableTensorMgr() { | |
return Maybe<void>::Ok(); | ||
})); | ||
return variable_names_; | ||
}() | ||
.GetOrThrow(); | ||
}; | ||
std::set<std::string> variable_names = *JUST(var_get_func()); | ||
|
||
auto mgr = Singleton<VariableTensorMgr>::Get(); | ||
for (auto& name : mgr->DumpNames()) { | ||
|
@@ -324,28 +326,88 @@ Maybe<void> NNGraph::DeleteOutdatedVariableInVariableTensorMgr() { | |
return Maybe<void>::Ok(); | ||
} | ||
|
||
Maybe<void> NNGraph::CompileAndInitRuntime() { | ||
Maybe<void> NNGraph::AlignStatesAfterLogicalGraphCompile() { | ||
auto compile_tc = std::make_unique<CostCounter<std::chrono::seconds>>(true, true); | ||
CHECK_OR_RETURN(!runtime_inited_) | ||
<< Error::RuntimeError() << "nn.Graph runtime is already initialized"; | ||
JUST(RegisterFreeEagerTensorsToVariableOpNames()); | ||
JUST(RegisterNewVariableOpInJobPass()); | ||
JUST(DeleteOutdatedVariableInVariableTensorMgr()); | ||
|
||
// NOTE(chengcheng): TensorNameScope need to be cleared after current graph is built. | ||
one::TensorNameScope::Global()->Clear(); | ||
// Clear all backward pass scope | ||
ClearAllBackwardPassScope(); | ||
compile_tc->Count("[GraphCompile]" + name_ + " AlignStates", 0); | ||
return Maybe<void>::Ok(); | ||
} | ||
|
||
// NOTE(chengcheng): Singleton<JobDesc> need be clear before GlobalJobDescScope construct. | ||
if (Singleton<JobDesc>::Get() != nullptr) { Singleton<JobDesc>::Delete(); } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这行逻辑挪到哪里了呢 |
||
Maybe<void> NNGraph::CompleteLogicalGraphForRuntime() { | ||
auto compile_tc = std::make_unique<CostCounter<std::chrono::seconds>>(true, true); | ||
// A global variable to get graph configurations. | ||
auto current_graph_config = std::make_unique<GlobalJobDescScope>(job_.job_conf(), job_id()); | ||
// NOTE(chengcheng): do job compeleter for each rank. | ||
JUST(JobCompleter::Complete(&job_)); | ||
compile_tc->Count("[GraphCompile]" + name_ + " CompleteJob", 0); | ||
return Maybe<void>::Ok(); | ||
} | ||
|
||
auto scope = std::make_unique<GlobalJobDescScope>(job_.job_conf(), job_id_); | ||
Maybe<void> NNGraph::BuildWithNewInputFromSharedGraph( | ||
const std::vector<std::string>& shared_inputs_op_names, | ||
const std::vector<std::shared_ptr<one::Tensor>>& new_input_tensors, | ||
const std::vector<std::string>& shared_op_names_from_ordered_original_graph, | ||
const std::string& new_serialized_original_job) { | ||
CHECK_EQ_OR_RETURN(shared_inputs_op_names.size(), new_input_tensors.size()); // NOLINE | ||
auto compile_tc = std::make_unique<CostCounter<std::chrono::seconds>>(true, true); | ||
// Register inputs. | ||
JUST(RegisterInputOpNamesAndTensors(shared_inputs_op_names, new_input_tensors)); | ||
|
||
// Generate new input tensor getter. | ||
HashMap<std::string, std::shared_ptr<one::Tensor>> input_name2tensor; | ||
for (int64_t idx = 0; idx < shared_inputs_op_names.size(); ++idx) { | ||
input_name2tensor.emplace(shared_inputs_op_names[idx], new_input_tensors[idx]); | ||
} | ||
const auto& InputTensor4Name = | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个是不是应该放在 RegisterInputOpNamesAndTensors 里 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
InputTensor4Name 只有下面用了,上面的RegisterInputOpNamesAndTensors不依赖这个查找,所以就写成了用完就释放的形式 |
||
[&input_name2tensor](const std::string& op_name) -> Maybe<std::shared_ptr<one::Tensor>> { | ||
auto iter = input_name2tensor.find(op_name); | ||
CHECK_OR_RETURN(iter != input_name2tensor.end()) | ||
<< "Can't find input tensor of " << op_name << "."; | ||
return iter->second; | ||
}; | ||
|
||
// Generate new OperatorConf getter. | ||
Job new_build_original_job; | ||
CHECK_OR_RETURN(new_build_original_job.ParseFromString(new_serialized_original_job)) | ||
<< "nn.Graph " << name_ << " parse job proto of new build graph failed."; | ||
CHECK_EQ_OR_RETURN(new_build_original_job.net().op_size(), | ||
shared_op_names_from_ordered_original_graph.size()) | ||
<< "nn.Graph " << name_ | ||
<< " new_build_original_job op size and shared_op_names_from_ordered_original_graph size are " | ||
"not " | ||
"equal."; | ||
HashMap<std::string, const OperatorConf*> shared_op_name2_new_op; | ||
for (int64_t op_idx = 0; op_idx < shared_op_names_from_ordered_original_graph.size(); ++op_idx) { | ||
// Assume that the new graph and the shared graph from nn.Graph.build have the same op order. | ||
const auto& op = new_build_original_job.mutable_net()->mutable_op()->at(op_idx); | ||
shared_op_name2_new_op.emplace(shared_op_names_from_ordered_original_graph[op_idx], &op); | ||
} | ||
const auto& NewOp4SharedOpName = | ||
[&shared_op_name2_new_op](const std::string& shared_op_name) -> Maybe<const OperatorConf*> { | ||
auto iter = shared_op_name2_new_op.find(shared_op_name); | ||
CHECK_OR_RETURN(iter != shared_op_name2_new_op.end()) | ||
<< "Can't find new operator conf of " << shared_op_name << "."; | ||
return iter->second; | ||
}; | ||
|
||
// A global variable to get graph configurations. | ||
auto current_graph_config = std::make_unique<GlobalJobDescScope>(job_.job_conf(), job_id()); | ||
// NOTE(chengcheng): do job compeleter for each rank. | ||
JUST(JobCompleter().Complete(&job_)); | ||
JUST(JobCompleter::UpdateSharedGraphForNewInput(&job_, InputTensor4Name, NewOp4SharedOpName)); | ||
compile_tc->Count("[GraphCompile]" + name_ + " CompleteJob", 0); | ||
return Maybe<void>::Ok(); | ||
} | ||
|
||
Maybe<void> NNGraph::CompilePlanForRuntime() { | ||
auto compile_tc = std::make_unique<CostCounter<std::chrono::seconds>>(true, true); | ||
// A global variable to get graph configurations. | ||
auto current_graph_config = std::make_unique<GlobalJobDescScope>(job_.job_conf(), job_id()); | ||
if (GlobalProcessCtx::IsThisProcessMaster()) { | ||
// TODO(chengcheng): new memory reused by chunk | ||
Compiler().Compile(&job_, &plan_); | ||
|
@@ -389,7 +451,14 @@ Maybe<void> NNGraph::CompileAndInitRuntime() { | |
compile_tc->Count("[GraphCompile]" + name_ + " SyncPlan", 0, true); | ||
// NOTE(chengcheng): recovery op_attr | ||
PlanUtil::PopulateOpAttribute(&plan_, plan_.job_id2op_attribute_ref_table()); | ||
return Maybe<void>::Ok(); | ||
} | ||
|
||
Maybe<void> NNGraph::InitRuntime() { | ||
CHECK_OR_RETURN(!runtime_inited_) | ||
<< Error::RuntimeError() << "nn.Graph runtime is already initialized"; | ||
|
||
auto compile_tc = std::make_unique<CostCounter<std::chrono::seconds>>(true, true); | ||
NewRuntimeBuffers(); | ||
|
||
JUST(GetVariableRealBlobAfterSyncPlan()); | ||
|
@@ -402,10 +471,19 @@ Maybe<void> NNGraph::CompileAndInitRuntime() { | |
runtime_.reset(new Runtime(plan_, variable_op_name2eager_blob_object_)); | ||
compile_tc->Count("[GraphCompile]" + name_ + " InitRuntime", 0, true); | ||
JUST(LogProgress("[GraphCompile]" + name_ + " Done", true)); | ||
|
||
runtime_inited_ = true; | ||
return Maybe<void>::Ok(); | ||
} | ||
|
||
Maybe<void> NNGraph::CompileAndInitRuntime() { | ||
JUST(AlignStatesAfterLogicalGraphCompile()); | ||
JUST(CompleteLogicalGraphForRuntime()); | ||
JUST(CompilePlanForRuntime()); | ||
JUST(InitRuntime()); | ||
return Maybe<void>::Ok(); | ||
} | ||
|
||
Maybe<void> NNGraph::GetVariableRealBlobAfterSyncPlan() { | ||
CHECK_OR_RETURN(variable_op_name2eager_blob_object_.empty()) | ||
<< Error::RuntimeError() << kOfBugIssueUploadPrompt; | ||
|
@@ -611,7 +689,8 @@ Maybe<void> RunLazyNNGraph(const one::TensorTuple& inputs, const one::TensorTupl | |
<< "nn.Graph ONLY accepts static inputs tensor meta, please check whether your input " | ||
<< "tensor meta each step is the same as the input of first call graph.\nThe excepted " | ||
<< "tensor meta is: " << static_meta_str | ||
<< ", but the actual tensor meta is: " << tensor_meta_str; | ||
<< ", but the actual tensor meta is: " << tensor_meta_str << ". The input index is " << i | ||
<< "."; | ||
} | ||
for (int i = 0; i < outputs.size(); ++i) { | ||
CHECK_OR_RETURN(nn_graph->outputs_tensor_meta_str().at(i) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,12 +41,24 @@ class NNGraph final : public NNGraphIf { | |
session_ctx_(session_ctx), | ||
runtime_inited_(false), | ||
is_closed_(false) {} | ||
explicit NNGraph(const std::string& name, const Plan& plan, int64_t job_id, | ||
const std::shared_ptr<MultiClientSessionContext>& session_ctx) | ||
: name_(name), | ||
job_id_(job_id), | ||
session_ctx_(session_ctx), | ||
plan_(plan), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 支持重 plan 初始化 NNGraph |
||
runtime_inited_(false), | ||
is_closed_(false) {} | ||
OF_DISALLOW_COPY_AND_MOVE(NNGraph); | ||
~NNGraph(); | ||
|
||
const std::string& job_name() const override { return name_; } | ||
const Job& job() const { return job_; } | ||
void restore_job(const Job& job) { job_ = job; } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这一行只是挪了位置吧? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是的,这个没改 |
||
int64_t job_id() const { return job_id_; } | ||
void restore_job_id(int64_t job_id) { job_id_ = job_id; } | ||
const Plan& plan() const { return plan_; } | ||
void restore_plan(const Plan& plan) { plan_ = plan; } | ||
const std::vector<std::string>& inputs_op_names() const override; | ||
const std::vector<std::string>& outputs_op_names() const override; | ||
const std::vector<bool>& inputs_valid() const override; | ||
|
@@ -56,9 +68,6 @@ class NNGraph final : public NNGraphIf { | |
int64_t variable_op_size() const; | ||
const std::shared_ptr<vm::EagerBlobObjectList>& var_blobs() const; | ||
|
||
void restore_job(const Job& job) { job_ = job; } | ||
void restore_job_id(int64_t job_id) { job_id_ = job_id; } | ||
|
||
Maybe<void> RegisterAdditionalVarOpNamesAndTensorsToBeLoaded( | ||
const std::vector<std::string>& additional_var_names, | ||
const std::vector<std::shared_ptr<one::Tensor>>& additional_var_tensors); | ||
|
@@ -73,6 +82,20 @@ class NNGraph final : public NNGraphIf { | |
const std::vector<std::shared_ptr<one::Tensor>>& variable_tensors); | ||
Maybe<std::vector<std::string>> GetAdditionalVarOpNames() const; | ||
Maybe<std::vector<std::shared_ptr<one::Tensor>>> GetAdditionalVarOpTensors() const; | ||
// After logical graph compile, some state variables should be cleaned or built. | ||
Maybe<void> AlignStatesAfterLogicalGraphCompile(); | ||
// Add special operators into logical graph for lazy runtime. | ||
Maybe<void> CompleteLogicalGraphForRuntime(); | ||
// Build graph with new inputs from a completed job of a shared graph. | ||
Maybe<void> BuildWithNewInputFromSharedGraph( | ||
const std::vector<std::string>& shared_inputs_op_names, | ||
const std::vector<std::shared_ptr<one::Tensor>>& new_input_tensors, | ||
const std::vector<std::string>& shared_op_names_from_ordered_original_graph, | ||
const std::string& new_serialized_original_job); | ||
// Generate execution plan for lazy runtime. Oneflow lazy runtime is an actor based runtime. | ||
Maybe<void> CompilePlanForRuntime(); | ||
// Initialize lazy runtime. | ||
Maybe<void> InitRuntime(); | ||
Maybe<void> CompileAndInitRuntime(); | ||
Maybe<void> Close(); | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -72,6 +72,7 @@ bool IsInterfaceOpConf(const OperatorConf& op_conf) { | |
} | ||
|
||
GlobalJobDescScope::GlobalJobDescScope(const JobConfigProto& job_conf, int64_t job_id) { | ||
if (Singleton<JobDesc>::Get() != nullptr) { Singleton<JobDesc>::Delete(); } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 哦哦看到挪到这里了 |
||
Singleton<JobDesc>::New(job_conf, job_id); | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个 ctx 是为了还原 save 的时候的 session?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个 ctx 是之前那个用引用计数来释放 graph/session/env 的 pr 引入的,这个 pr 只是新增了一个从 plan 构造 c nn graph 构造函数。