Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,31 +30,73 @@ public class MathOperationsConverter : StableHloOperationConverter {
// Additional mathematical operations
"pow", "mod", "remainder",
// Element-wise operations
"element_add", "element_sub", "element_mul", "element_div"
"element_add", "element_sub", "element_mul", "element_div",
// Element-wise type conversion. Not strictly "math", but
// MathOperationsConverter already owns the elementwise-op
// family and cast is an elementwise primitive.
"cast", "convert", "to"
)

override fun convert(
node: GraphNode,
operands: List<String>,
node: GraphNode,
operands: List<String>,
context: ConversionContext
): ConversionResult {
// Delegate basic math operations to BasicMathConverter
if (basicMathConverter.supportedOperations.contains(node.operation.name.lowercase())) {
return basicMathConverter.convert(node, operands, context)
}

// Handle additional mathematical operations
return when (node.operation.name.lowercase()) {
"pow" -> convertPower(node, operands, context)
"mod", "remainder" -> convertRemainder(node, operands, context)
"element_add", "element_sub", "element_mul", "element_div" ->
"element_add", "element_sub", "element_mul", "element_div" ->
convertElementWise(node, operands, context)
"cast", "convert", "to" -> convertCast(node, operands, context)
else -> ConversionResult.Unsupported(
node.operation.name,
"Operation not supported by MathOperationsConverter"
)
}
}

/**
* Convert cast / convert / to to stablehlo.convert.
*
* Reads the target dtype from `to`, `to_dtype`, or `dtype`
* parameter — or, when absent, from the output spec's dtype,
* which is the normal tracing path. Emits the MLIR type-
* transition signature `(<from_type>) -> <to_type>`.
*/
private fun convertCast(
node: GraphNode,
operands: List<String>,
context: ConversionContext
): ConversionResult {
if (operands.size != 1) {
return ConversionResult.Failure(
"Cast operation requires exactly 1 operand, got ${operands.size}",
"Unsupported cast arity for node ${node.id}"
)
}

val typeMapper = context.getTypeMapper()
val inputSpec = node.inputs.firstOrNull()
val outputSpec = node.outputs.firstOrNull()

val inputType = inputSpec?.let { typeMapper.mapTensorType(it) } ?: "tensor<?xf32>"
val outputType = outputSpec?.let { typeMapper.mapTensorType(it) } ?: "tensor<?xf32>"

val resultValue = context.nextTempValue()
val operation = "$resultValue = stablehlo.convert ${operands[0]} : ($inputType) -> $outputType"
context.emitOperation(operation)

return ConversionResult.Success(
outputValueName = resultValue,
emittedOperations = listOf(operation)
)
}

private fun convertPower(
node: GraphNode,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,133 @@ import sk.ainet.lang.graph.GraphNode
public class ShapeOperationsConverter : StableHloOperationConverter {

override val supportedOperations: Set<String> = setOf(
"reshape", "flatten", "squeeze", "unsqueeze"
"reshape", "flatten", "squeeze", "unsqueeze",
// Structural tensor ops — generic companions to reshape /
// flatten / squeeze. concat glues tensors along an axis,
// slice extracts a static window of a tensor.
"concat", "concatenate", "cat", "stack",
"slice"
)

override fun convert(
node: GraphNode,
operands: List<String>,
node: GraphNode,
operands: List<String>,
context: ConversionContext
): ConversionResult {
return when (node.operation.name.lowercase()) {
"reshape" -> convertReshape(node, operands, context)
"flatten" -> convertFlatten(node, operands, context)
"squeeze" -> convertSqueeze(node, operands, context)
"unsqueeze" -> convertUnsqueeze(node, operands, context)
"concat", "concatenate", "cat", "stack" -> convertConcat(node, operands, context)
"slice" -> convertSlice(node, operands, context)
else -> ConversionResult.Unsupported(
node.operation.name,
"Operation not supported by ShapeOperationsConverter"
)
}
}

/**
* Convert concat / concatenate / cat / stack to stablehlo.concatenate.
*
* Reads the join axis from `axis` or `dim` parameter (default 0)
* and emits:
*
* %out = stablehlo.concatenate %a, %b, ..., dim = <axis> : <type>
*/
private fun convertConcat(
node: GraphNode,
operands: List<String>,
context: ConversionContext
): ConversionResult {
if (operands.isEmpty()) {
return ConversionResult.Failure(
"Concat operation requires at least 1 operand, got 0",
"Unsupported concat arity for node ${node.id}"
)
}

val outputSpec = node.outputs.firstOrNull()
val outputType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) }
?: "tensor<?xf32>"

val rank = node.inputs.firstOrNull()?.shape?.size
?: outputSpec?.shape?.size ?: 0
val rawAxis = node.operation.parameters["axis"] as? Int
?: node.operation.parameters["dim"] as? Int
?: 0
val axis = if (rawAxis < 0 && rank > 0) rank + rawAxis else rawAxis

val resultValue = context.nextTempValue()
val operandList = operands.joinToString(", ")
val operation = "$resultValue = stablehlo.concatenate $operandList, dim = $axis : $outputType"
context.emitOperation(operation)

return ConversionResult.Success(
outputValueName = resultValue,
emittedOperations = listOf(operation)
)
}

/**
* Convert slice to stablehlo.slice.
*
* Reads per-dim `start_indices`, `limit_indices`, and `strides`
* from parameters and emits a static slice:
*
* %out = stablehlo.slice %x [s0:l0:d0, s1:l1:d1, ...] : <type>
*
* Strides default to 1 per dim when not supplied. Dynamic slice
* (runtime bounds) is explicitly out of scope for this first pass.
*/
private fun convertSlice(
node: GraphNode,
operands: List<String>,
context: ConversionContext
): ConversionResult {
if (operands.size != 1) {
return ConversionResult.Failure(
"Slice operation requires exactly 1 operand, got ${operands.size}",
"Unsupported slice arity for node ${node.id}"
)
}

val outputSpec = node.outputs.firstOrNull()
val outputType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) }
?: "tensor<?xf32>"

val inputShape = node.inputs.firstOrNull()?.shape ?: emptyList()
val rank = inputShape.size

@Suppress("UNCHECKED_CAST")
val starts = (node.operation.parameters["start_indices"] as? List<Int>)
?: (node.operation.parameters["starts"] as? List<Int>)
?: List(rank) { 0 }
@Suppress("UNCHECKED_CAST")
val limits = (node.operation.parameters["limit_indices"] as? List<Int>)
?: (node.operation.parameters["limits"] as? List<Int>)
?: inputShape
@Suppress("UNCHECKED_CAST")
val strides = (node.operation.parameters["strides"] as? List<Int>)
?: List(rank) { 1 }

val startsAttr = starts.joinToString(", ")
val limitsAttr = limits.joinToString(", ")
val stridesAttr = strides.joinToString(", ")

val resultValue = context.nextTempValue()
val operation = "$resultValue = stablehlo.slice ${operands[0]} " +
"{start_indices = [$startsAttr], " +
"limit_indices = [$limitsAttr], " +
"strides = [$stridesAttr]} : $outputType"
context.emitOperation(operation)

return ConversionResult.Success(
outputValueName = resultValue,
emittedOperations = listOf(operation)
)
}

/**
* Convert reshape operation using stablehlo.reshape.
Expand Down
Loading
Loading