From 569b279f195dc0ad7c83f0513d51b94afa3311ba Mon Sep 17 00:00:00 2001 From: Lingrui98 Date: Tue, 16 Nov 2021 20:11:32 +0800 Subject: [PATCH] bpu: extract wrbypass to be a module --- src/main/scala/xiangshan/frontend/Bim.scala | 59 +++------ .../scala/xiangshan/frontend/ITTAGE.scala | 78 +---------- src/main/scala/xiangshan/frontend/SC.scala | 64 +-------- src/main/scala/xiangshan/frontend/Tage.scala | 122 +++--------------- .../scala/xiangshan/frontend/WrBypass.scala | 104 +++++++++++++++ 5 files changed, 154 insertions(+), 273 deletions(-) create mode 100644 src/main/scala/xiangshan/frontend/WrBypass.scala diff --git a/src/main/scala/xiangshan/frontend/Bim.scala b/src/main/scala/xiangshan/frontend/Bim.scala index 36164f3136..85abd2e190 100644 --- a/src/main/scala/xiangshan/frontend/Bim.scala +++ b/src/main/scala/xiangshan/frontend/Bim.scala @@ -64,55 +64,32 @@ class BIM(implicit p: Parameters) extends BasePredictor with BimParams with BPUU // Update logic val u_valid = RegNext(io.update.valid) val update = RegNext(io.update.bits) - val u_idx = bimAddr.getIdx(update.pc) + + val update_mask = LowerMask(PriorityEncoderOH(update.preds.br_taken_mask.asUInt)) + val newCtrs = Wire(Vec(numBr, UInt(2.W))) + val need_to_update = VecInit((0 until numBr).map(i => u_valid && update.ftb_entry.brValids(i) && update_mask(i))) - // Bypass logic - val wrbypass_ctrs = RegInit(0.U.asTypeOf(Vec(bypassEntries, Vec(numBr, UInt(2.W))))) - val wrbypass_ctr_valids = RegInit(0.U.asTypeOf(Vec(bypassEntries, Vec(numBr, Bool())))) - val wrbypass_idx = RegInit(0.U.asTypeOf(Vec(bypassEntries, UInt(log2Up(bimSize).W)))) - val wrbypass_enq_ptr = RegInit(0.U(log2Up(bypassEntries).W)) - - val wrbypass_hits = VecInit((0 until bypassEntries).map(i => - !doing_reset && wrbypass_idx(i) === u_idx)) - val wrbypass_hit = wrbypass_hits.reduce(_||_) - val wrbypass_hit_idx = PriorityEncoder(wrbypass_hits) - val oldCtrs = VecInit((0 until numBr).map(i => - Mux(wrbypass_hit && wrbypass_ctr_valids(wrbypass_hit_idx)(i), - wrbypass_ctrs(wrbypass_hit_idx)(i), update.meta(2*i+1, 2*i)))) + // Bypass logic + val wrbypass = Module(new WrBypass(UInt(2.W), bypassEntries, log2Up(bimSize), numWays = numBr)) + wrbypass.io.wen := need_to_update.reduce(_||_) + wrbypass.io.write_idx := u_idx + wrbypass.io.write_data := newCtrs + wrbypass.io.write_way_mask.map(_ := need_to_update) + + val oldCtrs = + VecInit((0 until numBr).map(i => + Mux(wrbypass.io.hit && wrbypass.io.hit_data(i).valid, + wrbypass.io.hit_data(i).bits, + update.meta(2*i+1, 2*i)) + )) val newTakens = update.preds.br_taken_mask - val newCtrs = VecInit((0 until numBr).map(i => + newCtrs := VecInit((0 until numBr).map(i => satUpdate(oldCtrs(i), 2, newTakens(i)) )) - val update_mask = LowerMask(PriorityEncoderOH(update.preds.br_taken_mask.asUInt)) - val need_to_update = VecInit((0 until numBr).map(i => u_valid && update.ftb_entry.brValids(i) && update_mask(i))) - - when (reset.asBool) { wrbypass_ctr_valids.foreach(_ := VecInit(Seq.fill(numBr)(false.B)))} - - for (i <- 0 until numBr) { - when(need_to_update.reduce(_||_)) { - when(wrbypass_hit) { - when(need_to_update(i)) { - wrbypass_ctrs(wrbypass_hit_idx)(i) := newCtrs(i) - wrbypass_ctr_valids(wrbypass_hit_idx)(i) := true.B - } - }.otherwise { - wrbypass_ctr_valids(wrbypass_enq_ptr)(i) := false.B - when(need_to_update(i)) { - wrbypass_ctrs(wrbypass_enq_ptr)(i) := newCtrs(i) - wrbypass_ctr_valids(wrbypass_enq_ptr)(i) := true.B - } - } - } - } - - when (need_to_update.reduce(_||_) && !wrbypass_hit) { - wrbypass_idx(wrbypass_enq_ptr) := u_idx - wrbypass_enq_ptr := (wrbypass_enq_ptr + 1.U)(log2Up(bypassEntries)-1, 0) - } bim.io.w.apply( valid = need_to_update.asUInt.orR || doing_reset, diff --git a/src/main/scala/xiangshan/frontend/ITTAGE.scala b/src/main/scala/xiangshan/frontend/ITTAGE.scala index ea2b197f1d..0f047068af 100644 --- a/src/main/scala/xiangshan/frontend/ITTAGE.scala +++ b/src/main/scala/xiangshan/frontend/ITTAGE.scala @@ -157,21 +157,6 @@ class ITTageTable val wrBypassEntries = 4 val phistLen = if (PathHistoryLength > histLen) histLen else PathHistoryLength - // def compute_tag_and_hash(unhashed_idx: UInt, hist: UInt, phist: UInt) = { - // val idx_history = compute_folded_ghist(hist, log2Ceil(nRows)) - // // val idx = (unhashed_idx ^ (unhashed_idx >> (log2Ceil(nRows)-tableIdx+1)) ^ idx_history ^ idx_phist)(log2Ceil(nRows) - 1, 0) - // val idx = (unhashed_idx ^ idx_history)(log2Ceil(nRows) - 1, 0) - // val tag_history = compute_folded_ghist(hist, tagLen) - // val alt_tag_history = compute_folded_ghist(hist, tagLen-1) - // // Use another part of pc to make tags - // val tag = ( - // if (tagLen > 1) - // ((unhashed_idx >> log2Ceil(nRows)) ^ tag_history ^ (alt_tag_history << 1)) (tagLen - 1, 0) - // else 0.U - // ) - // (idx, tag) - // } - require(histLen == 0 && tagLen == 0 || histLen != 0 && tagLen != 0) val idxFhInfo = (histLen, min(log2Ceil(nRows), histLen)) val tagFhInfo = (histLen, min(histLen, tagLen)) @@ -312,37 +297,14 @@ class ITTageTable waymask = true.B ) - val wrbypass_tags = RegInit(0.U.asTypeOf(Vec(wrBypassEntries, UInt(tagLen.W)))) - val wrbypass_idxs = RegInit(0.U.asTypeOf(Vec(wrBypassEntries, UInt(log2Ceil(nRows).W)))) - val wrbypass_ctrs = RegInit(0.U.asTypeOf(Vec(wrBypassEntries, UInt(ITTageCtrBits.W)))) - val wrbypass_enq_idx = RegInit(0.U(log2Ceil(wrBypassEntries).W)) - - - val wrbypass_hits = VecInit((0 until wrBypassEntries) map { i => - wrbypass_tags(i) === update_tag && - wrbypass_idxs(i) === update_idx - }) - - - val wrbypass_hit = wrbypass_hits.reduce(_||_) - // val wrbypass_rhit = wrbypass_rhits.reduce(_||_) - val wrbypass_hit_idx = ParallelPriorityEncoder(wrbypass_hits) - // val wrbypass_rhit_idx = PriorityEncoder(wrbypass_rhits) - - // val wrbypass_rctr_hits = VecInit((0 until TageBanks).map( b => wrbypass_ctr_valids(wrbypass_rhit_idx)(b))) - - // val rhit_ctrs = RegEnable(wrbypass_ctrs(wrbypass_rhit_idx), wrbypass_rhit) - - // when (RegNext(wrbypass_rhit)) { - // for (b <- 0 until TageBanks) { - // when (RegNext(wrbypass_rctr_hits(b.U + baseBank))) { - // io.resp(b).bits.ctr := rhit_ctrs(s2_bankIdxInOrder(b)) - // } - // } - // } + val wrbypass = Module(new WrBypass(UInt(ITTageCtrBits.W), wrBypassEntries, log2Ceil(nRows), tagWidth=tagLen)) + wrbypass.io.wen := io.update.valid + wrbypass.io.write_idx := update_idx + wrbypass.io.write_tag.map(_ := update_tag) + wrbypass.io.write_data.map(_ := update_wdata.ctr) - val old_ctr = Mux(wrbypass_hit, wrbypass_ctrs(wrbypass_hit_idx), io.update.oldCtr) + val old_ctr = Mux(wrbypass.io.hit, wrbypass.io.hit_data(0).bits, io.update.oldCtr) update_wdata.ctr := Mux(io.update.alloc, 2.U, inc_ctr(old_ctr, io.update.correct)) update_wdata.valid := true.B update_wdata.tag := update_tag @@ -352,22 +314,6 @@ class ITTageTable update_hi_wdata := io.update.u(1) update_lo_wdata := io.update.u(0) - when (io.update.valid) { - when (wrbypass_hit) { - wrbypass_ctrs(wrbypass_hit_idx) := update_wdata.ctr - } .otherwise { - wrbypass_ctrs(wrbypass_enq_idx) := update_wdata.ctr - } - } - - when (io.update.valid && !wrbypass_hit) { - wrbypass_tags(wrbypass_enq_idx) := update_tag - wrbypass_idxs(wrbypass_enq_idx) := update_idx - wrbypass_enq_idx := (wrbypass_enq_idx + 1.U)(log2Ceil(wrBypassEntries)-1,0) - } - - XSPerfAccumulate("ittage_table_wrbypass_hit", io.update.valid && wrbypass_hit) - XSPerfAccumulate("ittage_table_wrbypass_enq", io.update.valid && !wrbypass_hit) XSPerfAccumulate("ittage_table_hits", io.resp.valid) if (BPUDebug && debug) { @@ -388,20 +334,8 @@ class ITTageTable p"update ITTAGE Table: writing tag:${update_tag}, " + p"ctr: ${update_wdata.ctr}, target:${Hexadecimal(update_wdata.target)}" + p" in idx $update_idx\n") - val hitCtr = wrbypass_ctrs(wrbypass_hit_idx) - XSDebug(wrbypass_hit && io.update.valid, - p"wrbypass hit wridx:$wrbypass_hit_idx, idx:$update_idx, tag: $update_tag, " + - p"ctr:$hitCtr, newCtr:${update_wdata.ctr}\n") - XSDebug(RegNext(io.req.valid) && !s1_req_rhit, "TageTableResp: no hits!\n") - // when (wrbypass_rhit && wrbypass_ctr_valids(wrbypass_rhit_idx).reduce(_||_)) { - // for (b <- 0 until TageBanks) { - // XSDebug(wrbypass_ctr_valids(wrbypass_rhit_idx)(b), - // "wrbypass rhits, wridx:%d, tag:%x, idx:%d, hitctr:%d, bank:%d\n", - // wrbypass_rhit_idx, tag, idx, wrbypass_ctrs(wrbypass_rhit_idx)(b), b.U) - // } - // } // ------------------------------Debug------------------------------------- val valids = RegInit(0.U.asTypeOf(Vec(nRows, Bool()))) diff --git a/src/main/scala/xiangshan/frontend/SC.scala b/src/main/scala/xiangshan/frontend/SC.scala index afa844eef5..3e1738e8b7 100644 --- a/src/main/scala/xiangshan/frontend/SC.scala +++ b/src/main/scala/xiangshan/frontend/SC.scala @@ -116,67 +116,19 @@ class SCTable(val nRows: Int, val ctrBits: Int, val histLen: Int)(implicit p: Pa val wrBypassEntries = 4 - class SCWrBypass extends XSModule { - val io = IO(new Bundle { - val wen = Input(Bool()) - val update_idx = Input(UInt(log2Ceil(nRows).W)) - val update_ctrs = Flipped(ValidIO(SInt(ctrBits.W))) - val update_ctrPos = Input(UInt(log2Ceil(2).W)) - val update_altPos = Input(UInt(log2Ceil(2).W)) - - val hit = Output(Bool()) - val ctrs = Vec(2, ValidIO(SInt(ctrBits.W))) - }) - - val idxes = RegInit(0.U.asTypeOf(Vec(wrBypassEntries, UInt(log2Ceil(nRows).W)))) - val ctrs = RegInit(0.U.asTypeOf(Vec(wrBypassEntries, Vec(2, SInt(ctrBits.W))))) - val ctr_valids = RegInit(0.U.asTypeOf(Vec(wrBypassEntries, Vec(2, Bool())))) - val enq_idx = RegInit(0.U(log2Ceil(wrBypassEntries).W)) - - val hits = VecInit((0 until wrBypassEntries).map { i => idxes(i) === io.update_idx }) - - val hit = hits.reduce(_||_) - val hit_idx = ParallelPriorityEncoder(hits) - - io.hit := hit - - for (i <- 0 until 2) { - io.ctrs(i).valid := ctr_valids(hit_idx)(i) - io.ctrs(i).bits := ctrs(hit_idx)(i) - } - - when (io.wen) { - when (hit) { - ctrs(hit_idx)(io.update_ctrPos) := io.update_ctrs.bits - ctr_valids(hit_idx)(io.update_ctrPos) := io.update_ctrs.valid - }.otherwise { - ctr_valids(enq_idx)(io.update_altPos) := false.B - ctr_valids(enq_idx)(io.update_ctrPos) := io.update_ctrs.valid - ctrs(enq_idx)(io.update_ctrPos) := io.update_ctrs.bits - } - } - - when(io.wen && !hit) { - idxes(enq_idx) := io.update_idx - enq_idx := (enq_idx + 1.U)(log2Ceil(wrBypassEntries)-1, 0) - } - } - - val wrbypass = Module(new SCWrBypass) + val wrbypass = Module(new WrBypass(SInt(ctrBits.W), wrBypassEntries, log2Ceil(nRows), numWays=2)) val ctrPos = io.update.tagePred val altPos = !io.update.tagePred - val bypass_ctr = wrbypass.io.ctrs(ctrPos) + val bypass_ctr = wrbypass.io.hit_data(ctrPos) val hit_and_valid = wrbypass.io.hit && bypass_ctr.valid val oldCtr = Mux(hit_and_valid, bypass_ctr.bits, io.update.oldCtr) update_wdata := ctrUpdate(oldCtr, io.update.taken) wrbypass.io.wen := io.update.mask - wrbypass.io.update_ctrs.valid := io.update.mask - wrbypass.io.update_ctrs.bits := update_wdata - wrbypass.io.update_idx := update_idx - wrbypass.io.update_ctrPos := ctrPos - wrbypass.io.update_altPos := altPos + wrbypass.io.write_data.map(_ := update_wdata) // only one of them are used + wrbypass.io.write_idx := update_idx + wrbypass.io.write_way_mask.map(_ := UIntToOH(ctrPos).asTypeOf(Vec(2, Bool()))) val u = io.update XSDebug(io.req.valid, @@ -188,12 +140,6 @@ class SCTable(val nRows: Int, val ctrBits: Int, val histLen: Int)(implicit p: Pa XSDebug(io.update.mask, p"update Table: pc:${Hexadecimal(u.pc)}, " + p"tageTaken:${u.tagePred}, taken:${u.taken}, oldCtr:${u.oldCtr}\n") - val updateCtrPos = io.update.tagePred - val hitCtr = wrbypass.io.ctrs(updateCtrPos).bits - XSDebug(wrbypass.io.hit && wrbypass.io.ctrs(updateCtrPos).valid && io.update.mask, - p"wrbypass hit idx:$update_idx, ctr:$hitCtr, " + - p"taken:${io.update.taken} newCtr:${update_wdata}\n") - } class SCThreshold(val ctrBits: Int = 6)(implicit p: Parameters) extends SCBundle { diff --git a/src/main/scala/xiangshan/frontend/Tage.scala b/src/main/scala/xiangshan/frontend/Tage.scala index 3f4dba8a78..778c29eacb 100644 --- a/src/main/scala/xiangshan/frontend/Tage.scala +++ b/src/main/scala/xiangshan/frontend/Tage.scala @@ -157,7 +157,6 @@ class TageBTable(implicit p: Parameters) extends XSModule with TBTParams{ val s1_read = bt.io.r.resp.data - //io.s1_cnt := Cat((0 until numBr reverse).map(i => s1_read(i)(1,0))).asUInt() io.s1_cnt := bt.io.r.resp.data // Update logic @@ -165,22 +164,23 @@ class TageBTable(implicit p: Parameters) extends XSModule with TBTParams{ val update = io.update.bits val u_idx = bimAddr.getIdx(update.pc) + val need_to_update = VecInit((0 until numBr).map(i => u_valid && update.ftb_entry.brValids(i)/* && update_mask(i)*/)) + + val newCtrs = Wire(Vec(numBr, UInt(2.W))) - // Bypass logic - val wrbypass_ctrs = RegInit(0.U.asTypeOf(Vec(bypassEntries, Vec(numBr, UInt(2.W))))) - val wrbypass_ctr_valids = RegInit(0.U.asTypeOf(Vec(bypassEntries, Vec(numBr, Bool())))) - val wrbypass_idx = RegInit(0.U.asTypeOf(Vec(bypassEntries, UInt(log2Up(BtSize).W)))) - val wrbypass_enq_ptr = RegInit(0.U(log2Up(bypassEntries).W)) + val wrbypass = Module(new WrBypass(UInt(2.W), bypassEntries, log2Up(BtSize), numWays = numBr)) + wrbypass.io.wen := need_to_update.reduce(_||_) + wrbypass.io.write_idx := u_idx + wrbypass.io.write_data := newCtrs + wrbypass.io.write_way_mask.map(_ := need_to_update) - val wrbypass_hits = VecInit((0 until bypassEntries).map(i => - !doing_reset && wrbypass_idx(i) === u_idx)) - val wrbypass_hit = wrbypass_hits.reduce(_||_) - val wrbypass_hit_idx = PriorityEncoder(wrbypass_hits) - val oldCtrs = VecInit((0 until numBr).map(i => - Mux(wrbypass_hit && wrbypass_ctr_valids(wrbypass_hit_idx)(i), - wrbypass_ctrs(wrbypass_hit_idx)(i), io.update_cnt(i)))) - //wrbypass_ctrs(wrbypass_hit_idx)(i), update.meta(2*i+1, 2*i)))) + val oldCtrs = + VecInit((0 until numBr).map(i => + Mux(wrbypass.io.hit && wrbypass.io.hit_data(i).valid, + wrbypass.io.hit_data(i).bits, + io.update_cnt(i)) + )) def satUpdate(old: UInt, len: Int, taken: Bool): UInt = { val oldSatTaken = old === ((1 << len)-1).U @@ -191,37 +191,10 @@ class TageBTable(implicit p: Parameters) extends XSModule with TBTParams{ } val newTakens = update.preds.br_taken_mask - val newCtrs = VecInit((0 until numBr).map(i => + newCtrs := VecInit((0 until numBr).map(i => satUpdate(oldCtrs(i), 2, newTakens(i)) )) -// val update_mask = LowerMask(PriorityEncoderOH(update.preds.taken_mask.asUInt)) - val need_to_update = VecInit((0 until numBr).map(i => u_valid && update.ftb_entry.brValids(i)/* && update_mask(i)*/)) - - when (reset.asBool) { wrbypass_ctr_valids.foreach(_ := VecInit(Seq.fill(numBr)(false.B)))} - - for (i <- 0 until numBr) { - when(need_to_update.reduce(_||_)) { - when(wrbypass_hit) { - when(need_to_update(i)) { - wrbypass_ctrs(wrbypass_hit_idx)(i) := newCtrs(i) - wrbypass_ctr_valids(wrbypass_hit_idx)(i) := true.B - } - }.otherwise { - wrbypass_ctr_valids(wrbypass_enq_ptr)(i) := false.B - when(need_to_update(i)) { - wrbypass_ctrs(wrbypass_enq_ptr)(i) := newCtrs(i) - wrbypass_ctr_valids(wrbypass_enq_ptr)(i) := true.B - } - } - } - } - - when (need_to_update.reduce(_||_) && !wrbypass_hit) { - wrbypass_idx(wrbypass_enq_ptr) := u_idx - wrbypass_enq_ptr := (wrbypass_enq_ptr + 1.U)(log2Up(bypassEntries)-1, 0) - } - bt.io.w.apply( valid = need_to_update.asUInt.orR || doing_reset, data = Mux(doing_reset, VecInit(Seq.fill(numBr)(2.U(2.W))), newCtrs), @@ -243,10 +216,6 @@ class TageTable val resp = Output(Valid(new TageResp)) val update = Input(new TageUpdate) }) - // val folded_hist = Wire(new FoldedHistory(histLen, log2Ceil(nRows), numBr)) - // // val folded_tag_hist = Wire(new FoldedHistory(histLen, tagLen, numBr)) - // def zeros = VecInit(false.B, false.B) - // folded_hist.update(zeros, zeros, 0.U(64.W), 0.U(6.W)) // bypass entries for tage update val wrBypassEntries = 8 val phistLen = if (PathHistoryLength > histLen) histLen else PathHistoryLength @@ -387,59 +356,15 @@ class TageTable waymask = true.B ) - - class WrBypass extends XSModule { - val io = IO(new Bundle { - val wen = Input(Bool()) - val update_idx = Input(UInt(log2Ceil(nRows).W)) - val update_tag = Input(UInt(tagLen.W)) - val update_ctr = Input(UInt(TageCtrBits.W)) - - val hit = Output(Bool()) - val ctr = Output(UInt(TageCtrBits.W)) - }) - - val tags = RegInit(0.U.asTypeOf(Vec(wrBypassEntries, UInt(tagLen.W)))) - val idxes = RegInit(0.U.asTypeOf(Vec(wrBypassEntries, UInt(log2Ceil(nRows).W)))) - val ctrs = RegInit(0.U.asTypeOf(Vec(wrBypassEntries, UInt(TageCtrBits.W)))) - val enq_idx = RegInit(0.U(log2Ceil(wrBypassEntries).W)) - - val hits = VecInit((0 until wrBypassEntries).map { i => - tags(i) === io.update_tag && idxes(i) === io.update_idx - }) - - val hit = hits.reduce(_||_) - val hit_idx = ParallelPriorityEncoder(hits) - - io.hit := hit - io.ctr := ctrs(hit_idx) - - when (io.wen) { - when (hit) { - ctrs(hit_idx) := io.update_ctr - }.otherwise { - ctrs(enq_idx) := io.update_ctr - } - } - - when(io.wen && !hit) { - tags(enq_idx) := io.update_tag - idxes(enq_idx) := io.update_idx - enq_idx := (enq_idx + 1.U)(log2Ceil(wrBypassEntries)-1, 0) - } - } - - val wrbypass = Module(new WrBypass) + val wrbypass = Module(new WrBypass(UInt(TageCtrBits.W), wrBypassEntries, log2Ceil(nRows), tagWidth=tagLen)) wrbypass.io.wen := io.update.mask - wrbypass.io.update_ctr := update_wdata.ctr + wrbypass.io.write_data.map(_ := update_wdata.ctr) update_wdata.ctr := Mux(io.update.alloc, - Mux(io.update.taken, 4.U, - 3.U - ), + Mux(io.update.taken, 4.U, 3.U), Mux(wrbypass.io.hit, - inc_ctr(wrbypass.io.ctr, io.update.taken), + inc_ctr(wrbypass.io.hit_data(0).bits, io.update.taken), inc_ctr(io.update.oldCtr, io.update.taken) ) ) @@ -449,8 +374,8 @@ class TageTable update_hi_wdata := io.update.u(1) update_lo_wdata := io.update.u(0) - wrbypass.io.update_idx := update_idx - wrbypass.io.update_tag := update_tag + wrbypass.io.write_idx := update_idx + wrbypass.io.write_tag.map(_ := update_tag) @@ -475,11 +400,6 @@ class TageTable XSDebug(io.update.mask, p"update Table: writing tag:$update_tag, " + p"ctr: ${update_wdata.ctr} in idx ${update_idx}\n") - val hitCtr = wrbypass.io.ctr - XSDebug(wrbypass.io.hit && io.update.mask, - // p"bank $i wrbypass hit wridx:$wrbypass_hit_idx, idx:$update_idx, tag: $update_tag, " + - p"ctr:$hitCtr, newCtr:${update_wdata.ctr}") - XSDebug(RegNext(io.req.valid) && !req_rhit, "TageTableResp: not hit!\n") // ------------------------------Debug------------------------------------- diff --git a/src/main/scala/xiangshan/frontend/WrBypass.scala b/src/main/scala/xiangshan/frontend/WrBypass.scala new file mode 100644 index 0000000000..908eca9c20 --- /dev/null +++ b/src/main/scala/xiangshan/frontend/WrBypass.scala @@ -0,0 +1,104 @@ +/*************************************************************************************** +* Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences +* Copyright (c) 2020-2021 Peng Cheng Laboratory +* +* XiangShan is licensed under Mulan PSL v2. +* You can use this software according to the terms and conditions of the Mulan PSL v2. +* You may obtain a copy of Mulan PSL v2 at: +* http://license.coscl.org.cn/MulanPSL2 +* +* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +* +* See the Mulan PSL v2 for more details. +***************************************************************************************/ +package xiangshan.frontend + +import chipsalliance.rocketchip.config.Parameters +import chisel3._ +import chisel3.util._ +import xiangshan._ +import utils._ +import chisel3.experimental.chiselName + +class WrBypass[T <: Data](gen: T, val numEntries: Int, val idxWidth: Int, + val numWays: Int = 1, val tagWidth: Int = 0)(implicit p: Parameters) extends XSModule { + require(numEntries >= 0) + require(idxWidth > 0) + require(numWays >= 1) + require(tagWidth >= 0) + def hasTag = tagWidth > 0 + def multipleWays = numWays > 1 + val io = IO(new Bundle { + val wen = Input(Bool()) + val write_idx = Input(UInt(idxWidth.W)) + val write_tag = if (hasTag) Some(Input(UInt(tagWidth.W))) else None + val write_data = Input(Vec(numWays, gen)) + val write_way_mask = if (multipleWays) Some(Input(Vec(numWays, Bool()))) else None + + val hit = Output(Bool()) + val hit_data = Vec(numWays, Valid(gen)) + }) + + class WrBypassPtr extends CircularQueuePtr[WrBypassPtr](numEntries){ + override def cloneType = (new WrBypassPtr).asInstanceOf[this.type] + } + + + + val tags = RegInit(0.U.asTypeOf((Vec(numEntries, UInt(tagWidth.W))))) + val idxes = RegInit(0.U.asTypeOf((Vec(numEntries, UInt(idxWidth.W))))) + val datas = RegInit(0.U.asTypeOf(Vec(numEntries, Vec(numWays, gen)))) + val valids = RegInit(0.U.asTypeOf(Vec(numEntries, Vec(numWays, Bool())))) + + val enq_ptr = RegInit(0.U.asTypeOf(new WrBypassPtr)) + val enq_idx = enq_ptr.value + + val hits = VecInit((0 until numEntries).map {i => + idxes(i) === io.write_idx && + tags(i) === io.write_tag.getOrElse(0.U) + }) + val hit = hits.reduce(_||_) + val hit_idx = ParallelPriorityEncoder(hits) + + io.hit := hit + for (i <- 0 until numWays) { + io.hit_data(i).valid := valids(hit_idx)(i) + io.hit_data(i).bits := datas(hit_idx)(i) + } + + for (i <- 0 until numWays) { + when (io.wen) { + val full_mask = Fill(numWays, 1.U(1.W)).asTypeOf(Vec(numWays, Bool())) + val update_this_way = io.write_way_mask.getOrElse(full_mask)(i) + when (hit) { + when (update_this_way) { + datas(hit_idx)(i) := io.write_data(i) + valids(hit_idx)(i) := true.B + } + }.otherwise { + valids(enq_idx)(i) := false.B + when (update_this_way) { + valids(enq_idx)(i) := true.B + datas(enq_idx)(i) := io.write_data(i) + } + } + } + + } + + when (io.wen && !hit) { + idxes(enq_idx) := io.write_idx + tags(enq_idx) := io.write_tag.getOrElse(0.U) + enq_ptr := enq_ptr + 1.U + } + + XSPerfAccumulate("wrbypass_hit", io.wen && hit) + XSPerfAccumulate("wrbypass_miss", io.wen && !hit) + + XSDebug(io.wen && hit, p"wrbypass hit entry #${hit_idx}, idx ${io.write_idx}" + + p"tag ${io.write_tag.getOrElse(0.U)}data ${io.write_data}\n") + XSDebug(io.wen && !hit, p"wrbypass enq entry #${enq_idx}, idx ${io.write_idx}" + + p"tag ${io.write_tag.getOrElse(0.U)}data ${io.write_data}\n") +} \ No newline at end of file