Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/tilegym-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ jobs:
test-ops:
name: test-ops
needs: [config, build]
timeout-minutes: 12
timeout-minutes: 17
if: |
always() &&
needs.config.outputs.run_ops == 'true' &&
Expand All @@ -277,7 +277,7 @@ jobs:
password: ${{ secrets.GITHUB_TOKEN }}

- name: Pull and run ops tests
timeout-minutes: 10
timeout-minutes: 15
run: |
OWNER_LOWER=$(echo '${{ github.repository_owner }}' | tr '[:upper:]' '[:lower:]')
IMAGE="ghcr.io/${OWNER_LOWER}/${{ needs.config.outputs.image_name }}:${{ needs.config.outputs.image_tag }}"
Expand Down
4 changes: 2 additions & 2 deletions src/tilegym/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from .attn_interface import get_fmha_interface
from .attn_interface import mla_decoding_interface
from .attn_interface import mla_interface
from .moe_interface import fused_moe_kernel_interface
from .moe_interface import fused_moe

# Import all operation interfaces from the unified ops module
from .ops import *
Expand All @@ -53,7 +53,7 @@
"get_fmha_gemma3_interface",
"mla_interface",
"mla_decoding_interface",
"fused_moe_kernel_interface",
"fused_moe",
]

# Add cutile to exports only if successfully imported
Expand Down
69 changes: 48 additions & 21 deletions src/tilegym/ops/moe_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,31 @@ def fused_experts_impl(
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
config: Optional[dict] = None,
):
"""
Standard MoE implementation with chunked execution.

hidden_states: [batch_size * seq_len, moe_intermediate_size] (9, 2048)
w1: [n_experts, moe_intermediate_size, 2 * hidden_size] (64, 1408 * 2, 2048)
w2: [n_experts, hidden_size, moe_intermediate_size] (64, 2048, 1408)
topk_weights: [batch_size * seq_len, top_k] (9, 6)
topk_ids: [batch_size * seq_len, top_k] (9, 6)
"""
_backend = get_current_backend()
if _backend == "cutile":
config = {
"TILE_SIZE_M": 128,
"TILE_SIZE_N": 128,
"TILE_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4,
}
# Override block sizes to match scale tensor shapes for FP8
if use_fp8_w8a8 and w1_scale is not None:
_, N, K = w1.shape
config["TILE_SIZE_N"] = N // w1_scale.shape[1]
config["TILE_SIZE_K"] = K // w1_scale.shape[2]
device = hidden_states.device
if not w1.is_cuda:
w1 = w1.to(device)
Expand Down Expand Up @@ -238,38 +254,49 @@ def fused_experts_impl(
return out_hidden_states


def fused_moe_kernel_interface(
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
):
) -> torch.Tensor:
"""
hidden_states: [batch_size, top_k, moe_intermediate_size]
topk_ids: [batch_size, top_k]
topk_weights: [batch_size, top_k]
Unified MoE kernel interface.

Args:
hidden_states: Input activations [batch_size * seq_len, hidden_size]
w1: Expert gate+up weights [n_experts, intermediate_size*2, hidden_size]
w2: Expert down weights [n_experts, hidden_size, intermediate_size]
topk_weights: Router weights [batch_size * seq_len, top_k]
topk_ids: Selected expert IDs [batch_size * seq_len, top_k]

Returns:
Output tensor [batch_size * seq_len, hidden_size]


Examples:
# Standard FP16/BF16 MoE
>>> out = fused_moe(hidden, w1, w2, topk_weights, topk_ids)
"""
backend = get_current_backend()
if backend == "cutile":
kernel_configs = {
"TILE_SIZE_M": 128,
"TILE_SIZE_N": 128,
"TILE_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4,
}
assert activation == "silu", "Only silu is supported for now"
return _call_fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids)


def _call_fused_experts_impl(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
):
"""Standard implementation (no quantization - FP16/BF16)."""
inplace = False
return fused_experts_impl(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
inplace,
config=kernel_configs,
inplace=inplace,
use_fp8_w8a8=False,
)
4 changes: 2 additions & 2 deletions src/tilegym/transformers/deepseek2/modeling_deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
raise ImportError("In new transformers version, past_key_value is named to past_key_values")

from tilegym.logger import get_logger
from tilegym.ops import fused_moe_kernel_interface
from tilegym.ops import fused_moe
from tilegym.ops import get_fused_swiglu_module
from tilegym.ops import group_gemm
from tilegym.ops import mla_interface
Expand Down Expand Up @@ -311,7 +311,7 @@ def init_merged_expert_weights(self):
self.init = True

def moe_infer(self, x, topk_ids, topk_weight):
out = fused_moe_kernel_interface(
out = fused_moe(
x,
w1=self.w13_merged,
w2=self.w2_merged,
Expand Down
29 changes: 7 additions & 22 deletions tests/ops/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch

from tilegym import set_backend
from tilegym.ops.moe_interface import fused_moe

from .. import common

Expand Down Expand Up @@ -156,12 +157,11 @@ def test_op(
topk_weights = topk_weights.contiguous()
topk_ids = topk_ids.contiguous()

# Define wrapper for fused_moe_kernel_interface
# Define wrapper for fused_moe
from tilegym.ops.moe_interface import fused_experts_impl
from tilegym.ops.moe_interface import fused_moe_kernel_interface

def moe_wrapper(hidden_states, w1, w2, topk_weights, topk_ids):
return fused_moe_kernel_interface(
return fused_moe(
hidden_states,
w1,
w2,
Expand All @@ -170,27 +170,12 @@ def moe_wrapper(hidden_states, w1, w2, topk_weights, topk_ids):
)

def moe_wrapper_fp8(hidden_states, w1, w2, topk_weights, topk_ids):
kernel_configs = {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": quant_block,
"BLOCK_SIZE_K": quant_block,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4,
}

return fused_experts_impl(
hidden_states_fp8,
w1_fp8,
w2_fp8,
return fused_moe(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
inplace=False,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=hidden_states_scale,
config=kernel_configs,
)

# Set tolerances based on dtype
Expand Down