Skip to content

Commit

Permalink
when: emit switch for String if possible
Browse files Browse the repository at this point in the history
Effectively, the following when structure:

  when (s) {
    s1, s2 -> e1,
    s3 -> e2,
    s4 -> e3,
    ...
    else -> e
  }

is implemented as:

  when (s.hashCode()) {
    h1 -> {
      if (s == s1)
        e1
      else if (s == s2)
        e1
      else if (s == s3)
        e2
      else
        e
    }
    h2 -> if (s == s3) e2 else e,
    ...
    else -> e
  }

where s1.hashCode() == s2.hashCode() == s3.hashCode() == h1,
      s4.hashCode() == h2.

A tableswitch or lookupswitch is used for the hash code lookup.

Change-Id: I087bf623dbb4a41d3cc64399a1b42342a50757a6
  • Loading branch information
ting-yuan authored and max-kammerer committed Mar 20, 2019
1 parent 1a9ed88 commit f6cf434
Show file tree
Hide file tree
Showing 22 changed files with 573 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,17 @@ import java.util.*
class SwitchGenerator(private val expression: IrWhen, private val data: BlockInfo, private val codegen: ExpressionCodegen) {
private val mv = codegen.mv

data class ExpressionToLabel(val expression: IrExpression, val label: Label)
data class CallToLabel(val call: IrCall, val label: Label)
data class ValueToLabel(val value: Any?, val label: Label)

// @return null if the IrWhen cannot be emitted as lookupswitch or tableswitch.
fun generate(): StackValue? {
val endLabel = Label()
var defaultLabel = endLabel
val thenExpressions = ArrayList<Pair<IrExpression, Label>>()
val expressionToLabels = ArrayList<ExpressionToLabel>()
var elseExpression: IrExpression? = null
val allConditions = ArrayList<Pair<IrCall, Label>>()
val callToLabels = ArrayList<CallToLabel>()

// Parse the when structure. Note that the condition can be nested. See matchConditions() for details.
for (branch in expression.branches) {
Expand All @@ -38,41 +42,95 @@ class SwitchGenerator(private val expression: IrWhen, private val data: BlockInf
} else {
val conditions = matchConditions(branch.condition) ?: return null
val thenLabel = Label()
thenExpressions.add(Pair(branch.result, thenLabel))
allConditions += conditions.map { Pair(it, thenLabel) }
expressionToLabels.add(ExpressionToLabel(branch.result, thenLabel))
callToLabels += conditions.map { CallToLabel(it, thenLabel) }
}
}

// IF is more compact when there are only 1 or fewer branches, in addition to else.
if (allConditions.size <= 1)
// switch isn't applicable if there's no case at all, e.g., when() { else -> ... }
if (callToLabels.size == 0)
return null

if (areConstIntComparisons(allConditions.map { it.first })) {
// if all conditions are CALL EQEQ(tmp_variable, some_int_constant)
val cases = allConditions.mapTo(ArrayList()) { Pair((it.first.getValueArgument(1) as IrConst<*>).value as Int, it.second) }
val subject = allConditions[0].first.getValueArgument(0)!! as IrGetValue
return gen(cases, subject, defaultLabel, endLabel, elseExpression, thenExpressions)
}
val calls = callToLabels.map { it.call }

// TODO: String, Enum, etc.
return null
// Checks if all conditions are CALL EQEQ(tmp_variable, some_constant)
if (!areConstComparisons(calls))
return null

// Subject should be the same for all conditions. Let's pick the first.
val subject = callToLabels[0].call.getValueArgument(0)!! as IrGetValue

// Don't generate repeated cases, which are unreachable but allowed in Kotlin.
// Only keep the first encountered case:
val cases =
callToLabels.map { ValueToLabel((it.call.getValueArgument(1) as IrConst<*>).value, it.label) }.distinctBy { it.value }

// Remove labels and "then expressions" that are not reachable.
val reachableLabels = HashSet(cases.map { it.label })
expressionToLabels.removeIf { it.label !in reachableLabels }

return when {
areConstIntComparisons(calls) ->
IntSwitch(
subject,
defaultLabel,
endLabel,
elseExpression,
expressionToLabels,
cases
)
areConstStringComparisons(calls) ->
StringSwitch(
subject,
defaultLabel,
endLabel,
elseExpression,
expressionToLabels,
cases
)
else -> null // TODO: Enum, etc.
}?.genOptimizedIfEnoughCases()
}

// A lookup/table switch can be used if...
private fun areConstIntComparisons(conditions: List<IrCall>): Boolean {
// 1. All branches are CALL 'EQEQ(Any?, Any?)': Boolean
private fun areConstComparisons(conditions: List<IrCall>): Boolean {
// All branches must be CALL 'EQEQ(Any?, Any?)': Boolean
if (conditions.any { it.symbol != codegen.classCodegen.context.irBuiltIns.eqeqSymbol })
return false

// 2. All types of variables involved in comparison are Int.
// 3. All arg0 refer to the same value.
// All LHS refer to the same tmp variable.
val lhs = conditions.map { it.getValueArgument(0) as? IrGetValue }
if (lhs.any { it == null || it.symbol != lhs[0]!!.symbol || !it.type.isInt() })
if (lhs.any { it == null || it.symbol != lhs[0]!!.symbol })
return false

// All RHS are constants
if (conditions.any { it.getValueArgument(1) !is IrConst<*> })
return false

return true
}

private fun areConstIntComparisons(conditions: List<IrCall>): Boolean {
return checkTypeSpecifics(conditions, { it.isInt() }, { it.kind == IrConstKind.Int })
}

private fun areConstStringComparisons(conditions: List<IrCall>): Boolean {
return checkTypeSpecifics(
conditions,
{ it.isString() || it.isNullableString() },
{ it.kind == IrConstKind.String || it.kind == IrConstKind.Null })
}

private fun checkTypeSpecifics(
conditions: List<IrCall>,
subjectTypePredicate: (IrType) -> Boolean,
irConstPredicate: (IrConst<*>) -> Boolean
): Boolean {
val lhs = conditions.map { it.getValueArgument(0) as IrGetValue }
if (lhs.any { !subjectTypePredicate(it.type) })
return false

// 4. All arg1 are IrConst<*>.
val rhs = conditions.map { it.getValueArgument(1) as? IrConst<*> }
if (rhs.any { it == null || it.kind != IrConstKind.Int })
val rhs = conditions.map { it.getValueArgument(1) as IrConst<*> }
if (rhs.any { !irConstPredicate(it) })
return false

return true
Expand Down Expand Up @@ -137,58 +195,187 @@ class SwitchGenerator(private val expression: IrWhen, private val data: BlockInf
private fun coerceNotToUnit(fromType: Type, fromKotlinType: KotlinType?, toKotlinType: KotlinType): StackValue =
codegen.coerceNotToUnit(fromType, fromKotlinType, toKotlinType)

private fun gen(
cases: ArrayList<Pair<Int, Label>>,
abstract inner class Switch(
val subject: IrGetValue,
val defaultLabel: Label,
val endLabel: Label,
val elseExpression: IrExpression?,
val expressionToLabels: ArrayList<ExpressionToLabel>
) {
open fun shouldOptimize() = false

open fun genOptimizedIfEnoughCases(): StackValue? {
if (!shouldOptimize())
return null

genSubject()
genSwitch()
genThenExpressions()
val result = genElseExpression()

mv.mark(endLabel)
return result
}

protected abstract fun genSubject()

protected abstract fun genSwitch()

protected fun genIntSwitch(unsortedIntCases: List<ValueToLabel>) {
val intCases = unsortedIntCases.sortedBy { it.value as Int }
val caseMin = intCases.first().value as Int
val caseMax = intCases.last().value as Int
val rangeLength = caseMax.toLong() - caseMin.toLong() + 1L

// Emit either tableswitch or lookupswitch, depending on the code size.
//
// lookupswitch is 2X as large as tableswitch with the same entries. However, lookupswitch is sparse while tableswitch must
// enumerate all the entries in the range.
if (preferLookupOverSwitch(intCases.size, rangeLength)) {
mv.lookupswitch(defaultLabel, intCases.map { it.value as Int }.toIntArray(), intCases.map { it.label }.toTypedArray())
} else {
val labels = Array(rangeLength.toInt()) { defaultLabel }
for (case in intCases)
labels[case.value as Int - caseMin] = case.label
mv.tableswitch(caseMin, caseMax, defaultLabel, *labels)
}
}

protected fun genThenExpressions() {
for ((thenExpression, label) in expressionToLabels) {
mv.visitLabel(label)
val stackValue = thenExpression.run { gen(this, data) }
coerceNotToUnit(stackValue.type, stackValue.kotlinType, expression.type.toKotlinType())
mv.goTo(endLabel)
}
}

protected fun genElseExpression(): StackValue {
return if (elseExpression == null) {
// There's no else part. No stack value will be generated.
StackValue.putUnitInstance(mv)
onStack(Type.VOID_TYPE)
} else {
// Generate the else part.
mv.visitLabel(defaultLabel)
val stackValue = elseExpression.run { gen(this, data) }
coerceNotToUnit(stackValue.type, stackValue.kotlinType, expression.type.toKotlinType())
}
}
}

inner class IntSwitch(
subject: IrGetValue,
defaultLabel: Label,
endLabel: Label,
elseExpression: IrExpression?,
thenExpressions: ArrayList<Pair<IrExpression, Label>>
): StackValue {
cases.sortBy { it.first }

// Emit the temporary variable for subject.
gen(subject, data)

val caseMin = cases.first().first
val caseMax = cases.last().first
val rangeLength = caseMax - caseMin + 1L

// Emit either tableswitch or lookupswitch, depending on the code size.
//
// lookupswitch is 2X as large as tableswitch with the same entries. However, lookupswitch is sparse while tableswitch must
// enumerate all the entries in the range.
if (preferLookupOverSwitch(cases.size, rangeLength)) {
mv.lookupswitch(defaultLabel, cases.map { it.first }.toIntArray(), cases.map { it.second }.toTypedArray())
} else {
val labels = Array(rangeLength.toInt()) { defaultLabel }
for (case in cases)
labels[case.first - caseMin] = case.second
mv.tableswitch(caseMin, caseMax, defaultLabel, *labels)
expressionToLabels: ArrayList<ExpressionToLabel>,
private val cases: List<ValueToLabel>
) : Switch(subject, defaultLabel, endLabel, elseExpression, expressionToLabels) {

// IF is more compact when there are only 1 or fewer branches, in addition to else.
override fun shouldOptimize() = cases.size > 1

override fun genSubject() {
gen(subject, data)
}

// all entries except else
for (thenExpression in thenExpressions) {
mv.visitLabel(thenExpression.second)
val stackValue = thenExpression.first.run { gen(this, data) }
coerceNotToUnit(stackValue.type, stackValue.kotlinType, expression.type.toKotlinType())
mv.goTo(endLabel)
override fun genSwitch() {
genIntSwitch(cases)
}
}

// The following when structure:
//
// when (s) {
// s1, s2 -> e1,
// s3 -> e2,
// s4 -> e3,
// ...
// else -> e
// }
//
// is implemented as:
//
// // if s is String?, generate the following null check:
// if (s == null)
// // jump to the case where null is handled, if defined.
// // otherwise, jump out of the when().
// ...
// ...
// when (s.hashCode()) {
// h1 -> {
// if (s == s1)
// e1
// else if (s == s2)
// e1
// else if (s == s3)
// e2
// else
// e
// }
// h2 -> if (s == s3) e2 else e,
// ...
// else -> e
// }
//
// where s1.hashCode() == s2.hashCode() == s3.hashCode() == h1,
// s4.hashCode() == h2.
//
// A tableswitch or lookupswitch is then used for the hash code lookup.

inner class StringSwitch(
subject: IrGetValue,
defaultLabel: Label,
endLabel: Label,
elseExpression: IrExpression?,
expressionToLabels: ArrayList<ExpressionToLabel>,
private val cases: List<ValueToLabel>
) : Switch(subject, defaultLabel, endLabel, elseExpression, expressionToLabels) {

private val hashToStringAndExprLabels = HashMap<Int, ArrayList<ValueToLabel>>()
private val hashAndSwitchLabels = ArrayList<ValueToLabel>()

init {
for (case in cases)
if (case.value != null) // null is handled specially and will never be dispatched from the switch.
hashToStringAndExprLabels.getOrPut(case.value.hashCode()) { ArrayList() }.add(
ValueToLabel(case.value, case.label)
)

for (key in hashToStringAndExprLabels.keys)
hashAndSwitchLabels.add(ValueToLabel(key, Label()))
}

// else
val result = if (elseExpression == null) {
// There's no else part. No stack value will be generated.
StackValue.putUnitInstance(mv)
onStack(Type.VOID_TYPE)
} else {
// Generate the else part.
mv.visitLabel(defaultLabel)
val stackValue = elseExpression.run { gen(this, data) }
coerceNotToUnit(stackValue.type, stackValue.kotlinType, expression.type.toKotlinType())
// Using a switch, the subject string has to be traversed at least twice (hash + comparison * N, where N is #strings hashed into the
// same bucket). The optimization isn't better than an IF cascade when #switch-targets <= 2.
override fun shouldOptimize() = hashAndSwitchLabels.size > 2

override fun genSubject() {
if (subject.type.isNullableString()) {
val nullLabel = cases.find { it.value == null }?.label ?: defaultLabel
gen(subject, data)
mv.ifnull(nullLabel)
}
gen(subject, data)
mv.invokevirtual("java/lang/String", "hashCode", "()I", false)
}

mv.mark(endLabel)
return result
override fun genSwitch() {
genIntSwitch(hashAndSwitchLabels)

// Multiple strings can be hashed into the same bucket.
// Generate an if cascade to resolve that for each bucket.
for ((hash, switchLabel) in hashAndSwitchLabels) {
mv.visitLabel(switchLabel)
for ((string, label) in hashToStringAndExprLabels[hash]!!) {
gen(subject, data)
mv.aconst(string)
mv.invokevirtual("java/lang/String", "equals", "(Ljava/lang/Object;)Z", false)
mv.ifne(label)
}
mv.goTo(defaultLabel)
}
}
}
}

20 changes: 20 additions & 0 deletions compiler/testData/codegen/box/when/edgeCases.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
fun foo(x: Int): String {
return when (x) {
2_147_483_647 -> "MAX"
-2_147_483_648 -> "MIN"
else -> "else"
}
}

fun box(): String {
if (foo(0) != "else")
return "0: " + foo(0).toString()

if (foo(Int.MAX_VALUE) != "MAX")
return "Int.MAX_VALUE: " + foo(Int.MAX_VALUE).toString()

if (foo(Int.MIN_VALUE) != "MIN")
return "Int.MIN_VALUE: " + foo(Int.MIN_VALUE).toString()

return "OK"
}
Loading

0 comments on commit f6cf434

Please sign in to comment.