diff --git a/src/tilegym/ops/cutile/flash_decode.py b/src/tilegym/ops/cutile/flash_decode.py index 3740da6..833dd3d 100644 --- a/src/tilegym/ops/cutile/flash_decode.py +++ b/src/tilegym/ops/cutile/flash_decode.py @@ -97,7 +97,8 @@ def attention_decode_kernel_grouped( # Compute qk - unconditional execution enables Tensor Core usage # (HEAD_DIM, QUERY_GROUP_TILE_SIZE) @ (TILE_N, HEAD_DIM).T # Result: (TILE_N, QUERY_GROUP_TILE_SIZE) - qk = ct.matmul(k, q) + qk = ct.full((TILE_N, QUERY_GROUP_TILE_SIZE), 0.0, dtype=ct.float32) + qk = ct.mma(k, q, qk) # Process boundary case (non-causal) - apply mask to result only if curr_n + TILE_N > S_kv: