Skip to content

Commit

Permalink
backend,rs: load balance for issue selection (#1048)
Browse files Browse the repository at this point in the history
This commit adds load balance strategy in issue selection logic for
reservation stations.

Previously we have a load balance option in ExuBlock, but it cannot work
if the function units have feedbacks to RS. In this commit it is
removed.

This commit adds a victim index option for oldestFirst. For LOAD, the
first issue port has better performance and thus we set the victim index
to 0. For other function units, we use the last issue port.
  • Loading branch information
poemonsense committed Sep 19, 2021
1 parent ebb8ebf commit 7bb7bf3
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 76 deletions.
54 changes: 37 additions & 17 deletions src/main/scala/utils/BitUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,12 @@ object OnesMoreThan {
}

abstract class SelectOne {
def getNthOH(n: Int): (Bool, Vec[Bool])
protected val balance2 = RegInit(false.B)
balance2 := !balance2

// need_balance: for balanced selections only (DO NOT use this if you don't know what it is)
def getNthOH(n: Int, need_balance: Boolean = false): (Bool, Vec[Bool])
def getBalance2: Bool = balance2
}

class NaiveSelectOne(bits: Seq[Bool], max_sel: Int = -1) extends SelectOne {
Expand All @@ -199,9 +204,9 @@ class NaiveSelectOne(bits: Seq[Bool], max_sel: Int = -1) extends SelectOne {
}
}

def getNthOH(n: Int): (Bool, Vec[Bool]) = {
require(n > 0, s"n should be positive to select the n-th one")
require(n <= n_sel, s"n should not be larger than n_sel")
def getNthOH(n: Int, need_balance: Boolean = false): (Bool, Vec[Bool]) = {
require(n > 0, s"$n should be positive to select the n-th one")
require(n <= n_sel, s"$n should not be larger than $n_sel")
// bits(i) is true.B and bits(i - 1, 0) has n - 1
val selValid = OnesMoreThan(bits, n)
val sel = VecInit(bits.zip(matrix).map{ case (b, m) => b && m(n - 1) })
Expand All @@ -216,15 +221,28 @@ class CircSelectOne(bits: Seq[Bool], max_sel: Int = -1) extends SelectOne {

val sel_forward = new NaiveSelectOne(bits, (n_sel + 1) / 2)
val sel_backward = new NaiveSelectOne(bits.reverse, n_sel / 2)

def getNthOH(n: Int): (Bool, Vec[Bool]) = {
val selValid = OnesMoreThan(bits, n)
val moreThan = Seq(1, 2).map(i => OnesMoreThan(bits, i))

def getNthOH(n: Int, need_balance: Boolean = false): (Bool, Vec[Bool]) = {
require(!need_balance || max_sel == 2, s"does not support load balance between $max_sel selections")
val selValid = if (!need_balance) {
OnesMoreThan(bits, n)
} else {
if (n == 1) {
// When balance2 bit is set, we prefer the second selection port.
Mux(balance2, moreThan.last, moreThan.head)
}
else {
require(n == 2)
Mux(balance2, moreThan.head, moreThan.last)
}
}
val sel_index = (n + 1) / 2
if (n % 2 == 1) {
(selValid, sel_forward.getNthOH(sel_index)._2)
(selValid, sel_forward.getNthOH(sel_index, need_balance)._2)
}
else {
(selValid, VecInit(sel_backward.getNthOH(sel_index)._2.reverse))
(selValid, VecInit(sel_backward.getNthOH(sel_index, need_balance)._2.reverse))
}
}
}
Expand All @@ -240,26 +258,28 @@ class OddEvenSelectOne(bits: Seq[Bool], max_sel: Int = -1) extends SelectOne {
val n_odd = n_bits / 2
val sel_odd = new CircSelectOne((0 until n_odd).map(i => bits(2 * i + 1)), (n_sel + 1) / 2)

def getNthOH(n: Int): (Bool, Vec[Bool]) = {
def getNthOH(n: Int, need_balance: Boolean = false): (Bool, Vec[Bool]) = {
val sel_index = (n + 1) / 2
if (n % 2 == 1) {
val selected = sel_even.getNthOH(sel_index)
val selected = sel_even.getNthOH(sel_index, need_balance)
val sel = VecInit((0 until n_bits).map(i => if (i % 2 == 0) selected._2(i / 2) else false.B))
(selected._1, sel)
}
else {
val selected = sel_odd.getNthOH(sel_index)
val selected = sel_odd.getNthOH(sel_index, need_balance)
val sel = VecInit((0 until n_bits).map(i => if (i % 2 == 1) selected._2(i / 2) else false.B))
(selected._1, sel)
}
}
}

object SelectOne {
def apply(policy: String, bits: Seq[Bool], max_sel: Int = -1): SelectOne = policy.toLowerCase match {
case "naive" => new NaiveSelectOne(bits, max_sel)
case "circ" => new CircSelectOne(bits, max_sel)
case "oddeven" => new OddEvenSelectOne(bits, max_sel)
case _ => throw new IllegalArgumentException(s"unknown select policy")
def apply(policy: String, bits: Seq[Bool], max_sel: Int = -1): SelectOne = {
policy.toLowerCase match {
case "naive" => new NaiveSelectOne(bits, max_sel)
case "circ" => new CircSelectOne(bits, max_sel)
case "oddeven" => new OddEvenSelectOne(bits, max_sel)
case _ => throw new IllegalArgumentException(s"unknown select policy")
}
}
}
21 changes: 0 additions & 21 deletions src/main/scala/xiangshan/backend/ExuBlock.scala
Original file line number Diff line number Diff line change
Expand Up @@ -169,27 +169,6 @@ class ExuBlockImp(outer: ExuBlock)(implicit p: Parameters) extends LazyModuleImp
}
}

// Optimizations for load balance between different issue ports
// When a reservation station has at least two issue ports and
// the corresponding function unit does not have fixed latency (not pipelined),
// we let the function units alternate between each two issue ports.
val multiIssueFuConfigs = fuConfigs.filter(_._2 >= 2).filter(_._1.needLoadBalance).map(_._1)
val multiIssuePortsIdx = flattenFuConfigs.zipWithIndex.filter(x => multiIssueFuConfigs.contains(x._1))
val multiIssue = multiIssueFuConfigs.map(cfg => multiIssuePortsIdx.filter(_._1 == cfg).map(_._2))
multiIssue.foreach(ports => {
val numPingPong = ports.length / 2
for (i <- 0 until numPingPong) {
val index = ports.drop(2 * i).take(2)
println(s"Enable issue load balance between ports $index")
val pingpong = RegInit(false.B)
pingpong := !pingpong
when (pingpong) {
scheduler.io.issue(index(0)) <> fuBlock.io.issue(index(1))
scheduler.io.issue(index(1)) <> fuBlock.io.issue(index(0))
}
}
})

// By default, instructions do not have exceptions when they enter the function units.
fuBlock.io.issue.map(_.bits.uop.clearExceptions())
// For exe units that don't have exceptions, we assign zeroes to their exception vector.
Expand Down
15 changes: 8 additions & 7 deletions src/main/scala/xiangshan/backend/Scheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,6 @@ class SchedulerImp(outer: Scheduler) extends LazyModuleImp(outer) with HasXSPara

// print rs info
println("Scheduler: ")
for ((rs, i) <- rs_all.zipWithIndex) {
println(s"RS $i: $rs")
println(s" innerIntUop: ${outer.innerIntFastSources(i).map(_._2)}")
println(s" innerFpUop: ${outer.innerFpFastSources(i).map(_._2)}")
println(s" innerFastPorts: ${outer.innerFastPorts(i)}")
println(s" outFastPorts: ${outer.outFastPorts(i)}")
}
println(s" number of issue ports: ${outer.numIssuePorts}")
println(s" number of replay ports: ${outer.numReplayPorts}")
println(s" size of load and store RSes: ${outer.getMemRsEntries}")
Expand All @@ -166,6 +159,14 @@ class SchedulerImp(outer: Scheduler) extends LazyModuleImp(outer) with HasXSPara
if (fpRfConfig._1) {
println(s"FP Regfile: ${fpRfConfig._2}R${fpRfConfig._3}W")
}
for ((rs, i) <- rs_all.zipWithIndex) {
println(s"RS $i: $rs")
println(s" innerIntUop: ${outer.innerIntFastSources(i).map(_._2)}")
println(s" innerFpUop: ${outer.innerFpFastSources(i).map(_._2)}")
println(s" innerFastPorts: ${outer.innerFastPorts(i)}")
println(s" outFastPorts: ${outer.outFastPorts(i)}")
println(s" loadBalance: ${rs_all(i).params.needBalance}")
}

class SchedulerExtraIO extends XSBundle {
// feedback ports
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/xiangshan/backend/exu/Exu.scala
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ case class ExuConfig
val allWakeupFromRS = !hasUncertainlatency && (wbIntPriority <= 1 || wbFpPriority <= 1)
val wakeupFromExu = !wakeupFromRS
val hasExclusiveWbPort = (wbIntPriority == 0 && writeIntRf) || (wbFpPriority == 0 && writeFpRf)
val needLoadBalance = hasUncertainlatency && !wakeupFromRS
val needLoadBalance = hasUncertainlatency

def canAccept(fuType: UInt): Bool = {
Cat(fuConfigs.map(_.fuType === fuType)).orR()
Expand Down
59 changes: 31 additions & 28 deletions src/main/scala/xiangshan/backend/issue/ReservationStation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,10 @@ case class RSParams
){
def allWakeup: Int = numFastWakeup + numWakeup
def indexWidth: Int = log2Up(numEntries)
def oldestFirst: Boolean = exuCfg.get != AluExeUnitCfg
// oldestFirst: (Enable_or_not, Need_balance, Victim_index)
def oldestFirst: (Boolean, Boolean, Int) = (true, !isLoad, if (isLoad) 0 else numDeq - 1)
def needScheduledBit: Boolean = hasFeedback || delayedRf
def needBalance: Boolean = exuCfg.get.needLoadBalance

override def toString: String = {
s"type ${exuCfg.get.name}, size $numEntries, enq $numEnq, deq $numDeq, numSrc $numSrc, fast $numFastWakeup, wakeup $numWakeup"
Expand Down Expand Up @@ -311,35 +313,37 @@ class ReservationStation(params: RSParams)(implicit p: Parameters) extends XSMod
* S1: read uop and data
*/
val issueVec = Wire(Vec(params.numDeq, Valid(UInt(params.numEntries.W))))
// When the reservation station has oldestFirst, we need to issue the oldest instruction if possible.
// However, in this case, the select policy always selects at maximum numDeq instructions to issue.
// Thus, we need an arbitration between the numDeq + 1 possibilities.
// For better performance, we always let the last issue port be the victim.
def doIssueArbitration(oldest: Valid[UInt], in: Vec[ValidIO[UInt]], out: Vec[ValidIO[UInt]]): Bool = {
require(in.length == out.length)
out := in
// When the oldest is not matched in in.dropRight(1), we always select the oldest.
// We don't need to compare the last selection here, because we will select the oldest when
// either the last matches the oldest or the last does not match the oldest.
val oldestMatchVec = in.dropRight(1).map(i => i.valid && i.bits === oldest.bits)
val oldestMatchIn = if (params.numDeq > 1) VecInit(oldestMatchVec).asUInt().orR() else false.B
val oldestNotSelected = params.oldestFirst.B && oldest.valid && !oldestMatchIn
out.last.valid := in.last.valid || oldestNotSelected
when (oldestNotSelected) {
out.last.bits := oldest.bits
val oldestOverride = Wire(Vec(params.numDeq, Bool()))
if (params.oldestFirst._1) {
// When the reservation station has oldestFirst, we need to issue the oldest instruction if possible.
// However, in this case, the select policy always selects at maximum numDeq instructions to issue.
// Thus, we need an arbitration between the numDeq + 1 possibilities.
val oldestSelection = Module(new OldestSelection(params))
oldestSelection.io.in := RegNext(select.io.grant)
oldestSelection.io.oldest := RegNext(oldestSel)
// By default, we use the default victim index set in parameters.
oldestSelection.io.canOverride := (0 until params.numDeq).map(_ == params.oldestFirst._3).map(_.B)
// When deq width is two, we have a balance bit to indicate selection priorities.
// For better performance, we decide the victim according to selection priorities.
if (params.needBalance && params.oldestFirst._2 && params.numDeq == 2) {
// When balance2 bit is set, selection prefers the second selection port.
// Thus, the first is the victim if balance2 bit is set.
oldestSelection.io.canOverride(0) := select.io.grantBalance
oldestSelection.io.canOverride(1) := !select.io.grantBalance
}
XSPerfAccumulate("oldest_override_last", oldestNotSelected)
// returns whether the oldest is selected
oldestNotSelected
issueVec := oldestSelection.io.out
oldestOverride := oldestSelection.io.isOverrided
}
else {
issueVec := RegNext(select.io.grant)
oldestOverride.foreach(_ := false.B)
}
val oldestOverride = doIssueArbitration(RegNext(oldestSel), RegNext(select.io.grant), issueVec)

// pipeline registers for stage one
val s1_out = Wire(Vec(params.numDeq, Decoupled(new ExuInput)))
// Do the read data arbitration
s1_out.zip(payloadArray.io.read.dropRight(1)).foreach{ case (o, r) => o.bits.uop := r.data }
when (oldestOverride) {
s1_out.last.bits.uop := payloadArray.io.read.last.data
for ((doOverride, i) <- oldestOverride.zipWithIndex) {
s1_out(i).bits.uop := Mux(doOverride, payloadArray.io.read.last.data, payloadArray.io.read(i).data)
}
s1_out.foreach(_.bits.uop.debugInfo.selectTime := GTimer())

Expand Down Expand Up @@ -431,10 +435,9 @@ class ReservationStation(params: RSParams)(implicit p: Parameters) extends XSMod
dataArray.io.read.last.addr := oldestSel.bits
// Do the read data arbitration
s1_out.foreach(_.bits.src := DontCare)
for (i <- 0 until params.numSrc) {
s1_out.zip(dataArray.io.read.dropRight(1)).foreach{ case (o, r) => o.bits.src(i) := r.data(i) }
when (oldestOverride) {
s1_out.last.bits.src(i) := dataArray.io.read.last.data(i)
for ((doOverride, i) <- oldestOverride.zipWithIndex) {
for (j <- 0 until params.numSrc) {
s1_out(i).bits.src(j) := Mux(doOverride, dataArray.io.read.last.data(j), dataArray.io.read(i).data(j))
}
}

Expand Down
38 changes: 36 additions & 2 deletions src/main/scala/xiangshan/backend/issue/SelectPolicy.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ class SelectPolicy(params: RSParams)(implicit p: Parameters) extends XSModule {
val allocate = Vec(params.numEnq, ValidIO(UInt(params.numEntries.W)))
// select for issue
val request = Input(UInt(params.numEntries.W))
val grant = Vec(params.numDeq, ValidIO(UInt(params.numEntries.W))) //TODO: optimize it
val grant = Vec(params.numDeq, ValidIO(UInt(params.numEntries.W)))
val grantBalance = Output(Bool())
})

val policy = if (params.numDeq > 2 && params.numEntries > 32) "oddeven" else if (params.numDeq >= 2) "circ" else "naive"
Expand All @@ -49,14 +50,47 @@ class SelectPolicy(params: RSParams)(implicit p: Parameters) extends XSModule {
val request = io.request.asBools
val select = SelectOne(policy, request, params.numDeq)
for (i <- 0 until params.numDeq) {
val sel = select.getNthOH(i + 1)
val sel = select.getNthOH(i + 1, params.needBalance)
io.grant(i).valid := sel._1
io.grant(i).bits := sel._2.asUInt

XSError(io.grant(i).valid && PopCount(io.grant(i).bits.asBools) =/= 1.U,
p"grant vec ${Binary(io.grant(i).bits)} is not onehot")
XSDebug(io.grant(i).valid, p"select for issue request: ${Binary(io.grant(i).bits)}\n")
}
io.grantBalance := select.getBalance2

}

class OldestSelection(params: RSParams)(implicit p: Parameters) extends XSModule {
val io = IO(new Bundle() {
val in = Vec(params.numDeq, Flipped(ValidIO(UInt(params.numEntries.W))))
val oldest = Flipped(ValidIO(UInt(params.numEntries.W)))
val canOverride = Vec(params.numDeq, Input(Bool()))
val out = Vec(params.numDeq, ValidIO(UInt(params.numEntries.W)))
val isOverrided = Vec(params.numDeq, Output(Bool()))
})

io.out := io.in

val oldestMatchVec = VecInit(io.in.map(i => i.valid && OHToUInt(i.bits) === OHToUInt(io.oldest.bits)))
io.isOverrided := io.canOverride.zipWithIndex.map{ case (canDo, i) =>
// When the oldest is not matched with io.in(i), we always select the oldest.
// We don't need to compare in(i) here, because we will select the oldest no matter in(i) matches or not.
val oldestMatchIn = if (params.numDeq > 1) {
VecInit(oldestMatchVec.zipWithIndex.filterNot(_._2 == i).map(_._1)).asUInt.orR
} else false.B
canDo && io.oldest.valid && !oldestMatchIn
}

for ((out, i) <- io.out.zipWithIndex) {
out.valid := io.in(i).valid || io.isOverrided(i)
when (io.isOverrided(i)) {
out.bits := io.oldest.bits
}

XSPerfAccumulate(s"oldest_override_$i", io.isOverrided(i))
}
}

class AgeDetector(numEntries: Int, numEnq: Int, regOut: Boolean = true)(implicit p: Parameters) extends XSModule {
Expand Down

0 comments on commit 7bb7bf3

Please sign in to comment.