Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 100 additions & 18 deletions ds4_metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,19 @@ static int ds4_gpu_mpp_available(void) {
return g_metal4_tensor_api_enabled && !g_quality_mode;
}

static uint32_t ds4_gpu_env_u32(const char *name, uint32_t default_val) {
const char *env = getenv(name);
if (!env || !env[0]) {
return default_val;
}
char *end = NULL;
const long v = strtol(env, &end, 10);
if (end == env || v <= 0) {
return default_val;
}
return (uint32_t)v;
}

/*
* Retained Metal4 defaults live here instead of behind user-visible options.
* The public runtime has one automatic accelerated path plus the global
Expand Down Expand Up @@ -2556,7 +2569,8 @@ static int ds4_gpu_encode_mul_mm_id_mapped_tile(
id<MTLBuffer> src1,
NSUInteger src1_off,
id<MTLBuffer> dst,
NSUInteger dst_off);
NSUInteger dst_off,
uint32_t token_tile_n);

typedef struct {
int32_t ne11;
Expand Down Expand Up @@ -12613,32 +12627,96 @@ static NSUInteger ds4_gpu_routed_mv_smem(uint32_t type) {
}
}

static id<MTLComputePipelineState> ds4_gpu_routed_mm_pipeline(uint32_t type) {
static id<MTLComputePipelineState> ds4_gpu_routed_mm_pipeline_for_tile(uint32_t type,
uint32_t tile_n) {
switch (type) {
case DS4_METAL_TENSOR_IQ2_XXS:
if (tile_n == 128u) {
return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_iq2_xxs_f32_n128", false);
}
if (tile_n == 64u) {
return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_iq2_xxs_f32_n64", false);
}
return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_iq2_xxs_f32", false);
case DS4_METAL_TENSOR_Q2_K:
if (tile_n == 128u) {
return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q2_K_f32_n128", false);
}
if (tile_n == 64u) {
return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q2_K_f32_n64", false);
}
return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q2_K_f32", false);
case DS4_METAL_TENSOR_Q4_K:
if (tile_n == 128u) {
return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q4_K_f32_n128", false);
}
if (tile_n == 64u) {
return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q4_K_f32_n64", false);
}
return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q4_K_f32", false);
default:
return nil;
}
}

static id<MTLComputePipelineState> ds4_gpu_routed_mm_f16_rhs_pipeline(uint32_t type) {
static id<MTLComputePipelineState> ds4_gpu_routed_mm_pipeline(uint32_t type) {
return ds4_gpu_routed_mm_pipeline_for_tile(type, 32u);
}

static uint32_t ds4_gpu_moe_mm_tile_n(uint32_t gate_type, uint32_t n_tokens) {
const uint32_t tile_max = ds4_gpu_env_u32("DS4_METAL_MOE_TILE_MAX", 128u);
if (tile_max <= 32u) {
return 32u;
}
if (gate_type == DS4_METAL_TENSOR_Q4_K ||
gate_type == DS4_METAL_TENSOR_Q2_K ||
gate_type == DS4_METAL_TENSOR_IQ2_XXS) {
if (tile_max >= 128u && n_tokens >= 128u && (n_tokens % 128u) == 0) {
return 128u;
}
if (tile_max >= 64u && n_tokens >= 64u && (n_tokens % 64u) == 0) {
return 64u;
}
}
return 32u;
}

static id<MTLComputePipelineState> ds4_gpu_routed_mm_f16_rhs_pipeline_for_tile(uint32_t type,
uint32_t tile_n) {
switch (type) {
case DS4_METAL_TENSOR_IQ2_XXS:
if (tile_n == 128u) {
return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_iq2_xxs_f16_n128", false);
}
if (tile_n == 64u) {
return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_iq2_xxs_f16_n64", false);
}
return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_iq2_xxs_f16", false);
case DS4_METAL_TENSOR_Q2_K:
if (tile_n == 128u) {
return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q2_K_f16_n128", false);
}
if (tile_n == 64u) {
return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q2_K_f16_n64", false);
}
return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q2_K_f16", false);
case DS4_METAL_TENSOR_Q4_K:
if (tile_n == 128u) {
return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q4_K_f16_n128", false);
}
if (tile_n == 64u) {
return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q4_K_f16_n64", false);
}
return ds4_gpu_get_mul_mm_id_pipeline("kernel_mul_mm_id_q4_K_f16", false);
default:
return nil;
}
}

static id<MTLComputePipelineState> ds4_gpu_routed_mm_f16_rhs_pipeline(uint32_t type) {
return ds4_gpu_routed_mm_f16_rhs_pipeline_for_tile(type, 32u);
}

static int ds4_gpu_encode_mul_mv_id(
id<MTLCommandBuffer> cb,
id<MTLComputePipelineState> pipeline,
Expand Down Expand Up @@ -12939,19 +13017,17 @@ static int ds4_gpu_encode_mul_mm_id_mapped_tile(
id<MTLBuffer> src1,
NSUInteger src1_off,
id<MTLBuffer> dst,
NSUInteger dst_off) {
NSUInteger dst_off,
uint32_t token_tile_n) {
if (!cb || !mm_pipeline || !mm_args || !src0 || !src1 || !dst ||
!g_moe_id_map_buffer ||
mm_args->ne00 <= 0 || mm_args->ne0 <= 0 ||
mm_args->ne20 <= 0 || mm_args->ne21 <= 0 || mm_args->ne02 <= 0) {
return 0;
}
/*
* The routed MoE grouped matmul uses the legacy 32-token expert-major tile.
* The removed TensorOps variant was not semantically stable on evals, so keep
* this encoder tied to the tested simdgroup kernel shape.
*/
const NSUInteger tile_n = 32u;
const NSUInteger tile_n = token_tile_n != 0u ? (NSUInteger)token_tile_n :
(((NSUInteger)mm_args->ne21 % 64u) == 0u && (NSUInteger)mm_args->ne21 >= 64u) ?
64u : 32u;

