-
Notifications
You must be signed in to change notification settings - Fork 754
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Plan rank compiler #10141
Plan rank compiler #10141
Conversation
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally. |
@@ -622,4 +649,17 @@ void OpGraph::PrintSBPGraphDebugInfo() const { | |||
} | |||
} | |||
|
|||
OpGraphSingletonGuard::OpGraphSingletonGuard(const Job& job) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
提供一个 RAII 风格的 OpGraph
comp_task_node->set_thrd_id(EncodeStreamIdToInt64(StreamId{device_id, stream_index})); | ||
comp_task_node->set_op_node(op_node); | ||
sorted_comp_tasks->emplace_back(comp_task_node); | ||
sorted_comp_tasks->emplace_back(GenCompTaskNode(op_node, parallel_idx++, &GetStreamId)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里只是把上面的逻辑拆分成了个一个 GenCompTaskNode 函数。
@@ -491,47 +533,6 @@ Maybe<void> RegisterCreateSubTskGphBuilderFn(DeviceType device_type, | |||
return Maybe<void>::Ok(); | |||
} | |||
|
|||
TaskGraph::TaskGraph() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
task graph 的构造函数放到了子类中
InplaceObasInfo safe_inplace_obas_info; | ||
GetSafeInplaceOpBlobArgList(&safe_inplace_obas_info, dev_nodes, IsOpNameDataOrCtrlReachable); | ||
SetTaskRegstInplaceInfo(safe_inplace_obas_info, dev_nodes); | ||
EnableInplaceMemSharing(dev_nodes, IsOpNameDataOrCtrlReachable); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
只是拆分了函数
|| (straighten_algorithm_tag == StraightenAlgorithmTag::kOverlap4Transfer | ||
&& GlobalProcessCtx::WorldSize() == 1)) { | ||
InitOrderedTaskNodes(); | ||
Maybe<void> GlobalTaskGraph::Init() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
原 task graph,完整的 task graph
@@ -237,7 +237,7 @@ void GenChunkForMultiNNGraphMemoryReuseInMultiClient( | |||
|
|||
} // namespace | |||
|
|||
void PlanUtil::MergeMemBlockIdByLogicalChainId(Plan* plan, const Job& job) { | |||
void PlanUtil::MergeMemBlockIdByLogicalChainId(Plan* plan, const Job& job, int64_t limited_rank) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
分离编译时,需要过滤下合法 rank
@@ -801,10 +816,13 @@ std::function<RegstDescProto*(int64_t)> PlanUtil::MakeMutRegstDesc4Id(Plan* plan | |||
}; | |||
} | |||
|
|||
void PlanUtil::SetForceInplaceMemBlock(Plan* plan) { | |||
void PlanUtil::SetForceInplaceMemBlock(Plan* plan, int64_t limited_rank) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
分离编译时,需要过滤下合法 rank
#ifdef WITH_CUDA | ||
// Use the right device when some plan compilation needs cuda to avoid creating unnecessary cuda | ||
// context on cuda:0. | ||
CudaCurrentDeviceGuard guard(GetCudaDeviceIndex()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
解决分离编译时多创建 cuda context 问题
|
||
} // namespace | ||
|
||
Maybe<void> RankCompiler::Compile(const HashSet<std::string>& var_op_names, Job* job, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
相对于原来的全局 compiler 的分 rank compiler
// context on cuda:0. | ||
CudaCurrentDeviceGuard guard(GetCudaDeviceIndex()); | ||
#endif // WITH_CUDA | ||
auto task_gph = JUST(RankTaskGraph::New(boxing_task_graph_proto_, var_op_names, rank_)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
从 boxing task graph 开始继续编译
…/oneflow into sep4_rank_task_graph
Speed stats:
|
IntraJobMemSharingUtil::InferMemBlockId4MemReusedRegst(plan); | ||
PlanUtil::MergeMemBlockIdByLogicalChainId(plan, *job, rank_); | ||
PlanUtil::SetUniqueMemBlockId4UnreusedMemRegst(plan); | ||
PlanUtil::SetForceInplaceMemBlock(plan, rank_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的其它流程和之前的 plan compiler 是相似的,只是在特定的地方需要考虑下单 rank 的处理。
@@ -27,26 +27,6 @@ limitations under the License. | |||
|
|||
namespace oneflow { | |||
|
|||
void CreateOpAttributeRef(Plan* plan, int64_t job_id, TaskProto* task_proto) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
改到了 PlanUtil 中作为公共函数
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/10141/ |
Speed stats:
|
@@ -399,6 +399,7 @@ message OperatorConf { | |||
optional string loc = 11 [default = ""]; | |||
optional int64 logical_chain_id = 12 [default = -1]; | |||
optional int64 order_in_logical_chain = 13 [default = -1]; | |||
optional string calculation_pass_name = 14 [default = "forward_pass"]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
额。。。好像:
optional string pass_tag = 10;
就是这个 calculation pass name,你查一下这个关键字
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看了下 pass_tag 的取值范围如下:
static const std::string kNoPassTag = "";
static const std::string kMainOp = "main_op";
而 calculation_pass_name 取值范围如下
const std::string kForwardPass = "forward_pass";
const std::string kBackwardPass = "backward_pass";
const std::string kOptimizerPass = "optimizer_pass";
看起来不是一个东西。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
😂 sorry 我记错了
Maybe<void> Graph<NodeType, EdgeType>::MaybeForEachEdge( | ||
std::function<Maybe<void>(EdgeType*)> EdgeHandler) const { | ||
for (auto& x : edges_) { | ||
if (x->src_node() == nullptr && x->dst_node() == nullptr) { continue; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个在什么情况下会是 nullptr 呢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个在什么情况下会是 nullptr 呢
在之前的没有 Maybe 的基础上改的,估计是为了严谨。因为默认情况下,edge 初始化时的 src_node 和 dst_node 都是 nullptr
template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::ForEachEdge(std::function<void(EdgeType*)> EdgeHandler) const {
for (auto& x : edges_) {
if (x->src_node() == nullptr && x->dst_node() == nullptr) { continue; }
EdgeHandler(x.get());
}
}
oneflow/core/graph/op_graph.cpp
Outdated
for (const auto& lbi : *lbis_) { | ||
const auto& obn = CHECK_JUST(MapAt(*lbi2obn_, lbi)); | ||
for (const auto& ibn : CHECK_JUST(MapAt(*lbi2ibns_, lbi))) { | ||
if (src_node()->NdSbp4BnInOp(obn) != dst_node()->NdSbp4BnInOp(ibn)) { return true; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个判断标准不严谨。在某些 2d case 下会失效,参考:
// NOTE(chengcheng): nd_sbp need to be reduction like from [P, P] to [P] |
这里需要做 reduce 判断。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个判断标准不严谨。在某些 2d case 下会失效,参考:
这里需要做 reduce 判断。
done in bded88f
oneflow/core/graph/op_graph.cpp
Outdated
} | ||
|
||
void OpGraph::UpdateCachedPredicatorIsReachable() { | ||
cached_predicator_is_reachable_ = MakePredicatorIsReachable(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为什么会有 update 的需求?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为什么会有 update 的需求?
之前提供了单进程多线程编译模式,主线程生成可达关系的 cache,多个编译子 plan 的线程可以复用。以达到缩减开销的目的。
但是我们去掉了单进程多线程编译这个 debug 模式。默认采用的是多进程模式,这样每个进程里面必然还是会生成一个可达关系的lambda,且没有复用。所以这个优化失去价值了。我删掉这部分。
} // namespace | ||
|
||
/*static*/ bool BoxingTaskGraph::SelectTaskNodeByRank(TaskNode* task_node, int64_t rank) { | ||
return TaskNodeVisitor<bool>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数写的好晦涩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数写的好晦涩
因为有几个地方用了同样的处理逻辑,所以提供了这样一个模版函数。
- task_node,输入的 task node
- HandleTansportTaskNodeT,如果task node是 transport task node,就调用这个处理函数(visit)
- HandleComputeTaskNodeT,如果task node是 compute task node,就调用这个处理函数
- RetT,函数返回类型
我在函数声明处补充下注释吧。
template<typename RetT, typename HandleTansportTaskNodeT, typename HandleComputeTaskNodeT>
RetT TaskNodeVisitor(TaskNode* task_node, const HandleTansportTaskNodeT& HandleTansportTaskNode,
const HandleComputeTaskNodeT& HandleComputeTaskNode)
for (auto* out_edge : task_node->out_edges()) { TryInsertEdge(out_edge); } | ||
} | ||
return rank_task_edges; | ||
}(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个括号可以写在 {} 后面吗?
lambda = [&] { ... } ();
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个括号可以写在 {} 后面吗?
lambda = [&] { ... } ();
这里相当于是 lambda = [&] { ... } ;然后 const auto rank_task_edges = lambda()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
哦哦哦懂了,其实是省略了:
lambda = [&] () { ... };
lambda = [&] { ... };
oneflow/core/graph/task_graph.cpp
Outdated
int64_t parallel_id) { | ||
auto* comp_task_node = JUST(TryGetBoxingRelatedComTaskNode(op_node, parallel_id)); | ||
if (comp_task_node != nullptr) { return comp_task_node; } | ||
auto** comp_task_node_ptr = &op_node2comp_task_node_[op_node]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的写法也很奇怪,为什么不是 find in map return, or create ,应该可以不用二级指针。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的写法也很奇怪,为什么不是 find in map return, or create ,应该可以不用二级指针。
done,已经修改
oneflow/core/graph/task_graph.cpp
Outdated
<< "parallel_id not found."; | ||
auto* comp_task_node = JUST(TryGetBoxingRelatedComTaskNode(op_node, parallel_id)); | ||
if (comp_task_node != nullptr) { return comp_task_node; } | ||
return op_node2comp_task_node_[op_node]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这有有可能返回 nullptr 吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这有有可能返回 nullptr 吗
compute task node 已经生成完成了,按说不会返回 nullptr。
我增加下查找不到时的报错提示。
…/oneflow into sep4_rank_task_graph
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/10141/ |
Speed stats:
|
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/10141/ |
Speed stats:
|
TaskEdge* edge = NewEdge(); | ||
Connect<TaskNode>(src_task_nodes.at(i), edge, dst_task_nodes.at(i)); | ||
src_task_nodes.at(i)->BindEdgeWithProducedRegst(edge, regst_desc_name); | ||
ConnectCtrlEdge(src_task_nodes.at(i), dst_task_nodes.at(i)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
那其实没有这个开销 😂
for (auto* out_edge : task_node->out_edges()) { TryInsertEdge(out_edge); } | ||
} | ||
return rank_task_edges; | ||
}(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
哦哦哦懂了,其实是省略了:
lambda = [&] () { ... };
lambda = [&] { ... };
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/10141/ |
Speed stats:
|
CI failed when running job: cuda-module. PR label automerge has been removed |
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/10141/ |
Speed stats:
|
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/10141/ |
Speed stats:
|
No description provided.