Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Task infer blob desc support choosing method #10124

Merged
merged 18 commits into from
Apr 20, 2023
Merged

Conversation

strint
Copy link
Contributor

@strint strint commented Apr 13, 2023

No description provided.

@strint strint requested a review from chengtbf as a code owner April 13, 2023 04:41
@@ -71,6 +71,9 @@ class TaskNode : public Node<TaskNode, TaskEdge> {
DeviceType device_type() const;
virtual const ParallelContext* parallel_ctx() const { return nullptr; }

// Different types of ExecNode choose different output BlobDesc inference methods
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ExecNode 没有 type,这里说的是 TaskNode Type 吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ExecNode 没有 type,这里说的是 TaskNode Type 吧

我改下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -72,7 +72,8 @@ class ExecNode final : public Node<ExecNode, ExecEdge> {
std::string VisualStr() const override { return op_->op_name(); }
void ToProto(const ParallelContext*, ExecNodeProto*) const;

void InferBlobDescs(const ParallelContext* parallel_ctx);
typedef void (ExecNode::*InferBlobDescsMethod)(const ParallelContext*);
void InferBlobDescsByInputs(const ParallelContext* parallel_ctx);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里只提供了一种 method 吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里只提供了一种 method 吗

对,相当于这个分支不改变执行逻辑,只提供了接口

@@ -43,6 +43,10 @@ class CompTaskNode : public TaskNode {
// op
std::shared_ptr<const Operator> op() const { return op_node_->shared_op(); }

ExecNode::InferBlobDescsMethod GetInferBlobDescsMethod() const override {
return &ExecNode::InferBlobDescsByInputs;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里应该支持 from sbp 和 logical shape ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里应该支持 from sbp 和 logical shape ?

from sbp 的推理方法和编译模式关联了,所以就没加到这个分支

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

其实为了加速,master 编译这里也可以用 from sbp 吧,这样是不是会更快?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

其实为了加速,master 编译这里也可以用 from sbp 吧,这样是不是会更快?

当前估计使用 from sbp 不会变快:

  • from sbp 本身的实现开销和之前做 infer physical blobdesc 估计差不多
  • 只适用于 user op,改成通用的影响的地方比较多,有个后续 pr 在做这个
  • master infer 的过程如果不使用并行,加速不明显

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有的op不支持from sbp吧,比如涉及到求平均,求和或者求最大值的。(我记得当前有一个op是这样的,好像是叫bn?)

还有个问题就是如果sbp变动了会怎么样?当前是要重新推导一遍
比如自动并行会大规模修改sbp。这时候如果有个logical desc储存着应该好一点

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有的op不支持from sbp吧,比如涉及到求平均,求和或者求最大值的。(我记得当前有一个op是这样的,好像是叫bn?)

user op 都是符合的,这里使用的场景,之前用 physical infer 推理后,也会再用 sbp 做下 check,所以也是符合的。

还有个问题就是如果sbp变动了会怎么样?当前是要重新推导一遍 比如自动并行会大规模修改sbp。这时候如果有个logical desc储存着应该好一点

因为这个推理发生在 plan 生成阶段,是在 自动不行之后,按说 sbp 已经稳定了。

@strint strint changed the title support infer desc choose method Task infer blob desc support choosing method Apr 16, 2023
strint and others added 2 commits April 17, 2023 18:25
Co-authored-by: Yipeng Li <jamesonli1313@gmail.com>
@github-actions
Copy link
Contributor

Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.

Copy link
Contributor

@Yipeng1994 Yipeng1994 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Base automatically changed from sep0_task_proto to master April 19, 2023 21:58
@github-actions
Copy link
Contributor

Speed stats:
GPU Name: GeForce GTX 1080 

❌ OneFlow resnet50 time: 141.3ms (= 14131.6ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 147.3ms (= 14725.1ms / 100, input_shape=[16, 3, 224, 224])
❌ Relative speed: 1.04 (= 147.3ms / 141.3ms)

OneFlow resnet50 time: 82.8ms (= 8283.8ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 93.5ms (= 9346.2ms / 100, input_shape=[8, 3, 224, 224])
✔️ Relative speed: 1.13 (= 93.5ms / 82.8ms)

OneFlow resnet50 time: 51.6ms (= 10315.5ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 70.5ms (= 14095.9ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.37 (= 70.5ms / 51.6ms)

OneFlow resnet50 time: 34.0ms (= 6805.8ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 64.7ms (= 12940.5ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 1.90 (= 64.7ms / 34.0ms)

OneFlow resnet50 time: 25.7ms (= 5147.8ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 61.7ms (= 12343.9ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 2.40 (= 61.7ms / 25.7ms)

OneFlow swin dataloader time: 0.242s (= 48.435s / 200, num_workers=1)
PyTorch swin dataloader time: 0.150s (= 29.976s / 200, num_workers=1)
Relative speed: 0.619 (= 0.150s / 0.242s)

OneFlow swin dataloader time: 0.072s (= 14.333s / 200, num_workers=4)
PyTorch swin dataloader time: 0.042s (= 8.374s / 200, num_workers=4)
Relative speed: 0.584 (= 0.042s / 0.072s)

OneFlow swin dataloader time: 0.046s (= 9.208s / 200, num_workers=8)
PyTorch swin dataloader time: 0.022s (= 4.387s / 200, num_workers=8)
Relative speed: 0.476 (= 0.022s / 0.046s)

❌ OneFlow resnet50 time: 154.4ms (= 15436.7ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 165.0ms (= 16499.8ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
❌ Relative speed: 1.07 (= 165.0ms / 154.4ms)

OneFlow resnet50 time: 94.5ms (= 9453.6ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 104.3ms (= 10429.0ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.10 (= 104.3ms / 94.5ms)

OneFlow resnet50 time: 62.0ms (= 12395.1ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 88.7ms (= 17742.1ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.43 (= 88.7ms / 62.0ms)

OneFlow resnet50 time: 44.2ms (= 8834.9ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 72.2ms (= 14430.5ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.63 (= 72.2ms / 44.2ms)

OneFlow resnet50 time: 37.4ms (= 7487.9ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 70.8ms (= 14151.0ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.89 (= 70.8ms / 37.4ms)

@github-actions
Copy link
Contributor

View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/10124/

@mergify mergify bot merged commit 2d54365 into master Apr 20, 2023
@mergify mergify bot deleted the sep2_custom_blobdesc_infer branch April 20, 2023 07:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants