Skip to content

Commit

Permalink
Support alternative mapping TP->PP->DP (#8909)
Browse files Browse the repository at this point in the history
* support new tp-pp-dp mapping

Signed-off-by: jxin <jxin@nvidia.com>

* add test

Signed-off-by: jxin <jxin@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refine

Signed-off-by: jxin <jxin@nvidia.com>

* change mcore commit

Signed-off-by: jxin <jxin@nvidia.com>

---------

Signed-off-by: jxin <jxin@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Harper <complex451@gmail.com>
  • Loading branch information
3 people authored Apr 13, 2024
1 parent 21913a0 commit ac53e22
Show file tree
Hide file tree
Showing 8 changed files with 189 additions and 68 deletions.
2 changes: 1 addition & 1 deletion Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ pipeline {
steps {
sh 'git clone https://github.com/NVIDIA/Megatron-LM.git && \
cd Megatron-LM && \
git checkout f3a3020031f384ddafd9b7e9f3a587798c0aea21 && \
git checkout fbb375d4b5e88ce52f5f7125053068caff47f93f && \
pip install . && \
cd megatron/core/datasets && \
make'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1770,12 +1770,7 @@ def fwd_bwd_step(self, dataloader_iter, forward_only):
# we can avoid this broadcast by updating the PTL log function to accept specific ranks
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
if self.loss_broadcast_src_rank is None:
dp_size = parallel_state.get_data_parallel_world_size()
tp_size = parallel_state.get_tensor_model_parallel_world_size()
pp_size = parallel_state.get_pipeline_model_parallel_world_size()
rank_in_dp_tp_group = torch.distributed.get_rank() % (dp_size * tp_size)
last_pipeline_stage_offset = (tp_size * dp_size) * (pp_size - 1)
self.loss_broadcast_src_rank = last_pipeline_stage_offset + rank_in_dp_tp_group
self.loss_broadcast_src_rank = parallel_state.get_pipeline_model_parallel_last_rank()
torch.distributed.broadcast(
loss_mean, self.loss_broadcast_src_rank, group=parallel_state.get_pipeline_model_parallel_group(),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True):
pipeline_model_parallel_size=cfg.get('pipeline_model_parallel_size', 1),
virtual_pipeline_model_parallel_size=vp_size,
pipeline_model_parallel_split_rank=cfg.get('pipeline_model_parallel_split_rank', 0),
use_tp_pp_dp_mapping=cfg.get('use_tp_pp_dp_mapping', False),
context_parallel_size=cfg.get('context_parallel_size', 1),
micro_batch_size=cfg.get('micro_batch_size'),
global_batch_size=cfg.get('global_batch_size'),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1310,13 +1310,7 @@ def on_validation_epoch_end(self):
# it should be casted to other pipeline stages for logging.
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
if self.loss_broadcast_src_rank is None:
dp_size = parallel_state.get_data_parallel_world_size()
cp_size = parallel_state.get_context_parallel_world_size()
tp_size = parallel_state.get_tensor_model_parallel_world_size()
pp_size = parallel_state.get_pipeline_model_parallel_world_size()
rank_in_dp_tp_group = torch.distributed.get_rank() % (dp_size * cp_size * tp_size)
last_pipeline_stage_offset = (tp_size * cp_size * dp_size) * (pp_size - 1)
self.loss_broadcast_src_rank = last_pipeline_stage_offset + rank_in_dp_tp_group
self.loss_broadcast_src_rank = parallel_state.get_pipeline_model_parallel_last_rank()
torch.distributed.broadcast(
averaged_loss, self.loss_broadcast_src_rank, group=parallel_state.get_pipeline_model_parallel_group(),
)
Expand Down
95 changes: 41 additions & 54 deletions nemo/collections/nlp/modules/common/megatron/megatron_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
try:
from megatron.core import tensor_parallel
from megatron.core.parallel_state import (
RankGenerator,
get_pipeline_model_parallel_rank,
set_expert_model_parallel_rank,
set_expert_model_parallel_world_size,
Expand Down Expand Up @@ -74,6 +75,7 @@ def initialize_model_parallel_for_nemo(
init_mpi_proc_group=False,
seed=1234,
apex_transformer_log_level=30,
use_tp_pp_dp_mapping=False,
):

if virtual_pipeline_model_parallel_size is not None and not HAVE_INTERLEAVED:
Expand All @@ -84,6 +86,7 @@ def initialize_model_parallel_for_nemo(
app_state.global_rank = global_rank
app_state.world_size = world_size
app_state.local_rank = local_rank
app_state.use_tp_pp_dp_mapping = use_tp_pp_dp_mapping
app_state.expert_model_parallel_size = expert_model_parallel_size
app_state.tensor_model_parallel_size = tensor_model_parallel_size
app_state.pipeline_model_parallel_size = pipeline_model_parallel_size
Expand All @@ -108,6 +111,7 @@ def initialize_model_parallel_for_nemo(
pipeline_model_parallel_split_rank_=pipeline_model_parallel_split_rank,
context_parallel_size_=context_parallel_size,
expert_model_parallel_size_=expert_model_parallel_size,
use_tp_pp_dp_mapping=use_tp_pp_dp_mapping,
)

# update apex.transformer globals
Expand Down Expand Up @@ -192,6 +196,7 @@ def fake_initialize_model_parallel(
virtual_pipeline_model_parallel_size_=None,
expert_model_parallel_size_=1,
context_parallel_size_=1,
use_tp_pp_dp_mapping=False,
):
"""
Fake initialize model data parallel groups so that we can instantiate model parallel models before DDP is initialized.
Expand Down Expand Up @@ -241,24 +246,29 @@ def fake_initialize_model_parallel(
if virtual_pipeline_model_parallel_size_ is not None:
virtual_pipeline_model_parallel_rank = 0

rank_generator = RankGenerator(
tp=tensor_model_parallel_size,
ep=expert_model_parallel_size_,
dp=data_parallel_size,
pp=pipeline_model_parallel_size,
cp=context_parallel_size,
order='tp-pp-dp' if use_tp_pp_dp_mapping else 'tp-cp-ep-dp-pp',
)

# Build the data-parallel groups.
all_data_parallel_group_ranks_with_cp = []
for i in range(pipeline_model_parallel_size):
start_rank = i * num_pipeline_model_parallel_groups
end_rank = (i + 1) * num_pipeline_model_parallel_groups
for j in range(context_parallel_size * tensor_model_parallel_size):
ranks = range(start_rank + j, end_rank, context_parallel_size * tensor_model_parallel_size)
if rank in ranks:
data_parallel_group = list(ranks)
logging.info(f'Rank {rank} has data parallel group : {data_parallel_group}')
for j in range(tensor_model_parallel_size):
ranks_with_cp = range(start_rank + j, end_rank, tensor_model_parallel_size)
all_data_parallel_group_ranks_with_cp.append(list(ranks_with_cp))
if rank in ranks_with_cp:
data_parallel_group_with_cp = list(ranks_with_cp)
logging.info(
f'Rank {rank} has combined group of data parallel and context parallel : {data_parallel_group_with_cp}'
)
for ranks in rank_generator.get_ranks('dp'):
if rank in ranks:
data_parallel_group = list(ranks)
logging.info(f'Rank {rank} has data parallel group : {data_parallel_group}')

for ranks_with_cp in rank_generator.get_ranks('dp-cp'):
all_data_parallel_group_ranks_with_cp.append(ranks_with_cp)
if rank in ranks_with_cp:
data_parallel_group_with_cp = ranks_with_cp
logging.info(
f'Rank {rank} has combined group of data parallel and context parallel : {data_parallel_group_with_cp}'
)

data_parallel_rank = data_parallel_group.index(rank)
logging.info(
Expand All @@ -268,32 +278,19 @@ def fake_initialize_model_parallel(

# Build the context-parallel groups.
all_context_parallel_group_ranks = []
for i in range(pipeline_model_parallel_size):
for j in range(data_parallel_size):
start_rank = (
i * num_pipeline_model_parallel_groups + j * tensor_model_parallel_size * context_parallel_size
)
end_rank = (
i * num_pipeline_model_parallel_groups + (j + 1) * tensor_model_parallel_size * context_parallel_size
)
for k in range(tensor_model_parallel_size):
ranks = range(start_rank + k, end_rank, tensor_model_parallel_size)
all_context_parallel_group_ranks.append(list(ranks))
if rank in ranks:
context_parallel_group = list(ranks)
logging.info(f'Rank {rank} has context parallel group: {context_parallel_group}')
for ranks in rank_generator.get_ranks('cp'):
all_context_parallel_group_ranks.append(ranks)
if rank in ranks:
context_parallel_group = ranks
logging.info(f'Rank {rank} has context parallel group: {context_parallel_group}')

context_parallel_rank = context_parallel_group.index(rank)
logging.info(f'All context parallel group ranks: {all_context_parallel_group_ranks}')
logging.info(f'Ranks {rank} has context parallel rank: {context_parallel_rank}')

# Build the model-parallel groups.
all_model_parallel_group_ranks = []
for i in range(data_parallel_size * context_parallel_size):
ranks = [
data_parallel_group_ranks_with_cp[i]
for data_parallel_group_ranks_with_cp in all_data_parallel_group_ranks_with_cp
]
for ranks in rank_generator.get_ranks('tp-pp'):
all_model_parallel_group_ranks.append(ranks)
if rank in ranks:
logging.info(f'Rank {rank} has model parallel group: {list(ranks)}')
Expand All @@ -302,11 +299,10 @@ def fake_initialize_model_parallel(
# Build the tensor model-parallel groups.
all_tensor_model_parallel_group_ranks = []
tensor_model_parallel_group = None
for i in range(num_tensor_model_parallel_groups):
ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)
all_tensor_model_parallel_group_ranks.append(list(ranks))
for ranks in rank_generator.get_ranks('tp'):
all_tensor_model_parallel_group_ranks.append(ranks)
if rank in ranks:
tensor_model_parallel_group = list(ranks)
tensor_model_parallel_group = ranks
logging.info(f'Rank {rank} has tensor model parallel group: {tensor_model_parallel_group}')

tensor_model_parallel_rank = tensor_model_parallel_group.index(rank)
Expand All @@ -317,17 +313,9 @@ def fake_initialize_model_parallel(
# EP rank
expert_model_parallel_rank = 0
if expert_model_parallel_size_ is not None and expert_model_parallel_size_ > 1:
tensor_and_data_group_size: int = tensor_model_parallel_size * data_parallel_size
num_tensor_and_data_groups: int = world_size // tensor_and_data_group_size
tensor_and_expert_group_size: int = tensor_model_parallel_size * expert_model_parallel_size_
num_expert_groups: int = data_parallel_size // expert_model_parallel_size_
for i in range(num_tensor_and_data_groups):
for j in range(num_expert_groups):
start_rank = i * tensor_and_data_group_size + j * tensor_and_expert_group_size
end_rank = i * tensor_and_data_group_size + (j + 1) * tensor_and_expert_group_size
ranks = range(start_rank, end_rank)
if rank in ranks:
expert_model_parallel_rank = list(ranks).index(rank) // tensor_model_parallel_size
for ranks in rank_generator.get_ranks('ep', independent_ep=True):
if rank in ranks:
expert_model_parallel_rank = list(ranks).index(rank) // tensor_model_parallel_size

# Build the pipeline model-parallel groups and embedding groups
# (first and last rank in each pipeline model-parallel group).
Expand All @@ -336,11 +324,10 @@ def fake_initialize_model_parallel(
pipeline_model_parallel_group = None
embedding_group = None
embedding_rank = None
for i in range(num_pipeline_model_parallel_groups):
ranks = range(i, world_size, num_pipeline_model_parallel_groups)
all_pipeline_model_parallel_group_ranks.append(list(ranks))
for ranks in rank_generator.get_ranks('pp'):
all_pipeline_model_parallel_group_ranks.append(ranks)
if rank in ranks:
pipeline_model_parallel_group = list(ranks)
pipeline_model_parallel_group = ranks
logging.info(f'Rank {rank} has pipeline model parallel group: {pipeline_model_parallel_group}')

# Setup embedding group (to exchange gradients between
Expand Down
1 change: 1 addition & 0 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def init_model_parallel(sharp: bool, nccl_communicator_config_path: str = None)
nccl_communicator_config_path=nccl_communicator_config_path,
use_sharp=sharp,
expert_model_parallel_size=app_state.expert_model_parallel_size,
order='tp-pp-dp' if app_state.use_tp_pp_dp_mapping else 'tp-cp-ep-dp-pp',
)

# assert that fake tp and pp rank match after model parallel init
Expand Down
9 changes: 9 additions & 0 deletions nemo/utils/app_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(self):
self._is_megatron_initialized = False
self._data_parallel_size = None
self._data_parallel_group = None
self._use_tp_pp_dp_mapping = False
self._megatron_checkpoint_version = None
self._use_fp8 = False
self._context_parallel_size = None
Expand Down Expand Up @@ -191,6 +192,14 @@ def pipeline_model_parallel_size(self, size):
"""
self._pipeline_model_parallel_size = size

@property
def use_tp_pp_dp_mapping(self):
return self._use_tp_pp_dp_mapping

@use_tp_pp_dp_mapping.setter
def use_tp_pp_dp_mapping(self, use_new_mapping):
self._use_tp_pp_dp_mapping = use_new_mapping

@property
def virtual_pipeline_model_parallel_size(self):
""" Property returns the number of GPUs in each model parallel group.
Expand Down
134 changes: 134 additions & 0 deletions tests/collections/nlp/test_initialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel


def old_fake_initialize_model_parallel(
world_size,
rank,
tensor_model_parallel_size_,
pipeline_model_parallel_size_,
pipeline_model_parallel_split_rank_=None,
virtual_pipeline_model_parallel_size_=None,
expert_model_parallel_size_=1,
context_parallel_size_=1,
):
# Get world size and rank. Ensure some consistencies.
tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size)
pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size)
model_parallel_size = tensor_model_parallel_size * pipeline_model_parallel_size
context_parallel_size = min(context_parallel_size_, world_size)

assert (
world_size % (tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size) == 0
), f'world_size: {world_size} must be divisible by tensor_model_parallel_size: {tensor_model_parallel_size} times pipeline_model_parallel_size {pipeline_model_parallel_size} times context_parallel_size {context_parallel_size}'
data_parallel_size = world_size // (
tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size
)

num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size

virtual_pipeline_model_parallel_rank = None
if virtual_pipeline_model_parallel_size_ is not None:
virtual_pipeline_model_parallel_rank = 0

# Build the tensor model-parallel groups.
tensor_model_parallel_group = None
for i in range(num_tensor_model_parallel_groups):
ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)
if rank in ranks:
tensor_model_parallel_group = list(ranks)

tensor_model_parallel_rank = tensor_model_parallel_group.index(rank)

# EP rank
expert_model_parallel_rank = 0
if expert_model_parallel_size_ is not None and expert_model_parallel_size_ > 1:
tensor_and_data_group_size: int = tensor_model_parallel_size * data_parallel_size
num_tensor_and_data_groups: int = world_size // tensor_and_data_group_size
tensor_and_expert_group_size: int = tensor_model_parallel_size * expert_model_parallel_size_
num_expert_groups: int = data_parallel_size // expert_model_parallel_size_
for i in range(num_tensor_and_data_groups):
for j in range(num_expert_groups):
start_rank = i * tensor_and_data_group_size + j * tensor_and_expert_group_size
end_rank = i * tensor_and_data_group_size + (j + 1) * tensor_and_expert_group_size
ranks = range(start_rank, end_rank)
if rank in ranks:
expert_model_parallel_rank = list(ranks).index(rank) // tensor_model_parallel_size

# Build the pipeline model-parallel groups and embedding groups
# (first and last rank in each pipeline model-parallel group).
pipeline_model_parallel_group = None
for i in range(num_pipeline_model_parallel_groups):
ranks = range(i, world_size, num_pipeline_model_parallel_groups)
if rank in ranks:
pipeline_model_parallel_group = list(ranks)

pipeline_model_parallel_rank = pipeline_model_parallel_group.index(rank)

return (
tensor_model_parallel_rank,
pipeline_model_parallel_rank,
expert_model_parallel_rank,
model_parallel_size,
data_parallel_size,
pipeline_model_parallel_split_rank_,
virtual_pipeline_model_parallel_rank,
)


@pytest.mark.parametrize(
'nodes, num_gpu, tp, pp, cp, ep',
[
(1, 1, 1, 1, 1, 1),
(4, 8, 2, 4, 1, 1),
(8, 8, 8, 8, 1, 1),
(16, 8, 4, 8, 1, 1),
(16, 8, 4, 8, 4, 1),
(32, 8, 8, 8, 1, 1),
(32, 8, 4, 8, 1, 4),
(32, 8, 8, 8, 4, 1),
],
)
def test_fake_initialize(nodes, num_gpu, tp, pp, cp, ep):
(
tensor_model_parallel_rank,
pipeline_model_parallel_rank,
expert_model_parallel_rank,
model_parallel_size,
data_parallel_size,
pipeline_model_parallel_split_rank,
virtual_pipeline_model_parallel_rank,
) = old_fake_initialize_model_parallel(nodes * num_gpu, 0, tp, pp, None, None, ep, cp)

(
m_tensor_model_parallel_rank,
n_pipeline_model_parallel_rank,
n_expert_model_parallel_rank,
n_model_parallel_size,
n_data_parallel_size,
n_pipeline_model_parallel_split_rank,
n_virtual_pipeline_model_parallel_rank,
) = fake_initialize_model_parallel(nodes * num_gpu, 0, tp, pp, None, None, ep, cp)
assert m_tensor_model_parallel_rank == tensor_model_parallel_rank
assert n_pipeline_model_parallel_rank == pipeline_model_parallel_rank
assert n_expert_model_parallel_rank == expert_model_parallel_rank
assert n_model_parallel_size == model_parallel_size
assert n_data_parallel_size == data_parallel_size
assert n_pipeline_model_parallel_split_rank == pipeline_model_parallel_split_rank
assert n_virtual_pipeline_model_parallel_rank == virtual_pipeline_model_parallel_rank

0 comments on commit ac53e22

Please sign in to comment.