diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/ConversionContext.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/ConversionContext.kt index 31df5816..925b3319 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/ConversionContext.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/ConversionContext.kt @@ -2,6 +2,8 @@ package sk.ainet.compile.hlo import sk.ainet.lang.graph.ComputeGraph import sk.ainet.lang.graph.GraphNode +import sk.ainet.lang.tensor.ops.TensorSpec +import sk.ainet.lang.tensor.ops.tensorEncoding /** * Context object for maintaining state during StableHLO conversion. @@ -54,6 +56,38 @@ public class ConversionContext( public fun emitComment(comment: String) { stringBuilder.appendLine(" // $comment") } + + /** + * Emit a `tensor_encoding` diagnostic comment when [spec] carries a + * non-null `tensorEncoding` (set via [sk.ainet.lang.tensor.ops.withTensorEncoding]). + * + * The emitted line has the shape: + * + * ```mlir + * // tensor_encoding: role= index= name= encoding= + * ``` + * + * MLIR tools ignore comments but text round-trips preserve them, so + * this is the cheapest way to keep SKaiNET's quantization metadata + * visible through the StableHLO emit boundary until a structured + * attribute or quant-dialect lowering lands. Emits nothing when the + * spec has no encoding — a `null` [sk.ainet.lang.tensor.storage.TensorEncoding] + * is the unknown / not-carried state, intentionally distinct from + * `TensorEncoding.Dense`. + * + * @param role Logical slot the spec occupies for the node being + * emitted (e.g. `"input"` for function arguments, `"result"` for + * node outputs). Free-form so individual converters can use + * finer-grained tags if they call this helper directly. + * @param index Positional index of the spec within its role, e.g. + * the output port index for a multi-result node. + */ + public fun emitEncodingAnnotation(role: String, index: Int, spec: TensorSpec) { + val encoding = spec.tensorEncoding ?: return + emitComment( + "tensor_encoding: role=$role index=$index name=${spec.name} encoding=${encoding.name}" + ) + } /** * Get the complete generated MLIR content diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/StableHloConverter.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/StableHloConverter.kt index 5c327ba4..4a6c3713 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/StableHloConverter.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/StableHloConverter.kt @@ -111,11 +111,19 @@ public class StableHloConverter( inputNodes.forEachIndexed { idx, node -> val valueName = "%arg$idx" context.setValueName(node.id, valueName) - + // Add comment for clarity node.outputs.firstOrNull()?.let { spec -> context.emitComment("input ${node.id}: ${spec.name} : ${typeMapper.mapTensorType(spec)}") } + + // Preserve any physical storage encoding carried on the input + // spec (Q4_K / Q8_0 / TernaryPacked / TurboQuant / …) as an + // MLIR comment so downstream tools see that quantization + // flowed through. No-op when the spec has no encoding. + node.outputs.forEachIndexed { outIdx, spec -> + context.emitEncodingAnnotation(role = "input", index = outIdx, spec = spec) + } } } @@ -138,11 +146,19 @@ public class StableHloConverter( private fun processNode(node: GraphNode, context: ConversionContext) { val converter = registry.getConverter(node.operation.name) ?: throw UnsupportedOperationException("No converter found for operation: ${node.operation.name}") - + // Get input operands from context val inputNodes = context.getInputNodes(node) val operands = inputNodes.mapNotNull { context.getValueName(it.id) } - + + // Surface any physical storage encoding declared on this node's + // result specs as an MLIR comment before the operation is + // emitted. Converters that want finer-grained placement can call + // ConversionContext.emitEncodingAnnotation themselves. + node.outputs.forEachIndexed { outIdx, spec -> + context.emitEncodingAnnotation(role = "result", index = outIdx, spec = spec) + } + // Convert the operation val result = converter.convert(node, operands, context) diff --git a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/EncodingAnnotationTest.kt b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/EncodingAnnotationTest.kt new file mode 100644 index 00000000..53b1315d --- /dev/null +++ b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/EncodingAnnotationTest.kt @@ -0,0 +1,128 @@ +package sk.ainet.compile.hlo + +import sk.ainet.lang.graph.DefaultComputeGraph +import sk.ainet.lang.graph.GraphEdge +import sk.ainet.lang.graph.GraphNode +import sk.ainet.lang.tensor.ops.AddOperation +import sk.ainet.lang.tensor.ops.InputOperation +import sk.ainet.lang.tensor.ops.TensorSpec +import sk.ainet.lang.tensor.ops.withTensorEncoding +import sk.ainet.lang.tensor.storage.TensorEncoding +import sk.ainet.lang.types.DType +import kotlin.test.Test +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +/** + * Covers the emitter hook added for #473: when a [TensorSpec] flowing + * through the graph carries a non-null `tensorEncoding`, the emitted + * StableHLO module must preserve that information as a comment so + * downstream tools (and humans reading the MLIR) can see that + * quantization flowed through the compile boundary. + */ +class EncodingAnnotationTest { + + @Test + fun q8_0_weight_produces_tensor_encoding_comment() { + val graph = DefaultComputeGraph() + + val inputA = GraphNode( + id = "a", + operation = InputOperation(), + inputs = emptyList(), + outputs = listOf(TensorSpec("a", listOf(1, 4), "FP32")) + ) + + // A Q8_0 weight, synthesized the way TraceToGraphBuilder.finalize() + // produces its weight nodes after #469 landed: the spec carries + // TensorEncoding.Q8_0 on its metadata. + val weightSpec = TensorSpec("w", listOf(1, 4), "FP32") + .withTensorEncoding(TensorEncoding.Q8_0) + val weightNode = GraphNode( + id = "w", + operation = InputOperation(), + inputs = emptyList(), + outputs = listOf(weightSpec) + ) + + val add = GraphNode( + id = "add1", + operation = AddOperation(), + inputs = listOf( + TensorSpec("a", listOf(1, 4), "FP32"), + weightSpec + ), + outputs = listOf(TensorSpec("out", listOf(1, 4), "FP32")) + ) + + graph.addNode(inputA) + graph.addNode(weightNode) + graph.addNode(add) + graph.addEdge(GraphEdge("e1", inputA, add, 0, 0, inputA.outputs[0])) + graph.addEdge(GraphEdge("e2", weightNode, add, 0, 1, weightSpec)) + + val mlir = toStableHlo(graph, "quant_add").content + println("[DEBUG_LOG] quant-annotated export:\n$mlir") + + // The emitter must have surfaced the encoding as a comment near + // the weight input's initialization. MLIR tools ignore comments + // but the text round-trips preserve them, so this is the cheapest + // way to keep SKaiNET's quantization metadata visible through the + // StableHLO emit boundary. + assertTrue( + mlir.contains("tensor_encoding"), + "emitter must include a tensor_encoding annotation comment" + ) + assertTrue( + mlir.contains("encoding=Q8_0"), + "annotation must name the concrete TensorEncoding (Q8_0)" + ) + assertTrue( + mlir.contains("name=w"), + "annotation must identify the tensor the encoding applies to" + ) + } + + @Test + fun dense_graph_emits_no_encoding_comment() { + // An all-FP32 graph with no encoding metadata must not introduce + // spurious tensor_encoding comments. A `null` tensorEncoding is + // the unknown / not-carried state, not "Dense", and the emitter + // must treat it as silent. + val graph = DefaultComputeGraph() + + val inputA = GraphNode( + id = "a", + operation = InputOperation(), + inputs = emptyList(), + outputs = listOf(TensorSpec("a", listOf(1, 4), "FP32")) + ) + val inputB = GraphNode( + id = "b", + operation = InputOperation(), + inputs = emptyList(), + outputs = listOf(TensorSpec("b", listOf(1, 4), "FP32")) + ) + val add = GraphNode( + id = "add1", + operation = AddOperation(), + inputs = listOf( + TensorSpec("a", listOf(1, 4), "FP32"), + TensorSpec("b", listOf(1, 4), "FP32") + ), + outputs = listOf(TensorSpec("c", listOf(1, 4), "FP32")) + ) + + graph.addNode(inputA) + graph.addNode(inputB) + graph.addNode(add) + graph.addEdge(GraphEdge("e1", inputA, add, 0, 0, inputA.outputs[0])) + graph.addEdge(GraphEdge("e2", inputB, add, 0, 1, inputB.outputs[0])) + + val mlir = toStableHlo(graph, "dense_add").content + assertFalse( + mlir.contains("tensor_encoding"), + "dense graph must not emit any tensor_encoding annotation" + ) + } +}