Add kernel based alltoallv and cuda backend for MoE dispatch and combine#5863
Add kernel based alltoallv and cuda backend for MoE dispatch and combine#5863samnordmann merged 24 commits intomainfrom
Conversation
|
!test |
Greptile SummaryAdds GPU-initiated communication support for MoE dispatch/combine operations using CUDA backend, avoiding GPU-to-CPU synchronization overhead compared to NCCL. Major changes:
Issues found:
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant R0 as Rank 0
participant R1 as Rank 1
participant Store as TCPStore
participant GPU0 as GPU 0
participant GPU1 as GPU 1
Note over R0,R1: Dispatch Phase
R0->>R0: Sort tokens by expert_id
R1->>R1: Sort tokens by expert_id
R0->>R0: Compute send_counts per rank
R1->>R1: Compute send_counts per rank
Note over R0,Store: CUDA Backend: Exchange metadata via TCPStore
R0->>Store: Put send_counts
R1->>Store: Put send_counts
R0->>Store: Get all ranks' send_counts
R1->>Store: Get all ranks' send_counts
R0->>R0: Compute recv_counts/offsets
R1->>R1: Compute recv_counts/offsets
Note over GPU0,GPU1: Allocate symmetric memory for send/recv buffers
R0->>GPU0: Allocate send_x_sym, recv_x_sym
R1->>GPU1: Allocate send_x_sym, recv_x_sym
R0->>R1: Exchange IPC handles
R1->>R0: Exchange IPC handles
Note over GPU0,GPU1: GPU-initiated alltoallv kernel
GPU0->>GPU1: Write to remote recv_x_sym
GPU1->>GPU0: Write to remote recv_x_sym
Note over R0,R1: Barrier synchronization
R0->>R1: Barrier
R1->>R0: Barrier
Note over R0,R1: Combine Phase
R0->>R0: Sort by src_rank
R1->>R1: Sort by src_rank
GPU0->>GPU1: Alltoallv send_x back
GPU1->>GPU0: Alltoallv send_x back
R0->>R0: Scatter by src_idx
R1->>R1: Scatter by src_idx
Last reviewed commit: 374c8b3 |
There was a problem hiding this comment.
logic: checking wrong backend constant - should check backend parameter, not hardcoded kNccl
| if (!communicator_->isBackendAvailable(CommunicatorBackend::kNccl)) { | |
| GTEST_SKIP() << "Backend " << backend << " not available."; | |
| } | |
| if (!communicator_->isBackendAvailable(backend)) { | |
| GTEST_SKIP() << "Backend " << backend << " not available."; | |
| } |
|
Review updated until commit 374c8b3 Description
|
| Relevant files | |||||||||
|---|---|---|---|---|---|---|---|---|---|
| Enhancement |
| ||||||||
| Tests |
| ||||||||
| Documentation |
| ||||||||
| Configuration changes |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Runtime Compilation Risk
|
|
!test |
| opts.push_back("-I/usr/local/cuda/include"); | ||
| opts.push_back("-I/usr/local/cuda/include/cccl"); |
There was a problem hiding this comment.
hardcoded CUDA include paths may break on non-standard installations
| opts.push_back("-I/usr/local/cuda/include"); | |
| opts.push_back("-I/usr/local/cuda/include/cccl"); | |
| // Use CUDA_HOME environment variable or CMake-detected paths | |
| std::string cuda_home = std::getenv("CUDA_HOME") ? std::getenv("CUDA_HOME") : "/usr/local/cuda"; | |
| opts.push_back(("-I" + cuda_home + "/include").c_str()); | |
| opts.push_back(("-I" + cuda_home + "/include/cccl").c_str()); |
| std::string alltoallvBarrierKey(const std::string& tag, int64_t rank) { | ||
| return "nvfuser_alltoallv_barrier_" + tag + "_" + std::to_string(rank); | ||
| } |
There was a problem hiding this comment.
unused function - alltoallvBarrierKey is defined but never called
| NVFUSER_NVRTC_SAFE_CALL(nvrtcCreateProgram( | ||
| &prog, | ||
| nvfuser_resources::alltoallv_cu, | ||
| "alltoallv.cu", |
There was a problem hiding this comment.
Why nvrtc? Can't we simply alltoallv_kernel<<<...>>>?
There was a problem hiding this comment.
Good question. I used nvrtc here to match nvFuser’s existing runtime-kernel pattern, to the best of my understanding, not because alltoallv_kernel<<<...>>> is impossible.
In our codebase, IIUC, these helper CUDA kernels are treated as runtime resources:
CMakeLists.txt adds csrc/multidevice/alltoallv.cu to NVFUSER_RUNTIME_FILES, then stringifies it into nvfuser_resources/alltoallv.h.
At runtime, we compile/load that source with NVRTC + cuModuleLoadData, same style as other runtime kernels in csrc/runtime/compiled_kernel.cpp, and similarly to the multicast helper in cuda_p2p.cpp.
If you’d prefer the static CUDA-launch route (alltoallv_kernel<<<...>>>), I can switch to that — could you clarify the exact direction you want?
There was a problem hiding this comment.
cc @naoyam because it's out of my expertise.
IIUC, nvfuser_resources ought to contain device functions that are fused into kernels generated on the fly. nvrtc makes sense there because kernels are generated as a string.
When the code is already a kernel (e.g. alltoallv_kernel) in a standalone .cu file, I've yet to see a good reason to nvrtc that.
There was a problem hiding this comment.
Ok I see, thanks fr the explanation. Actually, the goal is to eventually be able to fuse this kernel with others, so in this regards it could still make sense to keep dynamic compilation
There was a problem hiding this comment.
Sam is right that the nvfuser library has no compiled device code, except for the CUTLASS integration @jacobhinkle worked on. Making everything dynamically compiled has some merits when considering shipping the compiled nvFuser library. For example, you don't need to worry about target specialization of kernel compilation. I doubt this would matter for this case, but for things like matmuls, I'd imagine we would need to compile nvFuser itself for multiple target device versions that could be potentially used by the compiled nvFuser library, whereas if nvrtc is used, obviously it is compiled once only for the actual device at the run time.
As for the CUTLASS integration, it is virtually impossible to do nvrtc as it just takes too long to compile, so I'd say it's an exception. Please chime in if not, @jacobhinkle.
| static CUmodule module = nullptr; | ||
| static CUfunction kernel = nullptr; |
There was a problem hiding this comment.
static variables without thread safety - multiple threads calling launchAlltoallvKernel concurrently could race on the module == nullptr check during first initialization
| static CUmodule module = nullptr; | |
| static CUfunction kernel = nullptr; | |
| static std::once_flag init_flag; | |
| static CUmodule module = nullptr; | |
| static CUfunction kernel = nullptr; | |
| std::call_once(init_flag, [&]() { |
Additional Comments (1)
|
|
!test |
alltoallvimplementation using GPU-initiated comms (SM-driven NVLink), taking only GPU buffers, even for the alltoallv "metadate" such as splitSize. Available throughkCudabackend. Requires recv buffer to be allocated as symmetric memoryCudabackend for dispatch and combine which avoids gpu->cpu sync (compared to nccl backed version)