diff --git a/skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/trace/TraceToGraphBuilder.kt b/skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/trace/TraceToGraphBuilder.kt index b31990a1..55a5a304 100644 --- a/skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/trace/TraceToGraphBuilder.kt +++ b/skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/trace/TraceToGraphBuilder.kt @@ -6,6 +6,8 @@ import sk.ainet.lang.graph.GraphNode import sk.ainet.lang.tensor.ops.Operation import sk.ainet.lang.tensor.ops.TensorSpec import sk.ainet.lang.tensor.ops.ValidationResult +import sk.ainet.lang.tensor.ops.inferTensorEncoding +import sk.ainet.lang.tensor.ops.withTensorEncoding /** * Shared builder to convert OpTrace streams into a ComputeGraph. @@ -258,8 +260,14 @@ public class TraceToGraphBuilder( // Try to resolve as a constant from the session val tensor = if (!forceInput && embedConstants) session?.resolve(firstRef.tensorRef) else null val constantValues = tensor?.let { extractFloatArray(it) } + // Resolved tensors that carry a concrete storage encoding (Q4_K, + // Q8_0, TernaryPacked, TurboQuant, …) propagate it onto the + // produced spec so later compile stages can preserve the + // quantization instead of silently re-materializing FP32. + val encoding = tensor?.data?.inferTensorEncoding() val syntheticNode: GraphNode + val producedSpec: TensorSpec if (constantValues != null) { // Create a constant/weight node with embedded values val weightShape = tensor!!.shape.dimensions.toList() @@ -273,15 +281,16 @@ public class TraceToGraphBuilder( "trainable" to false ) ) + producedSpec = TensorSpec( + name = tensorId, + shape = weightShape, + dtype = weightDtype + ).withTensorEncoding(encoding) syntheticNode = GraphNode( id = nodeId, operation = op, inputs = emptyList(), - outputs = listOf(TensorSpec( - name = tensorId, - shape = weightShape, - dtype = weightDtype - )) + outputs = listOf(producedSpec) ) } else { // Create an input placeholder node @@ -291,17 +300,19 @@ public class TraceToGraphBuilder( type = "input", parameters = emptyMap() ) + producedSpec = spec.withTensorEncoding(encoding) syntheticNode = GraphNode( id = nodeId, operation = op, inputs = emptyList(), - outputs = listOf(spec) + outputs = listOf(producedSpec) ) } graph.addNode(syntheticNode) - // Wire edges to all consumers + // Wire edges to all consumers, propagating the encoding on the + // edge tensor spec so every consumer sees the quantization hint. for (ref in refs) { graph.addEdge( GraphEdge( @@ -310,13 +321,13 @@ public class TraceToGraphBuilder( destination = ref.consumerNode, sourceOutputIndex = 0, destinationInputIndex = ref.inputIndex, - tensorSpec = ref.spec + tensorSpec = ref.spec.withTensorEncoding(encoding) ) ) } // Register as producer - producersByTensorId[tensorId] = Producer(syntheticNode, 0, spec) + producersByTensorId[tensorId] = Producer(syntheticNode, 0, producedSpec) } unresolvedByTensorId.clear() } diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorSpecEncoding.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorSpecEncoding.kt new file mode 100644 index 00000000..9cc4e620 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorSpecEncoding.kt @@ -0,0 +1,54 @@ +package sk.ainet.lang.tensor.ops + +import sk.ainet.lang.tensor.data.TensorData +import sk.ainet.lang.tensor.storage.PackedBlockStorage +import sk.ainet.lang.tensor.storage.TensorEncoding + +/** + * Metadata key used to carry a [TensorEncoding] on a [TensorSpec]. + * + * Exposed so that callers that need to read/write the raw metadata map + * directly (for interop, serialization round-trips, etc.) use the same + * string the typed accessors below use. + */ +public const val TENSOR_ENCODING_METADATA_KEY: String = "tensorEncoding" + +/** + * Physical storage encoding carried on this spec, or `null` if the producer + * did not populate it. + * + * A `null` return means "unknown / not carried through the graph" — it is + * NOT equivalent to [TensorEncoding.Dense]. Consumers that need a concrete + * encoding should treat `null` as unknown and fall back to dtype-driven + * defaults rather than assuming dense. + */ +public val TensorSpec.tensorEncoding: TensorEncoding? + get() = metadata[TENSOR_ENCODING_METADATA_KEY] as? TensorEncoding + +/** + * Return a copy of this spec with [encoding] stored in its metadata map. + * Passing `null` removes the entry; passing a non-null value adds or + * replaces it, leaving all other metadata untouched. + */ +public fun TensorSpec.withTensorEncoding(encoding: TensorEncoding?): TensorSpec { + val newMetadata: Map = if (encoding == null) { + metadata - TENSOR_ENCODING_METADATA_KEY + } else { + metadata + (TENSOR_ENCODING_METADATA_KEY to encoding) + } + return copy(metadata = newMetadata) +} + +/** + * Infer a [TensorEncoding] from a concrete [TensorData] instance, or return + * `null` when the layout is dense / unknown. Single source of truth for the + * data-subclass → encoding mapping so trace builders and loaders agree. + * + * Any [TensorData] implementing [PackedBlockStorage] already exposes its + * own `encoding`, so this helper is one line today but centralizes the + * contract for future non-packed quantized layouts. + */ +public fun TensorData<*, *>.inferTensorEncoding(): TensorEncoding? = when (this) { + is PackedBlockStorage -> this.encoding + else -> null +} diff --git a/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/ops/TensorSpecEncodingTest.kt b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/ops/TensorSpecEncodingTest.kt new file mode 100644 index 00000000..bd4022d0 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/tensor/ops/TensorSpecEncodingTest.kt @@ -0,0 +1,87 @@ +package sk.ainet.lang.tensor.ops + +import sk.ainet.lang.tensor.storage.TensorEncoding +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNull +import kotlin.test.assertSame +import kotlin.test.assertTrue + +class TensorSpecEncodingTest { + + @Test + fun unset_encoding_reads_as_null() { + val spec = TensorSpec(name = "x", shape = listOf(2, 3), dtype = "FP32") + assertNull(spec.tensorEncoding) + } + + @Test + fun withTensorEncoding_round_trips_Q8_0() { + val spec = TensorSpec(name = "w", shape = listOf(32), dtype = "FP32") + val annotated = spec.withTensorEncoding(TensorEncoding.Q8_0) + + assertSame(TensorEncoding.Q8_0, annotated.tensorEncoding) + // Original spec is untouched — TensorSpec is a data class and the + // helper returns a copy. + assertNull(spec.tensorEncoding) + } + + @Test + fun withTensorEncoding_round_trips_Q4_K() { + val spec = TensorSpec(name = "w", shape = listOf(256), dtype = "FP32") + val annotated = spec.withTensorEncoding(TensorEncoding.Q4_K) + assertSame(TensorEncoding.Q4_K, annotated.tensorEncoding) + } + + @Test + fun withTensorEncoding_round_trips_TernaryPacked() { + val spec = TensorSpec(name = "w", shape = listOf(128), dtype = "FP32") + val annotated = spec.withTensorEncoding(TensorEncoding.TernaryPacked) + assertSame(TensorEncoding.TernaryPacked, annotated.tensorEncoding) + } + + @Test + fun withTensorEncoding_round_trips_Dense() { + val spec = TensorSpec(name = "x", shape = listOf(4), dtype = "FP32") + val dense = TensorEncoding.Dense(bytesPerElement = 4) + val annotated = spec.withTensorEncoding(dense) + assertEquals(dense, annotated.tensorEncoding) + } + + @Test + fun passing_null_removes_the_encoding_entry() { + val spec = TensorSpec(name = "w", shape = listOf(32), dtype = "FP32") + .withTensorEncoding(TensorEncoding.Q8_0) + assertSame(TensorEncoding.Q8_0, spec.tensorEncoding) + + val cleared = spec.withTensorEncoding(null) + assertNull(cleared.tensorEncoding) + assertTrue( + !cleared.metadata.containsKey(TENSOR_ENCODING_METADATA_KEY), + "clearing should remove the metadata key entirely, not leave a null" + ) + } + + @Test + fun withTensorEncoding_preserves_other_metadata() { + val spec = TensorSpec( + name = "w", + shape = listOf(32), + dtype = "FP32", + metadata = mapOf("owner" to "attention.q_proj", "frozen" to true) + ) + val annotated = spec.withTensorEncoding(TensorEncoding.Q8_0) + + assertEquals("attention.q_proj", annotated.metadata["owner"]) + assertEquals(true, annotated.metadata["frozen"]) + assertSame(TensorEncoding.Q8_0, annotated.tensorEncoding) + } + + @Test + fun replacing_encoding_overwrites_previous_value() { + val spec = TensorSpec(name = "w", shape = listOf(32), dtype = "FP32") + .withTensorEncoding(TensorEncoding.Q8_0) + .withTensorEncoding(TensorEncoding.Q4_K) + assertSame(TensorEncoding.Q4_K, spec.tensorEncoding) + } +}