Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MoE support for T5 model (w/o expert parallel) #5409

Merged
merged 21 commits into from
Nov 15, 2022
Merged
44 changes: 44 additions & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -3694,6 +3694,50 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
sh "rm -rf examples/nlp/language_modeling/t5_index_mappings"
}
}
stage('L2: Megatron T5 w/ Mixture of Expert Pretraining') {
when {
anyOf {
branch 'main'
changeRequest target: 'main'
}
}
failFast true
steps {
sh "python examples/nlp/language_modeling/megatron_t5_pretraining.py \
trainer.devices=2 \
trainer.accelerator=gpu \
trainer.log_every_n_steps=1 \
trainer.val_check_interval=10 \
trainer.limit_val_batches=2 \
trainer.accumulate_grad_batches=1 \
trainer.max_steps=10 \
trainer.precision=16 \
trainer.gradient_clip_val=1.0 \
exp_manager.exp_dir=examples/nlp/language_modeling/t5_pretrain_results \
model.pipeline_model_parallel_split_rank=1 \
model.seq_length=256 \
model.encoder.num_layers=4 \
model.decoder.num_layers=1 \
model.encoder.num_moe_experts=4 \
model.decoder.num_moe_experts=4 \
model.encoder.moe_frequency=3 \
model.decoder.moe_frequency=1 \
model.encoder.hidden_size=64 \
model.decoder.hidden_size=64 \
model.encoder.num_attention_heads=8 \
model.decoder.num_attention_heads=8 \
model.decoder.ffn_hidden_size=2048 \
model.encoder.activation='gelu' \
model.encoder.activations_checkpoint_method='block' \
model.encoder.activations_checkpoint_num_layers=1 \
model.encoder.transformer_block_type='pre_ln' \
model.decoder.transformer_block_type='post_ln' \
model.data.data_prefix=[.5,/home/TestData/nlp/megatron_t5/data/pile_val_small_bert_tokenizer_text_document,.5,/home/TestData/nlp/megatron_t5/data/pile_val_small_bert_tokenizer_text_document] \
model.data.index_mapping_dir=examples/nlp/language_modeling/t5_index_mappings"
sh "rm -rf examples/nlp/language_modeling/t5_pretrain_results"
sh "rm -rf examples/nlp/language_modeling/t5_index_mappings"
}
}
stage('L2: Megatron T5 Prompt Learning') {
when {
anyOf {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,6 @@ activations_checkpoint_method: null # 'uniform', 'block'
activations_checkpoint_num_layers: 1
megatron_legacy: False # Whether to use the legacy Megatron model. This affects the way q,k,v is partitioned from the mixed q,k,v layer in ParallelAttention. This needs to be True for models converted from HF.
normalize_attention_scores: True # Whether to scale the output Q * K^T by 1 / sqrt(hidden_size_per_head). This arg is provided as a configuration option mostly for compatibility with models that have been weight-converted from HF. You almost always want to se this to True.
num_moe_experts: 1 # When >1, FFNs are changed to MoE layers
moe_frequency: 1 # every Nth ffn layer will be made MoE
moe_dropout: 0.0 # Dropout value for MoE layers
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def get_decoder_model(
normalize_attention_scores=True,
sequence_parallel=False,
gradient_accumulation_fusion=False,
num_moe_experts=1,
moe_frequency=1,
moe_dropout=0.0,
):
"""Build language model and return along with the key to save."""

Expand Down Expand Up @@ -134,6 +137,9 @@ def get_decoder_model(
parent_model_type=parent_model_type,
megatron_legacy=megatron_legacy,
normalize_attention_scores=normalize_attention_scores,
num_moe_experts=num_moe_experts,
moe_frequency=moe_frequency,
moe_dropout=moe_dropout,
)
elif arch == "retro":
decoder = MegatronRetrievalTransformerDecoderModule(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def get_encoder_model(
normalize_attention_scores=True,
sequence_parallel=False,
gradient_accumulation_fusion=False,
num_moe_experts=1,
moe_frequency=1,
moe_dropout=0.0,
):
"""Build language model and return along with the key to save."""

Expand Down Expand Up @@ -136,6 +139,9 @@ def get_encoder_model(
parent_model_type=parent_model_type,
megatron_legacy=megatron_legacy,
normalize_attention_scores=normalize_attention_scores,
num_moe_experts=num_moe_experts,
moe_frequency=moe_frequency,
moe_dropout=moe_dropout,
)
elif arch == "retro":
encoder = MegatronRetrievalTransformerEncoderModule(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def __init__(
parent_model_type=ModelType.encoder_or_decoder,
megatron_legacy=False,
normalize_attention_scores=True,
num_moe_experts=1,
moe_frequency=1,
moe_dropout=0.0,
):
super(MegatronTransformerDecoderModule, self).__init__()

Expand Down Expand Up @@ -139,6 +142,9 @@ def __init__(
gradient_accumulation_fusion=False, # TODO: This has to be False for enc-dec models for now.
megatron_legacy=megatron_legacy,
normalize_attention_scores=normalize_attention_scores,
num_moe_experts=num_moe_experts,
moe_frequency=moe_frequency,
moe_dropout=moe_dropout,
)
self._model_key = 'model'

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def __init__(
parent_model_type=ModelType.encoder_or_decoder,
megatron_legacy=False,
normalize_attention_scores=True,
num_moe_experts=1,
moe_frequency=1,
moe_dropout=0.0,
):
super(MegatronTransformerEncoderModule, self).__init__()

Expand Down Expand Up @@ -137,6 +140,9 @@ def __init__(
gradient_accumulation_fusion=False, # TODO: This has to be False for enc-dec models for now.
megatron_legacy=megatron_legacy,
normalize_attention_scores=normalize_attention_scores,
num_moe_experts=num_moe_experts,
moe_frequency=moe_frequency,
moe_dropout=moe_dropout,
)
self._model_key = 'model'

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ def __init__(
num_self_attention_per_cross_attention=encoder_cfg.get('num_self_attention_per_cross_attention', 1),
megatron_legacy=encoder_cfg.get('megatron_legacy', False),
normalize_attention_scores=encoder_cfg.get('normalize_attention_scores', True),
num_moe_experts=encoder_cfg.get('num_moe_experts', 1),
moe_frequency=encoder_cfg.get('moe_frequency', 1),
moe_dropout=encoder_cfg.get('moe_dropout', 0.0),
)

if add_decoder:
Expand Down Expand Up @@ -300,6 +303,9 @@ def __init__(
parent_model_type=ModelType.encoder_and_decoder,
megatron_legacy=decoder_cfg.get('megatron_legacy', False),
normalize_attention_scores=decoder_cfg.get('normalize_attention_scores', True),
num_moe_experts=decoder_cfg.get('num_moe_experts', 1),
moe_frequency=decoder_cfg.get('moe_frequency', 1),
moe_dropout=decoder_cfg.get('moe_dropout', 0.0),
)

self.enc_dec_model = MegatronTransformerEncoderDecoderModule(
Expand Down
Loading