Skip to content

Commit

Permalink
support llama pretraining (NVIDIA#166)
Browse files Browse the repository at this point in the history
* add external user record

* support llama pretraining

* move example script to examples_deepspeed/

---------

Co-authored-by: Conglong <conglong.li@gmail.com>
Co-authored-by: LydiaXiaohongLi <xiaohong.li@ahrefs.com>
  • Loading branch information
3 people committed Jul 20, 2023
1 parent 99d1eea commit 69e3c6a
Show file tree
Hide file tree
Showing 5 changed files with 292 additions and 75 deletions.
118 changes: 118 additions & 0 deletions examples_deepspeed/pretrain_llama_distributed.sh
@@ -0,0 +1,118 @@
#!/bin/bash
# This example script is contributed by external user https://github.com/LydiaXiaohongLi
set -ex

######################################
# Change the below configurations here
BASE_PATH=./tmp
DS_CONFIG=${BASE_PATH}/deepspeed.json
DATASET_1="./tmp/data/bookcorpus_train_1m_text_sentence"
DATASET="1 ${DATASET_1}"
CHECKPOINT_PATH=./tmp
TOKENIZER_PATH=./tmp/tokenizer.model # offical llama tokenizer.model

TP=2
PP=2
ZERO_STAGE=0

GPUS_PER_NODE=8
MASTER_ADDR=localhost
MASTER_PORT=6000
NNODES=1
NODE_RANK=0

HIDDEN_SIZE=2048 # e.g. llama-13b: 5120
FFN_HIDDEN_SIZE=5504 # e.g. llama-13b: 13824
NUM_LAYERS=24 # e.g. llama-13b: 40
NUM_HEADS=16 # e.g. llama-13b: 40
SEQ_LENGTH=2048

MICRO_BATCH_SIZE=4
GLOBAL_BATCH_SIZE=32 # e.g. llama: 4M tokens
TRAIN_STEPS=250000 # e.g. llama: 1T tokens / 4M tokens_per_batch = 250000 steps
LR=3e-4
MIN_LR=3e-5
LR_WARMUP_STEPS=2000
WEIGHT_DECAY=0.1
GRAD_CLIP=1

# Below configuration required for llama model as per llama paper
# --no-query-key-layer-scaling \
# --attention-dropout 0 \
# --hidden-dropout 0 \
# --use-rotary-position-embeddings \
# --untie-embeddings-and-output-weights \
# --swiglu \
# --normalization rmsnorm \
# --disable-bias-linear \
######################################



cat <<EOT > $DS_CONFIG
{
"train_batch_size" : $GLOBAL_BATCH_SIZE,
"train_micro_batch_size_per_gpu": $MICRO_BATCH_SIZE,
"steps_per_print": 1,
"zero_optimization": {
"stage": $ZERO_STAGE
},
"bf16": {
"enabled": true
}
}
EOT

ds_args=""
ds_args=" --deepspeed ${ds_args}"
ds_args=" --deepspeed_config=$DS_CONFIG ${ds_args}"
ds_args=" --zero-stage=$ZERO_STAGE ${ds_args}"
ds_args=" --deepspeed-activation-checkpointing ${ds_args}"


DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"

torchrun $DISTRIBUTED_ARGS \
pretrain_gpt.py \
--tensor-model-parallel-size $TP \
--pipeline-model-parallel-size $PP \
--num-layers $NUM_LAYERS \
--hidden-size $HIDDEN_SIZE \
--ffn-hidden-size $FFN_HIDDEN_SIZE \
--num-attention-heads $NUM_HEADS \
--micro-batch-size $MICRO_BATCH_SIZE \
--global-batch-size $GLOBAL_BATCH_SIZE \
--seq-length $SEQ_LENGTH \
--max-position-embeddings $SEQ_LENGTH \
--train-iters $TRAIN_STEPS \
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \
--data-path $DATASET \
--data-impl mmap \
--tokenizer-type GPTSentencePieceTokenizer \
--tokenizer-model $TOKENIZER_PATH \
--split 949,50,1 \
--distributed-backend nccl \
--lr $LR \
--lr-decay-style cosine \
--min-lr $MIN_LR \
--weight-decay $WEIGHT_DECAY \
--clip-grad $GRAD_CLIP \
--lr-warmup-iters $LR_WARMUP_STEPS \
--optimizer adam \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--log-interval 1 \
--save-interval 10000 \
--eval-interval 1000 \
--eval-iters 10 \
--bf16 \
--no-query-key-layer-scaling \
--attention-dropout 0 \
--hidden-dropout 0 \
--use-rotary-position-embeddings \
--untie-embeddings-and-output-weights \
--swiglu \
--normalization rmsnorm \
--disable-bias-linear \
$ds_args
5 changes: 5 additions & 0 deletions megatron/arguments.py
Expand Up @@ -597,6 +597,11 @@ def _add_network_size_args(parser):
group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
help='Pad the vocab size to be divisible by this value.'
'This is added for computational efficieny reasons.')
group.add_argument('--normalization', type=str, default='layernorm',
choices=['layernorm', 'rmsnorm'],
help='Options for layer normalization type:'
' layernorm'
' rmsnorm')
group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
help='Layer norm epsilon.')
group.add_argument('--apply-layernorm-1p', action='store_true',
Expand Down
68 changes: 43 additions & 25 deletions megatron/model/gpt_model.py
Expand Up @@ -16,9 +16,11 @@

from megatron.model import LayerNorm
from .language_model import EmbeddingPipe
from .transformer import ParallelTransformerLayerPipe
from .transformer import ParallelTransformerLayerPipe, LMHeadPipe
from deepspeed.pipe import PipelineModule, LayerSpec, TiedLayerSpec

from apex.normalization import MixedFusedRMSNorm

def post_language_model_processing(lm_output, labels, logit_weights,
parallel_output,
fp16_lm_cross_entropy):
Expand All @@ -40,7 +42,7 @@ def post_language_model_processing(lm_output, labels, logit_weights,
loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels)
else:
loss = tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels)

