Skip to content

Commit

Permalink
Tensor-parallel communication overlap with userbuffer backend (#6780)
Browse files Browse the repository at this point in the history
* add interfaces for tp_communication overlap

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Interface to provide custom userbuffer communicator settings by yaml file

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Construct MPI process group for userbuffers support

Signed-off-by: Tim Moon <tmoon@nvidia.com>

---------

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: Abhinav Khattar <aklife97@gmail.com>
  • Loading branch information
3 people authored and web-flow committed Jun 1, 2023
1 parent 23f1c42 commit 0ca05a3
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 0 deletions.
7 changes: 7 additions & 0 deletions examples/nlp/language_modeling/conf/megatron_gpt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,13 @@ model:
fp8_amax_compute_algo: most_recent # 'most_recent' or 'max'. Algorithm for computing amax from history
reduce_amax: True # Perform reduction to sync amax tensors across GPUs after every iteration
use_emha: False # Use fused multi-head attention for large sequence-length. Note this is not yet supported. Please set to False.
ub_tp_comm_overlap: False
# Use userbuffer backend to overlap tensor-parallel communications with computes.
# This feature is only available with Transformer Engine and squence parallelism enabled and, currently, supports only GPT models.
ub_tp_comm_overlap_cfg: null
# A yaml file with userbuffer communicator configurations. This file should provide `method`, `dtype`, `num_sm`, `num_splits`,
# `cga_size`, `num_splits`, `set_sm_margin`, and `aggregate` for the communicators to use custom settings.
# If the configuration file is not provided a default setting is used for all communicators.

data:
# Path to data must be specified by the user.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def __init__(
fp8_amax_compute_algo='most_recent',
reduce_amax=True,
use_emha=False,
ub_tp_comm_overlap=False,
):
super(GPTModel, self).__init__(share_token_embeddings=share_embeddings_and_output_weights)

Expand Down Expand Up @@ -243,6 +244,7 @@ def __init__(
fp8_amax_compute_algo=fp8_amax_compute_algo,
reduce_amax=reduce_amax,
use_emha=use_emha,
ub_tp_comm_overlap=ub_tp_comm_overlap,
)

if self.share_embeddings_and_output_weights:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True):
global_batch_size=cfg.get('global_batch_size'),
rampup_batch_size=cfg.get('rampup_batch_size'),
use_fp8=cfg.get('fp8', False),
init_mpi_proc_group=cfg.get('ub_tp_comm_overlap', False),
seed=self.cfg.get('seed', 1234),
apex_transformer_log_level=self.cfg.get('apex_transformer_log_level', 30),
)
Expand Down Expand Up @@ -538,6 +539,14 @@ def _validate_and_override_config(self):
'Make sure the number of model chunks is the same across all pipeline stages.'
)

if self.cfg.get('ub_tp_comm_overlap', False):
if not self.cfg.get('transformer_engine', False) or not self.cfg.get('sequence_parallel', False):
logging.info(
"Userbuffer tensor-parallel communication overlap is available with both Transformer Engine and sequence-parallelism."
)
with open_dict(self.cfg):
self.cfg.ub_tp_comm_overlap = False

def is_data_parallel_rank_zero(self):
if is_global_rank_zero():
return True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@

try:
import transformer_engine
from transformer_engine.pytorch import module as te_module

HAVE_TE = True

Expand Down Expand Up @@ -276,6 +277,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
self._nsys_profile_end_step *= grad_accum_steps

self.get_attention_mask_from_fusion = self.cfg.get('get_attention_mask_from_fusion', True)
self.initialize_ub = self.cfg.get('ub_tp_comm_overlap', False)

def get_gpt_module_list(self):
if isinstance(self.model, list):
Expand Down Expand Up @@ -351,6 +353,7 @@ def model_provider_func(self, pre_process, post_process):
fp8_amax_compute_algo=self.cfg.get('fp8_amax_compute_algo', 'most_recent'),
reduce_amax=self.cfg.get('reduce_amax', True),
use_emha=self.cfg.get('use_emha', False),
ub_tp_comm_overlap=self.cfg.get('ub_tp_comm_overlap', False),
)

