Add RMSNorm StableHLO converter (#479)#481
Merged
michalharakal merged 2 commits intodevelopfrom Apr 13, 2026
Merged
Conversation
Adds RmsNormConverterTest with three cases: 1. rmsNorm_operation_is_supported_by_neural_net_converter — asserts NeuralNetOperationsConverter registers rmsNorm plus the rms_norm and RMSNorm aliases. Red today. 2. rmsNorm_with_scale_lowers_to_real_ops — builds a 2×4 FP32 graph with an RMSNorm node and a per-channel scale operand, runs the converter, asserts the emitted module contains @reduce_mean, sqrt, divide, broadcast_in_dim, multiply, and is not labelled as "Unsupported operation rmsNorm". Red today because no converter claims the op. 3. rmsNorm_without_scale_still_normalizes — the scale operand is optional (RMSNorm can be used without the trailing affine multiply, though most LLMs do include it). The core norm must still lower to real ops. Tests use a minimal in-file fixture op stub rather than a real RMSNorm Operation subclass since the converter only reads `operation.name` and `operation.parameters`. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Extends NeuralNetOperationsConverter with a convertRmsNorm method
covering the rmsNorm / rms_norm / RMSNorm / RmsNorm operation
names and registers them in supportedOperations. The lowering is
the standard Llama-family form:
rms = sqrt(mean(x^2, axis) + eps)
out = scale * x / rms (scale operand optional)
Emission style matches the softmax fix (#467) and the rest of
the emitter: reductions go through `stablehlo.custom_call
@reduce_mean`, the reduced tensor is broadcast back to the input
shape via `stablehlo.broadcast_in_dim` for the final divide, and
the epsilon is materialized as a scalar constant broadcast into
the reduced shape. Migrating all reductions to real
`stablehlo.reduce` regions is a separate refactor, explicitly
out of scope.
Axis normalization against rank handles negative axes and also
accepts an `IntArray` `normalized_shape` parameter for callers
that prefer PyTorch-style configuration. Default epsilon is
1e-6, matching Llama / Mistral / Qwen / Gemma; callers can
override via `eps` or `epsilon`.
Without a scale operand the final affine multiply is skipped and
the normalized value is returned directly — a few implementations
use RMSNorm without a learnable scale, and dropping the multiply
keeps the emitted MLIR faithful to the input graph.
Tests: 3/3 in RmsNormConverterTest green, full compile-hlo
jvmTest suite still green.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Closes #479.
Summary
Adds a real RMSNorm lowering so modern transformer exports (Llama / Mistral / Qwen / Gemma — every open-weight LLM family uses RMSNorm rather than LayerNorm) stop dropping through the converter registry's "no converter found" path.
Emission style matches the softmax fix (#467) and the rest of the emitter: reductions go through `stablehlo.custom_call @reduce_mean`, the reduced tensor is broadcast back to the input shape via `stablehlo.broadcast_in_dim` for the final divide, and the epsilon is materialized as a scalar constant broadcast into the reduced shape.
Two commits
1. Failing test — `RmsNormConverterTest`
Red against the pre-fix converter: no handler claims `rmsNorm` at all.
2. The fix
Test plan
Out of scope
🤖 Generated with Claude Code