-
Notifications
You must be signed in to change notification settings - Fork 756
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add consume fake regst #10140
Add consume fake regst #10140
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Speed stats:
|
void EraseFakeRegstsIf() override; | ||
|
||
// ConsumeFakeRegsts is used for initializing CompTaskNode.consumed_regsts_ on the other ranks. | ||
virtual void ConsumeFakeRegsts() = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BoxingTaskGraph 中的 TransportTaskNode 会带有上游的多个 rank 的 ComputeTaskNode,当做 task graph 的 rank 内 infer BlobDesc 时,非本 rank 的 input regst 是不存在的,但是需要一个 fake 数据来保证正常的推理逻辑可以通过。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
本质上这个是一个取巧的临时绕过去的办法。 未来要重构出 task graph build 不依赖 input regst 的版本
// Provide a compute task node with a fake input regst, and its output regst can be inferred using | ||
// SBP + Placement. The fake compute task node can help the task graph of one rank to infer blob | ||
// desc, mainly to ensure that the transport task node has the correct input blob desc. | ||
class FakeConsumedRegstProvider { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mixin 接口类
@@ -265,16 +297,33 @@ RegstDesc* TaskNode::BuildCtrlRegstDesc(TaskNode* dst_node, std::string* name) { | |||
} | |||
|
|||
void TaskNode::BindEdgeWithProducedRegst(TaskEdge* edge, const std::string& name) { | |||
if (edge->HasRegst(name)) { return; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
避免重复绑定
oneflow/core/graph/task_node.cpp
Outdated
std::shared_ptr<RegstDesc> TaskNode::ProduceRegst(const std::string& name, bool enable_reuse_mem) { | ||
return ProduceRegst(name, enable_reuse_mem, 1, kMaxRegisterNum); | ||
} | ||
|
||
std::shared_ptr<RegstDesc> TaskNode::ProduceRegst(const std::string& name, bool enable_reuse_mem, | ||
int32_t min_register_num, | ||
int32_t max_register_num) { | ||
// Because the Regst of separate compilation is not created in order, some Regst may have been | ||
// built. This implementation can avoid ProduceRegst being called multiple times. | ||
const auto& regst = GetOrCheckRegst(name, enable_reuse_mem, min_register_num, max_register_num); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
部分 register boxing task graph 中重建过的,避免重复创建。
for (const auto& pair : data_regst_desc_proto.lbi2blob_desc()) { | ||
*AddLbi(pair.lbi()) = BlobDesc(pair.blob_desc()); | ||
} | ||
CHECK(!data_regst_desc_proto.has_time_shape()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为什么是 false
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为什么是 false
因为序列化时没有创建这个字段:
https://github.com/Oneflow-Inc/oneflow/pull/10119/files#r1165306496
@@ -120,6 +120,32 @@ void RegstDesc::EraseUninitializedShapeBlob() { | |||
}); | |||
} | |||
|
|||
void RegstDesc::InitFromProtoExceptConsumers(const RegstDescProto& proto) { | |||
regst_desc_id_ = proto.regst_desc_id(); | |||
CHECK_EQ(proto.producer_task_id(), producer_->task_id()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
proto.consumer_task_id 应该是怎样检查呢? 为空?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
proto.consumer_task_id 应该是怎样检查呢? 为空?
不会为空,不过不能在创建 task node 时立即消费,需要延迟到 task node 先把 register produce 一遍后,才能消费得到,所以拆分了一个 InitConsumedRegstsFromProto 函数来消费它
oneflow/core/graph/task_node.h
Outdated
@@ -181,6 +189,7 @@ class TaskEdge final : public Edge<TaskNode, TaskEdge> { | |||
void AddLbis(const std::vector<LogicalBlobId>& lbis) { lbis_.insert(lbis.begin(), lbis.end()); } | |||
|
|||
void CheckRegstLbiValid() const; | |||
bool OutHasBindRegst() const { return !name_in_producer2regst_.empty(); } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个命名没理解含义
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个命名没理解含义
改成了 HasRegst,这个是用于 check edge 都绑定了 regst 了,之前编译出错时出现了 部分 edge 没有绑定 regst 的情况。
oneflow/core/graph/task_node.cpp
Outdated
@@ -367,6 +421,7 @@ std::vector<std::shared_ptr<RegstDesc>> TaskEdge::GetRegsts() const { | |||
|
|||
void TaskEdge::AddRegst(const std::string& name_in_producer, | |||
const std::shared_ptr<RegstDesc>& regst) { | |||
if (HasRegst(name_in_producer)) { return; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里不能直接 return 吧? 至少应该检查一下传入的 regst 和 name_in_producer2regst_.at(name_in_producer)
一致?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里不能直接 return 吧? 至少应该检查一下传入的 regst 和
name_in_producer2regst_.at(name_in_producer)
一致?
好的
oneflow/core/graph/task_node.cpp
Outdated
std::shared_ptr<RegstDesc> TaskEdge::GetSoleRegst() const { | ||
CHECK_EQ(name_in_producer2regst_.size(), 1); | ||
CHECK_EQ(name_in_producer2regst_.size(), 1) | ||
<< "edge: " << this << ", src: " << src_node()->task_id(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dst ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dst ?
done
@@ -100,6 +100,10 @@ class TaskNode : public Node<TaskNode, TaskEdge> { | |||
std::string VisualStr() const override; | |||
virtual bool IsMeaningLess(); | |||
void ToProto(TaskProto* task_proto) const { ToProto(task_proto, /*check*/ true); } | |||
virtual void InitFromProtoExceptConsumedRegsts(const TaskProto& task_proto); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里可以注释一下是为了分离编译中途基于 plan 创建 task node 用
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里可以注释一下是为了分离编译中途基于 plan 创建 task node 用
done
chain_id_ = task_proto.chain_id(); | ||
order_in_chain_ = task_proto.order_in_chain(); | ||
// Step2: check exec_gph empty. | ||
CHECK(task_proto.exec_sequence().exec_node().empty()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consumed_regst_desc_id 要检查吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consumed_regst_desc_id 要检查吗?
检查它非空么,consumed_regst_desc_id 在后面做了消费
oneflow/core/graph/task_node.cpp
Outdated
edge->AddRegst(name, GetProducedRegst(name)); | ||
} | ||
|
||
std::shared_ptr<RegstDesc> TaskNode::GetOrCheckRegst(const std::string& name, bool enable_reuse_mem, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看实现,这里的 get 和 check 都做了吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看实现,这里的 get 和 check 都做了吧
改成了 GetAndCheckRegst
void EraseFakeRegstsIf() override; | ||
|
||
// ConsumeFakeRegsts is used for initializing CompTaskNode.consumed_regsts_ on the other ranks. | ||
virtual void ConsumeFakeRegsts() = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
本质上这个是一个取巧的临时绕过去的办法。 未来要重构出 task graph build 不依赖 input regst 的版本
for (const auto& pair : consumed_regsts()) { | ||
for (const auto& regst_desc : pair.second) { | ||
if (regst_desc->regst_desc_type().has_data_regst_desc()) { | ||
CHECK(data_regst_desc == nullptr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里没看懂,表示最多只有一个 data regst ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里没看懂,表示最多只有一个 data regst ?
嗯,只创建了一个占位的 fake regst。
void NormalForwardCompTaskNode::ConsumeAllRegsts() {
ForEachInDataEdge([&](TaskEdge* edge) {
for (const auto& regst : edge->GetRegsts()) { ConsumeRegst("in", regst); }
});
}
void NormalForwardCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); }
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
那加下注释吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
那加下注释吧
done
for (const auto& pair : consumed_regsts()) { | ||
for (const auto& regst_desc : pair.second) { | ||
if (regst_desc->regst_desc_type().has_data_regst_desc()) { | ||
CHECK(data_regst_desc == nullptr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里没看懂,表示最多只有一个 data regst ?
嗯,只创建了一个占位的 fake regst。
void NormalForwardCompTaskNode::ConsumeAllRegsts() {
ForEachInDataEdge([&](TaskEdge* edge) {
for (const auto& regst : edge->GetRegsts()) { ConsumeRegst("in", regst); }
});
}
void NormalForwardCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); }
} | ||
if (data_regst_desc != nullptr) { | ||
for (const auto& ibn : op_node()->op().input_bns()) { | ||
data_regst_desc->AddLbi(op_node()->op().BnInOp2Lbi(ibn)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
多个 lbi 指向了这占位的 fake regst
const TaskProto& task_proto, | ||
const std::function<Maybe<RegstDesc>(int64_t regst_desc_id)>& RegstDesc4Id) { | ||
// init consumed_regst. | ||
for (const auto& pair : task_proto.consumed_regst_desc_id()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consumed_regst_desc_id 这里做了消费
chain_id_ = task_proto.chain_id(); | ||
order_in_chain_ = task_proto.order_in_chain(); | ||
// Step2: check exec_gph empty. | ||
CHECK(task_proto.exec_sequence().exec_node().empty()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consumed_regst_desc_id 要检查吗?
检查它非空么,consumed_regst_desc_id 在后面做了消费
oneflow/core/graph/task_node.cpp
Outdated
@@ -367,6 +421,7 @@ std::vector<std::shared_ptr<RegstDesc>> TaskEdge::GetRegsts() const { | |||
|
|||
void TaskEdge::AddRegst(const std::string& name_in_producer, | |||
const std::shared_ptr<RegstDesc>& regst) { | |||
if (HasRegst(name_in_producer)) { return; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里不能直接 return 吧? 至少应该检查一下传入的 regst 和
name_in_producer2regst_.at(name_in_producer)
一致?
好的
for (const auto& pair : consumed_regsts()) { | ||
for (const auto& regst_desc : pair.second) { | ||
if (regst_desc->regst_desc_type().has_data_regst_desc()) { | ||
CHECK(data_regst_desc == nullptr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
那加下注释吧
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/10140/ |
Speed stats:
|
CI failed when running job: cuda-module. PR label automerge has been removed |
CI failed when running job: cuda-module. PR label automerge has been removed |
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/10140/ |
Speed stats:
|
No description provided.