forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[FSDP2] Supported
set_all_reduce_gradients=False
for HSDP (pytorch#…
…126166) **Context** For FSDP, gradient accumulation across microbatches has two flavors: (1) reduce-scatter or (2) no reduce-scatter. (1) incurs the collective per microbatch backward but saves gradient memory (storing the sharded gradients), while (2) avoids the communication but uses more gradient memory (storing the unsharded gradients). - FSDP2 offers (1) without any intervention. The user should simply make sure to run the optimizer step after `K` microbatches for `K > 1`. - FSDP2 offers (2) via `module.set_requires_gradient_sync()` (e.g. `module.set_requires_gradient_sync(is_last_microbatch)`. For HSDP, since we reduce-scatter and then all-reduce, we have additional flexibility and get three flavors: (1) reduce-scatter and all-reduce, (2) reduce-scatter but no all-reduce, and (3) no reduce-scatter and no all-reduce. This PR adds support for (2). - FSDP2 offers (1) without any intervention like mentioned above. - FSDP2 offers (3) via `module.set_requires_gradient_sync()` like mentioned above. - FSDP2 offers (2) via `module.set_requires_all_reduce()` similar to `set_requires_gradient_sync()`. **Overview** For HSDP, to reduce-scatter but not all-reduce during gradient accumulation, the user can do something like: ``` for microbatch_idx, microbatch in enumerate(microbatches): is_last_microbatch = microbatch_idx == len(microbatches) - 1 model.set_requires_all_reduce(is_last_microbatch) # Run forward/backward ``` This PR also makes the minor change of making the `recurse: bool` argument in these setter methods to be kwarg only. **Developer Notes** We choose to implement this by saving the partial reduce output to the `FSDPParamGroup` for simplicity, where we assume that the set of parameters that receive gradients does not change across microbatches. An alternative would be to view into the partial reduce output per parameter and save the view to each parameter. We prefer to avoid this alternative for now because it introduces more complexity to do extra viewing when saving the partial reduce output to each parameter, accumulating into them, and accumulating back to the last microbatch's reduce output. Pull Request resolved: pytorch#126166 Approved by: https://github.com/weifengpy, https://github.com/wanchaol ghstack dependencies: pytorch#126067, pytorch#126070, pytorch#126161
- Loading branch information
Showing
5 changed files
with
98 additions
and
64 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters