Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Mixtral 8*7B MOE #667

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open

Support Mixtral 8*7B MOE #667

wants to merge 6 commits into from

Conversation

matrixssy
Copy link

@matrixssy matrixssy commented Jan 18, 2024

Support Mixtral 8*7B MOE model structure and weight converter from huggingface.

You can refer to this script to convert the huggingface weight to megatron:

python tools/checkpoint/util.py --model-type GPT
--loader mixtral_hf
--saver mixtral
--load-dir ../models/Mixtral-8x7B-Instruct-v0.1
--save-dir ../models/Mixtral-8x7B-Instruct-v0.1-tp2-pp4
--tokenizer-model ../models/Mixtral-8x7B-Instruct-v0.1/tokenizer.model
--target-tensor-parallel-size 2
--target-pipeline-parallel-size 4

To activate mixtral moe in training:

--num-experts 8 \
--moe-type mixtral \

Note that:
To implement the load balancing loss of huggingface equivalently on megatron requires a lot of modifications for returning router logits. Therefore, in order to simplify the work, I choose to use the original sinkhorn algorithm to balance the voting probability of each expert instead of using the load_balancing_loss_func in huggingface.

@matrixssy
Copy link
Author

#649

@matrixssy matrixssy marked this pull request as ready for review January 18, 2024 12:31
@cdj0311
Copy link

cdj0311 commented Jan 19, 2024

Great work! Could you provide a script of convert megatron mixtral to hf ?

@matrixssy
Copy link
Author

Great work! Could you provide a script of convert megatron mixtral to hf ?

Still working on it.

Fixed the bug where the b and s dimensions were mixed up.
@ftgreat
Copy link

ftgreat commented Jan 22, 2024

Great work!

Great work! Look forward to a script of convert megatron mixtral to hf.

@TissueC
Copy link

TissueC commented Jan 28, 2024

Hi, I wonder if the loss is normal after converting and training mixtral with megatron at your computer.
I apply this PR and the initial loss is quite high, which seems to indicate the forward step is not aligned with huggingface version, especially in TP>1.
@matrixssy

@matrixssy
Copy link
Author

Hi, I wonder if the loss is normal after converting and training mixtral with megatron at your computer. I apply this PR and the initial loss is quite high, which seems to indicate the forward step is not aligned with huggingface version, especially in TP>1. @matrixssy

Hello, first of all, thank you for giving it a try. In my example, I trained Mixtral 8x7B-v0.1 on the gpt4 alpaca zh dataset, and the loss decreased from 4.5 to around 1.0 after 300 iterations; if I train Mixtral 8x7B-Instruction-v0.1 on the gpt4 alpaca zh dataset, the loss decreases from 1.5 to around 1.0 after 300 iterations. Could you share your HF training loss curve? Although this PR has not achieved load balancing loss (around 1e-2 level), the loss should not differ significantly from HF.

@matrixssy
Copy link
Author

Hi, I wonder if the loss is normal after converting and training mixtral with megatron at your computer. I apply this PR and the initial loss is quite high, which seems to indicate the forward step is not aligned with huggingface version, especially in TP>1. @matrixssy

By the way, I have verified that the relative error of the average forward logits between the converted megatron model and the HF (Hugging Face) model is within 1%, and the cosine similarity is 0.9999.

@TissueC
Copy link

TissueC commented Jan 29, 2024

Hi, I fix a bug in my script and now the initial loss is normal *(around 2.3 in arxiv dataset). Thanks for your contribution!

also, I have an extra question, as the gate linear is shared across TP groups, why not define it as a tensor_parallel.RowParallelLinear?

@matrixssy
Copy link
Author

matrixssy commented Jan 29, 2024

Hi, I fix a bug in my script and now the initial loss is normal *(around 2.3 in arxiv dataset). Thanks for your contribution!

also, I have an extra question, as the gate linear is shared across TP groups, why not define it as a tensor_parallel.RowParallelLinear?

Good question! My initial idea was that the shape of the router would be (hidden state, n_experts), which is not particularly large and only has one per layer, so the benefits of parallelization are not significant. Additionally, when implementing load-balancing loss in the future, obtaining route logits will become difficult.

@Victarry
Copy link
Contributor

Victarry commented Feb 5, 2024

Hi, @matrixssy. Thanks for your contribution, there are some ongoing efforts in NVIDIA internally working on the Mixtral 8x7B example. We will support convert HF checkpoint to MCore checkpoint with different EP/TP/PP size. The code will be released with some code refactor soon.

@matrixssy
Copy link
Author

Hi, @matrixssy. Thanks for your contribution, there are some ongoing efforts in NVIDIA internally working on the Mixtral 8x7B example. We will support convert HF checkpoint to MCore checkpoint with different EP/TP/PP size. The code will be released with some code refactor soon.

Cool! How much longer will this task take (convert HF checkpoint to MCore checkpoint in EP), and are there any preceding pull requests?

@Victarry
Copy link
Contributor

Victarry commented Feb 6, 2024

Actually, the functionality of changing EP size has been implemented.

But there is a preceding MR (on gitlab internally) still being reviewed which implement MLM-legacy to MCore model converter.
After that MR finished, I need some time for code refactor.
I think the MR for Mixtral checkpoint convert will be released this month.

@ZhangEnmao
Copy link

Hi, when I run your code, I got two errors. Could you help me and give some advises ?
image
image

@ZhangEnmao
Copy link

ZhangEnmao commented Feb 8, 2024

Hi, when I set target-tensor-parallel-size > 1 , I got the following errors. only setting target-tensor-parallel-size = 1 works. Is it possible that it is related to the following warning ? I use the latest docker with pytorch and nvidia, What can I do to resolve this missing packages problem. Thanks very much.
image

image

@matrixssy
Copy link
Author

Hi, when I set target-tensor-parallel-size > 1 , I got the following errors. only setting target-tensor-parallel-size = 1 works. Is it possible that it is related to the following warning ? I use the latest docker with pytorch and nvidia, What can I do to resolve this missing packages problem. Thanks very much. image

image

Yes, you need to set --sequence-parallel

Copy link

Marking as stale. No activity in 60 days.

@github-actions github-actions bot added the stale No activity in 60 days on issue or PR label Apr 17, 2024
@shamanez
Copy link

Any update on this ?

@github-actions github-actions bot removed the stale No activity in 60 days on issue or PR label Apr 21, 2024
@passaglia
Copy link

@Victarry
Have the plans to release a checkpoint-converter that supports MoE (mentioned here as "Coming Soon") changed?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants