Skip to content

Commit

Permalink
Detect block cycles using stackalloc op (scala-native#3416)
Browse files Browse the repository at this point in the history
* Detect block cycles using stackalloc to improve stack memory managment of inlined, potentially looped functions
* Fix issue with reseting state upon assigning memory to variable and though escaping the loop.
* Ensure the stackrestore target exists. Fixes issues found in release mode

(cherry picked from commit 3273166)
  • Loading branch information
WojciechMazur committed Sep 1, 2023
1 parent 2121321 commit 94f7ada
Show file tree
Hide file tree
Showing 7 changed files with 232 additions and 18 deletions.
1 change: 1 addition & 0 deletions nir/src/main/scala/scala/scalanative/nir/Buffer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class Buffer(implicit fresh: Fresh) {

def toSeq: Seq[Inst] = buffer.toSeq
def size: Int = buffer.size
def foreach(fn: Inst => Unit) = buffer.foreach(fn)
def exists(pred: Inst => Boolean) = buffer.exists(pred)

// Control-flow ops
Expand Down
16 changes: 1 addition & 15 deletions tools/src/main/scala/scala/scalanative/interflow/Inline.scala
Original file line number Diff line number Diff line change
Expand Up @@ -221,21 +221,7 @@ trait Inline { self: Interflow =>
}
}

// Check if inlined function performed stack allocation, if so add
// insert stacksave/stackrestore LLVM Intrinsics to prevent affecting.
// By definition every stack allocation of inlined function is only needed within it's body
val allocatesOnStack = emit.exists {
case Inst.Let(_, _: Op.Stackalloc, _) => true
case _ => false
}
if (allocatesOnStack) {
import Interflow.LLVMIntrinsics._
val stackState = state.emit
.call(StackSaveSig, StackSave, Nil, Next.None)
state.emit ++= emit
state.emit
.call(StackRestoreSig, StackRestore, Seq(stackState), Next.None)
} else state.emit ++= emit
state.emit ++= emit
state.inherit(endState, res +: args)

val Type.Function(_, retty) = defn.ty: @unchecked
Expand Down
38 changes: 38 additions & 0 deletions tools/src/main/scala/scala/scalanative/interflow/MergeBlock.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package interflow

import scala.collection.mutable
import scalanative.nir._
import scala.annotation.tailrec

final class MergeBlock(val label: Inst.Label, val name: Local) {
var incoming = mutable.Map.empty[Local, (Seq[Val], State)]
Expand All @@ -17,11 +18,23 @@ final class MergeBlock(val label: Inst.Label, val name: Local) {
else label.pos
}

private var stackSavePtr: Val.Local = _
private[interflow] var emitStackSaveOp = false
private[interflow] var emitStackRestoreFor: List[Local] = Nil

def toInsts(): Seq[Inst] = {
import Interflow.LLVMIntrinsics._
val block = this
val result = new nir.Buffer()(Fresh(0))
def mergeNext(next: Next.Label): Next.Label = {
val nextBlock = outgoing(next.name)
if (nextBlock.stackSavePtr != null &&
emitStackRestoreFor.contains(next.name)) {
emitIfMissing(
end.fresh(),
Op.Call(StackRestoreSig, StackRestore, Seq(nextBlock.stackSavePtr))
)(result, block)
}
val mergeValues = nextBlock.phis.flatMap {
case MergePhi(_, incoming) =>
incoming.collect {
Expand All @@ -39,8 +52,17 @@ final class MergeBlock(val label: Inst.Label, val name: Local) {
case _ =>
util.unreachable
}

val params = block.phis.map(_.param)
result.label(block.name, params)
if (emitStackSaveOp) {
val id = block.end.fresh()
val emmited = emitIfMissing(
id = id,
op = Op.Call(StackSaveSig, StackSave, Nil)
)(result, block)
if (emmited) block.stackSavePtr = Val.Local(id, Type.Ptr)
}
result ++= block.end.emit
block.cf match {
case ret: Inst.Ret =>
Expand All @@ -66,4 +88,20 @@ final class MergeBlock(val label: Inst.Label, val name: Local) {
}
result.toSeq
}

private def emitIfMissing(
id: => Local,
op: Op.Call
)(result: nir.Buffer, block: MergeBlock): Boolean = {
// Check if original defn already contains this op
val alreadyEmmited = block.end.emit.exists {
case Inst.Let(_, `op`, _) => true
case _ => false
}
if (alreadyEmmited) false
else {
result.let(id, op, Next.None)
true
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ final class MergeProcessor(
}

orderedBlocks ++= sortedBlocks.filter(isExceptional)
orderedBlocks.toSeq
orderedBlocks.toList
}
}

Expand Down
156 changes: 155 additions & 1 deletion tools/src/main/scala/scala/scalanative/interflow/Opt.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package interflow

import scalanative.nir._
import scalanative.linker._
import scala.collection.mutable
import scala.annotation.tailrec

trait Opt { self: Interflow =>

Expand Down Expand Up @@ -134,6 +136,158 @@ trait Opt { self: Interflow =>
popMergeProcessor()
}

processor.toSeq(retTy)
val blocks = processor.toSeq(retTy)
postProcess(blocks)
}

def postProcess(blocks: Seq[MergeBlock]): Seq[MergeBlock] = {
lazy val blockIndices = blocks.zipWithIndex.toMap

blocks.foreach { block =>
emitStackStateResetForCycles(block, blocks, blockIndices)
}

blocks
}

private def emitStackStateResetForCycles(
block: MergeBlock,
blocks: Seq[MergeBlock],
blockIndices: => Map[MergeBlock, Int]
): Unit = {
// Detect cycles involving stackalloc memory
// Insert StackSave/StackRestore instructions at its first/last block
def allocatesOnStack(block: MergeBlock) = block.end.emit.exists {
case Inst.Let(_, _: Op.Stackalloc, _) => true
case _ => false
}

if (allocatesOnStack(block)) {
val allocationEscapeCheck = new TrackStackallocEscape()
def tryEmit(
block: MergeBlock,
innerCycle: BlocksCycle,
innerCycleStart: Option[MergeBlock]
): Unit = {
findCycles(block)
.filter { cycle =>
val isDirectLoop = innerCycle.isEmpty
def isEnclosingLoop = innerCycle.toSet != cycle.toSet
isDirectLoop || // 1st run
(isEnclosingLoop && !cycle.exists(allocatesOnStack)) // 2nd run
}
.foreach { cycle =>
val startIdx = cycle.map(blockIndices(_)).min
val start = blocks(startIdx)
val startName = start.label.name
val endIdx = (cycle.indexOf(start) + 1) % cycle.size
val end = cycle(endIdx)

val isNewCycle = !end.emitStackRestoreFor.contains(startName)
def canEscapeAlloc = allocationEscapeCheck(
allocatingBlock = block,
entryBlock = start,
cycle = cycle
)
if (isNewCycle) { // ensure unique
// If memory escapes current loop we cannot create stack stage guards
// Instead try to insert guard in outer loop
if (!canEscapeAlloc || innerCycleStart.exists(cycle.contains)) {
start.emitStackSaveOp = true
end.emitStackRestoreFor ::= startName
} else if (innerCycleStart.isEmpty) {
// If allocation escapes direct loop try to create state restore in outer loop
// Outer loop is a while loop which does not perform stack allocation, but is a cycle
// containing entry to inner loop
tryEmit(
start,
innerCycle = cycle,
innerCycleStart = Some(start)
)
}
}
}
}
tryEmit(block, innerCycle = Nil, innerCycleStart = None)
}
}

private type BlocksCycle = List[MergeBlock]
private def findCycles(targetNode: MergeBlock): List[BlocksCycle] = {
def dfs(
current: MergeBlock,
visited: Set[MergeBlock],
stack: List[MergeBlock]
): List[List[MergeBlock]] = {
if (visited(current)) {
// ignore cycle if backward edge does not point to targetNode
if (stack.nonEmpty && stack.last == current) stack :: Nil
else Nil
} else {
val newStack = current :: stack
current.outgoing
.flatMap {
case (_, next) => dfs(next, visited + current, newStack)
}
.filter(_.nonEmpty)
.toList
}
}

dfs(targetNode, Set.empty, Nil)
}

// NIR traversal used to check if stackallocated memory might escape the cycle
// meaning it might be referenced in next loop runs
private class TrackStackallocEscape() extends nir.Traverse {
private var tracked = mutable.Set.empty[Local]
private var curInst: Inst = _

// thread-unsafe
def apply(
allocatingBlock: MergeBlock,
entryBlock: MergeBlock,
cycle: Seq[MergeBlock]
): Boolean = {
val loopStateVals = mutable.Set.empty[Local]
entryBlock.phis.foreach {
case MergePhi(_, values) =>
values.foreach {
case (_, v: Val.Local) =>
if (Type.isPtrType(v.ty)) loopStateVals += v.name
case _ => ()
}
}
if (loopStateVals.isEmpty) false
else {
tracked.clear()
def visit(blocks: Seq[MergeBlock]) =
blocks.foreach(_.end.emit.foreach(onInst))
cycle.view
.dropWhile(_ ne allocatingBlock)
.takeWhile(_ ne entryBlock)
.foreach(_.end.emit.foreach(onInst))
tracked.intersect(loopStateVals).nonEmpty
}
}

override def onInst(inst: Inst): Unit = {
curInst = inst
inst match {
case Inst.Let(name, _: Op.Stackalloc, _) => tracked += name
case _ => ()
}
super.onInst(inst)
}

override def onVal(value: Val): Unit = value match {
case Val.Local(valName, _) =>
curInst match {
case Inst.Let(instName, op, _) if Type.isPtrType(op.resty) =>
if (tracked.contains(valName)) tracked += instName
case _ => ()
}
case _ => ()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ object UseDef {
def enterInst(n: Local) = {
val deps = mutable.UnrolledBuffer.empty[Def]
val uses = mutable.UnrolledBuffer.empty[Def]
assert(!defs.contains(n))
assert(!defs.contains(n), s"duplicate local ids: $n")
defs += ((n, InstDef(n, deps, uses)))
}
def deps(n: Local, deps: Seq[Local]) = {
Expand Down
35 changes: 35 additions & 0 deletions tools/src/test/scala/scala/scalanative/OptimizerSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package scala.scalanative

import scala.scalanative.build.{Config, NativeConfig, Mode}
import scala.scalanative.build.core.ScalaNative
import scala.scalanative.nir._

/** Base class to test the optimizer */
abstract class OptimizerSpec extends LinkerSpec {
Expand Down Expand Up @@ -32,4 +33,38 @@ abstract class OptimizerSpec extends LinkerSpec {
fn(config, optimized)
}

protected def findEntry(linked: Seq[Defn]): Option[Defn.Define] = {
import OptimizerSpec._
val companionMethod = linked
.collectFirst { case defn @ Defn.Define(_, TestMain(), _, _) => defn }
def staticForwarder = linked
.collectFirst {
case defn @ Defn.Define(_, TestMainForwarder(), _, _) => defn
}
companionMethod
.orElse(staticForwarder)
.ensuring(_.isDefined, "Not found linked method")
}
}

object OptimizerSpec {
private object TestMain {
val TestModule = Global.Top("Test$")
val CompanionMain =
TestModule.member(Rt.ScalaMainSig.copy(scope = Sig.Scope.Public))

def unapply(name: Global): Boolean = name match {
case CompanionMain => true
case Global.Member(TestModule, sig) =>
sig.unmangled match {
case Sig.Duplicate(of, _) => of == CompanionMain.sig
case _ => false
}
case _ => false
}
}
private object TestMainForwarder {
val staticForwarder = Global.Top("Test").member(Rt.ScalaMainSig)
def unapply(name: Global): Boolean = name == staticForwarder
}
}

0 comments on commit 94f7ada

Please sign in to comment.