From 5b2663590d45bb32e7454124648319b05f845a5b Mon Sep 17 00:00:00 2001 From: LeanBitLab <245915690+LeanBitLab@users.noreply.github.com> Date: Mon, 27 Apr 2026 20:31:24 +0000 Subject: [PATCH] Fix ONNX model cache to only use valid PKV when Sequence Length > 0 --- .../keyboard/latin/utils/ProofreadService.kt | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/app/src/offline/java/helium314/keyboard/latin/utils/ProofreadService.kt b/app/src/offline/java/helium314/keyboard/latin/utils/ProofreadService.kt index aa43f2dd3..35b69bd32 100644 --- a/app/src/offline/java/helium314/keyboard/latin/utils/ProofreadService.kt +++ b/app/src/offline/java/helium314/keyboard/latin/utils/ProofreadService.kt @@ -457,8 +457,19 @@ class ProofreadService(private val context: Context) { for (step in 0 until maxTokens) { // For KV-cache models, only pass the last token after first step + var isValidPkv = false + if (hasPkvInputs && pastKeyValues != null) { + val currentPkv = pastKeyValues!!.values.firstOrNull() + if (currentPkv != null) { + val sequenceLength = currentPkv.info.shape[2] + if (sequenceLength > 0) { + isValidPkv = true + } + } + } + // CRITICAL FIX: Only use valid pastKeyValues if model actually accepts PKV inputs - val inputTokens = if (step > 0 && pastKeyValues != null && hasPkvInputs) { + val inputTokens = if (step > 0 && isValidPkv) { longArrayOf(generatedTokens.last()) } else { generatedTokens.toLongArray() @@ -488,7 +499,7 @@ class ProofreadService(private val context: Context) { } // Add past_key_values from previous step (if available and model expects them) - if (hasPkvInputs && pastKeyValues != null) { + if (isValidPkv) { for ((name, tensor) in pastKeyValues!!) { // Map present.X.* output names to past_key_values.X.* or pkv_* input names val inputName = name.replace("present", "past_key_values") @@ -502,8 +513,8 @@ class ProofreadService(private val context: Context) { } } } - } else if (hasPkvInputs && step == 0) { - // First step with PKV model: provide zero tensors + } else if (hasPkvInputs) { + // First step with PKV model or invalid cache: provide zero tensors // T5 pkv format: pkv_0 to pkv_N where first half is decoder self-attn, second half is encoder cross-attn // Shape: [batch, num_heads, seq_len, head_dim]