Skip to content

Reduce and Allreduce NVLS implementations for the cuda backend#6038

Merged
nsarka merged 33 commits intoNVIDIA:mainfrom
nsarka:nsarka/reduce-nvls-cuda-backend
Mar 25, 2026
Merged

Reduce and Allreduce NVLS implementations for the cuda backend#6038
nsarka merged 33 commits intoNVIDIA:mainfrom
nsarka:nsarka/reduce-nvls-cuda-backend

Conversation

@nsarka
Copy link
Copy Markdown
Member

@nsarka nsarka commented Mar 20, 2026

Built on top of #5620. Adds reduce and allreduce NVLS implementations. Both use the same ld_reduce kernel and synchronize using a symmetric integer tensor as a semaphore

@nsarka nsarka requested review from samnordmann and wujingyue March 20, 2026 14:24
@nsarka nsarka force-pushed the nsarka/reduce-nvls-cuda-backend branch from 6cc186a to 2697b92 Compare March 20, 2026 14:25
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 20, 2026

Greptile Summary

This PR extends the CUDA (NVLS) backend with Reduce and Allreduce collective implementations, completing the set alongside the existing Broadcast and Allgather. Both collectives share a new multimem_ld_reduce_sum_f32_kernel (NVRTC-compiled at runtime) that issues the multimem.ld_reduce.global.add.v4.f32 PTX instruction, which performs the in-network SUM reduction across all ranks in a single load.

Key changes:

  • runtime/multicast_reduce.cu — new kernel; 16-byte vector loads over the multicast VA; requires SM90+/Hopper.
  • csrc/multidevice/ipc_handle.{h,cpp} — two new SymmetricMemoryHandle subclasses (SymmetricMemoryForAllreduce, SymmetricMemoryForReduce) allocating per-rank input symmetric buffers and per-rank semaphore tensors; IpcHandleCache changed from reference to pointer.
  • csrc/multidevice/cuda_p2p.cpp — two-phase semaphore protocols (write-ready → wait-ready → kernel → write-done → wait-done) for Allreduce; single-phase with root-as-orchestrator for Reduce; launchMulticastReduceKernel promoted to public API with alignment/size checks.
  • csrc/host_ir/evaluator.cpp — CUDA-backend dispatch extended to Allreduce/Reduce; non-root Reduce ranks correctly fall back to the input tensor as the cache key when there is no output tensor.
  • tests/cpp/utils.cpp — pre-existing bug fixed: found_seed was never set, causing repeated seed recomputation.

Minor issues found:

  • std::array<void*, 9> in the refactored launchAlltoallvKernel has only 8 initializers (the 9th element is silently null). Functionally harmless since cuLaunchKernel reads only the kernel's declared parameter count, but the size is incorrect.
  • The cache_root / cache_buffer derivation block is duplicated between handle(Communication*) and handle(Wait*) and could be extracted into a shared helper.
  • postAllreduceWithCudaBackend step 4 writes kIdle to the self-slot which was never set to kInProgress; the write is a no-op but adds a minor asymmetry with step 1.

Confidence Score: 4/5

  • PR is safe to merge; all issues found are non-blocking style/clarity concerns with no correctness impact.
  • The two-phase semaphore protocols for Allreduce and Reduce are logically correct and consistent with the existing Allgather/Broadcast patterns. The new kernel is straightforward. The only actionable item is the off-by-one array size in the alltoallv refactoring (std::array<void*, 9> vs 8), which is functionally harmless. Previous review concerns (alignment checks in launchMulticastReduceKernel, copyright year) have been addressed.
  • csrc/multidevice/cuda_p2p.cpp — specifically the std::array size mismatch in launchAlltoallvKernel introduced during refactoring.

Important Files Changed

Filename Overview
runtime/multicast_reduce.cu New CUDA kernel using multimem.ld_reduce.global.add.v4.f32 PTX for allreduce/reduce via NVLink SHARP. Requires 16-byte alignment (enforced by caller) and SM90+. Logic is straightforward and correct.
csrc/multidevice/cuda_p2p.cpp Adds postAllreduceWithCudaBackend, waitAllreduceWithCudaBackend, postReduceWithCudaBackend, waitReduceWithCudaBackend, and launchMulticastReduceKernel. Two-phase semaphore protocols look correct. Minor: std::array size mismatch (9 vs 8 elements) introduced in the alltoallv refactoring; self-slot kIdle write in allreduce step 4 is redundant but harmless.
csrc/multidevice/ipc_handle.h Adds SymmetricMemoryForAllreduce and SymmetricMemoryForReduce classes. IpcHandleCache changed to take a pointer instead of reference; lifetime is safe given usage in HostIrEvaluator. Override specifiers added to destructors.
csrc/multidevice/ipc_handle.cpp Implements constructors and accessors for SymmetricMemoryForAllreduce and SymmetricMemoryForReduce. Semaphore initialization via synchronous cudaMemcpy is intentional (pre-stream setup). SymmetricMemoryHandleCache::get updated to dispatch the two new handle types.
csrc/host_ir/evaluator.cpp Extends CUDA-backend dispatch to Allreduce and Reduce. Correct use of cache_buffer fallback for non-root Reduce ranks that lack an output tensor. cache_root/cache_buffer derivation is duplicated between handle(Communication*) and handle(Wait*).
csrc/multidevice/cuda_p2p.h postWithCudaBackend gains an output tensor parameter; launchMulticastReduceKernel declared as public API. Removed duplicate getP2pProtocol declaration. P2pProtocol given explicit underlying type.

