From 9d13f159bb150faf0a1842b0ce364e2f5d3a59a0 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Thu, 9 May 2024 18:01:54 +0000 Subject: [PATCH 01/11] Mcore dist opt ckpt fix Signed-off-by: Alexandros Koumparoulis --- nemo/collections/nlp/parts/nlp_overrides.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 079732f6b9c5..0c81f47ac83b 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -78,7 +78,7 @@ from apex.transformer.pipeline_parallel.utils import get_num_microbatches from nemo.core.optim.distributed_adam import MegatronDistributedFusedAdam - + from nemo.core.optim.mcore_optim import McoreDistributedOptimizer HAVE_APEX = True except (ImportError, ModuleNotFoundError): @@ -294,7 +294,7 @@ def optimizer_sharded_state_dict(self, unsharded_optim_state=None): key: value for key, value in model_sharded_state_dict.items() if not key.endswith('_extra_state') } - if isinstance(optimizer, MegatronDistributedFusedAdam): + if isinstance(optimizer, MegatronDistributedFusedAdam) or isinstance(optimizer, McoreDistributedOptimizer): return optimizer.sharded_state_dict(model_sharded_state_dict, unsharded_optim_state) elif not isinstance(optimizer, MainParamsOptimizerWrapper): # Regular optimizer, e.g. Adam or FusedAdam From 1a6490416fbb5478fa1ed11c897b7ad6f8e9ea2e Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Wed, 15 May 2024 07:13:36 +0000 Subject: [PATCH 02/11] pass dp_zero_gather_scatter to starded-state-dict Signed-off-by: Alexandros Koumparoulis --- nemo/core/optim/mcore_optim.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/nemo/core/optim/mcore_optim.py b/nemo/core/optim/mcore_optim.py index 0d4b524049ca..5103cce2dd84 100644 --- a/nemo/core/optim/mcore_optim.py +++ b/nemo/core/optim/mcore_optim.py @@ -37,6 +37,8 @@ def __init__(self, optim): self.mcore_optimizer = optim self.param_groups = self.mcore_optimizer.param_groups self.state = self.mcore_optimizer.state + self.sharding_type = 'dp_zero_gather_scatter' + # 'fully_sharded_bucket_space' if args.ckpt_fully_parallel_save else 'dp_zero_gather_scatter' def zero_grad(self, set_to_none: bool = True): """We only need to zero the model related parameters, i.e., @@ -55,8 +57,9 @@ def state_dict(self): def load_state_dict(self, state_dict): self.mcore_optimizer.load_state_dict(state_dict) - def sharded_state_dict(self, model_sharded_state_dict, is_loading: bool = False, **kwargs): - return self.mcore_optimizer.sharded_state_dict(model_sharded_state_dict, is_loading, **kwargs) + def sharded_state_dict(self, model_sharded_state_dict, optimizer_state_dict=None): + return self.mcore_optimizer.sharded_state_dict( + model_sharded_state_dict, is_loading=False, sharding_type='dp_zero_gather_scatter') def step(self, closure): """Clip gradients (if needed) and step the base optimizer. From cb78ce98f6959e1253a40b3a43759bc592b501a7 Mon Sep 17 00:00:00 2001 From: akoumpa Date: Wed, 15 May 2024 16:26:00 +0000 Subject: [PATCH 03/11] Apply isort and black reformatting Signed-off-by: akoumpa --- nemo/collections/nlp/parts/nlp_overrides.py | 1 + nemo/core/optim/mcore_optim.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 0c81f47ac83b..71714acf6be9 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -79,6 +79,7 @@ from nemo.core.optim.distributed_adam import MegatronDistributedFusedAdam from nemo.core.optim.mcore_optim import McoreDistributedOptimizer + HAVE_APEX = True except (ImportError, ModuleNotFoundError): diff --git a/nemo/core/optim/mcore_optim.py b/nemo/core/optim/mcore_optim.py index 5103cce2dd84..caad2ba17c6e 100644 --- a/nemo/core/optim/mcore_optim.py +++ b/nemo/core/optim/mcore_optim.py @@ -59,7 +59,8 @@ def load_state_dict(self, state_dict): def sharded_state_dict(self, model_sharded_state_dict, optimizer_state_dict=None): return self.mcore_optimizer.sharded_state_dict( - model_sharded_state_dict, is_loading=False, sharding_type='dp_zero_gather_scatter') + model_sharded_state_dict, is_loading=False, sharding_type='dp_zero_gather_scatter' + ) def step(self, closure): """Clip gradients (if needed) and step the base optimizer. From 006b6d85c0c145f5fd4294f1323fa7b933ad4018 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Thu, 16 May 2024 09:45:50 -0700 Subject: [PATCH 04/11] introduce dist_ckpt_parallel_save option Signed-off-by: Alexandros Koumparoulis --- examples/nlp/language_modeling/conf/megatron_gpt_config.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml index 20e20744833c..6df315eba2b9 100755 --- a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml @@ -154,6 +154,7 @@ model: # Distributed checkpoint setup dist_ckpt_format: 'zarr' # Set to 'torch_dist' to use PyTorch distributed checkpoint format. dist_ckpt_load_on_device: True # whether to load checkpoint weights directly on GPU or to CPU + dist_ckpt_parallel_save: False # if true, each worker will write its own part of the dist checkpoint ## Activation Checkpointing # NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed. From 545fc51da1c1965c4b72cca0633b8df4b90e28df Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Thu, 16 May 2024 09:56:14 -0700 Subject: [PATCH 05/11] determine sharding type from dist_ckpt_parallel_save Signed-off-by: Alexandros Koumparoulis --- nemo/core/optim/mcore_optim.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/nemo/core/optim/mcore_optim.py b/nemo/core/optim/mcore_optim.py index caad2ba17c6e..c48de6bdcbee 100644 --- a/nemo/core/optim/mcore_optim.py +++ b/nemo/core/optim/mcore_optim.py @@ -37,8 +37,6 @@ def __init__(self, optim): self.mcore_optimizer = optim self.param_groups = self.mcore_optimizer.param_groups self.state = self.mcore_optimizer.state - self.sharding_type = 'dp_zero_gather_scatter' - # 'fully_sharded_bucket_space' if args.ckpt_fully_parallel_save else 'dp_zero_gather_scatter' def zero_grad(self, set_to_none: bool = True): """We only need to zero the model related parameters, i.e., @@ -57,9 +55,10 @@ def state_dict(self): def load_state_dict(self, state_dict): self.mcore_optimizer.load_state_dict(state_dict) - def sharded_state_dict(self, model_sharded_state_dict, optimizer_state_dict=None): + def sharded_state_dict(self, model_sharded_state_dict, optimizer_state_dict=None, is_loading=False, **kwargs): + sharding_type = 'fully_sharded_bucket_space' if kwargs.get('dist_ckpt_parallel_save', False) else 'dp_zero_gather_scatter' return self.mcore_optimizer.sharded_state_dict( - model_sharded_state_dict, is_loading=False, sharding_type='dp_zero_gather_scatter' + model_sharded_state_dict, is_loading=is_loading, sharding_type=sharding_type ) def step(self, closure): From 4643470fe6569a39b2279e8d4670f28ac028bcaf Mon Sep 17 00:00:00 2001 From: akoumpa Date: Thu, 16 May 2024 17:00:09 +0000 Subject: [PATCH 06/11] Apply isort and black reformatting Signed-off-by: akoumpa --- nemo/core/optim/mcore_optim.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nemo/core/optim/mcore_optim.py b/nemo/core/optim/mcore_optim.py index c48de6bdcbee..0edac2b83f8f 100644 --- a/nemo/core/optim/mcore_optim.py +++ b/nemo/core/optim/mcore_optim.py @@ -56,7 +56,9 @@ def load_state_dict(self, state_dict): self.mcore_optimizer.load_state_dict(state_dict) def sharded_state_dict(self, model_sharded_state_dict, optimizer_state_dict=None, is_loading=False, **kwargs): - sharding_type = 'fully_sharded_bucket_space' if kwargs.get('dist_ckpt_parallel_save', False) else 'dp_zero_gather_scatter' + sharding_type = ( + 'fully_sharded_bucket_space' if kwargs.get('dist_ckpt_parallel_save', False) else 'dp_zero_gather_scatter' + ) return self.mcore_optimizer.sharded_state_dict( model_sharded_state_dict, is_loading=is_loading, sharding_type=sharding_type ) From 8fa988de3c281aa4b7c6964c69211821187544b8 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Fri, 17 May 2024 15:10:25 -0700 Subject: [PATCH 07/11] read model.disk_ckpt_parallel_save from cfg and pass it to mcore dist ckpt Signed-off-by: Alexandros Koumparoulis --- .../collections/nlp/parts/megatron_trainer_builder.py | 1 + nemo/collections/nlp/parts/nlp_overrides.py | 11 +++++++++-- nemo/core/optim/mcore_optim.py | 4 ++-- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/nemo/collections/nlp/parts/megatron_trainer_builder.py b/nemo/collections/nlp/parts/megatron_trainer_builder.py index e1a780f09756..4bd0e223b939 100644 --- a/nemo/collections/nlp/parts/megatron_trainer_builder.py +++ b/nemo/collections/nlp/parts/megatron_trainer_builder.py @@ -90,6 +90,7 @@ def _training_strategy(self) -> Union[NLPDDPStrategy, NLPFSDPStrategy]: find_unused_parameters=False, nccl_communicator_config_path=self.cfg.model.get('nccl_communicator_config_path', None), sharp=self.cfg.model.get('sharp', False), + dist_ckpt_parallel_save=self.cfg.model.get('dist_ckpt_parallel_save', False), ) def _grad_scaler(self) -> GradScaler: diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 71714acf6be9..43fca57189c2 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -183,6 +183,7 @@ def __init__( no_ddp_communication_hook: bool = False, nccl_communicator_config_path: Optional[str] = None, sharp: bool = False, + dist_ckpt_parallel_save: bool = False, **kwargs: Union[Any, Dict[str, Any]], ) -> None: if not HAVE_APEX: @@ -199,6 +200,7 @@ def __init__( self.no_ddp_communication_hook = no_ddp_communication_hook self.nccl_communicator_config_path = nccl_communicator_config_path self.sharp = sharp + self._dist_ckpt_parallel_save = dist_ckpt_parallel_save def setup(self, trainer: "pl.Trainer") -> None: """ @@ -294,8 +296,13 @@ def optimizer_sharded_state_dict(self, unsharded_optim_state=None): model_sharded_state_dict = { key: value for key, value in model_sharded_state_dict.items() if not key.endswith('_extra_state') } - - if isinstance(optimizer, MegatronDistributedFusedAdam) or isinstance(optimizer, McoreDistributedOptimizer): + if isinstance(optimizer, McoreDistributedOptimizer): + return optimizer.sharded_state_dict( + model_sharded_state_dict, + unsharded_optim_state, + dist_ckpt_parallel_save=self._dist_ckpt_parallel_save + ) + elif isinstance(optimizer, MegatronDistributedFusedAdam): return optimizer.sharded_state_dict(model_sharded_state_dict, unsharded_optim_state) elif not isinstance(optimizer, MainParamsOptimizerWrapper): # Regular optimizer, e.g. Adam or FusedAdam diff --git a/nemo/core/optim/mcore_optim.py b/nemo/core/optim/mcore_optim.py index 0edac2b83f8f..ab4a0dd6af53 100644 --- a/nemo/core/optim/mcore_optim.py +++ b/nemo/core/optim/mcore_optim.py @@ -55,9 +55,9 @@ def state_dict(self): def load_state_dict(self, state_dict): self.mcore_optimizer.load_state_dict(state_dict) - def sharded_state_dict(self, model_sharded_state_dict, optimizer_state_dict=None, is_loading=False, **kwargs): + def sharded_state_dict(self, model_sharded_state_dict, optimizer_state_dict=None, is_loading=False, dist_ckpt_parallel_save=False): sharding_type = ( - 'fully_sharded_bucket_space' if kwargs.get('dist_ckpt_parallel_save', False) else 'dp_zero_gather_scatter' + 'fully_sharded_bucket_space' if dist_ckpt_parallel_save else 'dp_zero_gather_scatter' ) return self.mcore_optimizer.sharded_state_dict( model_sharded_state_dict, is_loading=is_loading, sharding_type=sharding_type From 4b297e0d4ede06ea9dbcd6bf8a15f4839ecef2e2 Mon Sep 17 00:00:00 2001 From: akoumpa Date: Fri, 17 May 2024 22:16:18 +0000 Subject: [PATCH 08/11] Apply isort and black reformatting Signed-off-by: akoumpa --- nemo/collections/nlp/parts/nlp_overrides.py | 4 +--- nemo/core/optim/mcore_optim.py | 8 ++++---- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 43fca57189c2..97b95893214f 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -298,9 +298,7 @@ def optimizer_sharded_state_dict(self, unsharded_optim_state=None): } if isinstance(optimizer, McoreDistributedOptimizer): return optimizer.sharded_state_dict( - model_sharded_state_dict, - unsharded_optim_state, - dist_ckpt_parallel_save=self._dist_ckpt_parallel_save + model_sharded_state_dict, unsharded_optim_state, dist_ckpt_parallel_save=self._dist_ckpt_parallel_save ) elif isinstance(optimizer, MegatronDistributedFusedAdam): return optimizer.sharded_state_dict(model_sharded_state_dict, unsharded_optim_state) diff --git a/nemo/core/optim/mcore_optim.py b/nemo/core/optim/mcore_optim.py index ab4a0dd6af53..d64ac392f0eb 100644 --- a/nemo/core/optim/mcore_optim.py +++ b/nemo/core/optim/mcore_optim.py @@ -55,10 +55,10 @@ def state_dict(self): def load_state_dict(self, state_dict): self.mcore_optimizer.load_state_dict(state_dict) - def sharded_state_dict(self, model_sharded_state_dict, optimizer_state_dict=None, is_loading=False, dist_ckpt_parallel_save=False): - sharding_type = ( - 'fully_sharded_bucket_space' if dist_ckpt_parallel_save else 'dp_zero_gather_scatter' - ) + def sharded_state_dict( + self, model_sharded_state_dict, optimizer_state_dict=None, is_loading=False, dist_ckpt_parallel_save=False + ): + sharding_type = 'fully_sharded_bucket_space' if dist_ckpt_parallel_save else 'dp_zero_gather_scatter' return self.mcore_optimizer.sharded_state_dict( model_sharded_state_dict, is_loading=is_loading, sharding_type=sharding_type ) From 82b07c957dc559cb33585c4c3ca30eeb3ceca574 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Tue, 21 May 2024 02:23:18 -0700 Subject: [PATCH 09/11] Pass is_loading to mcore_optim.py's sharded_state_dict Signed-off-by: Alexandros Koumparoulis --- nemo/collections/nlp/parts/nlp_overrides.py | 7 ++++--- nemo/core/optim/mcore_optim.py | 3 ++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 97b95893214f..510731c0981f 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -278,7 +278,7 @@ def configure_ddp(self): else: super().configure_ddp() - def optimizer_sharded_state_dict(self, unsharded_optim_state=None): + def optimizer_sharded_state_dict(self, unsharded_optim_state=None, is_loading=False): """ Sharded state dictionary for an MainParamsOptimizerWrapper. Used to save and load the optimizer state when training with distributed_checkpoint. @@ -298,7 +298,8 @@ def optimizer_sharded_state_dict(self, unsharded_optim_state=None): } if isinstance(optimizer, McoreDistributedOptimizer): return optimizer.sharded_state_dict( - model_sharded_state_dict, unsharded_optim_state, dist_ckpt_parallel_save=self._dist_ckpt_parallel_save + model_sharded_state_dict, unsharded_optim_state, is_loading=is_loading, + dist_ckpt_parallel_save=self._dist_ckpt_parallel_save ) elif isinstance(optimizer, MegatronDistributedFusedAdam): return optimizer.sharded_state_dict(model_sharded_state_dict, unsharded_optim_state) @@ -442,7 +443,7 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: # after dist_checkpointing.load, sharded tensors will be replaced with tensors checkpoint['state_dict'] = sharded_state_dict - checkpoint['optimizer_states'] = [self.optimizer_sharded_state_dict()] + checkpoint['optimizer_states'] = [self.optimizer_sharded_state_dict(is_loading=True)] return self.checkpoint_io.load_checkpoint(checkpoint_path, sharded_state_dict=checkpoint) # Legacy model parallel checkpointing logic, does not use megatron core diff --git a/nemo/core/optim/mcore_optim.py b/nemo/core/optim/mcore_optim.py index d64ac392f0eb..6de93aa5ec3a 100644 --- a/nemo/core/optim/mcore_optim.py +++ b/nemo/core/optim/mcore_optim.py @@ -58,7 +58,8 @@ def load_state_dict(self, state_dict): def sharded_state_dict( self, model_sharded_state_dict, optimizer_state_dict=None, is_loading=False, dist_ckpt_parallel_save=False ): - sharding_type = 'fully_sharded_bucket_space' if dist_ckpt_parallel_save else 'dp_zero_gather_scatter' + # TODO(@akoumparouli, @mikolajblaz): switch to sharding_type once support for fully_sharded_model_space merged in mcore. + sharding_type = 'fully_sharded_model_space' if dist_ckpt_parallel_save else 'dp_zero_gather_scatter' return self.mcore_optimizer.sharded_state_dict( model_sharded_state_dict, is_loading=is_loading, sharding_type=sharding_type ) From 27eb5535a508e0aa8c36dfa80df7f477400fd96b Mon Sep 17 00:00:00 2001 From: akoumpa Date: Tue, 21 May 2024 09:24:44 +0000 Subject: [PATCH 10/11] Apply isort and black reformatting Signed-off-by: akoumpa --- nemo/collections/nlp/parts/nlp_overrides.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 510731c0981f..43dbb4a68f21 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -298,8 +298,10 @@ def optimizer_sharded_state_dict(self, unsharded_optim_state=None, is_loading=Fa } if isinstance(optimizer, McoreDistributedOptimizer): return optimizer.sharded_state_dict( - model_sharded_state_dict, unsharded_optim_state, is_loading=is_loading, - dist_ckpt_parallel_save=self._dist_ckpt_parallel_save + model_sharded_state_dict, + unsharded_optim_state, + is_loading=is_loading, + dist_ckpt_parallel_save=self._dist_ckpt_parallel_save, ) elif isinstance(optimizer, MegatronDistributedFusedAdam): return optimizer.sharded_state_dict(model_sharded_state_dict, unsharded_optim_state) From 0a9cd7137933be87b3412855cfc07a0104dc278e Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> Date: Wed, 22 May 2024 10:00:20 -0700 Subject: [PATCH 11/11] Update nemo/core/optim/mcore_optim.py Co-authored-by: mikolajblaz Signed-off-by: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> --- nemo/core/optim/mcore_optim.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nemo/core/optim/mcore_optim.py b/nemo/core/optim/mcore_optim.py index 6de93aa5ec3a..234680f49249 100644 --- a/nemo/core/optim/mcore_optim.py +++ b/nemo/core/optim/mcore_optim.py @@ -59,7 +59,8 @@ def sharded_state_dict( self, model_sharded_state_dict, optimizer_state_dict=None, is_loading=False, dist_ckpt_parallel_save=False ): # TODO(@akoumparouli, @mikolajblaz): switch to sharding_type once support for fully_sharded_model_space merged in mcore. - sharding_type = 'fully_sharded_model_space' if dist_ckpt_parallel_save else 'dp_zero_gather_scatter' + # sharding_type = 'fully_sharded_model_space' if dist_ckpt_parallel_save else 'dp_zero_gather_scatter' + sharding_type = 'dp_zero_gather_scatter' return self.mcore_optimizer.sharded_state_dict( model_sharded_state_dict, is_loading=is_loading, sharding_type=sharding_type )