Skip to content

Emit tensor_encoding comments in StableHLO output (#473)#475

Merged
michalharakal merged 2 commits intodevelopfrom
feature/473-stablehlo-encoding-comments
Apr 13, 2026
Merged

Emit tensor_encoding comments in StableHLO output (#473)#475
michalharakal merged 2 commits intodevelopfrom
feature/473-stablehlo-encoding-comments

Conversation

@michalharakal
Copy link
Copy Markdown
Contributor

Closes #473.

Summary

First step after #469: teach the StableHLO emitter to read `TensorSpec.tensorEncoding` and preserve it through the compile boundary. Today, even though `TraceToGraphBuilder.finalize` carries `TensorEncoding.Q8_0` / `Q4_K` / `TernaryPacked` / `TurboQuant` onto weight-node output specs, `StableHloConverter` ignores the metadata and emits MLIR that looks identical to a dense FP32 graph. This PR is the cheapest reversible hook that fixes that.

Two commits:

1. Failing test

`EncodingAnnotationTest` with two cases:

  • Positive: build a graph whose weight input `TensorSpec` is annotated with `TensorEncoding.Q8_0` via `withTensorEncoding` (matching what `TraceToGraphBuilder.finalize` produces after Plumb TensorEncoding into TensorSpec.metadata (P0-1 step 1) #469), run it through `StableHloConverter`, assert the emitted module text contains a `tensor_encoding` comment naming `encoding=Q8_0` and `name=w`. Red against `main`.
  • Negative: a dense FP32 graph with no encoding metadata must emit no `tensor_encoding` annotation. A `null` `tensorEncoding` is the unknown / not-carried state, intentionally distinct from `TensorEncoding.Dense`, and the emitter has to treat it as silent. Green baseline.

2. The fix

  • `ConversionContext.emitEncodingAnnotation(role, index, spec)` — helper that emits a single `tensor_encoding` MLIR comment when `spec.tensorEncoding` is non-null. No-op otherwise.
  • `StableHloConverter.initializeInputValues` annotates each function-argument input with `role="input"`.
  • `StableHloConverter.processNode` annotates each non-input node's output specs with `role="result"` immediately before the converter is invoked, so the annotation precedes the op that produces the encoded value.

Net effect: a Q8_0 weight node flowing through the emitter now produces a line like

```mlir
// tensor_encoding: role=result index=0 name=w encoding=Q8_0
```

right before the op that materializes `%w`. MLIR tools ignore comments but text round-trips preserve them, so IREE and any downstream consumer can read the encoding from the emitted `.mlir` file. Converters that want finer-grained placement can call the helper themselves.

Why comments, not real quant dialect ops

StableHLO's `quant.` dialect uses typed quant element types (`!quant.uniform<i8:f32, 0.1:128>`) that are fiddly to emit as text and are not yet consumed anywhere in the SKaiNET pipeline. Emitting them prematurely would just produce MLIR that no existing tool in this repo validates. Comments are the cheapest reversible first hop. Follow-up PRs in the P0-1 track can either grow the comment into a structured `#skainet.tensor_encoding` attribute or cut over to real `stablehlo.custom_call @dequantize_q8_0` stubs matching the style already used by `ReductionOperationsConverter`.

Test plan

  • `EncodingAnnotationTest` — both cases green
  • Full `:skainet-compile:skainet-compile-hlo:jvmTest` — green (no regressions from the new comment)
  • CI: full multiplatform build

Out of scope

  • Real `stablehlo.uniform_quantize` / quant dialect emission.
  • Teaching IREE or any downstream tool to actually consume the comments. That's downstream work.
  • Changing the shape of `TensorEncoding` or `TensorSpec`.
  • Softmax / conv / attention lowering (Fix softmax StableHLO lowering to use real reductions #467 is separate).

🤖 Generated with Claude Code

michalharakal and others added 2 commits April 13, 2026 12:41
Adds EncodingAnnotationTest with two cases:

1. q8_0_weight_produces_tensor_encoding_comment — builds a graph
   whose weight input TensorSpec carries TensorEncoding.Q8_0 via
   withTensorEncoding (the shape that TraceToGraphBuilder.finalize
   already produces after #469 landed) and asserts the emitted
   MLIR contains the `tensor_encoding` annotation comment naming
   `encoding=Q8_0` and `name=w`. Red against StableHloConverter
   today — the metadata is dropped at the emit boundary.

2. dense_graph_emits_no_encoding_comment — dense FP32 graph with
   no encoding metadata must not emit any spurious annotation.
   A `null` tensorEncoding is the unknown / not-carried state,
   not TensorEncoding.Dense, so the emitter has to treat it as
   silent. Green baseline.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Teaches the StableHLO emitter to surface the physical storage
encoding carried on TensorSpec.metadata (introduced in #469) as
diagnostic MLIR comments so quantization metadata survives the
compile boundary instead of being silently erased into FP32.

- ConversionContext gains `emitEncodingAnnotation(role, index,
  spec)`, a helper that emits a single `tensor_encoding` comment
  when `spec.tensorEncoding` is non-null and is a no-op when it's
  null. A `null` encoding remains the unknown / not-carried
  state, intentionally distinct from `TensorEncoding.Dense`.
- StableHloConverter.initializeInputValues annotates function-
  argument input specs with role="input".
- StableHloConverter.processNode annotates each non-input node's
  output specs with role="result" just before the converter is
  invoked, so the annotation precedes the operation that
  produces the encoded value.

Net effect: a Q8_0 weight node flowing through the emitter now
produces a line like

    // tensor_encoding: role=result index=0 name=w encoding=Q8_0

right before the op that materializes `%w`. MLIR tools ignore
comments but text round-trips preserve them, so IREE and any
downstream consumer can read the encoding from the emitted
.mlir file. Converters that want finer-grained placement can
call the helper themselves.

This is the cheapest reversible first hop on the quant-in-IR
path. Follow-up PRs in the P0-1 track can grow this into a
structured `#skainet.tensor_encoding` attribute or cut over to
real `stablehlo.custom_call @dequantize_q8_0` stubs matching
the existing custom_call style — neither of which is useful
without the metadata first flowing, which this change proves
it does.

Full compile-hlo jvmTest suite stays green.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@michalharakal michalharakal merged commit 3bfd2f6 into develop Apr 13, 2026
4 checks passed
@michalharakal michalharakal deleted the feature/473-stablehlo-encoding-comments branch April 13, 2026 10:49
michalharakal added a commit that referenced this pull request Apr 13, 2026
Captures the goal, phases, non-goals, and risks of implementing
a new `skainet-backend-nnapi` module that runs SKaiNET models
on an Amlogic Android dev board's NPU via Android's NNAPI HAL.

Important placement notes in the PRD itself:
- The backend lives in a NEW sibling repo, not in mainline
  SKaiNET. Mainline stays general and IREE-focused.
- The backend builds on top of the already-merged
  skainet-backend-api module (#470) and the TensorEncoding
  metadata flow (#471 / #475 / #478) — no mainline code
  changes are required to ship Phase 1-3.
- Orthogonal to SKaiNET-transformers, which owns LLM modules.

Phases:
  0. Board bring-up + NNAPI device capability dump
  1. FP32 dense matmul end-to-end
  2. int8 quantization path hitting the NPU driver
  3. Target model (MobileNetV3 int8 or TinyLlama candidate)
  4. Optional production packaging

Also documents the known deprecation risk (Android 15 marked
NNAPI deprecated in favor of LiteRT) and captures this as
accepted: ship the Amlogic use case now, plan a LiteRT
successor later.

This file is a planning artifact; it will be moved / referenced
from the new backend repo once that repo exists.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Emit TensorEncoding into StableHLO output (P0-1 step 2)

1 participant