From 7923e1d0113676ea1a3a47a212e4f71a11aeba15 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Mon, 2 May 2022 12:15:58 +0200 Subject: [PATCH 1/5] Add bytecode tests with the status quo of codegen control flow. --- .../tools/nsc/backend/jvm/BytecodeTest.scala | 469 ++++++++++++++++++ 1 file changed, 469 insertions(+) diff --git a/test/junit/scala/tools/nsc/backend/jvm/BytecodeTest.scala b/test/junit/scala/tools/nsc/backend/jvm/BytecodeTest.scala index 7682bdf04084..22bb492cd272 100644 --- a/test/junit/scala/tools/nsc/backend/jvm/BytecodeTest.scala +++ b/test/junit/scala/tools/nsc/backend/jvm/BytecodeTest.scala @@ -325,4 +325,473 @@ class BytecodeTest extends BytecodeTesting { val mm = classOf[scala.collection.immutable.TreeMap[_, _]].getDeclaredMethod(mn, classOf[Object]) assertEquals(mn, mm.getName) } + + @Test def tailrecControlFlow(): Unit = { + // Without change of the `this` value + + val sourceFoo = + s"""class Foo { + | @scala.annotation.tailrec // explicit @tailrec here + | final def fact(n: Int, acc: Int): Int = + | if (n == 0) acc + | else fact(n - 1, acc * n) + |} + """.stripMargin + + val fooClass = compileClass(sourceFoo) + + assertSameCode(getMethod(fooClass, "fact"), List( + Label(0), + VarOp(ILOAD, 1), + Op(ICONST_0), + Jump(IF_ICMPNE, Label(8)), + VarOp(ILOAD, 2), + Jump(GOTO, Label(20)), + Label(8), + VarOp(ILOAD, 1), + Op(ICONST_1), + Op(ISUB), + VarOp(ILOAD, 2), + VarOp(ILOAD, 1), + Op(IMUL), + VarOp(ISTORE, 2), + VarOp(ISTORE, 1), + Jump(GOTO, Label(0)), + Label(20), + Op(IRETURN), + )) + + // With changing the `this` value + + val sourceIntList = + s"""class IntList(head: Int, tail: IntList) { + | // implicit @tailrec + | final def sum(acc: Int): Int = { + | val t = tail + | if (t == null) acc + head + | else t.sum(acc + head) + | } + |} + """.stripMargin + + val intListClass = compileClass(sourceIntList) + + assertSameCode(getMethod(intListClass, "sum"), List( + Label(0), + VarOp(ALOAD, 0), + Field(GETFIELD, "IntList", "tail", "LIntList;"), + VarOp(ASTORE, 3), + VarOp(ALOAD, 3), + Jump(IFNONNULL, Label(15)), + VarOp(ILOAD, 1), + VarOp(ALOAD, 0), + Field(GETFIELD, "IntList", "head", "I"), + Op(IADD), + Jump(GOTO, Label(26)), + Label(15), + VarOp(ALOAD, 3), + VarOp(ILOAD, 1), + VarOp(ALOAD, 0), + Field(GETFIELD, "IntList", "head", "I"), + Op(IADD), + VarOp(ISTORE, 1), + VarOp(ASTORE, 0), + Jump(GOTO, Label(0)), + Label(26), + Op(IRETURN), + )) + } + + @Test def patmatControlFlow(): Unit = { + val source = + s"""class Foo { + | def m1(xs: List[Int]): Int = xs match { + | case x :: xr => x + | case Nil => 20 + | } + | + | def m2(xs: List[Int]): Int = xs match { + | case (1 | 2) :: xr => 10 + | case x :: xr => x + | case _ => 20 + | } + | + | def m3(xs: List[Int]): Unit = xs match { + | case x :: _ => println(x) + | case Nil => println("nil") + | } + |} + """.stripMargin + + val fooClass = compileClass(source) + + // --------------- + + assertSameCode(getMethod(fooClass, "m1"), List( + VarOp(ALOAD, 1), + VarOp(ASTORE, 3), + VarOp(ALOAD, 3), + TypeOp(INSTANCEOF, "scala/collection/immutable/$colon$colon"), + Jump(IFEQ, Label(20)), + VarOp(ALOAD, 3), + TypeOp(CHECKCAST, "scala/collection/immutable/$colon$colon"), + VarOp(ASTORE, 4), + VarOp(ALOAD, 4), + Invoke(INVOKEVIRTUAL, "scala/collection/immutable/$colon$colon", "head", "()Ljava/lang/Object;", false), + Invoke(INVOKESTATIC, "scala/runtime/BoxesRunTime", "unboxToInt", "(Ljava/lang/Object;)I", false), + VarOp(ISTORE, 5), + VarOp(ILOAD, 5), + VarOp(ISTORE, 2), + Jump(GOTO, Label(44)), + Label(20), + Jump(GOTO, Label(23)), + Label(23), + Field(GETSTATIC, "scala/collection/immutable/Nil$", "MODULE$", "Lscala/collection/immutable/Nil$;"), + VarOp(ALOAD, 3), + Invoke(INVOKEVIRTUAL, "java/lang/Object", "equals", "(Ljava/lang/Object;)Z", false), + Jump(IFEQ, Label(33)), + IntOp(BIPUSH, 20), + VarOp(ISTORE, 2), + Jump(GOTO, Label(44)), + Label(33), + Jump(GOTO, Label(36)), + Label(36), + TypeOp(NEW, "scala/MatchError"), + Op(DUP), + VarOp(ALOAD, 3), + Invoke(INVOKESPECIAL, "scala/MatchError", "", "(Ljava/lang/Object;)V", false), + Op(ATHROW), + Label(44), + VarOp(ILOAD, 2), + Op(IRETURN), + )) + + // --------------- + + assertSameCode(getMethod(fooClass, "m2"), List( + Op(ICONST_0), + VarOp(ISTORE, 4), + Op(ACONST_NULL), + VarOp(ASTORE, 5), + VarOp(ALOAD, 1), + VarOp(ASTORE, 6), + VarOp(ALOAD, 6), + TypeOp(INSTANCEOF, "scala/collection/immutable/$colon$colon"), + Jump(IFEQ, Label(57)), + Op(ICONST_1), + VarOp(ISTORE, 4), + VarOp(ALOAD, 6), + TypeOp(CHECKCAST, "scala/collection/immutable/$colon$colon"), + VarOp(ASTORE, 5), + VarOp(ALOAD, 5), + Invoke(INVOKEVIRTUAL, "scala/collection/immutable/$colon$colon", "head", "()Ljava/lang/Object;", false), + Invoke(INVOKESTATIC, "scala/runtime/BoxesRunTime", "unboxToInt", "(Ljava/lang/Object;)I", false), + VarOp(ISTORE, 7), + Op(ICONST_1), + VarOp(ILOAD, 7), + Jump(IF_ICMPNE, Label(28)), + Op(ICONST_1), + VarOp(ISTORE, 3), + Jump(GOTO, Label(47)), + Label(28), + Jump(GOTO, Label(31)), + Label(31), + Op(ICONST_2), + VarOp(ILOAD, 7), + Jump(IF_ICMPNE, Label(39)), + Op(ICONST_1), + VarOp(ISTORE, 3), + Jump(GOTO, Label(47)), + Label(39), + Jump(GOTO, Label(42)), + Label(42), + Op(ICONST_0), + VarOp(ISTORE, 3), + Jump(GOTO, Label(47)), + Label(47), + VarOp(ILOAD, 3), + Jump(IFEQ, Label(54)), + IntOp(BIPUSH, 10), + VarOp(ISTORE, 2), + Jump(GOTO, Label(82)), + Label(54), + Jump(GOTO, Label(60)), + Label(57), + Jump(GOTO, Label(60)), + Label(60), + VarOp(ILOAD, 4), + Jump(IFEQ, Label(73)), + VarOp(ALOAD, 5), + Invoke(INVOKEVIRTUAL, "scala/collection/immutable/$colon$colon", "head", "()Ljava/lang/Object;", false), + Invoke(INVOKESTATIC, "scala/runtime/BoxesRunTime", "unboxToInt", "(Ljava/lang/Object;)I", false), + VarOp(ISTORE, 8), + VarOp(ILOAD, 8), + VarOp(ISTORE, 2), + Jump(GOTO, Label(82)), + Label(73), + Jump(GOTO, Label(76)), + Label(76), + IntOp(BIPUSH, 20), + VarOp(ISTORE, 2), + Jump(GOTO, Label(82)), + Label(82), + VarOp(ILOAD, 2), + Op(IRETURN), + )) + + // --------------- + + assertSameCode(getMethod(fooClass, "m3"), List( + VarOp(ALOAD, 1), + VarOp(ASTORE, 3), + VarOp(ALOAD, 3), + TypeOp(INSTANCEOF, "scala/collection/immutable/$colon$colon"), + Jump(IFEQ, Label(24)), + VarOp(ALOAD, 3), + TypeOp(CHECKCAST, "scala/collection/immutable/$colon$colon"), + VarOp(ASTORE, 4), + VarOp(ALOAD, 4), + Invoke(INVOKEVIRTUAL, "scala/collection/immutable/$colon$colon", "head", "()Ljava/lang/Object;", false), + Invoke(INVOKESTATIC, "scala/runtime/BoxesRunTime", "unboxToInt", "(Ljava/lang/Object;)I", false), + VarOp(ISTORE, 5), + Field(GETSTATIC, "scala/Predef$", "MODULE$", "Lscala/Predef$;"), + VarOp(ILOAD, 5), + Invoke(INVOKESTATIC, "scala/runtime/BoxesRunTime", "boxToInteger", "(I)Ljava/lang/Integer;", false), + Invoke(INVOKEVIRTUAL, "scala/Predef$", "println", "(Ljava/lang/Object;)V", false), + Field(GETSTATIC, "scala/runtime/BoxedUnit", "UNIT", "Lscala/runtime/BoxedUnit;"), + VarOp(ASTORE, 2), + Jump(GOTO, Label(51)), + Label(24), + Jump(GOTO, Label(27)), + Label(27), + Field(GETSTATIC, "scala/collection/immutable/Nil$", "MODULE$", "Lscala/collection/immutable/Nil$;"), + VarOp(ALOAD, 3), + Invoke(INVOKEVIRTUAL, "java/lang/Object", "equals", "(Ljava/lang/Object;)Z", false), + Jump(IFEQ, Label(40)), + Field(GETSTATIC, "scala/Predef$", "MODULE$", "Lscala/Predef$;"), + Ldc(LDC, "nil"), + Invoke(INVOKEVIRTUAL, "scala/Predef$", "println", "(Ljava/lang/Object;)V", false), + Field(GETSTATIC, "scala/runtime/BoxedUnit", "UNIT", "Lscala/runtime/BoxedUnit;"), + VarOp(ASTORE, 2), + Jump(GOTO, Label(51)), + Label(40), + Jump(GOTO, Label(43)), + Label(43), + TypeOp(NEW, "scala/MatchError"), + Op(DUP), + VarOp(ALOAD, 3), + Invoke(INVOKESPECIAL, "scala/MatchError", "", "(Ljava/lang/Object;)V", false), + Op(ATHROW), + Label(51), + Op(RETURN), + )) + } + + @Test def switchControlFlow(): Unit = { + val source = + s"""import scala.annotation.switch + | + |class Foo { + | def m1(x: Int): Int = (x: @switch) match { + | case 1 => 10 + | case 7 => 20 + | case 8 => 30 + | case 9 => 40 + | case _ => x + | } + | + | def m2(x: Int): Int = (x: @switch) match { + | case (1 | 2) => 10 + | case 7 => 20 + | case 8 => 30 + | case c if c > 100 => 20 + | } + |} + """.stripMargin + + val fooClass = compileClass(source) + + // --------------- + + assertSameCode(getMethod(fooClass, "m1"), List( + VarOp(ILOAD, 1), + VarOp(ISTORE, 2), + VarOp(ILOAD, 2), + LookupSwitch(LOOKUPSWITCH, Label(26), List(1, 7, 8, 9), List(Label(6), Label(11), Label(16), Label(21))), + Label(6), + IntOp(BIPUSH, 10), + Jump(GOTO, Label(31)), + Label(11), + IntOp(BIPUSH, 20), + Jump(GOTO, Label(31)), + Label(16), + IntOp(BIPUSH, 30), + Jump(GOTO, Label(31)), + Label(21), + IntOp(BIPUSH, 40), + Jump(GOTO, Label(31)), + Label(26), + VarOp(ILOAD, 1), + Jump(GOTO, Label(31)), + Label(31), + Op(IRETURN), + )) + + // --------------- + + assertSameCode(getMethod(fooClass, "m2"), List( + VarOp(ILOAD, 1), + VarOp(ISTORE, 2), + VarOp(ILOAD, 2), + LookupSwitch(LOOKUPSWITCH, Label(21), List(1, 2, 7, 8), List(Label(6), Label(6), Label(11), Label(16))), + Label(6), + IntOp(BIPUSH, 10), + Jump(GOTO, Label(40)), + Label(11), + IntOp(BIPUSH, 20), + Jump(GOTO, Label(40)), + Label(16), + IntOp(BIPUSH, 30), + Jump(GOTO, Label(40)), + Label(21), + VarOp(ILOAD, 2), + IntOp(BIPUSH, 100), + Jump(IF_ICMPLE, Label(29)), + IntOp(BIPUSH, 20), + Jump(GOTO, Label(37)), + Label(29), + TypeOp(NEW, "scala/MatchError"), + Op(DUP), + VarOp(ILOAD, 2), + Invoke(INVOKESTATIC, "scala/runtime/BoxesRunTime", "boxToInteger", "(I)Ljava/lang/Integer;", false), + Invoke(INVOKESPECIAL, "scala/MatchError", "", "(Ljava/lang/Object;)V", false), + Op(ATHROW), + Label(37), + Jump(GOTO, Label(40)), + Label(40), + Op(IRETURN), + )) + } + + @Test def ifThenElseControlFlow(): Unit = { + /* This is a test case coming from the Scala.js linker, where in Scala 2 we + * had to introduce a "useless" `return` to make the bytecode size smaller, + * measurably increasing performance (!). + */ + + val source = + s"""import java.io.Writer + | + |final class SourceMapWriter(out: Writer) { + | private val Base64Map = + | "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + + | "abcdefghijklmnopqrstuvwxyz" + + | "0123456789+/" + | + | private final val VLQBaseShift = 5 + | private final val VLQBase = 1 << VLQBaseShift + | private final val VLQBaseMask = VLQBase - 1 + | private final val VLQContinuationBit = VLQBase + | + | def entryPoint(value: Int): Unit = writeBase64VLQ(value) + | + | private def writeBase64VLQ(value0: Int): Unit = { + | val signExtended = value0 >> 31 + | val value = (((value0 ^ signExtended) - signExtended) << 1) | (signExtended & 1) + | if (value < 26) { + | out.write('A' + value) // was `return out...` + | } else { + | def writeBase64VLQSlowPath(value0: Int): Unit = { + | var value = value0 + | do { + | var digit = value & VLQBaseMask + | value = value >>> VLQBaseShift + | if (value != 0) + | digit |= VLQContinuationBit + | out.write(Base64Map.charAt(digit)) + | } while (value != 0) + | } + | writeBase64VLQSlowPath(value) + | } + | } + |} + """.stripMargin + + val sourceMapWriterClass = compileClass(source) + + // --------------- + + assertSameCode(getMethod(sourceMapWriterClass, "writeBase64VLQ"), List( + VarOp(ILOAD, 1), + IntOp(BIPUSH, 31), + Op(ISHR), + VarOp(ISTORE, 2), + VarOp(ILOAD, 1), + VarOp(ILOAD, 2), + Op(IXOR), + VarOp(ILOAD, 2), + Op(ISUB), + Op(ICONST_1), + Op(ISHL), + VarOp(ILOAD, 2), + Op(ICONST_1), + Op(IAND), + Op(IOR), + VarOp(ISTORE, 3), + VarOp(ILOAD, 3), + IntOp(BIPUSH, 26), + Jump(IF_ICMPGE, Label(34)), + VarOp(ALOAD, 0), + Field(GETFIELD, "SourceMapWriter", "out", "Ljava/io/Writer;"), + IntOp(BIPUSH, 65), + VarOp(ILOAD, 3), + Op(IADD), + Invoke(INVOKEVIRTUAL, "java/io/Writer", "write", "(I)V", false), + Jump(GOTO, Label(40)), + Label(34), + VarOp(ALOAD, 0), + VarOp(ILOAD, 3), + Invoke(INVOKESPECIAL, "SourceMapWriter", "writeBase64VLQSlowPath$1", "(I)V", false), + Label(40), + Op(RETURN), + )) + + // --------------- + + assertSameCode(getMethod(sourceMapWriterClass, "writeBase64VLQSlowPath$1"), List( + VarOp(ILOAD, 1), + VarOp(ISTORE, 2), + Label(4), + VarOp(ILOAD, 2), + IntOp(BIPUSH, 31), + Op(IAND), + VarOp(ISTORE, 3), + VarOp(ILOAD, 2), + Op(ICONST_5), + Op(IUSHR), + VarOp(ISTORE, 2), + VarOp(ILOAD, 2), + Op(ICONST_0), + Jump(IF_ICMPEQ, Label(29)), + VarOp(ILOAD, 3), + IntOp(BIPUSH, 32), + Op(IOR), + VarOp(ISTORE, 3), + Jump(GOTO, Label(29)), + Label(29), + VarOp(ALOAD, 0), + Field(GETFIELD, "SourceMapWriter", "out", "Ljava/io/Writer;"), + VarOp(ALOAD, 0), + Invoke(INVOKESPECIAL, "SourceMapWriter", "Base64Map", "()Ljava/lang/String;", false), + VarOp(ILOAD, 3), + Invoke(INVOKEVIRTUAL, "java/lang/String", "charAt", "(I)C", false), + Invoke(INVOKEVIRTUAL, "java/io/Writer", "write", "(I)V", false), + VarOp(ILOAD, 2), + Op(ICONST_0), + Jump(IF_ICMPEQ, Label(47)), + Jump(GOTO, Label(4)), + Label(47), + Op(RETURN), + )) + } } From ff3c9db8c00d95e978c8aa18aa5ed826c842afbf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Mon, 2 May 2022 15:23:02 +0200 Subject: [PATCH 2/5] Remove a fallback in InlinerTest.oldInlineHigherOrderTest(). --- .../scala/tools/nsc/backend/jvm/opt/InlinerTest.scala | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/test/junit/scala/tools/nsc/backend/jvm/opt/InlinerTest.scala b/test/junit/scala/tools/nsc/backend/jvm/opt/InlinerTest.scala index c9dcb52582af..1d5a8633937e 100644 --- a/test/junit/scala/tools/nsc/backend/jvm/opt/InlinerTest.scala +++ b/test/junit/scala/tools/nsc/backend/jvm/opt/InlinerTest.scala @@ -1443,17 +1443,10 @@ class InlinerTest extends BytecodeTesting { val c = compileClass(code) // box-unbox will clean it up - try assertSameSummary(getMethod(c, "t"), List( + assertSameSummary(getMethod(c, "t"), List( ALOAD, "$anonfun$t$1", IFEQ /*A*/, "$anonfun$t$2", IRETURN, -1 /*A*/, "$anonfun$t$3", IRETURN)) - catch { case e: AssertionError => - try assertSameSummary(getMethod(c, "t"), List( // this is the new behaviour, after restarr'ing - ALOAD, "debug", IFEQ /*A*/, - "$anonfun$t$2", IRETURN, - -1 /*A*/, "$anonfun$t$3", IRETURN)) - catch { case _: AssertionError => throw e } - } } @Test From a793f9514f8e9df2ebe58cd52c0e791cc6f440f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Mon, 2 May 2022 15:26:07 +0200 Subject: [PATCH 3/5] Use explicit destinations in codegen to avoid uselessly jumping around. Previously, the codegen's main method `genLoad` always generated code that loaded the value on the stack before continuing. There were a number of situations where `genLoad` would be directly followed by unconditional jumps to instructions performing more jumps, returns and throws. This generated more spurious jumps than necessary, along with artifact dead code. We solve these limitations by introducing `LoadDestination`s that specify the destination of a loaded value: * FallThrough: as previously, load the value on the stack and continue. * Jump(label): load the value on the stack and jump to the given label. * Return: return the value from the enclosing method. * Throw: throw the value. We generalize `genLoad` as `genLoadTo`, taking a specific destination for the loaded value. `genLoadTo` can "push down" its destination into all control flow structures (except `Try`s, because of their cleanups). With that, when we get to the end of what amounts to "basic blocks", we know exactly the ultimate destination of the loaded value. We can therefore directly jump, return or throw to the final destination. This produces less bytecode, notably because fewer labels are necessary. For example, the method: def abs(x: Int): Int = if x < 0 then -x else x previously generated bytecode like ILOAD 1 ICONST_0 IF_ICMPGE Label(1) ILOAD 1 INEG GOTO Label(2) Label(1): ILOAD 1 Label(2): IRETURN Now, instead of jumping to Label(2), we directly perform an IRETURN: ILOAD 1 ICONST_0 IF_ICMPGE Label(1) ILOAD 1 INEG IRETURN Label(1): ILOAD 1 IRETURN While the changes are not very impressive on that simple example, they become more important in more complex cases, notably with switch matches. Examples can be found in the changed bytecode tests. An added benefit is that `genLoadTo` knows when loading a value results in an unconditional control flow change (jump, return or throw). It can then avoid inserting any useless adaptation. This removes all the dead bytecode that the codegen used to generate as artifacts of its own compilation scheme. (It will still generate dead bytecode if the original source code/inlined code contains dead code.) This is a backport of dotty's commit https://github.com/lampepfl/dotty/commit/4a2889f93a46372c3551beef8b17d4ba8f289ddf The improvements in scalac are not as good as in dotc because jumps to `LabelDef`s cannot be tracked in the same way as the returns from `Labeled` blocks. For the latter, by the time we see a return from a labeled block, we always already know the ultimate destination of the labeled block itself, and we can therefore jump to it. In scalac, we can find jumps to `LabelDef`s before knowing their ultimate destination. In this commit, we therefore do not push down the jumps to `LabelDef`s as `LoadDestination`s. We leave that to a separate commit, to keep this commit as close as possible to the dotty commit it backports. --- .../nsc/backend/jvm/BCodeBodyBuilder.scala | 141 ++++++++++++------ .../nsc/backend/jvm/BCodeSkelBuilder.scala | 34 +++-- .../tools/nsc/backend/jvm/BytecodeTest.scala | 50 +++---- .../nsc/backend/jvm/opt/InlinerTest.scala | 5 +- .../backend/jvm/opt/UnreachableCodeTest.scala | 6 +- 5 files changed, 144 insertions(+), 92 deletions(-) diff --git a/src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala b/src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala index 2c215f23a2d5..404b4fae273c 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala @@ -185,32 +185,44 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder { generatedType } - def genLoadIf(tree: If, expectedType: BType): BType = { + def genLoadIfTo(tree: If, expectedType: BType, dest: LoadDestination): BType = { val If(condp, thenp, elsep) = tree val success = new asm.Label val failure = new asm.Label val hasElse = !elsep.isEmpty - val postIf = if (hasElse) new asm.Label else failure genCond(condp, success, failure, targetIfNoJump = success) markProgramPoint(success) - val thenKind = tpeTK(thenp) - val elseKind = if (!hasElse) UNIT else tpeTK(elsep) - def hasUnitBranch = (thenKind == UNIT || elseKind == UNIT) - val resKind = if (hasUnitBranch) UNIT else tpeTK(tree) - - genLoad(thenp, resKind) - if (hasElse) { bc goTo postIf } - markProgramPoint(failure) - if (hasElse) { - genLoad(elsep, resKind) - markProgramPoint(postIf) + if (dest == LoadDestination.FallThrough) { + if (hasElse) { + val thenKind = tpeTK(thenp) + val elseKind = tpeTK(elsep) + def hasUnitBranch = (thenKind == UNIT || elseKind == UNIT) + val resKind = if (hasUnitBranch) UNIT else tpeTK(tree) + + val postIf = new asm.Label + genLoadTo(thenp, resKind, LoadDestination.Jump(postIf)) + markProgramPoint(failure) + genLoadTo(elsep, resKind, LoadDestination.FallThrough) + markProgramPoint(postIf) + resKind + } else { + genLoad(thenp, UNIT) + markProgramPoint(failure) + UNIT + } + } else { + genLoadTo(thenp, expectedType, dest) + markProgramPoint(failure) + if (hasElse) + genLoadTo(elsep, expectedType, dest) + else + genAdaptAndSendToDest(UNIT, expectedType, dest) + expectedType } - - resKind } def genPrimitiveOp(tree: Apply, expectedType: BType): BType = { @@ -257,13 +269,20 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder { } /* Generate code for trees that produce values on the stack */ - def genLoad(tree: Tree, expectedType: BType) { + def genLoad(tree: Tree, expectedType: BType): Unit = + genLoadTo(tree, expectedType, LoadDestination.FallThrough) + + /* Generate code for trees that produce values, sent to a given `LoadDestination`. */ + def genLoadTo(tree: Tree, expectedType: BType, dest: LoadDestination): Unit = { var generatedType = expectedType + var generatedDest: LoadDestination = LoadDestination.FallThrough lineNumber(tree) tree match { - case lblDf : LabelDef => genLabelDef(lblDf, expectedType) + case lblDf : LabelDef => + genLabelDefTo(lblDf, expectedType, dest) + generatedDest = dest case ValDef(_, nme.THIS, _, _) => debuglog("skipping trivial assign to _$this: " + tree) @@ -283,17 +302,20 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder { generatedType = UNIT case t : If => - generatedType = genLoadIf(t, expectedType) + generatedType = genLoadIfTo(t, expectedType, dest) + generatedDest = dest case r : Return => genReturn(r) - generatedType = expectedType + generatedDest = LoadDestination.Return case t : Try => generatedType = genLoadTry(t) case Throw(expr) => - generatedType = genThrow(expr) + val thrownKind = tpeTK(expr) + genLoadTo(expr, thrownKind, LoadDestination.Throw) + generatedDest = LoadDestination.Throw case New(tpt) => abort(s"Unexpected New(${tpt.summaryString}/$tpt) reached GenBCode.\n" + @@ -370,11 +392,17 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder { case _ => genConstant(value); generatedType = tpeTK(tree) } - case blck : Block => genBlock(blck, expectedType) + case blck : Block => + genBlockTo(blck, expectedType, dest) + generatedDest = dest - case Typed(Super(_, _), _) => genLoad(This(claszSymbol), expectedType) + case Typed(Super(_, _), _) => + genLoadTo(This(claszSymbol), expectedType, dest) + generatedDest = dest - case Typed(expr, _) => genLoad(expr, expectedType) + case Typed(expr, _) => + genLoadTo(expr, expectedType, dest) + generatedDest = dest case Assign(_, _) => generatedType = UNIT @@ -384,20 +412,40 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder { generatedType = genArrayValue(av) case mtch : Match => - generatedType = genMatch(mtch) + generatedType = genMatchTo(mtch, expectedType, dest) + generatedDest = dest case EmptyTree => if (expectedType != UNIT) { emitZeroOf(expectedType) } case _ => abort(s"Unexpected tree in genLoad: $tree/${tree.getClass} at: ${tree.pos}") } - // emit conversion - if (generatedType != expectedType) { - adapt(generatedType, expectedType) - } + // emit conversion and send to the right destination + if (generatedDest == LoadDestination.FallThrough) + genAdaptAndSendToDest(generatedType, expectedType, dest) } // end of GenBCode.genLoad() + def genAdaptAndSendToDest(generatedType: BType, expectedType: BType, dest: LoadDestination): Unit = { + if (generatedType != expectedType) + adapt(generatedType, expectedType) + + dest match { + case LoadDestination.FallThrough => + () + case LoadDestination.Jump(label) => + bc goTo label + case LoadDestination.Return => + bc emitRETURN returnType + case LoadDestination.Throw => + val thrownKind = expectedType + // `throw null` is valid although scala.Null (as defined in src/library-aux) isn't a subtype of Throwable. + // Similarly for scala.Nothing (again, as defined in src/library-aux). + assert(thrownKind.isNullType || thrownKind.isNothingType || thrownKind.asClassBType.isSubtypeOf(jlThrowableRef).get) + emit(asm.Opcodes.ATHROW) + } + } + // ---------------- field load and store ---------------- /* @@ -475,27 +523,25 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder { } } - private def genLabelDef(lblDf: LabelDef, expectedType: BType) { + private def genLabelDefTo(lblDf: LabelDef, expectedType: BType, dest: LoadDestination): Unit = { // duplication of LabelDefs contained in `finally`-clauses is handled when emitting RETURN. No bookkeeping for that required here. // no need to call index() over lblDf.params, on first access that magic happens (moreover, no LocalVariableTable entries needed for them). markProgramPoint(programPoint(lblDf.symbol)) lineNumber(lblDf) - genLoad(lblDf.rhs, expectedType) + genLoadTo(lblDf.rhs, expectedType, dest) } private def genReturn(r: Return) { val Return(expr) = r - val returnedKind = tpeTK(expr) - genLoad(expr, returnedKind) - adapt(returnedKind, returnType) - val saveReturnValue = (returnType != UNIT) - lineNumber(r) cleanups match { case Nil => // not an assertion: !shouldEmitCleanup (at least not yet, pendingCleanups() may still have to run, and reset `shouldEmitCleanup`. - bc emitRETURN returnType + genLoadTo(expr, returnType, LoadDestination.Return) case nextCleanup :: rest => + genLoad(expr, returnType) + lineNumber(r) + val saveReturnValue = (returnType != UNIT) if (saveReturnValue) { // regarding return value, the protocol is: in place of a `return-stmt`, a sequence of `adapt, store, jump` are inserted. if (earlyReturnVar == null) { @@ -764,10 +810,18 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder { * * On a second pass, we emit the switch blocks, one for each different target. */ - private def genMatch(tree: Match): BType = { + private def genMatchTo(tree: Match, expectedType: BType, dest: LoadDestination): BType = { lineNumber(tree) genLoad(tree.selector, INT) - val generatedType = tpeTK(tree) + + val (generatedType, postMatch, postMatchDest) = { + if (dest == LoadDestination.FallThrough) { + val postMatch = new asm.Label + (tpeTK(tree), postMatch, LoadDestination.Jump(postMatch)) + } else { + (expectedType, null, dest) + } + } var flatKeys: List[Int] = Nil var targets: List[asm.Label] = Nil @@ -801,24 +855,23 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder { bc.emitSWITCH(mkArrayReverse(flatKeys), mkArray(targets.reverse), default, MIN_SWITCH_DENSITY) // emit switch-blocks. - val postMatch = new asm.Label for (sb <- switchBlocks.reverse) { val (caseLabel, caseBody) = sb markProgramPoint(caseLabel) - genLoad(caseBody, generatedType) - bc goTo postMatch + genLoadTo(caseBody, generatedType, postMatchDest) } - markProgramPoint(postMatch) + if (postMatch != null) + markProgramPoint(postMatch) generatedType } - def genBlock(tree: Block, expectedType: BType) { + def genBlockTo(tree: Block, expectedType: BType, dest: LoadDestination): Unit = { val Block(stats, expr) = tree val savedScope = varsInScope varsInScope = Nil stats foreach genStat - genLoad(expr, expectedType) + genLoadTo(expr, expectedType, dest) val end = currProgramPoint() if (emitVars) { // add entries to LocalVariableTable JVM attribute for ((sym, start) <- varsInScope.reverse) { emitLocalVarScope(sym, start, end) } diff --git a/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala b/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala index 5bd3c080ffc7..64d95ec6fc8a 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala @@ -32,6 +32,20 @@ abstract class BCodeSkelBuilder extends BCodeHelpers { import coreBTypes._ import genBCode.postProcessor.backendUtils + /** The destination of a value generated by `genLoadTo`. */ + sealed abstract class LoadDestination extends Product with Serializable + + object LoadDestination { + /** The value is put on the stack, and control flows through to the next opcode. */ + case object FallThrough extends LoadDestination + /** The value is put on the stack, and control flow is transferred to the given `label`. */ + case class Jump(label: asm.Label) extends LoadDestination + /** The value is RETURN'ed from the enclosing method. */ + case object Return extends LoadDestination + /** The value is ATHROW'n. */ + case object Throw extends LoadDestination + } + /* * There's a dedicated PlainClassBuilder for each CompilationUnit, * which simplifies the initialization of per-class data structures in `genPlainClass()` which in turn delegates to `initJClass()` @@ -607,17 +621,15 @@ abstract class BCodeSkelBuilder extends BCodeHelpers { def emitNormalMethodBody() { val veryFirstProgramPoint = currProgramPoint() - genLoad(rhs, returnType) - - rhs match { - case Return(_) | Block(_, Return(_)) | Throw(_) | Block(_, Throw(_)) => () - case EmptyTree => - globalError("Concrete method has no definition: " + dd + ( - if (settings.isDebug) "(found: " + methSymbol.owner.info.decls.toList.mkString(", ") + ")" - else "")) - case _ => - bc emitRETURN returnType + + if (rhs == EmptyTree) { + globalError("Concrete method has no definition: " + dd + ( + if (settings.isDebug) "(found: " + methSymbol.owner.info.decls.toList.mkString(", ") + ")" + else "")) + } else { + genLoadTo(rhs, returnType, LoadDestination.Return) } + if (emitVars) { // add entries to LocalVariableTable JVM attribute val onePastLastProgramPoint = currProgramPoint() @@ -726,7 +738,7 @@ abstract class BCodeSkelBuilder extends BCodeHelpers { } } - def genLoad(tree: Tree, expectedType: BType) + def genLoadTo(tree: Tree, expectedType: BType, dest: LoadDestination) } // end of class PlainSkelBuilder diff --git a/test/junit/scala/tools/nsc/backend/jvm/BytecodeTest.scala b/test/junit/scala/tools/nsc/backend/jvm/BytecodeTest.scala index 22bb492cd272..c5a59ef587f3 100644 --- a/test/junit/scala/tools/nsc/backend/jvm/BytecodeTest.scala +++ b/test/junit/scala/tools/nsc/backend/jvm/BytecodeTest.scala @@ -128,16 +128,14 @@ class BytecodeTest extends BytecodeTesting { // t1: no unnecessary GOTOs assertSameCode(getMethod(c, "t1"), List( VarOp(ILOAD, 1), Jump(IFEQ, Label(6)), - Op(ICONST_1), Jump(GOTO, Label(9)), - Label(6), Op(ICONST_2), - Label(9), Op(IRETURN))) + Op(ICONST_1), Op(IRETURN), + Label(6), Op(ICONST_2), Op(IRETURN))) // t2: no unnecessary GOTOs assertSameCode(getMethod(c, "t2"), List( VarOp(ILOAD, 1), IntOp(SIPUSH, 393), Jump(IF_ICMPNE, Label(7)), - Op(ICONST_1), Jump(GOTO, Label(10)), - Label(7), Op(ICONST_2), - Label(10), Op(IRETURN))) + Op(ICONST_1), Op(IRETURN), + Label(7), Op(ICONST_2), Op(IRETURN))) // t3: Array == is translated to reference equality, AnyRef == to null checks and equals assertSameCode(getMethod(c, "t3"), List( @@ -169,9 +167,8 @@ class BytecodeTest extends BytecodeTesting { VarOp(ILOAD, 1), IntOp(BIPUSH, 10), Jump(IF_ICMPNE, Label(7)), VarOp(ILOAD, 2), Jump(IFNE, Label(12)), Label(7), VarOp(ILOAD, 1), Op(ICONST_1), Jump(IF_ICMPEQ, Label(16)), - Label(12), Op(ICONST_1), Jump(GOTO, Label(19)), - Label(16), Op(ICONST_2), - Label(19), Op(IRETURN))) + Label(12), Op(ICONST_1), Op(IRETURN), + Label(16), Op(ICONST_2), Op(IRETURN))) // t7: universal equality assertInvoke(getMethod(c, "t7"), "scala/runtime/BoxesRunTime", "equals") @@ -208,7 +205,7 @@ class BytecodeTest extends BytecodeTesting { Label(0), Ldc(LDC, ""), VarOp(ASTORE, 1), Label(4), VarOp(ALOAD, 1), Jump(IFNULL, Label(20)), Label(9), VarOp(ALOAD, 0), Invoke(INVOKEVIRTUAL, "C", "foo", "()V", false), Label(13), Op(ACONST_NULL), VarOp(ASTORE, 1), Label(17), Jump(GOTO, Label(4)), - Label(20), VarOp(ALOAD, 0), Invoke(INVOKEVIRTUAL, "C", "bar", "()V", false), Label(25), Op(RETURN), Label(27))) + Label(20), VarOp(ALOAD, 0), Invoke(INVOKEVIRTUAL, "C", "bar", "()V", false), Op(RETURN), Label(26))) val labels = t.instructions collect { case l: Label => l } val x = t.localVars.find(_.name == "x").get assertEquals(x.start, labels(1)) @@ -346,7 +343,7 @@ class BytecodeTest extends BytecodeTesting { Op(ICONST_0), Jump(IF_ICMPNE, Label(8)), VarOp(ILOAD, 2), - Jump(GOTO, Label(20)), + Op(IRETURN), Label(8), VarOp(ILOAD, 1), Op(ICONST_1), @@ -357,8 +354,6 @@ class BytecodeTest extends BytecodeTesting { VarOp(ISTORE, 2), VarOp(ISTORE, 1), Jump(GOTO, Label(0)), - Label(20), - Op(IRETURN), )) // With changing the `this` value @@ -387,7 +382,7 @@ class BytecodeTest extends BytecodeTesting { VarOp(ALOAD, 0), Field(GETFIELD, "IntList", "head", "I"), Op(IADD), - Jump(GOTO, Label(26)), + Op(IRETURN), Label(15), VarOp(ALOAD, 3), VarOp(ILOAD, 1), @@ -397,8 +392,6 @@ class BytecodeTest extends BytecodeTesting { VarOp(ISTORE, 1), VarOp(ASTORE, 0), Jump(GOTO, Label(0)), - Label(26), - Op(IRETURN), )) } @@ -620,20 +613,18 @@ class BytecodeTest extends BytecodeTesting { LookupSwitch(LOOKUPSWITCH, Label(26), List(1, 7, 8, 9), List(Label(6), Label(11), Label(16), Label(21))), Label(6), IntOp(BIPUSH, 10), - Jump(GOTO, Label(31)), + Op(IRETURN), Label(11), IntOp(BIPUSH, 20), - Jump(GOTO, Label(31)), + Op(IRETURN), Label(16), IntOp(BIPUSH, 30), - Jump(GOTO, Label(31)), + Op(IRETURN), Label(21), IntOp(BIPUSH, 40), - Jump(GOTO, Label(31)), + Op(IRETURN), Label(26), VarOp(ILOAD, 1), - Jump(GOTO, Label(31)), - Label(31), Op(IRETURN), )) @@ -646,19 +637,19 @@ class BytecodeTest extends BytecodeTesting { LookupSwitch(LOOKUPSWITCH, Label(21), List(1, 2, 7, 8), List(Label(6), Label(6), Label(11), Label(16))), Label(6), IntOp(BIPUSH, 10), - Jump(GOTO, Label(40)), + Op(IRETURN), Label(11), IntOp(BIPUSH, 20), - Jump(GOTO, Label(40)), + Op(IRETURN), Label(16), IntOp(BIPUSH, 30), - Jump(GOTO, Label(40)), + Op(IRETURN), Label(21), VarOp(ILOAD, 2), IntOp(BIPUSH, 100), Jump(IF_ICMPLE, Label(29)), IntOp(BIPUSH, 20), - Jump(GOTO, Label(37)), + Op(IRETURN), Label(29), TypeOp(NEW, "scala/MatchError"), Op(DUP), @@ -666,10 +657,6 @@ class BytecodeTest extends BytecodeTesting { Invoke(INVOKESTATIC, "scala/runtime/BoxesRunTime", "boxToInteger", "(I)Ljava/lang/Integer;", false), Invoke(INVOKESPECIAL, "scala/MatchError", "", "(Ljava/lang/Object;)V", false), Op(ATHROW), - Label(37), - Jump(GOTO, Label(40)), - Label(40), - Op(IRETURN), )) } @@ -747,12 +734,11 @@ class BytecodeTest extends BytecodeTesting { VarOp(ILOAD, 3), Op(IADD), Invoke(INVOKEVIRTUAL, "java/io/Writer", "write", "(I)V", false), - Jump(GOTO, Label(40)), + Op(RETURN), Label(34), VarOp(ALOAD, 0), VarOp(ILOAD, 3), Invoke(INVOKESPECIAL, "SourceMapWriter", "writeBase64VLQSlowPath$1", "(I)V", false), - Label(40), Op(RETURN), )) diff --git a/test/junit/scala/tools/nsc/backend/jvm/opt/InlinerTest.scala b/test/junit/scala/tools/nsc/backend/jvm/opt/InlinerTest.scala index 1d5a8633937e..0e848626cd04 100644 --- a/test/junit/scala/tools/nsc/backend/jvm/opt/InlinerTest.scala +++ b/test/junit/scala/tools/nsc/backend/jvm/opt/InlinerTest.scala @@ -1445,8 +1445,9 @@ class InlinerTest extends BytecodeTesting { // box-unbox will clean it up assertSameSummary(getMethod(c, "t"), List( ALOAD, "$anonfun$t$1", IFEQ /*A*/, - "$anonfun$t$2", IRETURN, - -1 /*A*/, "$anonfun$t$3", IRETURN)) + "$anonfun$t$2", ISTORE, GOTO /*B*/, + -1 /*A*/, "$anonfun$t$3", ISTORE, + -1 /*B*/, ILOAD, IRETURN)) } @Test diff --git a/test/junit/scala/tools/nsc/backend/jvm/opt/UnreachableCodeTest.scala b/test/junit/scala/tools/nsc/backend/jvm/opt/UnreachableCodeTest.scala index e308480b291f..789153c3bcaf 100644 --- a/test/junit/scala/tools/nsc/backend/jvm/opt/UnreachableCodeTest.scala +++ b/test/junit/scala/tools/nsc/backend/jvm/opt/UnreachableCodeTest.scala @@ -129,9 +129,9 @@ class UnreachableCodeTest extends ClearAfterClass { // L2: ICONST_2 << dead // L3: IRETURN << dead // - // Finally, instructions in the dead basic blocks are replaced by ATHROW, as explained in - // a comment in BCodeBodyBuilder. - assertSameCode(noDce.dropNonOp, List(Op(ICONST_1), Op(IRETURN), Op(ATHROW), Op(ATHROW))) + // Finally, instructions in the dead basic blocks are replaced by NOP's and ATHROW's, + // as explained in a comment in BCodeBodyBuilder. + assertSameCode(noDce.dropNonOp, List(Op(ICONST_1), Op(IRETURN), Op(NOP), Op(ATHROW))) } @Test From 81f6f6b6a0332838eafeea9936a28976e0724e27 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Tue, 3 May 2022 18:10:08 +0200 Subject: [PATCH 4/5] Optimize typical matchEnd() LabelDefs for LoadDestination. To do this, we intercept typical shapes produced by patmat in `genBlockTo`. --- .../nsc/backend/jvm/BCodeBodyBuilder.scala | 121 ++++++++++++++++-- .../nsc/backend/jvm/BCodeSkelBuilder.scala | 18 ++- .../nsc/backend/jvm/BCodeSyncAndTry.scala | 2 +- .../tools/nsc/backend/jvm/BytecodeTest.scala | 79 +++++------- .../backend/jvm/opt/MethodLevelOptsTest.scala | 9 +- .../transform/patmat/PatmatBytecodeTest.scala | 12 +- 6 files changed, 166 insertions(+), 75 deletions(-) diff --git a/src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala b/src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala index 404b4fae273c..e618d2359e18 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala @@ -321,6 +321,23 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder { abort(s"Unexpected New(${tpt.summaryString}/$tpt) reached GenBCode.\n" + " Call was genLoad" + ((tree, expectedType))) + case app @ Apply(fun, args) if fun.symbol.isLabel => + // jump to a label + val sym = fun.symbol + getJumpDestOrCreate(sym) match { + case JumpDestination.Regular(label) => + val lblDef = labelDef.getOrElse(sym, { + abort("Not found: " + sym + " in " + labelDef) + }) + genLoadLabelArguments(args, lblDef, app.pos) + bc goTo label + generatedDest = LoadDestination.Jump(label) + case JumpDestination.LoadArgTo(paramType, jumpDest) => + val List(arg) = args + genLoadTo(arg, paramType, jumpDest) + generatedDest = jumpDest + } + case app : Apply => generatedType = genApply(app, expectedType) @@ -526,7 +543,10 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder { private def genLabelDefTo(lblDf: LabelDef, expectedType: BType, dest: LoadDestination): Unit = { // duplication of LabelDefs contained in `finally`-clauses is handled when emitting RETURN. No bookkeeping for that required here. // no need to call index() over lblDf.params, on first access that magic happens (moreover, no LocalVariableTable entries needed for them). - markProgramPoint(programPoint(lblDf.symbol)) + + // If we get inside genLabelDefTo, no one has or will register a non-regular jump destination for this LabelDef + val JumpDestination.Regular(label) = getJumpDestOrCreate(lblDf.symbol) + markProgramPoint(label) lineNumber(lblDf) genLoadTo(lblDf.rhs, expectedType, dest) } @@ -699,11 +719,7 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder { case app @ Apply(fun, args) => val sym = fun.symbol - if (sym.isLabel) { // jump to a label - def notFound() = abort("Not found: " + sym + " in " + labelDef) - genLoadLabelArguments(args, labelDef.getOrElse(sym, notFound()), app.pos) - bc goTo programPoint(sym) - } else if (isPrimitive(sym)) { // primitive method call + if (isPrimitive(sym)) { // primitive method call generatedType = genPrimitiveOp(app, expectedType) } else { // normal method call def isTraitSuperAccessorBodyCall = app.hasAttachment[UseInvokeSpecial.type] @@ -870,8 +886,97 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder { val Block(stats, expr) = tree val savedScope = varsInScope varsInScope = Nil - stats foreach genStat - genLoadTo(expr, expectedType, dest) + + /* Optimize common patmat-generated shapes, so that we can push the + * `dest` down the various cases. + * + * The two most common shapes are + * + * { + * initStats + * case1() { ... matchEnd(caseBody1) ... } + * ... + * caseN() { ... matchEnd(caseBodyN) ... } + * matchEnd(x: R) { + * x + * } + * } + * + * for non-unit results, and + * + * { + * initStats + * case1() { ... matchEnd(caseBody1) ... } + * ... + * caseN() { ... matchEnd(caseBodyN) ... } + * matchEnd(x: BoxedUnit$) { + * () + * } + * } + * + * for unit results. + * + * If we do nothing, when we encounter the calls to `matchEnd` in the + * cases, we don't know yet what is the final `dest` of the block, so we + * cannot generate good code. + * + * Here, we recognize those shapes, and if we find them, we record a + * priori the ultimate `dest` of the full match. This allows to push + * `dest` to all the cases. + * + * For the transformation to be correct, control must not flow into + * `matchEnd` "normally" (i.e., not through a label apply). This is + * always the case for patmat-generated `matchEnd`s, but not for + * arbitrary LabelDefs. In particular, it is not true for + * tailrec-generarted LabelDefs. Therefore, we add specific tests to + * only recognize patmat-generated `matchEnd` labels this way. + * + * There are some rare cases where patmat will generate other shapes. + * For example, the source-code shape `return x match { ... }` transfers + * the `return` right around the `matchEnd`, for some reason, instead of + * around the entire Block. Those rare shapes are not recognized here. + * For them, the default (non-optimal) codegen will apply. + */ + + def default(): Unit = { + stats foreach genStat + genLoadTo(expr, expectedType, dest) + } + + def optimizedMatch(sym: Symbol): Unit = { + if (dest == LoadDestination.FallThrough) { + val label = new asm.Label + jumpDest += ((sym -> JumpDestination.LoadArgTo(expectedType, LoadDestination.Jump(label)))) + stats foreach genStat + markProgramPoint(label) + } else { + jumpDest += ((sym -> JumpDestination.LoadArgTo(expectedType, dest))) + stats foreach genStat + } + } + + def isMatchEndLabelDef(tree: LabelDef): Boolean = + treeInfo.hasSynthCaseSymbol(tree) && tree.symbol.name.startsWith("matchEnd") + + expr match { + case matchEnd @ LabelDef(_, singleArg :: Nil, body) if isMatchEndLabelDef(matchEnd) => + val sym = matchEnd.symbol + body match { + case _ if jumpDest.contains(sym) => + // We already generated a jump to this label in the regular way; we cannot optimize anymore + default() + case bodyIdent: Ident if bodyIdent.symbol == singleArg.symbol => + optimizedMatch(sym) + case Literal(Constant(())) => + optimizedMatch(sym) + case _ => + default() + } + + case _ => + default() + } + val end = currProgramPoint() if (emitVars) { // add entries to LocalVariableTable JVM attribute for ((sym, start) <- varsInScope.reverse) { emitLocalVarScope(sym, start, end) } diff --git a/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala b/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala index 64d95ec6fc8a..2be31e33cb4e 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala @@ -292,13 +292,19 @@ abstract class BCodeSkelBuilder extends BCodeHelpers { * A related map is `labelDef`: it has the same keys as `jumpDest` but its values are LabelDef nodes not asm.Labels. * */ - var jumpDest: immutable.Map[ /* LabelDef */ Symbol, asm.Label ] = null - def programPoint(labelSym: Symbol): asm.Label = { + sealed abstract class JumpDestination + object JumpDestination { + case class Regular(label: asm.Label) extends JumpDestination + case class LoadArgTo(expectedType: BType, dest: LoadDestination) extends JumpDestination + } + + var jumpDest: immutable.Map[ /* LabelDef */ Symbol, JumpDestination ] = null + def getJumpDestOrCreate(labelSym: Symbol): JumpDestination = { assert(labelSym.isLabel, s"trying to map a non-label symbol to an asm.Label, at: ${labelSym.pos}") jumpDest.getOrElse(labelSym, { - val pp = new asm.Label - jumpDest += (labelSym -> pp) - pp + val regularDest = JumpDestination.Regular(new asm.Label) + jumpDest += (labelSym -> regularDest) + regularDest }) } @@ -479,7 +485,7 @@ abstract class BCodeSkelBuilder extends BCodeHelpers { // on entering a method def resetMethodBookkeeping(dd: DefDef) { locals.reset(isStaticMethod = methSymbol.isStaticMember) - jumpDest = immutable.Map.empty[ /* LabelDef */ Symbol, asm.Label ] + jumpDest = immutable.Map.empty // populate labelDefsAtOrUnder val ldf = new LabelDefsFinder(dd.rhs) ldf(dd.rhs) diff --git a/src/compiler/scala/tools/nsc/backend/jvm/BCodeSyncAndTry.scala b/src/compiler/scala/tools/nsc/backend/jvm/BCodeSyncAndTry.scala index 94a590ed2d10..601d1eb40d02 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/BCodeSyncAndTry.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/BCodeSyncAndTry.scala @@ -402,7 +402,7 @@ abstract class BCodeSyncAndTry extends BCodeBodyBuilder { /* `tmp` (if non-null) is the symbol of the local-var used to preserve the result of the try-body, see `guardResult` */ def emitFinalizer(finalizer: Tree, tmp: Symbol, isDuplicate: Boolean) { - var saved: immutable.Map[ /* LabelDef */ Symbol, asm.Label ] = null + var saved: immutable.Map[ /* LabelDef */ Symbol, JumpDestination ] = null if (isDuplicate) { saved = jumpDest for(ldef <- labelDefsAtOrUnder.getOrElse(finalizer, Nil)) { diff --git a/test/junit/scala/tools/nsc/backend/jvm/BytecodeTest.scala b/test/junit/scala/tools/nsc/backend/jvm/BytecodeTest.scala index c5a59ef587f3..3af4270d8b93 100644 --- a/test/junit/scala/tools/nsc/backend/jvm/BytecodeTest.scala +++ b/test/junit/scala/tools/nsc/backend/jvm/BytecodeTest.scala @@ -425,7 +425,7 @@ class BytecodeTest extends BytecodeTesting { VarOp(ASTORE, 3), VarOp(ALOAD, 3), TypeOp(INSTANCEOF, "scala/collection/immutable/$colon$colon"), - Jump(IFEQ, Label(20)), + Jump(IFEQ, Label(19)), VarOp(ALOAD, 3), TypeOp(CHECKCAST, "scala/collection/immutable/$colon$colon"), VarOp(ASTORE, 4), @@ -434,29 +434,24 @@ class BytecodeTest extends BytecodeTesting { Invoke(INVOKESTATIC, "scala/runtime/BoxesRunTime", "unboxToInt", "(Ljava/lang/Object;)I", false), VarOp(ISTORE, 5), VarOp(ILOAD, 5), - VarOp(ISTORE, 2), - Jump(GOTO, Label(44)), - Label(20), - Jump(GOTO, Label(23)), - Label(23), + Op(IRETURN), + Label(19), + Jump(GOTO, Label(22)), + Label(22), Field(GETSTATIC, "scala/collection/immutable/Nil$", "MODULE$", "Lscala/collection/immutable/Nil$;"), VarOp(ALOAD, 3), Invoke(INVOKEVIRTUAL, "java/lang/Object", "equals", "(Ljava/lang/Object;)Z", false), - Jump(IFEQ, Label(33)), + Jump(IFEQ, Label(31)), IntOp(BIPUSH, 20), - VarOp(ISTORE, 2), - Jump(GOTO, Label(44)), - Label(33), - Jump(GOTO, Label(36)), - Label(36), + Op(IRETURN), + Label(31), + Jump(GOTO, Label(34)), + Label(34), TypeOp(NEW, "scala/MatchError"), Op(DUP), VarOp(ALOAD, 3), Invoke(INVOKESPECIAL, "scala/MatchError", "", "(Ljava/lang/Object;)V", false), Op(ATHROW), - Label(44), - VarOp(ILOAD, 2), - Op(IRETURN), )) // --------------- @@ -470,7 +465,7 @@ class BytecodeTest extends BytecodeTesting { VarOp(ASTORE, 6), VarOp(ALOAD, 6), TypeOp(INSTANCEOF, "scala/collection/immutable/$colon$colon"), - Jump(IFEQ, Label(57)), + Jump(IFEQ, Label(56)), Op(ICONST_1), VarOp(ISTORE, 4), VarOp(ALOAD, 6), @@ -484,8 +479,7 @@ class BytecodeTest extends BytecodeTesting { VarOp(ILOAD, 7), Jump(IF_ICMPNE, Label(28)), Op(ICONST_1), - VarOp(ISTORE, 3), - Jump(GOTO, Label(47)), + Jump(GOTO, Label(44)), Label(28), Jump(GOTO, Label(31)), Label(31), @@ -493,42 +487,33 @@ class BytecodeTest extends BytecodeTesting { VarOp(ILOAD, 7), Jump(IF_ICMPNE, Label(39)), Op(ICONST_1), - VarOp(ISTORE, 3), - Jump(GOTO, Label(47)), + Jump(GOTO, Label(44)), Label(39), Jump(GOTO, Label(42)), Label(42), Op(ICONST_0), - VarOp(ISTORE, 3), - Jump(GOTO, Label(47)), - Label(47), - VarOp(ILOAD, 3), - Jump(IFEQ, Label(54)), + Jump(GOTO, Label(44)), + Label(44), + Jump(IFEQ, Label(53)), IntOp(BIPUSH, 10), - VarOp(ISTORE, 2), - Jump(GOTO, Label(82)), - Label(54), - Jump(GOTO, Label(60)), - Label(57), - Jump(GOTO, Label(60)), - Label(60), + Op(IRETURN), + Label(53), + Jump(GOTO, Label(59)), + Label(56), + Jump(GOTO, Label(59)), + Label(59), VarOp(ILOAD, 4), - Jump(IFEQ, Label(73)), + Jump(IFEQ, Label(71)), VarOp(ALOAD, 5), Invoke(INVOKEVIRTUAL, "scala/collection/immutable/$colon$colon", "head", "()Ljava/lang/Object;", false), Invoke(INVOKESTATIC, "scala/runtime/BoxesRunTime", "unboxToInt", "(Ljava/lang/Object;)I", false), VarOp(ISTORE, 8), VarOp(ILOAD, 8), - VarOp(ISTORE, 2), - Jump(GOTO, Label(82)), - Label(73), - Jump(GOTO, Label(76)), - Label(76), + Op(IRETURN), + Label(71), + Jump(GOTO, Label(74)), + Label(74), IntOp(BIPUSH, 20), - VarOp(ISTORE, 2), - Jump(GOTO, Label(82)), - Label(82), - VarOp(ILOAD, 2), Op(IRETURN), )) @@ -552,8 +537,8 @@ class BytecodeTest extends BytecodeTesting { Invoke(INVOKESTATIC, "scala/runtime/BoxesRunTime", "boxToInteger", "(I)Ljava/lang/Integer;", false), Invoke(INVOKEVIRTUAL, "scala/Predef$", "println", "(Ljava/lang/Object;)V", false), Field(GETSTATIC, "scala/runtime/BoxedUnit", "UNIT", "Lscala/runtime/BoxedUnit;"), - VarOp(ASTORE, 2), - Jump(GOTO, Label(51)), + Op(POP), + Op(RETURN), Label(24), Jump(GOTO, Label(27)), Label(27), @@ -565,8 +550,8 @@ class BytecodeTest extends BytecodeTesting { Ldc(LDC, "nil"), Invoke(INVOKEVIRTUAL, "scala/Predef$", "println", "(Ljava/lang/Object;)V", false), Field(GETSTATIC, "scala/runtime/BoxedUnit", "UNIT", "Lscala/runtime/BoxedUnit;"), - VarOp(ASTORE, 2), - Jump(GOTO, Label(51)), + Op(POP), + Op(RETURN), Label(40), Jump(GOTO, Label(43)), Label(43), @@ -575,8 +560,6 @@ class BytecodeTest extends BytecodeTesting { VarOp(ALOAD, 3), Invoke(INVOKESPECIAL, "scala/MatchError", "", "(Ljava/lang/Object;)V", false), Op(ATHROW), - Label(51), - Op(RETURN), )) } diff --git a/test/junit/scala/tools/nsc/backend/jvm/opt/MethodLevelOptsTest.scala b/test/junit/scala/tools/nsc/backend/jvm/opt/MethodLevelOptsTest.scala index 57f45399a45e..c7011ade300a 100644 --- a/test/junit/scala/tools/nsc/backend/jvm/opt/MethodLevelOptsTest.scala +++ b/test/junit/scala/tools/nsc/backend/jvm/opt/MethodLevelOptsTest.scala @@ -485,11 +485,10 @@ class MethodLevelOptsTest extends BytecodeTesting { val c = compileClass(code) assertSameSummary(getMethod(c, "t"), List( - BIPUSH, ILOAD, IF_ICMPNE, - BIPUSH, ILOAD, IF_ICMPNE, - LDC, ASTORE, GOTO, - -1, LDC, ASTORE, - -1, ALOAD, ARETURN)) + BIPUSH, ILOAD, IF_ICMPNE /*A*/, + BIPUSH, ILOAD, IF_ICMPNE /*A*/, + LDC, ARETURN, + -1 /*A*/, LDC, ARETURN)) } @Test diff --git a/test/junit/scala/tools/nsc/transform/patmat/PatmatBytecodeTest.scala b/test/junit/scala/tools/nsc/transform/patmat/PatmatBytecodeTest.scala index 40d981534d2f..8cb5faf8c9a5 100644 --- a/test/junit/scala/tools/nsc/transform/patmat/PatmatBytecodeTest.scala +++ b/test/junit/scala/tools/nsc/transform/patmat/PatmatBytecodeTest.scala @@ -116,9 +116,9 @@ class PatmatBytecodeTest extends BytecodeTesting { assertSameSummary(getMethod(c, "a"), List( NEW, DUP, ICONST_1, "boxToInteger", LDC, "", ASTORE /*1*/, ALOAD /*1*/, "y", ASTORE /*2*/, - ALOAD /*1*/, "x", INSTANCEOF, IFNE /*R*/, - NEW, DUP, ALOAD /*1*/, "", ATHROW, - /*R*/ -1, ALOAD /*2*/, ARETURN)) + ALOAD /*1*/, "x", INSTANCEOF, IFEQ /*E*/, + ALOAD /*2*/, ARETURN, + -1 /*E*/, NEW, DUP, ALOAD /*1*/, "", ATHROW)) } @Test @@ -137,10 +137,8 @@ class PatmatBytecodeTest extends BytecodeTesting { val expected = List( ALOAD /*1*/ , INSTANCEOF /*::*/ , IFEQ /*A*/ , - ALOAD, CHECKCAST /*::*/ , "head", "unboxToInt", - ISTORE, GOTO /*B*/ , - -1 /*A*/ , NEW /*MatchError*/ , DUP, ALOAD /*1*/ , "", ATHROW, - -1 /*B*/ , ILOAD, IRETURN) + ALOAD, CHECKCAST /*::*/ , "head", "unboxToInt", IRETURN, + -1 /*A*/ , NEW /*MatchError*/ , DUP, ALOAD /*1*/ , "", ATHROW) assertSameSummary(getMethod(c, "a"), expected) assertSameSummary(getMethod(c, "b"), expected) From 5fd3ca4f1902d4194eee3c9ecb159ea12fd80c9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Wed, 4 May 2022 14:36:03 +0200 Subject: [PATCH 5/5] Inliner: Load return values on the stack before jumping to post-call. Previously, the inliner would emit the following sequence to replace xRETURN instructions: 1. STORE the return value 2. DROP the rest of the stack 3. JUMP to post-call And at post-call, it would start by LOADing the return value back on the stack. This commit moves (and potentially duplicates) the LOAD instruction to before JUMPing to post-call. In the common case where the stack is empty at the point of the xRETURN (save for the return value itself), we can then completely omit storing and loading the value. This change recovers the optimized bytecode that the LoadDestination commit destroyed in `InlinerTest.oldInlineHigherOrderTest(). The changes in that commit make it much more frequent to have xRETURN nodes in the middle of a method body, and hence the new inliner behavior makes more sense. Note: there already existed a comment in Inliner.inlineCallsite() that pretended that loading the return value was part of the replacement sequence for xRETURN. It was basically lying. The new behavior is aligned with that comment. --- .../tools/nsc/backend/jvm/opt/Inliner.scala | 38 +++++++++++-------- test/files/neg/inlineMaxSize.check | 4 +- test/files/neg/inlineMaxSize.scala | 2 +- .../nsc/backend/jvm/opt/InlinerTest.scala | 21 +++++----- 4 files changed, 34 insertions(+), 31 deletions(-) diff --git a/src/compiler/scala/tools/nsc/backend/jvm/opt/Inliner.scala b/src/compiler/scala/tools/nsc/backend/jvm/opt/Inliner.scala index 97be68252457..ea6dd1487f95 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/opt/Inliner.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/opt/Inliner.scala @@ -498,29 +498,35 @@ abstract class Inliner { def drop(slot: Int) = returnReplacement add getPop(frame.peekStack(slot).getSize) - // for non-void methods, store the stack top into the return local variable - if (hasReturnValue) { - returnReplacement add returnValueStore(originalReturn) - stackHeight -= 1 - } + if (stackHeight == (if (hasReturnValue) 1 else 0)) { + // In most cases, the xRETURN we found should be at an empty stack height, + // either because it comes from Java in statement position, or because + // it comes from Scala in a tail expression position. For this common + // case, we don't have to manipulate the stack at all; we only leave + // the optional return value on the stack before jumping + } else { + // Otherwise, we have to empty the stack and only leave the optional + // return value on the stack. - // drop the rest of the stack - for (i <- 0 until stackHeight) drop(i) + // for non-void methods, store the stack top into the return local variable + if (hasReturnValue) { + returnReplacement add returnValueStore(originalReturn) + stackHeight -= 1 + } + + // drop the rest of the stack + for (i <- 0 until stackHeight) drop(i) + + // load the return value back on the stack + if (hasReturnValue) + returnReplacement add new VarInsnNode(returnType.getOpcode(ILOAD), returnValueIndex) + } returnReplacement add new JumpInsnNode(GOTO, postCallLabel) clonedInstructions.insert(inlinedReturn, returnReplacement) clonedInstructions.remove(inlinedReturn) } - // Load instruction for the return value - if (hasReturnValue) { - val retVarLoad = { - val opc = returnType.getOpcode(ILOAD) - new VarInsnNode(opc, returnValueIndex) - } - clonedInstructions.insert(postCallLabel, retVarLoad) - } - undo.saveMethodState(callsiteMethod) callsiteMethod.instructions.insert(callsiteInstruction, clonedInstructions) diff --git a/test/files/neg/inlineMaxSize.check b/test/files/neg/inlineMaxSize.check index b66b845bfbb6..a2f22f0aa794 100644 --- a/test/files/neg/inlineMaxSize.check +++ b/test/files/neg/inlineMaxSize.check @@ -2,8 +2,8 @@ inlineMaxSize.scala:8: warning: C::i()I is annotated @inline but could not be in The size of the callsite method C::j()I would exceed the JVM method size limit after inlining C::i()I. - @inline final def j = i + i + i - ^ + @inline final def j = i + i + i + i + ^ error: No warnings can be incurred under -Xfatal-warnings. one warning found one error found diff --git a/test/files/neg/inlineMaxSize.scala b/test/files/neg/inlineMaxSize.scala index 70192efd70bd..4eae0545ccab 100644 --- a/test/files/neg/inlineMaxSize.scala +++ b/test/files/neg/inlineMaxSize.scala @@ -5,5 +5,5 @@ class C { @inline final def g = f + f + f + f + f + f + f + f + f + f @inline final def h = g + g + g + g + g + g + g + g + g + g @inline final def i = h + h + h + h + h + h + h + h + h + h - @inline final def j = i + i + i + @inline final def j = i + i + i + i } diff --git a/test/junit/scala/tools/nsc/backend/jvm/opt/InlinerTest.scala b/test/junit/scala/tools/nsc/backend/jvm/opt/InlinerTest.scala index 0e848626cd04..26e24bfb40ba 100644 --- a/test/junit/scala/tools/nsc/backend/jvm/opt/InlinerTest.scala +++ b/test/junit/scala/tools/nsc/backend/jvm/opt/InlinerTest.scala @@ -80,8 +80,8 @@ class InlinerTest extends BytecodeTesting { assertSameCode(gConv, List( VarOp(ALOAD, 0), VarOp(ASTORE, 1), // store this - Op(ICONST_1), VarOp(ISTORE, 2), Jump(GOTO, Label(10)), // store return value - Label(10), VarOp(ILOAD, 2), // load return value + Op(ICONST_1), Jump(GOTO, Label(10)), // load return value + Label(10), VarOp(ALOAD, 0), Invoke(INVOKEVIRTUAL, "C", "f", "()I", false), Op(IADD), Op(IRETURN))) // line numbers are kept, so there's a line 2 (from the inlined f) @@ -115,10 +115,8 @@ class InlinerTest extends BytecodeTesting { Invoke(INVOKEVIRTUAL, "scala/Predef$", "$qmark$qmark$qmark", "()Lscala/runtime/Nothing$;", false)) val gBeforeLocalOpt = VarOp(ALOAD, 0) :: VarOp(ASTORE, 1) :: invokeQQQ ::: List( - VarOp(ASTORE, 2), - Jump(GOTO, Label(11)), - Label(11), - VarOp(ALOAD, 2), + Jump(GOTO, Label(14)), + Label(14), Op(ATHROW)) assertSameCode(convertMethod(g), gBeforeLocalOpt) @@ -372,13 +370,13 @@ class InlinerTest extends BytecodeTesting { assert(g1.maxStack == 7 && f1.maxStack == 6, s"${g1.maxStack} - ${f1.maxStack}") // locals in f1: this, x, a - // locals in g1 after inlining: this, this-of-f1, x, a, return value - assert(g1.maxLocals == 5 && f1.maxLocals == 3, s"${g1.maxLocals} - ${f1.maxLocals}") + // locals in g1 after inlining: this, this-of-f1, x, a + assert(g1.maxLocals == 4 && f1.maxLocals == 3, s"${g1.maxLocals} - ${f1.maxLocals}") // like maxStack in g1 / f1 assert(g2.maxStack == 5 && f2.maxStack == 4, s"${g2.maxStack} - ${f2.maxStack}") - // like maxLocals for g1 / f1, but no return value + // like maxLocals for g1 / f1 assert(g2.maxLocals == 4 && f2.maxLocals == 3, s"${g2.maxLocals} - ${f2.maxLocals}") } @@ -1445,9 +1443,8 @@ class InlinerTest extends BytecodeTesting { // box-unbox will clean it up assertSameSummary(getMethod(c, "t"), List( ALOAD, "$anonfun$t$1", IFEQ /*A*/, - "$anonfun$t$2", ISTORE, GOTO /*B*/, - -1 /*A*/, "$anonfun$t$3", ISTORE, - -1 /*B*/, ILOAD, IRETURN)) + "$anonfun$t$2", IRETURN, + -1 /*A*/, "$anonfun$t$3", IRETURN)) } @Test