diff --git a/examples_deepspeed/pretrain_llama_distributed.sh b/examples_deepspeed/pretrain_llama_distributed.sh new file mode 100644 index 0000000000..960b3ee674 --- /dev/null +++ b/examples_deepspeed/pretrain_llama_distributed.sh @@ -0,0 +1,118 @@ +#!/bin/bash +# This example script is contributed by external user https://github.com/LydiaXiaohongLi +set -ex + +###################################### +# Change the below configurations here +BASE_PATH=./tmp +DS_CONFIG=${BASE_PATH}/deepspeed.json +DATASET_1="./tmp/data/bookcorpus_train_1m_text_sentence" +DATASET="1 ${DATASET_1}" +CHECKPOINT_PATH=./tmp +TOKENIZER_PATH=./tmp/tokenizer.model # offical llama tokenizer.model + +TP=2 +PP=2 +ZERO_STAGE=0 + +GPUS_PER_NODE=8 +MASTER_ADDR=localhost +MASTER_PORT=6000 +NNODES=1 +NODE_RANK=0 + +HIDDEN_SIZE=2048 # e.g. llama-13b: 5120 +FFN_HIDDEN_SIZE=5504 # e.g. llama-13b: 13824 +NUM_LAYERS=24 # e.g. llama-13b: 40 +NUM_HEADS=16 # e.g. llama-13b: 40 +SEQ_LENGTH=2048 + +MICRO_BATCH_SIZE=4 +GLOBAL_BATCH_SIZE=32 # e.g. llama: 4M tokens +TRAIN_STEPS=250000 # e.g. llama: 1T tokens / 4M tokens_per_batch = 250000 steps +LR=3e-4 +MIN_LR=3e-5 +LR_WARMUP_STEPS=2000 +WEIGHT_DECAY=0.1 +GRAD_CLIP=1 + +# Below configuration required for llama model as per llama paper +# --no-query-key-layer-scaling \ +# --attention-dropout 0 \ +# --hidden-dropout 0 \ +# --use-rotary-position-embeddings \ +# --untie-embeddings-and-output-weights \ +# --swiglu \ +# --normalization rmsnorm \ +# --disable-bias-linear \ +###################################### + + + +cat < $DS_CONFIG +{ + "train_batch_size" : $GLOBAL_BATCH_SIZE, + "train_micro_batch_size_per_gpu": $MICRO_BATCH_SIZE, + "steps_per_print": 1, + "zero_optimization": { + "stage": $ZERO_STAGE + }, + "bf16": { + "enabled": true + } +} +EOT + +ds_args="" +ds_args=" --deepspeed ${ds_args}" +ds_args=" --deepspeed_config=$DS_CONFIG ${ds_args}" +ds_args=" --zero-stage=$ZERO_STAGE ${ds_args}" +ds_args=" --deepspeed-activation-checkpointing ${ds_args}" + + +DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" + +torchrun $DISTRIBUTED_ARGS \ + pretrain_gpt.py \ + --tensor-model-parallel-size $TP \ + --pipeline-model-parallel-size $PP \ + --num-layers $NUM_LAYERS \ + --hidden-size $HIDDEN_SIZE \ + --ffn-hidden-size $FFN_HIDDEN_SIZE \ + --num-attention-heads $NUM_HEADS \ + --micro-batch-size $MICRO_BATCH_SIZE \ + --global-batch-size $GLOBAL_BATCH_SIZE \ + --seq-length $SEQ_LENGTH \ + --max-position-embeddings $SEQ_LENGTH \ + --train-iters $TRAIN_STEPS \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH \ + --data-path $DATASET \ + --data-impl mmap \ + --tokenizer-type GPTSentencePieceTokenizer \ + --tokenizer-model $TOKENIZER_PATH \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr $LR \ + --lr-decay-style cosine \ + --min-lr $MIN_LR \ + --weight-decay $WEIGHT_DECAY \ + --clip-grad $GRAD_CLIP \ + --lr-warmup-iters $LR_WARMUP_STEPS \ + --optimizer adam \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --log-interval 1 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 10 \ + --bf16 \ + --no-query-key-layer-scaling \ + --attention-dropout 0 \ + --hidden-dropout 0 \ + --use-rotary-position-embeddings \ + --untie-embeddings-and-output-weights \ + --swiglu \ + --normalization rmsnorm \ + --disable-bias-linear \ + $ds_args \ No newline at end of file diff --git a/megatron/arguments.py b/megatron/arguments.py index 4bf1d725bb..7fdfaa506b 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -597,6 +597,11 @@ def _add_network_size_args(parser): group.add_argument('--make-vocab-size-divisible-by', type=int, default=128, help='Pad the vocab size to be divisible by this value.' 'This is added for computational efficieny reasons.') + group.add_argument('--normalization', type=str, default='layernorm', + choices=['layernorm', 'rmsnorm'], + help='Options for layer normalization type:' + ' layernorm' + ' rmsnorm') group.add_argument('--layernorm-epsilon', type=float, default=1e-5, help='Layer norm epsilon.') group.add_argument('--apply-layernorm-1p', action='store_true', diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index c632a932f8..b06fb9f81a 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -16,9 +16,11 @@ from megatron.model import LayerNorm from .language_model import EmbeddingPipe -from .transformer import ParallelTransformerLayerPipe +from .transformer import ParallelTransformerLayerPipe, LMHeadPipe from deepspeed.pipe import PipelineModule, LayerSpec, TiedLayerSpec +from apex.normalization import MixedFusedRMSNorm + def post_language_model_processing(lm_output, labels, logit_weights, parallel_output, fp16_lm_cross_entropy): @@ -40,7 +42,7 @@ def post_language_model_processing(lm_output, labels, logit_weights, loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels) else: loss = tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels) - + # [s b] => [b, s] loss = loss.transpose(0,1).contiguous() return loss @@ -74,7 +76,7 @@ def __init__(self, pre_process=self.pre_process, post_process=self.post_process, num_experts=args.num_experts) - + if not args.untie_embeddings_and_output_weights: self.initialize_word_embeddings() @@ -210,16 +212,26 @@ def _to_float16(inputs): self.specs.append(_to_float16) # Embedding layer - self.specs.append(TiedLayerSpec('embed', - EmbeddingPipe, + if args.untie_embeddings_and_output_weights: + self.specs.append(LayerSpec(EmbeddingPipe, args.hidden_size, args.padded_vocab_size, args.max_position_embeddings, args.hidden_dropout, config, num_tokentypes=num_tokentypes, - embedding_weights_in_fp32=args.embedding_weights_in_fp32, - tied_weight_attr='word_embeddings_weight')) + embedding_weights_in_fp32=args.embedding_weights_in_fp32,)) + else: + self.specs.append(TiedLayerSpec('embed', + EmbeddingPipe, + args.hidden_size, + args.padded_vocab_size, + args.max_position_embeddings, + args.hidden_dropout, + config, + num_tokentypes=num_tokentypes, + embedding_weights_in_fp32=args.embedding_weights_in_fp32, + tied_weight_attr='word_embeddings_weight')) for layer_idx in range(args.num_layers): self.specs.append( @@ -229,10 +241,12 @@ def _to_float16(inputs): self_attn_mask_type=AttnMaskType.causal)) # Final layernorm after transformer layers - self.specs.append( - LayerSpec(LayerNorm, - args.hidden_size, - eps=args.layernorm_epsilon)) + if args.normalization == 'layernorm': + self.specs.append(LayerSpec(LayerNorm, + args.hidden_size, + eps=args.layernorm_epsilon)) + else: + self.specs.append(LayerSpec(MixedFusedRMSNorm, args.hidden_size, args.layernorm_epsilon)) def _logits_helper(embedding, lm_output): """A wrapper to massage inputs/outputs from pipeline. """ @@ -240,20 +254,24 @@ def _logits_helper(embedding, lm_output): lm_output, embedding.word_embeddings_weight, self.parallel_output) - - self.specs.append( - TiedLayerSpec('embed', - EmbeddingPipe, - args.hidden_size, - args.padded_vocab_size, - args.max_position_embeddings, - args.hidden_dropout, - config, - num_tokentypes=num_tokentypes, - embedding_weights_in_fp32=args.embedding_weights_in_fp32, - forward_fn=_logits_helper, - tied_weight_attr='word_embeddings_weight') - ) + if args.untie_embeddings_and_output_weights: + self.specs.append( + LayerSpec(LMHeadPipe, args.hidden_size, args.padded_vocab_size, config) + ) + else: + self.specs.append( + TiedLayerSpec('embed', + EmbeddingPipe, + args.hidden_size, + args.padded_vocab_size, + args.max_position_embeddings, + args.hidden_dropout, + config, + num_tokentypes=num_tokentypes, + embedding_weights_in_fp32=args.embedding_weights_in_fp32, + forward_fn=_logits_helper, + tied_weight_attr='word_embeddings_weight') + ) # Convert to fp32 if needed if args.fp16 or args.bf16: diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 215c98d38f..0c33ccb29f 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -45,6 +45,8 @@ flash_attn_builder = FlashAttentionBuilder().load() except TypeError: flash_attn_builder = None +from apex.normalization import MixedFusedRMSNorm + """ We use the following notation throughout this file: h: hidden size @@ -760,18 +762,20 @@ def __init__(self, config, self.fp32_residual_connection = config.fp32_residual_connection # Layernorm on the input data. - if get_accelerator().device_name() == 'cuda': - self.input_layernorm = LayerNorm( - config.hidden_size, - eps=config.layernorm_epsilon, - no_persist_layer_norm=args.no_persist_layer_norm, - sequence_parallel=config.sequence_parallel, - apply_layernorm_1p=args.apply_layernorm_1p) + if args.normalization == 'layernorm': + if get_accelerator().device_name() == 'cuda': + self.input_layernorm = LayerNorm( + config.hidden_size, + eps=config.layernorm_epsilon, + no_persist_layer_norm=args.no_persist_layer_norm, + sequence_parallel=config.sequence_parallel, + apply_layernorm_1p=args.apply_layernorm_1p) + else: + self.input_layernorm = LayerNorm( + config.hidden_size, + eps=config.layernorm_epsilon) else: - self.input_layernorm = LayerNorm( - config.hidden_size, - eps=config.layernorm_epsilon) - + self.input_layernorm = MixedFusedRMSNorm(config.hidden_size, config.layernorm_epsilon) # Self attention. self.self_attention = ParallelAttention( config, @@ -783,19 +787,21 @@ def __init__(self, config, self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None # Layernorm on the attention output - if get_accelerator().device_name() == 'cuda': - self.post_attention_layernorm = LayerNorm( - config.hidden_size, - eps=config.layernorm_epsilon, - no_persist_layer_norm=not config.persist_layer_norm, - sequence_parallel=config.sequence_parallel, - apply_layernorm_1p=args.apply_layernorm_1p) + if args.normalization == 'layernorm': + if get_accelerator().device_name() == 'cuda': + self.post_attention_layernorm = LayerNorm( + config.hidden_size, + eps=config.layernorm_epsilon, + no_persist_layer_norm=not config.persist_layer_norm, + sequence_parallel=config.sequence_parallel, + apply_layernorm_1p=args.apply_layernorm_1p) + else: + self.post_attention_layernorm = LayerNorm( + config.hidden_size, + eps=config.layernorm_epsilon) else: - self.post_attention_layernorm = LayerNorm( - config.hidden_size, - eps=config.layernorm_epsilon) - - # Cross attention. + self.post_attention_layernorm = MixedFusedRMSNorm(config.hidden_size, config.layernorm_epsilon) + # Cross attention. if self.layer_type in (LayerType.decoder, LayerType.retro_decoder, LayerType.retro_decoder_with_retriever, @@ -805,12 +811,15 @@ def __init__(self, config, layer_number, attention_type=AttnType.cross_attn) # Layernorm on the attention output. - self.post_inter_attention_layernorm = LayerNorm( - config.hidden_size, - eps=config.layernorm_epsilon, - no_persist_layer_norm=not config.persist_layer_norm, - sequence_parallel=config.sequence_parallel, - apply_layernorm_1p=args.apply_layernorm_1p) + if args.normalization == 'layernorm': + self.post_inter_attention_layernorm = LayerNorm( + config.hidden_size, + eps=config.layernorm_epsilon, + no_persist_layer_norm=not config.persist_layer_norm, + sequence_parallel=config.sequence_parallel, + apply_layernorm_1p=args.apply_layernorm_1p) + else: + self.post_inter_attention_layernorm = MixedFusedRMSNorm(config.hidden_size, config.layernorm_epsilon) # MLP self.num_experts = num_experts @@ -825,7 +834,7 @@ def __init__(self, config, ParallelMLP(config, moe=True, enable_expert_tensor_parallelism=enable_expert_tensor_parallelism), - num_experts=self.num_experts, + num_experts=self.num_experts, ep_size=args.moe_expert_parallel_size, k=args.topk, use_residual=(args.mlp_type == 'residual'), @@ -1230,20 +1239,21 @@ class ParallelTransformerLayerPipe(ParallelTransformerLayer): """ def forward(self, inputs, **kwargs): assert torch.is_tensor(inputs) or isinstance(inputs, tuple) + if not hasattr(self, '_args'): + self._args = get_args() + rotary_pos_emb = self._args.rotary_pos_emb if self._args.use_rotary_position_embeddings else None if torch.is_tensor(inputs) or len(inputs) == 1: # No attention mask forwarded, search for args.attn_mask - if not hasattr(self, '_args'): - self._args = get_args() hidden_states, attention_mask = inputs, self._args.attn_mask # HACK: currently MoE model does not support pipeline parallel, so # here we just ignore the moe_loss returned by forward() - return super().forward(hidden_states, attention_mask, **kwargs)[0] + return super().forward(hidden_states, attention_mask, **kwargs, rotary_pos_emb=rotary_pos_emb)[0] elif len(inputs) == 2: # Attention mask is an activation. hidden_states, attention_mask = inputs[0], inputs[1] # HACK: currently MoE model does not support pipeline parallel, so # here we just ignore the moe_loss returned by forward() - return super().forward(*inputs, **kwargs)[0], attention_mask + return super().forward(*inputs, **kwargs, rotary_pos_emb=rotary_pos_emb)[0], attention_mask else: raise RuntimeError('Received more inputs than understood.') @@ -1557,18 +1567,20 @@ def build_layer(layer_number, n_e): if self.post_process and self.post_layer_norm: # Final layer norm before output. - if get_accelerator().device_name() == 'cuda': - self.final_layernorm = LayerNorm( - config.hidden_size, - eps=config.layernorm_epsilon, - no_persist_layer_norm=args.no_persist_layer_norm, - sequence_parallel=config.sequence_parallel, - apply_layernorm_1p=args.apply_layernorm_1p) + if args.normalization == 'layernorm': + if get_accelerator().device_name() == 'cuda': + self.final_layernorm = LayerNorm( + config.hidden_size, + eps=config.layernorm_epsilon, + no_persist_layer_norm=args.no_persist_layer_norm, + sequence_parallel=config.sequence_parallel, + apply_layernorm_1p=args.apply_layernorm_1p) + else: + self.final_layernorm = LayerNorm( + config.hidden_size, + eps=config.layernorm_epsilon) else: - self.final_layernorm = LayerNorm( - config.hidden_size, - eps=config.layernorm_epsilon) - + self.final_layernorm = MixedFusedRMSNorm(config.hidden_size, config.layernorm_epsilon) if deepspeed.checkpointing.is_configured(): global get_cuda_rng_tracker, checkpoint get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker @@ -1710,7 +1722,7 @@ def forward(self, hidden_states, attention_mask, assert self.recompute_granularity is None, \ 'inference does not work with activation checkpointing' - # TODO: Below old DeepSpeed code are commented because it's unsure whether + # TODO: Below old DeepSpeed code are commented because it's unsure whether # it is still relevant. # # Reza's note: DeepSpeed inference does not support transposes # if not self.ds_inference: @@ -1837,7 +1849,7 @@ def forward(self, hidden_states, attention_mask, # Final layer norm. if self.post_process and self.post_layer_norm: - # TODO: Below old DeepSpeed code are commented because it's unsure whether + # TODO: Below old DeepSpeed code are commented because it's unsure whether # it is still relevant. # if not self.ds_inference: # # Reverting data format change [s b h] --> [b s h]. @@ -1845,3 +1857,47 @@ def forward(self, hidden_states, attention_mask, hidden_states = self.final_layernorm(hidden_states) return (hidden_states, *moe_losses) + +class LMHeadPipe(MegatronModule): + """ + Arguments: + vocab_size: size of vocabulary. + hidden_size: hidden size + gather_output: wether output logits being gathered or not. + init_method: init method for weight initialization + config: + """ + + def __init__(self, hidden_size, vocab_size, config): + args = get_args() + super(LMHeadPipe, self).__init__() + self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=hidden_size, + output_size=vocab_size, + bias=False, + config=config, + init_method=config.init_method,) + + def forward(self, inputs, **kwargs): + assert torch.is_tensor(inputs) or isinstance(inputs, tuple) + if isinstance(inputs, tuple): + hidden_states = inputs[0] + else: + hidden_states = inputs + + if not hasattr(self, '_args'): + self._args = get_args() + + if hasattr(self._args, 'attn_mask'): + attention_mask = None + else: + attention_mask = inputs[1] + + logits, _ = self.lm_head(hidden_states) + + # If cmd args has attn_mask, we don't forward it as an activation. + if hasattr(self._args, 'attn_mask'): + return logits + else: + return logits, attention_mask + + diff --git a/pretrain_gpt.py b/pretrain_gpt.py index c0b95d0718..d6389953ea 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -13,6 +13,7 @@ from megatron.core.enums import ModelType from megatron.data.gpt_dataset import build_train_valid_test_datasets from megatron.model import GPTModel, GPTModelPipe +from megatron.model.rotary_pos_embedding import apply_rotary_pos_emb, RotaryEmbedding from megatron.training import pretrain from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import average_losses_across_data_parallel_group @@ -67,6 +68,25 @@ def model_provider(pre_process=True, post_process=True): # Attention mask must be bool. args.attn_mask = attention_mask.to(torch.bool) + # For prertaining, since sequence length is fixed, cache rotary embedding in args, to avoid communicating around + if args.use_rotary_position_embeddings: + rotary_dim = args.hidden_size // args.num_attention_heads \ + if args.kv_channels is None else args.kv_channels + + if args.rotary_percent < 1.0: + rotary_dim = int(rotary_dim * args.rotary_percent) + + # partial rotary embeddings, which is better than full rotary + # Wang and Komatsuzaki et al + # https://github.com/kingoflolz/mesh-transformer-jax/ + rotary_pos_emb = RotaryEmbedding(rotary_dim)(args.seq_length).to( + get_accelerator().current_device_name()) + if args.fp16: + rotary_pos_emb = rotary_pos_emb.half() + elif args.bf16: + rotary_pos_emb = rotary_pos_emb.bfloat16() + args.rotary_pos_emb = rotary_pos_emb + else: model = GPTModel( config, @@ -176,7 +196,7 @@ def loss_func(loss_mask, moe_loss, mos_loss, output_tensor): losses = output_tensor.float() loss_mask = loss_mask.view(-1).float() loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() - + # Reduce loss for logging. averaged_loss = average_losses_across_data_parallel_group([loss]) if args.mos or args.kd: @@ -199,7 +219,7 @@ def calculate_mos_loss(args, stu_output, teacher_model, tokens, position_ids, at alpha = args.kd_alpha_ce beta = args.kd_beta_ce kd_temp = args.kd_temp - + if teacher_model: with torch.no_grad(): if args.curriculum_learning_legacy and args.curriculum_seqlen < args.seq_length: @@ -262,7 +282,7 @@ def forward_step(data_iterator, model): if args.teacher_forward and args.teacher_model is not None: mos_loss = calculate_mos_loss(args, stu_output, args.teacher_model[0], tokens, position_ids, attention_mask) - + # Output_tensor stores the standard loss, loos_func calculates the total loss. return output_tensor, partial(loss_func, loss_mask, moe_loss, mos_loss)