Thanks for these efficient MoE kernels — we rely on them in training.
Two independent latent bugs in transformer_engine/common/permutation/permutation.cu and its PyTorch extension caller. Both reproduce on main (264da2b), release tag v2.13 (2877704), and v2.9. The buggy code has been unchanged since the kernel was introduced in #936 (Aug 2024).
Environment for repros
conda-forge:
python 3.13.13
pytorch 2.10.0 cuda129_mkl_py313_h623d66f_303
transformer-engine-torch 2.13 py313hba49f57_1
cuda-version 12.9
GPU: NVIDIA H20Z
Bug 1 — int32 overflow in moe_unpermute_kernel / moe_permute_kernel
source_token * num_cols and source_row * num_cols are computed in int:
// moe_unpermute_kernel
const int source_token = blockIdx.x;
int source_row = row_id_map[source_token];
const T *source_row_ptr = input + source_row * num_cols; // int * int → overflows
Once num_out_tokens * num_cols ≥ 2³¹, the pointer offset wraps and the kernel reads/writes at a bogus address. We hit this on DeepSeek-V3 MoE training (long-context runs with hidden = 7168, topK = 8; num_out_tokens * hidden crosses the 2³¹ boundary once per-rank token count grows past a few hundred thousand). The failure mode is either silent corruption — sudden training-loss spikes / NaNs that took us a while to trace back to the kernel — or an outright CUDA error: an illegal memory access was encountered.
Repro
# repro_int_overflow.py (~20 GB GPU memory needed)
import torch
from transformer_engine.pytorch import moe_permute
num_tokens = 2 ** 18 + 1 # 262_145
num_cols = 2 ** 13 # 8_192 → (num_out_tokens - 1) * num_cols = 2**31 exactly
topk, num_experts = 1, 4
x = torch.randn(num_tokens, num_cols, dtype=torch.bfloat16, device="cuda")
indices = (torch.arange(num_tokens, dtype=torch.int32, device="cuda")
.remainder(num_experts).view(-1, topk))
# Case A: just below 2**31 -> 262144 tokens -> OK, matches reference bit-for-bit
# Case B: at 2**31 -> 262145 tokens -> BUG
te_permuted, _ = moe_permute(x, indices, num_out_tokens=num_tokens * topk, map_type="index")
torch.cuda.synchronize()
Observed (case B):
RuntimeError: .../transformer_engine/common/permutation/permutation.cu:252 in
function nvte_permute_launcher: CUDA Error: an illegal memory access was encountered
Case A (num_tokens = 2**18, product just under 2**31) runs clean with diff 0.0 against torch.argsort + index_select.
Bug 2 — Incorrect handling of -1 sentinels in routing_map
Libraries such as DeepEP (and any expert-parallel mask that sets non-local (token, slot) pairs to -1) feed a routing_map with -1 entries. cub::DeviceRadixSort::SortPairs is signed ascending, so those sentinels land at the head of sorted_row_id, not the tail.
moe_permute_row_map, however, assumes the opposite:
if (idx >= num_rows * topK) return;
int source_row = sorted_row_id[idx];
int source_token_id = source_row / topK;
int source_topK_id = source_row % topK;
if (idx >= num_out_tokens) {
// Set the indices of dropped tokens to -1
row_id_map[source_topK_id * num_rows + source_token_id] = -1;
} else {
// Create a row id map for subsequent unpermute operation
row_id_map[source_topK_id * num_rows + source_token_id] = idx;
}
For idx < num_minus_ones, the kernel reads sorted_row_id[idx] == -1, computes source_token_id = -1 / topK / source_topK_id = -1 % topK (implementation-defined), and writes a valid idx into an unrelated slot of row_id_map. Conversely, valid tokens that should be kept lose their mapping. The subsequent unpermute then silently mixes sentinel rows into the output.
Expert-parallel context (why -1 appears)
4 global experts (E0/E1/E2/E3), 2 ranks.
Rank 0 holds: E0, E1
Rank 1 holds: E2, E3
Router assigns (replicated across ranks):
token_a → [0, 3]
token_b → [2, 1]
token_c → [3, 2]
On Rank 0 we mask out non-local experts:
a → [ 0, -1] # E3 not on this rank
b → [-1, 1]
c → [-1, -1]
num_out_tokens = 2
Each rank then calls moe_permute with its masked indices, expecting -1 slots to be skipped.
Repro
# repro_minus_one.py (trivial, any GPU)
import torch
from transformer_engine.pytorch import moe_permute
def reference_permute(x, indices, num_out_tokens):
topk = indices.size(1)
flat = indices.view(-1)
sorted_idx = torch.argsort(flat, stable=True)
sorted_idx = sorted_idx[flat[sorted_idx] != -1]
assert sorted_idx.numel() == num_out_tokens
return x.index_select(0, sorted_idx // topk), sorted_idx
torch.manual_seed(0)
N, topk, H = 8, 2, 16
tokens = torch.randn(N, H, dtype=torch.bfloat16, device="cuda")
indices = torch.randint(0, 4, (N, topk), dtype=torch.int32, device="cuda")
# EP-style mask: half the (token, slot) pairs → -1
flat = indices.view(-1)
flat[torch.randperm(flat.numel(), device="cuda")[: flat.numel() // 2]] = -1
num_out_tokens = int((indices != -1).sum())
te_permuted, _ = moe_permute(tokens, indices,
num_out_tokens=num_out_tokens, map_type="index")
ref_permuted, _ = reference_permute(tokens, indices, num_out_tokens)
print("max |TE - ref| =", (te_permuted.float() - ref_permuted.float()).abs().max().item())
Observed (main / v2.13):
indices =
[[ 0 -1]
[-1 2]
[-1 2]
[ 1 1]
[ 3 -1]
[ 3 -1]
[-1 1]
[-1 -1]]
num_out_tokens = 8
max |TE - ref| = 4.562e+00 # bf16 — full-scale garbage, not noise
TE permuted (first 3 rows):
[[-0.92578125 -0.42578125 -2.640625 0.14550781]
[-0.12695312 1.21875 1.4375 1.0625 ] ← wrong row
[-0.59765625 -0.328125 -0.91015625 -0.8046875 ]] ← wrong row
ref permuted (first 3 rows):
[[-0.92578125 -0.42578125 -2.640625 0.14550781]
[-0.18359375 0.3828125 0.39257812 -0.08300781]
[-0.18359375 0.3828125 0.39257812 -0.08300781]]
Expected: the first row (token 0 → expert 0) matches; rows 1 and 2 should both be token 3 (the only other survivor among indices with expert 1), but TE returns two different arbitrary rows instead.
Related prior report
Issue #1336 ("the max error of moe_permute/unpermute.grad could reach 3.6e+00", Nov 2024) is very likely the same root cause — the reporter observed full-scale errors that grew with topK and noted "This has had some impact on training loss." Neither -1-sentinel nor int32-overflow handling has been touched since that report.
Fix
PR #2907 proposes a minimal fix:
- Widen
source_token, source_row, dest_row to int64_t inside the kernels.
- In the caller, advance
sorted_row_id_ptr past the num_minus_ones sentinel prefix and pre-fill row_id_map with -1 via torch::full; switch the launcher grid to num_out_tokens blocks.
No public API / dtype changes. Happy-path workloads (no -1, offset within int32) are unchanged.
Thanks for these efficient MoE kernels — we rely on them in training.
Two independent latent bugs in
transformer_engine/common/permutation/permutation.cuand its PyTorch extension caller. Both reproduce onmain(264da2b), release tagv2.13(2877704), andv2.9. The buggy code has been unchanged since the kernel was introduced in #936 (Aug 2024).Environment for repros
Bug 1 —
int32overflow inmoe_unpermute_kernel/moe_permute_kernelsource_token * num_colsandsource_row * num_colsare computed inint:Once
num_out_tokens * num_cols ≥ 2³¹, the pointer offset wraps and the kernel reads/writes at a bogus address. We hit this on DeepSeek-V3 MoE training (long-context runs withhidden = 7168,topK = 8;num_out_tokens * hiddencrosses the 2³¹ boundary once per-rank token count grows past a few hundred thousand). The failure mode is either silent corruption — sudden training-loss spikes / NaNs that took us a while to trace back to the kernel — or an outrightCUDA error: an illegal memory access was encountered.Repro
Observed (case B):
Case A (
num_tokens = 2**18, product just under2**31) runs clean with diff0.0againsttorch.argsort + index_select.Bug 2 — Incorrect handling of
-1sentinels inrouting_mapLibraries such as DeepEP (and any expert-parallel mask that sets non-local
(token, slot)pairs to-1) feed arouting_mapwith-1entries.cub::DeviceRadixSort::SortPairsis signed ascending, so those sentinels land at the head ofsorted_row_id, not the tail.moe_permute_row_map, however, assumes the opposite:For
idx < num_minus_ones, the kernel readssorted_row_id[idx] == -1, computessource_token_id = -1 / topK/source_topK_id = -1 % topK(implementation-defined), and writes a valididxinto an unrelated slot ofrow_id_map. Conversely, valid tokens that should be kept lose their mapping. The subsequentunpermutethen silently mixes sentinel rows into the output.Expert-parallel context (why
-1appears)Each rank then calls
moe_permutewith its masked indices, expecting-1slots to be skipped.Repro
Observed (main / v2.13):
Expected: the first row (token 0 → expert 0) matches; rows 1 and 2 should both be token 3 (the only other survivor among indices with expert 1), but TE returns two different arbitrary rows instead.
Related prior report
Issue #1336 ("the max error of moe_permute/unpermute.grad could reach 3.6e+00", Nov 2024) is very likely the same root cause — the reporter observed full-scale errors that grew with
topKand noted "This has had some impact on training loss." Neither-1-sentinel norint32-overflow handling has been touched since that report.Fix
PR #2907 proposes a minimal fix:
source_token,source_row,dest_rowtoint64_tinside the kernels.sorted_row_id_ptrpast thenum_minus_onessentinel prefix and pre-fillrow_id_mapwith-1viatorch::full; switch the launcher grid tonum_out_tokensblocks.No public API / dtype changes. Happy-path workloads (no
-1, offset withinint32) are unchanged.