Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Job pass maybe system #5503

Merged
merged 14 commits into from
Jul 16, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 1 addition & 2 deletions oneflow/core/graph/task_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -532,8 +532,7 @@ void TaskGraph::ConnectCtrlEdges(const std::vector<CompTaskNode*>& src_task_node
CHECK_EQ(src_task_nodes.size(), dst_task_nodes.size());
FOR_RANGE(int32_t, i, 0, src_task_nodes.size()) {
std::string regst_desc_name;
RegstDesc* ctrl_regst_desc =
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

解决一直有的编译警告unused variable。

src_task_nodes.at(i)->BuildCtrlRegstDesc(dst_task_nodes.at(i), &regst_desc_name);
src_task_nodes.at(i)->BuildCtrlRegstDesc(dst_task_nodes.at(i), &regst_desc_name);
TaskEdge* edge = NewEdge();
Connect<TaskNode>(src_task_nodes.at(i), edge, dst_task_nodes.at(i));
src_task_nodes.at(i)->BindEdgeWithProducedRegst(edge, regst_desc_name);
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/job/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ void CreateOpAttributeRef(Plan* plan, int64_t job_id, TaskProto* task_proto) {

void Compiler::Compile(Job* job, Plan* plan, bool need_job_complete) const {
// Step1: ensure job is completed.
if (need_job_complete) { JobCompleter().Complete(job); }
if (need_job_complete) { CHECK_JUST(JobCompleter().Complete(job)); }

// Step2: new Global<OpGraph> and set log configs.
Global<OpGraph>::New(*job);
Expand Down
179 changes: 99 additions & 80 deletions oneflow/core/job_rewriter/autotick.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,21 +171,22 @@ Maybe<void> ConnectSrcSubsetTickAndOtherTick(const OperatorConf& src_subset_tick
auto mut_helper = NewMutOpConTickInputHelper(op.op_conf());
if (!mut_helper) { return Maybe<void>::Ok(); }
if (mut_helper->IsTickInputBound() == true) { return Maybe<void>::Ok(); }
job_builder->MutOpsOnlyOnce({mut_helper->NewTickInputBoundOpConf(src_lbn)});
JUST(job_builder->MutOpOnlyOnce(mut_helper->NewTickInputBoundOpConf(src_lbn)));
return Maybe<void>::Ok();
}));
return Maybe<void>::Ok();
}

const OpNode* GetSrcSubsetTickOpNode(const OpGraph& op_graph) {
Maybe<const OpNode*> GetSrcSubsetTickOpNode(const OpGraph& op_graph) {
const OpNode* src_subset_tick = nullptr;
op_graph.ForEachNode([&](OpNode* op_node) {
JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> {
if (op_node->op().op_conf().has_src_subset_tick_conf()) {
CHECK_ISNULL(src_subset_tick);
CHECK_ISNULL_OR_RETURN(src_subset_tick);
src_subset_tick = op_node;
}
});
CHECK_NOTNULL(src_subset_tick);
return Maybe<void>::Ok();
}));
CHECK_NOTNULL_OR_RETURN(src_subset_tick);
return src_subset_tick;
}

Expand Down Expand Up @@ -277,77 +278,81 @@ std::vector<std::string> GetOpNames(const HashSet<const OpNode*>& op_nodes) {
return ret;
};

void InitOpTypeCase2OpNodes(
Maybe<void> InitOpTypeCase2OpNodes(
const OpGraph& op_graph,
HashMap<OperatorConf::OpTypeCase, HashSet<const OpNode*>>* op_type_case2op_nodes) {
op_graph.ForEachNode([&](OpNode* op_node) {
JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> {
const auto& op_conf = op_node->op().op_conf();
if (IsInterfaceOpConf(op_conf)) {
CHECK((*op_type_case2op_nodes)[op_conf.op_type_case()].emplace(op_node).second);
CHECK_OR_RETURN((*op_type_case2op_nodes)[op_conf.op_type_case()].emplace(op_node).second);
}
});
return Maybe<void>::Ok();
}));
return Maybe<void>::Ok();
}

