Skip to content

Commit

Permalink
backend: add 3-bit shift fused instructions (#1022)
Browse files Browse the repository at this point in the history
This commit adds 3-bit shift fused instructions. When the program
tries to add 8-byte index, these may be used.

List of fused instructions added in this commit:

* szewl3: `slli r1, r0, 32` + `srli r1, r0, 29`

* sr29add: `srli r1, r0, 29` + `add r1, r1, r2`
  • Loading branch information
poemonsense committed Sep 12, 2021
1 parent 59a7cc9 commit a792bcf
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 39 deletions.
41 changes: 41 additions & 0 deletions src/main/scala/xiangshan/backend/decode/FusionDecoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,25 @@ class FusedSzewl2(pair: Seq[Valid[UInt]])(implicit p: Parameters) extends BaseFu
def fusionName: String = "slli32_srli30"
}

// Case: shift zero-extended word left by three
// Source: `slli r1, r0, 32` + `srli r1, r0, 29`
// Target: `szewl3 r1, r0` (customized internal opcode)
class FusedSzewl3(pair: Seq[Valid[UInt]])(implicit p: Parameters) extends BaseFusionCase(pair) {
def inst1Cond = instr(0) === Instructions.SLLI && instr(0)(25, 20) === 32.U
def inst2Cond = instr(1) === Instructions.SRLI && instr(1)(25, 20) === 29.U

def isValid: Bool = inst1Cond && inst2Cond && withSameDest && destToRs1
def target: CtrlSignals = {
val cs = getBaseCS(Instructions.ZEXT_H)
// replace the fuOpType with szewl3
cs.fuOpType := ALUOpType.szewl3
cs.lsrc(0) := instr1Rs1
cs
}

def fusionName: String = "slli32_srli29"
}

// Case: get the second byte
// Source: `srli r1, r0, 8` + `andi r1, r1, 255`
// Target: `byte2 r1, r0` (customized internal opcode)
Expand Down Expand Up @@ -251,6 +270,26 @@ class FusedSh4add(pair: Seq[Valid[UInt]])(implicit p: Parameters) extends BaseFu
def fusionName: String = "slli4_add"
}

// Case: shift right by 29 and add
// Source: `srli r1, r0, 29` + `add r1, r1, r2`
// Target: `sr29add r1, r0, r2` (customized internal opcode)
class FusedSr29add(pair: Seq[Valid[UInt]])(implicit p: Parameters) extends BaseFusionCase(pair) {
def inst1Cond = instr(0) === Instructions.SRLI && instr(0)(25, 20) === 29.U
def inst2Cond = instr(1) === Instructions.ADD

def isValid: Bool = inst1Cond && inst2Cond && withSameDest && (destToRs1 || destToRs2)
def target: CtrlSignals = {
val cs = getBaseCS(Instructions.SH3ADD)
// replace the fuOpType with sr29add
cs.fuOpType := ALUOpType.sr29add
cs.lsrc(0) := instr1Rs1
cs.lsrc(1) := Mux(destToRs1, instr2Rs2, instr2Rs1)
cs
}

def fusionName: String = "srli29_add"
}

// Case: shift right by 30 and add
// Source: `srli r1, r0, 30` + `add r1, r1, r2`
// Target: `sr30add r1, r0, r2` (customized internal opcode)
Expand Down Expand Up @@ -473,8 +512,10 @@ class FusionDecoder(implicit p: Parameters) extends XSModule {
new FusedSh3add(pair),
new FusedSzewl1(pair),
new FusedSzewl2(pair),
new FusedSzewl3(pair),
new FusedByte2(pair),
new FusedSh4add(pair),
new FusedSr29add(pair),
new FusedSr30add(pair),
new FusedSr31add(pair),
new FusedSr32add(pair),
Expand Down
48 changes: 24 additions & 24 deletions src/main/scala/xiangshan/backend/fu/Alu.scala
Original file line number Diff line number Diff line change
Expand Up @@ -109,18 +109,19 @@ class MiscResultSelect(implicit p: Parameters) extends XSModule {
ALUOpType.or -> io.or,
ALUOpType.xnor -> io.xnor,
ALUOpType.xor -> io.xor,
ALUOpType.orh48 -> io.orh48
ALUOpType.orh48 -> io.orh48,
ALUOpType.orc_b -> io.orcb
).map(x => (x._1(2, 0) === io.func(2, 0), x._2)))
val maskedLogicRes = Cat(Fill(63, ~io.func(3)), 1.U(1.W)) & logicResSel

val miscRes = ParallelMux(List(
ALUOpType.sext_b -> io.sextb,
ALUOpType.sext_h -> io.sexth,
ALUOpType.zext_h -> io.zexth,
ALUOpType.orc_b -> io.orcb,
ALUOpType.rev8 -> io.rev8,
ALUOpType.szewl1 -> Cat(0.U(31.W), io.src(31, 0), 0.U(1.W)),
ALUOpType.szewl2 -> Cat(0.U(30.W), io.src(31, 0), 0.U(2.W)),
ALUOpType.szewl3 -> Cat(0.U(29.W), io.src(31, 0), 0.U(3.W)),
ALUOpType.byte2 -> Cat(0.U(56.W), io.src(15, 8))
).map(x => (x._1(2, 0) === io.func(2, 0), x._2)))