# [s b] => [b, s]
loss = loss.transpose(0,1).contiguous()
return loss
Expand Down Expand Up @@ -74,7 +76,7 @@ def __init__(self,
pre_process=self.pre_process,
post_process=self.post_process,
num_experts=args.num_experts)

if not args.untie_embeddings_and_output_weights:
self.initialize_word_embeddings()

Expand Down Expand Up @@ -210,16 +212,26 @@ def _to_float16(inputs):
self.specs.append(_to_float16)

# Embedding layer
self.specs.append(TiedLayerSpec('embed',
EmbeddingPipe,
if args.untie_embeddings_and_output_weights:
self.specs.append(LayerSpec(EmbeddingPipe,
args.hidden_size,
args.padded_vocab_size,
args.max_position_embeddings,
args.hidden_dropout,
config,
num_tokentypes=num_tokentypes,
embedding_weights_in_fp32=args.embedding_weights_in_fp32,
tied_weight_attr='word_embeddings_weight'))
embedding_weights_in_fp32=args.embedding_weights_in_fp32,))
else:
self.specs.append(TiedLayerSpec('embed',
EmbeddingPipe,
args.hidden_size,
args.padded_vocab_size,
args.max_position_embeddings,
args.hidden_dropout,
config,
num_tokentypes=num_tokentypes,
embedding_weights_in_fp32=args.embedding_weights_in_fp32,
tied_weight_attr='word_embeddings_weight'))

for layer_idx in range(args.num_layers):
self.specs.append(
Expand All @@ -229,31 +241,37 @@ def _to_float16(inputs):
self_attn_mask_type=AttnMaskType.causal))

# Final layernorm after transformer layers
self.specs.append(
LayerSpec(LayerNorm,
args.hidden_size,
eps=args.layernorm_epsilon))
if args.normalization == 'layernorm':
self.specs.append(LayerSpec(LayerNorm,
args.hidden_size,
eps=args.layernorm_epsilon))
else:
self.specs.append(LayerSpec(MixedFusedRMSNorm, args.hidden_size, args.layernorm_epsilon))

def _logits_helper(embedding, lm_output):
"""A wrapper to massage inputs/outputs from pipeline. """
return parallel_lm_logits(
lm_output,
embedding.word_embeddings_weight,
self.parallel_output)

self.specs.append(
TiedLayerSpec('embed',
EmbeddingPipe,
args.hidden_size,
args.padded_vocab_size,
args.max_position_embeddings,
args.hidden_dropout,
config,
num_tokentypes=num_tokentypes,
embedding_weights_in_fp32=args.embedding_weights_in_fp32,
forward_fn=_logits_helper,
tied_weight_attr='word_embeddings_weight')
)
if args.untie_embeddings_and_output_weights:
self.specs.append(
LayerSpec(LMHeadPipe, args.hidden_size, args.padded_vocab_size, config)
)
else:
self.specs.append(
TiedLayerSpec('embed',
EmbeddingPipe,
args.hidden_size,
args.padded_vocab_size,
args.max_position_embeddings,
args.hidden_dropout,
config,
num_tokentypes=num_tokentypes,
embedding_weights_in_fp32=args.embedding_weights_in_fp32,
forward_fn=_logits_helper,
tied_weight_attr='word_embeddings_weight')
)

# Convert to fp32 if needed
if args.fp16 or args.bf16:
Expand Down

0 comments on commit 69e3c6a

Please sign in to comment.