Skip to content

Commit

Permalink
Merge pull request scala#10022 from sjrd/codegen-destinations
Browse files Browse the repository at this point in the history
Use explicit destinations in codegen to avoid uselessly jumping around.
  • Loading branch information
lrytz committed May 4, 2022
2 parents ae5d4e8 + 5fd3ca4 commit 5786230
Show file tree
Hide file tree
Showing 11 changed files with 738 additions and 130 deletions.
260 changes: 209 additions & 51 deletions src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -185,32 +185,44 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder {
generatedType
}

def genLoadIf(tree: If, expectedType: BType): BType = {
def genLoadIfTo(tree: If, expectedType: BType, dest: LoadDestination): BType = {
val If(condp, thenp, elsep) = tree

val success = new asm.Label
val failure = new asm.Label

val hasElse = !elsep.isEmpty
val postIf = if (hasElse) new asm.Label else failure

genCond(condp, success, failure, targetIfNoJump = success)
markProgramPoint(success)

val thenKind = tpeTK(thenp)
val elseKind = if (!hasElse) UNIT else tpeTK(elsep)
def hasUnitBranch = (thenKind == UNIT || elseKind == UNIT)
val resKind = if (hasUnitBranch) UNIT else tpeTK(tree)

genLoad(thenp, resKind)
if (hasElse) { bc goTo postIf }
markProgramPoint(failure)
if (hasElse) {
genLoad(elsep, resKind)
markProgramPoint(postIf)
if (dest == LoadDestination.FallThrough) {
if (hasElse) {
val thenKind = tpeTK(thenp)
val elseKind = tpeTK(elsep)
def hasUnitBranch = (thenKind == UNIT || elseKind == UNIT)
val resKind = if (hasUnitBranch) UNIT else tpeTK(tree)

val postIf = new asm.Label
genLoadTo(thenp, resKind, LoadDestination.Jump(postIf))
markProgramPoint(failure)
genLoadTo(elsep, resKind, LoadDestination.FallThrough)
markProgramPoint(postIf)
resKind
} else {
genLoad(thenp, UNIT)
markProgramPoint(failure)
UNIT
}
} else {
genLoadTo(thenp, expectedType, dest)
markProgramPoint(failure)
if (hasElse)
genLoadTo(elsep, expectedType, dest)
else
genAdaptAndSendToDest(UNIT, expectedType, dest)
expectedType
}

resKind
}

def genPrimitiveOp(tree: Apply, expectedType: BType): BType = {
Expand Down Expand Up @@ -257,13 +269,20 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder {
}

/* Generate code for trees that produce values on the stack */
def genLoad(tree: Tree, expectedType: BType) {
def genLoad(tree: Tree, expectedType: BType): Unit =
genLoadTo(tree, expectedType, LoadDestination.FallThrough)

/* Generate code for trees that produce values, sent to a given `LoadDestination`. */
def genLoadTo(tree: Tree, expectedType: BType, dest: LoadDestination): Unit = {
var generatedType = expectedType
var generatedDest: LoadDestination = LoadDestination.FallThrough

lineNumber(tree)

tree match {
case lblDf : LabelDef => genLabelDef(lblDf, expectedType)
case lblDf : LabelDef =>
genLabelDefTo(lblDf, expectedType, dest)
generatedDest = dest

case ValDef(_, nme.THIS, _, _) =>
debuglog("skipping trivial assign to _$this: " + tree)
Expand All @@ -283,22 +302,42 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder {
generatedType = UNIT

case t : If =>
generatedType = genLoadIf(t, expectedType)
generatedType = genLoadIfTo(t, expectedType, dest)
generatedDest = dest

case r : Return =>
genReturn(r)
generatedType = expectedType
generatedDest = LoadDestination.Return

case t : Try =>
generatedType = genLoadTry(t)

case Throw(expr) =>
generatedType = genThrow(expr)
val thrownKind = tpeTK(expr)
genLoadTo(expr, thrownKind, LoadDestination.Throw)
generatedDest = LoadDestination.Throw

case New(tpt) =>
abort(s"Unexpected New(${tpt.summaryString}/$tpt) reached GenBCode.\n" +
" Call was genLoad" + ((tree, expectedType)))

case app @ Apply(fun, args) if fun.symbol.isLabel =>
// jump to a label
val sym = fun.symbol
getJumpDestOrCreate(sym) match {
case JumpDestination.Regular(label) =>
val lblDef = labelDef.getOrElse(sym, {
abort("Not found: " + sym + " in " + labelDef)
})
genLoadLabelArguments(args, lblDef, app.pos)
bc goTo label
generatedDest = LoadDestination.Jump(label)
case JumpDestination.LoadArgTo(paramType, jumpDest) =>
val List(arg) = args
genLoadTo(arg, paramType, jumpDest)
generatedDest = jumpDest
}

case app : Apply =>
generatedType = genApply(app, expectedType)

Expand Down Expand Up @@ -370,11 +409,17 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder {
case _ => genConstant(value); generatedType = tpeTK(tree)
}

case blck : Block => genBlock(blck, expectedType)
case blck : Block =>
genBlockTo(blck, expectedType, dest)
generatedDest = dest

case Typed(Super(_, _), _) => genLoad(This(claszSymbol), expectedType)
case Typed(Super(_, _), _) =>
genLoadTo(This(claszSymbol), expectedType, dest)
generatedDest = dest

case Typed(expr, _) => genLoad(expr, expectedType)
case Typed(expr, _) =>
genLoadTo(expr, expectedType, dest)
generatedDest = dest

case Assign(_, _) =>
generatedType = UNIT
Expand All @@ -384,20 +429,40 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder {
generatedType = genArrayValue(av)

case mtch : Match =>
generatedType = genMatch(mtch)
generatedType = genMatchTo(mtch, expectedType, dest)
generatedDest = dest

case EmptyTree => if (expectedType != UNIT) { emitZeroOf(expectedType) }

case _ => abort(s"Unexpected tree in genLoad: $tree/${tree.getClass} at: ${tree.pos}")
}

// emit conversion
if (generatedType != expectedType) {
adapt(generatedType, expectedType)
}
// emit conversion and send to the right destination
if (generatedDest == LoadDestination.FallThrough)
genAdaptAndSendToDest(generatedType, expectedType, dest)

} // end of GenBCode.genLoad()

def genAdaptAndSendToDest(generatedType: BType, expectedType: BType, dest: LoadDestination): Unit = {
if (generatedType != expectedType)
adapt(generatedType, expectedType)

dest match {
case LoadDestination.FallThrough =>
()
case LoadDestination.Jump(label) =>
bc goTo label
case LoadDestination.Return =>
bc emitRETURN returnType
case LoadDestination.Throw =>
val thrownKind = expectedType
// `throw null` is valid although scala.Null (as defined in src/library-aux) isn't a subtype of Throwable.
// Similarly for scala.Nothing (again, as defined in src/library-aux).
assert(thrownKind.isNullType || thrownKind.isNothingType || thrownKind.asClassBType.isSubtypeOf(jlThrowableRef).get)
emit(asm.Opcodes.ATHROW)
}
}

// ---------------- field load and store ----------------

/*
Expand Down Expand Up @@ -475,27 +540,28 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder {
}
}

private def genLabelDef(lblDf: LabelDef, expectedType: BType) {
private def genLabelDefTo(lblDf: LabelDef, expectedType: BType, dest: LoadDestination): Unit = {
// duplication of LabelDefs contained in `finally`-clauses is handled when emitting RETURN. No bookkeeping for that required here.
// no need to call index() over lblDf.params, on first access that magic happens (moreover, no LocalVariableTable entries needed for them).
markProgramPoint(programPoint(lblDf.symbol))

// If we get inside genLabelDefTo, no one has or will register a non-regular jump destination for this LabelDef
val JumpDestination.Regular(label) = getJumpDestOrCreate(lblDf.symbol)
markProgramPoint(label)
lineNumber(lblDf)
genLoad(lblDf.rhs, expectedType)
genLoadTo(lblDf.rhs, expectedType, dest)
}

private def genReturn(r: Return) {
val Return(expr) = r
val returnedKind = tpeTK(expr)
genLoad(expr, returnedKind)
adapt(returnedKind, returnType)
val saveReturnValue = (returnType != UNIT)
lineNumber(r)

cleanups match {
case Nil =>
// not an assertion: !shouldEmitCleanup (at least not yet, pendingCleanups() may still have to run, and reset `shouldEmitCleanup`.
bc emitRETURN returnType
genLoadTo(expr, returnType, LoadDestination.Return)
case nextCleanup :: rest =>
genLoad(expr, returnType)
lineNumber(r)
val saveReturnValue = (returnType != UNIT)
if (saveReturnValue) {
// regarding return value, the protocol is: in place of a `return-stmt`, a sequence of `adapt, store, jump` are inserted.
if (earlyReturnVar == null) {
Expand Down Expand Up @@ -653,11 +719,7 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder {
case app @ Apply(fun, args) =>
val sym = fun.symbol

if (sym.isLabel) { // jump to a label
def notFound() = abort("Not found: " + sym + " in " + labelDef)
genLoadLabelArguments(args, labelDef.getOrElse(sym, notFound()), app.pos)
bc goTo programPoint(sym)
} else if (isPrimitive(sym)) { // primitive method call
if (isPrimitive(sym)) { // primitive method call
generatedType = genPrimitiveOp(app, expectedType)
} else { // normal method call
def isTraitSuperAccessorBodyCall = app.hasAttachment[UseInvokeSpecial.type]
Expand Down Expand Up @@ -764,10 +826,18 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder {
*
* On a second pass, we emit the switch blocks, one for each different target.
*/
private def genMatch(tree: Match): BType = {
private def genMatchTo(tree: Match, expectedType: BType, dest: LoadDestination): BType = {
lineNumber(tree)
genLoad(tree.selector, INT)
val generatedType = tpeTK(tree)

val (generatedType, postMatch, postMatchDest) = {
if (dest == LoadDestination.FallThrough) {
val postMatch = new asm.Label
(tpeTK(tree), postMatch, LoadDestination.Jump(postMatch))
} else {
(expectedType, null, dest)
}
}

var flatKeys: List[Int] = Nil
var targets: List[asm.Label] = Nil
Expand Down Expand Up @@ -801,24 +871,112 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder {
bc.emitSWITCH(mkArrayReverse(flatKeys), mkArray(targets.reverse), default, MIN_SWITCH_DENSITY)

// emit switch-blocks.
val postMatch = new asm.Label
for (sb <- switchBlocks.reverse) {
val (caseLabel, caseBody) = sb
markProgramPoint(caseLabel)
genLoad(caseBody, generatedType)
bc goTo postMatch
genLoadTo(caseBody, generatedType, postMatchDest)
}

markProgramPoint(postMatch)
if (postMatch != null)
markProgramPoint(postMatch)
generatedType
}

def genBlock(tree: Block, expectedType: BType) {
def genBlockTo(tree: Block, expectedType: BType, dest: LoadDestination): Unit = {
val Block(stats, expr) = tree
val savedScope = varsInScope
varsInScope = Nil
stats foreach genStat
genLoad(expr, expectedType)

/* Optimize common patmat-generated shapes, so that we can push the
* `dest` down the various cases.
*
* The two most common shapes are
*
* {
* initStats
* case1() { ... matchEnd(caseBody1) ... }
* ...
* caseN() { ... matchEnd(caseBodyN) ... }
* matchEnd(x: R) {
* x
* }
* }
*
* for non-unit results, and
*
* {
* initStats
* case1() { ... matchEnd(caseBody1) ... }
* ...
* caseN() { ... matchEnd(caseBodyN) ... }
* matchEnd(x: BoxedUnit$) {
* ()
* }
* }
*
* for unit results.
*
* If we do nothing, when we encounter the calls to `matchEnd` in the
* cases, we don't know yet what is the final `dest` of the block, so we
* cannot generate good code.
*
* Here, we recognize those shapes, and if we find them, we record a
* priori the ultimate `dest` of the full match. This allows to push
* `dest` to all the cases.
*
* For the transformation to be correct, control must not flow into
* `matchEnd` "normally" (i.e., not through a label apply). This is
* always the case for patmat-generated `matchEnd`s, but not for
* arbitrary LabelDefs. In particular, it is not true for
* tailrec-generarted LabelDefs. Therefore, we add specific tests to
* only recognize patmat-generated `matchEnd` labels this way.
*
* There are some rare cases where patmat will generate other shapes.
* For example, the source-code shape `return x match { ... }` transfers
* the `return` right around the `matchEnd`, for some reason, instead of
* around the entire Block. Those rare shapes are not recognized here.
* For them, the default (non-optimal) codegen will apply.
*/

def default(): Unit = {
stats foreach genStat
genLoadTo(expr, expectedType, dest)
}

def optimizedMatch(sym: Symbol): Unit = {
if (dest == LoadDestination.FallThrough) {
val label = new asm.Label
jumpDest += ((sym -> JumpDestination.LoadArgTo(expectedType, LoadDestination.Jump(label))))
stats foreach genStat
markProgramPoint(label)
} else {
jumpDest += ((sym -> JumpDestination.LoadArgTo(expectedType, dest)))
stats foreach genStat
}
}

def isMatchEndLabelDef(tree: LabelDef): Boolean =
treeInfo.hasSynthCaseSymbol(tree) && tree.symbol.name.startsWith("matchEnd")

expr match {
case matchEnd @ LabelDef(_, singleArg :: Nil, body) if isMatchEndLabelDef(matchEnd) =>
val sym = matchEnd.symbol
body match {
case _ if jumpDest.contains(sym) =>
// We already generated a jump to this label in the regular way; we cannot optimize anymore
default()
case bodyIdent: Ident if bodyIdent.symbol == singleArg.symbol =>
optimizedMatch(sym)
case Literal(Constant(())) =>
optimizedMatch(sym)
case _ =>
default()
}

case _ =>
default()
}

val end = currProgramPoint()
if (emitVars) { // add entries to LocalVariableTable JVM attribute
for ((sym, start) <- varsInScope.reverse) { emitLocalVarScope(sym, start, end) }
Expand Down
Loading

0 comments on commit 5786230

Please sign in to comment.