Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add consume fake regst #10140

Merged
merged 25 commits into from
May 6, 2023
Merged

Add consume fake regst #10140

merged 25 commits into from
May 6, 2023

Conversation

strint
Copy link
Contributor

@strint strint commented Apr 14, 2023

No description provided.

@strint strint requested a review from chengtbf as a code owner April 14, 2023 15:20
Copy link
Contributor

@Yipeng1994 Yipeng1994 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Base automatically changed from sep2_custom_blobdesc_infer to master April 20, 2023 07:32
@strint strint added graph graph mode feature labels Apr 20, 2023
@strint strint requested a review from oneflow-ci-bot April 20, 2023 07:46
@github-actions
Copy link
Contributor

Speed stats:

void EraseFakeRegstsIf() override;

// ConsumeFakeRegsts is used for initializing CompTaskNode.consumed_regsts_ on the other ranks.
virtual void ConsumeFakeRegsts() = 0;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BoxingTaskGraph 中的 TransportTaskNode 会带有上游的多个 rank 的 ComputeTaskNode,当做 task graph 的 rank 内 infer BlobDesc 时,非本 rank 的 input regst 是不存在的,但是需要一个 fake 数据来保证正常的推理逻辑可以通过。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

本质上这个是一个取巧的临时绕过去的办法。 未来要重构出 task graph build 不依赖 input regst 的版本

