Skip to content

[FEAT][distributed]: implement Sequence-Parallel aware LogProb and Loss reductions #49

@Flink-ddd

Description

@Flink-ddd

Description:

Context:

Training on extreme long-context tasks (32k+ tokens) requires Sequence Parallelism or Context Parallelism , where a single sequence is sharded across multiple GPUs. Our current loss functions (LogP, KL, Masked-Mean) assume the full sequence resides on a single device.

Tasks:

  1. Update the selected_logprobs and KL operators to accept sharded input sequences.
  2. Implement cross-rank reductions (e.g., partial sum of masked tokens -> AllReduce -> global division) for metrics like masked_mean and masked_sum.
  3. Ensure that gradients route back to the correct sequence shards during the backward pass seamlessly.

Metadata

Metadata

Assignees

No one assigned

    Labels

    component: distributedTasks involving Ray actor management, cross-node scheduling, and communication synchronization.component: kernelsTasks involving the development of CUDA and Triton underlying operatorstype: designIssues requiring in-depth discussion of architecture design

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions