Skip to content

Commit

Permalink
add async grad allreduce and chunk optimization (#4084)
Browse files Browse the repository at this point in the history
* O2 runs but O1 does not

Signed-off-by: ericharper <complex451@gmail.com>

* disable async for O1

Signed-off-by: ericharper <complex451@gmail.com>

* typo

Signed-off-by: ericharper <complex451@gmail.com>

* update async flag in configure_optimizers

Signed-off-by: ericharper <complex451@gmail.com>

* typo

Signed-off-by: ericharper <complex451@gmail.com>

* revert

Signed-off-by: ericharper <complex451@gmail.com>

* update _require if using async

Signed-off-by: ericharper <complex451@gmail.com>

* clean comments

Signed-off-by: ericharper <complex451@gmail.com>

* always all_reduce

Signed-off-by: ericharper <complex451@gmail.com>

* add async grad allreduce and chunk optimization to T5

* push reformatted files after style check

* set chunk size as 0 while async grad allreduce is off

* more experiments show that 125MB is a better default chunk size for most cases

* add grad_allreduce_chunk_size_mb for GPT-3

* at the end of each training step, wait until all async grad allreduce works are done

* replace individual allreduce work.wait() with a single dGPU evice synchroonization

* record the status of each allreduce work seems too much for perf

* add more comments

* push a reformatted file

Co-authored-by: ericharper <complex451@gmail.com>
  • Loading branch information
xrennvidia and ericharper committed May 17, 2022
1 parent 89994de commit de0b445
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 47 deletions.
5 changes: 3 additions & 2 deletions examples/nlp/language_modeling/conf/megatron_gpt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ model:
pre_process: True # add embedding
post_process: True # add pooler
persist_layer_norm: True # Use of persistent fused layer norm kernel.
gradient_as_bucket_view: True # Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory)

tokenizer:
library: 'megatron'
Expand All @@ -82,13 +81,15 @@ model:
fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16

# Megatron O2-style half-precision
megatron_amp_O2: False # Enable O2-level automatic mixed precision using master parameters
megatron_amp_O2: False # Enable O2-level automatic mixed precision using main parameters
grad_allreduce_chunk_size_mb: 125

# miscellaneous
seed: 1234
use_cpu_initialization: False # Init weights on the CPU (slow for large models)
onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter.
apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this
gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory)

activations_checkpoint_method: null # 'uniform', 'block'
activations_checkpoint_num_layers: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ model:
post_process: True # add pooler

megatron_amp_O2: False # use AMP with O2 style mixed precision instead of native amp on-the-fly weight autocasting.
grad_allreduce_chunk_size_mb: 125

seq_length: 512
max_position_embeddings: ${.seq_length}
Expand Down
3 changes: 2 additions & 1 deletion examples/nlp/language_modeling/megatron_gpt_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ def main(cfg) -> None:
logging.info(f'\n{OmegaConf.to_yaml(cfg)}')

megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False)

