Skip to content

Lower layerNorm to real StableHLO ops instead of custom_call stub (P1) #480

@michalharakal

Description

@michalharakal

Context

Part of the P1 correctness track for the NPU / IREE export path. `NeuralNetOperationsConverter.convertLayerNorm` currently emits a placeholder `custom_call`:

```kotlin
"$resultValue = stablehlo.custom_call @layer_norm($input, $scale, $offset) {epsilon = $epsilon} : $outputType"
```

No MLIR tool in the repo understands `@layer_norm`. IREE certainly doesn't. So every `.mlir` module exported today that contains a LayerNorm has a hole punched in it at the most expensive op in every transformer block.

The sibling `softmax` lowering was fixed in #467 (merged) using `custom_call @reduce_max` + `@reduce_sum` + `broadcast_in_dim`. This PR does the same transformation for LayerNorm: replace the hand-wave `@layer_norm` with the actual elementwise decomposition, matching the codebase's existing reduction-via-custom-call style.

Target lowering

layer_norm(x) = scale * (x - mean) / sqrt(var + eps) + offset

Emission pattern (matching softmax's style — `custom_call @reduce_mean` / `@reduce_variance` already exist via `ReductionOperationsConverter`):

```mlir
%mean = stablehlo.custom_call @reduce_mean(%x) {dimensions = [axis], keepdim = true} : tensor
%mean_b = stablehlo.broadcast_in_dim %mean, dims = [non-axis dims] : (tensor) -> tensor
%centered = stablehlo.subtract %x, %mean_b : tensor
%var = stablehlo.custom_call @reduce_variance(%x) {dimensions = [axis], keepdim = true} : tensor
%eps_const = stablehlo.constant dense<1.0e-05> : tensor
%eps_b = stablehlo.broadcast_in_dim %eps_const, dims = [] : (tensor) -> tensor
%var_eps = stablehlo.add %var, %eps_b : tensor
%std = stablehlo.sqrt %var_eps : tensor
%std_b = stablehlo.broadcast_in_dim %std, dims = [non-axis dims] : (tensor) -> tensor
%normed = stablehlo.divide %centered, %std_b : tensor
%scaled = stablehlo.multiply %normed, %scale : tensor // only when scale present
%out = stablehlo.add %scaled, %offset : tensor // only when offset present
```

Scope

  • Rewrite `convertLayerNorm` and `buildLayerNormOperation` in `NeuralNetOperationsConverter` to emit the elementwise decomposition above.
  • Handle negative axis / `normalized_shape` against rank correctly.
  • Scale / offset are optional — omit the final multiply / add when the operand is null.
  • Update / add unit tests in `NeuralNetOperationsConverterTest` to assert: no `@layer_norm` custom_call in the output, and the presence of `@reduce_mean`, `@reduce_variance`, `broadcast_in_dim`, `sqrt`, `divide`.

Out of scope

  • Real `stablehlo.reduce` region bodies. Every reduction in the emitter uses `custom_call @reduce_*` today; migrating all reductions to proper regions is a separate, larger refactor.
  • RMSNorm converter. That's its own issue.
  • Quantized LayerNorm — depends on further P0-1 work.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions