Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package sk.ainet.compile.hlo

import sk.ainet.lang.graph.ComputeGraph
import sk.ainet.lang.graph.GraphNode
import sk.ainet.lang.tensor.ops.TensorSpec
import sk.ainet.lang.tensor.ops.tensorEncoding

/**
* Context object for maintaining state during StableHLO conversion.
Expand Down Expand Up @@ -54,6 +56,38 @@ public class ConversionContext(
public fun emitComment(comment: String) {
stringBuilder.appendLine(" // $comment")
}

/**
* Emit a `tensor_encoding` diagnostic comment when [spec] carries a
* non-null `tensorEncoding` (set via [sk.ainet.lang.tensor.ops.withTensorEncoding]).
*
* The emitted line has the shape:
*
* ```mlir
* // tensor_encoding: role=<role> index=<i> name=<spec.name> encoding=<enc.name>
* ```
*
* MLIR tools ignore comments but text round-trips preserve them, so
* this is the cheapest way to keep SKaiNET's quantization metadata
* visible through the StableHLO emit boundary until a structured
* attribute or quant-dialect lowering lands. Emits nothing when the
* spec has no encoding — a `null` [sk.ainet.lang.tensor.storage.TensorEncoding]
* is the unknown / not-carried state, intentionally distinct from
* `TensorEncoding.Dense`.
*
* @param role Logical slot the spec occupies for the node being
* emitted (e.g. `"input"` for function arguments, `"result"` for
* node outputs). Free-form so individual converters can use
* finer-grained tags if they call this helper directly.
* @param index Positional index of the spec within its role, e.g.
* the output port index for a multi-result node.
*/
public fun emitEncodingAnnotation(role: String, index: Int, spec: TensorSpec) {
val encoding = spec.tensorEncoding ?: return
emitComment(
"tensor_encoding: role=$role index=$index name=${spec.name} encoding=${encoding.name}"
)
}

/**
* Get the complete generated MLIR content
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,19 @@ public class StableHloConverter(
inputNodes.forEachIndexed { idx, node ->
val valueName = "%arg$idx"
context.setValueName(node.id, valueName)

// Add comment for clarity
node.outputs.firstOrNull()?.let { spec ->
context.emitComment("input ${node.id}: ${spec.name} : ${typeMapper.mapTensorType(spec)}")
}

// Preserve any physical storage encoding carried on the input
// spec (Q4_K / Q8_0 / TernaryPacked / TurboQuant / …) as an
// MLIR comment so downstream tools see that quantization
// flowed through. No-op when the spec has no encoding.
node.outputs.forEachIndexed { outIdx, spec ->
context.emitEncodingAnnotation(role = "input", index = outIdx, spec = spec)
}
}
}

Expand All @@ -138,11 +146,19 @@ public class StableHloConverter(
private fun processNode(node: GraphNode, context: ConversionContext) {
val converter = registry.getConverter(node.operation.name)
?: throw UnsupportedOperationException("No converter found for operation: ${node.operation.name}")

// Get input operands from context
val inputNodes = context.getInputNodes(node)
val operands = inputNodes.mapNotNull { context.getValueName(it.id) }


// Surface any physical storage encoding declared on this node's
// result specs as an MLIR comment before the operation is
// emitted. Converters that want finer-grained placement can call
// ConversionContext.emitEncodingAnnotation themselves.
node.outputs.forEachIndexed { outIdx, spec ->
context.emitEncodingAnnotation(role = "result", index = outIdx, spec = spec)
}

// Convert the operation
val result = converter.convert(node, operands, context)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
package sk.ainet.compile.hlo

import sk.ainet.lang.graph.DefaultComputeGraph
import sk.ainet.lang.graph.GraphEdge
import sk.ainet.lang.graph.GraphNode
import sk.ainet.lang.tensor.ops.AddOperation
import sk.ainet.lang.tensor.ops.InputOperation
import sk.ainet.lang.tensor.ops.TensorSpec
import sk.ainet.lang.tensor.ops.withTensorEncoding
import sk.ainet.lang.tensor.storage.TensorEncoding
import sk.ainet.lang.types.DType
import kotlin.test.Test
import kotlin.test.assertFalse
import kotlin.test.assertTrue

/**
* Covers the emitter hook added for #473: when a [TensorSpec] flowing
* through the graph carries a non-null `tensorEncoding`, the emitted
* StableHLO module must preserve that information as a comment so
* downstream tools (and humans reading the MLIR) can see that
* quantization flowed through the compile boundary.
*/
class EncodingAnnotationTest {

@Test
fun q8_0_weight_produces_tensor_encoding_comment() {
val graph = DefaultComputeGraph()

val inputA = GraphNode(
id = "a",
operation = InputOperation<DType, Any>(),
inputs = emptyList(),
outputs = listOf(TensorSpec("a", listOf(1, 4), "FP32"))
)

// A Q8_0 weight, synthesized the way TraceToGraphBuilder.finalize()
// produces its weight nodes after #469 landed: the spec carries
// TensorEncoding.Q8_0 on its metadata.
val weightSpec = TensorSpec("w", listOf(1, 4), "FP32")
.withTensorEncoding(TensorEncoding.Q8_0)
val weightNode = GraphNode(
id = "w",
operation = InputOperation<DType, Any>(),
inputs = emptyList(),
outputs = listOf(weightSpec)
)

val add = GraphNode(
id = "add1",
operation = AddOperation<DType, Any>(),
inputs = listOf(
TensorSpec("a", listOf(1, 4), "FP32"),
weightSpec
),
outputs = listOf(TensorSpec("out", listOf(1, 4), "FP32"))
)

graph.addNode(inputA)
graph.addNode(weightNode)
graph.addNode(add)
graph.addEdge(GraphEdge("e1", inputA, add, 0, 0, inputA.outputs[0]))
graph.addEdge(GraphEdge("e2", weightNode, add, 0, 1, weightSpec))

val mlir = toStableHlo(graph, "quant_add").content
println("[DEBUG_LOG] quant-annotated export:\n$mlir")

// The emitter must have surfaced the encoding as a comment near
// the weight input's initialization. MLIR tools ignore comments
// but the text round-trips preserve them, so this is the cheapest
// way to keep SKaiNET's quantization metadata visible through the
// StableHLO emit boundary.
assertTrue(
mlir.contains("tensor_encoding"),
"emitter must include a tensor_encoding annotation comment"
)
assertTrue(
mlir.contains("encoding=Q8_0"),
"annotation must name the concrete TensorEncoding (Q8_0)"
)
assertTrue(
mlir.contains("name=w"),
"annotation must identify the tensor the encoding applies to"
)
}

@Test
fun dense_graph_emits_no_encoding_comment() {
// An all-FP32 graph with no encoding metadata must not introduce
// spurious tensor_encoding comments. A `null` tensorEncoding is
// the unknown / not-carried state, not "Dense", and the emitter
// must treat it as silent.
val graph = DefaultComputeGraph()

val inputA = GraphNode(
id = "a",
operation = InputOperation<DType, Any>(),
inputs = emptyList(),
outputs = listOf(TensorSpec("a", listOf(1, 4), "FP32"))
)
val inputB = GraphNode(
id = "b",
operation = InputOperation<DType, Any>(),
inputs = emptyList(),
outputs = listOf(TensorSpec("b", listOf(1, 4), "FP32"))
)
val add = GraphNode(
id = "add1",
operation = AddOperation<DType, Any>(),
inputs = listOf(
TensorSpec("a", listOf(1, 4), "FP32"),
TensorSpec("b", listOf(1, 4), "FP32")
),
outputs = listOf(TensorSpec("c", listOf(1, 4), "FP32"))
)

graph.addNode(inputA)
graph.addNode(inputB)
graph.addNode(add)
graph.addEdge(GraphEdge("e1", inputA, add, 0, 0, inputA.outputs[0]))
graph.addEdge(GraphEdge("e2", inputB, add, 0, 1, inputB.outputs[0]))

val mlir = toStableHlo(graph, "dense_add").content
assertFalse(
mlir.contains("tensor_encoding"),
"dense graph must not emit any tensor_encoding annotation"
)
}
}
Loading