return model
Expand Down Expand Up @@ -505,6 +508,31 @@ def training_step(self, dataloader_iter, batch_idx):
The input batch to each micro-batch is fetched using the dataloader function
in the micro-batch fwd function.
"""
# Initialize userbuffer communicators. Initialization is done only once at the
# beginning of the first training step.
if self.initialize_ub:
input_shape = [
self.cfg.get('encoder_seq_length') * self.cfg.get('micro_batch_size'),
self.cfg.get('hidden_size'),
]
ub_cfg_file_name = self.cfg.get('ub_tp_comm_overlap_cfg', None)
if ub_cfg_file_name is not None:
try:
import yaml

with open(ub_cfg_file_name, 'r') as ub_cfg_file:
ub_cfgs = yaml.safe_load(ub_cfg_file)
except (ImportError, TypeError):
print("Fail to read ub_tp_comm_overlap config file.")
else:
ub_cfgs = None
te_module.initialize_ub(
shape=input_shape,
tp_size=self.cfg.get('tensor_model_parallel_size'),
use_fp8=self.cfg.get('fp8'),
ub_cfgs=ub_cfgs,
)
self.initialize_ub = False

# we zero grads here because we also call backward in the megatron-core fwd/bwd functions
self._optimizer.zero_grad()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def get_language_model(
fp8_amax_compute_algo='most_recent',
reduce_amax=True,
use_emha=False,
ub_tp_comm_overlap=False,
):
"""Build language model and return along with the key to save."""

Expand Down Expand Up @@ -191,6 +192,7 @@ def get_language_model(
fp8_amax_compute_algo=fp8_amax_compute_algo,
reduce_amax=reduce_amax,
use_emha=use_emha,
ub_tp_comm_overlap=ub_tp_comm_overlap,
)
# key used for checkpoints.
language_model_key = 'language_model'
Expand Down Expand Up @@ -497,6 +499,7 @@ def __init__(
fp8_amax_compute_algo='most_recent',
reduce_amax=True,
use_emha=False,
ub_tp_comm_overlap=False,
):
super(TransformerLanguageModel, self).__init__(share_token_embeddings=share_embeddings_and_output_weights)

Expand Down Expand Up @@ -602,6 +605,7 @@ def __init__(
fp8_amax_compute_algo=fp8_amax_compute_algo,
reduce_amax=reduce_amax,
use_emha=use_emha,
ub_tp_comm_overlap=ub_tp_comm_overlap,
)
self._encoder_key = 'encoder'

Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/nlp/modules/common/megatron/megatron_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def initialize_model_parallel_for_nemo(
global_batch_size=None,
rampup_batch_size=None,
use_fp8=False,
init_mpi_proc_group=False,
seed=1234,
apex_transformer_log_level=30,
):
Expand All @@ -83,6 +84,7 @@ def initialize_model_parallel_for_nemo(
app_state.pipeline_model_parallel_size = pipeline_model_parallel_size
app_state.virtual_pipeline_model_parallel_size = virtual_pipeline_model_parallel_size
app_state.use_fp8 = use_fp8
app_state.init_mpi_proc_group = init_mpi_proc_group
(
app_state.tensor_model_parallel_rank,
app_state.pipeline_model_parallel_rank,
Expand Down
4 changes: 4 additions & 0 deletions nemo/collections/nlp/modules/common/megatron/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,7 @@ def __init__(
layer_type: str = "encoder",
drop_path_rate: float = 0,
use_emha: bool = False,
ub_tp_comm_overlap: bool = False,
autocast_dtype: Any = 16,
zero_centered_gamma: bool = False,
) -> None:
Expand Down Expand Up @@ -824,6 +825,7 @@ def __init__(
set_parallel_mode=tp_size > 1,
fuse_qkv_params=True,
zero_centered_gamma=zero_centered_gamma,
ub_tp_comm_overlap=ub_tp_comm_overlap,
)
# use_emha=use_emha,

Expand Down Expand Up @@ -919,6 +921,7 @@ def __init__(
fp8_amax_compute_algo='most_recent',
reduce_amax=True,
use_emha=False,
ub_tp_comm_overlap=False,
normalize_attention_scores=True,
multi_query_attention=False,
num_moe_experts=1,
Expand Down Expand Up @@ -1058,6 +1061,7 @@ def build_layer(layer_number):
apply_residual_connection_post_layernorm=False,
autocast_dtype=precision,
use_emha=use_emha,
ub_tp_comm_overlap=ub_tp_comm_overlap,
zero_centered_gamma=normalization == 'layernorm1p',
)
else:
Expand Down
4 changes: 4 additions & 0 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,10 @@ def init_model_parallel(self, global_rank: int, world_size: int) -> None:
app_state.data_parallel_size = parallel_state.get_data_parallel_world_size()
app_state.pipeline_model_parallel_group = parallel_state.get_pipeline_model_parallel_group()

# create MPI process group for UCX-based communication APIs
if app_state.init_mpi_proc_group:
torch.distributed.new_group(backend='mpi')

def save_checkpoint(
self, checkpoint: Dict[str, Any], filepath: Union[str, Path], storage_options: Optional[Any] = None
) -> None:
Expand Down
17 changes: 17 additions & 0 deletions nemo/utils/app_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(self):
self._data_parallel_group = None
self._megatron_checkpoint_version = None
self._use_fp8 = False
self._init_mpi_proc_gruop = False

self._random_seed = None

Expand Down Expand Up @@ -363,6 +364,22 @@ def use_fp8(self, use_fp8):
"""
self._use_fp8 = use_fp8

@property
def init_mpi_proc_group(self):
""" Property sets the initialization of mpi process group.
Returns:
Initialize mpi process group.
"""
return self._init_mpi_proc_group

@init_mpi_proc_group.setter
def init_mpi_proc_group(self, init_mpi_proc_group):
""" Property sets the initialization of mpi process group.
Args:
init_mpi_proc_group: Initialize mpi process group.
"""
self._init_mpi_proc_group = init_mpi_proc_group

@property
def random_seed(self):
""" Property returns the random seed.
Expand Down

0 comments on commit 0ca05a3

Please sign in to comment.