Skip to content

Fix Gpt3 MultiHeadAttention out projection dimension#2812

Merged
copybara-service[bot] merged 1 commit intoAI-Hypercomputer:mainfrom
CIeNET-International:fix/Gpt3-projection-dimension
Dec 11, 2025
Merged

Fix Gpt3 MultiHeadAttention out projection dimension#2812
copybara-service[bot] merged 1 commit intoAI-Hypercomputer:mainfrom
CIeNET-International:fix/Gpt3-projection-dimension

Conversation

@hsuan-lun-chiang
Copy link
Copy Markdown
Collaborator

Description

The output dimension of the output projection in the Gpt3 MultiHeadAttention layer is mistakenly set to self.num_heads * self.head_dim; it should be config.emb_dim.

Tests

python3 -m MaxText.train  MaxText/configs/base.yml run_name=gpt3-train-run base_output_directory=gs://maxtext-test/train_gpt3/13/ model_name=gpt3-6b dataset_type=synthetic steps=10

Logs: base_emb_dim=3584
Logs: base_emb_dim=3584 base_mlp_dim=3584

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@copybara-service copybara-service Bot merged commit a581e00 into AI-Hypercomputer:main Dec 11, 2025
127 of 133 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants