Skip to content

Enable torch.autocast with ZeRO #6993

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 89 commits into from
Jun 19, 2025
Merged

Enable torch.autocast with ZeRO #6993

merged 89 commits into from
Jun 19, 2025

Conversation

tohtana
Copy link
Contributor

@tohtana tohtana commented Feb 3, 2025

DeepSpeed supports mixed precision training, but the behavior is different from torch.autocast. DeepSpeed maintains parameters and gradients both in FP32 and a lower precision (FP16/BF16) (NVIDIA Apex AMP style) and computes all modules in the lower precision while torch.autocast maintains parameters in FP32 but computes only certain operators in the lower precision.
This leads to differences in:

  • performance: torch.autocast needs downcast in forward/backward
  • memory usage: DeepSpeed needs more memory to keep copies of parameters and gradients in lower precision
  • accuracy: torch.autocast has a list of modules that can safely be computed in lower precision. Some precision-sensitive operators (e.g. softmax) are computed in FP32.

To align DeepSpeed's behavior with torch.autocast when necessary, this PR adds the integration with torch.autocast with ZeRO. Here is an examples of the configuration.

"torch_autocast": {
  "enabled": true,
  "dtype": "bfloat16",
  "lower_precision_safe_modules": ["torch.nn.Linear", "torch.nn.Conv2d"]
}

