Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the bug of using loss before assignment #700

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions examples/pretrain_gpt_moe_demo.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#!/bin/bash


export CUDA_DEVICE_MAX_CONNECTIONS=1

GPUS_PER_NODE=8
# Change for multinode config
MASTER_ADDR=localhost
MASTER_PORT=6001
NNODES=1
NODE_RANK=0
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))

CHECKPOINT_PATH=checkpoint
VOCAB_FILE=/data/gpt2-vocab.json
MERGE_FILE=/data/gpt2-merges.txt
DATA_PATH=./train_data/my-gpt2_text_document


DISTRIBUTED_ARGS="
--nproc_per_node $GPUS_PER_NODE \
--nnodes $NNODES \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT \
"

GPT_ARGS="
--tensor-model-parallel-size 2 \
--pipeline-model-parallel-size 2 \
--sequence-parallel \
--ffn-hidden-size 6784 \
--norm-epsilon 1e-5 \
--num-layers 18 \
--hidden-size 2560 \
--num-attention-heads 20 \
--seq-length 4096 \
--max-position-embeddings 4096 \
--bf16
--micro-batch-size 2 \
--global-batch-size 32 \
--lr 3.4e-4 \
--train-iters 10000 \
--lr-decay-iters 8000 \
--lr-decay-style cosine \
--min-lr 3.4e-5 \
--weight-decay 0.1 \
--lr-warmup-iters 2000 \
--clip-grad 1.0 \
--use-mcore-models \
--use-flash-attn \
--untie-embeddings-and-output-weights \
--use-rotary-position-embeddings \
--disable-bias-linear \
--normalization RMSNorm \
--no-position-embedding \
--no-masked-softmax-fusion \
--swiglu \
--attention-dropout 0 \
--hidden-dropout 0 \
"

DATA_ARGS="
--data-path $DATA_PATH \
--vocab-file $VOCAB_FILE \
--merge-file $MERGE_FILE \
--dataloader-type cyclic \
--split 949,50,1
"

# --expert-model-parallel-size 2 \
MOE_ARGS="
--num-experts 4 \
--moe-grouped-gemm \
--moe-router-topk 1 \
--moe-router-load-balancing-type aux_loss \
--moe-aux-loss-coeff 1e-2 \
"


OUTPUT_ARGS="
--log-interval 1 \
--save-interval 10000 \
--eval-interval 1000 \
--eval-iters 10 \
--tensorboard-dir ./tensorboard/test_mcore_moe \
--tensorboard-log-interval 1 \
"

torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \
$GPT_ARGS \
$MOE_ARGS \
$DATA_ARGS \
$OUTPUT_ARGS \
--distributed-backend nccl \
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH
2 changes: 1 addition & 1 deletion megatron/core/pipeline_parallel/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def forward_step(
if hasattr(config, 'num_moe_experts') and config.num_moe_experts is not None:
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
loss_scale = (
config.grad_scale_func(torch.tensor(1.0, device=loss.device))
config.grad_scale_func(torch.tensor(1.0, device=output_tensor.device))
if config.grad_scale_func is not None
else torch.tensor(1.0)
)
Expand Down