From ac53e2296fcd9f699c928b62948a0b673c3817bc Mon Sep 17 00:00:00 2001 From: Jie Xin <932141413@qq.com> Date: Sat, 13 Apr 2024 09:33:30 +0800 Subject: [PATCH 1/4] Support alternative mapping TP->PP->DP (#8909) * support new tp-pp-dp mapping Signed-off-by: jxin * add test Signed-off-by: jxin * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refine Signed-off-by: jxin * change mcore commit Signed-off-by: jxin --------- Signed-off-by: jxin Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Harper --- Jenkinsfile | 2 +- .../stable_diffusion/ldm/ddpm.py | 7 +- .../language_modeling/megatron_base_model.py | 1 + .../language_modeling/megatron_gpt_model.py | 8 +- .../modules/common/megatron/megatron_init.py | 95 ++++++------- nemo/collections/nlp/parts/nlp_overrides.py | 1 + nemo/utils/app_state.py | 9 ++ tests/collections/nlp/test_initialize.py | 134 ++++++++++++++++++ 8 files changed, 189 insertions(+), 68 deletions(-) create mode 100644 tests/collections/nlp/test_initialize.py diff --git a/Jenkinsfile b/Jenkinsfile index c98d13fbed38..55e836eea13a 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -87,7 +87,7 @@ pipeline { steps { sh 'git clone https://github.com/NVIDIA/Megatron-LM.git && \ cd Megatron-LM && \ - git checkout f3a3020031f384ddafd9b7e9f3a587798c0aea21 && \ + git checkout fbb375d4b5e88ce52f5f7125053068caff47f93f && \ pip install . && \ cd megatron/core/datasets && \ make' diff --git a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/ddpm.py b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/ddpm.py index 33a194500a69..a96c3c47e44e 100644 --- a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/ddpm.py +++ b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/ddpm.py @@ -1770,12 +1770,7 @@ def fwd_bwd_step(self, dataloader_iter, forward_only): # we can avoid this broadcast by updating the PTL log function to accept specific ranks if parallel_state.get_pipeline_model_parallel_world_size() > 1: if self.loss_broadcast_src_rank is None: - dp_size = parallel_state.get_data_parallel_world_size() - tp_size = parallel_state.get_tensor_model_parallel_world_size() - pp_size = parallel_state.get_pipeline_model_parallel_world_size() - rank_in_dp_tp_group = torch.distributed.get_rank() % (dp_size * tp_size) - last_pipeline_stage_offset = (tp_size * dp_size) * (pp_size - 1) - self.loss_broadcast_src_rank = last_pipeline_stage_offset + rank_in_dp_tp_group + self.loss_broadcast_src_rank = parallel_state.get_pipeline_model_parallel_last_rank() torch.distributed.broadcast( loss_mean, self.loss_broadcast_src_rank, group=parallel_state.get_pipeline_model_parallel_group(), ) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index 035d194de09f..f431d43716b9 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -195,6 +195,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True): pipeline_model_parallel_size=cfg.get('pipeline_model_parallel_size', 1), virtual_pipeline_model_parallel_size=vp_size, pipeline_model_parallel_split_rank=cfg.get('pipeline_model_parallel_split_rank', 0), + use_tp_pp_dp_mapping=cfg.get('use_tp_pp_dp_mapping', False), context_parallel_size=cfg.get('context_parallel_size', 1), micro_batch_size=cfg.get('micro_batch_size'), global_batch_size=cfg.get('global_batch_size'), diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index d3f5a7afd631..ede72439615e 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -1310,13 +1310,7 @@ def on_validation_epoch_end(self): # it should be casted to other pipeline stages for logging. if parallel_state.get_pipeline_model_parallel_world_size() > 1: if self.loss_broadcast_src_rank is None: - dp_size = parallel_state.get_data_parallel_world_size() - cp_size = parallel_state.get_context_parallel_world_size() - tp_size = parallel_state.get_tensor_model_parallel_world_size() - pp_size = parallel_state.get_pipeline_model_parallel_world_size() - rank_in_dp_tp_group = torch.distributed.get_rank() % (dp_size * cp_size * tp_size) - last_pipeline_stage_offset = (tp_size * cp_size * dp_size) * (pp_size - 1) - self.loss_broadcast_src_rank = last_pipeline_stage_offset + rank_in_dp_tp_group + self.loss_broadcast_src_rank = parallel_state.get_pipeline_model_parallel_last_rank() torch.distributed.broadcast( averaged_loss, self.loss_broadcast_src_rank, group=parallel_state.get_pipeline_model_parallel_group(), ) diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_init.py b/nemo/collections/nlp/modules/common/megatron/megatron_init.py index 7ba2e28008ac..5d5b65b360ee 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_init.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_init.py @@ -32,6 +32,7 @@ try: from megatron.core import tensor_parallel from megatron.core.parallel_state import ( + RankGenerator, get_pipeline_model_parallel_rank, set_expert_model_parallel_rank, set_expert_model_parallel_world_size, @@ -74,6 +75,7 @@ def initialize_model_parallel_for_nemo( init_mpi_proc_group=False, seed=1234, apex_transformer_log_level=30, + use_tp_pp_dp_mapping=False, ): if virtual_pipeline_model_parallel_size is not None and not HAVE_INTERLEAVED: @@ -84,6 +86,7 @@ def initialize_model_parallel_for_nemo( app_state.global_rank = global_rank app_state.world_size = world_size app_state.local_rank = local_rank + app_state.use_tp_pp_dp_mapping = use_tp_pp_dp_mapping app_state.expert_model_parallel_size = expert_model_parallel_size app_state.tensor_model_parallel_size = tensor_model_parallel_size app_state.pipeline_model_parallel_size = pipeline_model_parallel_size @@ -108,6 +111,7 @@ def initialize_model_parallel_for_nemo( pipeline_model_parallel_split_rank_=pipeline_model_parallel_split_rank, context_parallel_size_=context_parallel_size, expert_model_parallel_size_=expert_model_parallel_size, + use_tp_pp_dp_mapping=use_tp_pp_dp_mapping, ) # update apex.transformer globals @@ -192,6 +196,7 @@ def fake_initialize_model_parallel( virtual_pipeline_model_parallel_size_=None, expert_model_parallel_size_=1, context_parallel_size_=1, + use_tp_pp_dp_mapping=False, ): """ Fake initialize model data parallel groups so that we can instantiate model parallel models before DDP is initialized. @@ -241,24 +246,29 @@ def fake_initialize_model_parallel( if virtual_pipeline_model_parallel_size_ is not None: virtual_pipeline_model_parallel_rank = 0 + rank_generator = RankGenerator( + tp=tensor_model_parallel_size, + ep=expert_model_parallel_size_, + dp=data_parallel_size, + pp=pipeline_model_parallel_size, + cp=context_parallel_size, + order='tp-pp-dp' if use_tp_pp_dp_mapping else 'tp-cp-ep-dp-pp', + ) + # Build the data-parallel groups. all_data_parallel_group_ranks_with_cp = [] - for i in range(pipeline_model_parallel_size): - start_rank = i * num_pipeline_model_parallel_groups - end_rank = (i + 1) * num_pipeline_model_parallel_groups - for j in range(context_parallel_size * tensor_model_parallel_size): - ranks = range(start_rank + j, end_rank, context_parallel_size * tensor_model_parallel_size) - if rank in ranks: - data_parallel_group = list(ranks) - logging.info(f'Rank {rank} has data parallel group : {data_parallel_group}') - for j in range(tensor_model_parallel_size): - ranks_with_cp = range(start_rank + j, end_rank, tensor_model_parallel_size) - all_data_parallel_group_ranks_with_cp.append(list(ranks_with_cp)) - if rank in ranks_with_cp: - data_parallel_group_with_cp = list(ranks_with_cp) - logging.info( - f'Rank {rank} has combined group of data parallel and context parallel : {data_parallel_group_with_cp}' - ) + for ranks in rank_generator.get_ranks('dp'): + if rank in ranks: + data_parallel_group = list(ranks) + logging.info(f'Rank {rank} has data parallel group : {data_parallel_group}') + + for ranks_with_cp in rank_generator.get_ranks('dp-cp'): + all_data_parallel_group_ranks_with_cp.append(ranks_with_cp) + if rank in ranks_with_cp: + data_parallel_group_with_cp = ranks_with_cp + logging.info( + f'Rank {rank} has combined group of data parallel and context parallel : {data_parallel_group_with_cp}' + ) data_parallel_rank = data_parallel_group.index(rank) logging.info( @@ -268,20 +278,11 @@ def fake_initialize_model_parallel( # Build the context-parallel groups. all_context_parallel_group_ranks = [] - for i in range(pipeline_model_parallel_size): - for j in range(data_parallel_size): - start_rank = ( - i * num_pipeline_model_parallel_groups + j * tensor_model_parallel_size * context_parallel_size - ) - end_rank = ( - i * num_pipeline_model_parallel_groups + (j + 1) * tensor_model_parallel_size * context_parallel_size - ) - for k in range(tensor_model_parallel_size): - ranks = range(start_rank + k, end_rank, tensor_model_parallel_size) - all_context_parallel_group_ranks.append(list(ranks)) - if rank in ranks: - context_parallel_group = list(ranks) - logging.info(f'Rank {rank} has context parallel group: {context_parallel_group}') + for ranks in rank_generator.get_ranks('cp'): + all_context_parallel_group_ranks.append(ranks) + if rank in ranks: + context_parallel_group = ranks + logging.info(f'Rank {rank} has context parallel group: {context_parallel_group}') context_parallel_rank = context_parallel_group.index(rank) logging.info(f'All context parallel group ranks: {all_context_parallel_group_ranks}') @@ -289,11 +290,7 @@ def fake_initialize_model_parallel( # Build the model-parallel groups. all_model_parallel_group_ranks = [] - for i in range(data_parallel_size * context_parallel_size): - ranks = [ - data_parallel_group_ranks_with_cp[i] - for data_parallel_group_ranks_with_cp in all_data_parallel_group_ranks_with_cp - ] + for ranks in rank_generator.get_ranks('tp-pp'): all_model_parallel_group_ranks.append(ranks) if rank in ranks: logging.info(f'Rank {rank} has model parallel group: {list(ranks)}') @@ -302,11 +299,10 @@ def fake_initialize_model_parallel( # Build the tensor model-parallel groups. all_tensor_model_parallel_group_ranks = [] tensor_model_parallel_group = None - for i in range(num_tensor_model_parallel_groups): - ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) - all_tensor_model_parallel_group_ranks.append(list(ranks)) + for ranks in rank_generator.get_ranks('tp'): + all_tensor_model_parallel_group_ranks.append(ranks) if rank in ranks: - tensor_model_parallel_group = list(ranks) + tensor_model_parallel_group = ranks logging.info(f'Rank {rank} has tensor model parallel group: {tensor_model_parallel_group}') tensor_model_parallel_rank = tensor_model_parallel_group.index(rank) @@ -317,17 +313,9 @@ def fake_initialize_model_parallel( # EP rank expert_model_parallel_rank = 0 if expert_model_parallel_size_ is not None and expert_model_parallel_size_ > 1: - tensor_and_data_group_size: int = tensor_model_parallel_size * data_parallel_size - num_tensor_and_data_groups: int = world_size // tensor_and_data_group_size - tensor_and_expert_group_size: int = tensor_model_parallel_size * expert_model_parallel_size_ - num_expert_groups: int = data_parallel_size // expert_model_parallel_size_ - for i in range(num_tensor_and_data_groups): - for j in range(num_expert_groups): - start_rank = i * tensor_and_data_group_size + j * tensor_and_expert_group_size - end_rank = i * tensor_and_data_group_size + (j + 1) * tensor_and_expert_group_size - ranks = range(start_rank, end_rank) - if rank in ranks: - expert_model_parallel_rank = list(ranks).index(rank) // tensor_model_parallel_size + for ranks in rank_generator.get_ranks('ep', independent_ep=True): + if rank in ranks: + expert_model_parallel_rank = list(ranks).index(rank) // tensor_model_parallel_size # Build the pipeline model-parallel groups and embedding groups # (first and last rank in each pipeline model-parallel group). @@ -336,11 +324,10 @@ def fake_initialize_model_parallel( pipeline_model_parallel_group = None embedding_group = None embedding_rank = None - for i in range(num_pipeline_model_parallel_groups): - ranks = range(i, world_size, num_pipeline_model_parallel_groups) - all_pipeline_model_parallel_group_ranks.append(list(ranks)) + for ranks in rank_generator.get_ranks('pp'): + all_pipeline_model_parallel_group_ranks.append(ranks) if rank in ranks: - pipeline_model_parallel_group = list(ranks) + pipeline_model_parallel_group = ranks logging.info(f'Rank {rank} has pipeline model parallel group: {pipeline_model_parallel_group}') # Setup embedding group (to exchange gradients between diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index d4a75e3353c7..983b76784a66 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -137,6 +137,7 @@ def init_model_parallel(sharp: bool, nccl_communicator_config_path: str = None) nccl_communicator_config_path=nccl_communicator_config_path, use_sharp=sharp, expert_model_parallel_size=app_state.expert_model_parallel_size, + order='tp-pp-dp' if app_state.use_tp_pp_dp_mapping else 'tp-cp-ep-dp-pp', ) # assert that fake tp and pp rank match after model parallel init diff --git a/nemo/utils/app_state.py b/nemo/utils/app_state.py index 8ba9880219ec..34a03fc28871 100644 --- a/nemo/utils/app_state.py +++ b/nemo/utils/app_state.py @@ -55,6 +55,7 @@ def __init__(self): self._is_megatron_initialized = False self._data_parallel_size = None self._data_parallel_group = None + self._use_tp_pp_dp_mapping = False self._megatron_checkpoint_version = None self._use_fp8 = False self._context_parallel_size = None @@ -191,6 +192,14 @@ def pipeline_model_parallel_size(self, size): """ self._pipeline_model_parallel_size = size + @property + def use_tp_pp_dp_mapping(self): + return self._use_tp_pp_dp_mapping + + @use_tp_pp_dp_mapping.setter + def use_tp_pp_dp_mapping(self, use_new_mapping): + self._use_tp_pp_dp_mapping = use_new_mapping + @property def virtual_pipeline_model_parallel_size(self): """ Property returns the number of GPUs in each model parallel group. diff --git a/tests/collections/nlp/test_initialize.py b/tests/collections/nlp/test_initialize.py new file mode 100644 index 000000000000..b8e27573ce61 --- /dev/null +++ b/tests/collections/nlp/test_initialize.py @@ -0,0 +1,134 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel + + +def old_fake_initialize_model_parallel( + world_size, + rank, + tensor_model_parallel_size_, + pipeline_model_parallel_size_, + pipeline_model_parallel_split_rank_=None, + virtual_pipeline_model_parallel_size_=None, + expert_model_parallel_size_=1, + context_parallel_size_=1, +): + # Get world size and rank. Ensure some consistencies. + tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size) + pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size) + model_parallel_size = tensor_model_parallel_size * pipeline_model_parallel_size + context_parallel_size = min(context_parallel_size_, world_size) + + assert ( + world_size % (tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size) == 0 + ), f'world_size: {world_size} must be divisible by tensor_model_parallel_size: {tensor_model_parallel_size} times pipeline_model_parallel_size {pipeline_model_parallel_size} times context_parallel_size {context_parallel_size}' + data_parallel_size = world_size // ( + tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size + ) + + num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size + num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size + + virtual_pipeline_model_parallel_rank = None + if virtual_pipeline_model_parallel_size_ is not None: + virtual_pipeline_model_parallel_rank = 0 + + # Build the tensor model-parallel groups. + tensor_model_parallel_group = None + for i in range(num_tensor_model_parallel_groups): + ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) + if rank in ranks: + tensor_model_parallel_group = list(ranks) + + tensor_model_parallel_rank = tensor_model_parallel_group.index(rank) + + # EP rank + expert_model_parallel_rank = 0 + if expert_model_parallel_size_ is not None and expert_model_parallel_size_ > 1: + tensor_and_data_group_size: int = tensor_model_parallel_size * data_parallel_size + num_tensor_and_data_groups: int = world_size // tensor_and_data_group_size + tensor_and_expert_group_size: int = tensor_model_parallel_size * expert_model_parallel_size_ + num_expert_groups: int = data_parallel_size // expert_model_parallel_size_ + for i in range(num_tensor_and_data_groups): + for j in range(num_expert_groups): + start_rank = i * tensor_and_data_group_size + j * tensor_and_expert_group_size + end_rank = i * tensor_and_data_group_size + (j + 1) * tensor_and_expert_group_size + ranks = range(start_rank, end_rank) + if rank in ranks: + expert_model_parallel_rank = list(ranks).index(rank) // tensor_model_parallel_size + + # Build the pipeline model-parallel groups and embedding groups + # (first and last rank in each pipeline model-parallel group). + pipeline_model_parallel_group = None + for i in range(num_pipeline_model_parallel_groups): + ranks = range(i, world_size, num_pipeline_model_parallel_groups) + if rank in ranks: + pipeline_model_parallel_group = list(ranks) + + pipeline_model_parallel_rank = pipeline_model_parallel_group.index(rank) + + return ( + tensor_model_parallel_rank, + pipeline_model_parallel_rank, + expert_model_parallel_rank, + model_parallel_size, + data_parallel_size, + pipeline_model_parallel_split_rank_, + virtual_pipeline_model_parallel_rank, + ) + + +@pytest.mark.parametrize( + 'nodes, num_gpu, tp, pp, cp, ep', + [ + (1, 1, 1, 1, 1, 1), + (4, 8, 2, 4, 1, 1), + (8, 8, 8, 8, 1, 1), + (16, 8, 4, 8, 1, 1), + (16, 8, 4, 8, 4, 1), + (32, 8, 8, 8, 1, 1), + (32, 8, 4, 8, 1, 4), + (32, 8, 8, 8, 4, 1), + ], +) +def test_fake_initialize(nodes, num_gpu, tp, pp, cp, ep): + ( + tensor_model_parallel_rank, + pipeline_model_parallel_rank, + expert_model_parallel_rank, + model_parallel_size, + data_parallel_size, + pipeline_model_parallel_split_rank, + virtual_pipeline_model_parallel_rank, + ) = old_fake_initialize_model_parallel(nodes * num_gpu, 0, tp, pp, None, None, ep, cp) + + ( + m_tensor_model_parallel_rank, + n_pipeline_model_parallel_rank, + n_expert_model_parallel_rank, + n_model_parallel_size, + n_data_parallel_size, + n_pipeline_model_parallel_split_rank, + n_virtual_pipeline_model_parallel_rank, + ) = fake_initialize_model_parallel(nodes * num_gpu, 0, tp, pp, None, None, ep, cp) + assert m_tensor_model_parallel_rank == tensor_model_parallel_rank + assert n_pipeline_model_parallel_rank == pipeline_model_parallel_rank + assert n_expert_model_parallel_rank == expert_model_parallel_rank + assert n_model_parallel_size == model_parallel_size + assert n_data_parallel_size == data_parallel_size + assert n_pipeline_model_parallel_split_rank == pipeline_model_parallel_split_rank + assert n_virtual_pipeline_model_parallel_rank == virtual_pipeline_model_parallel_rank From cb22d71e335bc25bfe09947c64a2223550fc65ae Mon Sep 17 00:00:00 2001 From: Eric Harper Date: Fri, 12 Apr 2024 20:03:36 -0600 Subject: [PATCH 2/4] update package info (#8793) Signed-off-by: eharper --- Dockerfile | 2 +- nemo/package_info.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Dockerfile b/Dockerfile index 970c34a690d4..fa825d61f015 100644 --- a/Dockerfile +++ b/Dockerfile @@ -141,7 +141,7 @@ COPY . . # start building the final container FROM nemo-deps as nemo -ARG NEMO_VERSION=1.23.0 +ARG NEMO_VERSION=2.0.0 # Check that NEMO_VERSION is set. Build will fail without this. Expose NEMO and base container # version information as runtime environment variable for introspection purposes diff --git a/nemo/package_info.py b/nemo/package_info.py index e0ff2247e6ad..b253927a6b38 100644 --- a/nemo/package_info.py +++ b/nemo/package_info.py @@ -13,8 +13,8 @@ # limitations under the License. -MAJOR = 1 -MINOR = 23 +MAJOR = 2 +MINOR = 0 PATCH = 0 PRE_RELEASE = 'rc0' From 378a9b3d9845a02eacc392e267f2e66dc62f151f Mon Sep 17 00:00:00 2001 From: Rachit Garg Date: Sat, 13 Apr 2024 10:05:18 -0700 Subject: [PATCH 3/4] Rachitg/dpa (#8911) * remove fp8 checkpoints for Attention Signed-off-by: rachitg * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes Signed-off-by: rachitg * set default value and support mha Signed-off-by: rachitg * skip by default Signed-off-by: rachitg --------- Signed-off-by: rachitg Co-authored-by: rachitg Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../nlp/models/language_modeling/megatron_gpt_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index ede72439615e..4493532f88bf 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -1741,7 +1741,7 @@ def skip_fp8_load(x): x = LocalNonpersitentObject(x.data) # use the FP8 state from initialization, not from ckpt return x - if self.cfg.get('fp8_dot_product_attention', False) or self.cfg.get('fp8_multi_head_attention', False): + if self.cfg.get('skip_fp8_attention_checkpoint_load', True): dict_list_map_inplace(skip_fp8_load, sharded_state_dict) return sharded_state_dict From de983ff6eb164944197c0e96807c3ee74119057c Mon Sep 17 00:00:00 2001 From: Pablo Garay Date: Sat, 13 Apr 2024 15:28:26 -0700 Subject: [PATCH 4/4] update mcore (#8917) --- .github/workflows/cicd-main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 29ea34dba197..c4350a42f59b 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -114,7 +114,7 @@ jobs: # Megatron Core installation git clone https://github.com/NVIDIA/Megatron-LM.git && \ pushd Megatron-LM && \ - git checkout f3a3020031f384ddafd9b7e9f3a587798c0aea21 && \ + git checkout fbb375d4b5e88ce52f5f7125053068caff47f93f && \ pip install . && \ pushd megatron/core/datasets && \ make && \