Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support tensor and optimizer serialization #6087

Merged
merged 10 commits into from
Aug 29, 2021

Conversation

wyg1997
Copy link
Contributor

@wyg1997 wyg1997 commented Aug 28, 2021

  • 支持 Tensor/DType 的 pickle 序列化和反序列化
  • Optimizer state_dict/load_state_dict 接口保存/加载训练状态

@wyg1997 wyg1997 requested a review from daquexian August 28, 2021 05:06
@wyg1997
Copy link
Contributor Author

wyg1997 commented Aug 28, 2021

Optimizer 的训练状态可以通过:

import pickle
with open("optim_state.pkl", "wb") as f:
    pickle.dump(sgd.state_dict(), f)

来保存

@daquexian
Copy link
Contributor

daquexian commented Aug 29, 2021

Optimizer 的训练状态可以通过:

import pickle
with open("optim_state.pkl", "wb") as f:
    pickle.dump(sgd.state_dict(), f)

来保存

这和 pytorch 是不是还不对齐。torch.save/load 是支持 pickle 的。

我们为了支持超大权重,是每个权重保存一个文件的,不像 pytorch 把整个 state_dict 整体序列化。所以支持 pickle 之后 save 逻辑可以改成:遍历传入的 dict,如果元素是 tensor,则走原来的逻辑(为了不引入风险),如果不是 tensor,则生成它的 .pkl 文件保存下来。
Load 的逻辑可以改成,如果某个文件可以被 pickle 加载则用 pickle 加载,如果不可以则走原来的逻辑

@wyg1997
Copy link
Contributor Author

wyg1997 commented Aug 29, 2021

这和 pytorch 是不是还不对齐。torch.save/load 是支持 pickle 的

我们 flow.save/load 接口现在是文件夹保存/加载的状态,这两个接口对齐了,就可以直接 flow.save/load(optim.state_dict)了。但是这是不是和 lazy 的 checkpoint 用法不一致了?

@daquexian
Copy link
Contributor

这和 pytorch 是不是还不对齐。torch.save/load 是支持 pickle 的

我们 flow.save/load 接口现在是文件夹保存/加载的状态,这两个接口对齐了,就可以直接 flow.save/load(optim.state_dict)了。但是这是不是和 lazy 的 checkpoint 用法不一致了?

lazy 的 checkpoint 是指神马,现在保存网络的权重也是通过 flow.save(module.state_dict()) 保存的。
我更新了上一条评论添加了一些具体思路

@wyg1997
Copy link
Contributor Author

wyg1997 commented Aug 29, 2021

这和 pytorch 是不是还不对齐。torch.save/load 是支持 pickle 的

我们 flow.save/load 接口现在是文件夹保存/加载的状态,这两个接口对齐了,就可以直接 flow.save/load(optim.state_dict)了。但是这是不是和 lazy 的 checkpoint 用法不一致了?

lazy 的 checkpoint 是指神马,现在保存网络的权重也是通过 flow.save(module.state_dict()) 保存的。
我更新了上一条评论添加了一些具体思路

好的,我改一下

@daquexian
Copy link
Contributor

Note: 讨论之后发现对齐 flow.save/load 需要支持嵌套的 dict,对已有机制的改动量比较大,不适合现在改。所以保持这个 pr 目前的方式

@oneflow-ci-bot oneflow-ci-bot self-requested a review August 29, 2021 09:45
@github-actions
Copy link
Contributor

Speed stats:
GPU Name: GeForce GTX 1080 

OneFlow resnet50 time: 128.2ms (= 6407.6ms / 50, input_shape=[16, 3, 224, 224], backward is enabled)
OneFlow GPU used (rank 0): 0 MiB
PyTorch resnet50 time: 142.2ms (= 7110.4ms / 50, input_shape=[16, 3, 224, 224], backward is enabled)
PyTorch GPU used (rank 0, estimated): 0 MiB
Relative speed: 1.11 (= 142.2ms / 128.2ms)

OneFlow resnet50 time: 74.5ms (= 3724.9ms / 50, input_shape=[8, 3, 224, 224], backward is enabled)
OneFlow GPU used (rank 0): 0 MiB
PyTorch resnet50 time: 83.8ms (= 4190.5ms / 50, input_shape=[8, 3, 224, 224], backward is enabled)
PyTorch GPU used (rank 0, estimated): 0 MiB
Relative speed: 1.13 (= 83.8ms / 74.5ms)

OneFlow resnet50 time: 47.7ms (= 2385.4ms / 50, input_shape=[4, 3, 224, 224], backward is enabled)
OneFlow GPU used (rank 0): 0 MiB
PyTorch resnet50 time: 59.2ms (= 2958.1ms / 50, input_shape=[4, 3, 224, 224], backward is enabled)
PyTorch GPU used (rank 0, estimated): 0 MiB
Relative speed: 1.24 (= 59.2ms / 47.7ms)

OneFlow resnet50 time: 39.1ms (= 1952.6ms / 50, input_shape=[2, 3, 224, 224], backward is enabled)
OneFlow GPU used (rank 0): 0 MiB
PyTorch resnet50 time: 49.8ms (= 2489.3ms / 50, input_shape=[2, 3, 224, 224], backward is enabled)
PyTorch GPU used (rank 0, estimated): 0 MiB
Relative speed: 1.27 (= 49.8ms / 39.1ms)

OneFlow resnet50 time: 44.2ms (= 2210.4ms / 50, input_shape=[1, 3, 224, 224], backward is enabled)
OneFlow GPU used (rank 0): 0 MiB
PyTorch resnet50 time: 45.9ms (= 2294.4ms / 50, input_shape=[1, 3, 224, 224], backward is enabled)
PyTorch GPU used (rank 0, estimated): 0 MiB
Relative speed: 1.04 (= 45.9ms / 44.2ms)

OneFlow resnet50 time: 142.4ms (= 7121.1ms / 50, input_shape=[16, 3, 224, 224], backward is enabled)
OneFlow GPU used (rank 0): 0 MiB
PyTorch resnet50 time: 150.1ms (= 7504.8ms / 50, input_shape=[16, 3, 224, 224], backward is enabled)
PyTorch GPU used (rank 0, estimated): 0 MiB
Relative speed: 1.05 (= 150.1ms / 142.4ms)

OneFlow resnet50 time: 92.3ms (= 4615.9ms / 50, input_shape=[8, 3, 224, 224], backward is enabled)
OneFlow GPU used (rank 0): 0 MiB
PyTorch resnet50 time: 92.9ms (= 4643.7ms / 50, input_shape=[8, 3, 224, 224], backward is enabled)
PyTorch GPU used (rank 0, estimated): 0 MiB
Relative speed: 1.01 (= 92.9ms / 92.3ms)

OneFlow resnet50 time: 67.8ms (= 3387.8ms / 50, input_shape=[4, 3, 224, 224], backward is enabled)
OneFlow GPU used (rank 0): 0 MiB
PyTorch resnet50 time: 67.3ms (= 3366.1ms / 50, input_shape=[4, 3, 224, 224], backward is enabled)
PyTorch GPU used (rank 0, estimated): 0 MiB
Relative speed: 0.99 (= 67.3ms / 67.8ms)

OneFlow resnet50 time: 62.5ms (= 3124.7ms / 50, input_shape=[2, 3, 224, 224], backward is enabled)
OneFlow GPU used (rank 0): 0 MiB
PyTorch resnet50 time: 53.9ms (= 2697.5ms / 50, input_shape=[2, 3, 224, 224], backward is enabled)
PyTorch GPU used (rank 0, estimated): 0 MiB
Relative speed: 0.86 (= 53.9ms / 62.5ms)

OneFlow resnet50 time: 54.1ms (= 2704.5ms / 50, input_shape=[1, 3, 224, 224], backward is enabled)
OneFlow GPU used (rank 0): 0 MiB
PyTorch resnet50 time: 52.9ms (= 2643.5ms / 50, input_shape=[1, 3, 224, 224], backward is enabled)
PyTorch GPU used (rank 0, estimated): 0 MiB
Relative speed: 0.98 (= 52.9ms / 54.1ms)

@oneflow-ci-bot oneflow-ci-bot merged commit dd7c382 into master Aug 29, 2021
@oneflow-ci-bot oneflow-ci-bot deleted the feat-tensor_optim_serialization branch August 29, 2021 10:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants