Skip to content

[Bug]: MLXLLM Gemma 4 attention crashes when kvBits is set #71

@notatestuser

Description

@notatestuser

Summary

Gemma 4 models loaded via LLMModelFactory crash on the first decode step when GenerateParameters.kvBits is non-nil:

MLXLMCommon/KVCache.swift:1063: Fatal error: `update` was called on `QuantizedKVCache`. Use `updateQuantized` instead.

The fault lies in Libraries/MLXLLM/Models/Gemma4Text.swift's attention forward: it calls cache.update(keys:values:) unconditionally, which is hard-trapped by QuantizedKVCache. The sibling VLM Gemma 4 attention in Libraries/MLXVLM/Models/Gemma4.swift already implements the correct dispatch — it just needs to be ported over.

Affected versions

  • mlx-swift-lm tag b436 (latest at time of filing).
  • Presumably every earlier tag that ships the current Gemma4Text.swift attention forward.

Affected models

Every Gemma 4 variant that routes through LLMModelFactory (i.e. model_type == "gemma4" with LLM loading, or model_type == "gemma4_text"):

  • mlx-community/gemma-4-e4b-it-8bit (dense, 42 layers, 8 attention heads) — verified
  • mlx-community/gemma-4-26b-a4b-it-8bit (MoE) — verified

The bug is not MoE-specific — the dense E4B variant crashes at the same source line. It is a property of the shared attention forward, not of the MoE layers.

Not affected

Every Gemma 4 variant loaded via VLMModelFactory. Libraries/MLXVLM/Models/Gemma4.swift's Gemma4TextAttention.callAsFunction already checks for QuantizedKVCacheProtocol and dispatches accordingly.

Reproduction

Minimal, programmatic — no test harness or CLI required:

import MLX
import MLXLLM
import MLXLMCommon

let config = ModelConfiguration(id: "mlx-community/gemma-4-e4b-it-8bit")
let container = try await LLMModelFactory.shared.loadContainer(configuration: config)

let parameters = GenerateParameters(
    maxTokens: 8,
    kvBits: 4,            // ← any non-nil value triggers it
    quantizedKVStart: 32
)

let userInput = UserInput(chat: [.init(role: .user, content: "hi")])
let lmInput = try await container.prepare(input: userInput)

let stream = try await container.perform { ctx in
    try MLXLMCommon.generate(input: lmInput, parameters: parameters, context: ctx)
}
for await _ in stream { }   // fatalError fires on the first decoded token

Setting kvBits: nil makes the same configuration work perfectly — this isolates the attribution to the quantized-cache code path.

Root cause

Libraries/MLXLLM/Models/Gemma4Text.swift around line 337, inside Gemma4TextAttention.callAsFunction:

if let cache {
    let (updatedK, updatedV) = cache.update(keys: k, values: v)
    keys = updatedK
    values = updatedV
} else {
    keys = k
    values = v
}

KVCache is a protocol. When GenerateParameters.kvBits is set, the model's newCache(...) returns a concrete QuantizedKVCache. That class overrides update(keys:values:) to trap rather than silently run the wrong code path:

// Libraries/MLXLMCommon/KVCache.swift:1060-1066
/// This method is required by the KVCache protocol, but it is not intended to
/// be used with QuantizedKVCache.
/// Use `updateQuantized` instead.
public override func update(keys: MLXArray, values: MLXArray) -> (MLXArray, MLXArray) {
    fatalError(
        "`update` was called on `QuantizedKVCache`. Use `updateQuantized` instead."
    )
}

The KVCache API contract is that quantized caches must be dispatched to updateQuantized(keys:values:), whose QuantizedKVCache.Key / .Value tuples then flow into quantizedScaledDotProductAttention(queries:quantizedKeys:quantizedValues:...) rather than MLXFast.scaledDotProductAttention(...).

Correct pattern — already present in the VLM sibling

Libraries/MLXVLM/Models/Gemma4.swift around lines 690-706 (Gemma4TextAttention.callAsFunction in the VLM file) handles both cache types correctly:

keys = rope(keys, offset: currentOffset)
if let quantizedCache = cache as? QuantizedKVCacheProtocol {
    let (quantizedKeys, quantizedValues) = quantizedCache.updateQuantized(
        keys: keys, values: values)
    kvState = .quantized(
        keys: quantizedKeys,
        values: quantizedValues,
        groupSize: quantizedCache.groupSize,
        bits: quantizedCache.bits,
        mode: quantizedCache.mode
    )
} else {
    if let cache {
        (keys, values) = cache.update(keys: keys, values: values)
    }
    kvState = .regular(keys: keys, values: values)
}

…and then switches on kvState to route the downstream scaledDotProductAttention call:

switch kvState {
case .regular(let keys, let values):
    MLXFast.scaledDotProductAttention(queries: queries, keys: keys, values: values, scale: scale, mask: localMask)
case .quantized(let keys, let values, let groupSize, let bits, let mode):
    quantizedScaledDotProductAttention(queries: queries, quantizedKeys: keys, quantizedValues: values, ...)
}

Suggested fix

Port the VLM's as? QuantizedKVCacheProtocol dispatch + quantizedScaledDotProductAttention branch into Libraries/MLXLLM/Models/Gemma4Text.swift's Gemma4TextAttention.callAsFunction. The change is local to that one function.

Rough shape of the minimal-diff:

// Replace the unconditional `cache.update(...)` branch around line 336-343 with:
let kvState: Gemma4SharedKVState
if let quantizedCache = cache as? QuantizedKVCacheProtocol {
    let (qKeys, qValues) = quantizedCache.updateQuantized(keys: k, values: v)
    kvState = .quantized(
        keys: qKeys, values: qValues,
        groupSize: quantizedCache.groupSize,
        bits: quantizedCache.bits,
        mode: quantizedCache.mode
    )
} else if let cache {
    let (updatedK, updatedV) = cache.update(keys: k, values: v)
    keys = updatedK
    values = updatedV
    kvState = .regular(keys: keys, values: values)
} else {
    keys = k
    values = v
    kvState = .regular(keys: keys, values: values)
}

// …then switch `kvState` where MLXFast.scaledDotProductAttention(...) is currently called.

Gemma4SharedKVState already exists in the VLM Gemma 4 file; it can be hoisted to a shared location or re-declared locally in Gemma4Text.swift.

Current workarounds

  1. TurboKV off on the LLM path: set GenerateParameters.kvBits = nil when the target model is Gemma 4 via LLMModelFactory. The model runs correctly but uses FP16 KV cache (2–4× more memory on long contexts).
  2. Load via VLMModelFactory: the same weights load through the VLM path, where the attention forward handles quantized caches correctly.

Both workarounds produce identical output since the weights are the same — they just take different forward-pass code paths.

References

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions