Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 50 additions & 3 deletions fast_llm_external_models/apriel2/vllm/modeling_apriel2.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,8 +585,12 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
return loaded


class Apriel2Attention(nn.Module):
"""Apriel2 attention layer with rotary embeddings and GQA support."""
class Apriel2Attention(nn.Module, AttentionLayerBase):
"""Apriel2 attention layer with rotary embeddings and GQA support.

Inherits from AttentionLayerBase to ensure vLLM uses our get_kv_cache_spec()
which returns the unified block size needed for hybrid models.
"""

def __init__(
self,
Expand All @@ -598,6 +602,7 @@ def __init__(
prefix: str = "",
) -> None:
super().__init__()
self.prefix = prefix
self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
Expand Down Expand Up @@ -678,6 +683,16 @@ def get_layer_bias(layer_name: str) -> bool:
prefix=f"{prefix}.attn",
)

# Override the internal Attention's get_kv_cache_spec to use our unified block size.
# The internal Attention stays registered in static_forward_context (needed for forward
# pass lookup), but when vLLM collects cache specs, it will get our unified block size.
wrapper_self = self # Capture for closure
self.attn.get_kv_cache_spec = lambda vllm_config: wrapper_self.get_kv_cache_spec(vllm_config)

def get_attn_backend(self) -> type[AttentionBackend]:
"""Delegate to internal Attention's backend."""
return self.attn.get_attn_backend()

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -1810,7 +1825,7 @@ def forward(

beta = self.b_proj(hidden_states)[0].float().sigmoid()
g1 = self.f_b_proj(self.f_a_proj(hidden_states)[0])[0]
g1 = fused_kda_gate(g1, self.A_log, self.head_dim, g_bias=self.dt_bias)
g1 = fused_kda_gate(g1, self.A_log.float(), self.head_dim, g_bias=self.dt_bias)
beta = beta.unsqueeze(0)
g1 = g1.unsqueeze(0)

Expand Down Expand Up @@ -2957,7 +2972,38 @@ def _patch_worker_for_placement_switching():
def _get_layer_placements(self) -> dict[int, str]:
return self.get_model().get_layer_placements()

def _clear_kv_cache(self) -> None:
"""Clear all KV cache tensors to prevent stale data after placement switch.

When mixer placement changes (e.g., layer 0 switches from KDA to attention),
the KV cache may contain data written by a different mixer type. Since different
mixers use incompatible cache formats, we must clear the cache to prevent NaN
errors from reading corrupted data.
"""
model_runner = getattr(self, "model_runner", None)
if model_runner is None:
return

kv_caches = getattr(model_runner, "kv_caches", [])
for cache_item in kv_caches:
if cache_item is None:
continue
# KV cache items can be either:
# - torch.Tensor for attention layers
# - list[torch.Tensor] for state-based layers (KDA, Mamba)
if isinstance(cache_item, list):
for tensor in cache_item:
if tensor is not None:
tensor.zero_()
else:
cache_item.zero_()

logger.info("Cleared KV cache tensors for placement switch")

def _set_layer_placements(self, placement: list[str]) -> dict[int, str]:
# Clear KV cache BEFORE changing placement to prevent reading stale data
# written by a different mixer type (which could cause NaN errors)
_clear_kv_cache(self)
return self.get_model().set_layer_placements(placement)

def _get_mixer_names(self) -> tuple[str, ...]:
Expand All @@ -2966,6 +3012,7 @@ def _get_mixer_names(self) -> tuple[str, ...]:
Worker.get_layer_placements = _get_layer_placements
Worker.set_layer_placements = _set_layer_placements
Worker.get_mixer_names = _get_mixer_names
Worker.clear_kv_cache = _clear_kv_cache


_patch_worker_for_placement_switching()