Skip to content

Add gather / embedding StableHLO converter (#483)#487

Merged
michalharakal merged 2 commits intodevelopfrom
feature/483-gather-embedding-converter
Apr 13, 2026
Merged

Add gather / embedding StableHLO converter (#483)#487
michalharakal merged 2 commits intodevelopfrom
feature/483-gather-embedding-converter

Conversation

@michalharakal
Copy link
Copy Markdown
Contributor

Closes #483.

Summary

Adds a real converter for `gather` / `embedding` / `index_select` / `Embedding`. Every LLM export begins with a token-id \u2192 embedding lookup, and without this converter a traced Llama / Mistral / Qwen / Gemma forward pass was failing at its very first operation — it never reached the norms, softmax, RMSNorm, LayerNorm, or attention that the other P1 converters cover.

Target lowering

For the canonical `embedding(input_ids)` shape (vocab_size=8, hidden_size=4, seq_len=3, axis=0):

```mlir
%out = stablehlo.gather(%W, %ids)
{ dimension_numbers = #stablehlo.gather<
offset_dims = [1],
collapsed_slice_dims = [0],
start_index_map = [0],
index_vector_dim = 1>,
slice_sizes = array<i64: 1, 4>,
indices_are_sorted = false }
: (tensor<8x4xf32>, tensor<3xi32>) -> tensor<3x4xf32>
```

The `offset_dims` / `collapsed_slice_dims` / `start_index_map` are derived from the gather axis. `slice_sizes` gets a `1` along the gathered axis and the full extent along every other dimension. Negative axes are normalized against the weight rank.

Registered in `StableHloConverterFactory.createBasic` / `createExtended` / `createFast`.

Two commits

1. Failing test

`GatherConverterTest` with 4 cases:

  • Three alias-registration tests (`gather`, `embedding`, `index_select`) asserting each is claimed by a converter and not dropped as unsupported.
  • `embedding_lowering_carries_canonical_dim_numbers_and_slice_sizes` asserting the exact `offset_dims = [1]` / `collapsed_slice_dims = [0]` / `start_index_map = [0]` / `slice_sizes = array<i64: 1, 4>` shape for the canonical embedding case, plus a tight operand-shape assertion:
    ```kotlin
    assertTrue(mlir.contains("stablehlo.gather(%arg0, %arg1)"))
    assertFalse(mlir.contains("stablehlo.gather([%"))
    ```
    This catches a Kotlin string-template pitfall I hit in development: `"$operands[0]"` expands to `[a, b][0]` literal, producing `stablehlo.gather([%arg0, %arg1][0], [%arg0, %arg1][1])`. The assertion pins the correct shape so the bug can never regress silently.

2. The converter

`GatherOperationsConverter.kt` implementing the lowering above, plus factory registration.

Test plan

  • `GatherConverterTest` — 4/4 green
  • `./gradlew :skainet-compile:skainet-compile-hlo:allTests -x kotlinWasmStoreYarnLock` — green across jvmTest, wasmJsTest, wasmJsBrowserTest, wasmWasiTest, wasmWasiNodeTest, macosArm64Test, iosSimulatorArm64Test (`linuxX64Test` skipped on macOS host)
  • CI: full multiplatform build

Out of scope

  • Higher-rank gathers (attention-side index gathers, multi-dim scatter/gather). Add when a traced model surfaces the pattern.
  • Scatter. Separate converter when needed.
  • Quantized embedding tables — depends on further P0-1 track work.
  • Rotary position embedding (RoPE). Separate issue.

🤖 Generated with Claude Code

michalharakal and others added 2 commits April 13, 2026 13:21
Adds GatherConverterTest with four cases:

1. gather_and_embedding_aliases_are_supported — `gather` must
   be claimed by a converter, not fall through to the registry's
   "No converter found" path. Red against develop today.
2. embedding_alias_routes_to_same_lowering — same for the
   `embedding` alias.
3. index_select_alias_routes_to_same_lowering — same for the
   `index_select` alias (PyTorch-style name).
4. embedding_lowering_carries_canonical_dim_numbers_and_slice_sizes
   — the emitted `stablehlo.gather` must carry the dim_numbers
   (offset_dims, collapsed_slice_dims, start_index_map) and
   slice_sizes attributes that downstream MLIR tools (IREE)
   expect for a 1-D index tensor indexing the leading dim of
   a 2-D embedding weight.

Test fixture builds the canonical embedding-lookup shape:
vocab_size=8, hidden_size=4, seq_len=3, axis=0.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Introduces GatherOperationsConverter claiming the operation
names `gather`, `embedding`, `Embedding`, and `index_select`.
Every LLM export begins with a token-id -> embedding lookup, so
without this converter a traced Llama / Mistral / Qwen / Gemma
forward pass was failing at its very first operation and never
reaching the norms, activations, softmax, or attention the other
P1 converters cover.

Target lowering is the canonical `embedding(input_ids)` shape:
1-D index tensor indexing the leading dim of a 2-D embedding
weight. For vocab_size=8, hidden_size=4, seq_len=3 and axis=0
the emitted op is:

    %out = stablehlo.gather(%W, %ids)
      { dimension_numbers = #stablehlo.gather<
          offset_dims = [1],
          collapsed_slice_dims = [0],
          start_index_map = [0],
          index_vector_dim = 1>,
        slice_sizes = array<i64: 1, 4>,
        indices_are_sorted = false }
      : (tensor<8x4xf32>, tensor<3xi32>) -> tensor<3x4xf32>

offset_dims / collapsed_slice_dims / start_index_map are derived
from a single gather axis. slice_sizes gets a 1 along the
gathered axis and the full extent of every other dim. Negative
axes are normalized against weight rank. Higher-rank gathers
(attention-side index gathers, multi-dim scatter/gather) can
be added in follow-ups when a traced model surfaces them; the
first PR deliberately targets the LLM front-door case only.

Registered in StableHloConverterFactory.createBasic /
createExtended / createFast.

Tests: 4/4 in GatherConverterTest — registration-via-missing
for each of the three aliases plus a canonical dim_numbers /
slice_sizes assertion. The canonical test also pins the exact
`stablehlo.gather(%arg0, %arg1)` operand shape so a prior
Kotlin-string-template bug that emitted
`stablehlo.gather([%arg0, %arg1][0], [%arg0, %arg1][1])` can
never regress.

Verified locally with
`./gradlew :skainet-compile:skainet-compile-hlo:allTests
 -x kotlinWasmStoreYarnLock` — green across jvmTest, wasmJsTest,
wasmJsBrowserTest, wasmWasiTest, wasmWasiNodeTest, macosArm64Test,
and iosSimulatorArm64Test.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@michalharakal michalharakal merged commit 8a12d42 into develop Apr 13, 2026
4 checks passed
@michalharakal michalharakal deleted the feature/483-gather-embedding-converter branch April 13, 2026 11:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add gather / embedding StableHLO converter (P1)

1 participant