Skip to content

Commit

Permalink
refactor: change alignment implementation via compiler plugin to be d…
Browse files Browse the repository at this point in the history
…ebug-friendly (issue #337) (#347)
  • Loading branch information
nicolasfara committed May 18, 2024
1 parent 85e18d3 commit c4ac438
Show file tree
Hide file tree
Showing 15 changed files with 144 additions and 380 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class TestCompileString : FreeSpec({
disassembled shouldNot beNull()
disassembled shouldNot beBlank()
val alignedOnCalls = disassembled.lines().filter {
"// InterfaceMethod it/unibo/collektive/aggregate/api/Aggregate.alignedOn:" in it
"// InterfaceMethod it/unibo/collektive/aggregate/api/Aggregate.align" in it
}
alignedOnCalls.size should beGreaterThan(1)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
@file:Suppress("ReturnCount")

package it.unibo.collektive

import it.unibo.collektive.transformers.AggregateCallTransformer
import it.unibo.collektive.transformers.EnabledCompilerPluginTransformer
import it.unibo.collektive.utils.common.AggregateFunctionNames
import it.unibo.collektive.utils.common.AggregateFunctionNames.ALIGN_FUNCTION
import it.unibo.collektive.utils.common.AggregateFunctionNames.DEALIGN_RAW_FUNCTION
import it.unibo.collektive.utils.common.AggregateFunctionNames.PROJECT_FUNCTION
import it.unibo.collektive.utils.logging.error
import org.jetbrains.kotlin.backend.common.extensions.IrGenerationExtension
import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
import org.jetbrains.kotlin.cli.common.messages.MessageCollector
import org.jetbrains.kotlin.ir.declarations.IrModuleFragment
import org.jetbrains.kotlin.ir.symbols.IrClassSymbol
import org.jetbrains.kotlin.ir.symbols.IrFunctionSymbol
import org.jetbrains.kotlin.ir.util.functions
import org.jetbrains.kotlin.name.CallableId
import org.jetbrains.kotlin.name.ClassId
Expand All @@ -19,13 +24,15 @@ import org.jetbrains.kotlin.name.Name
* The generation extension is used to register the transformer plugin, which is going to modify
* the IR using the function responsible for the alignment.
*/
@Suppress("ReturnCount")
class AlignmentIrGenerationExtension(private val logger: MessageCollector) : IrGenerationExtension {
override fun generate(moduleFragment: IrModuleFragment, pluginContext: IrPluginContext) {
// Aggregate Context class that has the reference to the stack
val aggregateClass = pluginContext.referenceClass(
ClassId.topLevel(FqName(AggregateFunctionNames.AGGREGATE_CLASS)),
)
if (aggregateClass == null) {
return logger.error("Unable to find the aggregate class")
}

val projectFunction = pluginContext.referenceFunctions(
CallableId(
Expand All @@ -35,49 +42,28 @@ class AlignmentIrGenerationExtension(private val logger: MessageCollector) : IrG
).firstOrNull() ?: return logger.error("Unable to find the 'project' function")

// Function that handles the alignment
val alignedOnFunction = aggregateClass
?.functions
?.filter { it.owner.name == Name.identifier(AggregateFunctionNames.ALIGNED_ON_FUNCTION) }
?.firstOrNull()

requireNotNull(alignedOnFunction) {
val error = """
Aggregate alignment requires function ${AggregateFunctionNames.ALIGNED_ON_FUNCTION} to be available.
Please, add the required library TODO TODO (gradle block):
""".trimIndent()
error.also(logger::error)
}
val (alignFunction, aggClass) = getBothOrNull(alignedOnFunction, aggregateClass)
?: return logger.error(
"The function and the class used to handle the alignment have not been found.",
)
val alignRawFunction = aggregateClass.getFunctionReferenceWithName(ALIGN_FUNCTION)
?: return logger.error("Unable to find the `$ALIGN_FUNCTION` function")

val isCompilerPluginAppliedFunction = pluginContext.referenceFunctions(
CallableId(
FqName("it.unibo.collektive.aggregate.api.impl"),
Name.identifier(AggregateFunctionNames.IS_COMPILER_PLUGIN_APPLIED_FUNCTION),
),
).firstOrNull() ?: return logger.error("Unable to find the 'isCompilerPluginApplied' function")
val dealignFunction = aggregateClass.getFunctionReferenceWithName(DEALIGN_RAW_FUNCTION)
?: return logger.error("Unable to find the `$DEALIGN_RAW_FUNCTION` function")

/*
This applies the alignment call on all the aggregate functions
*/
moduleFragment.transform(
AggregateCallTransformer(
pluginContext,
logger,
aggClass.owner,
alignFunction.owner,
aggregateClass.owner,
alignRawFunction.owner,
dealignFunction.owner,
projectFunction.owner,
),
null,
)
/*
This transformation changes the `isCompilerPluginApplied` function to return true.
*/
moduleFragment.transform(
EnabledCompilerPluginTransformer(pluginContext, logger, isCompilerPluginAppliedFunction.owner),
null,
)
}

private fun <F, S> getBothOrNull(first: F?, second: S?): Pair<F, S>? =
if (first != null && second != null) first to second else null
private fun IrClassSymbol.getFunctionReferenceWithName(functionName: String): IrFunctionSymbol? =
functions.firstOrNull { it.owner.name == Name.identifier(functionName) }
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ class AggregateCallTransformer(
private val pluginContext: IrPluginContext,
private val logger: MessageCollector,
private val aggregateClass: IrClass,
private val alignedOnFunction: IrFunction,
private val alignRawFunction: IrFunction,
private val dealignFunction: IrFunction,
private val projectFunction: IrFunction,
) : IrElementTransformerVoid() {

Expand All @@ -41,10 +42,16 @@ class AggregateCallTransformer(
null,
)
/*
This transformation is needed to add the `alignOn` function call to the aggregate functions.
This transformation is needed to add the `alignRaw` and `dealign` function call to the aggregate functions.
*/
declaration.transformChildren(
AlignmentTransformer(pluginContext, logger, aggregateClass, declaration, alignedOnFunction),
AlignmentTransformer(
pluginContext,
aggregateClass,
declaration,
alignRawFunction,
dealignFunction,
),
StackFunctionCall(),
)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package it.unibo.collektive.transformers

import it.unibo.collektive.utils.branch.addBranchAlignment
import it.unibo.collektive.utils.call.buildAlignedOnCall
import it.unibo.collektive.utils.common.AggregateFunctionNames
import it.unibo.collektive.utils.common.AggregateFunctionNames.ALIGN_FUNCTION
import it.unibo.collektive.utils.common.AggregateFunctionNames.DEALIGN_RAW_FUNCTION
import it.unibo.collektive.utils.common.findAggregateReference
import it.unibo.collektive.utils.common.getAlignmentToken
import it.unibo.collektive.utils.common.irStatement
import it.unibo.collektive.utils.common.isAssignableFrom
Expand All @@ -11,13 +11,23 @@ import it.unibo.collektive.utils.stack.StackFunctionCall
import it.unibo.collektive.visitors.collectAggregateReference
import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
import org.jetbrains.kotlin.backend.jvm.ir.receiverAndArgs
import org.jetbrains.kotlin.cli.common.messages.MessageCollector
import org.jetbrains.kotlin.ir.IrElement
import org.jetbrains.kotlin.ir.builders.IrBlockBodyBuilder
import org.jetbrains.kotlin.ir.builders.createTmpVariable
import org.jetbrains.kotlin.ir.builders.irBlock
import org.jetbrains.kotlin.ir.builders.irBoolean
import org.jetbrains.kotlin.ir.builders.irCall
import org.jetbrains.kotlin.ir.builders.irGet
import org.jetbrains.kotlin.ir.builders.irString
import org.jetbrains.kotlin.ir.declarations.IrClass
import org.jetbrains.kotlin.ir.declarations.IrFunction
import org.jetbrains.kotlin.ir.expressions.IrBranch
import org.jetbrains.kotlin.ir.expressions.IrCall
import org.jetbrains.kotlin.ir.expressions.IrConst
import org.jetbrains.kotlin.ir.expressions.IrContainerExpression
import org.jetbrains.kotlin.ir.expressions.IrElseBranch
import org.jetbrains.kotlin.ir.expressions.IrExpression
import org.jetbrains.kotlin.ir.expressions.putArgument
import org.jetbrains.kotlin.ir.util.defaultType
import org.jetbrains.kotlin.ir.visitors.IrElementTransformer

Expand All @@ -28,10 +38,10 @@ import org.jetbrains.kotlin.ir.visitors.IrElementTransformer
*/
class AlignmentTransformer(
private val pluginContext: IrPluginContext,
private val logger: MessageCollector,
private val aggregateContextClass: IrClass,
private val aggregateLambdaBody: IrFunction,
private val alignedOnFunction: IrFunction,
private val functionToAlign: IrFunction,
private val alignRawFunction: IrFunction,
private val dealignFunction: IrFunction,
) : IrElementTransformer<StackFunctionCall> {
private var alignedFunctions = emptyMap<String, Int>()

Expand All @@ -41,54 +51,83 @@ class AlignmentTransformer(
?: collectAggregateReference(aggregateContextClass, expression.symbol.owner)

val alignmentToken = expression.getAlignmentToken()
// If the context is null, this means that the function is not an aggregate function
if (contextReference == null) {
data.push(alignmentToken)
}
return contextReference?.let { context ->
// We don't want to align the alignedOn function :)
if (expression.simpleFunctionName() == AggregateFunctionNames.ALIGNED_ON_FUNCTION) {
// We don't want to align the alignRaw and dealign functions :)
val functionName = expression.simpleFunctionName()
if (functionName == ALIGN_FUNCTION || functionName == DEALIGN_RAW_FUNCTION) {
return super.visitCall(expression, data)
}
// If no function, the first time the counter is 1
val actualCounter = alignedFunctions[alignmentToken]?.let { it + 1 } ?: 1
alignedFunctions += alignmentToken to actualCounter

// If the expression contains a lambda, this recursion is necessary to visit the children
expression.transformChildren(this, StackFunctionCall())

irStatement(pluginContext, aggregateLambdaBody, expression) {
with(logger) {
buildAlignedOnCall(
pluginContext,
aggregateLambdaBody,
context,
alignedOnFunction,
expression,
data,
alignedFunctions,
)
}
}
val tokenCount = alignedFunctions[alignmentToken] ?: error(
"""
Unable to find the count for the token $alignmentToken.
This is may due to a bug in collektive compiler plugin.
""".trimIndent(),
)
val alignmentTokenRepresentation = "$data$alignmentToken.$tokenCount"
// Return the modified function body to have as a first statement the alignRaw function,
// then the body of the function to align and finally the dealign function
generateAlignmentCode(context, functionToAlign, expression) { irString(alignmentTokenRepresentation) }
} ?: super.visitCall(expression, data)
}

override fun visitBranch(branch: IrBranch, data: StackFunctionCall): IrBranch {
with(logger) {
branch.addBranchAlignment(pluginContext, aggregateContextClass, aggregateLambdaBody, alignedOnFunction)
}
branch.generateBranchAlignmentCode(true)
return super.visitBranch(branch, data)
}

override fun visitElseBranch(branch: IrElseBranch, data: StackFunctionCall): IrElseBranch {
with(logger) {
branch.addBranchAlignment(
pluginContext,
aggregateContextClass,
aggregateLambdaBody,
alignedOnFunction,
false,
)
}
branch.generateBranchAlignmentCode(false)
return super.visitElseBranch(branch, data)
}

private fun IrBranch.generateBranchAlignmentCode(condition: Boolean) {
result.findAggregateReference(aggregateContextClass)?.let {
result = generateAlignmentCode(it, functionToAlign, result) { irBoolean(condition) }
}
}

private fun <T> generateAlignmentCode(
context: IrExpression,
function: IrFunction,
expressionBody: IrExpression,
alignmentToken: IrBlockBodyBuilder.() -> IrConst<T>,
): IrContainerExpression {
return irStatement(pluginContext, function, expressionBody) {
// Call the `alignRaw` function before the body of the function to align
irBlock {
// Call the alignRaw function
+irCall(alignRawFunction).apply {
putArgument(
alignRawFunction.dispatchReceiverParameter
?: error("The alignRaw function has no dispatch receiver parameter"),
context,
)
putValueArgument(0, alignmentToken(this@irStatement))
}
val code = irBlock { +expressionBody }
// Call the body of the function to align
val variableName = "blockResult"
val variableType = expressionBody.type
val tmpVar = createTmpVariable(code, irType = variableType, nameHint = variableName)
// Call the `dealign` function after the body of the function to align
+irCall(dealignFunction).apply {
putArgument(
dealignFunction.dispatchReceiverParameter
?: error("The dealign function has no dispatch receiver parameter"),
context,
)
}
+irGet(tmpVar)
}
}
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,11 @@ class FieldTransformer(
private val projectFunction: IrFunction,
) : IrElementTransformerVoid() {
override fun visitCall(expression: IrCall): IrExpression {
if (expression.symbol.owner.name == Name.identifier(AggregateFunctionNames.ALIGNED_ON_FUNCTION)) {
logger.debug("Found alignedOn function call: ${expression.dumpKotlinLike()}")
val symbolName = expression.symbol.owner.name
val alignRawIdentifier = Name.identifier(AggregateFunctionNames.ALIGN_FUNCTION)
val alignedOnIdentifier = Name.identifier(AggregateFunctionNames.ALIGNED_ON_FUNCTION)
if (symbolName == alignRawIdentifier || symbolName == alignedOnIdentifier) {
logger.debug("Found alignedRaw function call: ${expression.dumpKotlinLike()}")
val contextReference = expression.receiverAndArgs()
.find { it.type.isAssignableFrom(aggregateClass.defaultType) }
?: collectAggregateReference(aggregateClass, expression.symbol.owner)
Expand Down
Loading

0 comments on commit c4ac438

Please sign in to comment.