// Provide a compute task node with a fake input regst, and its output regst can be inferred using
// SBP + Placement. The fake compute task node can help the task graph of one rank to infer blob
// desc, mainly to ensure that the transport task node has the correct input blob desc.
class FakeConsumedRegstProvider {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mixin 接口类

@@ -265,16 +297,33 @@ RegstDesc* TaskNode::BuildCtrlRegstDesc(TaskNode* dst_node, std::string* name) {
}

void TaskNode::BindEdgeWithProducedRegst(TaskEdge* edge, const std::string& name) {
if (edge->HasRegst(name)) { return; }
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

避免重复绑定

std::shared_ptr<RegstDesc> TaskNode::ProduceRegst(const std::string& name, bool enable_reuse_mem) {
return ProduceRegst(name, enable_reuse_mem, 1, kMaxRegisterNum);
}

std::shared_ptr<RegstDesc> TaskNode::ProduceRegst(const std::string& name, bool enable_reuse_mem,
int32_t min_register_num,
int32_t max_register_num) {
// Because the Regst of separate compilation is not created in order, some Regst may have been
// built. This implementation can avoid ProduceRegst being called multiple times.
const auto& regst = GetOrCheckRegst(name, enable_reuse_mem, min_register_num, max_register_num);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

部分 register boxing task graph 中重建过的,避免重复创建。

for (const auto& pair : data_regst_desc_proto.lbi2blob_desc()) {
*AddLbi(pair.lbi()) = BlobDesc(pair.blob_desc());
}
CHECK(!data_regst_desc_proto.has_time_shape());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么是 false

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么是 false

因为序列化时没有创建这个字段:
https://github.com/Oneflow-Inc/oneflow/pull/10119/files#r1165306496

@@ -120,6 +120,32 @@ void RegstDesc::EraseUninitializedShapeBlob() {
});
}

void RegstDesc::InitFromProtoExceptConsumers(const RegstDescProto& proto) {
regst_desc_id_ = proto.regst_desc_id();
CHECK_EQ(proto.producer_task_id(), producer_->task_id());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

proto.consumer_task_id 应该是怎样检查呢? 为空?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

proto.consumer_task_id 应该是怎样检查呢? 为空?

不会为空,不过不能在创建 task node 时立即消费,需要延迟到 task node 先把 register produce 一遍后,才能消费得到,所以拆分了一个 InitConsumedRegstsFromProto 函数来消费它

@@ -181,6 +189,7 @@ class TaskEdge final : public Edge<TaskNode, TaskEdge> {
void AddLbis(const std::vector<LogicalBlobId>& lbis) { lbis_.insert(lbis.begin(), lbis.end()); }

void CheckRegstLbiValid() const;
bool OutHasBindRegst() const { return !name_in_producer2regst_.empty(); }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个命名没理解含义

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个命名没理解含义

改成了 HasRegst,这个是用于 check edge 都绑定了 regst 了,之前编译出错时出现了 部分 edge 没有绑定 regst 的情况。

@@ -367,6 +421,7 @@ std::vector<std::shared_ptr<RegstDesc>> TaskEdge::GetRegsts() const {

void TaskEdge::AddRegst(const std::string& name_in_producer,
const std::shared_ptr<RegstDesc>& regst) {
if (HasRegst(name_in_producer)) { return; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不能直接 return 吧? 至少应该检查一下传入的 regst 和 name_in_producer2regst_.at(name_in_producer) 一致?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不能直接 return 吧? 至少应该检查一下传入的 regst 和 name_in_producer2regst_.at(name_in_producer) 一致?

好的

std::shared_ptr<RegstDesc> TaskEdge::GetSoleRegst() const {
CHECK_EQ(name_in_producer2regst_.size(), 1);
CHECK_EQ(name_in_producer2regst_.size(), 1)
<< "edge: " << this << ", src: " << src_node()->task_id();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dst ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dst ?

done

@@ -100,6 +100,10 @@ class TaskNode : public Node<TaskNode, TaskEdge> {
std::string VisualStr() const override;
virtual bool IsMeaningLess();
void ToProto(TaskProto* task_proto) const { ToProto(task_proto, /*check*/ true); }
virtual void InitFromProtoExceptConsumedRegsts(const TaskProto& task_proto);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可以注释一下是为了分离编译中途基于 plan 创建 task node 用

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可以注释一下是为了分离编译中途基于 plan 创建 task node 用

done

chain_id_ = task_proto.chain_id();
order_in_chain_ = task_proto.order_in_chain();
// Step2: check exec_gph empty.
CHECK(task_proto.exec_sequence().exec_node().empty());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consumed_regst_desc_id 要检查吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consumed_regst_desc_id 要检查吗?

检查它非空么,consumed_regst_desc_id 在后面做了消费

edge->AddRegst(name, GetProducedRegst(name));
}

std::shared_ptr<RegstDesc> TaskNode::GetOrCheckRegst(const std::string& name, bool enable_reuse_mem,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看实现,这里的 get 和 check 都做了吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看实现,这里的 get 和 check 都做了吧

改成了 GetAndCheckRegst

void EraseFakeRegstsIf() override;

// ConsumeFakeRegsts is used for initializing CompTaskNode.consumed_regsts_ on the other ranks.
virtual void ConsumeFakeRegsts() = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

本质上这个是一个取巧的临时绕过去的办法。 未来要重构出 task graph build 不依赖 input regst 的版本

for (const auto& pair : consumed_regsts()) {
for (const auto& regst_desc : pair.second) {
if (regst_desc->regst_desc_type().has_data_regst_desc()) {
CHECK(data_regst_desc == nullptr);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里没看懂,表示最多只有一个 data regst ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里没看懂,表示最多只有一个 data regst ?

嗯,只创建了一个占位的 fake regst。

void NormalForwardCompTaskNode::ConsumeAllRegsts() {
  ForEachInDataEdge([&](TaskEdge* edge) {
    for (const auto& regst : edge->GetRegsts()) { ConsumeRegst("in", regst); }
  });
}

void NormalForwardCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); }

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那加下注释吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那加下注释吧

done

for (const auto& pair : consumed_regsts()) {
for (const auto& regst_desc : pair.second) {
if (regst_desc->regst_desc_type().has_data_regst_desc()) {
CHECK(data_regst_desc == nullptr);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里没看懂,表示最多只有一个 data regst ?

嗯,只创建了一个占位的 fake regst。

void NormalForwardCompTaskNode::ConsumeAllRegsts() {
  ForEachInDataEdge([&](TaskEdge* edge) {
    for (const auto& regst : edge->GetRegsts()) { ConsumeRegst("in", regst); }
  });
}

void NormalForwardCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); }

}
if (data_regst_desc != nullptr) {
for (const auto& ibn : op_node()->op().input_bns()) {
data_regst_desc->AddLbi(op_node()->op().BnInOp2Lbi(ibn));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

多个 lbi 指向了这占位的 fake regst

const TaskProto& task_proto,
const std::function<Maybe<RegstDesc>(int64_t regst_desc_id)>& RegstDesc4Id) {
// init consumed_regst.
for (const auto& pair : task_proto.consumed_regst_desc_id()) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consumed_regst_desc_id 这里做了消费

chain_id_ = task_proto.chain_id();
order_in_chain_ = task_proto.order_in_chain();
// Step2: check exec_gph empty.
CHECK(task_proto.exec_sequence().exec_node().empty());
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consumed_regst_desc_id 要检查吗?

检查它非空么,consumed_regst_desc_id 在后面做了消费

@@ -367,6 +421,7 @@ std::vector<std::shared_ptr<RegstDesc>> TaskEdge::GetRegsts() const {

void TaskEdge::AddRegst(const std::string& name_in_producer,
const std::shared_ptr<RegstDesc>& regst) {
if (HasRegst(name_in_producer)) { return; }
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不能直接 return 吧? 至少应该检查一下传入的 regst 和 name_in_producer2regst_.at(name_in_producer) 一致?

好的

for (const auto& pair : consumed_regsts()) {
for (const auto& regst_desc : pair.second) {
if (regst_desc->regst_desc_type().has_data_regst_desc()) {
CHECK(data_regst_desc == nullptr);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那加下注释吧

@github-actions
Copy link
Contributor

github-actions bot commented May 5, 2023

View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/10140/

@github-actions
Copy link
Contributor

github-actions bot commented May 5, 2023

Speed stats:
GPU Name: NVIDIA GeForce RTX 3080 Ti 

❌ OneFlow resnet50 time: 43.0ms (= 4300.6ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 60.2ms (= 6016.9ms / 100, input_shape=[16, 3, 224, 224])
✔️ Relative speed: 1.40 (= 60.2ms / 43.0ms)

OneFlow resnet50 time: 25.9ms (= 2587.3ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 37.7ms (= 3773.5ms / 100, input_shape=[8, 3, 224, 224])
✔️ Relative speed: 1.46 (= 37.7ms / 25.9ms)

OneFlow resnet50 time: 18.1ms (= 3616.0ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 35.2ms (= 7040.0ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.95 (= 35.2ms / 18.1ms)

OneFlow resnet50 time: 16.6ms (= 3324.8ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 31.0ms (= 6201.9ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 1.87 (= 31.0ms / 16.6ms)

OneFlow resnet50 time: 15.9ms (= 3170.2ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 28.5ms (= 5699.3ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 1.80 (= 28.5ms / 15.9ms)

OneFlow swin dataloader time: 0.201s (= 40.190s / 200, num_workers=1)
PyTorch swin dataloader time: 0.132s (= 26.413s / 200, num_workers=1)
Relative speed: 0.657 (= 0.132s / 0.201s)

OneFlow swin dataloader time: 0.058s (= 11.554s / 200, num_workers=4)
PyTorch swin dataloader time: 0.033s (= 6.574s / 200, num_workers=4)
Relative speed: 0.569 (= 0.033s / 0.058s)

OneFlow swin dataloader time: 0.032s (= 6.472s / 200, num_workers=8)
PyTorch swin dataloader time: 0.017s (= 3.377s / 200, num_workers=8)
Relative speed: 0.522 (= 0.017s / 0.032s)

❌ OneFlow resnet50 time: 47.2ms (= 4719.2ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 63.9ms (= 6394.6ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.36 (= 63.9ms / 47.2ms)

OneFlow resnet50 time: 30.5ms (= 3053.7ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 43.4ms (= 4341.6ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.42 (= 43.4ms / 30.5ms)

OneFlow resnet50 time: 23.7ms (= 4743.2ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 40.0ms (= 8005.9ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.69 (= 40.0ms / 23.7ms)

OneFlow resnet50 time: 22.3ms (= 4464.3ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 36.6ms (= 7321.6ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.64 (= 36.6ms / 22.3ms)

OneFlow resnet50 time: 20.4ms (= 4076.8ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 33.7ms (= 6736.0ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.65 (= 33.7ms / 20.4ms)

@strint strint requested review from oneflow-ci-bot and removed request for oneflow-ci-bot May 6, 2023 01:04
@github-actions
Copy link
Contributor

github-actions bot commented May 6, 2023

CI failed when running job: cuda-module. PR label automerge has been removed

@github-actions github-actions bot removed the automerge label May 6, 2023
@strint strint requested review from oneflow-ci-bot and removed request for oneflow-ci-bot May 6, 2023 01:30
@strint strint requested review from oneflow-ci-bot and removed request for oneflow-ci-bot May 6, 2023 01:55
@github-actions
Copy link
Contributor

github-actions bot commented May 6, 2023

CI failed when running job: cuda-module. PR label automerge has been removed

@github-actions github-actions bot removed the automerge label May 6, 2023
@strint strint requested review from oneflow-ci-bot and removed request for oneflow-ci-bot May 6, 2023 02:23
@github-actions
Copy link
Contributor

github-actions bot commented May 6, 2023

View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/10140/

@github-actions
Copy link
Contributor

github-actions bot commented May 6, 2023

Speed stats:
GPU Name: NVIDIA GeForce RTX 3090 

❌ OneFlow resnet50 time: 42.7ms (= 4265.7ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 57.9ms (= 5787.2ms / 100, input_shape=[16, 3, 224, 224])
✔️ Relative speed: 1.36 (= 57.9ms / 42.7ms)

OneFlow resnet50 time: 26.5ms (= 2648.2ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 37.7ms (= 3769.2ms / 100, input_shape=[8, 3, 224, 224])
✔️ Relative speed: 1.42 (= 37.7ms / 26.5ms)

OneFlow resnet50 time: 18.2ms (= 3634.6ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 35.0ms (= 6990.7ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.92 (= 35.0ms / 18.2ms)

OneFlow resnet50 time: 16.7ms (= 3348.4ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 33.7ms (= 6736.0ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 2.01 (= 33.7ms / 16.7ms)

OneFlow resnet50 time: 15.8ms (= 3161.9ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 29.7ms (= 5948.5ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 1.88 (= 29.7ms / 15.8ms)

OneFlow swin dataloader time: 0.201s (= 40.254s / 200, num_workers=1)
PyTorch swin dataloader time: 0.129s (= 25.894s / 200, num_workers=1)
Relative speed: 0.643 (= 0.129s / 0.201s)

OneFlow swin dataloader time: 0.055s (= 11.075s / 200, num_workers=4)
PyTorch swin dataloader time: 0.033s (= 6.571s / 200, num_workers=4)
Relative speed: 0.593 (= 0.033s / 0.055s)

OneFlow swin dataloader time: 0.036s (= 7.211s / 200, num_workers=8)
PyTorch swin dataloader time: 0.017s (= 3.324s / 200, num_workers=8)
Relative speed: 0.461 (= 0.017s / 0.036s)

❌ OneFlow resnet50 time: 48.4ms (= 4840.7ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 64.4ms (= 6444.8ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.33 (= 64.4ms / 48.4ms)

OneFlow resnet50 time: 36.3ms (= 3630.0ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 47.3ms (= 4734.2ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.30 (= 47.3ms / 36.3ms)

OneFlow resnet50 time: 28.6ms (= 5714.3ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 41.3ms (= 8252.3ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.44 (= 41.3ms / 28.6ms)

OneFlow resnet50 time: 25.6ms (= 5117.3ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 38.8ms (= 7751.8ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.51 (= 38.8ms / 25.6ms)

OneFlow resnet50 time: 24.0ms (= 4800.3ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 37.4ms (= 7475.6ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.56 (= 37.4ms / 24.0ms)

@mergify mergify bot merged commit 9cd5d64 into master May 6, 2023
@mergify mergify bot deleted the sep3_fake_regst branch May 6, 2023 03:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants