From 52246548a9cf7614d17f4e731d62813968d0b05e Mon Sep 17 00:00:00 2001 From: Theinruj Toranavikrai Date: Wed, 27 May 2026 11:31:46 +0700 Subject: [PATCH 1/2] Add wide-token MoE prefill tiles (n64/n128 mul_mm_id). Use 64/128-token expert-major tiles for Q4_K, Q2_K, and IQ2_XXS routed prefill when batch length is aligned. Cap via DS4_METAL_MOE_TILE_MAX (default 128; set 32 to force legacy tiles for A/B). Co-authored-by: Cursor --- ds4_metal.m | 112 ++++++++++++++++++++++++++++++++++++++++-------- metal/moe.metal | 10 +++++ 2 files changed, 104 insertions(+), 18 deletions(-) diff --git a/ds4_metal.m b/ds4_metal.m index 465fb6294..39bd0b3fd 100644 --- a/ds4_metal.m +++ b/ds4_metal.m @@ -795,6 +795,19 @@ static int ds4_gpu_mpp_available(void) { return g_metal4_tensor_api_enabled && !g_quality_mode; } +static uint32_t ds4_gpu_env_u32(const char *name, uint32_t default_val) { + const char *env = getenv(name); + if (!env || !env[0]) { + return default_val; + } + char *end = NULL; + const long v = strtol(env, &end, 10); + if (end == env || v <= 0) { + return default_val; + } + return (uint32_t)v; +} + /* * Retained Metal4 defaults live here instead of behind user-visible options. * The public runtime has one automatic accelerated path plus the global @@ -2556,7 +2569,8 @@ static int ds4_gpu_encode_mul_mm_id_mapped_tile( id src1, NSUInteger src1_off, id dst, - NSUInteger dst_off); + NSUInteger dst_off, + uint32_t token_tile_n); typedef struct { int32_t ne11; @@ -12613,24 +12627,78 @@ static NSUInteger ds4_gpu_routed_mv_smem(uint32_t type) { } } -static id ds4_gpu_routed_mm_pipeline(uint32_t type) { +static id ds4_gpu_routed_mm_pipeline_for_tile(uint32_t type, + uint32_t tile_n) { switch (type) { case DS4_METAL_TENSOR_IQ2_XXS: + if (tile_n == 128u) { + return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_iq2_xxs_f32_n128", false); + } + if (tile_n == 64u) { + return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_iq2_xxs_f32_n64", false); + } return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_iq2_xxs_f32", false); case DS4_METAL_TENSOR_Q2_K: + if (tile_n == 128u) { + return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q2_K_f32_n128", false); + } + if (tile_n == 64u) { + return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q2_K_f32_n64", false); + } return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q2_K_f32", false); case DS4_METAL_TENSOR_Q4_K: + if (tile_n == 128u) { + return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q4_K_f32_n128", false); + } + if (tile_n == 64u) { + return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q4_K_f32_n64", false); + } return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q4_K_f32", false); default: return nil; } } -static id ds4_gpu_routed_mm_f16_rhs_pipeline(uint32_t type) { +static id ds4_gpu_routed_mm_pipeline(uint32_t type) { + return ds4_gpu_routed_mm_pipeline_for_tile(type, 32u); +} + +static uint32_t ds4_gpu_moe_mm_tile_n(uint32_t gate_type, uint32_t n_tokens) { + const uint32_t tile_max = ds4_gpu_env_u32("DS4_METAL_MOE_TILE_MAX", 128u); + if (tile_max <= 32u) { + return 32u; + } + if (gate_type == DS4_METAL_TENSOR_Q4_K || + gate_type == DS4_METAL_TENSOR_Q2_K || + gate_type == DS4_METAL_TENSOR_IQ2_XXS) { + if (tile_max >= 128u && n_tokens >= 128u && (n_tokens % 128u) == 0) { + return 128u; + } + if (tile_max >= 64u && n_tokens >= 64u && (n_tokens % 64u) == 0) { + return 64u; + } + } + return 32u; +} + +static id ds4_gpu_routed_mm_f16_rhs_pipeline_for_tile(uint32_t type, + uint32_t tile_n) { switch (type) { case DS4_METAL_TENSOR_IQ2_XXS: + if (tile_n == 128u) { + return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_iq2_xxs_f16_n128", false); + } + if (tile_n == 64u) { + return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_iq2_xxs_f16_n64", false); + } return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_iq2_xxs_f16", false); case DS4_METAL_TENSOR_Q2_K: + if (tile_n == 128u) { + return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q2_K_f16_n128", false); + } + if (tile_n == 64u) { + return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q2_K_f16_n64", false); + } return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q2_K_f16", false); case DS4_METAL_TENSOR_Q4_K: return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q4_K_f16", false); @@ -12639,6 +12707,10 @@ static NSUInteger ds4_gpu_routed_mv_smem(uint32_t type) { } } +static id ds4_gpu_routed_mm_f16_rhs_pipeline(uint32_t type) { + return ds4_gpu_routed_mm_f16_rhs_pipeline_for_tile(type, 32u); +} + static int ds4_gpu_encode_mul_mv_id( id cb, id pipeline, @@ -12939,19 +13011,17 @@ static int ds4_gpu_encode_mul_mm_id_mapped_tile( id src1, NSUInteger src1_off, id dst, - NSUInteger dst_off) { + NSUInteger dst_off, + uint32_t token_tile_n) { if (!cb || !mm_pipeline || !mm_args || !src0 || !src1 || !dst || !g_moe_id_map_buffer || mm_args->ne00 <= 0 || mm_args->ne0 <= 0 || mm_args->ne20 <= 0 || mm_args->ne21 <= 0 || mm_args->ne02 <= 0) { return 0; } - /* - * The routed MoE grouped matmul uses the legacy 32-token expert-major tile. - * The removed TensorOps variant was not semantically stable on evals, so keep - * this encoder tied to the tested simdgroup kernel shape. - */ - const NSUInteger tile_n = 32u; + const NSUInteger tile_n = token_tile_n != 0u ? (NSUInteger)token_tile_n : + (((NSUInteger)mm_args->ne21 % 64u) == 0u && (NSUInteger)mm_args->ne21 >= 64u) ? + 64u : 32u; const NSUInteger tpe_bytes = (NSUInteger)mm_args->ne02 * sizeof(int32_t); const NSUInteger hids_bytes = (NSUInteger)mm_args->ne02 * (NSUInteger)mm_args->ne21 * sizeof(int32_t); @@ -12995,7 +13065,8 @@ static int ds4_gpu_encode_mul_mm_id_mapped( src1, src1_off, dst, - dst_off); + dst_off, + 0u); } static int ds4_gpu_encode_attn_out_low_q8_mpp( @@ -14317,6 +14388,7 @@ int ds4_gpu_routed_moe_batch_tensor( ds4_gpu_mul_mm_id_args gate_mm_args = { 0 }; ds4_gpu_mul_mm_id_args down_mm_args = { 0 }; id map_pipeline = nil; + uint32_t moe_mm_tile_n = 32u; /* * The grouped routed-MoE matmul loads activation tiles as half before * using SIMD-group MMA. Store the SwiGLU/route-weight intermediate in @@ -14337,12 +14409,13 @@ int ds4_gpu_routed_moe_batch_tensor( n_expert, n_expert, n_tokens, request_mid_f16 ? sizeof(uint16_t) : sizeof(float)); + moe_mm_tile_n = ds4_gpu_moe_mm_tile_n(gate_type, n_tokens); map_pipeline = ds4_gpu_get_pipeline(ds4_gpu_mul_mm_id_map0_name(n_expert)); - gate_mm_pipeline = ds4_gpu_routed_mm_pipeline(gate_type); - up_mm_pipeline = ds4_gpu_routed_mm_pipeline(gate_type); + gate_mm_pipeline = ds4_gpu_routed_mm_pipeline_for_tile(gate_type, moe_mm_tile_n); + up_mm_pipeline = ds4_gpu_routed_mm_pipeline_for_tile(gate_type, moe_mm_tile_n); down_mm_pipeline = request_mid_f16 ? - ds4_gpu_routed_mm_f16_rhs_pipeline(down_type) : - ds4_gpu_routed_mm_pipeline(down_type); + ds4_gpu_routed_mm_f16_rhs_pipeline_for_tile(down_type, moe_mm_tile_n) : + ds4_gpu_routed_mm_pipeline_for_tile(down_type, moe_mm_tile_n); if (!map_pipeline || !gate_mm_pipeline || !up_mm_pipeline || !down_mm_pipeline) { return 0; } @@ -14435,7 +14508,8 @@ int ds4_gpu_routed_moe_batch_tensor( xbuf, ds4_gpu_tensor_offset(x), gatebuf, - ds4_gpu_tensor_offset(gate)); + ds4_gpu_tensor_offset(gate), + moe_mm_tile_n); DS4_METAL_PROFILE_MOE_STAGE("gate"); } if (ok) { @@ -14447,7 +14521,8 @@ int ds4_gpu_routed_moe_batch_tensor( xbuf, ds4_gpu_tensor_offset(x), upbuf, - ds4_gpu_tensor_offset(up)); + ds4_gpu_tensor_offset(up), + moe_mm_tile_n); DS4_METAL_PROFILE_MOE_STAGE("up"); } } else if (use_tiny_pair_mv) { @@ -14627,7 +14702,8 @@ int ds4_gpu_routed_moe_batch_tensor( midbuf, ds4_gpu_tensor_offset(mid), down_dst, - down_dst_off); + down_dst_off, + moe_mm_tile_n); } else { ok = ds4_gpu_encode_mul_mv_id(cb, down_mv_pipeline, diff --git a/metal/moe.metal b/metal/moe.metal index c776e8ddc..17c482cca 100644 --- a/metal/moe.metal +++ b/metal/moe.metal @@ -1751,12 +1751,22 @@ typedef decltype(kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, ha // Host-visible batched MoE matmul variants for the DS4 quant formats. template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>; template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>; +template [[host_name("kernel_mul_mm_id_q2_K_f32_n64")]] kernel mul_mm_id kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>; +template [[host_name("kernel_mul_mm_id_q2_K_f32_n128")]] kernel mul_mm_id kernel_mul_mm_id<128, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>; template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>; +template [[host_name("kernel_mul_mm_id_q4_K_f32_n64")]] kernel mul_mm_id kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>; +template [[host_name("kernel_mul_mm_id_q4_K_f32_n128")]] kernel mul_mm_id kernel_mul_mm_id<128, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>; template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f32_n64")]] kernel mul_mm_id kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f32_n128")]] kernel mul_mm_id kernel_mul_mm_id<128, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>; template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, half, half4x4, half, half2x4>; template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, half, half4x4, half, half2x4>; +template [[host_name("kernel_mul_mm_id_q2_K_f16_n64")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, half, half4x4, half, half2x4>; +template [[host_name("kernel_mul_mm_id_q2_K_f16_n128")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<128, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, half, half4x4, half, half2x4>; template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, half, half4x4, half, half2x4>; template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, half, half4x4, half, half2x4>; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f16_n64")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, half, half4x4, half, half2x4>; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f16_n128")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<128, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, half, half4x4, half, half2x4>; #ifdef DS4_METAL_HAS_TENSOR // Attention-output low-rank projection retained for Metal4 prefill. It uses From d517726a2c8c5b6c4d6726d017ea462935c54cd9 Mon Sep 17 00:00:00 2001 From: Theinruj Toranavikrai Date: Wed, 27 May 2026 20:57:42 +0700 Subject: [PATCH 2/2] Fix - Adding wide Q4_K F16 kernels and updating the host pipeline resolver to match Q2_K/IQ2_XXS. --- ds4_metal.m | 6 ++++++ metal/moe.metal | 2 ++ 2 files changed, 8 insertions(+) diff --git a/ds4_metal.m b/ds4_metal.m index 39bd0b3fd..4e265b9cc 100644 --- a/ds4_metal.m +++ b/ds4_metal.m @@ -12701,6 +12701,12 @@ static uint32_t ds4_gpu_moe_mm_tile_n(uint32_t gate_type, uint32_t n_tokens) { } return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q2_K_f16", false); case DS4_METAL_TENSOR_Q4_K: + if (tile_n == 128u) { + return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q4_K_f16_n128", false); + } + if (tile_n == 64u) { + return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q4_K_f16_n64", false); + } return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q4_K_f16", false); default: return nil; diff --git a/metal/moe.metal b/metal/moe.metal index 17c482cca..abbb84ccd 100644 --- a/metal/moe.metal +++ b/metal/moe.metal @@ -1764,6 +1764,8 @@ template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id_f16 template [[host_name("kernel_mul_mm_id_q2_K_f16_n64")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, half, half4x4, half, half2x4>; template [[host_name("kernel_mul_mm_id_q2_K_f16_n128")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<128, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, half, half4x4, half, half2x4>; template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, half, half4x4, half, half2x4>; +template [[host_name("kernel_mul_mm_id_q4_K_f16_n64")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, half, half4x4, half, half2x4>; +template [[host_name("kernel_mul_mm_id_q4_K_f16_n128")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<128, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, half, half4x4, half, half2x4>; template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, half, half4x4, half, half2x4>; template [[host_name("kernel_mul_mm_id_iq2_xxs_f16_n64")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, half, half4x4, half, half2x4>; template [[host_name("kernel_mul_mm_id_iq2_xxs_f16_n128")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<128, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, half, half4x4, half, half2x4>;