Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement async distributed checkpoint save #9028

Merged
merged 86 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
86 commits
Select commit Hold shift + click to select a range
1ade5a6
Prevent duplicated checkpoints
mikolajblaz Apr 18, 2024
adba993
Introduce DistributedCheckpointIO
mikolajblaz Apr 18, 2024
8dbd4f0
Fix DistCkptIO usage
mikolajblaz Apr 19, 2024
5c66f7a
Use NeMo logger
mikolajblaz Apr 19, 2024
a87313f
[DCIO] Fix save_to dist ckpt path
mikolajblaz Apr 22, 2024
dafba35
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 23, 2024
5ba8f13
Add versioning to save_to
mikolajblaz Apr 22, 2024
c02d3f5
Add versioning logic to all .nemo files
mikolajblaz Apr 23, 2024
96361d4
Add versioning test
mikolajblaz Apr 23, 2024
e6a9706
Add dist-ckpt test
mikolajblaz Apr 23, 2024
76df0d4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 23, 2024
272233f
Rename existing ckpts instead of using different name
mikolajblaz Apr 23, 2024
10297aa
Add comment
mikolajblaz Apr 23, 2024
1322cc9
Merge branch 'main' into mblaz/prevent-duplicated-checkpoints
mikolajblaz Apr 23, 2024
9ac26f6
Use dist ckpt flag in all methods
mikolajblaz Apr 23, 2024
051496b
Improve error msg
mikolajblaz Apr 23, 2024
fd80a90
Add dist ckpt unit tests
mikolajblaz Apr 23, 2024
048ccf4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 23, 2024
2504e01
Fix load_checkpoint
mikolajblaz Apr 23, 2024
7bcaf2a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 23, 2024
b8fc103
Merge remote-tracking branch 'origin/mblaz/dist-ckpt-io' into mblaz/d…
mikolajblaz Apr 23, 2024
db642eb
Fix auto-issues
mikolajblaz Apr 24, 2024
78715de
Fix ckpt_dir var
mikolajblaz Apr 24, 2024
86e3b16
Restore skipping behavior
mikolajblaz Apr 25, 2024
aba83f1
Fix steps on single-GPU machine
mikolajblaz Apr 25, 2024
37f845f
Run dist-ckpt test on GPU
mikolajblaz Apr 25, 2024
332108c
Add docs
mikolajblaz Apr 25, 2024
7557358
Apply black
mikolajblaz Apr 25, 2024
863abfd
Merge branch 'main' into mblaz/dist-ckpt-io
mikolajblaz Apr 26, 2024
5f01f91
Prevent saving last for non-equal val intervals
mikolajblaz Apr 26, 2024
23fb08a
Move checkpoint on rank 0
mikolajblaz Apr 29, 2024
daf9eb6
Fix num steps in tests
mikolajblaz Apr 29, 2024
d5e50d6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 29, 2024
feec31d
Merge branch 'mblaz/dist-ckpt-io' into mblaz/async-dist-ckpt-minimal-…
mikolajblaz Apr 29, 2024
b4dbeb3
Add async ckpt implementation
mikolajblaz Apr 17, 2024
e4807f7
Abstract AsyncFinalizableCheckpointIO away
mikolajblaz Apr 19, 2024
f9d58e4
Change async_save flag location
mikolajblaz Apr 19, 2024
6e0f997
Add debug info
mikolajblaz Apr 22, 2024
78919b4
Apply formatting
mikolajblaz Apr 24, 2024
4043806
Handle multiple async saves
mikolajblaz Apr 24, 2024
211042a
Apply formatting
mikolajblaz Apr 24, 2024
76e1b18
Move finalization calls to a callback
mikolajblaz Apr 24, 2024
2bd6366
Avoid deadlock in teardown
mikolajblaz Apr 24, 2024
4033f59
Adjust to MCore implementation
mikolajblaz Apr 24, 2024
7e39dc2
Add notes and copyrights
mikolajblaz Apr 24, 2024
843a76e
Apply formatting
mikolajblaz Apr 24, 2024
b7c15e5
Fix async_request attribute
mikolajblaz Apr 24, 2024
06541e6
Add MCore import guards
mikolajblaz Apr 25, 2024
9ad0795
Add async test
mikolajblaz Apr 25, 2024
7d8f3d7
Fix finalize_fn arg
mikolajblaz Apr 25, 2024
e0cf16b
Add docs
mikolajblaz Apr 25, 2024
466d38f
Remove checkpoints from accurate steps
mikolajblaz Apr 25, 2024
58bb90d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 25, 2024
19561e6
Fix MCore class usage
mikolajblaz Apr 25, 2024
8c1f854
Update docs
mikolajblaz Apr 26, 2024
d377216
Fix logger usage
mikolajblaz Apr 26, 2024
c535c82
Fix rebase
mikolajblaz Apr 29, 2024
50068b7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 29, 2024
b9f1ba2
Fix code scan issues
mikolajblaz Apr 29, 2024
cb495ec
Merge branch 'main' into mblaz/async-dist-ckpt-minimal-base
mikolajblaz Apr 29, 2024
0f546d8
Merge branch 'mblaz/async-dist-ckpt-minimal-base' into mblaz/async-di…
mikolajblaz Apr 29, 2024
13aa794
Remove unsused import
mikolajblaz Apr 29, 2024
8034497
Use dist-ckpt for Bert
mikolajblaz Apr 29, 2024
8f00d0b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 29, 2024
645e863
Fix load checkpoint return val
mikolajblaz Apr 29, 2024
9f89ee1
Merge branch 'mblaz/dist-ckpt-io' of github.com:NVIDIA/NeMo into mbla…
mikolajblaz Apr 29, 2024
29d4728
Merge branch 'mblaz/dist-ckpt-io' into mblaz/async-dist-ckpt-minimal-…
mikolajblaz Apr 29, 2024
e54e137
Merge branch 'mblaz/async-dist-ckpt-minimal-base' into mblaz/async-di…
mikolajblaz Apr 29, 2024
52b06bc
Use dist-ckpt based on sharded_state_dict
mikolajblaz Apr 29, 2024
781a78f
Merge branch 'mblaz/dist-ckpt-io' into mblaz/async-dist-ckpt-minimal-…
mikolajblaz Apr 29, 2024
853e23a
Add async logging
mikolajblaz Apr 30, 2024
7e8a16c
Remove deprecated argument
mikolajblaz Apr 30, 2024
ab367e0
Use correct checkpoint_io
mikolajblaz Apr 30, 2024
366d6d2
Merge branch 'mblaz/dist-ckpt-io' into mblaz/async-dist-ckpt-minimal-…
mikolajblaz Apr 30, 2024
af7bdec
Merge branch 'mblaz/async-dist-ckpt-minimal-base' into mblaz/async-di…
mikolajblaz Apr 30, 2024
0d756ff
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 30, 2024
1af7919
Fix bad merge
mikolajblaz Apr 30, 2024
2b3246e
Improve debug msg
mikolajblaz Apr 30, 2024
305c209
Merge branch 'main' into mblaz/async-dist-ckpt
mikolajblaz May 6, 2024
df86290
Run async test on GPU
mikolajblaz May 7, 2024
44f0949
Fix async ckpt unit test
mikolajblaz May 13, 2024
6361431
Merge branch 'main' into mblaz/async-dist-ckpt
mikolajblaz May 13, 2024
53529a0
Apply isort and black reformatting
mikolajblaz May 13, 2024
958acd1
Clarify async logs
mikolajblaz May 14, 2024
9b80b53
Merge branch 'main' into mblaz/async-dist-ckpt
mikolajblaz May 14, 2024
f038ffd
Add schema print
mikolajblaz May 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ exp_manager:
save_nemo_on_train_end: False # not recommended when training large models on clusters with short time limits
filename: 'megatron_gpt--{val_loss:.2f}-{step}-{consumed_samples}'
model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}}
async_save: False # Set to True to enable async checkpoint save. Currently works only with distributed checkpoints

model:
# use GPTModel from megatron.core
Expand Down
62 changes: 47 additions & 15 deletions nemo/collections/nlp/parts/megatron_trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# limitations under the License.

import sys
from typing import Union
from typing import Optional, Union

from lightning_fabric.utilities.exceptions import MisconfigurationException
from omegaconf import DictConfig
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelSummary
Expand All @@ -31,7 +32,11 @@
PipelineMixedPrecisionPlugin,
)
from nemo.utils import logging
from nemo.utils.callbacks.dist_ckpt_io import DistributedCheckpointIO
from nemo.utils.callbacks.dist_ckpt_io import (
AsyncFinalizableCheckpointIO,
AsyncFinalizerCallback,
DistributedCheckpointIO,
)


class MegatronTrainerBuilder:
Expand All @@ -51,7 +56,10 @@ def _training_strategy(self) -> Union[NLPDDPStrategy, NLPFSDPStrategy]:
_IS_INTERACTIVE = hasattr(sys, "ps1") or bool(sys.flags.interactive)
if _IS_INTERACTIVE and self.cfg.trainer.devices == 1:
logging.info("Detected interactive environment, using NLPDDPStrategyNotebook")
return NLPDDPStrategyNotebook(no_ddp_communication_hook=True, find_unused_parameters=False,)
return NLPDDPStrategyNotebook(
no_ddp_communication_hook=True,
find_unused_parameters=False,
)

