Skip to content

Broadcast-based allgather in host for-loop#5925

Merged
Priya2698 merged 25 commits intomainfrom
pm/stream_broadcast
Mar 9, 2026
Merged

Broadcast-based allgather in host for-loop#5925
Priya2698 merged 25 commits intomainfrom
pm/stream_broadcast

Conversation

@Priya2698
Copy link
Copy Markdown
Collaborator

@Priya2698 Priya2698 commented Feb 6, 2026

Screenshot 2026-02-09 at 1 24 11 PM

The broadcast version is very slow so I am not comparing timings until we integrate this with multicast

@github-actions
Copy link
Copy Markdown

github-actions bot commented Feb 6, 2026

Review updated until commit 46c4698

Description

  • Implement broadcast-based allgather decomposition in host for-loop using loop index as root

  • Add stream parallelization support for DID->Stream transitions with swizzle handling

  • Modify shardByStream to return nullptr when stream parallelization cannot be applied

  • Add tests for column parallel linear forward with broadcast-based overlapping

Changes walkthrough

Relevant files
Enhancement
lower_to_communication.cpp
Add broadcast-based allgather with root parameter               

csrc/host_ir/lower_to_communication.cpp

  • Modify lowerToBroadcast to handle same mesh (broadcast-based
    allgather) and different meshes cases
  • Add root parameter to support for-loop index as root for
    broadcast-based allgather
  • Add stream parallelization check for DID->Stream transitions in
    getCommunicationInfoForParallelType
  • Update convertSingleOpToCommunication to accept and pass root
    parameter
  • +50/-19 
    convert_op_to_communication.cpp
    Update convertSingleOpToCommunication call signature         

    csrc/host_ir/pass/convert_op_to_communication.cpp

    • Update call to convertSingleOpToCommunication to pass root=null
    +4/-1     
    propagation.cpp
    Add swizzle1d handling in canonicalizeLoopDomain                 

    csrc/multidevice/propagation.cpp

  • Add Swizzle1D handling in canonicalizeLoopDomain for stream
    parallelization
  • Process swizzle1d transforms before split transforms
  • +22/-15 
    lower_to_communication.h
    Add root parameter to function declaration                             

    csrc/host_ir/lower_to_communication.h

  • Add root parameter to convertSingleOpToCommunication function
    declaration
  • +6/-0     
    Bug fix
    lowering.cpp
    Pass root index and add null checks in lowering                   

    csrc/host_ir/lowering.cpp

  • Pass loop index as root to convertSingleOpToCommunication
  • Add null checks for sharded_in and sharded_out from shardByStream
  • Use replacement_map.contains() instead of direct access for SSA
    validation
  • +31/-9   
    ops.cpp
    Handle non-stream-parallelized cases in shardByStream       

    csrc/host_ir/ops.cpp

  • Change deviceAndStreamParallelTypes() to {ParallelType::Stream} for
    propagation
  • Return nullptr when destination loop domain is not stream-parallelized
  • Add better error messages for failed sharding operations
  • +18/-8   
    Tests
    test_overlap.py
    Add tests for column parallel linear with broadcast           

    tests/python/multidevice/test_overlap.py

  • Add column_parallel_linear_forward helper function for testing
  • Add test_column_parallel_linear_forward test for correctness
  • Add test_column_parallel_linear_forward_benchmark for performance
    testing
  • +114/-0 

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Missing performance data

    The PR description explicitly states "The broadcast version is very slow so I am not comparing timings until we integrate this with multicast".
    While this is acknowledged, it would be helpful to have some baseline performance numbers or at least an estimate of the expected performance gap.
    Please consider adding any available performance metrics or explaining why the current implementation is expected to be faster than the previous approach once integrated with multicast.

    // Either of the following cases is happening:
    // 1. Same mesh: a broadcast-based allgather in a host for loop. `root` is the
    //    for-loop index.
    // 2. Different meshes: we pick the first device in the sender mesh as root.
    void lowerToBroadcast(
        TensorView* input_tv,
        TensorView* output_tv,
        const CommunicatorBackend backend,
        Val* root,
        std::vector<Expr*>& comms) {
      const DeviceMesh& sender_mesh = input_tv->getDeviceMesh();
      const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh();
    
      Team team = receiver_mesh.vector();
    
      if (sender_mesh == receiver_mesh) {
        NVF_ERROR(
            root != nullptr,
            "Root must be provided for broadcast-based allgather in a host for "
            "loop.");
      } else {
        NVF_ERROR_EQ(sender_mesh.rank(), 1, "sender: ", input_tv);
        NVF_ERROR_EQ(receiver_mesh.rank(), 1, "receiver: ", output_tv);
        DeviceIdxType root_device = sender_mesh.at(0);
        if (!receiver_mesh.has(root_device)) {
          team.push_back(root_device);
        }
        root = IrBuilder::create<Val>(
            getRelativeIndex(team, root_device), DataType::Index);
      }
    
      comms.push_back(IrBuilder::create<Communication>(
          CommunicationType::Broadcast,
          output_tv,
          input_tv,
          team,
          root,
          c10d::ReduceOp::RedOpType::UNUSED,
          backend));
    }
    Potential null return needs handling

    The shardByStream function can now return nullptr when the destination's loop domain is not stream-parallelized (lines 75-78).
    While this is handled in some call sites with NVF_ERROR checks, ensure all callers of shardByStream properly handle the nullptr case to avoid potential crashes.

    // Destination's loop domain may not be stream-parallelized if the
    // corresponding id is already sharded such as in
    // broadcast/collective-permute based decomposition of allgather.
    if (getShardedIterDomain(
            destination, ParallelType::Stream, DomainType::kLoop) == nullptr) {
      return nullptr;
    }
    New Swizzle1D handling logic

    The canonicalizeLoopDomain function now handles Swizzle1D transforms (lines 353-359).
    This is a new code path that should be carefully reviewed to ensure the swizzle handling is correct, especially the logic that skips swizzles with parallelized outputs and properly reconstructs the loop domain.

    if (auto* swizzle1d = dynamic_cast<Swizzle1D*>(transform)) {
      if (swizzle1d->out()->isParallelized()) {
        continue;
      }
      auto it = loop.erase(swizzle1d->out()).second;
      loop.insert(it, swizzle1d->in(), std::monostate());
      continue;
    }
    if (auto* split = dynamic_cast<Split*>(transform)) {
      if (split->outer()->isParallelized() ||
          split->inner()->isParallelized()) {
        continue;
      }
    
      if (!loop.contains(split->outer()) || !loop.contains(split->inner())) {
        continue;
      }
    
      loop.erase(split->outer());
      const auto inner_i = loop.erase(split->inner()).second;
      // `inner_i` is picked arbitrarily as the insertion point. Given `in`,
      // `outer` and `inner` are all serial, `in`'s position in the loop domain
      // doesn't matter.
      loop.insert(inner_i, split->in(), std::monostate());
      continue;
    }
    NVF_THROW("Expected a swizzle1d or split transform. Got: ", transform);

    @Priya2698 Priya2698 marked this pull request as ready for review February 9, 2026 21:10
    @Priya2698 Priya2698 requested a review from wujingyue February 9, 2026 21:11
    @Priya2698
    Copy link
    Copy Markdown
    Collaborator Author

    !test

    @greptile-apps

    This comment was marked as outdated.

    Copy link
    Copy Markdown
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    8 files reviewed, 2 comments

    Edit Code Review Agent Settings | Greptile

    Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
    @Priya2698
    Copy link
    Copy Markdown
    Collaborator Author

    !test

    Copy link
    Copy Markdown
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    8 files reviewed, 3 comments

    Edit Code Review Agent Settings | Greptile

    @greptile-apps
    Copy link
    Copy Markdown
    Contributor

    greptile-apps bot commented Feb 9, 2026

    Additional Comments (1)

    csrc/multidevice/communication.cpp
    Root validation rejects non-const

    Communication::validate only enforces the root/type contract when root() is a const integral scalar. For StreamBroadcast, root is the host loop index (non-const), so hasRoot(type()) is never validated and invalid roots (e.g., non-integral or negative-at-runtime) can slip through. This can lead to runtime failures when postBroadcast interprets the root.

    Consider extending validation to require root() be Index dtype for StreamBroadcast/rooted collectives even when not constant, and/or add runtime checks where the root is consumed.

    @Priya2698
    Copy link
    Copy Markdown
    Collaborator Author

    !test

    Copy link
    Copy Markdown
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    8 files reviewed, 2 comments

    Edit Code Review Agent Settings | Greptile

    @greptile-apps
    Copy link
    Copy Markdown
    Contributor

    greptile-apps bot commented Feb 9, 2026

    Additional Comments (1)

    csrc/multidevice/communication.cpp
    Non-constant root accepted

    Communication::validate only checks root/type consistency when root() is a const integral scalar (communication.cpp:238-246). For StreamBroadcast, the root is intentionally a non-const Val* (host loop index), so this validation becomes a no-op: invalid roots (e.g., negative at runtime, wrong dtype) won’t be rejected here but later code assumes a valid rank/root. If StreamBroadcast relies on runtime root, it still needs a type/dtype/range validation path for non-const roots (at least DataType::Index and non-negative).

    Copy link
    Copy Markdown
    Collaborator

    @wujingyue wujingyue left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    It's great to see this work functionally!

    Copy link
    Copy Markdown
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    8 files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    @Priya2698
    Copy link
    Copy Markdown
    Collaborator Author

    !test

    Copy link
    Copy Markdown
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    8 files reviewed, 2 comments

    Edit Code Review Agent Settings | Greptile

    Copy link
    Copy Markdown
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    9 files reviewed, 2 comments

    Edit Code Review Agent Settings | Greptile

    Copy link
    Copy Markdown
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    9 files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    Copy link
    Copy Markdown
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    12 files reviewed, 3 comments

    Edit Code Review Agent Settings | Greptile

    @greptile-apps
    Copy link
    Copy Markdown
    Contributor

    greptile-apps bot commented Feb 19, 2026

    Additional Comments (3)

    csrc/host_ir/evaluator.cpp
    StreamBroadcast missing from CUDA backend check

            communication->type() == CommunicationType::Broadcast ||
                communication->type() == CommunicationType::StreamBroadcast ||
                communication->type() == CommunicationType::Allgather,
    

    csrc/multidevice/cuda_p2p.cpp
    StreamBroadcast should be handled like Broadcast

      switch (communication->type()) {
        case CommunicationType::Broadcast:
        case CommunicationType::StreamBroadcast: {
          auto* broadcast_handle =
              dynamic_cast<SymMemForBroadcast*>(symmetric_memory_handle);
          NVF_ERROR(broadcast_handle != nullptr, "Invalid broadcast handle");
          postBroadcastWithCudaBackend(
              communication, input, broadcast_handle, stream, root);
    

    csrc/multidevice/cuda_p2p.cpp
    StreamBroadcast should be handled like Broadcast

      switch (communication->type()) {
        case CommunicationType::Broadcast:
        case CommunicationType::StreamBroadcast: {
          auto* broadcast_handle =
              dynamic_cast<SymMemForBroadcast*>(symmetric_memory_handle);
          NVF_ERROR(broadcast_handle != nullptr, "Invalid broadcast handle");
          waitBroadcastWithCudaBackend(
              communication, broadcast_handle, stream, root);
    

    Copy link
    Copy Markdown
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    12 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    @greptile-apps

    This comment was marked as outdated.

    @Priya2698 Priya2698 marked this pull request as draft February 21, 2026 21:12
    Copy link
    Copy Markdown
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    9 files reviewed, 6 comments

    Edit Code Review Agent Settings | Greptile

    @greptile-apps

    This comment was marked as outdated.

    @Priya2698
    Copy link
    Copy Markdown
    Collaborator Author

    !test

    @Priya2698 Priya2698 requested a review from wujingyue March 4, 2026 22:44
    @Priya2698 Priya2698 marked this pull request as ready for review March 4, 2026 22:44
    Comment on lines +353 to 360
    if (auto* swizzle1d = dynamic_cast<Swizzle1D*>(transform)) {
    if (swizzle1d->out()->isParallelized()) {
    continue;
    }
    auto it = loop.erase(swizzle1d->out()).second;
    loop.insert(it, swizzle1d->in(), std::monostate());
    continue;
    }
    Copy link
    Copy Markdown
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Missing loop.contains guard for Swizzle1D

    The Split case immediately below explicitly guards with if (!loop.contains(split->outer()) || !loop.contains(split->inner())) { continue; } before erasing. The rationale is that an intermediate ID can be absent from loop if it was already replaced by an earlier iteration (e.g., a downstream transform already "consumed" it in the reverse traversal).

    The new Swizzle1D case lacks an equivalent guard. If swizzle1d->out() has already been replaced (because a transform downstream of the swizzle was processed first and removed it from loop), calling loop.erase(swizzle1d->out()) on an absent key returns a potentially invalid iterator, and the subsequent loop.insert(it, swizzle1d->in(), ...) then inserts at the wrong position or inserts a duplicate. This diverges from the consistent defensive pattern used for Split.

    Suggested change
    if (auto* swizzle1d = dynamic_cast<Swizzle1D*>(transform)) {
    if (swizzle1d->out()->isParallelized()) {
    continue;
    }
    auto it = loop.erase(swizzle1d->out()).second;
    loop.insert(it, swizzle1d->in(), std::monostate());
    continue;
    }
    if (auto* swizzle1d = dynamic_cast<Swizzle1D*>(transform)) {
    if (swizzle1d->out()->isParallelized()) {
    continue;
    }
    if (!loop.contains(swizzle1d->out())) {
    continue;
    }
    auto it = loop.erase(swizzle1d->out()).second;
    loop.insert(it, swizzle1d->in(), std::monostate());
    continue;
    }

    Comment on lines +504 to +505
    # This is a port of CollectiveBasedOverlapTest.RowParallelLinear_Forward.
    h, t = 8192, 8192
    Copy link
    Copy Markdown
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Copy-paste error in comment — wrong test name

    The comment says this ports RowParallelLinear_Forward, but this benchmark exercises column_parallel_linear_forward (column-parallel, not row-parallel). The functional test test_column_parallel_linear_forward above correctly refers to ColumnAndSequenceParallelLinear_Forward.

    Suggested change
    # This is a port of CollectiveBasedOverlapTest.RowParallelLinear_Forward.
    h, t = 8192, 8192
    # This is a port of CollectiveBasedOverlapTest.ColumnAndSequenceParallelLinear_Forward.

    @Priya2698
    Copy link
    Copy Markdown
    Collaborator Author

    !test

    @Priya2698
    Copy link
    Copy Markdown
    Collaborator Author

    !test

    // Check if we are going from DID -> Stream, which is a ring allgather.
    // This can be executed as a broadcast or send recvs, which is decided
    // by the presence of a swizzle in the stream id definition.
    if (c_logical_stream_id == p2c.at(p_logical_id)) {
    Copy link
    Copy Markdown
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    When c_logical_stream_id is nullptr (the consumer has no stream-parallelized ID), the expression p2c.at(p_logical_id) is still evaluated to perform the pointer comparison. If p_logical_id happens not to be a key in p2c, this throws std::out_of_range before we even reach the existing Allgather/Gather path below.

    The condition should short-circuit on the null check first:

    Suggested change
    if (c_logical_stream_id == p2c.at(p_logical_id)) {
    if (c_logical_stream_id != nullptr &&
    c_logical_stream_id == p2c.at(p_logical_id)) {

    This makes the intent explicit and avoids the map lookup entirely when there is no stream-parallel consumer ID.

    @Priya2698 Priya2698 merged commit 54d48ae into main Mar 9, 2026
    52 checks passed
    @Priya2698 Priya2698 deleted the pm/stream_broadcast branch March 9, 2026 18:34
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants