Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -457,8 +457,19 @@ class ProofreadService(private val context: Context) {

for (step in 0 until maxTokens) {
// For KV-cache models, only pass the last token after first step
var isValidPkv = false
if (hasPkvInputs && pastKeyValues != null) {
val currentPkv = pastKeyValues!!.values.firstOrNull()
if (currentPkv != null) {
val sequenceLength = currentPkv.info.shape[2]
if (sequenceLength > 0) {
isValidPkv = true
}
}
}

// CRITICAL FIX: Only use valid pastKeyValues if model actually accepts PKV inputs
val inputTokens = if (step > 0 && pastKeyValues != null && hasPkvInputs) {
val inputTokens = if (step > 0 && isValidPkv) {
longArrayOf(generatedTokens.last())
} else {
generatedTokens.toLongArray()
Expand Down Expand Up @@ -488,7 +499,7 @@ class ProofreadService(private val context: Context) {
}

// Add past_key_values from previous step (if available and model expects them)
if (hasPkvInputs && pastKeyValues != null) {
if (isValidPkv) {
for ((name, tensor) in pastKeyValues!!) {
// Map present.X.* output names to past_key_values.X.* or pkv_* input names
val inputName = name.replace("present", "past_key_values")
Expand All @@ -502,8 +513,8 @@ class ProofreadService(private val context: Context) {
}
}
}
} else if (hasPkvInputs && step == 0) {
// First step with PKV model: provide zero tensors
} else if (hasPkvInputs) {
// First step with PKV model or invalid cache: provide zero tensors
// T5 pkv format: pkv_0 to pkv_N where first half is decoder self-attn, second half is encoder cross-attn
// Shape: [batch, num_heads, seq_len, head_dim]

Expand Down