void ForEachInputCriticalSectionOpNodes(
Maybe<void> ForEachInputCriticalSectionOpNodes(
const OpGraph& op_graph,
const std::function<void(const HashSet<const OpNode*>&, const std::vector<std::string>&)>&
Handler) {
const std::function<Maybe<void>(const HashSet<const OpNode*>&,
const std::vector<std::string>&)>& Handler) {
HashMap<OperatorConf::OpTypeCase, HashSet<const OpNode*>> op_type_case2op_nodes;
InitOpTypeCase2OpNodes(op_graph, &op_type_case2op_nodes);
JUST(InitOpTypeCase2OpNodes(op_graph, &op_type_case2op_nodes));
OperatorConf::OpTypeCase op_type_case = OperatorConf::kInputConf;
if (op_type_case2op_nodes[op_type_case].empty()) { return; }
if (op_type_case2op_nodes[op_type_case].empty()) { return Maybe<void>::Ok(); }
HashSet<const OpNode*> op_nodes = op_type_case2op_nodes[op_type_case];
for (const OpNode* op_node : op_type_case2op_nodes[op_type_case]) {
op_node->ForEachNodeOnOutEdge([&](OpNode* out_node) { op_nodes.insert(out_node); });
}
Handler(op_nodes, GetOpNames(op_type_case2op_nodes[op_type_case]));
JUST(Handler(op_nodes, GetOpNames(op_type_case2op_nodes[op_type_case])));
return Maybe<void>::Ok();
}

void ForEachOutputCriticalSectionOpNodes(
Maybe<void> ForEachOutputCriticalSectionOpNodes(
const OpGraph& op_graph,
const std::function<void(const HashSet<const OpNode*>&, const std::vector<std::string>&)>&
Handler) {
const std::function<Maybe<void>(const HashSet<const OpNode*>&,
const std::vector<std::string>&)>& Handler) {
HashMap<OperatorConf::OpTypeCase, HashSet<const OpNode*>> op_type_case2op_nodes;
InitOpTypeCase2OpNodes(op_graph, &op_type_case2op_nodes);
JUST(InitOpTypeCase2OpNodes(op_graph, &op_type_case2op_nodes));
if (op_type_case2op_nodes[OperatorConf::kReturnConf].empty() == false) {
Handler(op_type_case2op_nodes[OperatorConf::kReturnConf],
GetOpNames(op_type_case2op_nodes[OperatorConf::kReturnConf]));
JUST(Handler(op_type_case2op_nodes[OperatorConf::kReturnConf],
GetOpNames(op_type_case2op_nodes[OperatorConf::kReturnConf])));
}
if (op_type_case2op_nodes[OperatorConf::kOutputConf].empty() == false) {
Handler(op_type_case2op_nodes[OperatorConf::kOutputConf],
GetOpNames(op_type_case2op_nodes[OperatorConf::kOutputConf]));
JUST(Handler(op_type_case2op_nodes[OperatorConf::kOutputConf],
GetOpNames(op_type_case2op_nodes[OperatorConf::kOutputConf])));
}
return Maybe<void>::Ok();
}

std::vector<OperatorConf> AddTickForTimeShape(const Shape& src_time_shape,
const HashSet<const OpNode*>& op_nodes,
JobBuilder* job_builder) {
Maybe<std::vector<OperatorConf>> AddTickForTimeShape(const Shape& src_time_shape,
const HashSet<const OpNode*>& op_nodes,
JobBuilder* job_builder) {
HashMap<std::pair<ParallelDesc, std::pair<Shape, Shape>>, std::list<const OpNode*>>
pd7ts2op_nodes;
for (const OpNode* op_node : op_nodes) {
auto ts = std::make_pair(*CHECK_JUST(op_node->op().GetInputOutputFastestTimeShape()),
*CHECK_JUST(op_node->op().GetOpTimeShape()));
auto ts = std::make_pair(*JUST(op_node->op().GetInputOutputFastestTimeShape()),
*JUST(op_node->op().GetOpTimeShape()));
pd7ts2op_nodes[{op_node->parallel_desc(), ts}].push_back(op_node);
}
std::vector<OperatorConf> op_confs;
for (const auto& pair : pd7ts2op_nodes) {
const std::pair<Shape, Shape>& ts = pair.first.second;
if (ts.second.elem_cnt() == src_time_shape.elem_cnt()) {
CHECK_GE(ts.first.elem_cnt(), ts.second.elem_cnt());
CHECK_GE_OR_RETURN(ts.first.elem_cnt(), ts.second.elem_cnt());
op_confs.push_back(
AppendTick("Append", pair.second, std::make_shared<const Shape>(ts.second), job_builder));
} else if (ts.second.elem_cnt() > src_time_shape.elem_cnt()) {
op_confs.push_back(AppendAccTick(src_time_shape, pair.second, job_builder));
} else {
UNIMPLEMENTED();
UNIMPLEMENTED_THEN_RETURN();
}
}
return op_confs;
}

void AddGlobalInputOutputCriticalSection(const HashSet<const OpNode*>& op_nodes,
const std::vector<std::string>& lbi_producer_op_names,
JobBuilder* job_builder) {
Maybe<void> AddGlobalInputOutputCriticalSection(
const HashSet<const OpNode*>& op_nodes, const std::vector<std::string>& lbi_producer_op_names,
JobBuilder* job_builder) {
auto* critical_section =
Global<CriticalSectionDesc>::Get()->AddCriticalSection(GlobalJobDesc().job_id());
{
Expand All @@ -358,22 +363,20 @@ void AddGlobalInputOutputCriticalSection(const HashSet<const OpNode*>& op_nodes,
auto time_shape = std::make_unique<Shape>(DimVector{1, 1});
HashMap<ParallelDesc, HashSet<const OpNode*>> parallel_desc2op_nodes;
for (const OpNode* op_node : op_nodes) {
CHECK(parallel_desc2op_nodes[op_node->parallel_desc()].insert(op_node).second);
CHECK_OR_RETURN(parallel_desc2op_nodes[op_node->parallel_desc()].insert(op_node).second);
}
std::vector<OperatorConf> source_ticks;
std::vector<OperatorConf> sink_ticks;
for (const auto& pair : parallel_desc2op_nodes) {
source_ticks.push_back(PrependTick(pair.second, job_builder));
for (const auto& sink_tick : AddTickForTimeShape(*time_shape, pair.second, job_builder)) {
sink_ticks.push_back(sink_tick);
}
const auto& ops = JUST(AddTickForTimeShape(*time_shape, pair.second, job_builder));
for (const auto& sink_tick : *ops) { sink_ticks.push_back(sink_tick); }
}
OperatorConf src_subset_tick_op;
{
CHECK_EQ(source_ticks.empty(), false);
CHECK_JUST(BuildSrcSubsetTickOpAndParallelConf(&src_subset_tick_op, job_builder));
CHECK_JUST(
CreateSourceTicksAndSrcSubsetTick(critical_section, &src_subset_tick_op, job_builder));
CHECK_EQ_OR_RETURN(source_ticks.empty(), false);
JUST(BuildSrcSubsetTickOpAndParallelConf(&src_subset_tick_op, job_builder));
JUST(CreateSourceTicksAndSrcSubsetTick(critical_section, &src_subset_tick_op, job_builder));
for (auto& op_conf : source_ticks) {
op_conf.mutable_tick_conf()->add_tick(src_subset_tick_op.name() + "/"
+ src_subset_tick_op.src_subset_tick_conf().out());
Expand All @@ -384,70 +387,86 @@ void AddGlobalInputOutputCriticalSection(const HashSet<const OpNode*>& op_nodes,
for (const auto& op_conf : sink_ticks) {
LogicalBlobId lbi;
lbi.set_op_name(op_conf.name());
CHECK(op_conf.has_device_tick_conf());
CHECK_OR_RETURN(op_conf.has_device_tick_conf());
lbi.set_blob_name(op_conf.device_tick_conf().out());
CHECK(tick_lbis.insert(lbi).second);
CHECK_OR_RETURN(tick_lbis.insert(lbi).second);
}
CHECK_JUST(CreateDstSubsetTickAndSinkTicks(critical_section, src_subset_tick_op, tick_lbis,
job_builder));
JUST(CreateDstSubsetTickAndSinkTicks(critical_section, src_subset_tick_op, tick_lbis,
job_builder));
return Maybe<void>::Ok();
}

} // namespace

void AutoPrependTick(const OpGraph& op_graph, JobBuilder* job_builder) {
Maybe<void> AutoPrependTick(const OpGraph& op_graph, JobBuilder* job_builder) {
PrependTickByParallelDesc(op_graph, job_builder);
OperatorConf src_subset_tick_op;
CHECK_JUST(BuildSrcSubsetTickOpAndParallelConf(&src_subset_tick_op, job_builder));
CHECK_JUST(ConnectSrcSubsetTickAndOtherTick(src_subset_tick_op, job_builder));
JUST(BuildSrcSubsetTickOpAndParallelConf(&src_subset_tick_op, job_builder));
JUST(ConnectSrcSubsetTickAndOtherTick(src_subset_tick_op, job_builder));
return Maybe<void>::Ok();
}

void AddTickForTimeShape(const OpGraph& op_graph, JobBuilder* job_builder) {
const auto& src_time_shape = *CHECK_JUST(GetSrcSubsetTickOpNode(op_graph)->op().GetOpTimeShape());
Maybe<void> AddTickForTimeShape(const OpGraph& op_graph, JobBuilder* job_builder) {
const auto* op_node = JUST(GetSrcSubsetTickOpNode(op_graph));
const auto& src_time_shape = *JUST(op_node->op().GetOpTimeShape());
HashSet<const OpNode*> sink_op_nodes;
op_graph.ForEachNode([&](OpNode* op_node) {
CHECK(!op_node->op().op_conf().has_sink_tick_conf());
JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> {
CHECK_OR_RETURN(!op_node->op().op_conf().has_sink_tick_conf());
size_t out_cnt = 0;
op_graph.ForEachDataAndCtrlOutNode(op_node, [&](OpNode*) { ++out_cnt; });
if (out_cnt == 0) { sink_op_nodes.insert(op_node); }
});
AddTickForTimeShape(src_time_shape, sink_op_nodes, job_builder);
return Maybe<void>::Ok();
}));
JUST(AddTickForTimeShape(src_time_shape, sink_op_nodes, job_builder));
return Maybe<void>::Ok();
}

void AutoSourceAndSinkTick(const OpGraph& op_graph, JobBuilder* job_builder) {
Maybe<void> AutoSourceAndSinkTick(const OpGraph& op_graph, JobBuilder* job_builder) {
auto* critical_section =
Global<CriticalSectionDesc>::Get()->AddCriticalSection(GlobalJobDesc().job_id());
critical_section->mutable_total_job_critical_section();
op_graph.ForEachNode([&](OpNode* node) { CHECK(!node->op().op_conf().has_sink_tick_conf()); });
const auto& src_time_shape = *CHECK_JUST(GetSrcSubsetTickOpNode(op_graph)->op().GetOpTimeShape());
JUST(op_graph.MaybeForEachNode([&](OpNode* node) -> Maybe<void> {
CHECK_OR_RETURN(!node->op().op_conf().has_sink_tick_conf());
return Maybe<void>::Ok();
}));
const auto* op_node = JUST(GetSrcSubsetTickOpNode(op_graph));
const auto& src_time_shape = JUST(op_node->op().GetOpTimeShape());
HashSet<LogicalBlobId> tick_lbis;
op_graph.ForEachNode([&](OpNode* op_node) {
JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> {
size_t out_cnt = 0;
op_graph.ForEachDataAndCtrlOutNode(op_node, [&](OpNode*) { ++out_cnt; });
if (out_cnt > 0) { return; }
CHECK(op_node->op().op_conf().has_device_tick_conf());
CHECK(CHECK_JUST(op_node->op().GetOpTimeShape())->elem_cnt() == src_time_shape.elem_cnt());
CHECK(tick_lbis.emplace(op_node->op().BnInOp2Lbi(op_node->op().SoleObn())).second);
});
OperatorConf src_subset_tick = CHECK_JUST(FindSrcSubsetTickOpConf(job_builder->job()));
CHECK_JUST(CreateSourceTicksAndSrcSubsetTick(critical_section, &src_subset_tick, job_builder));
CHECK_JUST(
CreateDstSubsetTickAndSinkTicks(critical_section, src_subset_tick, tick_lbis, job_builder));
if (out_cnt > 0) { return Maybe<void>::Ok(); }
CHECK_OR_RETURN(op_node->op().op_conf().has_device_tick_conf());
CHECK_OR_RETURN(JUST(op_node->op().GetOpTimeShape())->elem_cnt() == src_time_shape->elem_cnt());
CHECK_OR_RETURN(tick_lbis.emplace(op_node->op().BnInOp2Lbi(op_node->op().SoleObn())).second);
return Maybe<void>::Ok();
}));
OperatorConf src_subset_tick = JUST(FindSrcSubsetTickOpConf(job_builder->job()));
JUST(CreateSourceTicksAndSrcSubsetTick(critical_section, &src_subset_tick, job_builder));
JUST(CreateDstSubsetTickAndSinkTicks(critical_section, src_subset_tick, tick_lbis, job_builder));
return Maybe<void>::Ok();
}

void AddGlobalInputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder) {
ForEachInputCriticalSectionOpNodes(
op_graph, [&](const HashSet<const OpNode*>& op_nodes,
const std::vector<std::string>& lbi_producer_op_names) {
AddGlobalInputOutputCriticalSection(op_nodes, lbi_producer_op_names, job_builder);
});
Maybe<void> AddGlobalInputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder) {
JUST(ForEachInputCriticalSectionOpNodes(
op_graph,
[&](const HashSet<const OpNode*>& op_nodes,
const std::vector<std::string>& lbi_producer_op_names) -> Maybe<void> {
JUST(AddGlobalInputOutputCriticalSection(op_nodes, lbi_producer_op_names, job_builder));
return Maybe<void>::Ok();
}));
return Maybe<void>::Ok();
}

void AddGlobalOutputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder) {
ForEachOutputCriticalSectionOpNodes(
op_graph, [&](const HashSet<const OpNode*>& op_nodes,
const std::vector<std::string>& lbi_producer_op_names) {
AddGlobalInputOutputCriticalSection(op_nodes, lbi_producer_op_names, job_builder);
});
Maybe<void> AddGlobalOutputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder) {
JUST(ForEachOutputCriticalSectionOpNodes(
op_graph,
[&](const HashSet<const OpNode*>& op_nodes,
const std::vector<std::string>& lbi_producer_op_names) -> Maybe<void> {
JUST(AddGlobalInputOutputCriticalSection(op_nodes, lbi_producer_op_names, job_builder));
return Maybe<void>::Ok();
}));
return Maybe<void>::Ok();
}

} // namespace oneflow
10 changes: 5 additions & 5 deletions oneflow/core/job_rewriter/autotick.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ limitations under the License.

namespace oneflow {

void AutoPrependTick(const OpGraph& op_graph, JobBuilder* job_builder);
void AddTickForTimeShape(const OpGraph& op_graph, JobBuilder* job_builder);
void AutoSourceAndSinkTick(const OpGraph& op_graph, JobBuilder* job_builder);
void AddGlobalInputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder);
void AddGlobalOutputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder);
Maybe<void> AutoPrependTick(const OpGraph& op_graph, JobBuilder* job_builder);
Maybe<void> AddTickForTimeShape(const OpGraph& op_graph, JobBuilder* job_builder);
Maybe<void> AutoSourceAndSinkTick(const OpGraph& op_graph, JobBuilder* job_builder);
Maybe<void> AddGlobalInputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder);
Maybe<void> AddGlobalOutputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder);

class MutOpConTickInputHelper {
public:
Expand Down
7 changes: 4 additions & 3 deletions oneflow/core/job_rewriter/group_boxing_by_dst_parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ limitations under the License.

namespace oneflow {

void GroupBoxingByDstParallel(const OpGraph& op_graph, JobBuilder* job_builder) {
Maybe<void> GroupBoxingByDstParallel(const OpGraph& op_graph, JobBuilder* job_builder) {
HashMap<LogicalBlobId, HashMap<std::pair<ParallelDesc, cfg::ParallelDistribution>,
std::vector<std::pair<const OpNode*, std::string>>>>
lbi2consumer_grouped_by_parallel;
Expand Down Expand Up @@ -76,13 +76,14 @@ void GroupBoxingByDstParallel(const OpGraph& op_graph, JobBuilder* job_builder)
OperatorConf& consumer_op_conf = op_node2op_conf[consumer];
const auto& old_val = ReplaceInputLbnInOpCustomizedConf(&consumer_op_conf, ibn,
GenLogicalBlobName(grouped_lbi));
CHECK_EQ(GenLogicalBlobName(lbi), old_val);
CHECK_EQ_OR_RETURN(GenLogicalBlobName(lbi), old_val);
}
}
}
for (const auto& op_node7op_conf : op_node2op_conf) {
job_builder->MutOpsOnlyOnce({op_node7op_conf.second});
JUST(job_builder->MutOpOnlyOnce(op_node7op_conf.second));
}
return Maybe<void>::Ok();
}

} // namespace oneflow
2 changes: 1 addition & 1 deletion oneflow/core/job_rewriter/group_boxing_by_dst_parallel.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace oneflow {
class OpGraph;
class Job;

void GroupBoxingByDstParallel(const OpGraph& op_graph, JobBuilder* job_builder);
Maybe<void> GroupBoxingByDstParallel(const OpGraph& op_graph, JobBuilder* job_builder);

} // namespace oneflow

Expand Down