const NSUInteger tpe_bytes = (NSUInteger)mm_args->ne02 * sizeof(int32_t);
const NSUInteger hids_bytes = (NSUInteger)mm_args->ne02 * (NSUInteger)mm_args->ne21 * sizeof(int32_t);
Expand Down Expand Up @@ -12995,7 +13071,8 @@ static int ds4_gpu_encode_mul_mm_id_mapped(
src1,
src1_off,
dst,
dst_off);
dst_off,
0u);
}

static int ds4_gpu_encode_attn_out_low_q8_mpp(
Expand Down Expand Up @@ -14317,6 +14394,7 @@ int ds4_gpu_routed_moe_batch_tensor(
ds4_gpu_mul_mm_id_args gate_mm_args = { 0 };
ds4_gpu_mul_mm_id_args down_mm_args = { 0 };
id<MTLComputePipelineState> map_pipeline = nil;
uint32_t moe_mm_tile_n = 32u;
/*
* The grouped routed-MoE matmul loads activation tiles as half before
* using SIMD-group MMA. Store the SwiGLU/route-weight intermediate in
Expand All @@ -14337,12 +14415,13 @@ int ds4_gpu_routed_moe_batch_tensor(
n_expert, n_expert, n_tokens,
request_mid_f16 ? sizeof(uint16_t) : sizeof(float));

moe_mm_tile_n = ds4_gpu_moe_mm_tile_n(gate_type, n_tokens);
map_pipeline = ds4_gpu_get_pipeline(ds4_gpu_mul_mm_id_map0_name(n_expert));
gate_mm_pipeline = ds4_gpu_routed_mm_pipeline(gate_type);
up_mm_pipeline = ds4_gpu_routed_mm_pipeline(gate_type);
gate_mm_pipeline = ds4_gpu_routed_mm_pipeline_for_tile(gate_type, moe_mm_tile_n);
up_mm_pipeline = ds4_gpu_routed_mm_pipeline_for_tile(gate_type, moe_mm_tile_n);
down_mm_pipeline = request_mid_f16 ?
ds4_gpu_routed_mm_f16_rhs_pipeline(down_type) :
ds4_gpu_routed_mm_pipeline(down_type);
ds4_gpu_routed_mm_f16_rhs_pipeline_for_tile(down_type, moe_mm_tile_n) :
ds4_gpu_routed_mm_pipeline_for_tile(down_type, moe_mm_tile_n);
if (!map_pipeline || !gate_mm_pipeline || !up_mm_pipeline || !down_mm_pipeline) {
return 0;
}
Expand Down Expand Up @@ -14435,7 +14514,8 @@ int ds4_gpu_routed_moe_batch_tensor(
xbuf,
ds4_gpu_tensor_offset(x),
gatebuf,
ds4_gpu_tensor_offset(gate));
ds4_gpu_tensor_offset(gate),
moe_mm_tile_n);
DS4_METAL_PROFILE_MOE_STAGE("gate");
}
if (ok) {
Expand All @@ -14447,7 +14527,8 @@ int ds4_gpu_routed_moe_batch_tensor(
xbuf,
ds4_gpu_tensor_offset(x),
upbuf,
ds4_gpu_tensor_offset(up));
ds4_gpu_tensor_offset(up),
moe_mm_tile_n);
DS4_METAL_PROFILE_MOE_STAGE("up");
}
} else if (use_tiny_pair_mv) {
Expand Down Expand Up @@ -14627,7 +14708,8 @@ int ds4_gpu_routed_moe_batch_tensor(
midbuf,
ds4_gpu_tensor_offset(mid),
down_dst,
down_dst_off);
down_dst_off,
moe_mm_tile_n);
} else {
ok = ds4_gpu_encode_mul_mv_id(cb,
down_mv_pipeline,
Expand Down
12 changes: 12 additions & 0 deletions metal/moe.metal
Original file line number Diff line number Diff line change
Expand Up @@ -1751,12 +1751,24 @@ typedef decltype(kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, ha
// Host-visible batched MoE matmul variants for the DS4 quant formats.
template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q2_K_f32_n64")]] kernel mul_mm_id kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q2_K_f32_n128")]] kernel mul_mm_id kernel_mul_mm_id<128, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q4_K_f32_n64")]] kernel mul_mm_id kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q4_K_f32_n128")]] kernel mul_mm_id kernel_mul_mm_id<128, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32_n64")]] kernel mul_mm_id kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32_n128")]] kernel mul_mm_id kernel_mul_mm_id<128, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, half, half4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, half, half4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q2_K_f16_n64")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, half, half4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q2_K_f16_n128")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<128, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, half, half4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, half, half4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q4_K_f16_n64")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, half, half4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q4_K_f16_n128")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<128, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, half, half4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, half, half4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq2_xxs_f16_n64")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<64, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, half, half4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq2_xxs_f16_n128")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<128, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, half, half4x4, half, half2x4>;

#ifdef DS4_METAL_HAS_TENSOR
// Attention-output low-rank projection retained for Metal4 prefill. It uses
Expand Down