[GPTOSS] Support sequence parallelism with attention sinks#558
Closed
SumanthRH wants to merge 13 commits intoNovaSky-AI:mainfrom
Closed
[GPTOSS] Support sequence parallelism with attention sinks#558SumanthRH wants to merge 13 commits intoNovaSky-AI:mainfrom
SumanthRH wants to merge 13 commits intoNovaSky-AI:mainfrom
Conversation
Signed-off-by: SumanthRH <sumanthrh@anyscale.com>
Signed-off-by: SumanthRH <sumanthrh@anyscale.com>
Signed-off-by: SumanthRH <sumanthrh@anyscale.com>
…allel Signed-off-by: SumanthRH <sumanthrh@anyscale.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Follow-up PR to #515 . Adds support for sequence parallelism for the custom flex attention implementation to scale to longer context lengths.
Summary
Support sequence parallelism for attention sinks: Currently, we use Unsloth's flex attention implementation which has a score mod function
sink_score_mod. This function uses sink weights per attention head to change the attention score. Now, with ulysses sequence parallelism, each rank initially requires query states of size (bsz, seq_len // sp_size, num_heads, hidden_dim), and then after All2All this comes (bsz, seq_len, num_heads // sp_size, hidden_dim). Different SP ranks thus handle compute for different attention heads. We thus need to index into the sink weights appropriately for different sp ranks.Support custom attention functions for ulysses sequence parallelism: The current ulysses sequence parallel implemention supports only flash attention. This PR adds a simple wrapper to support custom attention functions.
Tranpose dimentions for Q, K, V states in GPT OSS: Further, there's an issue with the current GPTOSS attention
forwardimplementation: It receives query states in the format(bsz, num_heads, seq_len, hidden_dim)instead of the standard(bsz, seq_len, num_heads, hidden_dim). This PR adds an additional transpose before calling the flex attention function. This makes it compatible with the ulysses sequence parallelism implemention.Benchmarking
Using our full context training scripts, I tested out the max context lengths I could use. I tested with 4K, 8K, 16K, 32K, and 64K context length for single node training on 1 8xH100. I can scale upto 32K context length without sequence parallelism and with sequence parallelism I can go beyond 32K.
Correctness checks
I modified our existing
test_model_wrapper.pyfile to use GPT-OSS to validate that logprobs with the new implementation are correct. Logprobs match but differ in the second decimal place (higher error because of the flex attention implementation)E2E Validation