if self.cfg.model.get('fsdp', False):
assert (
Expand Down Expand Up @@ -89,7 +97,7 @@ def _grad_scaler(self) -> GradScaler:
Returns a scaler for precision plugins.
"""
return GradScaler(
init_scale=self.cfg.model.get('native_amp_init_scale', 2 ** 32),
init_scale=self.cfg.model.get('native_amp_init_scale', 2**32),
growth_interval=self.cfg.model.get('native_amp_growth_interval', 1000),
hysteresis=self.cfg.model.get('hysteresis', 2),
)
Expand Down Expand Up @@ -137,19 +145,41 @@ def _plugins(self) -> list:
use_dist_ckpt = not self.cfg.model.get('fsdp', False) and (
self.cfg.model.get('mcore_gpt', False) or self.cfg.model.get('mcore_bert', False)
)
async_save = self.cfg.exp_manager.checkpoint_callback_params.get('async_save', False)
if use_dist_ckpt:
plugins.append(DistributedCheckpointIO.from_config(self.cfg.model))
checkpoint_io = DistributedCheckpointIO.from_config(self.cfg.model, async_save)
if async_save:
checkpoint_io = AsyncFinalizableCheckpointIO(checkpoint_io)
plugins.append(checkpoint_io)
elif async_save:
raise MisconfigurationException(
'exp_manager.checkpoint_callback_params.async_save=True without'
'distributed checkpoints is currently not supported'
)

return plugins

def _callbacks(self, callbacks: Optional[list]) -> list:
"""
Returns:
callbacks: list of callbacks passed to Trainer.callbacks.
"""
if callbacks is None:
callbacks = []
# enable_progress_bar is True by default. If cfg.trainer.enable_progress_bar=False, CustomProgressBar is not appended to callbacks
if 'enable_progress_bar' not in self.cfg.trainer or self.cfg.trainer.enable_progress_bar:
callbacks.append(CustomProgressBar())

if self.cfg.exp_manager.checkpoint_callback_params.get('async_save', False):
callbacks.append(AsyncFinalizerCallback())
return callbacks

def create_trainer(self, callbacks=None) -> Trainer:
# cfg.trainer.precision becomes None in Trainer if precision_plugins exist since both precision plugins and precision
precision = self.cfg.trainer.precision
strategy = self._training_strategy()
plugins = self._plugins()
# enable_progress_bar is True by default. If cfg.trainer.enable_progress_bar=False, CustomProgressBar is not appended to callbacks
if 'enable_progress_bar' not in self.cfg.trainer or self.cfg.trainer.enable_progress_bar:
callbacks = [CustomProgressBar()]
callbacks = self._callbacks(callbacks)
trainer = Trainer(plugins=plugins, strategy=strategy, **self.cfg.trainer, callbacks=callbacks)
# Restore the precision value after Trainer is built.
self.cfg.trainer.precision = precision
Expand All @@ -161,21 +191,23 @@ class MegatronBertTrainerBuilder(MegatronTrainerBuilder):

def _grad_scaler(self) -> GradScaler:
return GradScaler(
init_scale=self.cfg.model.get('native_amp_init_scale', 2 ** 32),
init_scale=self.cfg.model.get('native_amp_init_scale', 2**32),
growth_interval=self.cfg.model.get('native_amp_growth_interval', 1000),
)


class MegatronT5TrainerBuilder(MegatronTrainerBuilder):
"""Builder for T5 model Trainer with overrides."""

def create_trainer(self) -> Trainer:
def _callbacks(self, callbacks: Optional[list]) -> list:
callbacks = super()._callbacks(callbacks)
callbacks.append(ModelSummary(max_depth=3))
return callbacks

def create_trainer(self, callbacks=None) -> Trainer:
strategy = self._training_strategy()
plugins = self._plugins()
callbacks = [ModelSummary(max_depth=3)]
# enable_progress_bar is True by default. If cfg.trainer.enable_progress_bar=False, CustomProgressBar is not appended to callbacks
if 'enable_progress_bar' not in self.cfg.trainer or self.cfg.trainer.enable_progress_bar:
callbacks.append(CustomProgressBar())
callbacks = self._callbacks(callbacks)
return Trainer(plugins=plugins, strategy=strategy, **self.cfg.trainer, callbacks=callbacks)


Expand Down Expand Up @@ -207,7 +239,7 @@ class MegatronLMPPTrainerBuilder(MegatronTrainerBuilder):

def _grad_scaler(self) -> GradScaler:
return GradScaler(
init_scale=self.cfg.model.get("native_amp_init_scale", 2 ** 32),
init_scale=self.cfg.model.get("native_amp_init_scale", 2**32),
growth_interval=self.cfg.model.get("native_amp_growth_interval", 1000),
hysteresis=self.cfg.model.get("hysteresis", 2),
enabled=False if self.cfg.model.pipeline_model_parallel_size > 1 else True,
Expand Down
Loading
Loading