Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mcore dist opt ckpt fix #9156

Merged
merged 12 commits into from
May 22, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ model:
# Distributed checkpoint setup
dist_ckpt_format: 'zarr' # Set to 'torch_dist' to use PyTorch distributed checkpoint format.
dist_ckpt_load_on_device: True # whether to load checkpoint weights directly on GPU or to CPU
dist_ckpt_parallel_save: False # if true, each worker will write its own part of the dist checkpoint
akoumpa marked this conversation as resolved.
Show resolved Hide resolved

## Activation Checkpointing
# NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed.
Expand Down
1 change: 1 addition & 0 deletions nemo/collections/nlp/parts/megatron_trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def _training_strategy(self) -> Union[NLPDDPStrategy, NLPFSDPStrategy]:
find_unused_parameters=False,
nccl_communicator_config_path=self.cfg.model.get('nccl_communicator_config_path', None),
sharp=self.cfg.model.get('sharp', False),
dist_ckpt_parallel_save=self.cfg.model.get('dist_ckpt_parallel_save', False),
)

def _grad_scaler(self) -> GradScaler:
Expand Down
10 changes: 8 additions & 2 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
from apex.transformer.pipeline_parallel.utils import get_num_microbatches

from nemo.core.optim.distributed_adam import MegatronDistributedFusedAdam
from nemo.core.optim.mcore_optim import McoreDistributedOptimizer

HAVE_APEX = True

Expand Down Expand Up @@ -182,6 +183,7 @@ def __init__(
no_ddp_communication_hook: bool = False,
nccl_communicator_config_path: Optional[str] = None,
sharp: bool = False,
dist_ckpt_parallel_save: bool = False,
**kwargs: Union[Any, Dict[str, Any]],
) -> None:
if not HAVE_APEX:
Expand All @@ -198,6 +200,7 @@ def __init__(
self.no_ddp_communication_hook = no_ddp_communication_hook
self.nccl_communicator_config_path = nccl_communicator_config_path
self.sharp = sharp
self._dist_ckpt_parallel_save = dist_ckpt_parallel_save

def setup(self, trainer: "pl.Trainer") -> None:
"""
Expand Down Expand Up @@ -293,8 +296,11 @@ def optimizer_sharded_state_dict(self, unsharded_optim_state=None):
model_sharded_state_dict = {
key: value for key, value in model_sharded_state_dict.items() if not key.endswith('_extra_state')
}

if isinstance(optimizer, MegatronDistributedFusedAdam):
if isinstance(optimizer, McoreDistributedOptimizer):
return optimizer.sharded_state_dict(
model_sharded_state_dict, unsharded_optim_state, dist_ckpt_parallel_save=self._dist_ckpt_parallel_save
)
elif isinstance(optimizer, MegatronDistributedFusedAdam):
return optimizer.sharded_state_dict(model_sharded_state_dict, unsharded_optim_state)
elif not isinstance(optimizer, MainParamsOptimizerWrapper):
# Regular optimizer, e.g. Adam or FusedAdam
Expand Down
9 changes: 7 additions & 2 deletions nemo/core/optim/mcore_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,13 @@ def state_dict(self):
def load_state_dict(self, state_dict):
self.mcore_optimizer.load_state_dict(state_dict)

def sharded_state_dict(self, model_sharded_state_dict, is_loading: bool = False, **kwargs):
return self.mcore_optimizer.sharded_state_dict(model_sharded_state_dict, is_loading, **kwargs)
def sharded_state_dict(
self, model_sharded_state_dict, optimizer_state_dict=None, is_loading=False, dist_ckpt_parallel_save=False
):
sharding_type = 'fully_sharded_bucket_space' if dist_ckpt_parallel_save else 'dp_zero_gather_scatter'
akoumpa marked this conversation as resolved.
Show resolved Hide resolved
return self.mcore_optimizer.sharded_state_dict(
model_sharded_state_dict, is_loading=is_loading, sharding_type=sharding_type
)

def step(self, closure):
"""Clip gradients (if needed) and step the base optimizer.
Expand Down