Skip to content

Commit

Permalink
Implement unit suspend functions tail-call optimisation
Browse files Browse the repository at this point in the history
Unlike previously, this optimisation works on every callee return type.
Tail-calls inside unit functions can be either
INVOKE...
ARETURN
or
INVOKE
POP
GETSTATIC kotlin/Unit.INSTANCE
ARETURN
The first pattern is already covered. The second one is a bit tricky,
since we cannot just assume than the function is tail-call, we also need
to check whether the callee returned COROUTINE_SUSPENDED marker.
Thus, resulting bytecode of function's 'epilogue' look like
DUP
INVOKESTATIC getCOROUTINE_SUSPENDED
IF_ACMPNE LN
ARETURN
LN:
POP

 #KT-28938 Fixed
  • Loading branch information
ilmirus committed Jul 29, 2019
1 parent a16e036 commit cc06798
Show file tree
Hide file tree
Showing 13 changed files with 156 additions and 327 deletions.
Expand Up @@ -2533,8 +2533,6 @@ public void invokeMethodWithArguments(
callGenerator.genCall(callableMethod, resolvedCall, defaultMaskWasGenerated, this);

if (isSuspendNoInlineCall) {
addReturnsUnitMarkerIfNecessary(v, resolvedCall);

addSuspendMarker(v, false);
addInlineMarker(v, false);
}
Expand Down
Expand Up @@ -13,6 +13,7 @@ import org.jetbrains.kotlin.codegen.StackValue
import org.jetbrains.kotlin.codegen.TransformationMethodVisitor
import org.jetbrains.kotlin.codegen.inline.*
import org.jetbrains.kotlin.codegen.optimization.DeadCodeEliminationMethodTransformer
import org.jetbrains.kotlin.codegen.optimization.boxing.isUnitInstance
import org.jetbrains.kotlin.codegen.optimization.common.*
import org.jetbrains.kotlin.codegen.optimization.fixStack.FixStackMethodTransformer
import org.jetbrains.kotlin.codegen.optimization.fixStack.top
Expand All @@ -31,9 +32,7 @@ import org.jetbrains.org.objectweb.asm.Opcodes
import org.jetbrains.org.objectweb.asm.Type
import org.jetbrains.org.objectweb.asm.commons.InstructionAdapter
import org.jetbrains.org.objectweb.asm.tree.*
import org.jetbrains.org.objectweb.asm.tree.analysis.Frame
import org.jetbrains.org.objectweb.asm.tree.analysis.SourceInterpreter
import org.jetbrains.org.objectweb.asm.tree.analysis.SourceValue
import org.jetbrains.org.objectweb.asm.tree.analysis.*
import kotlin.math.max

private const val COROUTINES_DEBUG_METADATA_VERSION = 1
Expand Down Expand Up @@ -85,6 +84,9 @@ class CoroutineTransformerMethodVisitor(
override fun performTransformations(methodNode: MethodNode) {
removeFakeContinuationConstructorCall(methodNode)

// Remove redundant markers which came from compiled bytecode
cleanUpReturnsUnitMarkers(methodNode)

replaceFakeContinuationsWithRealOnes(
methodNode,
if (isForNamedFunction) getLastParameterIndex(methodNode.desc, methodNode.access) else 0
Expand All @@ -105,13 +107,17 @@ class CoroutineTransformerMethodVisitor(
val actualCoroutineStart = methodNode.instructions.first

if (isForNamedFunction) {
ReturnUnitMethodTransformer.transform(containingClassInternalName, methodNode)

if (putContinuationParameterToLvt) {
addCompletionParameterToLVT(methodNode)
}

if (allSuspensionPointsAreTailCalls(containingClassInternalName, methodNode, suspensionPoints)) {
val examined = ExaminedMethodNode(
languageVersionSettings,
containingClassInternalName,
methodNode
)
if (examined.allSuspensionPointsAreTailCalls(suspensionPoints)) {
examined.replacePopsBeforeSafeUnitInstancesWithCoroutineSuspendedChecks()
dropSuspensionMarkers(methodNode, suspensionPoints)
return
}
Expand All @@ -123,8 +129,6 @@ class CoroutineTransformerMethodVisitor(
continuationIndex = methodNode.maxLocals++

prepareMethodNodePreludeForNamedFunction(methodNode)
} else {
ReturnUnitMethodTransformer.cleanUpReturnsUnitMarkers(methodNode, ReturnUnitMethodTransformer.findReturnsUnitMarks(methodNode))
}

for (suspensionPoint in suspensionPoints) {
Expand Down Expand Up @@ -229,6 +233,12 @@ class CoroutineTransformerMethodVisitor(
)
}

private fun cleanUpReturnsUnitMarkers(methodNode: MethodNode) {
for (marker in methodNode.instructions.asSequence().filter(::isReturnsUnitMarker)) {
methodNode.instructions.removeAll(listOf(marker.previous, marker))
}
}

private fun findSuspensionPointLineNumber(suspensionPoint: SuspensionPoint) =
suspensionPoint.suspensionCallBegin.findPreviousOrNull { it is LineNumberNode } as LineNumberNode?

Expand Down Expand Up @@ -810,6 +820,88 @@ class CoroutineTransformerMethodVisitor(
private data class SpilledVariableDescriptor(val fieldName: String, val variableName: String)
}

// TODO Use this in variable liveness analysis
private class ExaminedMethodNode(
val languageVersionSettings: LanguageVersionSettings,
val containingClassInternalName: String,
val methodNode: MethodNode
) {
// DO NOT REORDER: pops and areturns collecting depends on unit instances collecting
// Which, in turn depends on frames and cfg
val sourceFrames: Array<Frame<SourceValue>?> =
MethodTransformer.analyze(containingClassInternalName, methodNode, IgnoringCopyOperationSourceInterpreter())
val controlFlowGraph = ControlFlowGraph.build(methodNode)

private val safeUnitInstances = collectSafeUnitInstances()
private val popsBeforeSafeUnitInstances = collectPopsBeforeSafeUnitInstances()
private val areturnsAfterSafeUnitInstances = collectAreturnsAfterSafeUnitInstances()

private fun collectSafeUnitInstances() = methodNode.instructions.asSequence().filter { unit ->
unit.isUnitInstance() &&
methodNode.instructions.asSequence().any { insn ->
insn.opcode == Opcodes.POP && insn.meaningfulSuccessors().let { succs ->
succs.all { it.isUnitInstance() } && unit in succs
}
} && unit.meaningfulSuccessors().all { it.opcode == Opcodes.ARETURN }
}.toList()

private fun collectPopsBeforeSafeUnitInstances() = methodNode.instructions.asSequence().filter { pop ->
pop.opcode == Opcodes.POP && pop.meaningfulSuccessors().all { it.isSafeUnitInstance() }
}.toList()

private fun collectAreturnsAfterSafeUnitInstances() = methodNode.instructions.asSequence().filter { areturn ->
areturn.opcode == Opcodes.ARETURN && sourceFrames[areturn.index()]?.top()?.insns?.all { it.isSafeUnitInstance() } == true
}

fun AbstractInsnNode.index() = methodNode.instructions.indexOf(this)

// GETSTATIC kotlin/Unit.INSTANCE is considered safe iff
// it is part of POP, PUSH Unit, ARETURN sequence.
fun AbstractInsnNode.isSafeUnitInstance(): Boolean = this in safeUnitInstances

fun AbstractInsnNode.isPopBeforeSafeUnitInstance(): Boolean = this in popsBeforeSafeUnitInstances
fun AbstractInsnNode.isAreturnAfterSafeUnitInstance(): Boolean = this in areturnsAfterSafeUnitInstances

private fun AbstractInsnNode.meaningfulSuccessors(): List<AbstractInsnNode> {
fun AbstractInsnNode.isMeaningful() = isMeaningful && opcode != Opcodes.NOP && opcode != Opcodes.GOTO && this !is LineNumberNode

val visited = arrayListOf<AbstractInsnNode>()
fun dfs(insn: AbstractInsnNode) {
if (insn in visited) return
visited += insn
if (!insn.isMeaningful()) {
for (succIndex in controlFlowGraph.getSuccessorsIndices(insn)) {
dfs(methodNode.instructions[succIndex])
}
}
}

for (succIndex in controlFlowGraph.getSuccessorsIndices(this)) {
dfs(methodNode.instructions[succIndex])
}
return visited.filter { it.isMeaningful() }
}

fun replacePopsBeforeSafeUnitInstancesWithCoroutineSuspendedChecks() {
val basicAnalyser = Analyzer(BasicInterpreter())
basicAnalyser.analyze(containingClassInternalName, methodNode)
val typedFrames = basicAnalyser.frames

for (pop in popsBeforeSafeUnitInstances) {
if (!isUnreachable(pop.index(), sourceFrames) && typedFrames[pop.index()]?.top()?.isReference == true) {
val label = Label()
methodNode.instructions.insertBefore(pop, withInstructionAdapter {
dup()
loadCoroutineSuspendedMarker(languageVersionSettings)
ifacmpne(label)
areturn(AsmTypes.OBJECT_TYPE)
mark(label)
})
}
}
}
}

internal fun InstructionAdapter.generateContinuationConstructorCall(
objectTypeForState: Type?,
methodNode: MethodNode,
Expand Down Expand Up @@ -939,13 +1031,8 @@ private fun getAllParameterTypes(desc: String, hasDispatchReceiver: Boolean, thi
listOfNotNull(if (!hasDispatchReceiver) null else Type.getObjectType(thisName)).toTypedArray() +
Type.getArgumentTypes(desc)

private fun allSuspensionPointsAreTailCalls(
thisName: String,
methodNode: MethodNode,
suspensionPoints: List<SuspensionPoint>
): Boolean {
val sourceFrames = MethodTransformer.analyze(thisName, methodNode, IgnoringCopyOperationSourceInterpreter())
val safelyReachableReturns = findSafelyReachableReturns(methodNode, sourceFrames)
private fun ExaminedMethodNode.allSuspensionPointsAreTailCalls(suspensionPoints: List<SuspensionPoint>): Boolean {
val safelyReachableReturns = findSafelyReachableReturns()

val instructions = methodNode.instructions
return suspensionPoints.all { suspensionPoint ->
Expand All @@ -963,7 +1050,7 @@ private fun allSuspensionPointsAreTailCalls(
if (insideTryBlock) return@all false

safelyReachableReturns[endIndex + 1]?.all { returnIndex ->
sourceFrames[returnIndex].top().sure {
sourceFrames[returnIndex]?.top().sure {
"There must be some value on stack to return"
}.insns.any { sourceInsn ->
sourceInsn?.let(instructions::indexOf) in beginIndex..endIndex
Expand All @@ -985,20 +1072,24 @@ internal class IgnoringCopyOperationSourceInterpreter : SourceInterpreter(Opcode
*
* @return indices of safely reachable returns for each instruction in the method node
*/
private fun findSafelyReachableReturns(methodNode: MethodNode, sourceFrames: Array<Frame<SourceValue?>?>): Array<Set<Int>?> {
val controlFlowGraph = ControlFlowGraph.build(methodNode)

private fun ExaminedMethodNode.findSafelyReachableReturns(): Array<Set<Int>?> {
val insns = methodNode.instructions
val reachableReturnsIndices = Array<Set<Int>?>(insns.size()) init@ { index ->
val reachableReturnsIndices = Array<Set<Int>?>(insns.size()) init@{ index ->
val insn = insns[index]

if (insn.opcode == Opcodes.ARETURN) {
if (insn.opcode == Opcodes.ARETURN && !insn.isAreturnAfterSafeUnitInstance()) {
if (isUnreachable(index, sourceFrames)) return@init null
return@init setOf(index)
}

if (!insn.isMeaningful || insn.opcode in SAFE_OPCODES || insn.isInvisibleInDebugVarInsn(methodNode) ||
isInlineMarker(insn)) {
// Since POP, PUSH Unit, ARETURN behaves like normal return in terms of tail-call optimization, set return index to POP
if (insn.isPopBeforeSafeUnitInstance()) {
return@init setOf(index)
}

if (!insn.isMeaningful || insn.opcode in SAFE_OPCODES || insn.isInvisibleInDebugVarInsn(methodNode) || isInlineMarker(insn)
|| insn.isSafeUnitInstance() || insn.isAreturnAfterSafeUnitInstance()
) {
setOf<Int>()
} else null
}
Expand Down Expand Up @@ -1029,7 +1120,8 @@ private fun findSafelyReachableReturns(methodNode: MethodNode, sourceFrames: Arr
}

// Check whether this instruction is unreachable, i.e. there is no path leading to this instruction
internal fun isUnreachable(index: Int, sourceFrames: Array<Frame<SourceValue?>?>) = sourceFrames[index] == null
internal fun isUnreachable(index: Int, sourceFrames: Array<Frame<SourceValue>?>): Boolean =
sourceFrames.size <= index || sourceFrames[index] == null

private fun AbstractInsnNode?.isInvisibleInDebugVarInsn(methodNode: MethodNode): Boolean {
val insns = methodNode.instructions
Expand Down
@@ -1,5 +1,5 @@
/*
* Copyright 2010-2018 JetBrains s.r.o. and Kotlin Programming Language contributors.
* Copyright 2010-2019 JetBrains s.r.o. and Kotlin Programming Language contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/

Expand All @@ -12,8 +12,10 @@ import org.jetbrains.kotlin.codegen.optimization.common.isMeaningful
import org.jetbrains.kotlin.codegen.optimization.common.removeAll
import org.jetbrains.kotlin.codegen.optimization.transformer.MethodTransformer
import org.jetbrains.kotlin.config.LanguageVersionSettings
import org.jetbrains.kotlin.utils.keysToMap
import org.jetbrains.org.objectweb.asm.Opcodes
import org.jetbrains.org.objectweb.asm.tree.*
import org.jetbrains.org.objectweb.asm.tree.analysis.SourceInterpreter

// Inliner emits a lot of locals during inlining.
// Remove all of them since these locals are
Expand Down Expand Up @@ -222,4 +224,22 @@ class RedundantLocalsEliminationMethodTransformer(private val languageVersionSet
assert(this is VarInsnNode)
return (this as VarInsnNode).`var`
}
}
}

private fun findSourceInstructions(
internalClassName: String,
methodNode: MethodNode,
insns: Collection<AbstractInsnNode>,
ignoreCopy: Boolean
): Map<AbstractInsnNode, Collection<AbstractInsnNode>> {
val frames = MethodTransformer.analyze(
internalClassName,
methodNode,
if (ignoreCopy) IgnoringCopyOperationSourceInterpreter() else SourceInterpreter()
)
return insns.keysToMap {
val index = methodNode.instructions.indexOf(it)
if (isUnreachable(index, frames)) return@keysToMap emptySet<AbstractInsnNode>()
frames[index].getStack(0).insns
}
}

0 comments on commit cc06798

Please sign in to comment.