diff --git a/transformer_engine/common/permutation/permutation.cu b/transformer_engine/common/permutation/permutation.cu index fbba27941c..d0a8f44f7c 100644 --- a/transformer_engine/common/permutation/permutation.cu +++ b/transformer_engine/common/permutation/permutation.cu @@ -19,19 +19,12 @@ static __global__ void moe_permute_row_map(const int *sorted_row_id, int *row_id const int tid = threadIdx.x; const int idx = bid * blockDim.x + tid; - if (idx >= num_rows * topK) return; + if (idx >= num_out_tokens) return; int source_row = sorted_row_id[idx]; int source_token_id = source_row / topK; int source_topK_id = source_row % topK; - - if (idx >= num_out_tokens) { - // Set the indices of dropped tokens to -1 - row_id_map[source_topK_id * num_rows + source_token_id] = -1; - } else { - // Create a row id map for subsequent unpermute operation - row_id_map[source_topK_id * num_rows + source_token_id] = idx; - } + row_id_map[source_topK_id * num_rows + source_token_id] = idx; } template @@ -42,7 +35,7 @@ __global__ void moe_unpermute_kernel(const T *input, T *unpermuted_output, const TCompute *s_prob = reinterpret_cast(s_mem); // Each block corresponds to one dest token - const int source_token = blockIdx.x; + const int64_t source_token = blockIdx.x; const int tid = threadIdx.x; if (hasProb) { @@ -65,7 +58,7 @@ __global__ void moe_unpermute_kernel(const T *input, T *unpermuted_output, const TCompute frag_elem[kElementsPerAccess]; TCompute frag_sum[kElementsPerAccess]; - int source_row = row_id_map[source_token]; + int64_t source_row = row_id_map[source_token]; // source_row == -1 represents a dropped token if (source_row != -1) { @@ -134,7 +127,7 @@ __global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *ac TCompute *s_prob = reinterpret_cast(s_mem); // Each block corresponds to one source token - const int source_token = blockIdx.x; + const int64_t source_token = blockIdx.x; const int tid = threadIdx.x; if (hasProb) { @@ -172,7 +165,7 @@ __global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *ac for (int k = 0; k < topKTile; k++) { if (k == topK) break; - int dest_row = row_id_map[index]; + int64_t dest_row = row_id_map[index]; index += num_rows; if (dest_row != -1) { @@ -239,7 +232,7 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id, // moe_permute_fwd int threads = 64; - int blocks = (num_rows * topK + threads - 1) / threads; + int blocks = (num_out_tokens + threads - 1) / threads; moe_permute_row_map<<>>(sorted_row_id, row_id_map, num_rows, topK, num_out_tokens); diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cpp b/transformer_engine/pytorch/csrc/extensions/permutation.cpp index 226705b169..1b3ba9cef9 100644 --- a/transformer_engine/pytorch/csrc/extensions/permutation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cpp @@ -51,13 +51,22 @@ std::tuple> moe_permute_fwd( reinterpret_cast(sorted_indices_ptr), reinterpret_cast(row_id_ptr), reinterpret_cast(sorted_row_id_ptr), num_tokens * topK); - // Output buffer alloc + // Signed radix sort places -1 sentinel entries (e.g. expert-parallel rank mask) + // at the HEAD of sorted_row_id. Skip that prefix so the kernel sees only the + // valid suffix, and pre-fill row_id_map with -1 so the dropped slots are marked + // without the kernel ever dereferencing a sentinel. num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * topK; + NVTE_CHECK(num_out_tokens <= num_tokens * topK, "num_out_tokens (", num_out_tokens, + ") must not exceed num_tokens*topK (", num_tokens * topK, ")"); + const int num_minus_ones = num_tokens * topK - num_out_tokens; + sorted_row_id_ptr = reinterpret_cast(sorted_row_id_ptr) + + static_cast(num_minus_ones) * sizeof(int); at::Tensor permuted_output = torch::empty({num_out_tokens, num_cols}, torch::dtype(input.scalar_type()).device(torch::kCUDA).requires_grad(false)); - at::Tensor row_id_map = torch::empty( - {num_tokens * topK}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); + at::Tensor row_id_map = + torch::full({num_tokens * topK}, -1, + torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -71,8 +80,7 @@ std::tuple> moe_permute_fwd( static_cast(num_cols)}, dtype); auto sorted_row_id_cu = makeTransformerEngineTensor( - sorted_row_id_ptr, std::vector{static_cast(num_tokens * topK)}, - DType::kInt32); + sorted_row_id_ptr, std::vector{static_cast(num_out_tokens)}, DType::kInt32); auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); nvte_permute(input_cu.data(), permuted_output_cu.data(), sorted_row_id_cu.data(),