Skip to content

Commit

Permalink
[zero] revert PR microsoft#3166, it disabled grad clip for bf16 (micr…
Browse files Browse the repository at this point in the history
…osoft#3790)

* zero++ tutorial PR (microsoft#3783)

* [Fix] _conv_flops_compute when padding is a str and stride=1 (microsoft#3169)

* fix conv_flops_compute when padding is a str when stride=1

* fix error

* change type of paddings to tuple

* fix padding calculation

* apply formatting check

---------

Co-authored-by: Cheng Li <pistasable@gmail.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>

* fix interpolate flops compute (microsoft#3782)

* use `Flops Profiler` to test `model.generate()` (microsoft#2515)

* Update profiler.py

* pre-commit run --all-files

* Delete .DS_Store

* Delete .DS_Store

* Delete .DS_Store

---------

Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
Co-authored-by: Cheng Li <pistasable@gmail.com>

* revert PR microsoft#3166, it disabled grad clip for bf16

* ensure no loss scaling for non-fp16 dtypes

* revert PR microsoft#3611 (microsoft#3786)

* bump to 0.9.6

* ZeRO++ chinese blog (microsoft#3793)

* zeropp chinese blog

* try better quality images

* make title larger

* even larger...

* various fix

* center captions

* more fixes

* fix format

* remove staging trigger (microsoft#3792)

* DeepSpeed-Triton for Inference (microsoft#3748)

Co-authored-by: Stephen Youn <styoun@microsoft.com>
Co-authored-by: Arash Bakhtiari <arash@bakhtiari.org>
Co-authored-by: Cheng Li <pistasable@gmail.com>
Co-authored-by: Ethan Doe <yidoe@microsoft.com>
Co-authored-by: yidoe <68296935+yidoe@users.noreply.github.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* ZeRO++ (microsoft#3784)

Co-authored-by: HeyangQin <heyangqin@microsoft.com>
Co-authored-by: GuanhuaWang <alexwgh333@gmail.com>
Co-authored-by: cmikeh2 <connorholmes@microsoft.com>
Co-authored-by: Ammar Ahmad Awan <ammar.awan@microsoft.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Reza Yazdani <reyazda@microsoft.com>

* adding zero++ to navigation panel of deepspeed.ai (microsoft#3796)

* Add ZeRO++ Japanese blog (microsoft#3797)

* zeropp chinese blog

* try better quality images

* make title larger

* even larger...

* various fix

* center captions

* more fixes

* fix format

* add ZeRO++ Japanese blog

* add links

---------

Co-authored-by: HeyangQin <heyangqin@microsoft.com>
Co-authored-by: Conglong Li <conglong.li@gmail.com>

* Bug Fixes for autotuner and flops profiler (microsoft#1880)

* fix autotuner when backward is not called

* fix format

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>

* Missing strided copy for gated MLP (microsoft#3788)

Co-authored-by: Ammar Ahmad Awan <ammar.awan@microsoft.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>

* Requires grad checking. (microsoft#3789)

Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* bump to 0.10.0

* Fix Bug in transform.cu (microsoft#3534)

* Bug fix

* Fixed formatting error

---------

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>

* bug fix: triton importing error (microsoft#3799)

Co-authored-by: Stephen Youn <styoun@microsoft.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

---------

Co-authored-by: Heyang Qin <heyangqin@microsoft.com>
Co-authored-by: Bill Luo <50068224+zhiruiluo@users.noreply.github.com>
Co-authored-by: Cheng Li <pistasable@gmail.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Guorun <84232793+CaffreyR@users.noreply.github.com>
Co-authored-by: stephen youn <13525892+stephen-youn@users.noreply.github.com>
Co-authored-by: Stephen Youn <styoun@microsoft.com>
Co-authored-by: Arash Bakhtiari <arash@bakhtiari.org>
Co-authored-by: Ethan Doe <yidoe@microsoft.com>
Co-authored-by: yidoe <68296935+yidoe@users.noreply.github.com>
Co-authored-by: GuanhuaWang <alexwgh333@gmail.com>
Co-authored-by: cmikeh2 <connorholmes@microsoft.com>
Co-authored-by: Ammar Ahmad Awan <ammar.awan@microsoft.com>
Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
Co-authored-by: Reza Yazdani <reyazda@microsoft.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
Co-authored-by: Conglong Li <conglong.li@gmail.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Joe Mayer <114769929+jomayeri@users.noreply.github.com>
Co-authored-by: Ramya Ramineni <62723901+rraminen@users.noreply.github.com>
  • Loading branch information
21 people authored and zhangir-azerbayev committed Jul 25, 2023
1 parent 22fda1e commit e4f92ad
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,11 @@ def __init__(self,
dynamic_loss_args=dynamic_loss_args)
self.dynamic_loss_scale = self.loss_scaler.dynamic

if self.dtype != torch.float16:
# Only fp16 should use dynamic loss scaling
assert self.loss_scaler.cur_scale == 1.0
assert not self.dynamic_loss_scale

see_memory_usage("Before initializing optimizer states", force=True)
self.initialize_optimizer_states()
see_memory_usage("After initializing optimizer states", force=True)
Expand Down Expand Up @@ -1670,21 +1675,19 @@ def step(self, closure=None):
self.stop_timers(timer_names)
return

# Step 1:- Calculate gradient norm using fp-16 grads
if self.dtype == torch.float16:
see_memory_usage('Before norm calculation')
scaled_global_grad_norm = self.scaled_global_norm()
self._global_grad_norm = scaled_global_grad_norm / prev_scale
see_memory_usage('After norm before optimizer')
# Step 1:- Calculate gradient norm using bit-16 grads
see_memory_usage('Before norm calculation')
scaled_global_grad_norm = self.scaled_global_norm()
self._global_grad_norm = scaled_global_grad_norm / prev_scale
see_memory_usage('After norm before optimizer')

# Step 2:- run optimizer and upscaling simultaneously
for i, group in enumerate(self.bit16_groups):
self.start_timers([OPTIMIZER_GRADIENTS])
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
if self.cpu_offload:
single_grad_partition = self.single_partition_of_fp32_groups[i].grad
if self.dtype == torch.float16:
self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm)
self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm)

self.stop_timers([OPTIMIZER_GRADIENTS])
self.start_timers([OPTIMIZER_STEP])
Expand Down Expand Up @@ -1724,8 +1727,7 @@ def step(self, closure=None):

self.averaged_gradients[i] = None

if self.dtype == torch.float16:
self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm)
self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm)

self.stop_timers([OPTIMIZER_GRADIENTS])

Expand Down

0 comments on commit e4f92ad

Please sign in to comment.