Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import sk.ainet.lang.graph.GraphNode
import sk.ainet.lang.tensor.ops.Operation
import sk.ainet.lang.tensor.ops.TensorSpec
import sk.ainet.lang.tensor.ops.ValidationResult
import sk.ainet.lang.tensor.ops.inferTensorEncoding
import sk.ainet.lang.tensor.ops.withTensorEncoding

/**
* Shared builder to convert OpTrace streams into a ComputeGraph.
Expand Down Expand Up @@ -258,8 +260,14 @@ public class TraceToGraphBuilder(
// Try to resolve as a constant from the session
val tensor = if (!forceInput && embedConstants) session?.resolve(firstRef.tensorRef) else null
val constantValues = tensor?.let { extractFloatArray(it) }
// Resolved tensors that carry a concrete storage encoding (Q4_K,
// Q8_0, TernaryPacked, TurboQuant, …) propagate it onto the
// produced spec so later compile stages can preserve the
// quantization instead of silently re-materializing FP32.
val encoding = tensor?.data?.inferTensorEncoding()

val syntheticNode: GraphNode
val producedSpec: TensorSpec
if (constantValues != null) {
// Create a constant/weight node with embedded values
val weightShape = tensor!!.shape.dimensions.toList()
Expand All @@ -273,15 +281,16 @@ public class TraceToGraphBuilder(
"trainable" to false
)
)
producedSpec = TensorSpec(
name = tensorId,
shape = weightShape,
dtype = weightDtype
).withTensorEncoding(encoding)
syntheticNode = GraphNode(
id = nodeId,
operation = op,
inputs = emptyList(),
outputs = listOf(TensorSpec(
name = tensorId,
shape = weightShape,
dtype = weightDtype
))
outputs = listOf(producedSpec)
)
} else {
// Create an input placeholder node
Expand All @@ -291,17 +300,19 @@ public class TraceToGraphBuilder(
type = "input",
parameters = emptyMap()
)
producedSpec = spec.withTensorEncoding(encoding)
syntheticNode = GraphNode(
id = nodeId,
operation = op,
inputs = emptyList(),
outputs = listOf(spec)
outputs = listOf(producedSpec)
)
}

graph.addNode(syntheticNode)

// Wire edges to all consumers
// Wire edges to all consumers, propagating the encoding on the
// edge tensor spec so every consumer sees the quantization hint.
for (ref in refs) {
graph.addEdge(
GraphEdge(
Expand All @@ -310,13 +321,13 @@ public class TraceToGraphBuilder(
destination = ref.consumerNode,
sourceOutputIndex = 0,
destinationInputIndex = ref.inputIndex,
tensorSpec = ref.spec
tensorSpec = ref.spec.withTensorEncoding(encoding)
)
)
}

// Register as producer
producersByTensorId[tensorId] = Producer(syntheticNode, 0, spec)
producersByTensorId[tensorId] = Producer(syntheticNode, 0, producedSpec)
}
unresolvedByTensorId.clear()
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package sk.ainet.lang.tensor.ops

import sk.ainet.lang.tensor.data.TensorData
import sk.ainet.lang.tensor.storage.PackedBlockStorage
import sk.ainet.lang.tensor.storage.TensorEncoding

/**
* Metadata key used to carry a [TensorEncoding] on a [TensorSpec].
*
* Exposed so that callers that need to read/write the raw metadata map
* directly (for interop, serialization round-trips, etc.) use the same
* string the typed accessors below use.
*/
public const val TENSOR_ENCODING_METADATA_KEY: String = "tensorEncoding"

/**
* Physical storage encoding carried on this spec, or `null` if the producer
* did not populate it.
*
* A `null` return means "unknown / not carried through the graph" — it is
* NOT equivalent to [TensorEncoding.Dense]. Consumers that need a concrete
* encoding should treat `null` as unknown and fall back to dtype-driven
* defaults rather than assuming dense.
*/
public val TensorSpec.tensorEncoding: TensorEncoding?
get() = metadata[TENSOR_ENCODING_METADATA_KEY] as? TensorEncoding

