diff --git a/ds4_metal.m b/ds4_metal.m index 465fb6294..4e265b9cc 100644 --- a/ds4_metal.m +++ b/ds4_metal.m @@ -795,6 +795,19 @@ static int ds4_gpu_mpp_available(void) { return g_metal4_tensor_api_enabled && !g_quality_mode; } +static uint32_t ds4_gpu_env_u32(const char *name, uint32_t default_val) { + const char *env = getenv(name); + if (!env || !env[0]) { + return default_val; + } + char *end = NULL; + const long v = strtol(env, &end, 10); + if (end == env || v <= 0) { + return default_val; + } + return (uint32_t)v; +} + /* * Retained Metal4 defaults live here instead of behind user-visible options. * The public runtime has one automatic accelerated path plus the global @@ -2556,7 +2569,8 @@ static int ds4_gpu_encode_mul_mm_id_mapped_tile( id src1, NSUInteger src1_off, id dst, - NSUInteger dst_off); + NSUInteger dst_off, + uint32_t token_tile_n); typedef struct { int32_t ne11; @@ -12613,32 +12627,96 @@ static NSUInteger ds4_gpu_routed_mv_smem(uint32_t type) { } } -static id ds4_gpu_routed_mm_pipeline(uint32_t type) { +static id ds4_gpu_routed_mm_pipeline_for_tile(uint32_t type, + uint32_t tile_n) { switch (type) { case DS4_METAL_TENSOR_IQ2_XXS: + if (tile_n == 128u) { + return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_iq2_xxs_f32_n128", false); + } + if (tile_n == 64u) { + return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_iq2_xxs_f32_n64", false); + } return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_iq2_xxs_f32", false); case DS4_METAL_TENSOR_Q2_K: + if (tile_n == 128u) { + return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q2_K_f32_n128", false); + } + if (tile_n == 64u) { + return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q2_K_f32_n64", false); + } return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q2_K_f32", false); case DS4_METAL_TENSOR_Q4_K: + if (tile_n == 128u) { + return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q4_K_f32_n128", false); + } + if (tile_n == 64u) { + return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q4_K_f32_n64", false); + } return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q4_K_f32", false); default: return nil; } } -static id ds4_gpu_routed_mm_f16_rhs_pipeline(uint32_t type) { +static id ds4_gpu_routed_mm_pipeline(uint32_t type) { + return ds4_gpu_routed_mm_pipeline_for_tile(type, 32u); +} + +static uint32_t ds4_gpu_moe_mm_tile_n(uint32_t gate_type, uint32_t n_tokens) { + const uint32_t tile_max = ds4_gpu_env_u32("DS4_METAL_MOE_TILE_MAX", 128u); + if (tile_max <= 32u) { + return 32u; + } + if (gate_type == DS4_METAL_TENSOR_Q4_K || + gate_type == DS4_METAL_TENSOR_Q2_K || + gate_type == DS4_METAL_TENSOR_IQ2_XXS) { + if (tile_max >= 128u && n_tokens >= 128u && (n_tokens % 128u) == 0) { + return 128u; + } + if (tile_max >= 64u && n_tokens >= 64u && (n_tokens % 64u) == 0) { + return 64u; + } + } + return 32u; +} + +static id ds4_gpu_routed_mm_f16_rhs_pipeline_for_tile(uint32_t type, + uint32_t tile_n) { switch (type) { case DS4_METAL_TENSOR_IQ2_XXS: + if (tile_n == 128u) { + return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_iq2_xxs_f16_n128", false); + } + if (tile_n == 64u) { + return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_iq2_xxs_f16_n64", false); + } return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_iq2_xxs_f16", false); case DS4_METAL_TENSOR_Q2_K: + if (tile_n == 128u) { + return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q2_K_f16_n128", false); + } + if (tile_n == 64u) { + return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q2_K_f16_n64", false); + } return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q2_K_f16", false); case DS4_METAL_TENSOR_Q4_K: + if (tile_n == 128u) { + return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q4_K_f16_n128", false); + } + if (tile_n == 64u) { + return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q4_K_f16_n64", false); + } return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q4_K_f16", false); default: return nil; } } +static id ds4_gpu_routed_mm_f16_rhs_pipeline(uint32_t type) { + return ds4_gpu_routed_mm_f16_rhs_pipeline_for_tile(type, 32u); +} + static int ds4_gpu_encode_mul_mv_id( id cb, id pipeline, @@ -12939,19 +13017,17 @@ static int ds4_gpu_encode_mul_mm_id_mapped_tile( id src1, NSUInteger src1_off, id dst, - NSUInteger dst_off) { + NSUInteger dst_off, + uint32_t token_tile_n) { if (!cb || !mm_pipeline || !mm_args || !src0 || !src1 || !dst || !g_moe_id_map_buffer || mm_args->ne00 <= 0 || mm_args->ne0 <= 0 || mm_args->ne20 <= 0 || mm_args->ne21 <= 0 || mm_args->ne02 <= 0) { return 0; } - /* - * The routed MoE grouped matmul uses the legacy 32-token expert-major tile. - * The removed TensorOps variant was not semantically stable on evals, so keep - * this encoder tied to the tested simdgroup kernel shape. - */ - const NSUInteger tile_n = 32u; + const NSUInteger tile_n = token_tile_n != 0u ? (NSUInteger)token_tile_n : + (((NSUInteger)mm_args->ne21 % 64u) == 0u && (NSUInteger)mm_args->ne21 >= 64u) ? + 64u : 32u; const NSUInteger tpe_bytes = (NSUInteger)mm_args->ne02 * sizeof(int32_t); const NSUInteger hids_bytes = (NSUInteger)mm_args->ne02 * (NSUInteger)mm_args->ne21 * sizeof(int32_t); @@ -12995,7 +13071,8 @@ static int ds4_gpu_encode_mul_mm_id_mapped( src1, src1_off, dst, - dst_off); + dst_off, + 0u); } static int ds4_gpu_encode_attn_out_low_q8_mpp( @@ -14317,6 +14394,7 @@ int ds4_gpu_routed_moe_batch_tensor( ds4_gpu_mul_mm_id_args gate_mm_args = { 0 }; ds4_gpu_mul_mm_id_args down_mm_args = { 0 }; id map_pipeline = nil; + uint32_t moe_mm_tile_n = 32u; /* * The grouped routed-MoE matmul loads activation tiles as half before * using SIMD-group MMA. Store the SwiGLU/route-weight intermediate in @@ -14337,12 +14415,13 @@ int ds4_gpu_routed_moe_batch_tensor( n_expert, n_expert, n_tokens, request_mid_f16 ? sizeof(uint16_t) : sizeof(float)); + moe_mm_tile_n = ds4_gpu_moe_mm_tile_n(gate_type, n_tokens); map_pipeline = ds4_gpu_get_pipeline(ds4_gpu_mul_mm_id_map0_name(n_expert)); - gate_mm_pipeline = ds4_gpu_routed_mm_pipeline(gate_type); - up_mm_pipeline = ds4_gpu_routed_mm_pipeline(gate_type); + gate_mm_pipeline = ds4_gpu_routed_mm_pipeline_for_tile(gate_type, moe_mm_tile_n); + up_mm_pipeline = ds4_gpu_routed_mm_pipeline_for_tile(gate_type, moe_mm_tile_n); down_mm_pipeline = request_mid_f16 ? - ds4_gpu_routed_mm_f16_rhs_pipeline(down_type) : - ds4_gpu_routed_mm_pipeline(down_type); + ds4_gpu_routed_mm_f16_rhs_pipeline_for_tile(down_type, moe_mm_tile_n) : + ds4_gpu_routed_mm_pipeline_for_tile(down_type, moe_mm_tile_n); if (!map_pipeline || !gate_mm_pipeline || !up_mm_pipeline || !down_mm_pipeline) { return 0; } @@ -14435,7 +14514,8 @@ int ds4_gpu_routed_moe_batch_tensor( xbuf, ds4_gpu_tensor_offset(x), gatebuf, - ds4_gpu_tensor_offset(gate)); + ds4_gpu_tensor_offset(gate), + moe_mm_tile_n); DS4_METAL_PROFILE_MOE_STAGE("gate"); } if (ok) { @@ -14447,7 +14527,8 @@ int ds4_gpu_routed_moe_batch_tensor( xbuf, ds4_gpu_tensor_offset(x), upbuf, - ds4_gpu_tensor_offset(up)); + ds4_gpu_tensor_offset(up), + moe_mm_tile_n); DS4_METAL_PROFILE_MOE_STAGE("up"); } } else if (use_tiny_pair_mv) { @@ -14627,7 +14708,8 @@ int ds4_gpu_routed_moe_batch_tensor( midbuf, ds4_gpu_tensor_offset(mid), down_dst, - down_dst_off); + down_dst_off, + moe_mm_tile_n); } else { ok = ds4_gpu_encode_mul_mv_id(cb, down_mv_pipeline, diff --git a/metal/moe.metal b/metal/moe.metal index c776e8ddc..abbb84ccd 100644 --- a/metal/moe.metal +++ b/metal/moe.metal @@ -1751,12 +1751,24 @@ typedef decltype(kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, ha // Host-visible batched MoE matmul variants for the DS4 quant formats. template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>; template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>; +template [[host_name("kernel_mul_mm_id_q2_K_f32_n64")]] kernel mul_mm_id kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>; +template [[host_name("kernel_mul_mm_id_q2_K_f32_n128")]] kernel mul_mm_id kernel_mul_mm_id<128, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>; template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>; +template [[host_name("kernel_mul_mm_id_q4_K_f32_n64")]] kernel mul_mm_id kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>; +template [[host_name("kernel_mul_mm_id_q4_K_f32_n128")]] kernel mul_mm_id kernel_mul_mm_id<128, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>; template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f32_n64")]] kernel mul_mm_id kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f32_n128")]] kernel mul_mm_id kernel_mul_mm_id<128, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>; template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, half, half4x4, half, half2x4>; template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, half, half4x4, half, half2x4>; +template [[host_name("kernel_mul_mm_id_q2_K_f16_n64")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, half, half4x4, half, half2x4>; +template [[host_name("kernel_mul_mm_id_q2_K_f16_n128")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<128, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, half, half4x4, half, half2x4>; template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, half, half4x4, half, half2x4>; +template [[host_name("kernel_mul_mm_id_q4_K_f16_n64")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, half, half4x4, half, half2x4>; +template [[host_name("kernel_mul_mm_id_q4_K_f16_n128")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<128, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, half, half4x4, half, half2x4>; template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, half, half4x4, half, half2x4>; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f16_n64")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, half, half4x4, half, half2x4>; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f16_n128")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<128, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, half, half4x4, half, half2x4>; #ifdef DS4_METAL_HAS_TENSOR // Attention-output low-rank projection retained for Metal4 prefill. It uses