[kernels] hybrid_w4a16: fix bf16 asymmetric dequant slowdown on gfx1151#953
Conversation
|
fyi @marcusr-amd for bf16 perf |
roberteg16
left a comment
There was a problem hiding this comment.
Are we planning to follow this up with a ticket for the compiler team?
I started to look into this, but my agent run into a dead end. Will need to backtrack and properly report this for further investigation. |
|
In my opinion, the merge commit should be removed, this PR should be rebased, and the golden values should be remeasured by running the full set of tests. My concern is that the subsets of the tests that you and Robert added separately might interact and push the job toward intermittency. |
|
Friendly reminder: The performance tests currently allow a ridiculously wide tolerance (±80%) for the hybrid_triton_w4a16 kernel. |
016e3fa to
1bb9844
Compare
… pre-dequant stub, threadwise element-op Wires several orthogonal template axes through aiter's CK W4A16 b_scale GEMM op surface so each can be runtime-toggled without rebuild. CK submodule bumped from ad2f19fd3 to 2cfd5509f to pick up the matching threadwise BElementwiseOperation plumbing + bf16 truncate element-ops. Axes added (composable as 2x2x2x2 instantiations on top of T, ScaleBlockK): * ScaleBlockK (32 | 128) — the original group_size axis was hardcoded to Scale_Block_K=128; now a template parameter so group_size=32 models (e.g. cyankiwi/Qwen3-VL-4B-Instruct-AWQ-4bit, cyankiwi/gemma-4-31B-it-AWQ-4bit) reach the CK kernel. * bf16 enable — the wrapper's bf16 branch previously TORCH_CHECK'd because the CK submodule didn't ship bf16 DequantPack8 overloads. CK c387fb4 added them; this commit lights up the wrapper's bf16 path with run_kernel<B16>(...). * TruncateBf16Round (false | true) — selects DequantPack8WithZpTruncate vs DequantPack8WithZp at template-instantiation time. After the CK threadwise edits (commit 2cfd5509f) this is a true runtime switch: both flavors live in the same .so as distinct template-mangled symbols. Skips the IEEE round-to-nearest-even chain in the bf16 dequant (worst-case 4e-3 ULP error, inside the W4A16 op-test tolerance). CK analog of vLLM PR ROCm/vllm#953's Triton-side fp32->bf16 truncate. * PreDequantToLDS (false | true) — stub for the pre-dequant-to-LDS variant (Stage 1: load packed B + scales -> dequant in registers -> store to LDS scratch; Stage 2: standard non-b_scale wmma reading bf16 from LDS). Currently the true specialization aliases to the false device-op + TORCH_CHECKs at dispatch with a TODO(AIESW-32282) pointing at the two implementation strategies. Template surface is in place so a follow-up agent can drop in the kernel body without touching the wrapper. Public op signature (torch.ops aiter._C.gemm_w4a16) gains two trailing optional bool args: pre_dequant_to_lds: Optional[bool] = None truncate_bf16_round: Optional[bool] = None Default behaviour is bit-identical to before for callers that don't pass them. The vLLM dispatcher (matthias.ck-w4a16-aiter-bench branch) flips these via VLLM_CK_W4A16_PRE_DEQUANT and VLLM_CK_W4A16_TRUNCATE_BF16 env vars; the latter is default-on after measurement showed CK_TRUNC is the only bf16 config that beats Triton on gfx1151. Smoke test op_tests/test_gemm_w4a16.py covers all 16 combos (sym/asym x fp16/bf16 x G=32/G=128 x RTE/TRUNC) at TOL_REL=5e-3. Changes: - csrc/ck_w4a16/include/gemm_w4a16_common.cuh: 4-axis DeviceGemmInstanceImpl<T, ScaleBlockK, PreDequantToLDS, TruncateBf16Round> with explicit specializations + inline comments on every CK device-op template arg. - csrc/ck_w4a16/gemm_w4a16.cu: 2x2x2 runtime dispatch on group_size, pre_dequant_to_lds, truncate_bf16_round; PreDequantToLDS=true path TORCH_CHECK'd as not yet implemented. - csrc/ck_w4a16/include/gemm_w4a16.h: torch op signature gains the two trailing optionals. - csrc/include/rocm_ops.hpp: pybind binding + py::arg defaults. - aiter/ops/gemm_w4a16.py: Python wrapper + fake gain pre_dequant_to_lds / truncate_bf16_round optionals. - op_tests/test_gemm_w4a16.py: --pre-dequant / --truncate-bf16 CLI flags adding parallel test axes.
Native-bf16 asymmetric w4a16 models (RedHatAI/Qwen3-8B-quantized.w4a16,
Orion-zhen/Qwen3-1.7B-AWQ, cyankiwi/Qwen3-VL-4B-Instruct-AWQ-4bit, ...)
suffered a ~15% prefill TTFT regression on Strix Halo (gfx1151) vs
forcing the same models to fp16, even when the kernel autotune config
was retuned for bf16.
The bottleneck is the HAS_ZP=True branch of _triton_w4a16_skinny_fmt_kernel:
b_fp = (b.to(scales.dtype) - zp_raw[:, None]) * scales[:, None]
For bf16 this casts the full [BLOCK_N, BLOCK_K] int32 nibble tile to
bf16 before subtracting the zero-point. On RDNA3.5 the compiler then
emits an inner loop of:
v_cvt_f32_ubyte0 ; int -> fp32
v_cvt_f16_f32 ; explicit narrowing
v_sub_bf16 (slow) ; bf16 ALU goes via fp32
v_pk_mul_bf16
Mirroring the symmetric path -- subtract in INT first (zp values are
0..15, fits exactly in any int width), then cast once after the subtract
-- lets LLVM fold the conversion into v_dot2_bf16_bf16, since bf16 is
the top 16 bits of fp32 and v_cvt_f32_i32 implicitly produces a usable
bf16 representation. The inner loop collapses to:
v_sub_nc_u32 ; int sub (cheap)
v_cvt_f32_i32 ; int -> fp32 (bf16 lives in high half, free)
v_dot2_bf16_bf16 ; fused MAC
Kernel microbench (M=128, gfx1151, Strix Halo, asymmetric Qwen3-8B
prefill projection shapes, group_size=128):
QKV (24576x4096): 1519 -> 1271 us (-16.3%)
down ( 4096x12288): 815 -> 685 us (-16.0%)
gate_up(6144x4096): 371 -> 324 us (-12.7%)
o_proj( 4096x4096): 271 -> 234 us (-13.7%)
The fp16 branch is left untouched on purpose: switching fp16 to the
int-subtract-first form does NOT collapse to a packed fused MAC (no
v_dot2_f16_f16 on gfx1151, and fp16 is not a sub-bit-pattern of fp32 so
the v_cvt_f16_f32 narrowing is unavoidable). A unified rewrite
regressed fp16 by 6-9% per shape in microbench.
End-to-end on RedHatAI/Qwen3-8B-quantized.w4a16 (input_len=128
output_len=128 num_prompts=10 max_num_seqs=1, Strix Halo):
bf16 fp16
Median TTFT before 137.6 ms 118.8 ms
Median TTFT after 117 ms 119.5 ms (bf16 -15.0%, fp16 noise)
Median TPOT 24.01 ms 24.07 ms (unchanged, decode path)
bf16 is now ~1.5% faster than fp16, so the dtype: float16 pin can be
removed from this model in rocm-scripts/regression_config.yaml.
Correctness: tests/kernels/quantization/test_hybrid_w4a16_triton.py
(20 cases covering sym/asym x fp16/bf16 x 5 shapes) all pass.
Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
Extends the perf-regression suite so the asymmetric-bf16 fast path introduced in the prior commit is guarded by golden TFLOP/s baselines, not just by the kernel-fix microbench. Parametrizes weight allocation and measurement on activation dtype (_provider_dtype helper) and doubles PROVIDERS to four variants: hybrid-w4a16, hybrid-w4a16-zp, hybrid-w4a16-bf16, hybrid-w4a16-zp-bf16. Adds the four Qwen3-8B prefill projection shapes (qkv 4096x6144, o 4096x4096, gate_up 4096x24576, down 12288x4096) that the prior commit's microbench reports on. Golden updates on gfx1151: - Old shapes keep their existing fp16 and fp16-zp baselines untouched; bf16 and bf16-zp providers are added. - New Qwen3-8B shapes get all four providers. - The PR #951 Qwen3.5-35B-A3B shape gets bf16 providers added on top of its existing fp16 baselines. Documents the -zp/-bf16 provider suffix convention in golden/README.md. 22 shapes x 4 providers x 13 batch sizes = 1144 baselines. Full-suite verify on gfx1151: 88 passed, 0 skipped, 0 failed; ~7 min wall time. Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
1bb9844 to
9351c8d
Compare
Improve bf16 perf in w4a16 triton kernel and add bfloat16 support to benchmark infra.
Native-bf16 asymmetric w4a16 models (RedHatAI/Qwen3-8B-quantized.w4a16, Orion-zhen/Qwen3-1.7B-AWQ, cyankiwi/Qwen3-VL-4B-Instruct-AWQ-4bit, ...) suffered a ~15% prefill TTFT regression on Strix Halo (gfx1151) vs forcing the same models to fp16, even when the kernel autotune config was retuned for bf16.
The bottleneck is the HAS_ZP=True branch of _triton_w4a16_skinny_fmt_kernel:
b_fp = (b.to(scales.dtype) - zp_raw[:, None]) * scales[:, None]
For bf16 this casts the full [BLOCK_N, BLOCK_K] int32 nibble tile to bf16 before subtracting the zero-point. On RDNA3.5 the compiler then emits an inner loop of:
v_cvt_f32_ubyte0 ; int -> fp32
v_cvt_f16_f32 ; explicit narrowing
v_sub_bf16 (slow) ; bf16 ALU goes via fp32
v_pk_mul_bf16
Mirroring the symmetric path -- subtract in INT first (zp values are 0..15, fits exactly in any int width), then cast once after the subtract -- lets LLVM fold the conversion into v_dot2_bf16_bf16, since bf16 is the top 16 bits of fp32 and v_cvt_f32_i32 implicitly produces a usable bf16 representation. The inner loop collapses to:
v_sub_nc_u32 ; int sub (cheap)
v_cvt_f32_i32 ; int -> fp32 (bf16 lives in high half, free)
v_dot2_bf16_bf16 ; fused MAC
Kernel microbench (M=128, gfx1151, Strix Halo, asymmetric Qwen3-8B prefill projection shapes, group_size=128):
QKV (24576x4096): 1519 -> 1271 us (-16.3%)
down ( 4096x12288): 815 -> 685 us (-16.0%)
gate_up(6144x4096): 371 -> 324 us (-12.7%)
o_proj( 4096x4096): 271 -> 234 us (-13.7%)
The fp16 branch is left untouched on purpose: switching fp16 to the int-subtract-first form does NOT collapse to a packed fused MAC (no v_dot2_f16_f16 on gfx1151, and fp16 is not a sub-bit-pattern of fp32 so the v_cvt_f16_f32 narrowing is unavoidable). A unified rewrite regressed fp16 by 6-9% per shape in microbench.
End-to-end on RedHatAI/Qwen3-8B-quantized.w4a16 (input_len=128 output_len=128 num_prompts=10 max_num_seqs=1, Strix Halo):
bf16 is now ~1.5% faster than fp16, so the dtype: float16 pin can be removed from this model in rocm-scripts/regression_config.yaml.
Correctness: tests/kernels/quantization/test_hybrid_w4a16_triton.py (20 cases covering sym/asym x fp16/bf16 x 5 shapes) all pass.