Description:
Context:
Training on extreme long-context tasks (32k+ tokens) requires Sequence Parallelism or Context Parallelism , where a single sequence is sharded across multiple GPUs. Our current loss functions (LogP, KL, Masked-Mean) assume the full sequence resides on a single device.
Tasks:
- Update the selected_logprobs and KL operators to accept sharded input sequences.
- Implement cross-rank reductions (e.g., partial sum of masked tokens -> AllReduce -> global division) for metrics like masked_mean and masked_sum.
- Ensure that gradients route back to the correct sequence shards during the backward pass seamlessly.
Description:
Context:
Training on extreme long-context tasks (32k+ tokens) requires Sequence Parallelism or Context Parallelism , where a single sequence is sharded across multiple GPUs. Our current loss functions (LogP, KL, Masked-Mean) assume the full sequence resides on a single device.
Tasks: