From c26c1e86b5f98a012ae94f005e9e3611e489646f Mon Sep 17 00:00:00 2001 From: Jinman Xie Date: Thu, 29 Jan 2026 20:48:39 -0800 Subject: [PATCH] fix qwen2 fp16 bug --- src/tilegym/ops/cutile/flash_decode.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tilegym/ops/cutile/flash_decode.py b/src/tilegym/ops/cutile/flash_decode.py index 3740da6..833dd3d 100644 --- a/src/tilegym/ops/cutile/flash_decode.py +++ b/src/tilegym/ops/cutile/flash_decode.py @@ -97,7 +97,8 @@ def attention_decode_kernel_grouped( # Compute qk - unconditional execution enables Tensor Core usage # (HEAD_DIM, QUERY_GROUP_TILE_SIZE) @ (TILE_N, HEAD_DIM).T # Result: (TILE_N, QUERY_GROUP_TILE_SIZE) - qk = ct.matmul(k, q) + qk = ct.full((TILE_N, QUERY_GROUP_TILE_SIZE), 0.0, dtype=ct.float32) + qk = ct.mma(k, q, qk) # Process boundary case (non-causal) - apply mask to result only if curr_n + TILE_N > S_kv: