Skip to content

feat(models): per-layer intermediate_size for Gemma 4 double-wide MLP#10

Merged
mikeumus merged 1 commit intomainfrom
feat/gemma4-per-layer-intermediate-size
Apr 18, 2026
Merged

feat(models): per-layer intermediate_size for Gemma 4 double-wide MLP#10
mikeumus merged 1 commit intomainfrom
feat/gemma4-per-layer-intermediate-size

Conversation

@mikeumus
Copy link
Copy Markdown

Summary

  • Adds ModelArchitecture::intermediate_size_for_layer(layer) (default = config.intermediate_size)
  • Gemma4Arch overrides via precomputed kv_sources when use_double_wide_mlp=True
  • Threads per-layer lookup through edit_py, edit_cmd, and memit
  • Parses use_double_wide_mlp in detect.rs

Why

Crown-scan on google/gemma-4-e2b-it fails with intermediate-size mismatch in captured keys because the default start layer (3n/5)=21 lands in the KV-shared region where the MLP is double-wide (12288 vs. base 6144). Verified against actual HF tensor shapes: L0/L14 = (6144, 1536), L15/L21/L34 = (12288, 1536).

Mirrors HuggingFace's modeling_gemma4.py:

use_double_wide_mlp = config.use_double_wide_mlp and is_kv_shared_layer
self.intermediate_size = config.intermediate_size * (2 if use_double_wide_mlp else 1)

Test plan

  • cargo test -p larql-models --lib — 74 pass (3 new)
  • cargo test -p larql-inference --lib — 10 pass
  • cargo check -p larql-models -p larql-inference -p larql-cli -p larql-python clean
  • Rebuild Python wheel on L4 Colab + re-run Day-0 sanity on gemma-4-e2b-it
  • Verify edit + apply_patch round-trip on both a double-wide layer (L21) and a normal layer (L0)

🤖 Generated with Claude Code

Gemma 4's `use_double_wide_mlp=True` widens gate/up/down_proj to 2× base
`intermediate_size` on KV-shared layers. On gemma-4-e2b-it (35 layers,
last 20 shared), layers 15–34 have `intermediate=12288`, layers 0–14
have 6144. Crown-scan defaults to `(3n/5)=21` and lands on a double-wide
layer, so the rank-1 edit hit `intermediate-size mismatch in captured
keys` against the config-wide base size.

Adds `ModelArchitecture::intermediate_size_for_layer(layer) -> usize`
(default = `config.intermediate_size`, mirroring `head_dim_for_layer`).
`Gemma4Arch` overrides by reusing the precomputed `kv_sources` set —
one source of truth for KV-shared-layer membership.

Thread the per-layer lookup through:
- `edit_py.rs`: compute `intermediate` after `chosen_layer` is picked.
- `edit_cmd.rs`: same for the CLI path.
- `memit.rs`: `ffn_dim` now per-layer; `run_memit` already solves per
  layer, so covariances remain correctly sized across mixed layers.

Parse `use_double_wide_mlp` in `detect.rs`; add to `ModelConfig`.

Tests (in `detect.rs`):
- `test_detect_gemma4_e2b`: asserts 6144 on L0/L14, 12288 on L15/L21/L34
  — matches the actual HF tensor shapes verified in the Colab repl.
- `test_gemma4_31b_no_double_wide`: 31B lacks the flag → base everywhere.
- `test_non_gemma4_intermediate_default`: Llama returns base for all
  layers via the default trait impl.

The bare `weights.intermediate_size` field is left as "base" for
display / metadata call sites (demos, patch-print, vindex stats).
Patch file-format unchanged: `compute_rank1` / `compute_dense` already
derive `intermediate_size` from the runtime tensor, so new patches for
double-wide layers store 12288 correctly without a version bump.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@mikeumus mikeumus merged commit 44d549b into main Apr 18, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant