Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement flow.utils.checkpoint #9053

Merged
merged 7 commits into from Sep 9, 2022
Merged

implement flow.utils.checkpoint #9053

merged 7 commits into from Sep 9, 2022

Conversation

daquexian
Copy link
Contributor

#8995 的后续

从 PyTorch 搬运来它的 checkpointing 实现,用于 eager

Signed-off-by: daquexian <daquexian566@gmail.com>
Signed-off-by: daquexian <daquexian566@gmail.com>
Signed-off-by: daquexian <daquexian566@gmail.com>
return None

def inner_unpack(packed):
raise RuntimeError("You are calling backwards on a tensor that is never exposed. Please open an issue.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是不是在暗示不能处理二阶导?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不是的,是因为这里真正会被调用的只有 inner_pack 这个函数,而 API 要求同时提供 pack 和 unpack 两个 hook,所以写了一个没用的函数占位置。

目测这个 API 在二阶导下仍然可以正常工作,但它确实是不适合在二阶导下使用的 —— 二阶导要求两次导,在求第一次导的时候重计算就会被触发并占据显存,因此无法减少接下来求第二次导时的显存占用。此外二阶导要在一阶导的后向图上求导,这部分计算图没有暴露给用户,因此从原理上就没有办法用这套 checkpointing API 控制。DTR 这种形式的重计算方案才可以。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lixinqi lixinqi self-requested a review September 5, 2022 05:26
@github-actions
Copy link
Contributor

github-actions bot commented Sep 5, 2022

Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.

@oneflow-ci-bot oneflow-ci-bot requested review from oneflow-ci-bot and removed request for oneflow-ci-bot September 5, 2022 07:40
@github-actions
Copy link
Contributor

github-actions bot commented Sep 5, 2022

Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.

@oneflow-ci-bot oneflow-ci-bot requested review from oneflow-ci-bot and removed request for oneflow-ci-bot September 5, 2022 13:17
@github-actions
Copy link
Contributor

github-actions bot commented Sep 5, 2022

View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/9053/

@github-actions
Copy link
Contributor

github-actions bot commented Sep 5, 2022

Speed stats:
GPU Name: GeForce GTX 1080 

❌ OneFlow resnet50 time: 129.3ms (= 12926.9ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 143.0ms (= 14296.0ms / 100, input_shape=[16, 3, 224, 224])
✔️ Relative speed: 1.11 (= 143.0ms / 129.3ms)

OneFlow resnet50 time: 74.6ms (= 7462.8ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 85.5ms (= 8547.1ms / 100, input_shape=[8, 3, 224, 224])
✔️ Relative speed: 1.15 (= 85.5ms / 74.6ms)

OneFlow resnet50 time: 47.6ms (= 9512.4ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 54.6ms (= 10928.8ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.15 (= 54.6ms / 47.6ms)

OneFlow resnet50 time: 35.0ms (= 6990.7ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 40.7ms (= 8138.8ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 1.16 (= 40.7ms / 35.0ms)

OneFlow resnet50 time: 28.6ms (= 5710.4ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 37.8ms (= 7550.5ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 1.32 (= 37.8ms / 28.6ms)

OneFlow swin dataloader time: 0.268s (= 53.642s / 200, num_workers=1)
PyTorch swin dataloader time: 0.150s (= 29.986s / 200, num_workers=1)
Relative speed: 0.559 (= 0.150s / 0.268s)

OneFlow swin dataloader time: 0.071s (= 14.207s / 200, num_workers=4)
PyTorch swin dataloader time: 0.041s (= 8.275s / 200, num_workers=4)
Relative speed: 0.582 (= 0.041s / 0.071s)

OneFlow swin dataloader time: 0.040s (= 8.002s / 200, num_workers=8)
PyTorch swin dataloader time: 0.022s (= 4.357s / 200, num_workers=8)
Relative speed: 0.545 (= 0.022s / 0.040s)

❌ OneFlow resnet50 time: 138.2ms (= 13822.5ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 162.9ms (= 16286.9ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.18 (= 162.9ms / 138.2ms)

OneFlow resnet50 time: 85.7ms (= 8565.8ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 104.7ms (= 10469.4ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.22 (= 104.7ms / 85.7ms)

OneFlow resnet50 time: 58.2ms (= 11645.2ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 78.4ms (= 15686.1ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.35 (= 78.4ms / 58.2ms)

OneFlow resnet50 time: 45.1ms (= 9029.0ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 71.2ms (= 14232.9ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.58 (= 71.2ms / 45.1ms)

OneFlow resnet50 time: 38.8ms (= 7763.9ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 67.2ms (= 13437.4ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.73 (= 67.2ms / 38.8ms)

@github-actions
Copy link
Contributor

github-actions bot commented Sep 9, 2022

Speed stats:
GPU Name: GeForce GTX 1080 

❌ OneFlow resnet50 time: 129.6ms (= 12960.5ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 148.5ms (= 14848.3ms / 100, input_shape=[16, 3, 224, 224])
✔️ Relative speed: 1.15 (= 148.5ms / 129.6ms)

OneFlow resnet50 time: 75.0ms (= 7496.6ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 83.6ms (= 8357.6ms / 100, input_shape=[8, 3, 224, 224])
✔️ Relative speed: 1.11 (= 83.6ms / 75.0ms)

OneFlow resnet50 time: 47.7ms (= 9536.5ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 60.6ms (= 12112.4ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.27 (= 60.6ms / 47.7ms)

OneFlow resnet50 time: 35.5ms (= 7104.0ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 41.5ms (= 8309.9ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 1.17 (= 41.5ms / 35.5ms)

OneFlow resnet50 time: 29.8ms (= 5954.2ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 37.7ms (= 7549.4ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 1.27 (= 37.7ms / 29.8ms)

OneFlow swin dataloader time: 0.263s (= 52.621s / 200, num_workers=1)
PyTorch swin dataloader time: 0.151s (= 30.248s / 200, num_workers=1)
Relative speed: 0.575 (= 0.151s / 0.263s)

OneFlow swin dataloader time: 0.075s (= 14.976s / 200, num_workers=4)
PyTorch swin dataloader time: 0.041s (= 8.173s / 200, num_workers=4)
Relative speed: 0.546 (= 0.041s / 0.075s)

OneFlow swin dataloader time: 0.040s (= 7.938s / 200, num_workers=8)
PyTorch swin dataloader time: 0.023s (= 4.609s / 200, num_workers=8)
Relative speed: 0.581 (= 0.023s / 0.040s)

❌ OneFlow resnet50 time: 139.9ms (= 13988.6ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 161.5ms (= 16145.1ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.15 (= 161.5ms / 139.9ms)

OneFlow resnet50 time: 86.2ms (= 8621.7ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 103.1ms (= 10306.6ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.20 (= 103.1ms / 86.2ms)

OneFlow resnet50 time: 58.8ms (= 11764.3ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 78.4ms (= 15677.9ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.33 (= 78.4ms / 58.8ms)

OneFlow resnet50 time: 45.2ms (= 9046.6ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 70.7ms (= 14149.4ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.56 (= 70.7ms / 45.2ms)

OneFlow resnet50 time: 40.3ms (= 8051.8ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 67.9ms (= 13578.9ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.69 (= 67.9ms / 40.3ms)

@github-actions
Copy link
Contributor

github-actions bot commented Sep 9, 2022

View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/9053/

@mergify mergify bot merged commit ecafd61 into master Sep 9, 2022
@mergify mergify bot deleted the eager_checkpointing branch September 9, 2022 09:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants