Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sk.ainet.compile.hlo

import sk.ainet.compile.hlo.converters.ActivationOperationsConverter
import sk.ainet.compile.hlo.converters.ConstantOperationsConverter
import sk.ainet.compile.hlo.converters.GatherOperationsConverter
import sk.ainet.compile.hlo.converters.LegacyOperationsConverter
import sk.ainet.compile.hlo.converters.LinalgOperationsConverter
import sk.ainet.compile.hlo.converters.MathOperationsConverter
Expand Down Expand Up @@ -46,6 +47,10 @@ public object StableHloConverterFactory {
// Register constant operations converter
registry.register(ConstantOperationsConverter())

// Register gather / embedding / index_select converter — the
// LLM front-door op for token-id \u2192 embedding lookups.
registry.register(GatherOperationsConverter())

return StableHloConverter(registry, typeMapper, validator)
}

Expand Down Expand Up @@ -81,6 +86,10 @@ public object StableHloConverterFactory {
// Register constant operations converter
registry.register(ConstantOperationsConverter())

// Register gather / embedding / index_select converter — the
// LLM front-door op for token-id \u2192 embedding lookups.
registry.register(GatherOperationsConverter())

return StableHloConverter(registry, typeMapper, validator)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
package sk.ainet.compile.hlo.converters

import sk.ainet.compile.hlo.ConversionContext
import sk.ainet.compile.hlo.ConversionResult
import sk.ainet.compile.hlo.StableHloOperationConverter
import sk.ainet.lang.graph.GraphNode

/**
* Converter for memory-access / indexing operations.
*
* Today that's just `gather` and its framework aliases — the
* critical path for LLM exports, where every transformer forward
* pass begins with a token-id \u2192 embedding lookup. Without a
* converter for `gather` / `embedding` / `index_select`, a traced
* Llama / Mistral / Qwen / Gemma model fails at the very first
* operation and never reaches the norms, activations, or attention
* that the other P1 converters cover.
*
* The target lowering is the canonical `embedding(input_ids)`
* shape: a 1-D index tensor indexing the leading dimension of a
* 2-D embedding weight. Higher-rank gathers (attention-side index
* gathers, multi-dim scatter/gather) can be added in follow-up PRs
* once a traced model surfaces them; scoping this converter to the
* LLM front-door case keeps review tight.
*
* Emitted shape:
*
* ```mlir
* %out = stablehlo.gather(%weights, %indices)
* { dimension_numbers = #stablehlo.gather<
* offset_dims = [1],
* collapsed_slice_dims = [0],
* start_index_map = [0],
* index_vector_dim = 1>,
* slice_sizes = array<i64: 1, hidden_size>,
* indices_are_sorted = false }
* : (tensor<vocab_size x hidden_size x f32>, tensor<seq_len x i32>)
* -> tensor<seq_len x hidden_size x f32>
* ```
*
* The `slice_sizes` vector is derived from the weight shape: a 1
* along the gathered axis and the full extent of every other
* dimension. `offset_dims`, `collapsed_slice_dims`, and
* `start_index_map` are computed from the single gather axis.
*/
public class GatherOperationsConverter : StableHloOperationConverter {

override val supportedOperations: Set<String> = setOf(
"gather", "embedding", "Embedding", "index_select"
)

override fun convert(
node: GraphNode,
operands: List<String>,
context: ConversionContext
): ConversionResult {
return when (node.operation.name.lowercase()) {
"gather", "embedding", "index_select" -> convertGather(node, operands, context)
else -> ConversionResult.Unsupported(
node.operation.name,
"Operation not supported by GatherOperationsConverter"
)
}
}

private fun convertGather(
node: GraphNode,
operands: List<String>,
context: ConversionContext
): ConversionResult {
if (operands.size < 2) {
return ConversionResult.Failure(
"Gather operation requires 2 operands (weights, indices), got ${operands.size}",
"Unsupported gather arity for node ${node.id}"
)
}

val weightSpec = node.inputs.getOrNull(0)
val indicesSpec = node.inputs.getOrNull(1)
val outputSpec = node.outputs.firstOrNull()

val typeMapper = context.getTypeMapper()
val weightType = weightSpec?.let { typeMapper.mapTensorType(it) } ?: "tensor<?x?xf32>"
val indicesType = indicesSpec?.let { typeMapper.mapTensorType(it) } ?: "tensor<?xi32>"
val outputType = outputSpec?.let { typeMapper.mapTensorType(it) } ?: "tensor<?x?xf32>"

val weightShape = weightSpec?.shape ?: emptyList()
val weightRank = weightShape.size
val indicesRank = indicesSpec?.shape?.size ?: 1

// Gather axis. Default to 0 (the conventional embedding-lookup
// shape) and normalize negative axes against the weight rank.
val rawAxis = node.operation.parameters["axis"] as? Int
?: node.operation.parameters["dim"] as? Int
?: 0
val axis = when {
weightRank == 0 -> 0
rawAxis < 0 -> weightRank + rawAxis
else -> rawAxis
}.coerceIn(0, (weightRank - 1).coerceAtLeast(0))

// offset_dims: the axes of the output that carry "the rest of
// the row" — every weight axis except the gathered one, offset
// by the indices rank (which sits at the beginning of the
// output shape for a canonical gather).
val offsetDims = (0 until weightRank)
.filter { it != axis }
.mapIndexed { i, _ -> indicesRank + i }
.joinToString(", ")

// collapsed_slice_dims: the axes of the weight that are
// "picked" by the indices — just the gathered axis for this
// single-axis case.
val collapsedSliceDims = "$axis"

// start_index_map: index `i` in the indices tensor maps to
// start coordinate along the weight's gathered axis.
val startIndexMap = "$axis"

// index_vector_dim: the axis of the indices tensor that holds
// the multi-dim coordinate. For a 1-D index tensor indexing a
// single axis, this is the rank (i.e. one past the last dim),
// following StableHLO convention that a trailing scalar
// "implicit index vector" is allowed.
val indexVectorDim = indicesRank

// slice_sizes: a 1 along the gathered axis, the full extent
// along every other axis.
val sliceSizes = weightShape.mapIndexed { i, extent ->
if (i == axis) 1 else extent
}.joinToString(", ")

val weightOperand = operands[0]
val indicesOperand = operands[1]
val resultValue = context.nextTempValue()
val gatherOp = "$resultValue = stablehlo.gather($weightOperand, $indicesOperand) " +
"{ dimension_numbers = #stablehlo.gather<" +
"offset_dims = [$offsetDims], " +
"collapsed_slice_dims = [$collapsedSliceDims], " +
"start_index_map = [$startIndexMap], " +
"index_vector_dim = $indexVectorDim>, " +
"slice_sizes = array<i64: $sliceSizes>, " +
"indices_are_sorted = false } " +
": ($weightType, $indicesType) -> $outputType"

context.emitOperation(gatherOp)

return ConversionResult.Success(
outputValueName = resultValue,
emittedOperations = listOf(gatherOp)
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
package sk.ainet.compile.hlo

import sk.ainet.lang.graph.DefaultComputeGraph
import sk.ainet.lang.graph.GraphEdge
import sk.ainet.lang.graph.GraphNode
import sk.ainet.lang.tensor.Tensor
import sk.ainet.lang.tensor.ops.Operation
import sk.ainet.lang.tensor.ops.TensorSpec
import sk.ainet.lang.tensor.ops.ValidationResult
import sk.ainet.lang.types.DType
import kotlin.test.Test
import kotlin.test.assertFalse
import kotlin.test.assertTrue

/**
* Covers the gather / embedding converter for #483. Every LLM export
* begins with a token-id \u2192 embedding lookup and the StableHLO
* emitter had no converter for `gather` / `embedding` today — a
* traced Llama / Mistral / Qwen / Gemma forward pass therefore failed
* at the very first operation.
*
* Target is the canonical `embedding(input_ids)` shape: 1-D index
* tensor indexing the leading dimension of a 2-D embedding weight.
* The lowering follows the StableHLO gather custom assembly that
* downstream MLIR tools (IREE in particular) expect.
*/
class GatherConverterTest {

@Test
fun gather_and_embedding_aliases_are_supported() {
val module = buildEmbeddingModule(opName = "gather")
assertTrue(module.content.contains("stablehlo.gather"))
assertFalse(
module.content.contains("Unsupported operation gather"),
"`gather` must be claimed by a converter, not dropped as unsupported"
)
assertFalse(
module.content.contains("No converter found"),
"`gather` must be claimed by a converter, not left without a handler"
)
}

@Test
fun embedding_alias_routes_to_same_lowering() {
val module = buildEmbeddingModule(opName = "embedding")
assertTrue(module.content.contains("stablehlo.gather"))
assertFalse(module.content.contains("Unsupported operation"))
}

@Test
fun index_select_alias_routes_to_same_lowering() {
val module = buildEmbeddingModule(opName = "index_select")
assertTrue(module.content.contains("stablehlo.gather"))
assertFalse(module.content.contains("Unsupported operation"))
}

@Test
fun embedding_lowering_carries_canonical_dim_numbers_and_slice_sizes() {
val module = buildEmbeddingModule(opName = "embedding")
println("[DEBUG_LOG] gather/embedding export:\n${module.content}")

// The emitted op must carry the dim_numbers / slice_sizes
// custom assembly that downstream MLIR tools expect for a
// 1-D index tensor gathering rows from a 2-D weight.
assertTrue(
module.content.contains("dimension_numbers"),
"gather must emit a dimension_numbers attribute"
)
assertTrue(
module.content.contains("offset_dims = [1]"),
"gather must declare offset_dims = [1] for an axis-0 row gather on a 2-D weight"
)
assertTrue(
module.content.contains("collapsed_slice_dims = [0]"),
"gather must declare collapsed_slice_dims = [0] for the gathered axis"
)
assertTrue(
module.content.contains("start_index_map = [0]"),
"gather must declare start_index_map = [0]"
)
assertTrue(
module.content.contains("slice_sizes = array<i64: 1, 4>"),
"gather must declare slice_sizes = [1, hidden_size=4] matching the weight row shape"
)

// Tight regression check: the gather operands must be the
// actual SSA value names, not a bracketed list expression.
// (Earlier draft accidentally emitted
// `stablehlo.gather([%arg0, %arg1][0], [%arg0, %arg1][1])`
// because of a `$operands[0]` Kotlin string-template pitfall.)
assertTrue(
module.content.contains("stablehlo.gather(%arg0, %arg1)"),
"gather must reference operands as bare SSA values, not `[%arg0, %arg1][0]`"
)
assertFalse(
module.content.contains("stablehlo.gather([%"),
"gather must not emit operand lists as Kotlin-string `[..., ...][0]` junk"
)
}

private fun buildEmbeddingModule(opName: String): StableHloModule {
val graph = DefaultComputeGraph()

val vocabSize = 8
val hiddenSize = 4
val seqLen = 3

val weightNode = GraphNode(
id = "W",
operation = markerInputOp(),
inputs = emptyList(),
outputs = listOf(TensorSpec("W", listOf(vocabSize, hiddenSize), "FP32"))
)
val indicesNode = GraphNode(
id = "ids",
operation = markerInputOp(),
inputs = emptyList(),
outputs = listOf(TensorSpec("ids", listOf(seqLen), "INT32"))
)
val gatherNode = GraphNode(
id = "embed1",
operation = gatherOp(opName, axis = 0),
inputs = listOf(
TensorSpec("W", listOf(vocabSize, hiddenSize), "FP32"),
TensorSpec("ids", listOf(seqLen), "INT32")
),
outputs = listOf(TensorSpec("y", listOf(seqLen, hiddenSize), "FP32"))
)

graph.addNode(weightNode)
graph.addNode(indicesNode)
graph.addNode(gatherNode)
graph.addEdge(GraphEdge("e1", weightNode, gatherNode, 0, 0, weightNode.outputs[0]))
graph.addEdge(GraphEdge("e2", indicesNode, gatherNode, 0, 1, indicesNode.outputs[0]))

val converter = StableHloConverterFactory.createExtended()
return converter.convert(graph, "test_$opName")
}

private fun markerInputOp(): Operation = object : Operation {
override val name: String = "input"
override val type: String = "input"
override val parameters: Map<String, Any> = emptyMap()
override fun <T : DType, V> execute(inputs: List<Tensor<T, V>>): List<Tensor<T, V>> =
throw UnsupportedOperationException("test fixture only")
override fun validateInputs(inputs: List<TensorSpec>): ValidationResult = ValidationResult.Valid
override fun inferOutputs(inputs: List<TensorSpec>): List<TensorSpec> = emptyList()
override fun clone(newParameters: Map<String, Any>): Operation = this
override fun serialize(): Map<String, Any> = mapOf("name" to name, "type" to type)
}

private fun gatherOp(name: String, axis: Int): Operation = object : Operation {
override val name: String = name
override val type: String = "indexing"
override val parameters: Map<String, Any> = mapOf("axis" to axis)
override fun <T : DType, V> execute(inputs: List<Tensor<T, V>>): List<Tensor<T, V>> =
throw UnsupportedOperationException("test fixture only")
override fun validateInputs(inputs: List<TensorSpec>): ValidationResult = ValidationResult.Valid
override fun inferOutputs(inputs: List<TensorSpec>): List<TensorSpec> = inputs.take(1)
override fun clone(newParameters: Map<String, Any>): Operation = this
override fun serialize(): Map<String, Any> = mapOf(
"name" to name, "type" to type, "parameters" to parameters
)
}
}
Loading