Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat multi input sharing graph, save and load compiled graph #9754

Merged
merged 28 commits into from
Feb 1, 2023
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
72cab02
refact complete and gen plan
strint Jan 15, 2023
9b968f1
add align states
strint Jan 15, 2023
ee42376
mut input and re-infer
strint Jan 15, 2023
e68d89a
add build from shared draft
strint Jan 16, 2023
6c50cef
infer shape and build new output
strint Jan 16, 2023
ec7a542
support the main logical, will stuck
strint Jan 17, 2023
78a41da
fix tick
strint Jan 17, 2023
10f0c9e
infer blob with input tensor
strint Jan 17, 2023
843b429
support shape attr update
strint Jan 18, 2023
9dc411d
rm useless header
strint Jan 18, 2023
b5db0e7
rm debug code
strint Jan 18, 2023
3ba6e12
Graph save/load from runtime states (#9779)
strint Jan 21, 2023
936b0fa
refine
strint Jan 24, 2023
4e69b77
fix tensor share and restore device type
strint Jan 26, 2023
8bc4f0e
auto format by CI
oneflow-ci-bot Jan 26, 2023
00b5a37
refine mem
strint Jan 28, 2023
539770f
Merge branch 'feat_multi_in' of https://github.com/Oneflow-Inc/oneflo…
strint Jan 28, 2023
d1fcd7d
auto format by CI
oneflow-ci-bot Jan 28, 2023
f4038f8
reduce cost
strint Jan 28, 2023
0fba176
fix conflict
strint Jan 28, 2023
9921dc2
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
strint Jan 29, 2023
97d4994
save with file
strint Jan 29, 2023
2f9f28f
test with sub process
strint Jan 30, 2023
5c48aa4
add review
strint Jan 31, 2023
d5caa28
refine infer shape
strint Jan 31, 2023
6be6b90
address review, refine states for share and save runtime
strint Feb 1, 2023
edb1061
address review
strint Feb 1, 2023
21fc9ba
Merge branch 'master' into feat_multi_in
strint Feb 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions oneflow/api/python/framework/nn_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,24 @@ ONEFLOW_API_PYBIND11_MODULE("nn.graph.", m) {
const std::shared_ptr<MultiClientSessionContext>& session_ctx) {
Job job;
if (!job.ParseFromString(serialized_job)) {
PyErr_SetString(PyExc_TypeError, "the second argument is not a valid job");
PyErr_SetString(PyExc_TypeError, "The second argument is not a valid job");
}
return std::make_shared<NNGraph>(name, job, job_id, session_ctx);
}))
.def(py::init([](const std::string& name, const std::string& serialized_plan, int64_t job_id,
const std::shared_ptr<MultiClientSessionContext>& session_ctx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个 ctx 是为了还原 save 的时候的 session?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个 ctx 是为了还原 save 的时候的 session?

这个 ctx 是之前那个用引用计数来释放 graph/session/env 的 pr 引入的,这个 pr 只是新增了一个从 plan 构造 c nn graph 构造函数。

bool init_from_plan) {
if (!init_from_plan) {
PyErr_SetString(
PyExc_TypeError,
"init_from_plan must be True when init CNNGraph with this bool parameter.");
}
Plan plan;
if (!plan.ParseFromString(serialized_plan)) {
PyErr_SetString(PyExc_TypeError, "The second argument is not a valid plan");
}
return std::make_shared<NNGraph>(name, plan, job_id, session_ctx);
}))
.def_property_readonly("name", &NNGraph::job_name)
.def_property(
"job", /*getter*/
Expand All @@ -73,14 +87,30 @@ ONEFLOW_API_PYBIND11_MODULE("nn.graph.", m) {
})
.def_property("job_id", &NNGraph::job_id,
[](NNGraph& nn_graph, int64_t job_id) { nn_graph.restore_job_id(job_id); })
.def_property(
"plan", /*getter*/
[](const NNGraph& nn_graph) { return py::bytes(nn_graph.plan().SerializeAsString()); },
/*setter*/
[](NNGraph& nn_graph, const std::string& serialized_plan) {
Plan plan;
if (!plan.ParseFromString(serialized_plan)) {
PyErr_SetString(PyExc_TypeError, "the value is not a valid plan");
}
nn_graph.restore_plan(plan);
})
.def("register_input_op_names_and_tensors", &NNGraph::RegisterInputOpNamesAndTensors)
.def("register_output_op_names_and_tensors", &NNGraph::RegisterOutputOpNamesAndTensors)
.def("register_variable_op_names_and_tensors", &NNGraph::RegisterVariableOpNamesAndTensors)
.def("register_additional_variable_names_and_tensors",
&NNGraph::RegisterAdditionalVarOpNamesAndTensorsToBeLoaded)
.def_property_readonly("additional_var_names", &APINNGraphAdditionalVarNames)
.def_property_readonly("additional_var_tensors", &APINNGraphAdditionalVarTensors)
.def("complie_and_init_runtime", &NNGraph::CompileAndInitRuntime)
.def("align_states_after_logical_graph_compile",
&NNGraph::AlignStatesAfterLogicalGraphCompile)
.def("complete_graph_for_runtime", &NNGraph::CompleteLogicalGraphForRuntime)
.def("build_with_new_input_from_shared_graph", &NNGraph::BuildWithNewInputFromSharedGraph)
.def("compile_plan_for_runtime", &NNGraph::CompilePlanForRuntime)
.def("init_runtime", &NNGraph::InitRuntime)
.def("get_current_job_str", &APINNGraphGetCurrentSerializedJob);

m.def("RunLazyNNGraph", &RunLazyNNGraph);
Expand Down
34 changes: 18 additions & 16 deletions oneflow/core/common/buffer_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,45 +44,47 @@ class BufferMgr final {
};

static const std::string kBufferNameGlobalWaitJobId = "GlobalWaitJobId";
static const std::string kCallbackNotifierBufferNamePrefix = "CallbackNotifier-";
static const std::string kInputCriticalSectionWaitBufferNamePrefix = "InputCriticalSectionWait-";
static const std::string kInputCriticalSectionCallbackBufferNamePrefix =
"InputCriticalSectionCallback-";
static const std::string kOutputCriticalSectionWaitBufferNamePrefix = "OutputCriticalSectionWait-";
static const std::string kOutputCriticalSectionCallbackBufferNamePrefix =
"OutputCriticalSectionCallback-";
static const std::string kInputBufferNamePrefix = "Input-";
static const std::string kOutputBufferNamePrefix = "Output-";
static const std::string kSourceTickBufferNamePrefix = "SourceTick-";
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为本文件外部要使用这些名字,所以放到了函数外部


inline std::string GetCallbackNotifierBufferName(const std::string& job_name) {
static const std::string prefix = "CallbackNotifier-";
return prefix + job_name;
return kCallbackNotifierBufferNamePrefix + job_name;
}

inline std::string GetInputCriticalSectionWaitBufferName(const std::string& job_name) {
static const std::string prefix = "InputCriticalSectionWait-";
return prefix + job_name;
return kInputCriticalSectionWaitBufferNamePrefix + job_name;
}

inline std::string GetInputCriticalSectionCallbackBufferName(const std::string& job_name) {
static const std::string prefix = "InputCriticalSectionCallback-";
return prefix + job_name;
return kInputCriticalSectionCallbackBufferNamePrefix + job_name;
}

inline std::string GetOutputCriticalSectionWaitBufferName(const std::string& job_name) {
static const std::string prefix = "OutputCriticalSectionWait-";
return prefix + job_name;
return kOutputCriticalSectionWaitBufferNamePrefix + job_name;
}

inline std::string GetOutputCriticalSectionCallbackBufferName(const std::string& job_name) {
static const std::string prefix = "OutputCriticalSectionCallback-";
return prefix + job_name;
return kOutputCriticalSectionCallbackBufferNamePrefix + job_name;
}

inline std::string GetInputBufferName(const std::string& job_name, const std::string& op_name) {
static const std::string prefix = "Input-";
return prefix + job_name + "-" + op_name;
return kInputBufferNamePrefix + job_name + "-" + op_name;
}

inline std::string GetOutputBufferName(const std::string& job_name, const std::string& op_name) {
static const std::string prefix = "Output-";
return prefix + job_name + "-" + op_name;
return kOutputBufferNamePrefix + job_name + "-" + op_name;
}

