Problem
In long-context training (e.g., seq_length=65536) with variable-length sequence packing (THD format), the fused RoPE kernel becomes the dominant bottleneck when a micro-batch contains many packed sequences.
Most training samples are shorter than max_seq_length, so a single micro-batch routinely packs hundreds to thousands of spans. In this regime, RoPE time grows linearly with n_seqs and can exceed the combined cost of attention + MLP.
Cause
The kernel grid is dim3(max_seqlen, n_seqs), launching max_seqlen × n_seqs CUDA blocks. Only total_tokens blocks do useful work; the rest read cu_seqlens and early-exit. When n_seqs is large, the vast majority of blocks are wasted.
// fused_rope.cu — forward & backward launchers
dim3 blocks(s, b); // s = max_seqlen, b = n_seqs
// Inside kernel — THD path
int t_id = s_id + start;
if (t_id >= end) return; // most blocks exit here
Example: total_tokens=65536, n_seqs=2401 → 157M blocks launched, 65K useful (99.96% wasted).
Profiling Data
0.9B model, 40 layers, H100, TP=2, seq_length=65536:
| n_seqs |
RoPE / layer (×24) |
% of layer time |
| <50 |
22 ms |
~10% |
| 50–200 |
201 ms |
~59% |
| 200–500 |
488 ms |
~79% |
| ~2400 |
4,620 ms |
~97% |
Environment
- Transformer Engine 2.6.0.post1
- H100 80GB, CUDA 12.8
- Megatron-LM, THD format, span-based attention
Problem
In long-context training (e.g.,
seq_length=65536) with variable-length sequence packing (THD format), the fused RoPE kernel becomes the dominant bottleneck when a micro-batch contains many packed sequences.Most training samples are shorter than
max_seq_length, so a single micro-batch routinely packs hundreds to thousands of spans. In this regime, RoPE time grows linearly withn_seqsand can exceed the combined cost of attention + MLP.Cause
The kernel grid is
dim3(max_seqlen, n_seqs), launchingmax_seqlen × n_seqsCUDA blocks. Onlytotal_tokensblocks do useful work; the rest readcu_seqlensand early-exit. Whenn_seqsis large, the vast majority of blocks are wasted.Example:
total_tokens=65536,n_seqs=2401→ 157M blocks launched, 65K useful (99.96% wasted).Profiling Data
0.9B model, 40 layers, H100, TP=2,
seq_length=65536:Environment