/**
* Return a copy of this spec with [encoding] stored in its metadata map.
* Passing `null` removes the entry; passing a non-null value adds or
* replaces it, leaving all other metadata untouched.
*/
public fun TensorSpec.withTensorEncoding(encoding: TensorEncoding?): TensorSpec {
val newMetadata: Map<String, Any> = if (encoding == null) {
metadata - TENSOR_ENCODING_METADATA_KEY
} else {
metadata + (TENSOR_ENCODING_METADATA_KEY to encoding)
}
return copy(metadata = newMetadata)
}

/**
* Infer a [TensorEncoding] from a concrete [TensorData] instance, or return
* `null` when the layout is dense / unknown. Single source of truth for the
* data-subclass → encoding mapping so trace builders and loaders agree.
*
* Any [TensorData] implementing [PackedBlockStorage] already exposes its
* own `encoding`, so this helper is one line today but centralizes the
* contract for future non-packed quantized layouts.
*/
public fun TensorData<*, *>.inferTensorEncoding(): TensorEncoding? = when (this) {
is PackedBlockStorage -> this.encoding
else -> null
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package sk.ainet.lang.tensor.ops

import sk.ainet.lang.tensor.storage.TensorEncoding
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertNull
import kotlin.test.assertSame
import kotlin.test.assertTrue

class TensorSpecEncodingTest {

@Test
fun unset_encoding_reads_as_null() {
val spec = TensorSpec(name = "x", shape = listOf(2, 3), dtype = "FP32")
assertNull(spec.tensorEncoding)
}

@Test
fun withTensorEncoding_round_trips_Q8_0() {
val spec = TensorSpec(name = "w", shape = listOf(32), dtype = "FP32")
val annotated = spec.withTensorEncoding(TensorEncoding.Q8_0)

assertSame(TensorEncoding.Q8_0, annotated.tensorEncoding)
// Original spec is untouched — TensorSpec is a data class and the
// helper returns a copy.
assertNull(spec.tensorEncoding)
}

@Test
fun withTensorEncoding_round_trips_Q4_K() {
val spec = TensorSpec(name = "w", shape = listOf(256), dtype = "FP32")
val annotated = spec.withTensorEncoding(TensorEncoding.Q4_K)
assertSame(TensorEncoding.Q4_K, annotated.tensorEncoding)
}

@Test
fun withTensorEncoding_round_trips_TernaryPacked() {
val spec = TensorSpec(name = "w", shape = listOf(128), dtype = "FP32")
val annotated = spec.withTensorEncoding(TensorEncoding.TernaryPacked)
assertSame(TensorEncoding.TernaryPacked, annotated.tensorEncoding)
}

@Test
fun withTensorEncoding_round_trips_Dense() {
val spec = TensorSpec(name = "x", shape = listOf(4), dtype = "FP32")
val dense = TensorEncoding.Dense(bytesPerElement = 4)
val annotated = spec.withTensorEncoding(dense)
assertEquals(dense, annotated.tensorEncoding)
}

@Test
fun passing_null_removes_the_encoding_entry() {
val spec = TensorSpec(name = "w", shape = listOf(32), dtype = "FP32")
.withTensorEncoding(TensorEncoding.Q8_0)
assertSame(TensorEncoding.Q8_0, spec.tensorEncoding)

val cleared = spec.withTensorEncoding(null)
assertNull(cleared.tensorEncoding)
assertTrue(
!cleared.metadata.containsKey(TENSOR_ENCODING_METADATA_KEY),
"clearing should remove the metadata key entirely, not leave a null"
)
}

@Test
fun withTensorEncoding_preserves_other_metadata() {
val spec = TensorSpec(
name = "w",
shape = listOf(32),
dtype = "FP32",
metadata = mapOf("owner" to "attention.q_proj", "frozen" to true)
)
val annotated = spec.withTensorEncoding(TensorEncoding.Q8_0)

assertEquals("attention.q_proj", annotated.metadata["owner"])
assertEquals(true, annotated.metadata["frozen"])
assertSame(TensorEncoding.Q8_0, annotated.tensorEncoding)
}

@Test
fun replacing_encoding_overwrites_previous_value() {
val spec = TensorSpec(name = "w", shape = listOf(32), dtype = "FP32")
.withTensorEncoding(TensorEncoding.Q8_0)
.withTensorEncoding(TensorEncoding.Q4_K)
assertSame(TensorEncoding.Q4_K, spec.tensorEncoding)
}
}
Loading