From ce3477410a3b5fd3435ffa617aa9dd01a4d01771 Mon Sep 17 00:00:00 2001 From: Jingyi Xi Date: Tue, 21 Apr 2026 15:57:43 +0800 Subject: [PATCH 1/3] [Common][PyTorch] Fix int32 overflow and -1 sentinel handling in moe_permute Two independent bugs in transformer_engine/common/permutation/permutation.cu and the PyTorch extension caller reproduce on main (264da2b) and v2.13: 1. int32 overflow in moe_unpermute_kernel and moe_permute_kernel. `source_token * num_cols` and `source_row * num_cols` are computed with int, so for long-sequence MoE workloads where num_out_tokens * num_cols reaches 2**31 (e.g. 2**18 tokens x 2**13 hidden), the pointer offset wraps and the kernel either reads garbage or raises `an illegal memory access was encountered`. Widening source_token, source_row and dest_row to int64_t inside the kernels keeps the index arithmetic in 64 bits without changing any public types. 2. Incorrect handling of -1 sentinels in the routing indices. Libraries such as DeepEP (and any expert-parallel mask that sets non-local (token, slot) pairs to -1) feed a routing_map that contains -1 entries. `cub::DeviceRadixSort::SortPairs` is signed ascending, so those sentinels land at the HEAD of sorted_row_id, not the tail. moe_permute_row_map currently writes -1 only for idx >= num_out_tokens and reads the sentinel prefix as if it were a valid sorted id, producing bogus row_id_map writes (for instance `source_row / topK == 0, source_row % topK == -1`). The caller now advances sorted_row_id_ptr past the num_minus_ones prefix and pre-fills row_id_map with -1 via torch::full, so the kernel only processes the valid suffix and never dereferences a sentinel. The launcher's grid switches from num_rows*topK blocks to num_out_tokens blocks to match the new valid range. No behaviour change on happy-path routing_map (no -1, no overflow). Reproducers: - 8-token, topK=2 routing_map with -1 masking: max |TE - ref| = 4.5e0 on bf16 with current main; 0.0 with this patch. - num_tokens=2**18+1, num_cols=2**13, topK=1: current main raises CUDA illegal memory access at permutation.cu:252; with this patch it succeeds. Signed-off-by: Jingyi Xi --- .../common/permutation/permutation.cu | 21 +++++++------------ .../pytorch/csrc/extensions/permutation.cpp | 14 +++++++++---- 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/transformer_engine/common/permutation/permutation.cu b/transformer_engine/common/permutation/permutation.cu index fbba27941c..d0a8f44f7c 100644 --- a/transformer_engine/common/permutation/permutation.cu +++ b/transformer_engine/common/permutation/permutation.cu @@ -19,19 +19,12 @@ static __global__ void moe_permute_row_map(const int *sorted_row_id, int *row_id const int tid = threadIdx.x; const int idx = bid * blockDim.x + tid; - if (idx >= num_rows * topK) return; + if (idx >= num_out_tokens) return; int source_row = sorted_row_id[idx]; int source_token_id = source_row / topK; int source_topK_id = source_row % topK; - - if (idx >= num_out_tokens) { - // Set the indices of dropped tokens to -1 - row_id_map[source_topK_id * num_rows + source_token_id] = -1; - } else { - // Create a row id map for subsequent unpermute operation - row_id_map[source_topK_id * num_rows + source_token_id] = idx; - } + row_id_map[source_topK_id * num_rows + source_token_id] = idx; } template @@ -42,7 +35,7 @@ __global__ void moe_unpermute_kernel(const T *input, T *unpermuted_output, const TCompute *s_prob = reinterpret_cast(s_mem); // Each block corresponds to one dest token - const int source_token = blockIdx.x; + const int64_t source_token = blockIdx.x; const int tid = threadIdx.x; if (hasProb) { @@ -65,7 +58,7 @@ __global__ void moe_unpermute_kernel(const T *input, T *unpermuted_output, const TCompute frag_elem[kElementsPerAccess]; TCompute frag_sum[kElementsPerAccess]; - int source_row = row_id_map[source_token]; + int64_t source_row = row_id_map[source_token]; // source_row == -1 represents a dropped token if (source_row != -1) { @@ -134,7 +127,7 @@ __global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *ac TCompute *s_prob = reinterpret_cast(s_mem); // Each block corresponds to one source token - const int source_token = blockIdx.x; + const int64_t source_token = blockIdx.x; const int tid = threadIdx.x; if (hasProb) { @@ -172,7 +165,7 @@ __global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *ac for (int k = 0; k < topKTile; k++) { if (k == topK) break; - int dest_row = row_id_map[index]; + int64_t dest_row = row_id_map[index]; index += num_rows; if (dest_row != -1) { @@ -239,7 +232,7 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id, // moe_permute_fwd int threads = 64; - int blocks = (num_rows * topK + threads - 1) / threads; + int blocks = (num_out_tokens + threads - 1) / threads; moe_permute_row_map<<>>(sorted_row_id, row_id_map, num_rows, topK, num_out_tokens); diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cpp b/transformer_engine/pytorch/csrc/extensions/permutation.cpp index 226705b169..6505dba1ee 100644 --- a/transformer_engine/pytorch/csrc/extensions/permutation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cpp @@ -51,13 +51,19 @@ std::tuple> moe_permute_fwd( reinterpret_cast(sorted_indices_ptr), reinterpret_cast(row_id_ptr), reinterpret_cast(sorted_row_id_ptr), num_tokens * topK); - // Output buffer alloc + // Signed radix sort places -1 sentinel entries (e.g. expert-parallel rank mask) + // at the HEAD of sorted_row_id. Skip that prefix so the kernel sees only the + // valid suffix, and pre-fill row_id_map with -1 so the dropped slots are marked + // without the kernel ever dereferencing a sentinel. num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * topK; + const int num_minus_ones = num_tokens * topK - num_out_tokens; + sorted_row_id_ptr = reinterpret_cast(sorted_row_id_ptr) + num_minus_ones * sizeof(int); at::Tensor permuted_output = torch::empty({num_out_tokens, num_cols}, torch::dtype(input.scalar_type()).device(torch::kCUDA).requires_grad(false)); - at::Tensor row_id_map = torch::empty( - {num_tokens * topK}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); + at::Tensor row_id_map = torch::full( + {num_tokens * topK}, -1, + torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -71,7 +77,7 @@ std::tuple> moe_permute_fwd( static_cast(num_cols)}, dtype); auto sorted_row_id_cu = makeTransformerEngineTensor( - sorted_row_id_ptr, std::vector{static_cast(num_tokens * topK)}, + sorted_row_id_ptr, std::vector{static_cast(num_out_tokens)}, DType::kInt32); auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); From b42546b045af06acc436c081facca36610e6dbc0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 21 Apr 2026 08:00:21 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/csrc/extensions/permutation.cpp | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cpp b/transformer_engine/pytorch/csrc/extensions/permutation.cpp index 6505dba1ee..65c5f1cf4a 100644 --- a/transformer_engine/pytorch/csrc/extensions/permutation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cpp @@ -61,9 +61,9 @@ std::tuple> moe_permute_fwd( at::Tensor permuted_output = torch::empty({num_out_tokens, num_cols}, torch::dtype(input.scalar_type()).device(torch::kCUDA).requires_grad(false)); - at::Tensor row_id_map = torch::full( - {num_tokens * topK}, -1, - torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); + at::Tensor row_id_map = + torch::full({num_tokens * topK}, -1, + torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -77,8 +77,7 @@ std::tuple> moe_permute_fwd( static_cast(num_cols)}, dtype); auto sorted_row_id_cu = makeTransformerEngineTensor( - sorted_row_id_ptr, std::vector{static_cast(num_out_tokens)}, - DType::kInt32); + sorted_row_id_ptr, std::vector{static_cast(num_out_tokens)}, DType::kInt32); auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); nvte_permute(input_cu.data(), permuted_output_cu.data(), sorted_row_id_cu.data(), From b73a1f9df5cd33c777566266c5d714b0554ca9ea Mon Sep 17 00:00:00 2001 From: Jingyi Xi Date: Tue, 21 Apr 2026 16:07:32 +0800 Subject: [PATCH 3/3] Guard against invalid num_out_tokens in moe_permute_fwd Add an NVTE_CHECK that num_out_tokens <= num_tokens * topK and cast num_minus_ones to size_t before the pointer advance, so a negative num_minus_ones (from an invalid num_out_tokens) cannot silently wrap into a huge pointer offset. Signed-off-by: Jingyi Xi --- transformer_engine/pytorch/csrc/extensions/permutation.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cpp b/transformer_engine/pytorch/csrc/extensions/permutation.cpp index 65c5f1cf4a..1b3ba9cef9 100644 --- a/transformer_engine/pytorch/csrc/extensions/permutation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cpp @@ -56,8 +56,11 @@ std::tuple> moe_permute_fwd( // valid suffix, and pre-fill row_id_map with -1 so the dropped slots are marked // without the kernel ever dereferencing a sentinel. num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * topK; + NVTE_CHECK(num_out_tokens <= num_tokens * topK, "num_out_tokens (", num_out_tokens, + ") must not exceed num_tokens*topK (", num_tokens * topK, ")"); const int num_minus_ones = num_tokens * topK - num_out_tokens; - sorted_row_id_ptr = reinterpret_cast(sorted_row_id_ptr) + num_minus_ones * sizeof(int); + sorted_row_id_ptr = reinterpret_cast(sorted_row_id_ptr) + + static_cast(num_minus_ones) * sizeof(int); at::Tensor permuted_output = torch::empty({num_out_tokens, num_cols}, torch::dtype(input.scalar_type()).device(torch::kCUDA).requires_grad(false));