Skip to content

Fix softmax StableHLO lowering to use real reductions #467

@michalharakal

Description

@michalharakal

Problem

ActivationOperationsConverter.convertSoftmax() currently emits numerically incorrect StableHLO. It hardcodes placeholder constants instead of actual reductions:

%maxValue  = stablehlo.constant dense<0.0> : tensor<?xf32>   // should be reduce(max)
%shifted   = stablehlo.subtract %input, %maxValue
%exp       = stablehlo.exponential %shifted
%sumValue  = stablehlo.constant dense<1.0> : tensor<?xf32>   // should be reduce(add)
%result    = stablehlo.divide %exp, %sumValue

The max(x) and sum(exp(...)) terms are replaced with constants, so:

  • Every softmax output is mathematically wrong.
  • Any model exported through StableHloConverter with a softmax is unusable downstream (IREE, MLIR tools, NPU compile path).
  • This is especially blocking for the LLM → StableHLO → IREE → NPU (int8/int4) path, where softmax is on the critical path inside every attention block.

File: skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ActivationOperationsConverter.kt (around lines 92-150).

Expected lowering

// shape-preserving softmax along `axis`
%neg_inf = stablehlo.constant dense<-3.4028235e+38> : tensor<f32>
%max     = stablehlo.reduce(%input init: %neg_inf) applies stablehlo.maximum across dimensions = [axis]
%max_b   = stablehlo.broadcast_in_dim %max, dims = [non-axis dims] : tensor<...>
%shifted = stablehlo.subtract %input, %max_b
%exp     = stablehlo.exponential %shifted
%zero    = stablehlo.constant dense<0.0> : tensor<f32>
%sum     = stablehlo.reduce(%exp init: %zero) applies stablehlo.add across dimensions = [axis]
%sum_b   = stablehlo.broadcast_in_dim %sum, dims = [non-axis dims] : tensor<...>
%result  = stablehlo.divide %exp, %sum_b

Scope

  • Rewrite convertSoftmax to emit the real reduction form above.
  • Handle the axis parameter correctly (negative axes normalized against rank).
  • Unit test in ActivationOperationsConverterTest: assert no dense<0.0> / dense<1.0> placeholder constants are emitted, and that two stablehlo.reduce blocks are present.
  • Update any golden MLIR in StableHloExportTest or round-trip tests that was accepting the wrong form.

Out of scope

  • RMSNorm converter (separate issue).
  • LayerNorm proper lowering (currently custom_call, separate issue).
  • Quantized softmax / int8 attention (depends on P0 quant-in-IR work).

Context

Part of the priority-ordered NPU/IREE roadmap. This is P1 (core compile path, correctness prerequisite) — cheap, high-signal, unblocks every transformer export.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions