-
Notifications
You must be signed in to change notification settings - Fork 14k
[RISCV] Migrate zvqdotq reduce matching to use partial_reduce infrastructure #142212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RISCV] Migrate zvqdotq reduce matching to use partial_reduce infrastructure #142212
Conversation
…ructure This involves a slight codegen regression at the moment due to the issue described in 443cdd0, but this aligns the lowering paths for this case and makes it less likely future bugs go undetected.
@llvm/pr-subscribers-backend-risc-v Author: Philip Reames (preames) ChangesThis involves a codegen regression at the moment due to the issue described in 443cdd0, but this aligns the lowering paths for this case and makes it less likely future bugs go undetected. Full diff: https://github.com/llvm/llvm-project/pull/142212.diff 3 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index b7fd0c93fa93f..ff2500b8adbc2 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -18338,17 +18338,15 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
if (InVec.getOpcode() == ISD::ZERO_EXTEND ||
InVec.getOpcode() == ISD::SIGN_EXTEND) {
SDValue A = InVec.getOperand(0);
- if (A.getValueType().getVectorElementType() != MVT::i8 ||
- !TLI.isTypeLegal(A.getValueType()))
+ EVT OpVT = A.getValueType();
+ if (OpVT.getVectorElementType() != MVT::i8 || !TLI.isTypeLegal(OpVT))
return SDValue();
MVT ResVT = getQDOTXResultType(A.getSimpleValueType());
- A = DAG.getBitcast(ResVT, A);
- SDValue B = DAG.getConstant(0x01010101, DL, ResVT);
-
+ SDValue B = DAG.getConstant(0x1, DL, OpVT);
bool IsSigned = InVec.getOpcode() == ISD::SIGN_EXTEND;
- unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL;
- return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
+ unsigned Opc = IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+ return DAG.getNode(Opc, DL, ResVT, {DAG.getConstant(0, DL, ResVT), A, B});
}
// mul (sext, sext) -> vqdot
@@ -18362,32 +18360,38 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
SDValue A = InVec.getOperand(0);
SDValue B = InVec.getOperand(1);
- unsigned Opc = 0;
+
+ if (!ISD::isExtOpcode(A.getOpcode()))
+ return SDValue();
+
+ EVT OpVT = A.getOperand(0).getValueType();
+ if (OpVT.getVectorElementType() != MVT::i8 ||
+ OpVT != B.getOperand(0).getValueType() ||
+ !TLI.isTypeLegal(A.getValueType()))
+ return SDValue();
+
+ MVT ResVT = getQDOTXResultType(OpVT.getSimpleVT());
+ // Use the partial_reduce_*mla path if possible
if (A.getOpcode() == B.getOpcode()) {
- if (A.getOpcode() == ISD::SIGN_EXTEND)
- Opc = RISCVISD::VQDOT_VL;
- else if (A.getOpcode() == ISD::ZERO_EXTEND)
- Opc = RISCVISD::VQDOTU_VL;
- else
- return SDValue();
- } else {
- if (B.getOpcode() != ISD::ZERO_EXTEND)
- std::swap(A, B);
- if (A.getOpcode() != ISD::SIGN_EXTEND || B.getOpcode() != ISD::ZERO_EXTEND)
+ // TODO: handle ANY_EXTEND and zext nonneg here
+ if (A.getOpcode() != ISD::SIGN_EXTEND &&
+ A.getOpcode() != ISD::ZERO_EXTEND)
return SDValue();
- Opc = RISCVISD::VQDOTSU_VL;
- }
- assert(Opc);
- if (A.getOperand(0).getValueType().getVectorElementType() != MVT::i8 ||
- A.getOperand(0).getValueType() != B.getOperand(0).getValueType() ||
- !TLI.isTypeLegal(A.getValueType()))
+ bool IsSigned = A.getOpcode() == ISD::SIGN_EXTEND;
+ unsigned Opc = IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+ return DAG.getNode(Opc, DL, ResVT, {DAG.getConstant(0, DL, ResVT), A.getOperand(0), B.getOperand(0)});
+ }
+ // We don't yet have a partial_reduce_sumla node, so directly lower to the
+ // target node instead.
+ if (B.getOpcode() != ISD::ZERO_EXTEND)
+ std::swap(A, B);
+ if (A.getOpcode() != ISD::SIGN_EXTEND || B.getOpcode() != ISD::ZERO_EXTEND)
return SDValue();
- MVT ResVT = getQDOTXResultType(A.getOperand(0).getSimpleValueType());
A = DAG.getBitcast(ResVT, A.getOperand(0));
B = DAG.getBitcast(ResVT, B.getOperand(0));
- return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
+ return lowerVQDOT(RISCVISD::VQDOTSU_VL, A, B, DL, DAG, Subtarget);
}
static SDValue performVECREDUCECombine(SDNode *N, SelectionDAG &DAG,
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
index 0237faea9efb7..8ef691622415c 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
@@ -1,8 +1,8 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT
; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT
-; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT32
-; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT64
+; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT
+; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT
define i32 @vqdot_vv(<16 x i8> %a, <16 x i8> %b) {
; NODOT-LABEL: vqdot_vv:
@@ -230,29 +230,17 @@ define i32 @reduce_of_sext(<16 x i8> %a) {
; NODOT-NEXT: vmv.x.s a0, v8
; NODOT-NEXT: ret
;
-; DOT32-LABEL: reduce_of_sext:
-; DOT32: # %bb.0: # %entry
-; DOT32-NEXT: vsetivli zero, 4, e32, m1, ta, ma
-; DOT32-NEXT: vmv.v.i v9, 0
-; DOT32-NEXT: lui a0, 4112
-; DOT32-NEXT: addi a0, a0, 257
-; DOT32-NEXT: vqdot.vx v9, v8, a0
-; DOT32-NEXT: vmv.s.x v8, zero
-; DOT32-NEXT: vredsum.vs v8, v9, v8
-; DOT32-NEXT: vmv.x.s a0, v8
-; DOT32-NEXT: ret
-;
-; DOT64-LABEL: reduce_of_sext:
-; DOT64: # %bb.0: # %entry
-; DOT64-NEXT: vsetivli zero, 4, e32, m1, ta, ma
-; DOT64-NEXT: vmv.v.i v9, 0
-; DOT64-NEXT: lui a0, 4112
-; DOT64-NEXT: addiw a0, a0, 257
-; DOT64-NEXT: vqdot.vx v9, v8, a0
-; DOT64-NEXT: vmv.s.x v8, zero
-; DOT64-NEXT: vredsum.vs v8, v9, v8
-; DOT64-NEXT: vmv.x.s a0, v8
-; DOT64-NEXT: ret
+; DOT-LABEL: reduce_of_sext:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 16, e8, m1, ta, ma
+; DOT-NEXT: vmv.v.i v9, 1
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdot.vv v10, v8, v9
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vredsum.vs v8, v10, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.ext = sext <16 x i8> %a to <16 x i32>
%res = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %a.ext)
@@ -269,29 +257,17 @@ define i32 @reduce_of_zext(<16 x i8> %a) {
; NODOT-NEXT: vmv.x.s a0, v8
; NODOT-NEXT: ret
;
-; DOT32-LABEL: reduce_of_zext:
-; DOT32: # %bb.0: # %entry
-; DOT32-NEXT: vsetivli zero, 4, e32, m1, ta, ma
-; DOT32-NEXT: vmv.v.i v9, 0
-; DOT32-NEXT: lui a0, 4112
-; DOT32-NEXT: addi a0, a0, 257
-; DOT32-NEXT: vqdotu.vx v9, v8, a0
-; DOT32-NEXT: vmv.s.x v8, zero
-; DOT32-NEXT: vredsum.vs v8, v9, v8
-; DOT32-NEXT: vmv.x.s a0, v8
-; DOT32-NEXT: ret
-;
-; DOT64-LABEL: reduce_of_zext:
-; DOT64: # %bb.0: # %entry
-; DOT64-NEXT: vsetivli zero, 4, e32, m1, ta, ma
-; DOT64-NEXT: vmv.v.i v9, 0
-; DOT64-NEXT: lui a0, 4112
-; DOT64-NEXT: addiw a0, a0, 257
-; DOT64-NEXT: vqdotu.vx v9, v8, a0
-; DOT64-NEXT: vmv.s.x v8, zero
-; DOT64-NEXT: vredsum.vs v8, v9, v8
-; DOT64-NEXT: vmv.x.s a0, v8
-; DOT64-NEXT: ret
+; DOT-LABEL: reduce_of_zext:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 16, e8, m1, ta, ma
+; DOT-NEXT: vmv.v.i v9, 1
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdotu.vv v10, v8, v9
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vredsum.vs v8, v10, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.ext = zext <16 x i8> %a to <16 x i32>
%res = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %a.ext)
diff --git a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
index d0fc915a0d07e..1948904493e8f 100644
--- a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
@@ -1,8 +1,8 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT
; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT
-; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT32
-; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT64
+; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT
+; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT
define i32 @vqdot_vv(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
; NODOT-LABEL: vqdot_vv:
@@ -230,29 +230,17 @@ define i32 @reduce_of_sext(<vscale x 16 x i8> %a) {
; NODOT-NEXT: vmv.x.s a0, v8
; NODOT-NEXT: ret
;
-; DOT32-LABEL: reduce_of_sext:
-; DOT32: # %bb.0: # %entry
-; DOT32-NEXT: vsetvli a0, zero, e32, m2, ta, ma
-; DOT32-NEXT: vmv.v.i v10, 0
-; DOT32-NEXT: lui a0, 4112
-; DOT32-NEXT: addi a0, a0, 257
-; DOT32-NEXT: vqdot.vx v10, v8, a0
-; DOT32-NEXT: vmv.s.x v8, zero
-; DOT32-NEXT: vredsum.vs v8, v10, v8
-; DOT32-NEXT: vmv.x.s a0, v8
-; DOT32-NEXT: ret
-;
-; DOT64-LABEL: reduce_of_sext:
-; DOT64: # %bb.0: # %entry
-; DOT64-NEXT: vsetvli a0, zero, e32, m2, ta, ma
-; DOT64-NEXT: vmv.v.i v10, 0
-; DOT64-NEXT: lui a0, 4112
-; DOT64-NEXT: addiw a0, a0, 257
-; DOT64-NEXT: vqdot.vx v10, v8, a0
-; DOT64-NEXT: vmv.s.x v8, zero
-; DOT64-NEXT: vredsum.vs v8, v10, v8
-; DOT64-NEXT: vmv.x.s a0, v8
-; DOT64-NEXT: ret
+; DOT-LABEL: reduce_of_sext:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetvli a0, zero, e8, m2, ta, ma
+; DOT-NEXT: vmv.v.i v10, 1
+; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; DOT-NEXT: vmv.v.i v12, 0
+; DOT-NEXT: vqdot.vv v12, v8, v10
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vredsum.vs v8, v12, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%res = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %a.ext)
@@ -269,29 +257,17 @@ define i32 @reduce_of_zext(<vscale x 16 x i8> %a) {
; NODOT-NEXT: vmv.x.s a0, v8
; NODOT-NEXT: ret
;
-; DOT32-LABEL: reduce_of_zext:
-; DOT32: # %bb.0: # %entry
-; DOT32-NEXT: vsetvli a0, zero, e32, m2, ta, ma
-; DOT32-NEXT: vmv.v.i v10, 0
-; DOT32-NEXT: lui a0, 4112
-; DOT32-NEXT: addi a0, a0, 257
-; DOT32-NEXT: vqdotu.vx v10, v8, a0
-; DOT32-NEXT: vmv.s.x v8, zero
-; DOT32-NEXT: vredsum.vs v8, v10, v8
-; DOT32-NEXT: vmv.x.s a0, v8
-; DOT32-NEXT: ret
-;
-; DOT64-LABEL: reduce_of_zext:
-; DOT64: # %bb.0: # %entry
-; DOT64-NEXT: vsetvli a0, zero, e32, m2, ta, ma
-; DOT64-NEXT: vmv.v.i v10, 0
-; DOT64-NEXT: lui a0, 4112
-; DOT64-NEXT: addiw a0, a0, 257
-; DOT64-NEXT: vqdotu.vx v10, v8, a0
-; DOT64-NEXT: vmv.s.x v8, zero
-; DOT64-NEXT: vredsum.vs v8, v10, v8
-; DOT64-NEXT: vmv.x.s a0, v8
-; DOT64-NEXT: ret
+; DOT-LABEL: reduce_of_zext:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetvli a0, zero, e8, m2, ta, ma
+; DOT-NEXT: vmv.v.i v10, 1
+; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; DOT-NEXT: vmv.v.i v12, 0
+; DOT-NEXT: vqdotu.vv v12, v8, v10
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vredsum.vs v8, v12, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%res = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %a.ext)
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
@@ -18455,61 +18430,62 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL, | |||
} | |||
} | |||
|
|||
// reduce (zext a) <--> reduce (mul zext a. zext 1) | |||
// reduce (sext a) <--> reduce (mul sext a. sext 1) | |||
// reduce (zext a) <--> partial_reduce_umla 0, a, 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's still an outer reduce after the partial_reduce right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, let me adjust the comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, staring as it, not sure how to really make the comment more clear. The reduce here isn't the outermost reduce - we might be well into an interior add in the reduce tree. Any suggestions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can drop the reduce
. It's not in the comments for mul
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…ructure (llvm#142212) This involves a codegen regression at the moment due to the issue described in 443cdd0, but this aligns the lowering paths for this case and makes it less likely future bugs go undetected.
…ructure (llvm#142212) This involves a codegen regression at the moment due to the issue described in 443cdd0, but this aligns the lowering paths for this case and makes it less likely future bugs go undetected.
…ructure (llvm#142212) This involves a codegen regression at the moment due to the issue described in 443cdd0, but this aligns the lowering paths for this case and makes it less likely future bugs go undetected.
This involves a codegen regression at the moment due to the issue described in 443cdd0, but this aligns the lowering paths for this case and makes it less likely future bugs go undetected.