Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add async grad allreduce and chunk optimization #4084

Merged
merged 28 commits into from
May 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
b3f7a84
O2 runs but O1 does not
ericharper Mar 18, 2022
b91f0e3
disable async for O1
ericharper Mar 22, 2022
32120c9
Merge branch 'main' into gpt_sync_handler
ericharper Mar 22, 2022
f49ec58
typo
ericharper Mar 23, 2022
73eb1a5
update async flag in configure_optimizers
ericharper Mar 23, 2022
d79cee2
typo
ericharper Mar 23, 2022
f9355ee
Merge branch 'main' into gpt_sync_handler
ericharper Mar 24, 2022
8fef3a9
revert
ericharper Mar 24, 2022
2e4da67
Merge branch 'gpt_sync_handler' of github.com:NVIDIA/NeMo into gpt_sy…
ericharper Mar 24, 2022
5f7a45b
update _require if using async
ericharper Mar 24, 2022
21e8e34
Merge branch 'main' into gpt_sync_handler
ericharper Mar 25, 2022
eabaf62
Merge branch 'main' into gpt_sync_handler
ericharper Apr 18, 2022
e0271a8
Merge branch 'main' into gpt_sync_handler
ericharper Apr 21, 2022
fcb58de
clean comments
ericharper Apr 21, 2022
4761b24
always all_reduce
ericharper Apr 27, 2022
b0ba0a9
add async grad allreduce and chunk optimization to T5
xrennvidia Apr 28, 2022
08b2a70
Merge branch 'main' into xren/t5_sync_handler
ericharper Apr 29, 2022
04eac5e
push reformatted files after style check
xrennvidia Apr 29, 2022
db9cbb9
Merge branch 'main' into xren/t5_sync_handler
ericharper Apr 29, 2022
28b795c
set chunk size as 0 while async grad allreduce is off
xrennvidia Apr 30, 2022
f6b2f3b
more experiments show that 125MB is a better default chunk size for m…
xrennvidia May 3, 2022
382c533
add grad_allreduce_chunk_size_mb for GPT-3
xrennvidia May 3, 2022
36dcfd2
at the end of each training step, wait until all async grad allreduce…
xrennvidia May 11, 2022
e7e8124
replace individual allreduce work.wait() with a single dGPU evice syn…
xrennvidia May 11, 2022
6b8b59e
record the status of each allreduce work seems too much for perf
xrennvidia May 12, 2022
2280634
add more comments
xrennvidia May 12, 2022
0469315
push a reformatted file
xrennvidia May 12, 2022
2ba62cf
Merge branch 'main' into xren/t5_sync_handler
ericharper May 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions examples/nlp/language_modeling/conf/megatron_gpt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ model:
pre_process: True # add embedding
post_process: True # add pooler
persist_layer_norm: True # Use of persistent fused layer norm kernel.
gradient_as_bucket_view: True # Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory)

tokenizer:
library: 'megatron'
Expand All @@ -82,13 +81,15 @@ model:
fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16

# Megatron O2-style half-precision
megatron_amp_O2: False # Enable O2-level automatic mixed precision using master parameters
megatron_amp_O2: False # Enable O2-level automatic mixed precision using main parameters
grad_allreduce_chunk_size_mb: 125

# miscellaneous
seed: 1234
use_cpu_initialization: False # Init weights on the CPU (slow for large models)
onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter.
apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this
gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory)

activations_checkpoint_method: null # 'uniform', 'block'
activations_checkpoint_num_layers: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ model:
post_process: True # add pooler

megatron_amp_O2: False # use AMP with O2 style mixed precision instead of native amp on-the-fly weight autocasting.
grad_allreduce_chunk_size_mb: 125

seq_length: 512
max_position_embeddings: ${.seq_length}
Expand Down
3 changes: 2 additions & 1 deletion examples/nlp/language_modeling/megatron_gpt_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ def main(cfg) -> None:
logging.info(f'\n{OmegaConf.to_yaml(cfg)}')

megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False)

plugins = [
NLPDDPPlugin(
no_ddp_communication_hook=True,
no_ddp_communication_hook=True, # we don't use DDP for async grad allreduce
gradient_as_bucket_view=cfg.model.gradient_as_bucket_view,
find_unused_parameters=False,
)
Expand Down
2 changes: 1 addition & 1 deletion examples/nlp/language_modeling/megatron_t5_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def main(cfg) -> None:
megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False)
plugins = [
NLPDDPPlugin(
no_ddp_communication_hook=True,
no_ddp_communication_hook=True, # we don't use DDP for async grad allreduce
gradient_as_bucket_view=cfg.model.gradient_as_bucket_view,
find_unused_parameters=False,
)
Expand Down
40 changes: 28 additions & 12 deletions nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ def training_step(self, batch, batch_idx):
tensor_shape = [self.cfg.encoder_seq_length, self.cfg.micro_batch_size, self.cfg.hidden_size]

if self.cfg.get('pipeline_model_parallel_size', 1) > 1:

losses_reduced_per_micro_batch = forward_backward_pipelining_without_interleaving(
forward_step_func=self.get_forward_output_and_loss_func(),
batch=batch_for_pipeline,
Expand All @@ -183,6 +182,12 @@ def training_step(self, batch, batch_idx):
grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None,
)
else:
# no pipeline parallelism so we reduce grads asynchronously
if self.megatron_amp_o2:
custom_sync_context_handler = self._optimizer.no_sync
else:
# TODO: enable async grad all reduce for O1/autocast mixed precision training
custom_sync_context_handler = None
losses_reduced_per_micro_batch = forward_backward_no_pipelining(
forward_step_func=self.get_forward_output_and_loss_func(),
batch=batch_for_pipeline,
Expand All @@ -191,6 +196,7 @@ def training_step(self, batch, batch_idx):
tensor_shape=tensor_shape,
dtype=self.autocast_dtype,
grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None,
custom_sync_context_handler=custom_sync_context_handler,
)

# only the last stages of the pipeline return losses
Expand All @@ -202,19 +208,26 @@ def training_step(self, batch, batch_idx):
else:
loss_mean = torch.tensor(0.0).cuda()

# TODO: if we're not using pipeline, then we should do async allreduce (better perf)
# in order to do this with O2, we need the async handler to be added to apex fwd/bwd function
if self.megatron_amp_o2:
# main grads are stored in the MainParamsOptimizer wrapper
self._optimizer.allreduce_main_grads() # @sangkug we think this is fine

self.allreduce_first_last_embeddings()
# when using pipeline parallelism grads must be reduced after the pipeline (not asynchronously)
if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
# main grads are stored in the MainParamsOptimizer wrapper
self._optimizer.allreduce_main_grads()
else:

# async grad allreduce is not currently implemented for O1/autocasting mixed precision training
# so we allreduce gradients after the pipeline
self.allreduce_gradients() # @sangkug we think this is causing memory to blow up (hurts perf)

if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
# when using pipeline parallelism the first and last stage must keep embeddings in sync
self.allreduce_first_last_embeddings()

# while async grad allreduce is enabled, bprop will keep moving forward without waiting for
# the finish of async grad AR works. Hence, to guarantee the correctness of grads reduction,
# we cannot start weight update until all async grad AR works are done.
if self.megatron_amp_o2 and self.cfg.get('pipeline_model_parallel_size', 1) == 1:
torch.cuda.synchronize()
ericharper marked this conversation as resolved.
Show resolved Hide resolved

## logging
# we can only log on one rank if it is rank zero so we broadcast from last rank
# we can avoid this broadcast by updating the PTL log function to accept specific ranks
Expand Down Expand Up @@ -616,17 +629,20 @@ def configure_optimizers(self):
"fp16 training is not yet supported with O2. Please set megatron_amp_O2 to False in the model config."
)

# TODO: this should be true when not using pipeline parallelism
# we will support that for bf16 when we have async handler from apex
# and we will support it for fp16 when we have it implemented in the O2 recipe
async_grad_allreduce = False
# if using tensor parallel only, we can use async grad all-reduce
if self.cfg.get('pipeline_model_parallel_size', 1) == 1:
async_grad_allreduce = True
else:
async_grad_allreduce = False

self._optimizer = MainParamsOptimizerWrapper(
self._optimizer,
fp32_grad_accum=fp32_grad_accum,
contiguous_grad_bucket=contiguous_grad_bucket,
async_grad_allreduce=async_grad_allreduce,
grad_allreduce_chunk_size_mb=self.cfg.get('grad_allreduce_chunk_size_mb', 125),
)

assert self._trainer.max_steps is not None, "'max_steps' is missing in trainer config."
sched_config = self._cfg.optim.sched
sched_config['max_steps'] = self._trainer.max_steps
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,12 @@ def training_step(self, batch, batch_idx):
grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None,
)
else:
# no pipeline parallelism so we reduce grads asynchronously
if self.megatron_amp_o2:
custom_sync_context_handler = self._optimizer.no_sync
else:
# TODO: enable async grad all reduce for O1/autocast mixed precision training
custom_sync_context_handler = None
losses_reduced_per_micro_batch = forward_backward_no_pipelining(
forward_step_func=self.get_forward_output_and_loss_func(),
batch=batch_for_pipeline,
Expand All @@ -225,6 +231,7 @@ def training_step(self, batch, batch_idx):
decoder_sequence_length=decoder_seq_length,
dtype=self.autocast_dtype,
grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None,
custom_sync_context_handler=custom_sync_context_handler,
)

# only the last stages of the pipeline return losses
Expand All @@ -236,19 +243,26 @@ def training_step(self, batch, batch_idx):
else:
loss_mean = torch.tensor(0.0).cuda()

# TODO: if we're not using pipeline, then we should do async allreduce (better perf)
# in order to do this with O2, we need the async handler to be added to apex fwd/bwd function
if self.megatron_amp_o2:
# main grads are stored in the MainParamsOptimizer wrapper
self._optimizer.allreduce_main_grads() # @sangkug we think this is fine

self.allreduce_word_and_position_embeddings()
# when using pipeline parallelism grads must be reduced after the pipeline (not asynchronously)
if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
# main grads are stored in the MainParamsOptimizer wrapper
self._optimizer.allreduce_main_grads()
else:

# async grad allreduce is not currently implemented for O1/autocasting mixed precision training
# so we allreduce gradients after the pipeline
self.allreduce_gradients() # @sangkug we think this is causing memory to blow up (hurts perf)

if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
# when using pipeline parallelism, we need keep the word and position embeddings in sync
self.allreduce_word_and_position_embeddings()

# while async grad allreduce is enabled, bprop will keep moving forward without waiting for
# the finish of async grad AR works. Hence, to guarantee the correctness of grads reduction,
# we cannot start weight update until all async grad AR works are done.
if self.megatron_amp_o2 and self.cfg.get('pipeline_model_parallel_size', 1) == 1:
torch.cuda.synchronize()
ericharper marked this conversation as resolved.
Show resolved Hide resolved

## logging
# we can only log on one rank if it is rank zero so we broadcast from last rank
# we can avoid this broadcast by updating the PTL log function to accept specific ranks
Expand All @@ -270,6 +284,7 @@ def training_step(self, batch, batch_idx):
prog_bar=True,
rank_zero_only=True,
)

return loss_mean

def on_train_batch_end(self, outputs, batch, batch_idx: int, unused: Optional[int] = 0) -> None:
Expand Down Expand Up @@ -720,17 +735,20 @@ def configure_optimizers(self):
# TODO: contiguous grad bucket for fp16 is also planned to be supported
contiguous_grad_bucket = False

# TODO: this should be true when not using pipeline parallelism
# we will support that for bf16 when we have async handler from apex
# and we will support it for fp16 when we have it implemented in the O2 recipe
async_grad_allreduce = False
# if using tensor parallel only, we can use async grad all-reduce
if self.cfg.get('pipeline_model_parallel_size', 1) == 1:
async_grad_allreduce = True
else:
async_grad_allreduce = False

self._optimizer = MainParamsOptimizerWrapper(
self._optimizer,
fp32_grad_accum=fp32_grad_accum,
contiguous_grad_bucket=contiguous_grad_bucket,
async_grad_allreduce=async_grad_allreduce,
grad_allreduce_chunk_size_mb=self.cfg.get('grad_allreduce_chunk_size_mb', 125),
)

assert self._trainer.max_steps is not None, "'max_steps' is missing in trainer config."
if hasattr(self._cfg.optim, 'sched'):
sched_config = self._cfg.optim.sched
Expand Down
10 changes: 2 additions & 8 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def configure_ddp(self):
process_group=parallel_state.get_data_parallel_group(),
**self._ddp_kwargs,
)
self._register_ddp_hooks()

if self.no_ddp_communication_hook:
# When using custom gradient accumulation and allreduce, disable
# DDP communication hook that works on the gradient bucket.
Expand Down Expand Up @@ -582,13 +582,7 @@ def optimizer_step(

if self.scaler is None:
assert optimizer.fp32_grad_accumulation, "BF16 uses FP32 grad accumulation"
if optimizer.async_master_grads_allreudce:
# Execute the last step with asynchronous master gradients all-reduce
with optimizer.grad_sync():
_ = closure()
else:
_ = closure()

_ = closure()
self._after_closure(model, optimizer, optimizer_idx)
return optimizer.step(**kwargs)

Expand Down
Loading