Skip to content

Fix softmax StableHLO lowering to use real reductions (#467)#476

Merged
michalharakal merged 2 commits intodevelopfrom
feature/467-softmax-reductions
Apr 13, 2026
Merged

Fix softmax StableHLO lowering to use real reductions (#467)#476
michalharakal merged 2 commits intodevelopfrom
feature/467-softmax-reductions

Conversation

@michalharakal
Copy link
Copy Markdown
Contributor

Closes #467.

Summary

Unparks the P1 softmax fix that was written during the initial NPU roadmap scoping and held while P0 landed. With P0-1 step 2 (#475) and P0-2 step 2 (#474) now merged, it's time to address the first correctness bug surfaced by the NPU audit.

`ActivationOperationsConverter.convertSoftmax` currently emits numerically wrong MLIR: it hardcodes placeholder constants in place of the `max(x)` and `sum(exp(...))` terms instead of invoking real reductions.

```mlir
%maxValue = stablehlo.constant dense<0.0> : tensor<2x3xf32> // fake max
%shifted = stablehlo.subtract %input, %maxValue
%exp = stablehlo.exponential %shifted
%sumValue = stablehlo.constant dense<1.0> : tensor<2x3xf32> // fake sum
%result = stablehlo.divide %exp, %sumValue
```

Every softmax output from the exporter is therefore mathematically incorrect, and any transformer model exported through `StableHloConverter` with a softmax — which is every transformer — is unusable downstream (IREE, MLIR tools, NPU compile path).

Two commits

1. Failing test

Extends `ActivationOperationsConverterTest.testSoftmaxOperation` to assert that:

  • No `stablehlo.constant dense<0.0> : tensor<2x3xf32>` fake-max placeholder.
  • No `stablehlo.constant dense<1.0> : tensor<2x3xf32>` fake-sum placeholder.
  • `@reduce_max` appears (real reduction).
  • `@reduce_sum` appears (real reduction).
  • `stablehlo.broadcast_in_dim` appears (reduced values broadcast back to input shape).

Red against pre-fix `StableHloConverter`.

2. The fix

Rewrites `convertSoftmax` to emit the correct lowering:

```mlir
%max = stablehlo.custom_call @reduce_max(%input) {dimensions = [axis], keepdim = false} : tensor
%maxB = stablehlo.broadcast_in_dim %max, dims = [non-axis dims] : (tensor) -> tensor
%shift = stablehlo.subtract %input, %maxB : tensor
%exp = stablehlo.exponential %shift : tensor
%sum = stablehlo.custom_call @reduce_sum(%exp) {dimensions = [axis], keepdim = false} : tensor
%sumB = stablehlo.broadcast_in_dim %sum, dims = [non-axis dims] : (tensor) -> tensor
%out = stablehlo.divide %exp, %sumB : tensor
```

Handles negative-axis normalization against rank correctly, and uses `stablehlo.custom_call @reduce_max` / `@reduce_sum` to match the existing reduction-converter style (`ReductionOperationsConverter` emits `custom_call @reduce_sum` today). Migrating every reduction to proper `stablehlo.reduce` regions is a separate, larger refactor that's deliberately out of scope.

Test plan

Out of scope

  • RMSNorm converter (separate issue).
  • LayerNorm proper lowering (currently `custom_call`, separate issue).
  • Migrating all reductions from `custom_call @reduce_*` to proper `stablehlo.reduce` regions.
  • Quantized softmax / int8 attention — depends on further P0-1 track work.

🤖 Generated with Claude Code

michalharakal and others added 2 commits April 13, 2026 12:50
Extends testSoftmaxOperation to assert that the converter must not
emit hardcoded dense<0.0> / dense<1.0> placeholder constants at the
output shape in place of the max(x) and sum(exp(...)) terms, and
must invoke real reductions plus a broadcast_in_dim back to the
input shape. Red against current ActivationOperationsConverter.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Replaces the dense<0.0>/dense<1.0> placeholder constants with
custom_call @reduce_max and @reduce_sum (matching the codebase's
existing reduction-converter style) and broadcasts the reduced
values back to the input shape via stablehlo.broadcast_in_dim
before subtract / divide. Handles negative axis correctly.

Branch is parked pending P0 roadmap work (quant-in-IR + backend-api
extraction) — see issue #467 for context.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@michalharakal michalharakal merged commit abb6a2c into develop Apr 13, 2026
4 checks passed
@michalharakal michalharakal deleted the feature/467-softmax-reductions branch April 13, 2026 10:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Fix softmax StableHLO lowering to use real reductions

1 participant