From 5012aaca27b8574bd6099afe96b3ebc2ae5bc168 Mon Sep 17 00:00:00 2001 From: rachitg Date: Wed, 10 Apr 2024 13:00:49 -0700 Subject: [PATCH 1/5] remove fp8 checkpoints for Attention Signed-off-by: rachitg --- .../nlp/models/language_modeling/megatron_gpt_model.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 6648abac8ee0..69e5732da22f 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -101,6 +101,8 @@ from megatron.core.transformer.module import Float16Module as MCoreFloat16Module from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import drain_embedding_wgrad_compute, init_method_normal, scaled_init_method_normal + from megatron.core.dist_checkpointing.dict_utils import dict_list_map_inplace + from megatron.core.dist_checkpointing.mapping import ShardedObject, LocalNonpersitentObject # TODO @tmoon: Use once available in Megatron-LM # from megatron.core.pipeline_parallel.schedules import DataIteratorList @@ -1722,6 +1724,13 @@ def sharded_state_dict(self, prefix: str = '') -> Dict[str, Any]: if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: parallel_state.set_virtual_pipeline_model_parallel_rank(0) + def skip_fp8_load(x): + if isinstance(x, ShardedObject) and 'fused_attention' in x.key and '_extra_state' in x.key: + x = LocalNonpersitentObject(x.data) # use the FP8 state from initialization, not from ckpt + return x + if True: #if self.cfg.skip_attn_fp8_load: + dict_list_map_inplace(skip_fp8_load, sharded_state_dict) + return sharded_state_dict def parameters(self): From 69722e19f308fd017a6254e92da37c2c1796909d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Apr 2024 20:03:22 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../nlp/models/language_modeling/megatron_gpt_model.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 69e5732da22f..8a6451646429 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -89,6 +89,8 @@ from megatron.core import InferenceParams, parallel_state, tensor_parallel from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig, MockGPTDataset + from megatron.core.dist_checkpointing.dict_utils import dict_list_map_inplace + from megatron.core.dist_checkpointing.mapping import LocalNonpersitentObject, ShardedObject # NeMo's implementation of the get_gpt_layer_ammo_spec function is temporarily used # from megatron.core.inference.gpt.model_specs import get_gpt_layer_ammo_spec @@ -101,8 +103,6 @@ from megatron.core.transformer.module import Float16Module as MCoreFloat16Module from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import drain_embedding_wgrad_compute, init_method_normal, scaled_init_method_normal - from megatron.core.dist_checkpointing.dict_utils import dict_list_map_inplace - from megatron.core.dist_checkpointing.mapping import ShardedObject, LocalNonpersitentObject # TODO @tmoon: Use once available in Megatron-LM # from megatron.core.pipeline_parallel.schedules import DataIteratorList @@ -1728,7 +1728,8 @@ def skip_fp8_load(x): if isinstance(x, ShardedObject) and 'fused_attention' in x.key and '_extra_state' in x.key: x = LocalNonpersitentObject(x.data) # use the FP8 state from initialization, not from ckpt return x - if True: #if self.cfg.skip_attn_fp8_load: + + if True: # if self.cfg.skip_attn_fp8_load: dict_list_map_inplace(skip_fp8_load, sharded_state_dict) return sharded_state_dict From a43423724b5c00be8902fae2cac2d6add70f1aed Mon Sep 17 00:00:00 2001 From: rachitg Date: Thu, 11 Apr 2024 11:46:33 -0700 Subject: [PATCH 3/5] fixes Signed-off-by: rachitg --- .../nlp/models/language_modeling/megatron_gpt_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 69e5732da22f..077ae09e6be6 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -1724,11 +1724,12 @@ def sharded_state_dict(self, prefix: str = '') -> Dict[str, Any]: if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: parallel_state.set_virtual_pipeline_model_parallel_rank(0) + # WAR: This is a temporary fix to skip loading FP8 parameters for Dot Product Attention def skip_fp8_load(x): if isinstance(x, ShardedObject) and 'fused_attention' in x.key and '_extra_state' in x.key: x = LocalNonpersitentObject(x.data) # use the FP8 state from initialization, not from ckpt return x - if True: #if self.cfg.skip_attn_fp8_load: + if self.cfg.fp8_dot_product_attention: dict_list_map_inplace(skip_fp8_load, sharded_state_dict) return sharded_state_dict From d68c04bba401da0fe1b1bab2a68b1b066f358d36 Mon Sep 17 00:00:00 2001 From: rachitg Date: Fri, 12 Apr 2024 05:56:49 +0000 Subject: [PATCH 4/5] set default value and support mha Signed-off-by: rachitg --- .../nlp/models/language_modeling/megatron_gpt_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 841ccb4936d5..dbf4e0644c3f 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -1730,7 +1730,7 @@ def skip_fp8_load(x): x = LocalNonpersitentObject(x.data) # use the FP8 state from initialization, not from ckpt return x - if self.cfg.fp8_dot_product_attention: + if self.cfg.get('fp8_dot_product_attention', False) or self.cfg.get('fp8_multi_head_attention', False): dict_list_map_inplace(skip_fp8_load, sharded_state_dict) return sharded_state_dict From 5536c42b53e41155ee56c7313d5e25a1f695b73b Mon Sep 17 00:00:00 2001 From: rachitg Date: Fri, 12 Apr 2024 19:02:16 -0700 Subject: [PATCH 5/5] skip by default Signed-off-by: rachitg --- .../nlp/models/language_modeling/megatron_gpt_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index ede72439615e..4493532f88bf 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -1741,7 +1741,7 @@ def skip_fp8_load(x): x = LocalNonpersitentObject(x.data) # use the FP8 state from initialization, not from ckpt return x - if self.cfg.get('fp8_dot_product_attention', False) or self.cfg.get('fp8_multi_head_attention', False): + if self.cfg.get('skip_fp8_attention_checkpoint_load', True): dict_list_map_inplace(skip_fp8_load, sharded_state_dict) return sharded_state_dict