Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 22 additions & 21 deletions transformer_engine/common/triton/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,18 @@ def _argsort(x, indices, n_dims: tl.constexpr):

@triton.jit
def _row_id_map_pass_1_kernel(
# pointers
# input pointers
routing_map_ptr,
row_id_map_ptr,
workspace_ptr,
# sizes
num_tokens,
# strides
stride_routing_map_token,
stride_routing_map_expert,
stride_row_id_map_token,
stride_row_id_map_expert,
# output pointers
row_id_map_ptr,
workspace_ptr,
# metas
BLOCK_SIZE: tl.constexpr,
):
Expand Down Expand Up @@ -155,12 +156,11 @@ def _row_id_map_pass_2_kernel(
def _row_id_map_pass_3_kernel(
# pointers
row_id_map_ptr,
# sizes
num_experts: tl.constexpr,
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
# metas
num_experts: tl.constexpr,
LOAD_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
Expand Down Expand Up @@ -194,17 +194,13 @@ def _row_id_map_pass_3_kernel(

@triton.jit
def _permute_kernel(
# pointers
# input pointers
input_ptr,
output_ptr,
row_id_map_ptr,
probs_ptr,
scale_ptr,
permuted_probs_ptr,
permuted_scale_ptr,
# sizes
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
scale_hidden_dim,
# strides
stride_row_id_map_token,
Expand All @@ -220,7 +216,12 @@ def _permute_kernel(
stride_permuted_probs_token,
stride_permuted_scale_token,
stride_permuted_scale_hidden,
# output pointers
output_ptr,
permuted_probs_ptr,
# metas
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
PERMUTE_PROBS: tl.constexpr,
PERMUTE_SCALE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
Expand Down Expand Up @@ -291,16 +292,11 @@ def _permute_kernel(

@triton.jit
def _unpermute_kernel(
# pointers
# input pointers
input_ptr,
output_ptr,
row_id_map_ptr,
merging_probs_ptr,
permuted_probs_ptr,
unpermuted_probs_ptr,
# sizes
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
Expand All @@ -313,7 +309,12 @@ def _unpermute_kernel(
stride_permuted_probs_token,
stride_unpermuted_probs_token,
stride_unpermuted_probs_expert,
# output pointers
output_ptr,
unpermuted_probs_ptr,
# metas
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
PROBS_LOAD_WIDTH: tl.constexpr,
WITH_MERGING_PROBS: tl.constexpr,
PERMUTE_PROBS: tl.constexpr,
Expand Down Expand Up @@ -546,22 +547,22 @@ def _make_chunk_sort_map_kernel(

@triton.jit
def _sort_chunks_by_map_kernel(
# pointers
# input pointers
input_ptr,
output_ptr,
row_id_map_ptr,
probs_ptr,
permuted_probs_ptr,
# sizes
hidden_size: tl.constexpr,
# strides
stride_input_token,
stride_input_hidden,
stride_output_token,
stride_output_hidden,
stride_probs_token,
stride_permuted_probs_token,
# output pointers
output_ptr,
permuted_probs_ptr,
# metas
hidden_size: tl.constexpr,
PERMUTE_PROBS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
FORWARD: tl.constexpr,
Expand Down
28 changes: 14 additions & 14 deletions transformer_engine/pytorch/triton/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,13 @@ def make_row_id_map(
# [0, 0, 0, r, r, r, r]]
_row_id_map_pass_1_kernel[grid](
routing_map,
row_id_map,
workspace_tensor,
num_tokens,
routing_map.stride(0),
routing_map.stride(1),
row_id_map.stride(0),
row_id_map.stride(1),
row_id_map,
workspace_tensor,
block_size,
)

Expand Down Expand Up @@ -110,9 +110,9 @@ def make_row_id_map(
grid = (num_tokens,)
_row_id_map_pass_3_kernel[grid](
row_id_map,
num_experts,
row_id_map.stride(0),
row_id_map.stride(1),
num_experts,
triton.next_power_of_2(num_experts),
)
return row_id_map
Expand Down Expand Up @@ -169,14 +169,10 @@ def permute_with_mask_map(
grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"]))
_permute_kernel[grid](
inp,
output,
row_id_map,
probs,
scale,
permuted_probs,
permuted_scale,
num_experts,
hidden_size,
scale_hidden_dim,
row_id_map.stride(0),
row_id_map.stride(1),
Expand All @@ -191,6 +187,10 @@ def permute_with_mask_map(
permuted_probs.stride(0) if permuted_probs is not None else None,
permuted_scale.stride(0) if permuted_scale is not None else None,
permuted_scale.stride(1) if permuted_scale is not None else None,
output,
permuted_probs,
num_experts,
hidden_size,
PERMUTE_PROBS=probs is not None,
PERMUTE_SCALE=scale is not None,
)
Expand Down Expand Up @@ -238,13 +238,9 @@ def unpermute_with_mask_map(
grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"]))
_unpermute_kernel[grid](
inp,
output,
row_id_map,
merging_probs,
permuted_probs,
unpermuted_probs,
num_experts,
hidden_size,
row_id_map.stride(0),
row_id_map.stride(1),
inp.stride(0),
Expand All @@ -256,6 +252,10 @@ def unpermute_with_mask_map(
permuted_probs.stride(0) if permuted_probs is not None else None,
unpermuted_probs.stride(0) if unpermuted_probs is not None else None,
unpermuted_probs.stride(1) if unpermuted_probs is not None else None,
output,
unpermuted_probs,
num_experts,
hidden_size,
PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts),
WITH_MERGING_PROBS=merging_probs is not None,
PERMUTE_PROBS=permuted_probs is not None,
Expand Down Expand Up @@ -395,17 +395,17 @@ def sort_chunks_by_map(
grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"]))
_sort_chunks_by_map_kernel[grid](
inp,
output,
row_id_map,
probs,
permuted_probs,
hidden_size,
inp.stride(0),
inp.stride(1),
output.stride(0),
output.stride(1),
probs.stride(0) if probs is not None else None,
permuted_probs.stride(0) if permuted_probs is not None else None,
output,
permuted_probs,
hidden_size,
PERMUTE_PROBS=probs is not None,
FORWARD=is_forward,
)
Expand Down
Loading