Each configuration works as follows:

  • enabled: Enable the integration with torch.autocast if this is set to True. You don't need to call torch.autocast in your code. The grad scaler is also applied in the DeepSpeed optimizer.
  • dtype: lower precision dtype passed to torch.autocast. Gradients for allreduce (reduce-scatter) and parameters for allgather (only for ZeRO3) of lower_precision_safe_modules are also downcasted to this dtype.
  • lower_precision_safe_modules: Downcast for allreduce (reduce-scatter) and allgather (ZeRO3) are applied only to modules specified in this list. (The precision for PyTorch operators in forward/backward follows torch.autocast's policy, not this list.) You can set names of classes with their packages. If you don't set this item, DeepSpeed uses the default list: [torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d].

Note that we only maintain FP32 parameters with this feature enabled. For consistency, you cannot enable fp16 or bf16 in DeepSpeed config.

tjruwase and others added 30 commits February 28, 2025 22:53
Fix #6772

---------

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
…#6967)

- Issues with nv-sd updates, will follow up with a subsequent PR

Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
NVIDIA Blackwell GPU generation has number 10. The SM code and
architecture should be `100`, but the current code generates `1.`,
because it expects a 2 characters string.

This change modifies the logic to consider it as a string that contains
a `.`, hence splits the string and uses the array of strings.

Signed-off-by: Fabien Dupont <fdupont@redhat.com>
Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com>
Signed-off-by: Logan Adams <loadams@microsoft.com>
Signed-off-by: Fabien Dupont <fdupont@redhat.com>
Co-authored-by: Fabien Dupont <fabiendupont@fabiendupont.fr>
Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
1. update intel oneAPI basekit to 2025.0
2. update torch/ipex/oneccl to 2.5

Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
Same as [this PR](#6922).
[affeb88](affeb88)
I noticed the CI updated the DCO check recently. Using the suggested
rebase method for sign-off would reintroduce many conflicts, so I opted
for a squash merge with sign-off instead. thanks: )

Signed-off-by: inkcherry <mingzhi.liu@intel.com>
Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
Those files have code that gets run when importing them, so in systems
that doesn't support triton but have triton installed this causes
issues.

In general, I think it is better to import triton when it is installed
and supported.

Signed-off-by: Omar Elayan <oelayan@habana.ai>
Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
Signed-off-by: Logan Adams <loadams@microsoft.com>
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
Fix #7014
Avoid naming collision on `partition()`

---------

Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com>
Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
Fix typos

Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
BUGFIX for Apple Silicon hostname
#6497

---------

Signed-off-by: Fabien Dupont <fdupont@redhat.com>
Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com>
Signed-off-by: Logan Adams <loadams@microsoft.com>
Signed-off-by: inkcherry <mingzhi.liu@intel.com>
Signed-off-by: Roman Fitzjalen <romaactor@gmail.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Fabien Dupont <fabiendupont@fabiendupont.fr>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Liangliang Ma <1906710196@qq.com>
Co-authored-by: inkcherry <mingzhi.liu@intel.com>
Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
- Update existing workflows that use cu121 to cu124. Note, this means
that where we download torch latest, we will now be getting torch 2.6
rather than the torch latest 2.5 provided with cuda 12.1.
- Note, nv-nightly is failing in master currently due to unrelated
errors, so this could be ignored in this PR (nv-nightly tested locally,
where it passes with 12.1 and it also passes with 12.4).

---------

Signed-off-by: Fabien Dupont <fdupont@redhat.com>
Signed-off-by: Logan Adams <loadams@microsoft.com>
Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com>
Signed-off-by: inkcherry <mingzhi.liu@intel.com>
Signed-off-by: Omar Elayan <oelayan@habana.ai>
Co-authored-by: Fabien Dupont <fabiendupont@fabiendupont.fr>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Liangliang Ma <1906710196@qq.com>
Co-authored-by: inkcherry <mingzhi.liu@intel.com>
Co-authored-by: Omar Elayan <142979319+oelayan7@users.noreply.github.com>
Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
This change is required to successfully build fp_quantizer extension on
ROCm.

---------

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
cc @tjruwase @jomayeri

---------

Co-authored-by: root <root@ftqtmec25000000.taxzvufipdhelhupulxcbvr15f.ux.internal.cloudapp.net>
Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
Fix #7029
- Add Chinese blog for deepspeed windows
- Fix format in README.md

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
Adding compile support for AIO library on AMD GPUs.

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
Make trace cache warnings configurable, and disabled by default.

Fix #6985, #4081, #5033, #5006, #5662

---------

Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com>
Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
Update CUDA compute capability for cross compile according to wiki page.
https://en.wikipedia.org/wiki/CUDA#GPUs_supported

---------

Signed-off-by: Hongwei <hongweichen@microsoft.com>
Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
Signed-off-by: Logan Adams <loadams@microsoft.com>
Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
Propagate API change.

Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com>
Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
@tohtana
Copy link
Contributor Author

tohtana commented Apr 22, 2025

Hi @stas00, I tried to add the detection of nested autocast. This validation is called before the engine's forward.

I measured the default reduce_bucket_size 5e8 to consume 4GB peak memory usage when comms are in fp32. and only 1GB in bf16.

I see, I didn't know this behavior. It seems very weird that they allocate an additional buffer only for FP32, not for BF16. Perhaps this is a separate topic from this PR, but I will investigate it more when I have a chance.

@sfc-gh-sbekman
Copy link
Contributor

Thank you for looking into it, Masahiro. No problem doing it elsewhere.

Using torch mem profiler will be very helpful to see the reduction memory spikes

https://pytorch.org/blog/understanding-gpu-memory-1/ - it's very easy to set up - if you need help please let me know.

@sfc-gh-sbekman
Copy link
Contributor

Good, I can an assertion to detect that torch.autocast is enabled outside of DeepSpeed but ds_config doesn't set torch_autocast's enabled. Or it might be better to automatically enable it.

If it has to be on and it breaks nothing then automatically enabling it is probably a better idea to help with ease of use.

@tohtana
Copy link
Contributor Author

tohtana commented May 23, 2025

Hi @stas,
Thank you for your feedback!

If it has to be on and it breaks nothing then automatically enabling it is probably a better idea to help with ease of use.

After reviewing the design, I now feel automatically enabling it wouldn't be straightforward. This autocast feature sets some flags to parameters before the optimizer is initialized. However, we only know whether torch.autocast is enabled or not just before a forward pass call as with torch.autocast(...) is placed to wrap a forward call. Reinitializing parts of the optimizer at that point would complicate the code.

Given that, I think it’s better to throw an error with the explanation.

@stas00
Copy link
Collaborator

stas00 commented May 23, 2025

Then assert is the way to go, Masahiro

@tohtana
Copy link
Contributor Author

tohtana commented May 24, 2025

Then assert is the way to go, Masahiro

Thank you @stas00, then can you approve this PR?

@stas00
Copy link
Collaborator

stas00 commented May 27, 2025

Hmm, I can't just hit approve, that would be defeat the purpose of doing the review.

We have only discussed one small aspect of this PR, which has been resolved, but the rest of the PR I don't know and currently rushing to finish the porting of Ulysses to Hf/DS so until that is done I won't have time to do a serious review.

@tohtana tohtana enabled auto-merge (squash) June 19, 2025 20:23
@tohtana tohtana merged commit ed5f737 into master Jun 19, 2025
12 checks passed
@tohtana tohtana deleted the tohtana/support_autocast branch June 19, 2025 21:36
tohtana added a commit that referenced this pull request Jun 22, 2025
#6993 broke many paths in ZeRO1/2 optimizer. This PR fixes most of the
issues the PR caused. Currently we still have one error with tests in
`unit/runtime/zero`.

```
====================================== short test summary info ======================================
FAILED test_zero.py::TestParamPartitioningSkipInit::test[dtype1] - RuntimeError: mat1 and mat2 must have the same dtype, but got Half and BFloat16
========= 1 failed, 204 passed, 66 skipped, 15 deselected, 5 warnings in 2305.03s (0:38:25) =========
```

---------

Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.