Problem
ActivationOperationsConverter.convertSoftmax() currently emits numerically incorrect StableHLO. It hardcodes placeholder constants instead of actual reductions:
%maxValue = stablehlo.constant dense<0.0> : tensor<?xf32> // should be reduce(max)
%shifted = stablehlo.subtract %input, %maxValue
%exp = stablehlo.exponential %shifted
%sumValue = stablehlo.constant dense<1.0> : tensor<?xf32> // should be reduce(add)
%result = stablehlo.divide %exp, %sumValue
The max(x) and sum(exp(...)) terms are replaced with constants, so:
- Every softmax output is mathematically wrong.
- Any model exported through
StableHloConverter with a softmax is unusable downstream (IREE, MLIR tools, NPU compile path).
- This is especially blocking for the LLM → StableHLO → IREE → NPU (int8/int4) path, where softmax is on the critical path inside every attention block.
File: skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ActivationOperationsConverter.kt (around lines 92-150).
Expected lowering
// shape-preserving softmax along `axis`
%neg_inf = stablehlo.constant dense<-3.4028235e+38> : tensor<f32>
%max = stablehlo.reduce(%input init: %neg_inf) applies stablehlo.maximum across dimensions = [axis]
%max_b = stablehlo.broadcast_in_dim %max, dims = [non-axis dims] : tensor<...>
%shifted = stablehlo.subtract %input, %max_b
%exp = stablehlo.exponential %shifted
%zero = stablehlo.constant dense<0.0> : tensor<f32>
%sum = stablehlo.reduce(%exp init: %zero) applies stablehlo.add across dimensions = [axis]
%sum_b = stablehlo.broadcast_in_dim %sum, dims = [non-axis dims] : tensor<...>
%result = stablehlo.divide %exp, %sum_b
Scope
- Rewrite
convertSoftmax to emit the real reduction form above.
- Handle the
axis parameter correctly (negative axes normalized against rank).
- Unit test in
ActivationOperationsConverterTest: assert no dense<0.0> / dense<1.0> placeholder constants are emitted, and that two stablehlo.reduce blocks are present.
- Update any golden MLIR in
StableHloExportTest or round-trip tests that was accepting the wrong form.
Out of scope
- RMSNorm converter (separate issue).
- LayerNorm proper lowering (currently
custom_call, separate issue).
- Quantized softmax / int8 attention (depends on P0 quant-in-IR work).
Context
Part of the priority-ordered NPU/IREE roadmap. This is P1 (core compile path, correctness prerequisite) — cheap, high-signal, unblocks every transformer export.
Problem
ActivationOperationsConverter.convertSoftmax()currently emits numerically incorrect StableHLO. It hardcodes placeholder constants instead of actual reductions:The
max(x)andsum(exp(...))terms are replaced with constants, so:StableHloConverterwith a softmax is unusable downstream (IREE, MLIR tools, NPU compile path).File:
skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ActivationOperationsConverter.kt(around lines 92-150).Expected lowering
Scope
convertSoftmaxto emit the real reduction form above.axisparameter correctly (negative axes normalized against rank).ActivationOperationsConverterTest: assert nodense<0.0>/dense<1.0>placeholder constants are emitted, and that twostablehlo.reduceblocks are present.StableHloExportTestor round-trip tests that was accepting the wrong form.Out of scope
custom_call, separate issue).Context
Part of the priority-ordered NPU/IREE roadmap. This is P1 (core compile path, correctness prerequisite) — cheap, high-signal, unblocks every transformer export.