Skip to content

[Bug] moe_permute CUDA kernel: int32 overflow and incorrect -1 sentinel handling #2908

@jing-4369

Description

@jing-4369

Thanks for these efficient MoE kernels — we rely on them in training.

Two independent latent bugs in transformer_engine/common/permutation/permutation.cu and its PyTorch extension caller. Both reproduce on main (264da2b), release tag v2.13 (2877704), and v2.9. The buggy code has been unchanged since the kernel was introduced in #936 (Aug 2024).

Environment for repros

conda-forge:
  python              3.13.13
  pytorch             2.10.0   cuda129_mkl_py313_h623d66f_303
  transformer-engine-torch  2.13  py313hba49f57_1
  cuda-version        12.9
GPU: NVIDIA H20Z

Bug 1 — int32 overflow in moe_unpermute_kernel / moe_permute_kernel

source_token * num_cols and source_row * num_cols are computed in int:

// moe_unpermute_kernel
const int source_token = blockIdx.x;
int source_row = row_id_map[source_token];
const T *source_row_ptr = input + source_row * num_cols;   // int * int → overflows

Once num_out_tokens * num_cols ≥ 2³¹, the pointer offset wraps and the kernel reads/writes at a bogus address. We hit this on DeepSeek-V3 MoE training (long-context runs with hidden = 7168, topK = 8; num_out_tokens * hidden crosses the 2³¹ boundary once per-rank token count grows past a few hundred thousand). The failure mode is either silent corruption — sudden training-loss spikes / NaNs that took us a while to trace back to the kernel — or an outright CUDA error: an illegal memory access was encountered.

Repro

# repro_int_overflow.py  (~20 GB GPU memory needed)
import torch
from transformer_engine.pytorch import moe_permute

num_tokens = 2 ** 18 + 1   # 262_145
num_cols   = 2 ** 13       # 8_192  →  (num_out_tokens - 1) * num_cols = 2**31 exactly
topk, num_experts = 1, 4

x       = torch.randn(num_tokens, num_cols, dtype=torch.bfloat16, device="cuda")
indices = (torch.arange(num_tokens, dtype=torch.int32, device="cuda")
           .remainder(num_experts).view(-1, topk))

# Case A: just below 2**31  -> 262144 tokens  -> OK, matches reference bit-for-bit
# Case B: at 2**31          -> 262145 tokens  -> BUG
te_permuted, _ = moe_permute(x, indices, num_out_tokens=num_tokens * topk, map_type="index")
torch.cuda.synchronize()

Observed (case B):

RuntimeError: .../transformer_engine/common/permutation/permutation.cu:252 in
function nvte_permute_launcher: CUDA Error: an illegal memory access was encountered

Case A (num_tokens = 2**18, product just under 2**31) runs clean with diff 0.0 against torch.argsort + index_select.


Bug 2 — Incorrect handling of -1 sentinels in routing_map

Libraries such as DeepEP (and any expert-parallel mask that sets non-local (token, slot) pairs to -1) feed a routing_map with -1 entries. cub::DeviceRadixSort::SortPairs is signed ascending, so those sentinels land at the head of sorted_row_id, not the tail.

moe_permute_row_map, however, assumes the opposite:

if (idx >= num_rows * topK) return;

int source_row = sorted_row_id[idx];
int source_token_id = source_row / topK;
int source_topK_id = source_row % topK;

if (idx >= num_out_tokens) {
  // Set the indices of dropped tokens to -1
  row_id_map[source_topK_id * num_rows + source_token_id] = -1;
} else {
  // Create a row id map for subsequent unpermute operation
  row_id_map[source_topK_id * num_rows + source_token_id] = idx;
}

For idx < num_minus_ones, the kernel reads sorted_row_id[idx] == -1, computes source_token_id = -1 / topK / source_topK_id = -1 % topK (implementation-defined), and writes a valid idx into an unrelated slot of row_id_map. Conversely, valid tokens that should be kept lose their mapping. The subsequent unpermute then silently mixes sentinel rows into the output.

Expert-parallel context (why -1 appears)

4 global experts (E0/E1/E2/E3), 2 ranks.
  Rank 0 holds: E0, E1
  Rank 1 holds: E2, E3
Router assigns (replicated across ranks):
  token_a → [0, 3]
  token_b → [2, 1]
  token_c → [3, 2]

On Rank 0 we mask out non-local experts:
  a → [ 0, -1]    # E3 not on this rank
  b → [-1,  1]
  c → [-1, -1]
  num_out_tokens = 2

Each rank then calls moe_permute with its masked indices, expecting -1 slots to be skipped.

Repro

# repro_minus_one.py  (trivial, any GPU)
import torch
from transformer_engine.pytorch import moe_permute

def reference_permute(x, indices, num_out_tokens):
    topk = indices.size(1)
    flat = indices.view(-1)
    sorted_idx = torch.argsort(flat, stable=True)
    sorted_idx = sorted_idx[flat[sorted_idx] != -1]
    assert sorted_idx.numel() == num_out_tokens
    return x.index_select(0, sorted_idx // topk), sorted_idx

torch.manual_seed(0)
N, topk, H = 8, 2, 16
tokens  = torch.randn(N, H, dtype=torch.bfloat16, device="cuda")
indices = torch.randint(0, 4, (N, topk), dtype=torch.int32, device="cuda")
# EP-style mask: half the (token, slot) pairs → -1
flat = indices.view(-1)
flat[torch.randperm(flat.numel(), device="cuda")[: flat.numel() // 2]] = -1
num_out_tokens = int((indices != -1).sum())

te_permuted, _  = moe_permute(tokens, indices,
                              num_out_tokens=num_out_tokens, map_type="index")
ref_permuted, _ = reference_permute(tokens, indices, num_out_tokens)

print("max |TE - ref| =", (te_permuted.float() - ref_permuted.float()).abs().max().item())

Observed (main / v2.13):

indices =
[[ 0 -1]
 [-1  2]
 [-1  2]
 [ 1  1]
 [ 3 -1]
 [ 3 -1]
 [-1  1]
 [-1 -1]]
num_out_tokens = 8

max |TE - ref| = 4.562e+00          # bf16 — full-scale garbage, not noise
TE permuted (first 3 rows):
[[-0.92578125 -0.42578125 -2.640625    0.14550781]
 [-0.12695312  1.21875     1.4375      1.0625    ]    ← wrong row
 [-0.59765625 -0.328125   -0.91015625 -0.8046875 ]]   ← wrong row
ref permuted (first 3 rows):
[[-0.92578125 -0.42578125 -2.640625    0.14550781]
 [-0.18359375  0.3828125   0.39257812 -0.08300781]
 [-0.18359375  0.3828125   0.39257812 -0.08300781]]

Expected: the first row (token 0 → expert 0) matches; rows 1 and 2 should both be token 3 (the only other survivor among indices with expert 1), but TE returns two different arbitrary rows instead.


Related prior report

Issue #1336 ("the max error of moe_permute/unpermute.grad could reach 3.6e+00", Nov 2024) is very likely the same root cause — the reporter observed full-scale errors that grew with topK and noted "This has had some impact on training loss." Neither -1-sentinel nor int32-overflow handling has been touched since that report.

Fix

PR #2907 proposes a minimal fix:

  1. Widen source_token, source_row, dest_row to int64_t inside the kernels.
  2. In the caller, advance sorted_row_id_ptr past the num_minus_ones sentinel prefix and pre-fill row_id_map with -1 via torch::full; switch the launcher grid to num_out_tokens blocks.

No public API / dtype changes. Happy-path workloads (no -1, offset within int32) are unchanged.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions