Summary
Gemma 4 models loaded via LLMModelFactory crash on the first decode step when GenerateParameters.kvBits is non-nil:
MLXLMCommon/KVCache.swift:1063: Fatal error: `update` was called on `QuantizedKVCache`. Use `updateQuantized` instead.
The fault lies in Libraries/MLXLLM/Models/Gemma4Text.swift's attention forward: it calls cache.update(keys:values:) unconditionally, which is hard-trapped by QuantizedKVCache. The sibling VLM Gemma 4 attention in Libraries/MLXVLM/Models/Gemma4.swift already implements the correct dispatch — it just needs to be ported over.
Affected versions
mlx-swift-lm tag b436 (latest at time of filing).
- Presumably every earlier tag that ships the current
Gemma4Text.swift attention forward.
Affected models
Every Gemma 4 variant that routes through LLMModelFactory (i.e. model_type == "gemma4" with LLM loading, or model_type == "gemma4_text"):
mlx-community/gemma-4-e4b-it-8bit (dense, 42 layers, 8 attention heads) — verified
mlx-community/gemma-4-26b-a4b-it-8bit (MoE) — verified
The bug is not MoE-specific — the dense E4B variant crashes at the same source line. It is a property of the shared attention forward, not of the MoE layers.
Not affected
Every Gemma 4 variant loaded via VLMModelFactory. Libraries/MLXVLM/Models/Gemma4.swift's Gemma4TextAttention.callAsFunction already checks for QuantizedKVCacheProtocol and dispatches accordingly.
Reproduction
Minimal, programmatic — no test harness or CLI required:
import MLX
import MLXLLM
import MLXLMCommon
let config = ModelConfiguration(id: "mlx-community/gemma-4-e4b-it-8bit")
let container = try await LLMModelFactory.shared.loadContainer(configuration: config)
let parameters = GenerateParameters(
maxTokens: 8,
kvBits: 4, // ← any non-nil value triggers it
quantizedKVStart: 32
)
let userInput = UserInput(chat: [.init(role: .user, content: "hi")])
let lmInput = try await container.prepare(input: userInput)
let stream = try await container.perform { ctx in
try MLXLMCommon.generate(input: lmInput, parameters: parameters, context: ctx)
}
for await _ in stream { } // fatalError fires on the first decoded token
Setting kvBits: nil makes the same configuration work perfectly — this isolates the attribution to the quantized-cache code path.
Root cause
Libraries/MLXLLM/Models/Gemma4Text.swift around line 337, inside Gemma4TextAttention.callAsFunction:
if let cache {
let (updatedK, updatedV) = cache.update(keys: k, values: v)
keys = updatedK
values = updatedV
} else {
keys = k
values = v
}
KVCache is a protocol. When GenerateParameters.kvBits is set, the model's newCache(...) returns a concrete QuantizedKVCache. That class overrides update(keys:values:) to trap rather than silently run the wrong code path:
// Libraries/MLXLMCommon/KVCache.swift:1060-1066
/// This method is required by the KVCache protocol, but it is not intended to
/// be used with QuantizedKVCache.
/// Use `updateQuantized` instead.
public override func update(keys: MLXArray, values: MLXArray) -> (MLXArray, MLXArray) {
fatalError(
"`update` was called on `QuantizedKVCache`. Use `updateQuantized` instead."
)
}
The KVCache API contract is that quantized caches must be dispatched to updateQuantized(keys:values:), whose QuantizedKVCache.Key / .Value tuples then flow into quantizedScaledDotProductAttention(queries:quantizedKeys:quantizedValues:...) rather than MLXFast.scaledDotProductAttention(...).
Correct pattern — already present in the VLM sibling
Libraries/MLXVLM/Models/Gemma4.swift around lines 690-706 (Gemma4TextAttention.callAsFunction in the VLM file) handles both cache types correctly:
keys = rope(keys, offset: currentOffset)
if let quantizedCache = cache as? QuantizedKVCacheProtocol {
let (quantizedKeys, quantizedValues) = quantizedCache.updateQuantized(
keys: keys, values: values)
kvState = .quantized(
keys: quantizedKeys,
values: quantizedValues,
groupSize: quantizedCache.groupSize,
bits: quantizedCache.bits,
mode: quantizedCache.mode
)
} else {
if let cache {
(keys, values) = cache.update(keys: keys, values: values)
}
kvState = .regular(keys: keys, values: values)
}
…and then switches on kvState to route the downstream scaledDotProductAttention call:
switch kvState {
case .regular(let keys, let values):
MLXFast.scaledDotProductAttention(queries: queries, keys: keys, values: values, scale: scale, mask: localMask)
case .quantized(let keys, let values, let groupSize, let bits, let mode):
quantizedScaledDotProductAttention(queries: queries, quantizedKeys: keys, quantizedValues: values, ...)
}
Suggested fix
Port the VLM's as? QuantizedKVCacheProtocol dispatch + quantizedScaledDotProductAttention branch into Libraries/MLXLLM/Models/Gemma4Text.swift's Gemma4TextAttention.callAsFunction. The change is local to that one function.
Rough shape of the minimal-diff:
// Replace the unconditional `cache.update(...)` branch around line 336-343 with:
let kvState: Gemma4SharedKVState
if let quantizedCache = cache as? QuantizedKVCacheProtocol {
let (qKeys, qValues) = quantizedCache.updateQuantized(keys: k, values: v)
kvState = .quantized(
keys: qKeys, values: qValues,
groupSize: quantizedCache.groupSize,
bits: quantizedCache.bits,
mode: quantizedCache.mode
)
} else if let cache {
let (updatedK, updatedV) = cache.update(keys: k, values: v)
keys = updatedK
values = updatedV
kvState = .regular(keys: keys, values: values)
} else {
keys = k
values = v
kvState = .regular(keys: keys, values: values)
}
// …then switch `kvState` where MLXFast.scaledDotProductAttention(...) is currently called.
Gemma4SharedKVState already exists in the VLM Gemma 4 file; it can be hoisted to a shared location or re-declared locally in Gemma4Text.swift.
Current workarounds
- TurboKV off on the LLM path: set
GenerateParameters.kvBits = nil when the target model is Gemma 4 via LLMModelFactory. The model runs correctly but uses FP16 KV cache (2–4× more memory on long contexts).
- Load via
VLMModelFactory: the same weights load through the VLM path, where the attention forward handles quantized caches correctly.
Both workarounds produce identical output since the weights are the same — they just take different forward-pass code paths.
References
Summary
Gemma 4 models loaded via
LLMModelFactorycrash on the first decode step whenGenerateParameters.kvBitsis non-nil:The fault lies in
Libraries/MLXLLM/Models/Gemma4Text.swift's attention forward: it callscache.update(keys:values:)unconditionally, which is hard-trapped byQuantizedKVCache. The sibling VLM Gemma 4 attention inLibraries/MLXVLM/Models/Gemma4.swiftalready implements the correct dispatch — it just needs to be ported over.Affected versions
mlx-swift-lmtagb436(latest at time of filing).Gemma4Text.swiftattention forward.Affected models
Every Gemma 4 variant that routes through
LLMModelFactory(i.e.model_type == "gemma4"with LLM loading, ormodel_type == "gemma4_text"):mlx-community/gemma-4-e4b-it-8bit(dense, 42 layers, 8 attention heads) — verifiedmlx-community/gemma-4-26b-a4b-it-8bit(MoE) — verifiedThe bug is not MoE-specific — the dense E4B variant crashes at the same source line. It is a property of the shared attention forward, not of the MoE layers.
Not affected
Every Gemma 4 variant loaded via
VLMModelFactory.Libraries/MLXVLM/Models/Gemma4.swift'sGemma4TextAttention.callAsFunctionalready checks forQuantizedKVCacheProtocoland dispatches accordingly.Reproduction
Minimal, programmatic — no test harness or CLI required:
Setting
kvBits: nilmakes the same configuration work perfectly — this isolates the attribution to the quantized-cache code path.Root cause
Libraries/MLXLLM/Models/Gemma4Text.swiftaround line 337, insideGemma4TextAttention.callAsFunction:KVCacheis a protocol. WhenGenerateParameters.kvBitsis set, the model'snewCache(...)returns a concreteQuantizedKVCache. That class overridesupdate(keys:values:)to trap rather than silently run the wrong code path:The
KVCacheAPI contract is that quantized caches must be dispatched toupdateQuantized(keys:values:), whoseQuantizedKVCache.Key/.Valuetuples then flow intoquantizedScaledDotProductAttention(queries:quantizedKeys:quantizedValues:...)rather thanMLXFast.scaledDotProductAttention(...).Correct pattern — already present in the VLM sibling
Libraries/MLXVLM/Models/Gemma4.swiftaround lines 690-706 (Gemma4TextAttention.callAsFunctionin the VLM file) handles both cache types correctly:…and then switches on
kvStateto route the downstreamscaledDotProductAttentioncall:Suggested fix
Port the VLM's
as? QuantizedKVCacheProtocoldispatch +quantizedScaledDotProductAttentionbranch intoLibraries/MLXLLM/Models/Gemma4Text.swift'sGemma4TextAttention.callAsFunction. The change is local to that one function.Rough shape of the minimal-diff:
Gemma4SharedKVStatealready exists in the VLM Gemma 4 file; it can be hoisted to a shared location or re-declared locally inGemma4Text.swift.Current workarounds
GenerateParameters.kvBits = nilwhen the target model is Gemma 4 viaLLMModelFactory. The model runs correctly but uses FP16 KV cache (2–4× more memory on long contexts).VLMModelFactory: the same weights load through the VLM path, where the attention forward handles quantized caches correctly.Both workarounds produce identical output since the weights are the same — they just take different forward-pass code paths.
References
Libraries/MLXLLM/Models/Gemma4Text.swift:337QuantizedKVCache.updatetrap:Libraries/MLXLMCommon/KVCache.swift:1060-1066Libraries/MLXVLM/Models/Gemma4.swift:690-706StreamableMoEconformance inb434).