From 0448187486cb56234d50b2a9cc7b6832eb0d2495 Mon Sep 17 00:00:00 2001 From: tdophung Date: Mon, 24 Nov 2025 12:13:38 -0800 Subject: [PATCH 1/2] Change order of arguments to make jax works Signed-off-by: tdophung --- .../common/triton/permutation.py | 42 ++++++++++--------- .../pytorch/triton/permutation.py | 26 ++++++------ 2 files changed, 35 insertions(+), 33 deletions(-) diff --git a/transformer_engine/common/triton/permutation.py b/transformer_engine/common/triton/permutation.py index 3a3a32014f..19d6b81ae2 100644 --- a/transformer_engine/common/triton/permutation.py +++ b/transformer_engine/common/triton/permutation.py @@ -81,10 +81,8 @@ def _argsort(x, indices, n_dims: tl.constexpr): @triton.jit def _row_id_map_pass_1_kernel( - # pointers + # input pointers routing_map_ptr, - row_id_map_ptr, - workspace_ptr, # sizes num_tokens, # strides @@ -92,6 +90,9 @@ def _row_id_map_pass_1_kernel( stride_routing_map_expert, stride_row_id_map_token, stride_row_id_map_expert, + # output pointers + row_id_map_ptr, + workspace_ptr, # metas BLOCK_SIZE: tl.constexpr, ): @@ -156,7 +157,7 @@ def _row_id_map_pass_3_kernel( # pointers row_id_map_ptr, # sizes - num_experts: tl.constexpr, + num_experts, # strides stride_row_id_map_token, stride_row_id_map_expert, @@ -194,17 +195,13 @@ def _row_id_map_pass_3_kernel( @triton.jit def _permute_kernel( - # pointers + # input pointers input_ptr, - output_ptr, row_id_map_ptr, probs_ptr, scale_ptr, - permuted_probs_ptr, permuted_scale_ptr, # sizes - num_experts: tl.constexpr, - hidden_size: tl.constexpr, scale_hidden_dim, # strides stride_row_id_map_token, @@ -220,7 +217,12 @@ def _permute_kernel( stride_permuted_probs_token, stride_permuted_scale_token, stride_permuted_scale_hidden, + # output pointers + output_ptr, + permuted_probs_ptr, # metas + num_experts: tl.constexpr, + hidden_size: tl.constexpr, PERMUTE_PROBS: tl.constexpr, PERMUTE_SCALE: tl.constexpr, BLOCK_SIZE: tl.constexpr, @@ -291,16 +293,11 @@ def _permute_kernel( @triton.jit def _unpermute_kernel( - # pointers + # input pointers input_ptr, - output_ptr, row_id_map_ptr, merging_probs_ptr, permuted_probs_ptr, - unpermuted_probs_ptr, - # sizes - num_experts: tl.constexpr, - hidden_size: tl.constexpr, # strides stride_row_id_map_token, stride_row_id_map_expert, @@ -313,7 +310,12 @@ def _unpermute_kernel( stride_permuted_probs_token, stride_unpermuted_probs_token, stride_unpermuted_probs_expert, + # output pointers + output_ptr, + unpermuted_probs_ptr, # metas + num_experts: tl.constexpr, + hidden_size: tl.constexpr, PROBS_LOAD_WIDTH: tl.constexpr, WITH_MERGING_PROBS: tl.constexpr, PERMUTE_PROBS: tl.constexpr, @@ -546,14 +548,10 @@ def _make_chunk_sort_map_kernel( @triton.jit def _sort_chunks_by_map_kernel( - # pointers + # input pointers input_ptr, - output_ptr, row_id_map_ptr, probs_ptr, - permuted_probs_ptr, - # sizes - hidden_size: tl.constexpr, # strides stride_input_token, stride_input_hidden, @@ -561,7 +559,11 @@ def _sort_chunks_by_map_kernel( stride_output_hidden, stride_probs_token, stride_permuted_probs_token, + # output pointers + output_ptr, + permuted_probs_ptr, # metas + hidden_size: tl.constexpr, PERMUTE_PROBS: tl.constexpr, BLOCK_SIZE: tl.constexpr, FORWARD: tl.constexpr, diff --git a/transformer_engine/pytorch/triton/permutation.py b/transformer_engine/pytorch/triton/permutation.py index da22299fe5..b676c43c26 100644 --- a/transformer_engine/pytorch/triton/permutation.py +++ b/transformer_engine/pytorch/triton/permutation.py @@ -72,13 +72,13 @@ def make_row_id_map( # [0, 0, 0, r, r, r, r]] _row_id_map_pass_1_kernel[grid]( routing_map, - row_id_map, - workspace_tensor, num_tokens, routing_map.stride(0), routing_map.stride(1), row_id_map.stride(0), row_id_map.stride(1), + row_id_map, + workspace_tensor, block_size, ) @@ -169,14 +169,10 @@ def permute_with_mask_map( grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"])) _permute_kernel[grid]( inp, - output, row_id_map, probs, scale, - permuted_probs, permuted_scale, - num_experts, - hidden_size, scale_hidden_dim, row_id_map.stride(0), row_id_map.stride(1), @@ -191,6 +187,10 @@ def permute_with_mask_map( permuted_probs.stride(0) if permuted_probs is not None else None, permuted_scale.stride(0) if permuted_scale is not None else None, permuted_scale.stride(1) if permuted_scale is not None else None, + output, + permuted_probs, + num_experts, + hidden_size, PERMUTE_PROBS=probs is not None, PERMUTE_SCALE=scale is not None, ) @@ -238,13 +238,9 @@ def unpermute_with_mask_map( grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"])) _unpermute_kernel[grid]( inp, - output, row_id_map, merging_probs, permuted_probs, - unpermuted_probs, - num_experts, - hidden_size, row_id_map.stride(0), row_id_map.stride(1), inp.stride(0), @@ -256,6 +252,10 @@ def unpermute_with_mask_map( permuted_probs.stride(0) if permuted_probs is not None else None, unpermuted_probs.stride(0) if unpermuted_probs is not None else None, unpermuted_probs.stride(1) if unpermuted_probs is not None else None, + output, + unpermuted_probs, + num_experts, + hidden_size, PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts), WITH_MERGING_PROBS=merging_probs is not None, PERMUTE_PROBS=permuted_probs is not None, @@ -395,17 +395,17 @@ def sort_chunks_by_map( grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"])) _sort_chunks_by_map_kernel[grid]( inp, - output, row_id_map, probs, - permuted_probs, - hidden_size, inp.stride(0), inp.stride(1), output.stride(0), output.stride(1), probs.stride(0) if probs is not None else None, permuted_probs.stride(0) if permuted_probs is not None else None, + output, + permuted_probs, + hidden_size, PERMUTE_PROBS=probs is not None, FORWARD=is_forward, ) From f717722effca24d1a04ec10881f91eb2e26fde10 Mon Sep 17 00:00:00 2001 From: tdophung Date: Mon, 24 Nov 2025 13:38:58 -0800 Subject: [PATCH 2/2] make num_experts a tl.constepxr again Signed-off-by: tdophung --- transformer_engine/common/triton/permutation.py | 3 +-- transformer_engine/pytorch/triton/permutation.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/triton/permutation.py b/transformer_engine/common/triton/permutation.py index 19d6b81ae2..e8c43f52d2 100644 --- a/transformer_engine/common/triton/permutation.py +++ b/transformer_engine/common/triton/permutation.py @@ -156,12 +156,11 @@ def _row_id_map_pass_2_kernel( def _row_id_map_pass_3_kernel( # pointers row_id_map_ptr, - # sizes - num_experts, # strides stride_row_id_map_token, stride_row_id_map_expert, # metas + num_experts: tl.constexpr, LOAD_SIZE: tl.constexpr, ): pid = tl.program_id(0) diff --git a/transformer_engine/pytorch/triton/permutation.py b/transformer_engine/pytorch/triton/permutation.py index b676c43c26..741dd60c06 100644 --- a/transformer_engine/pytorch/triton/permutation.py +++ b/transformer_engine/pytorch/triton/permutation.py @@ -110,9 +110,9 @@ def make_row_id_map( grid = (num_tokens,) _row_id_map_pass_3_kernel[grid]( row_id_map, - num_experts, row_id_map.stride(0), row_id_map.stride(1), + num_experts, triton.next_power_of_2(num_experts), ) return row_id_map