plugins = [
NLPDDPPlugin(
no_ddp_communication_hook=True,
no_ddp_communication_hook=True, # we don't use DDP for async grad allreduce
gradient_as_bucket_view=cfg.model.gradient_as_bucket_view,
find_unused_parameters=False,
)
Expand Down
2 changes: 1 addition & 1 deletion examples/nlp/language_modeling/megatron_t5_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def main(cfg) -> None:
megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False)
plugins = [
NLPDDPPlugin(
no_ddp_communication_hook=True,
no_ddp_communication_hook=True, # we don't use DDP for async grad allreduce
gradient_as_bucket_view=cfg.model.gradient_as_bucket_view,
find_unused_parameters=False,
)
Expand Down
40 changes: 28 additions & 12 deletions nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ def training_step(self, batch, batch_idx):
tensor_shape = [self.cfg.encoder_seq_length, self.cfg.micro_batch_size, self.cfg.hidden_size]

if self.cfg.get('pipeline_model_parallel_size', 1) > 1:

losses_reduced_per_micro_batch = forward_backward_pipelining_without_interleaving(
forward_step_func=self.get_forward_output_and_loss_func(),
batch=batch_for_pipeline,
Expand All @@ -183,6 +182,12 @@ def training_step(self, batch, batch_idx):
grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None,
)
else:
# no pipeline parallelism so we reduce grads asynchronously
if self.megatron_amp_o2:
custom_sync_context_handler = self._optimizer.no_sync
else:
# TODO: enable async grad all reduce for O1/autocast mixed precision training
custom_sync_context_handler = None
losses_reduced_per_micro_batch = forward_backward_no_pipelining(
forward_step_func=self.get_forward_output_and_loss_func(),
batch=batch_for_pipeline,
Expand All @@ -191,6 +196,7 @@ def training_step(self, batch, batch_idx):
tensor_shape=tensor_shape,
dtype=self.autocast_dtype,
grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None,
custom_sync_context_handler=custom_sync_context_handler,
)

# only the last stages of the pipeline return losses
Expand All @@ -202,19 +208,26 @@ def training_step(self, batch, batch_idx):
else:
loss_mean = torch.tensor(0.0).cuda()

# TODO: if we're not using pipeline, then we should do async allreduce (better perf)
# in order to do this with O2, we need the async handler to be added to apex fwd/bwd function
if self.megatron_amp_o2:
# main grads are stored in the MainParamsOptimizer wrapper
self._optimizer.allreduce_main_grads() # @sangkug we think this is fine

self.allreduce_first_last_embeddings()
# when using pipeline parallelism grads must be reduced after the pipeline (not asynchronously)
if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
# main grads are stored in the MainParamsOptimizer wrapper
self._optimizer.allreduce_main_grads()
else:

# async grad allreduce is not currently implemented for O1/autocasting mixed precision training
# so we allreduce gradients after the pipeline
self.allreduce_gradients() # @sangkug we think this is causing memory to blow up (hurts perf)

if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
# when using pipeline parallelism the first and last stage must keep embeddings in sync
self.allreduce_first_last_embeddings()

# while async grad allreduce is enabled, bprop will keep moving forward without waiting for
# the finish of async grad AR works. Hence, to guarantee the correctness of grads reduction,
# we cannot start weight update until all async grad AR works are done.
if self.megatron_amp_o2 and self.cfg.get('pipeline_model_parallel_size', 1) == 1:
torch.cuda.synchronize()

## logging
# we can only log on one rank if it is rank zero so we broadcast from last rank
# we can avoid this broadcast by updating the PTL log function to accept specific ranks
Expand Down Expand Up @@ -616,17 +629,20 @@ def configure_optimizers(self):
"fp16 training is not yet supported with O2. Please set megatron_amp_O2 to False in the model config."
)

# TODO: this should be true when not using pipeline parallelism
# we will support that for bf16 when we have async handler from apex
# and we will support it for fp16 when we have it implemented in the O2 recipe
async_grad_allreduce = False
# if using tensor parallel only, we can use async grad all-reduce
if self.cfg.get('pipeline_model_parallel_size', 1) == 1:
async_grad_allreduce = True
else:
async_grad_allreduce = False

self._optimizer = MainParamsOptimizerWrapper(
self._optimizer,
fp32_grad_accum=fp32_grad_accum,
contiguous_grad_bucket=contiguous_grad_bucket,
async_grad_allreduce=async_grad_allreduce,
grad_allreduce_chunk_size_mb=self.cfg.get('grad_allreduce_chunk_size_mb', 125),
)

assert self._trainer.max_steps is not None, "'max_steps' is missing in trainer config."
sched_config = self._cfg.optim.sched
sched_config['max_steps'] = self._trainer.max_steps
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,12 @@ def training_step(self, batch, batch_idx):
grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None,
)
else:
# no pipeline parallelism so we reduce grads asynchronously
if self.megatron_amp_o2:
custom_sync_context_handler = self._optimizer.no_sync
else:
# TODO: enable async grad all reduce for O1/autocast mixed precision training
custom_sync_context_handler = None
losses_reduced_per_micro_batch = forward_backward_no_pipelining(
forward_step_func=self.get_forward_output_and_loss_func(),
batch=batch_for_pipeline,
Expand All @@ -225,6 +231,7 @@ def training_step(self, batch, batch_idx):
decoder_sequence_length=decoder_seq_length,
dtype=self.autocast_dtype,
grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None,
custom_sync_context_handler=custom_sync_context_handler,
)

