Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is it possible to collect state dict in cpu? #4

Open
JiaquanYe opened this issue Aug 12, 2021 · 2 comments
Open

Is it possible to collect state dict in cpu? #4

JiaquanYe opened this issue Aug 12, 2021 · 2 comments
Labels
Good Issue Good reference for newcomers

Comments

@JiaquanYe
Copy link

When I finish one epoch in trianing, the main_worker function will call ts.collect_state_dict(model, state_dict).
But because the limit of GPU resource, it will raise Out of Memory in my machine, when call ts.collect_state_dict(model, state_dict).
I found that will gather the state_dict in GPU, is it anyway to gather in CPU?

@kaiyuyue
Copy link
Owner

It is impossible to perform gather operation in cpu because the operation is based on NCCL backend. But there is a way to avoid gathering in GPU on-the-fly, that is to save state_dict of each shard locally and then write a post process script to hub them together. For example, if using 16 GPUs within 16 ranks, save 16 checkpoints during training, like model_state_rank_001.pth, model_state_rank_002.pth, … and model_state_rank_016.pth. After finishing training, write a post process script to gather these 16 checkpoints into one. Pay attention to keep right order for each shard state and run the inference test to check result.

@JiaquanYe
Copy link
Author

It is impossible to perform gather operation in cpu because the operation is based on NCCL backend. But there is a way to avoid gathering in GPU on-the-fly, that is to save state_dict of each shard locally and then write a post process script to hub them together. For example, if using 16 GPUs within 16 ranks, save 16 checkpoints during training, like model_state_rank_001.pth, model_state_rank_002.pth, … and model_state_rank_016.pth. After finishing training, write a post process script to gather these 16 checkpoints into one. Pay attention to keep right order for each shard state and run the inference test to check result.

It is an excellent solution! Thanks.

@kaiyuyue kaiyuyue added the Good Issue Good reference for newcomers label Aug 24, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Good Issue Good reference for newcomers
Projects
None yet
Development

No branches or pull requests

2 participants