Skip to content

Commit

Permalink
Job pass maybe system (#5503)
Browse files Browse the repository at this point in the history
* refactor job_pass by maybe_system

* remove useless files

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
  • Loading branch information
lixinqi and oneflow-ci-bot committed Jul 16, 2021
1 parent 810d8db commit 50e1c34
Show file tree
Hide file tree
Showing 9 changed files with 154 additions and 127 deletions.
3 changes: 1 addition & 2 deletions oneflow/core/graph/task_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -532,8 +532,7 @@ void TaskGraph::ConnectCtrlEdges(const std::vector<CompTaskNode*>& src_task_node
CHECK_EQ(src_task_nodes.size(), dst_task_nodes.size());
FOR_RANGE(int32_t, i, 0, src_task_nodes.size()) {
std::string regst_desc_name;
RegstDesc* ctrl_regst_desc =
src_task_nodes.at(i)->BuildCtrlRegstDesc(dst_task_nodes.at(i), &regst_desc_name);
src_task_nodes.at(i)->BuildCtrlRegstDesc(dst_task_nodes.at(i), &regst_desc_name);
TaskEdge* edge = NewEdge();
Connect<TaskNode>(src_task_nodes.at(i), edge, dst_task_nodes.at(i));
src_task_nodes.at(i)->BindEdgeWithProducedRegst(edge, regst_desc_name);
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/job/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ void CreateOpAttributeRef(Plan* plan, int64_t job_id, TaskProto* task_proto) {

void Compiler::Compile(Job* job, Plan* plan, bool need_job_complete) const {
// Step1: ensure job is completed.
if (need_job_complete) { JobCompleter().Complete(job); }
if (need_job_complete) { CHECK_JUST(JobCompleter().Complete(job)); }

// Step2: new Global<OpGraph> and set log configs.
Global<OpGraph>::New(*job);
Expand Down
179 changes: 99 additions & 80 deletions oneflow/core/job_rewriter/autotick.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,21 +171,22 @@ Maybe<void> ConnectSrcSubsetTickAndOtherTick(const OperatorConf& src_subset_tick
auto mut_helper = NewMutOpConTickInputHelper(op.op_conf());
if (!mut_helper) { return Maybe<void>::Ok(); }
if (mut_helper->IsTickInputBound() == true) { return Maybe<void>::Ok(); }
job_builder->MutOpsOnlyOnce({mut_helper->NewTickInputBoundOpConf(src_lbn)});
JUST(job_builder->MutOpOnlyOnce(mut_helper->NewTickInputBoundOpConf(src_lbn)));
return Maybe<void>::Ok();
}));
return Maybe<void>::Ok();
}

const OpNode* GetSrcSubsetTickOpNode(const OpGraph& op_graph) {
Maybe<const OpNode*> GetSrcSubsetTickOpNode(const OpGraph& op_graph) {
const OpNode* src_subset_tick = nullptr;
op_graph.ForEachNode([&](OpNode* op_node) {
JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> {
if (op_node->op().op_conf().has_src_subset_tick_conf()) {
CHECK_ISNULL(src_subset_tick);
CHECK_ISNULL_OR_RETURN(src_subset_tick);
src_subset_tick = op_node;
}
});
CHECK_NOTNULL(src_subset_tick);
return Maybe<void>::Ok();
}));
CHECK_NOTNULL_OR_RETURN(src_subset_tick);
return src_subset_tick;
}

Expand Down Expand Up @@ -277,77 +278,81 @@ std::vector<std::string> GetOpNames(const HashSet<const OpNode*>& op_nodes) {
return ret;
};

void InitOpTypeCase2OpNodes(
Maybe<void> InitOpTypeCase2OpNodes(
const OpGraph& op_graph,
HashMap<OperatorConf::OpTypeCase, HashSet<const OpNode*>>* op_type_case2op_nodes) {
op_graph.ForEachNode([&](OpNode* op_node) {
JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> {
const auto& op_conf = op_node->op().op_conf();
if (IsInterfaceOpConf(op_conf)) {
CHECK((*op_type_case2op_nodes)[op_conf.op_type_case()].emplace(op_node).second);
CHECK_OR_RETURN((*op_type_case2op_nodes)[op_conf.op_type_case()].emplace(op_node).second);
}
});
return Maybe<void>::Ok();
}));
return Maybe<void>::Ok();
}

