Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,30 +36,36 @@ public class MlirValidator {

for ((lineNum, line) in lines.withIndex()) {
val trimmed = line.trim()

// Skip empty lines and comments
if (trimmed.isEmpty() || trimmed.startsWith("//")) continue

// Check brace balance
braceCount += trimmed.count { it == '{' }
braceCount -= trimmed.count { it == '}' }

// Check module structure
if (trimmed.startsWith("module")) {
if (inModule) {
errors.add("Line ${lineNum + 1}: Nested modules not allowed")
}
inModule = true
// Module headers may carry a `module attributes { ... } {`
// preamble whose attribute dict contains `name = "value"`
// entries. These aren't SSA assignments and must not be
// fed into validateSSAValue, so stop processing this line
// here.
continue
}

// Check function structure
if (trimmed.contains("func.func")) {
if (!inModule) {
errors.add("Line ${lineNum + 1}: Function must be inside module")
}
inFunction = true
}

// Check for basic SSA value format
if (trimmed.contains(" = ") && !validateSSAValue(trimmed)) {
errors.add("Line ${lineNum + 1}: Invalid SSA value format")
Expand Down Expand Up @@ -90,10 +96,13 @@ public class MlirValidator {

for ((lineNum, line) in lines.withIndex()) {
val trimmed = line.trim()

// Skip empty lines and comments
if (trimmed.isEmpty() || trimmed.startsWith("//")) continue


// Skip empty lines, comments, and module header lines (which
// may carry a `module attributes { ... }` dictionary whose
// `name = "value"` entries look like SSA assignments but are
// not).
if (trimmed.isEmpty() || trimmed.startsWith("//") || trimmed.startsWith("module")) continue

// Extract defined SSA values
if (trimmed.contains(" = ")) {
val parts = trimmed.split(" = ", limit = 2)
Expand Down Expand Up @@ -162,8 +171,10 @@ public class MlirValidator {
*/
public fun validateModule(content: String): List<String> {
val errors = mutableListOf<String>()

if (!content.contains("module {")) {

// Accept both the bare `module {` and the attributes-carrying
// `module attributes { ... } {` header forms.
if (!content.contains("module {") && !content.contains("module attributes")) {
errors.add("Missing module declaration")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package sk.ainet.compile.hlo
import sk.ainet.lang.graph.ComputeGraph
import sk.ainet.lang.graph.GraphNode
import sk.ainet.lang.tensor.ops.TensorSpec
import sk.ainet.lang.tensor.ops.tensorEncoding
import sk.ainet.lang.tensor.storage.TensorEncoding

/**
* Main converter class that orchestrates the conversion process from ComputeGraph to StableHLO MLIR.
Expand Down Expand Up @@ -45,9 +47,26 @@ public class StableHloConverter(

// Build function signature with proper return types
val functionSignature = buildFunctionSignature(inputNodes, outputSpecs, functionName)

// Start building MLIR content
context.emitLine("module {")

// Collect every TensorSpec with a non-null tensorEncoding into a
// single name -> encoding map. Emitting this as a structured
// MLIR attribute on the module header lets downstream tools
// enumerate every encoded tensor via one attribute lookup
// instead of string-matching against scattered comments.
val tensorEncodings = collectTensorEncodings(topo)

// Start building MLIR content — promote to `module attributes`
// only when we have at least one encoded tensor. Dense graphs
// keep the bare `module {` header for byte-for-byte backward
// compatibility with existing round-trip tests.
if (tensorEncodings.isNotEmpty()) {
val dictEntries = tensorEncodings.entries
.sortedBy { it.key }
.joinToString(", ") { (name, encoding) -> "$name = \"${encoding.name}\"" }
context.emitLine("module attributes {skainet.tensor_encodings = {$dictEntries}} {")
} else {
context.emitLine("module {")
}
context.emitLine(" func.func $functionSignature {")

// Initialize input values in context
Expand Down Expand Up @@ -176,6 +195,27 @@ public class StableHloConverter(
}
}

/**
* Walk every node's input and output specs once and collect the
* `name -> encoding` map of every tensor that carries a non-null
* [TensorEncoding]. Duplicates (the same name appearing in multiple
* nodes) collapse to a single entry — first-writer-wins.
*/
private fun collectTensorEncodings(nodes: List<GraphNode>): Map<String, TensorEncoding> {
val result = linkedMapOf<String, TensorEncoding>()
for (node in nodes) {
for (spec in node.outputs) {
val encoding = spec.tensorEncoding ?: continue
result.putIfAbsent(spec.name, encoding)
}
for (spec in node.inputs) {
val encoding = spec.tensorEncoding ?: continue
result.putIfAbsent(spec.name, encoding)
}
}
return result
}

/**
* Determine output specifications from output nodes
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
package sk.ainet.compile.hlo

import sk.ainet.lang.graph.DefaultComputeGraph
import sk.ainet.lang.graph.GraphEdge
import sk.ainet.lang.graph.GraphNode
import sk.ainet.lang.tensor.ops.AddOperation
import sk.ainet.lang.tensor.ops.InputOperation
import sk.ainet.lang.tensor.ops.TensorSpec
import sk.ainet.lang.tensor.ops.withTensorEncoding
import sk.ainet.lang.tensor.storage.TensorEncoding
import sk.ainet.lang.types.DType
import kotlin.test.Test
import kotlin.test.assertFalse
import kotlin.test.assertTrue

/**
* Covers the structured module-level attribute emission for #477:
* every TensorSpec flowing through the graph with a non-null
* tensorEncoding must appear in a single `skainet.tensor_encodings`
* dictionary on the emitted `module attributes { ... }` header, so
* downstream tools can read it with one attribute lookup instead of
* string-matching against scattered comments.
*/
class TensorEncodingsModuleAttributeTest {

@Test
fun encoded_weights_produce_module_attributes_block() {
val graph = DefaultComputeGraph()

val inputA = GraphNode(
id = "a",
operation = InputOperation<DType, Any>(),
inputs = emptyList(),
outputs = listOf(TensorSpec("a", listOf(1, 4), "FP32"))
)

// Two weight inputs with distinct encodings, exactly the shape
// TraceToGraphBuilder.finalize() produces post-#469 when a session
// resolves quantized weights.
val q8Spec = TensorSpec("w_q8", listOf(1, 4), "FP32")
.withTensorEncoding(TensorEncoding.Q8_0)
val q8Node = GraphNode(
id = "w_q8",
operation = InputOperation<DType, Any>(),
inputs = emptyList(),
outputs = listOf(q8Spec)
)

val q4Spec = TensorSpec("w_q4", listOf(1, 4), "FP32")
.withTensorEncoding(TensorEncoding.Q4_K)
val q4Node = GraphNode(
id = "w_q4",
operation = InputOperation<DType, Any>(),
inputs = emptyList(),
outputs = listOf(q4Spec)
)

val add1 = GraphNode(
id = "add1",
operation = AddOperation<DType, Any>(),
inputs = listOf(TensorSpec("a", listOf(1, 4), "FP32"), q8Spec),
outputs = listOf(TensorSpec("sum1", listOf(1, 4), "FP32"))
)
val add2 = GraphNode(
id = "add2",
operation = AddOperation<DType, Any>(),
inputs = listOf(TensorSpec("sum1", listOf(1, 4), "FP32"), q4Spec),
outputs = listOf(TensorSpec("sum2", listOf(1, 4), "FP32"))
)

graph.addNode(inputA)
graph.addNode(q8Node)
graph.addNode(q4Node)
graph.addNode(add1)
graph.addNode(add2)
graph.addEdge(GraphEdge("e1", inputA, add1, 0, 0, inputA.outputs[0]))
graph.addEdge(GraphEdge("e2", q8Node, add1, 0, 1, q8Spec))
graph.addEdge(GraphEdge("e3", add1, add2, 0, 0, add1.outputs[0]))
graph.addEdge(GraphEdge("e4", q4Node, add2, 0, 1, q4Spec))

val mlir = toStableHlo(graph, "quant_chain").content
println("[DEBUG_LOG] module-attribute export:\n$mlir")

// The emitted module header must carry a structured attribute
// enumerating every encoded tensor in one place.
assertTrue(
mlir.contains("module attributes"),
"module header must be emitted with `module attributes { ... }` when encodings are present"
)
assertTrue(
mlir.contains("skainet.tensor_encodings"),
"module attributes must include the `skainet.tensor_encodings` dictionary"
)

// Both encoded tensors must appear in the dictionary by name,
// each mapped to its TensorEncoding.name.
assertTrue(
mlir.contains("w_q8 = \"Q8_0\""),
"dictionary must map `w_q8` to `\"Q8_0\"`"
)
assertTrue(
mlir.contains("w_q4 = \"Q4_K\""),
"dictionary must map `w_q4` to `\"Q4_K\"`"
)
}

@Test
fun dense_graph_keeps_bare_module_header() {
// A graph with no encoding metadata must emit the bare
// `module {` header with no `attributes` block. A `null`
// tensorEncoding is the unknown / not-carried state — not
// Dense — and the emitter must stay silent.
val graph = DefaultComputeGraph()

val inputA = GraphNode(
id = "a",
operation = InputOperation<DType, Any>(),
inputs = emptyList(),
outputs = listOf(TensorSpec("a", listOf(1, 4), "FP32"))
)
val inputB = GraphNode(
id = "b",
operation = InputOperation<DType, Any>(),
inputs = emptyList(),
outputs = listOf(TensorSpec("b", listOf(1, 4), "FP32"))
)
val add = GraphNode(
id = "add1",
operation = AddOperation<DType, Any>(),
inputs = listOf(
TensorSpec("a", listOf(1, 4), "FP32"),
TensorSpec("b", listOf(1, 4), "FP32")
),
outputs = listOf(TensorSpec("c", listOf(1, 4), "FP32"))
)

graph.addNode(inputA)
graph.addNode(inputB)
graph.addNode(add)
graph.addEdge(GraphEdge("e1", inputA, add, 0, 0, inputA.outputs[0]))
graph.addEdge(GraphEdge("e2", inputB, add, 0, 1, inputB.outputs[0]))

val mlir = toStableHlo(graph, "dense_add").content

assertFalse(
mlir.contains("module attributes"),
"dense graph must not emit a `module attributes` block"
)
assertFalse(
mlir.contains("skainet.tensor_encodings"),
"dense graph must not emit the `skainet.tensor_encodings` dictionary"
)
assertTrue(
mlir.contains("module {"),
"dense graph must keep the bare `module {` header"
)
}
}
Loading