Expand Down Expand Up @@ -187,30 +188,29 @@ class AluDataModule(implicit p: Parameters) extends XSModule {
// For 64-bit adder:
// BITS(2, 1): shamt (0, 1, 2, 3)
// BITS(3 ): different fused cases
val shaddShamt = func(2,1)
addModule.io.src(0) := (Cat(Fill(32, func(0)), Fill(32,1.U)) & src1) << shaddShamt
addModule.io.src(1) := src2
val wordMaskAddSource = Cat(Fill(32, func(0)), Fill(32, 1.U)) & src1
val shaddSource = VecInit(Seq(
Cat(wordMaskAddSource(62, 0), 0.U(1.W)),
Cat(wordMaskAddSource(61, 0), 0.U(2.W)),
Cat(wordMaskAddSource(60, 0), 0.U(3.W)),
Cat(wordMaskAddSource(59, 0), 0.U(4.W))
))
val sraddSource = VecInit(Seq(
ZeroExt(src1(63, 29), XLEN),
ZeroExt(src1(63, 30), XLEN),
ZeroExt(src1(63, 31), XLEN),
ZeroExt(src1(63, 32), XLEN)
))
// TODO: use decoder or other libraries to optimize timing
when (func(4)) {
addModule.io.src(0) := ZeroExt(src1(0), XLEN)
}
when (func(3)) {
val sourceVec = VecInit(Seq(
Cat(src1(59, 0), 0.U(4.W)),
ZeroExt(src1(63, 30), XLEN),
ZeroExt(src1(63, 31), XLEN),
ZeroExt(src1(63, 32), XLEN)
))
addModule.io.src(0) := sourceVec(func(2, 1))
}
// Now we assume shadd has the worst timing.
addModule.io.src(0) := Mux(ALUOpType.isShAdd(func), shaddSource(func(2, 1)),
Mux(ALUOpType.isSrAdd(func), sraddSource(func(2, 1)),
Mux(ALUOpType.isAddOddBit(func), ZeroExt(src1(0), XLEN), wordMaskAddSource))
)
addModule.io.src(1) := src2
val add = addModule.io.add
// For 32-bit adder:
// BITS(4 ): different fused cases
// BITS(2, 1): result mask (ffffffff, 0x1, 0xff)
addModule.io.srcw := src1(31,0)
when (func(4)) {
addModule.io.srcw := ZeroExt(src1(0), XLEN)
}
// For 32-bit adder: its source comes from lower 32bits or lowest bit.
addModule.io.srcw := Mux(ALUOpType.isAddOddBit(func), ZeroExt(src1(0), XLEN), src1(31,0))
val byteMask = Cat(Fill(56, ~func(1)), 0xff.U(8.W))
val bitMask = Cat(Fill(63, ~func(2)), 0x1.U(1.W))
val addw = addModule.io.addw & byteMask & bitMask
Expand Down
36 changes: 21 additions & 15 deletions src/main/scala/xiangshan/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ package object xiangshan {
def xor = "b0_00_00_100".U
def xnor = "b0_00_00_101".U
def orh48 = "b0_00_00_110".U
def orc_b = "b0_00_00_111".U

def andlsb = "b0_00_11_000".U
def andnlsb = "b0_00_11_001".U
Expand All @@ -225,13 +226,11 @@ package object xiangshan {
def sext_b = "b0_00_01_000".U
def sext_h = "b0_00_01_001".U
def zext_h = "b0_00_01_010".U
def rev8 = "b0_00_01_011".U
// TOOD: optimize it
def szewl1 = "b0_00_01_011".U
def orc_b = "b0_00_01_100".U
def rev8 = "b0_00_01_101".U
// TOOD: optimize it
def szewl2 = "b0_00_01_110".U
// TOOD: optimize it
def szewl1 = "b0_00_01_100".U
def szewl2 = "b0_00_01_101".U
def szewl3 = "b0_00_01_110".U
def byte2 = "b0_00_01_111".U

def beq = "b0_00_10_000".U
Expand All @@ -244,14 +243,18 @@ package object xiangshan {
// add & sub optype
def add_uw = "b0_01_00_000".U
def add = "b0_01_00_001".U
def oddadd = "b0_01_10_001".U
def sh1add_uw = "b0_01_00_010".U
def sh1add = "b0_01_00_011".U
def sh2add_uw = "b0_01_00_100".U
def sh2add = "b0_01_00_101".U
def sh3add_uw = "b0_01_00_110".U
def sh3add = "b0_01_00_111".U
def sh4add = "b0_01_01_001".U

def oddadd = "b0_01_11_001".U

def sh1add_uw = "b0_01_10_000".U
def sh1add = "b0_01_10_001".U
def sh2add_uw = "b0_01_10_010".U
def sh2add = "b0_01_10_011".U
def sh3add_uw = "b0_01_10_100".U
def sh3add = "b0_01_10_101".U
def sh4add = "b0_01_10_111".U

def sr29add = "b0_01_01_001".U
def sr30add = "b0_01_01_011".U
def sr31add = "b0_01_01_101".U
def sr32add = "b0_01_01_111".U
Expand Down Expand Up @@ -283,7 +286,7 @@ package object xiangshan {
def addw = "b1_01_00_001".U
def addwbyte = "b1_01_00_011".U
def addwbit = "b1_01_00_101".U
def oddaddw = "b1_01_10_001".U
def oddaddw = "b1_01_11_001".U
def subw = "b1_11_00_000".U
def sllw = "b1_10_00_000".U
def srlw = "b1_10_01_001".U
Expand All @@ -298,6 +301,9 @@ package object xiangshan {
def isBranch(func: UInt) = func(6, 3) === "b0010".U
def getBranchType(func: UInt) = func(2, 1)
def isBranchInvert(func: UInt) = func(0)
def isAddOddBit(func: UInt) = func(4, 3) === "b11".U(2.W)
def isShAdd(func: UInt) = func(4, 3) === "b10".U(2.W)
def isSrAdd(func: UInt) = func(4, 3) === "b01".U(2.W)

def apply() = UInt(8.W)
}
Expand Down

0 comments on commit a792bcf

Please sign in to comment.