void ForEachInputCriticalSectionOpNodes(
Maybe<void> ForEachInputCriticalSectionOpNodes(
const OpGraph& op_graph,
const std::function<void(const HashSet<const OpNode*>&, const std::vector<std::string>&)>&
Handler) {
const std::function<Maybe<void>(const HashSet<const OpNode*>&,
const std::vector<std::string>&)>& Handler) {
HashMap<OperatorConf::OpTypeCase, HashSet<const OpNode*>> op_type_case2op_nodes;
InitOpTypeCase2OpNodes(op_graph, &op_type_case2op_nodes);
JUST(InitOpTypeCase2OpNodes(op_graph, &op_type_case2op_nodes));
OperatorConf::OpTypeCase op_type_case = OperatorConf::kInputConf;
if (op_type_case2op_nodes[op_type_case].empty()) { return; }
if (op_type_case2op_nodes[op_type_case].empty()) { return Maybe<void>::Ok(); }
HashSet<const OpNode*> op_nodes = op_type_case2op_nodes[op_type_case];
for (const OpNode* op_node : op_type_case2op_nodes[op_type_case]) {
op_node->ForEachNodeOnOutEdge([&](OpNode* out_node) { op_nodes.insert(out_node); });
}
Handler(op_nodes, GetOpNames(op_type_case2op_nodes[op_type_case]));
JUST(Handler(op_nodes, GetOpNames(op_type_case2op_nodes[op_type_case])));
return Maybe<void>::Ok();
}

void ForEachOutputCriticalSectionOpNodes(
Maybe<void> ForEachOutputCriticalSectionOpNodes(
const OpGraph& op_graph,
const std::function<void(const HashSet<const OpNode*>&, const std::vector<std::string>&)>&
Handler) {
const std::function<Maybe<void>(const HashSet<const OpNode*>&,
const std::vector<std::string>&)>& Handler) {
HashMap<OperatorConf::OpTypeCase, HashSet<const OpNode*>> op_type_case2op_nodes;
InitOpTypeCase2OpNodes(op_graph, &op_type_case2op_nodes);
JUST(InitOpTypeCase2OpNodes(op_graph, &op_type_case2op_nodes));
if (op_type_case2op_nodes[OperatorConf::kReturnConf].empty() == false) {
Handler(op_type_case2op_nodes[OperatorConf::kReturnConf],
GetOpNames(op_type_case2op_nodes[OperatorConf::kReturnConf]));
JUST(Handler(op_type_case2op_nodes[OperatorConf::kReturnConf],
GetOpNames(op_type_case2op_nodes[OperatorConf::kReturnConf])));
}
if (op_type_case2op_nodes[OperatorConf::kOutputConf].empty() == false) {
Handler(op_type_case2op_nodes[OperatorConf::kOutputConf],
GetOpNames(op_type_case2op_nodes[OperatorConf::kOutputConf]));
JUST(Handler(op_type_case2op_nodes[OperatorConf::kOutputConf],
GetOpNames(op_type_case2op_nodes[OperatorConf::kOutputConf])));
}
return Maybe<void>::Ok();
}

std::vector<OperatorConf> AddTickForTimeShape(const Shape& src_time_shape,
const HashSet<const OpNode*>& op_nodes,
JobBuilder* job_builder) {
Maybe<std::vector<OperatorConf>> AddTickForTimeShape(const Shape& src_time_shape,
const HashSet<const OpNode*>& op_nodes,
JobBuilder* job_builder) {
HashMap<std::pair<ParallelDesc, std::pair<Shape, Shape>>, std::list<const OpNode*>>
pd7ts2op_nodes;
for (const OpNode* op_node : op_nodes) {
auto ts = std::make_pair(*CHECK_JUST(op_node->op().GetInputOutputFastestTimeShape()),
*CHECK_JUST(op_node->op().GetOpTimeShape()));
auto ts = std::make_pair(*JUST(op_node->op().GetInputOutputFastestTimeShape()),
*JUST(op_node->op().GetOpTimeShape()));
pd7ts2op_nodes[{op_node->parallel_desc(), ts}].push_back(op_node);
}
std::vector<OperatorConf> op_confs;
for (const auto& pair : pd7ts2op_nodes) {
const std::pair<Shape, Shape>& ts = pair.first.second;
if (ts.second.elem_cnt() == src_time_shape.elem_cnt()) {
CHECK_GE(ts.first.elem_cnt(), ts.second.elem_cnt());
CHECK_GE_OR_RETURN(ts.first.elem_cnt(), ts.second.elem_cnt());
op_confs.push_back(
AppendTick("Append", pair.second, std::make_shared<const Shape>(ts.second), job_builder));
} else if (ts.second.elem_cnt() > src_time_shape.elem_cnt()) {
op_confs.push_back(AppendAccTick(src_time_shape, pair.second, job_builder));
} else {
UNIMPLEMENTED();
UNIMPLEMENTED_THEN_RETURN();
}
}
return op_confs;
}

void AddGlobalInputOutputCriticalSection(const HashSet<const OpNode*>& op_nodes,
const std::vector<std::string>& lbi_producer_op_names,
JobBuilder* job_builder) {
Maybe<void> AddGlobalInputOutputCriticalSection(
const HashSet<const OpNode*>& op_nodes, const std::vector<std::string>& lbi_producer_op_names,
JobBuilder* job_builder) {
auto* critical_section =
Global<CriticalSectionDesc>::Get()->AddCriticalSection(GlobalJobDesc().job_id());
{
Expand All @@ -358,22 +363,20 @@ void AddGlobalInputOutputCriticalSection(const HashSet<const OpNode*>& op_nodes,
auto time_shape = std::make_unique<Shape>(DimVector{1, 1});
HashMap<ParallelDesc, HashSet<const OpNode*>> parallel_desc2op_nodes;
for (const OpNode* op_node : op_nodes) {
CHECK(parallel_desc2op_nodes[op_node->parallel_desc()].insert(op_node).second);
CHECK_OR_RETURN(parallel_desc2op_nodes[op_node->parallel_desc()].insert(op_node).second);
}
std::vector<OperatorConf> source_ticks;
std::vector<OperatorConf> sink_ticks;
for (const auto& pair : parallel_desc2op_nodes) {
source_ticks.push_back(PrependTick(pair.second, job_builder));
for (const auto& sink_tick : AddTickForTimeShape(*time_shape, pair.second, job_builder)) {
sink_ticks.push_back(sink_tick);
}
const auto& ops = JUST(AddTickForTimeShape(*time_shape, pair.second, job_builder));
for (const auto& sink_tick : *ops) { sink_ticks.push_back(sink_tick); }
}
OperatorConf src_subset_tick_op;
{
CHECK_EQ(source_ticks.empty(), false);
CHECK_JUST(BuildSrcSubsetTickOpAndParallelConf(&src_subset_tick_op, job_builder));
CHECK_JUST(
CreateSourceTicksAndSrcSubsetTick(critical_section, &src_subset_tick_op, job_builder));
CHECK_EQ_OR_RETURN(source_ticks.empty(), false);
JUST(BuildSrcSubsetTickOpAndParallelConf(&src_subset_tick_op, job_builder));
JUST(CreateSourceTicksAndSrcSubsetTick(critical_section, &src_subset_tick_op, job_builder));
for (auto& op_conf : source_ticks) {
op_conf.mutable_tick_conf()->add_tick(src_subset_tick_op.name() + "/"
+ src_subset_tick_op.src_subset_tick_conf().out());
Expand All @@ -384,70 +387,86 @@ void AddGlobalInputOutputCriticalSection(const HashSet<const OpNode*>& op_nodes,
for (const auto& op_conf : sink_ticks) {
LogicalBlobId lbi;
lbi.set_op_name(op_conf.name());
CHECK(op_conf.has_device_tick_conf());
CHECK_OR_RETURN(op_conf.has_device_tick_conf());
lbi.set_blob_name(op_conf.device_tick_conf().out());
CHECK(tick_lbis.insert(lbi).second);
CHECK_OR_RETURN(tick_lbis.insert(lbi).second);
}
CHECK_JUST(CreateDstSubsetTickAndSinkTicks(critical_section, src_subset_tick_op, tick_lbis,
job_builder));
JUST(CreateDstSubsetTickAndSinkTicks(critical_section, src_subset_tick_op, tick_lbis,
job_builder));
return Maybe<void>::Ok();
}

} // namespace

void AutoPrependTick(const OpGraph& op_graph, JobBuilder* job_builder) {
Maybe<void> AutoPrependTick(const OpGraph& op_graph, JobBuilder* job_builder) {
PrependTickByParallelDesc(op_graph, job_builder);
OperatorConf src_subset_tick_op;
CHECK_JUST(BuildSrcSubsetTickOpAndParallelConf(&src_subset_tick_op, job_builder));
CHECK_JUST(ConnectSrcSubsetTickAndOtherTick(src_subset_tick_op, job_builder));
JUST(BuildSrcSubsetTickOpAndParallelConf(&src_subset_tick_op, job_builder));
JUST(ConnectSrcSubsetTickAndOtherTick(src_subset_tick_op, job_builder));
return Maybe<void>::Ok();
}

void AddTickForTimeShape(const OpGraph& op_graph, JobBuilder* job_builder) {
const auto& src_time_shape = *CHECK_JUST(GetSrcSubsetTickOpNode(op_graph)->op().GetOpTimeShape());
Maybe<void> AddTickForTimeShape(const OpGraph& op_graph, JobBuilder* job_builder) {
const auto* op_node = JUST(GetSrcSubsetTickOpNode(op_graph));
const auto& src_time_shape = *JUST(op_node->op().GetOpTimeShape());
HashSet<const OpNode*> sink_op_nodes;
op_graph.ForEachNode([&](OpNode* op_node) {
CHECK(!op_node->op().op_conf().has_sink_tick_conf());
JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> {
CHECK_OR_RETURN(!op_node->op().op_conf().has_sink_tick_conf());
size_t out_cnt = 0;
op_graph.ForEachDataAndCtrlOutNode(op_node, [&](OpNode*) { ++out_cnt; });
if (out_cnt == 0) { sink_op_nodes.insert(op_node); }
});
AddTickForTimeShape(src_time_shape, sink_op_nodes, job_builder);
return Maybe<void>::Ok();
}));
JUST(AddTickForTimeShape(src_time_shape, sink_op_nodes, job_builder));
return Maybe<void>::Ok();
}

void AutoSourceAndSinkTick(const OpGraph& op_graph, JobBuilder* job_builder) {
Maybe<void> AutoSourceAndSinkTick(const OpGraph& op_graph, JobBuilder* job_builder) {
auto* critical_section =
Global<CriticalSectionDesc>::Get()->AddCriticalSection(GlobalJobDesc().job_id());
critical_section->mutable_total_job_critical_section();
op_graph.ForEachNode([&](OpNode* node) { CHECK(!node->op().op_conf().has_sink_tick_conf()); });
const auto& src_time_shape = *CHECK_JUST(GetSrcSubsetTickOpNode(op_graph)->op().GetOpTimeShape());
JUST(op_graph.MaybeForEachNode([&](OpNode* node) -> Maybe<void> {
CHECK_OR_RETURN(!node->op().op_conf().has_sink_tick_conf());
return Maybe<void>::Ok();
}));
const auto* op_node = JUST(GetSrcSubsetTickOpNode(op_graph));
const auto& src_time_shape = JUST(op_node->op().GetOpTimeShape());
HashSet<LogicalBlobId> tick_lbis;
op_graph.ForEachNode([&](OpNode* op_node) {
JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> {
size_t out_cnt = 0;
op_graph.ForEachDataAndCtrlOutNode(op_node, [&](OpNode*) { ++out_cnt; });
if (out_cnt > 0) { return; }
CHECK(op_node->op().op_conf().has_device_tick_conf());
CHECK(CHECK_JUST(op_node->op().GetOpTimeShape())->elem_cnt() == src_time_shape.elem_cnt());
CHECK(tick_lbis.emplace(op_node->op().BnInOp2Lbi(op_node->op().SoleObn())).second);
});
OperatorConf src_subset_tick = CHECK_JUST(FindSrcSubsetTickOpConf(job_builder->job()));
CHECK_JUST(CreateSourceTicksAndSrcSubsetTick(critical_section, &src_subset_tick, job_builder));
CHECK_JUST(
CreateDstSubsetTickAndSinkTicks(critical_section, src_subset_tick, tick_lbis, job_builder));
if (out_cnt > 0) { return Maybe<void>::Ok(); }
CHECK_OR_RETURN(op_node->op().op_conf().has_device_tick_conf());
CHECK_OR_RETURN(JUST(op_node->op().GetOpTimeShape())->elem_cnt() == src_time_shape->elem_cnt());
CHECK_OR_RETURN(tick_lbis.emplace(op_node->op().BnInOp2Lbi(op_node->op().SoleObn())).second);
return Maybe<void>::Ok();
}));
OperatorConf src_subset_tick = JUST(FindSrcSubsetTickOpConf(job_builder->job()));
JUST(CreateSourceTicksAndSrcSubsetTick(critical_section, &src_subset_tick, job_builder));
JUST(CreateDstSubsetTickAndSinkTicks(critical_section, src_subset_tick, tick_lbis, job_builder));
return Maybe<void>::Ok();
}

void AddGlobalInputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder) {
ForEachInputCriticalSectionOpNodes(
op_graph, [&](const HashSet<const OpNode*>& op_nodes,
const std::vector<std::string>& lbi_producer_op_names) {
AddGlobalInputOutputCriticalSection(op_nodes, lbi_producer_op_names, job_builder);
});
Maybe<void> AddGlobalInputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder) {
JUST(ForEachInputCriticalSectionOpNodes(
op_graph,
[&](const HashSet<const OpNode*>& op_nodes,
const std::vector<std::string>& lbi_producer_op_names) -> Maybe<void> {
JUST(AddGlobalInputOutputCriticalSection(op_nodes, lbi_producer_op_names, job_builder));
return Maybe<void>::Ok();
}));
return Maybe<void>::Ok();
}

void AddGlobalOutputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder) {
ForEachOutputCriticalSectionOpNodes(
op_graph, [&](const HashSet<const OpNode*>& op_nodes,
const std::vector<std::string>& lbi_producer_op_names) {
AddGlobalInputOutputCriticalSection(op_nodes, lbi_producer_op_names, job_builder);
});
Maybe<void> AddGlobalOutputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder) {
JUST(ForEachOutputCriticalSectionOpNodes(
op_graph,
[&](const HashSet<const OpNode*>& op_nodes,
const std::vector<std::string>& lbi_producer_op_names) -> Maybe<void> {
JUST(AddGlobalInputOutputCriticalSection(op_nodes, lbi_producer_op_names, job_builder));
return Maybe<void>::Ok();
}));
return Maybe<void>::Ok();
}

} // namespace oneflow
10 changes: 5 additions & 5 deletions oneflow/core/job_rewriter/autotick.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ limitations under the License.

namespace oneflow {

void AutoPrependTick(const OpGraph& op_graph, JobBuilder* job_builder);
void AddTickForTimeShape(const OpGraph& op_graph, JobBuilder* job_builder);
void AutoSourceAndSinkTick(const OpGraph& op_graph, JobBuilder* job_builder);
void AddGlobalInputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder);
void AddGlobalOutputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder);
Maybe<void> AutoPrependTick(const OpGraph& op_graph, JobBuilder* job_builder);
Maybe<void> AddTickForTimeShape(const OpGraph& op_graph, JobBuilder* job_builder);
Maybe<void> AutoSourceAndSinkTick(const OpGraph& op_graph, JobBuilder* job_builder);
Maybe<void> AddGlobalInputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder);
Maybe<void> AddGlobalOutputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder);

class MutOpConTickInputHelper {
public:
Expand Down
7 changes: 4 additions & 3 deletions oneflow/core/job_rewriter/group_boxing_by_dst_parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ limitations under the License.

namespace oneflow {

void GroupBoxingByDstParallel(const OpGraph& op_graph, JobBuilder* job_builder) {
Maybe<void> GroupBoxingByDstParallel(const OpGraph& op_graph, JobBuilder* job_builder) {
HashMap<LogicalBlobId, HashMap<std::pair<ParallelDesc, cfg::ParallelDistribution>,
std::vector<std::pair<const OpNode*, std::string>>>>
lbi2consumer_grouped_by_parallel;
Expand Down Expand Up @@ -76,13 +76,14 @@ void GroupBoxingByDstParallel(const OpGraph& op_graph, JobBuilder* job_builder)
OperatorConf& consumer_op_conf = op_node2op_conf[consumer];
const auto& old_val = ReplaceInputLbnInOpCustomizedConf(&consumer_op_conf, ibn,
GenLogicalBlobName(grouped_lbi));
CHECK_EQ(GenLogicalBlobName(lbi), old_val);
CHECK_EQ_OR_RETURN(GenLogicalBlobName(lbi), old_val);
}
}
}
for (const auto& op_node7op_conf : op_node2op_conf) {
job_builder->MutOpsOnlyOnce({op_node7op_conf.second});
JUST(job_builder->MutOpOnlyOnce(op_node7op_conf.second));
}
return Maybe<void>::Ok();
}

} // namespace oneflow
2 changes: 1 addition & 1 deletion oneflow/core/job_rewriter/group_boxing_by_dst_parallel.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace oneflow {
class OpGraph;
class Job;

void GroupBoxingByDstParallel(const OpGraph& op_graph, JobBuilder* job_builder);
Maybe<void> GroupBoxingByDstParallel(const OpGraph& op_graph, JobBuilder* job_builder);

} // namespace oneflow

Expand Down

0 comments on commit 50e1c34

Please sign in to comment.