Sequence Diagram

sequenceDiagram
    participant A as Rank A (all ranks)
    participant Sym as Symmetric Buffer (NVLS MC VA)
    participant B as Rank B (all ranks)

    Note over A,B: postAllreduceWithCudaBackend
    A->>Sym: cudaMemcpyAsync(inputBuffer, input)
    B->>Sym: cudaMemcpyAsync(inputBuffer, input)

    A->>A: cuStreamWriteValue32(local_sem[B]=kInProgress)
    B->>B: cuStreamWriteValue32(local_sem[A]=kInProgress)

    A->>B: cuStreamWaitValue32(B.sem[A]==kInProgress)
    B->>A: cuStreamWaitValue32(A.sem[B]==kInProgress)

    A->>A: launchMulticastReduceKernel(mc_ptr → output)
    B->>B: launchMulticastReduceKernel(mc_ptr → output)

    A->>B: cuStreamWriteValue32(B.sem[A]=kIdle)
    B->>A: cuStreamWriteValue32(A.sem[B]=kIdle)

    Note over A,B: waitAllreduceWithCudaBackend
    A->>A: cuStreamWaitValue32(local_sem[B]==kIdle)
    B->>B: cuStreamWaitValue32(local_sem[A]==kIdle)

    Note over A,B: postReduceWithCudaBackend (root=R)
    A->>Sym: cudaMemcpyAsync(inputBuffer, input)
    A->>A: cuStreamWriteValue32(own_sem=kInProgress)

    alt rank == root
        A->>B: cuStreamWaitValue32(B.sem==kInProgress)
        A->>Sym: launchMulticastReduceKernel(mc_ptr → output)
        A->>B: cuStreamWriteValue32(B.sem=kIdle)
    end

    Note over A,B: waitReduceWithCudaBackend
    B->>B: cuStreamWaitValue32(own_sem==kIdle)
Loading

Reviews (11): Last reviewed commit: "Add barrier" | Re-trigger Greptile

Comment on lines +25 to +41
size_t n_vec = n_bytes / 16;

for (size_t i = idx; i < n_vec; i += stride) {
float r0, r1, r2, r3;
const void* addr = mc_src_c + i * 16;
asm volatile(
"multimem.ld_reduce.global.add.v4.f32 {%0,%1,%2,%3}, [%4];"
: "=f"(r0), "=f"(r1), "=f"(r2), "=f"(r3)
: "l"(addr)
: "memory");
float4 out;
out.x = r0;
out.y = r1;
out.z = r2;
out.w = r3;
((float4*)dst_c)[i] = out;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Tail elements silently dropped when n_bytes is not a 16-byte multiple

The kernel computes n_vec = n_bytes / 16 using integer division, so any bytes in the range [n_vec*16, n_bytes) are silently ignored and never reduced. The caller (launchMulticastReduceKernelImpl) does check size % 16 == 0 via NVF_CHECK, so this cannot be triggered through the normal code path. However, the public launchMulticastReduceKernel wrapper (used by tests, per the header comment) does not perform that check, leaving callers free to pass a size that is not a multiple of 16 and silently get wrong results.

Consider adding the alignment assertion inside the kernel or inside launchMulticastReduceKernel itself so test callers also get a clear error.

nsarka and others added 5 commits March 20, 2026 16:40
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
@nsarka
Copy link
Copy Markdown
Member Author

nsarka commented Mar 20, 2026

!test

@nsarka
Copy link
Copy Markdown
Member Author

nsarka commented Mar 20, 2026

!test

@nsarka
Copy link
Copy Markdown
Member Author

nsarka commented Mar 20, 2026

!test

@nsarka
Copy link
Copy Markdown
Member Author

nsarka commented Mar 20, 2026

!test

// NOLINTNEXTLINE(performance-enum-size)
enum class IpcSemaphore : cuuint32_t { kIdle, kInProgress };

// Basic IPC handle for legacy P2P communication using cudaIpc* APIs
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@samnordmann I remember you said some IPC handles can be removed in favor of symmetric memory. Is this one example of that?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, or more exactly, this class as an interface could stay untouched but its members and implementation could replace all the legacy cudaIpc* API by cuMem* etc.

Copy link
Copy Markdown
Collaborator

@samnordmann samnordmann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks

// NOLINTNEXTLINE(performance-enum-size)
enum class IpcSemaphore : cuuint32_t { kIdle, kInProgress };

// Basic IPC handle for legacy P2P communication using cudaIpc* APIs
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, or more exactly, this class as an interface could stay untouched but its members and implementation could replace all the legacy cudaIpc* API by cuMem* etc.

Comment on lines +889 to +895
// Copy input to symmetric buffer
NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpyAsync(
reduce_handle->inputBuffer().data_ptr(),
input.data_ptr(),
size,
cudaMemcpyDeviceToDevice,
stream));
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be more efficient to operate directly on the user's input buffer. So we get a true 0 copy, one shot algorithm.
We need to assume the input buffer comes from Symmetric memory

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I'll make a new PR for this just so it can be smaller.

@nsarka
Copy link
Copy Markdown
Member Author

nsarka commented Mar 24, 2026

!test

@nsarka nsarka merged commit 5c17125 into NVIDIA:main Mar 25, 2026
51 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants