Expert Parallelism: common C API + NCCL EP backend#3034
Conversation
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Greptile SummaryThis PR introduces the Expert Parallelism (EP) foundation for TransformerEngine: a public C API (
Confidence Score: 3/5The new EP layer is not yet wired to any Python framework, so no production training path is affected today, but the handle-memory sizing inconsistency means the first framework integration could allocate an undersized buffer and corrupt memory inside ncclEpInitHandle. Two issues in ep_backend.cpp warrant attention before the next framework PR lands: register_layer uses a manually zero-initialised ncclEpHandleConfig_t instead of the NCCL_EP_HANDLE_CONFIG_INIT macro used everywhere else, almost certainly omitting the version field and making ncclEpHandleMemSize return a buffer size that does not match what ncclEpInitHandle will actually write; and validate_config skips the max_recv_tokens_per_rank > 0 check that an inline comment already flags as mandatory. transformer_engine/common/ep/ep_backend.cpp (handle config init and missing validation) and setup.py (arch detection and NCCL_HOME handling) Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant C_API as nvte_ep_* (ep_api.cpp)
participant Backend as EPBackend singleton
participant NCCL_EP as libnccl_ep.so
Caller->>C_API: nvte_ep_initialize(ep_comm, group_config)
C_API->>Backend: EPBackend::initialize(ncclComm_t, config)
Backend->>Backend: validate_config
Backend->>NCCL_EP: ncclEpCreateGroup
Caller->>C_API: nvte_ep_register_layer(layer_config, handle_mem_size)
C_API->>Backend: register_layer
Backend->>NCCL_EP: ncclEpHandleMemSize
Backend-->>Caller: handle_id
loop Per training step
Caller->>C_API: nvte_ep_prepare(handle, topk_idx, stream)
C_API->>Backend: prepare
Backend->>NCCL_EP: ncclEpInitHandle + ncclEpUpdateHandle
Caller->>C_API: nvte_ep_dispatch(handle, tokens, stream)
C_API->>Backend: dispatch
Backend->>NCCL_EP: ncclEpInitHandle + ncclEpDispatch
Caller->>C_API: nvte_ep_combine(handle, expert_out, stream)
C_API->>Backend: combine
Backend->>NCCL_EP: ncclEpInitHandle + ncclEpCombine
end
Caller->>C_API: nvte_ep_shutdown()
C_API->>Backend: EPBackend::shutdown()
Backend->>NCCL_EP: ncclEpGroupDestroy
Reviews (1): Last reviewed commit: "Expert Parallelism: common C API + NCCL ..." | Re-trigger Greptile |
| ncclEpHandleConfig_t hcfg{}; | ||
| hcfg.size = static_cast<unsigned int>(sizeof(hcfg)); | ||
| hcfg.dispatch_output_per_expert_alignment = layer_config.dispatch_output_per_expert_alignment; | ||
| size_t hm_size = 0; | ||
| NVTE_CHECK_NCCL(ncclEpHandleMemSize(ep_group_, NCCL_EP_LAYOUT_EXPERT_MAJOR, &hcfg, &hm_size, | ||
| layer_config.top_k)); |
There was a problem hiding this comment.
Inconsistent
ncclEpHandleConfig_t initialization may produce wrong buffer size
register_layer initializes the config with {} and manually sets only hcfg.size, while open_handle uses NCCL_EP_HANDLE_CONFIG_INIT (which also sets version and likely other fields to their expected defaults). ncclEpHandleMemSize and ncclEpInitHandle can disagree on the required buffer size when the version field is 0 instead of NCCL_EP_API_VERSION, causing the caller to allocate an undersized handle_mem buffer and leading to an out-of-bounds write in ncclEpInitHandle.
| NVTE_CHECK(config.max_num_sms >= 0, "max_num_sms must be >= 0 (0 = auto), got ", | ||
| config.max_num_sms); |
There was a problem hiding this comment.
max_recv_tokens_per_rank is not validated in validate_config
The comment at line 243 explicitly notes "Must be > 0; NCCL EP errors out on 0", but validate_config never enforces this. A zero value would cause ncclEpCreateGroup to fail with a cryptic NCCL internal error instead of the clear TE diagnostic that all other config fields get.
| NVTE_CHECK(config.max_num_sms >= 0, "max_num_sms must be >= 0 (0 = auto), got ", | |
| config.max_num_sms); | |
| NVTE_CHECK(config.max_num_sms >= 0, "max_num_sms must be >= 0 (0 = auto), got ", | |
| config.max_num_sms); | |
| NVTE_CHECK(config.max_recv_tokens_per_rank > 0, | |
| "max_recv_tokens_per_rank must be positive, got ", | |
| config.max_recv_tokens_per_rank); |
| env_home = os.environ.get("NCCL_HOME") | ||
| if env_home and (Path(env_home) / "include" / "nccl.h").exists(): | ||
| return env_home |
There was a problem hiding this comment.
NCCL_HOME set to a wrong path is silently ignored
If a user sets NCCL_HOME to an incorrect prefix that doesn't contain include/nccl.h, the function falls through to the system probe list without any warning. The function should warn when NCCL_HOME is set but doesn't resolve to a valid NCCL install.
| env_home = os.environ.get("NCCL_HOME") | |
| if env_home and (Path(env_home) / "include" / "nccl.h").exists(): | |
| return env_home | |
| env_home = os.environ.get("NCCL_HOME") | |
| if env_home: | |
| if (Path(env_home) / "include" / "nccl.h").exists(): | |
| return env_home | |
| print( | |
| f"[NCCL EP] WARNING: NCCL_HOME='{env_home}' is set but " | |
| f"'{env_home}/include/nccl.h' was not found; falling back to system probes." | |
| ) |
| has_hopper_or_newer = any( | ||
| int(a.strip().rstrip("af")) >= 90 | ||
| for a in str(archs or "").split(";") | ||
| if a.strip().rstrip("af").isdigit() | ||
| ) |
There was a problem hiding this comment.
NVTE_CUDA_ARCHS=native silently disables NCCL EP on valid Hopper hardware
The arch parsing rejects any token that is not isdigit() after stripping a/f suffixes. The CMake keyword "native" is silently skipped, so has_hopper_or_newer stays False and NCCL EP is auto-disabled even on a Hopper machine.
| has_hopper_or_newer = any( | |
| int(a.strip().rstrip("af")) >= 90 | |
| for a in str(archs or "").split(";") | |
| if a.strip().rstrip("af").isdigit() | |
| ) | |
| arch_tokens = [a.strip() for a in str(archs or "").split(";") if a.strip()] | |
| has_native = any(t.lower() == "native" for t in arch_tokens) | |
| has_hopper_or_newer = has_native or any( | |
| int(t.rstrip("af")) >= 90 | |
| for t in arch_tokens | |
| if t.rstrip("af").isdigit() | |
| ) |
Summary
First PR in the TE Expert Parallelism (EP) series. Lands the common C API and NCCL EP backend that later framework PRs (PyTorch, JAX) build on. No Python bindings yet — common-lib foundation plus build wiring only. Build/load works on any arch; SM and NCCL version gates fire at runtime.
Every network-bound payload tensor takes an optional
NVTECommWindow. When the window is provided, the backend uses NCCL EP's symmetric-memory zero-copy path, which skips the D2D Memcpy from the user buffers to the Symmetric Staging Buffers.Implementation
Public C API (
transformer_engine/common/include/transformer_engine/{ep.h,comm_window.h})Types:
NVTEEpGroupConfig,NVTEEpLayerConfig,NVTEEpHandle,NVTECommWindow(side-band{ncclWindow_t window, size_t offset}; NCCL peer handles are not carried onNVTETensor).Lifecycle (host-only, eager):
nvte_ep_initialize— borrow an externalncclComm_tfor the EP sub-group and init the singleton backend.nvte_ep_shutdown— tear down the backend; idempotent; does not destroyep_comm.nvte_ep_register_layer— reserve ahandle_idfor a layer config and report thehandle_membuffer size the caller must allocate. The pair{id, mem}becomes the per-stepNVTEEpHandle.Per-step (allocation-free, CUDA-graph capturable)
nvte_ep_prepare— all-gather the routing map and write routing maps tohandle.mem.nvte_ep_dispatch— scatter tokens and routing weights from source ranks to expert ranks.tokens,topk_weights,recv_tokens,recv_topk_weightseach accept an optional symm-mem window for zero-copy.nvte_ep_combine— scatter-sum expert outputs back to source ranks (unweighted; caller pre-multiplies byrecv_topk_weights).expert_outaccepts a window.nvte_ep_dispatch_bwd— backward of dispatch; routes token and weight grads back to source.gradandg_recv_topk_weightsaccept windows; the gathered outputs (grad_tokens,grad_topk_weights).nvte_ep_combine_bwd— backward of combine;gradandgrad_expert_outaccept windows. Padded slots ingrad_expert_outare zeroed.Backend + build
transformer_engine/common/ep/):EPBackendsingleton, HT-mode dispatch/combine over NCCL EP (libnccl_ep.so), group/layer registration. Internal helpermake_payload_tensor()builds the per-callncclEpTensor_t: when the caller'sNVTECommWindow.window != nullptrit setswin_hdl+win_offset(zero-copy); otherwise it setsdatafromnvte_tensor_data(t)(HBM fallback).EPBackend::initialize): SM>=90 (viacudaDeviceGetAttribute), NCCL>=2.30.4 (viancclGetVersion), CUDA multicast/NVLS support.NVTE_WITH_NCCL_EP=OFF,ep/ep_api_stub.cppprovides throwingnvte_ep_*stubs so framework bindings link unconditionally; failure surfaces at firstnvte_ep_initialize.setup.pybuildslibnccl_ep.sofrom3rdparty/ncclby default; auto-disables NCCL EP when no requested CUDA arch >= 90. ExplicitNVTE_BUILD_WITH_NCCL_EP=1with all archs < 90 is treated as user errorNVTE_BUILD_WITH_NCCL_EP=0to opt out.NCCL_HOMEresolved dynamically: explicit env →/opt/nvidia/nccl,/usr/local/nccl,/usr→ldconfig -pfallback.Testing
tests/cpp_distributed/.Type of change
Checklist: