Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make torch_dist ckpt strategy as default #9852

Merged
merged 8 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile.ci
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ WORKDIR /workspace
# Install NeMo requirements
ARG TE_TAG=7d576ed25266a17a7b651f2c12e8498f67e0baea
ARG MODELOPT_VERSION=0.13.0
ARG MCORE_TAG=c7a1f82d761577e6ca0338d3521eac82f2aa0904
ARG MCORE_TAG=338af51452a53982d202e8386db6233adad1ce86
ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c
RUN \
--mount=type=bind,source=requirements,target=requirements \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ model:
fsdp_sharded_checkpoint: False # Store and load FSDP shared checkpoint.

# Distributed checkpoint setup
dist_ckpt_format: 'zarr' # Set to 'torch_dist' to use PyTorch distributed checkpoint format.
dist_ckpt_format: 'torch_dist' # Set to 'torch_dist' to use PyTorch distributed checkpoint format.
dist_ckpt_load_on_device: True # whether to load checkpoint weights directly on GPU or to CPU
dist_ckpt_parallel_save: True # if true, each worker will write its own part of the dist checkpoint
dist_ckpt_parallel_save_within_dp: False # if true, save will be parallelized only within a DP group (whole world otherwise), which might slightly reduce the save overhead
Expand Down
7 changes: 7 additions & 0 deletions nemo/lightning/io/pl.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,13 @@ def _determine_dist_ckpt_save_strategy(self):
are passed in config or in case of a fully parallel save in which case
a parallelization wrapper is applied.
"""
if self.save_ckpt_format == 'zarr':
logging.warning(
f'`zarr` distributed checkpoint backend is deprecated.'
f' Distributed optimizer checkpoint saving might be extremely slow.'
f' Please switch to PyTorch Distributed format (model.dist_ckpt_format=torch_dist).'
)

if self.async_save and self.save_ckpt_format != 'torch_dist':
raise ValueError('Async dist-ckpt save supported only for torch_dist format')

Expand Down
9 changes: 8 additions & 1 deletion nemo/utils/callbacks/dist_ckpt_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def from_config(cls, model_cfg: dict, async_save: bool = False):
it should be provided separately. Defaults to False.
"""
return cls(
save_ckpt_format=model_cfg.get('dist_ckpt_format', 'zarr'),
save_ckpt_format=model_cfg.get('dist_ckpt_format', 'torch_dist'),
load_directly_on_device=model_cfg.get('dist_ckpt_load_on_device', True),
load_strictness=model_cfg.get('dist_ckpt_load_strictness', None),
async_save=async_save,
Expand Down Expand Up @@ -360,6 +360,13 @@ def _determine_dist_ckpt_save_strategy(self):
are passed in config or in case of a fully parallel save in which case
a parallelization wrapper is applied.
"""
if self.save_ckpt_format == 'zarr':
logging.warning(
f'`zarr` distributed checkpoint backend is deprecated.'
f' Distributed optimizer checkpoint saving might be extremely slow.'
f' Please switch to PyTorch Distributed format (model.dist_ckpt_format=torch_dist).'
)

if self.async_save and self.save_ckpt_format != 'torch_dist':
raise ValueError('Async dist-ckpt save supported only for torch_dist format')

Expand Down
Loading