Skip to content

Commit

Permalink
fixing optimizer sanity check (microsoft#2742)
Browse files Browse the repository at this point in the history
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
  • Loading branch information
2 people authored and NonSkull committed Mar 6, 2023
1 parent 42e2046 commit 8510a51
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 15 deletions.
25 changes: 10 additions & 15 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,7 @@ def get_data_types(self):
model_dtype = torch.bfloat16

if self._config.grad_accum_dtype == None:
if model_dtype == torch.bfloat16:
if model_dtype == torch.bfloat16 and not self.zero_optimization():
grad_accum_dtype = torch.float32
else:
grad_accum_dtype = model_dtype
Expand Down Expand Up @@ -1204,10 +1204,6 @@ def _do_optimizer_sanity_check(self, basic_optimizer):
not (amp_enabled and zero_enabled)
), "Amp and ZeRO are not currently compatible, please use (legacy) fp16 mode which performs similar to amp opt_mode=O2"
if zero_enabled:
if model_dtype != grad_accum_dtype:
raise NotImplementedError(
"Model data type and gradient accumulation data type must be equal to use ZeRO"
)
if not is_zero_supported_optimizer(basic_optimizer):
assert (
self.zero_allow_untested_optimizer()
Expand All @@ -1217,16 +1213,15 @@ def _do_optimizer_sanity_check(self, basic_optimizer):
logger.warning(
"**** You are using ZeRO with an untested optimizer, proceed with caution *****"
)
# BF16 optimizer supports stage 1 optimizations
if model_dtype == torch.bfloat16:
if grad_accum_dtype != torch.float32:
raise NotImplementedError(
"BF16 optimizer for ZeRO requires fp32 gradient accumulation")
if self.zero_optimization_stage() == 1:
return BFLOAT16
else:
raise NotImplementedError(
"ZeRO stages 2 and 3 are not supported with the BF16 optimizer")

if model_dtype == torch.bfloat16 and grad_accum_dtype == torch.float32 and self.zero_optimization_stage(
) == 1:
return BFLOAT16

if model_dtype != grad_accum_dtype:
raise NotImplementedError(
"Model data type and gradient accumulation data type must be equal to use ZeRO"
)
return ZERO_OPTIMIZATION
elif amp_enabled:
if model_dtype != grad_accum_dtype:
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/runtime/test_ds_initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,16 @@ def test(self, optimizer_extension, model_dtype, grad_accum_dtype):
# ZeRO 1 Wrapper
is_supported[('zero1', 'fp16', None)] = True
is_supported[('zero1', 'fp16', 'fp16')] = True
is_supported[('zero1', 'bf16', None)] = True
is_supported[('zero1', 'bf16', 'bf16')] = True
is_supported[('zero1', 'bf16', 'fp32')] = True
is_supported[('zero1', 'fp32', None)] = True
is_supported[('zero1', 'fp32', 'fp32')] = True
# ZeRO 2 Wrapper
is_supported[('zero2', 'fp16', None)] = True
is_supported[('zero2', 'fp16', 'fp16')] = True
is_supported[('zero2', 'bf16', None)] = True
is_supported[('zero2', 'bf16', 'bf16')] = True
is_supported[('zero2', 'fp32', None)] = True
is_supported[('zero2', 'fp32', 'fp32')] = True
# Amp Wrapper
Expand Down

0 comments on commit 8510a51

Please sign in to comment.