inline std::string GetSourceTickBufferName(const std::string& job_name) {
static const std::string prefix = "SourceTick-";
return prefix + job_name;
return kSourceTickBufferNamePrefix + job_name;
}

} // namespace oneflow
Expand Down
98 changes: 86 additions & 12 deletions oneflow/core/framework/nn_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.
*/
#include "oneflow/core/framework/nn_graph.h"
#include "oneflow/core/common/buffer_manager.h"
#include "oneflow/core/common/hash_container.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/scalar.h"
#include "oneflow/core/common/cost_util.h"
Expand All @@ -26,6 +27,7 @@ limitations under the License.
#include "oneflow/core/framework/instructions_builder.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/framework/scope_util.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/tensor_name_scope.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/graph/op_graph.h"
Expand Down Expand Up @@ -305,7 +307,7 @@ Maybe<void> NNGraph::RegisterNewVariableOpInJobPass() {
}

Maybe<void> NNGraph::DeleteOutdatedVariableInVariableTensorMgr() {
std::set<std::string> variable_names = [&]() -> Maybe<std::set<std::string>> {
const auto& var_get_func = [&]() -> Maybe<std::set<std::string>> {
std::set<std::string> variable_names_;
OpGraph op_graph(job_);
JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> {
Expand All @@ -314,8 +316,8 @@ Maybe<void> NNGraph::DeleteOutdatedVariableInVariableTensorMgr() {
return Maybe<void>::Ok();
}));
return variable_names_;
}()
.GetOrThrow();
};
std::set<std::string> variable_names = *JUST(var_get_func());

auto mgr = Singleton<VariableTensorMgr>::Get();
for (auto& name : mgr->DumpNames()) {
Expand All @@ -324,28 +326,83 @@ Maybe<void> NNGraph::DeleteOutdatedVariableInVariableTensorMgr() {
return Maybe<void>::Ok();
}

Maybe<void> NNGraph::CompileAndInitRuntime() {
Maybe<void> NNGraph::AlignStatesAfterLogicalGraphCompile() {
auto compile_tc = std::make_unique<CostCounter<std::chrono::seconds>>(true, true);
CHECK_OR_RETURN(!runtime_inited_)
<< Error::RuntimeError() << "nn.Graph runtime is already initialized";
JUST(RegisterFreeEagerTensorsToVariableOpNames());
JUST(RegisterNewVariableOpInJobPass());
JUST(DeleteOutdatedVariableInVariableTensorMgr());

// NOTE(chengcheng): TensorNameScope need to be cleared after current graph is built.
one::TensorNameScope::Global()->Clear();
// Clear all backward pass scope
ClearAllBackwardPassScope();
compile_tc->Count("[GraphCompile]" + name_ + " AlignStates", 0);
return Maybe<void>::Ok();
}

// NOTE(chengcheng): Singleton<JobDesc> need be clear before GlobalJobDescScope construct.
if (Singleton<JobDesc>::Get() != nullptr) { Singleton<JobDesc>::Delete(); }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这行逻辑挪到哪里了呢

Maybe<void> NNGraph::CompleteLogicalGraphForRuntime() {
auto compile_tc = std::make_unique<CostCounter<std::chrono::seconds>>(true, true);
// A global variable to get graph configurations.
auto current_graph_config = std::make_unique<GlobalJobDescScope>(job_.job_conf(), job_id());
// NOTE(chengcheng): do job compeleter for each rank.
JUST(JobCompleter::Complete(&job_));
compile_tc->Count("[GraphCompile]" + name_ + " CompleteJob", 0);
return Maybe<void>::Ok();
}

auto scope = std::make_unique<GlobalJobDescScope>(job_.job_conf(), job_id_);
Maybe<void> NNGraph::BuildWithNewInputFromSharedGraph(
const std::vector<std::string>& shared_inputs_op_names,
const std::vector<std::shared_ptr<one::Tensor>>& new_input_tensors,
const std::vector<std::string>& shared_op_names, const std::string& new_serialized_job) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shared_op_names 这个参数有什么用呢, new_serialized_job 里是不是都有

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shared_op_names 这个参数有什么用呢, new_serialized_job 里是不是都有

        shared_op_names = []
        for op_idx in range(len(self._forward_job_proto.net.op)):
            shared_op_names.append(
                self._shared_graph._forward_job_proto.net.op[op_idx].name
            )

shared_op_names 是从 build 那里直接产生的原始逻辑图得到的,new_serialized_job 里面已经是优化后的图了。
优化后的图,没有顺序保证了。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new_serialized_job 里面已经是优化后的图了

那如何保证 shared_op_names 和 new_serialized_job 两者相同呢,可能 new_serialized_job 没有 shared_op_names 里的 op 了

CHECK_EQ_OR_RETURN(shared_inputs_op_names.size(), new_input_tensors.size()); // NOLINE
auto compile_tc = std::make_unique<CostCounter<std::chrono::seconds>>(true, true);
// Register inputs.
JUST(RegisterInputOpNamesAndTensors(shared_inputs_op_names, new_input_tensors));

// Generate new input tensor getter.
HashMap<std::string, std::shared_ptr<one::Tensor>> input_name2tensor;
for (int64_t idx = 0; idx < shared_inputs_op_names.size(); ++idx) {
input_name2tensor.emplace(shared_inputs_op_names[idx], new_input_tensors[idx]);
}
const auto& InputTensor4Name =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是不是应该放在 RegisterInputOpNamesAndTensors 里

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是不是应该放在 RegisterInputOpNamesAndTensors 里

InputTensor4Name 只有下面用了,上面的RegisterInputOpNamesAndTensors不依赖这个查找,所以就写成了用完就释放的形式

[&input_name2tensor](const std::string& op_name) -> Maybe<std::shared_ptr<one::Tensor>> {
auto iter = input_name2tensor.find(op_name);
CHECK_OR_RETURN(iter != input_name2tensor.end())
<< "Can't find input tensor of " << op_name << ".";
return iter->second;
};

// Generate new OperatorConf getter.
Job new_build_job;
CHECK_OR_RETURN(new_build_job.ParseFromString(new_serialized_job))
<< "nn.Graph " << name_ << " parse job proto of new build graph failed.";
CHECK_EQ_OR_RETURN(new_build_job.net().op_size(), shared_op_names.size())
<< "nn.Graph " << name_ << " new_build_job op size and shared_op_names size are not equal.";
HashMap<std::string, const OperatorConf*> shared_op_name2_new_op;
for (int64_t op_idx = 0; op_idx < shared_op_names.size(); ++op_idx) {
// Assume that the new graph and the shared graph from nn.Graph.build have the same op order.
const auto& op = new_build_job.mutable_net()->mutable_op()->at(op_idx);
shared_op_name2_new_op.emplace(shared_op_names[op_idx], &op);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

其实可以把 : new_build_job 直接传给: CompleteSharedGraphForNewInput 吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

其实可以把 : new_build_job 直接传给: CompleteSharedGraphForNewInput 吧

这个 map 其实是从:shared op name 到new build job 中的 op。

中间用 op 顺序做了下对应, shared op name 都 op order 到 new build job 中的 op。

以给后面修改 shared graph op attr 做准备。所以只传递 new_build_job 还不行。

这个我改下名字,然后注释下。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

所以只传递 new_build_job 还不行。

这里我理解是 new_build_job 是不包含 op 顺序导致的?

Copy link
Contributor Author

@strint strint Feb 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new_build_job 是新的 build 函数产生的临时 job,后来改了下名字,它作为新 graph attr 的词典存在。

所以要额外传递 op name 信息来维护新老 op 的对应关系

}
const auto& NewOp4SharedOpName =
[&shared_op_name2_new_op](const std::string& shared_op_name) -> Maybe<const OperatorConf*> {
auto iter = shared_op_name2_new_op.find(shared_op_name);
CHECK_OR_RETURN(iter != shared_op_name2_new_op.end())
<< "Can't find new operator conf of " << shared_op_name << ".";
return iter->second;
};

// A global variable to get graph configurations.
auto current_graph_config = std::make_unique<GlobalJobDescScope>(job_.job_conf(), job_id());
// NOTE(chengcheng): do job compeleter for each rank.
JUST(JobCompleter().Complete(&job_));
JUST(JobCompleter::CompleteSharedGraphForNewInput(&job_, InputTensor4Name, NewOp4SharedOpName));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的 job_ 是不是一个空的,从接口命名上应该 update 或者 rewrite job_ ?

Copy link
Contributor Author

@strint strint Feb 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

job_ 不是空的,是 copy 过来的一个优化后的 job;
改成了 update;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个 copy 是哪里发生的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

另外一个 comment 里面回复了,是 python graph 在 初始化 c nn graph 时传递进来的。

compile_tc->Count("[GraphCompile]" + name_ + " CompleteJob", 0);
return Maybe<void>::Ok();
}

Maybe<void> NNGraph::CompilePlanForRuntime() {
auto compile_tc = std::make_unique<CostCounter<std::chrono::seconds>>(true, true);
// A global variable to get graph configurations.
auto current_graph_config = std::make_unique<GlobalJobDescScope>(job_.job_conf(), job_id());
if (GlobalProcessCtx::IsThisProcessMaster()) {
// TODO(chengcheng): new memory reused by chunk
Compiler().Compile(&job_, &plan_);
Expand Down Expand Up @@ -389,7 +446,14 @@ Maybe<void> NNGraph::CompileAndInitRuntime() {
compile_tc->Count("[GraphCompile]" + name_ + " SyncPlan", 0, true);
// NOTE(chengcheng): recovery op_attr
PlanUtil::PopulateOpAttribute(&plan_, plan_.job_id2op_attribute_ref_table());
return Maybe<void>::Ok();
}

Maybe<void> NNGraph::InitRuntime() {
CHECK_OR_RETURN(!runtime_inited_)
<< Error::RuntimeError() << "nn.Graph runtime is already initialized";

auto compile_tc = std::make_unique<CostCounter<std::chrono::seconds>>(true, true);
NewRuntimeBuffers();

JUST(GetVariableRealBlobAfterSyncPlan());
Expand All @@ -402,10 +466,19 @@ Maybe<void> NNGraph::CompileAndInitRuntime() {
runtime_.reset(new Runtime(plan_, variable_op_name2eager_blob_object_));
compile_tc->Count("[GraphCompile]" + name_ + " InitRuntime", 0, true);
JUST(LogProgress("[GraphCompile]" + name_ + " Done", true));

runtime_inited_ = true;
return Maybe<void>::Ok();
}

Maybe<void> NNGraph::CompileAndInitRuntime() {
JUST(AlignStatesAfterLogicalGraphCompile());
JUST(CompleteLogicalGraphForRuntime());
JUST(CompilePlanForRuntime());
JUST(InitRuntime());
return Maybe<void>::Ok();
}

Maybe<void> NNGraph::GetVariableRealBlobAfterSyncPlan() {
CHECK_OR_RETURN(variable_op_name2eager_blob_object_.empty())
<< Error::RuntimeError() << kOfBugIssueUploadPrompt;
Expand Down Expand Up @@ -611,7 +684,8 @@ Maybe<void> RunLazyNNGraph(const one::TensorTuple& inputs, const one::TensorTupl
<< "nn.Graph ONLY accepts static inputs tensor meta, please check whether your input "
<< "tensor meta each step is the same as the input of first call graph.\nThe excepted "
<< "tensor meta is: " << static_meta_str
<< ", but the actual tensor meta is: " << tensor_meta_str;
<< ", but the actual tensor meta is: " << tensor_meta_str << ". The input index is " << i
<< ".";
}
for (int i = 0; i < outputs.size(); ++i) {
CHECK_OR_RETURN(nn_graph->outputs_tensor_meta_str().at(i)
Expand Down
28 changes: 25 additions & 3 deletions oneflow/core/framework/nn_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,24 @@ class NNGraph final : public NNGraphIf {
session_ctx_(session_ctx),
runtime_inited_(false),
is_closed_(false) {}
explicit NNGraph(const std::string& name, const Plan& plan, int64_t job_id,
const std::shared_ptr<MultiClientSessionContext>& session_ctx)
: name_(name),
job_id_(job_id),
session_ctx_(session_ctx),
plan_(plan),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

支持重 plan 初始化 NNGraph

runtime_inited_(false),
is_closed_(false) {}
OF_DISALLOW_COPY_AND_MOVE(NNGraph);
~NNGraph();

const std::string& job_name() const override { return name_; }
const Job& job() const { return job_; }
void restore_job(const Job& job) { job_ = job; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这一行只是挪了位置吧?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,这个没改

int64_t job_id() const { return job_id_; }
void restore_job_id(int64_t job_id) { job_id_ = job_id; }
const Plan& plan() const { return plan_; }
void restore_plan(const Plan& plan) { plan_ = plan; }
const std::vector<std::string>& inputs_op_names() const override;
const std::vector<std::string>& outputs_op_names() const override;
const std::vector<bool>& inputs_valid() const override;
Expand All @@ -56,9 +68,6 @@ class NNGraph final : public NNGraphIf {
int64_t variable_op_size() const;
const std::shared_ptr<vm::EagerBlobObjectList>& var_blobs() const;

void restore_job(const Job& job) { job_ = job; }
void restore_job_id(int64_t job_id) { job_id_ = job_id; }

Maybe<void> RegisterAdditionalVarOpNamesAndTensorsToBeLoaded(
const std::vector<std::string>& additional_var_names,
const std::vector<std::shared_ptr<one::Tensor>>& additional_var_tensors);
Expand All @@ -73,6 +82,19 @@ class NNGraph final : public NNGraphIf {
const std::vector<std::shared_ptr<one::Tensor>>& variable_tensors);
Maybe<std::vector<std::string>> GetAdditionalVarOpNames() const;
Maybe<std::vector<std::shared_ptr<one::Tensor>>> GetAdditionalVarOpTensors() const;
// After logical graph compile, some state variables should be cleaned or built.
Maybe<void> AlignStatesAfterLogicalGraphCompile();
// Add special operators into logical graph for lazy runtime.
Maybe<void> CompleteLogicalGraphForRuntime();
// Build graph with new inputs from a completed job of a shared graph.
Maybe<void> BuildWithNewInputFromSharedGraph(
const std::vector<std::string>& shared_inputs_op_names,
const std::vector<std::shared_ptr<one::Tensor>>& new_input_tensors,
const std::vector<std::string>& shared_op_names, const std::string& new_serialized_job);
// Generate execution plan for lazy runtime. Oneflow lazy runtime is an actor based runtime.
Maybe<void> CompilePlanForRuntime();
// Initialize lazy runtime.
Maybe<void> InitRuntime();
Maybe<void> CompileAndInitRuntime();
Maybe<void> Close();

Expand Down
5 changes: 0 additions & 5 deletions oneflow/core/job/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,6 @@ void Compiler::Compile(Job* job, Plan* plan) const {
// Step1: new Singleton<OpGraph> and set log configs.
Singleton<OpGraph>::New(*job);
const JobDesc& job_desc = GlobalJobDesc();
if (Singleton<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {
TeePersistentLogStream::Create(StrCat("optimized_job", job_desc.job_id()))->Write(*job);
Singleton<OpGraph>::Get()->ToDotWithFilePath(
"optimized_dlnet_" + std::to_string(job_desc.job_id()) + "_op_graph.dot");
}
compile_tc->Count("[GraphCompile]" + job_name + " NewOpGraph", 1);

// Step2: build task_gph.
Expand Down
4 changes: 2 additions & 2 deletions oneflow/core/job/job_build_and_infer_ctx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -919,8 +919,8 @@ Maybe<void> LazyJobBuildAndInferCtx::Complete() {
<< " Sorry, nn.Graph need at least 1 op in net, but get 0 now.";
auto compile_tc = std::make_unique<CostCounter<std::chrono::seconds>>(true, true);
CHECK_NOTNULL(Singleton<JobDesc>::Get());
Singleton<JobDesc>::Delete();
auto scope = std::make_unique<GlobalJobDescScope>(mut_job()->job_conf(), job_id());
// A global variable to get graph configurations.
auto current_graph_config = std::make_unique<GlobalJobDescScope>(mut_job()->job_conf(), job_id());
JobPassCtx job_pass_ctx(GlobalJobDesc());
const auto job_name = job().job_conf().job_name();
auto LogJob = [&](const std::string& name_suffix) -> void {
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/job/job_desc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ bool IsInterfaceOpConf(const OperatorConf& op_conf) {
}

GlobalJobDescScope::GlobalJobDescScope(const JobConfigProto& job_conf, int64_t job_id) {
if (Singleton<JobDesc>::Get() != nullptr) { Singleton<JobDesc>::Delete(); }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

哦哦看到挪到这里了

Singleton<JobDesc>::New(job_conf, job_id);
}

Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/job/oneflow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ Maybe<void> CompileCurJobOnMaster(Job* job, Plan* plan, bool need_job_complete)
const JobDesc& job_desc = GlobalJobDesc();
if (GlobalProcessCtx::IsThisProcessMaster()) {
double start = GetCurTime();
if (need_job_complete) { JUST(JobCompleter().Complete(job)); }
if (need_job_complete) { JUST(JobCompleter::Complete(job)); }
Compiler().Compile(job, plan);
PlanUtil::GenMemBlockAndChunk4Plan(plan);

Expand Down
Loading