# only the last stages of the pipeline return losses
Expand All @@ -236,19 +243,26 @@ def training_step(self, batch, batch_idx):
else:
loss_mean = torch.tensor(0.0).cuda()

# TODO: if we're not using pipeline, then we should do async allreduce (better perf)
# in order to do this with O2, we need the async handler to be added to apex fwd/bwd function
if self.megatron_amp_o2:
# main grads are stored in the MainParamsOptimizer wrapper
self._optimizer.allreduce_main_grads() # @sangkug we think this is fine

self.allreduce_word_and_position_embeddings()
# when using pipeline parallelism grads must be reduced after the pipeline (not asynchronously)
if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
# main grads are stored in the MainParamsOptimizer wrapper
self._optimizer.allreduce_main_grads()
else:

# async grad allreduce is not currently implemented for O1/autocasting mixed precision training
# so we allreduce gradients after the pipeline
self.allreduce_gradients() # @sangkug we think this is causing memory to blow up (hurts perf)

if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
# when using pipeline parallelism, we need keep the word and position embeddings in sync
self.allreduce_word_and_position_embeddings()

# while async grad allreduce is enabled, bprop will keep moving forward without waiting for
# the finish of async grad AR works. Hence, to guarantee the correctness of grads reduction,
# we cannot start weight update until all async grad AR works are done.
if self.megatron_amp_o2 and self.cfg.get('pipeline_model_parallel_size', 1) == 1:
torch.cuda.synchronize()

## logging
# we can only log on one rank if it is rank zero so we broadcast from last rank
# we can avoid this broadcast by updating the PTL log function to accept specific ranks
Expand All @@ -270,6 +284,7 @@ def training_step(self, batch, batch_idx):
prog_bar=True,
rank_zero_only=True,
)

return loss_mean

def on_train_batch_end(self, outputs, batch, batch_idx: int, unused: Optional[int] = 0) -> None:
Expand Down Expand Up @@ -720,17 +735,20 @@ def configure_optimizers(self):
# TODO: contiguous grad bucket for fp16 is also planned to be supported
contiguous_grad_bucket = False

# TODO: this should be true when not using pipeline parallelism
# we will support that for bf16 when we have async handler from apex
# and we will support it for fp16 when we have it implemented in the O2 recipe
async_grad_allreduce = False
# if using tensor parallel only, we can use async grad all-reduce
if self.cfg.get('pipeline_model_parallel_size', 1) == 1:
async_grad_allreduce = True
else:
async_grad_allreduce = False

self._optimizer = MainParamsOptimizerWrapper(
self._optimizer,
fp32_grad_accum=fp32_grad_accum,
contiguous_grad_bucket=contiguous_grad_bucket,
async_grad_allreduce=async_grad_allreduce,
grad_allreduce_chunk_size_mb=self.cfg.get('grad_allreduce_chunk_size_mb', 125),
)

assert self._trainer.max_steps is not None, "'max_steps' is missing in trainer config."
if hasattr(self._cfg.optim, 'sched'):
sched_config = self._cfg.optim.sched
Expand Down
10 changes: 2 additions & 8 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def configure_ddp(self):
process_group=parallel_state.get_data_parallel_group(),
**self._ddp_kwargs,
)
self._register_ddp_hooks()

if self.no_ddp_communication_hook:
# When using custom gradient accumulation and allreduce, disable
# DDP communication hook that works on the gradient bucket.
Expand Down Expand Up @@ -582,13 +582,7 @@ def optimizer_step(

if self.scaler is None:
assert optimizer.fp32_grad_accumulation, "BF16 uses FP32 grad accumulation"
if optimizer.async_master_grads_allreudce:
# Execute the last step with asynchronous master gradients all-reduce
with optimizer.grad_sync():
_ = closure()
else:
_ = closure()

_ = closure()
self._after_closure(model, optimizer, optimizer_idx)
return optimizer.step(**kwargs)

Expand Down
Loading

0 comments on commit de0b445

Please sign in to comment.