Skip to content

[GPTOSS] Support sequence parallelism with attention sinks#558

Closed
SumanthRH wants to merge 13 commits intoNovaSky-AI:mainfrom
SumanthRH:gptoss-flex-seq-parallel
Closed

[GPTOSS] Support sequence parallelism with attention sinks#558
SumanthRH wants to merge 13 commits intoNovaSky-AI:mainfrom
SumanthRH:gptoss-flex-seq-parallel

Conversation

@SumanthRH
Copy link
Copy Markdown
Member

@SumanthRH SumanthRH commented Oct 22, 2025

What does this PR do?

Follow-up PR to #515 . Adds support for sequence parallelism for the custom flex attention implementation to scale to longer context lengths.

Summary

  1. Support sequence parallelism for attention sinks: Currently, we use Unsloth's flex attention implementation which has a score mod function sink_score_mod . This function uses sink weights per attention head to change the attention score. Now, with ulysses sequence parallelism, each rank initially requires query states of size (bsz, seq_len // sp_size, num_heads, hidden_dim), and then after All2All this comes (bsz, seq_len, num_heads // sp_size, hidden_dim). Different SP ranks thus handle compute for different attention heads. We thus need to index into the sink weights appropriately for different sp ranks.

  2. Support custom attention functions for ulysses sequence parallelism: The current ulysses sequence parallel implemention supports only flash attention. This PR adds a simple wrapper to support custom attention functions.

  3. Tranpose dimentions for Q, K, V states in GPT OSS: Further, there's an issue with the current GPTOSS attention forward implementation: It receives query states in the format (bsz, num_heads, seq_len, hidden_dim) instead of the standard (bsz, seq_len, num_heads, hidden_dim). This PR adds an additional transpose before calling the flex attention function. This makes it compatible with the ulysses sequence parallelism implemention.

Benchmarking

Using our full context training scripts, I tested out the max context lengths I could use. I tested with 4K, 8K, 16K, 32K, and 64K context length for single node training on 1 8xH100. I can scale upto 32K context length without sequence parallelism and with sequence parallelism I can go beyond 32K.

Correctness checks

I modified our existing test_model_wrapper.py file to use GPT-OSS to validate that logprobs with the new implementation are correct. Logprobs match but differ in the second decimal place (higher error because of the flex attention implementation)

E2E Validation

Signed-off-by: SumanthRH <sumanthrh@anyscale.com>
x
Signed-off-by: SumanthRH <sumanthrh@anyscale.com>
x
Signed-off-by: SumanthRH <sumanthrh@anyscale.com>
x
Signed-off-by: SumanthRH <sumanthrh@anyscale.com>
x
Signed-off-by: SumanthRH <sumanthrh@anyscale.com>
x
Signed-off-by: SumanthRH <sumanthrh@anyscale.com>
Signed-off-by: SumanthRH <sumanthrh@anyscale.com>
x
Signed-off-by: SumanthRH <sumanthrh@anyscale.com>
x
Signed-off-by: SumanthRH <sumanthrh@anyscale.com>
Signed-off-by: SumanthRH <sumanthrh@anyscale.com>
…allel

Signed-off-by: SumanthRH <sumanthrh@anyscale.com>
x
Signed-off-by: SumanthRH <sumanthrh@anyscale.com>
x
Signed-off-by: SumanthRH <sumanthrh@anyscale.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants