Skip to content

Add RMSNorm StableHLO converter (P1) #479

@michalharakal

Description

@michalharakal

Context

Follow-up to #467 (softmax fix). With softmax lowered correctly, the next correctness gap in the StableHLO emitter for modern transformer exports is RMSNorm. Every Llama / Mistral / Qwen / Gemma family model normalizes its activations with RMSNorm, not LayerNorm, and there is no converter for it today. `ActivationOperationsConverter` and `NeuralNetOperationsConverter` have zero entries for `rmsNorm` / `rms_norm` / `RMSNorm`, so a traced Llama model hits the converter registry's "no converter found" path and fails.

RMSNorm is:

rms  = sqrt(mean(x^2, axis) + eps)
out  = scale * x / rms

(No mean-centering, no offset — that's what distinguishes it from LayerNorm.)

This PR

Add an `RMSNormConverter` (or extend `NeuralNetOperationsConverter`, whichever matches the codebase convention best for normalization ops) that:

  1. Accepts operation names `rmsNorm`, `rms_norm`, `RMSNorm`, `RmsNorm`.
  2. Takes 1 or 2 operands: `(input)` or `(input, scale)`. If `scale` is absent, emits the norm without the final multiply (scale=1.0).
  3. Reads parameters: `eps` (default `1e-6`, typical for Llama/Mistral), `normalized_shape` or `axis` (default: last dim).
  4. Emits a lowering using the codebase's existing `custom_call @reduce_mean` style (same as `ReductionOperationsConverter` uses for mean / variance) so it's syntactically consistent with the rest of the emitter today:

```mlir
%x_squared = stablehlo.multiply %x, %x : tensor<...>
%mean_sq = stablehlo.custom_call @reduce_mean(%x_squared) {dimensions = [axis], keepdim = true} : tensor<...>
%eps_const = stablehlo.constant dense<1.0e-06> : tensor
%eps_b = stablehlo.broadcast_in_dim %eps_const, dims = [] : (tensor) -> tensor<...>
%var_eps = stablehlo.add %mean_sq, %eps_b : tensor<...>
%rms = stablehlo.sqrt %var_eps : tensor<...>
%normed = stablehlo.divide %x, %rms : tensor<...>
%out = stablehlo.multiply %normed, %scale : tensor<...> // only when scale operand present
```

  1. Handles axis normalization against rank (negative axes).
  2. Registers the converter in `StableHloConverterFactory` so `registry.isSupported("rmsNorm")` returns true end-to-end.
  3. Unit test in the appropriate test file: graph with an RMSNorm node + optional scale operand, assert the emitted module contains all expected ops and *no* placeholder-constant fakes.

Out of scope

  • Real `stablehlo.reduce` region bodies. Every reduction in the emitter uses `custom_call @reduce_*` today — migrating all reductions to proper regions is a separate, larger refactor.
  • Fusing RMSNorm with the following matmul / attention. That's an IREE-side optimization.
  • Quantized RMSNorm — depends on further P0-1 track work.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions