Skip to content

[RISCV] Migrate zvqdotq reduce matching to use partial_reduce infrastructure #142212

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

preames
Copy link
Collaborator

@preames preames commented May 30, 2025

This involves a codegen regression at the moment due to the issue described in 443cdd0, but this aligns the lowering paths for this case and makes it less likely future bugs go undetected.

…ructure

This involves a slight codegen regression at the moment due to the issue
described in 443cdd0, but this aligns the lowering paths for this case
and makes it less likely future bugs go undetected.
@llvmbot
Copy link
Member

llvmbot commented May 30, 2025

@llvm/pr-subscribers-backend-risc-v

Author: Philip Reames (preames)

Changes

This involves a codegen regression at the moment due to the issue described in 443cdd0, but this aligns the lowering paths for this case and makes it less likely future bugs go undetected.


Full diff: https://github.com/llvm/llvm-project/pull/142212.diff

3 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+30-26)
  • (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll (+24-48)
  • (modified) llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll (+24-48)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index b7fd0c93fa93f..ff2500b8adbc2 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -18338,17 +18338,15 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
   if (InVec.getOpcode() == ISD::ZERO_EXTEND ||
       InVec.getOpcode() == ISD::SIGN_EXTEND) {
     SDValue A = InVec.getOperand(0);
-    if (A.getValueType().getVectorElementType() != MVT::i8 ||
-        !TLI.isTypeLegal(A.getValueType()))
+    EVT OpVT = A.getValueType();
+    if (OpVT.getVectorElementType() != MVT::i8 || !TLI.isTypeLegal(OpVT))
       return SDValue();
 
     MVT ResVT = getQDOTXResultType(A.getSimpleValueType());
-    A = DAG.getBitcast(ResVT, A);
-    SDValue B = DAG.getConstant(0x01010101, DL, ResVT);
-
+    SDValue B = DAG.getConstant(0x1, DL, OpVT);
     bool IsSigned = InVec.getOpcode() == ISD::SIGN_EXTEND;
-    unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL;
-    return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
+    unsigned Opc = IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+    return DAG.getNode(Opc, DL, ResVT, {DAG.getConstant(0, DL, ResVT), A, B});
   }
 
   // mul (sext, sext) -> vqdot
@@ -18362,32 +18360,38 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
 
   SDValue A = InVec.getOperand(0);
   SDValue B = InVec.getOperand(1);
-  unsigned Opc = 0;
+
+  if (!ISD::isExtOpcode(A.getOpcode()))
+    return SDValue();
+
+  EVT OpVT = A.getOperand(0).getValueType();
+  if (OpVT.getVectorElementType() != MVT::i8 ||
+      OpVT != B.getOperand(0).getValueType() ||
+      !TLI.isTypeLegal(A.getValueType()))
+    return SDValue();
+
+  MVT ResVT = getQDOTXResultType(OpVT.getSimpleVT());
+  // Use the partial_reduce_*mla path if possible
   if (A.getOpcode() == B.getOpcode()) {
-    if (A.getOpcode() == ISD::SIGN_EXTEND)
-      Opc = RISCVISD::VQDOT_VL;
-    else if (A.getOpcode() == ISD::ZERO_EXTEND)
-      Opc = RISCVISD::VQDOTU_VL;
-    else
-      return SDValue();
-  } else {
-    if (B.getOpcode() != ISD::ZERO_EXTEND)
-      std::swap(A, B);
-    if (A.getOpcode() != ISD::SIGN_EXTEND || B.getOpcode() != ISD::ZERO_EXTEND)
+    // TODO: handle ANY_EXTEND and zext nonneg here
+    if (A.getOpcode() != ISD::SIGN_EXTEND &&
+        A.getOpcode() != ISD::ZERO_EXTEND)
       return SDValue();
-    Opc = RISCVISD::VQDOTSU_VL;
-  }
-  assert(Opc);
 
-  if (A.getOperand(0).getValueType().getVectorElementType() != MVT::i8 ||
-      A.getOperand(0).getValueType() != B.getOperand(0).getValueType() ||
-      !TLI.isTypeLegal(A.getValueType()))
+    bool IsSigned = A.getOpcode() == ISD::SIGN_EXTEND;
+    unsigned Opc = IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+    return DAG.getNode(Opc, DL, ResVT, {DAG.getConstant(0, DL, ResVT), A.getOperand(0), B.getOperand(0)});
+  }
+  // We don't yet have a partial_reduce_sumla node, so directly lower to the
+  // target node instead.
+  if (B.getOpcode() != ISD::ZERO_EXTEND)
+    std::swap(A, B);
+  if (A.getOpcode() != ISD::SIGN_EXTEND || B.getOpcode() != ISD::ZERO_EXTEND)
     return SDValue();
 
-  MVT ResVT = getQDOTXResultType(A.getOperand(0).getSimpleValueType());
   A = DAG.getBitcast(ResVT, A.getOperand(0));
   B = DAG.getBitcast(ResVT, B.getOperand(0));
-  return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
+  return lowerVQDOT(RISCVISD::VQDOTSU_VL, A, B, DL, DAG, Subtarget);
 }
 
 static SDValue performVECREDUCECombine(SDNode *N, SelectionDAG &DAG,
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
index 0237faea9efb7..8ef691622415c 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
@@ -1,8 +1,8 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
 ; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT
 ; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT
-; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT32
-; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT64
+; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT
+; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT
 
 define i32 @vqdot_vv(<16 x i8> %a, <16 x i8> %b) {
 ; NODOT-LABEL: vqdot_vv:
@@ -230,29 +230,17 @@ define i32 @reduce_of_sext(<16 x i8> %a) {
 ; NODOT-NEXT:    vmv.x.s a0, v8
 ; NODOT-NEXT:    ret
 ;
-; DOT32-LABEL: reduce_of_sext:
-; DOT32:       # %bb.0: # %entry
-; DOT32-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
-; DOT32-NEXT:    vmv.v.i v9, 0
-; DOT32-NEXT:    lui a0, 4112
-; DOT32-NEXT:    addi a0, a0, 257
-; DOT32-NEXT:    vqdot.vx v9, v8, a0
-; DOT32-NEXT:    vmv.s.x v8, zero
-; DOT32-NEXT:    vredsum.vs v8, v9, v8
-; DOT32-NEXT:    vmv.x.s a0, v8
-; DOT32-NEXT:    ret
-;
-; DOT64-LABEL: reduce_of_sext:
-; DOT64:       # %bb.0: # %entry
-; DOT64-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
-; DOT64-NEXT:    vmv.v.i v9, 0
-; DOT64-NEXT:    lui a0, 4112
-; DOT64-NEXT:    addiw a0, a0, 257
-; DOT64-NEXT:    vqdot.vx v9, v8, a0
-; DOT64-NEXT:    vmv.s.x v8, zero
-; DOT64-NEXT:    vredsum.vs v8, v9, v8
-; DOT64-NEXT:    vmv.x.s a0, v8
-; DOT64-NEXT:    ret
+; DOT-LABEL: reduce_of_sext:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetivli zero, 16, e8, m1, ta, ma
+; DOT-NEXT:    vmv.v.i v9, 1
+; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT:    vmv.v.i v10, 0
+; DOT-NEXT:    vqdot.vv v10, v8, v9
+; DOT-NEXT:    vmv.s.x v8, zero
+; DOT-NEXT:    vredsum.vs v8, v10, v8
+; DOT-NEXT:    vmv.x.s a0, v8
+; DOT-NEXT:    ret
 entry:
   %a.ext = sext <16 x i8> %a to <16 x i32>
   %res = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %a.ext)
@@ -269,29 +257,17 @@ define i32 @reduce_of_zext(<16 x i8> %a) {
 ; NODOT-NEXT:    vmv.x.s a0, v8
 ; NODOT-NEXT:    ret
 ;
-; DOT32-LABEL: reduce_of_zext:
-; DOT32:       # %bb.0: # %entry
-; DOT32-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
-; DOT32-NEXT:    vmv.v.i v9, 0
-; DOT32-NEXT:    lui a0, 4112
-; DOT32-NEXT:    addi a0, a0, 257
-; DOT32-NEXT:    vqdotu.vx v9, v8, a0
-; DOT32-NEXT:    vmv.s.x v8, zero
-; DOT32-NEXT:    vredsum.vs v8, v9, v8
-; DOT32-NEXT:    vmv.x.s a0, v8
-; DOT32-NEXT:    ret
-;
-; DOT64-LABEL: reduce_of_zext:
-; DOT64:       # %bb.0: # %entry
-; DOT64-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
-; DOT64-NEXT:    vmv.v.i v9, 0
-; DOT64-NEXT:    lui a0, 4112
-; DOT64-NEXT:    addiw a0, a0, 257
-; DOT64-NEXT:    vqdotu.vx v9, v8, a0
-; DOT64-NEXT:    vmv.s.x v8, zero
-; DOT64-NEXT:    vredsum.vs v8, v9, v8
-; DOT64-NEXT:    vmv.x.s a0, v8
-; DOT64-NEXT:    ret
+; DOT-LABEL: reduce_of_zext:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetivli zero, 16, e8, m1, ta, ma
+; DOT-NEXT:    vmv.v.i v9, 1
+; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT:    vmv.v.i v10, 0
+; DOT-NEXT:    vqdotu.vv v10, v8, v9
+; DOT-NEXT:    vmv.s.x v8, zero
+; DOT-NEXT:    vredsum.vs v8, v10, v8
+; DOT-NEXT:    vmv.x.s a0, v8
+; DOT-NEXT:    ret
 entry:
   %a.ext = zext <16 x i8> %a to <16 x i32>
   %res = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %a.ext)
diff --git a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
index d0fc915a0d07e..1948904493e8f 100644
--- a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
@@ -1,8 +1,8 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
 ; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT
 ; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT
-; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT32
-; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT64
+; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT
+; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT
 
 define i32 @vqdot_vv(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
 ; NODOT-LABEL: vqdot_vv:
@@ -230,29 +230,17 @@ define i32 @reduce_of_sext(<vscale x 16 x i8> %a) {
 ; NODOT-NEXT:    vmv.x.s a0, v8
 ; NODOT-NEXT:    ret
 ;
-; DOT32-LABEL: reduce_of_sext:
-; DOT32:       # %bb.0: # %entry
-; DOT32-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
-; DOT32-NEXT:    vmv.v.i v10, 0
-; DOT32-NEXT:    lui a0, 4112
-; DOT32-NEXT:    addi a0, a0, 257
-; DOT32-NEXT:    vqdot.vx v10, v8, a0
-; DOT32-NEXT:    vmv.s.x v8, zero
-; DOT32-NEXT:    vredsum.vs v8, v10, v8
-; DOT32-NEXT:    vmv.x.s a0, v8
-; DOT32-NEXT:    ret
-;
-; DOT64-LABEL: reduce_of_sext:
-; DOT64:       # %bb.0: # %entry
-; DOT64-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
-; DOT64-NEXT:    vmv.v.i v10, 0
-; DOT64-NEXT:    lui a0, 4112
-; DOT64-NEXT:    addiw a0, a0, 257
-; DOT64-NEXT:    vqdot.vx v10, v8, a0
-; DOT64-NEXT:    vmv.s.x v8, zero
-; DOT64-NEXT:    vredsum.vs v8, v10, v8
-; DOT64-NEXT:    vmv.x.s a0, v8
-; DOT64-NEXT:    ret
+; DOT-LABEL: reduce_of_sext:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetvli a0, zero, e8, m2, ta, ma
+; DOT-NEXT:    vmv.v.i v10, 1
+; DOT-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; DOT-NEXT:    vmv.v.i v12, 0
+; DOT-NEXT:    vqdot.vv v12, v8, v10
+; DOT-NEXT:    vmv.s.x v8, zero
+; DOT-NEXT:    vredsum.vs v8, v12, v8
+; DOT-NEXT:    vmv.x.s a0, v8
+; DOT-NEXT:    ret
 entry:
   %a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
   %res = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %a.ext)
@@ -269,29 +257,17 @@ define i32 @reduce_of_zext(<vscale x 16 x i8> %a) {
 ; NODOT-NEXT:    vmv.x.s a0, v8
 ; NODOT-NEXT:    ret
 ;
-; DOT32-LABEL: reduce_of_zext:
-; DOT32:       # %bb.0: # %entry
-; DOT32-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
-; DOT32-NEXT:    vmv.v.i v10, 0
-; DOT32-NEXT:    lui a0, 4112
-; DOT32-NEXT:    addi a0, a0, 257
-; DOT32-NEXT:    vqdotu.vx v10, v8, a0
-; DOT32-NEXT:    vmv.s.x v8, zero
-; DOT32-NEXT:    vredsum.vs v8, v10, v8
-; DOT32-NEXT:    vmv.x.s a0, v8
-; DOT32-NEXT:    ret
-;
-; DOT64-LABEL: reduce_of_zext:
-; DOT64:       # %bb.0: # %entry
-; DOT64-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
-; DOT64-NEXT:    vmv.v.i v10, 0
-; DOT64-NEXT:    lui a0, 4112
-; DOT64-NEXT:    addiw a0, a0, 257
-; DOT64-NEXT:    vqdotu.vx v10, v8, a0
-; DOT64-NEXT:    vmv.s.x v8, zero
-; DOT64-NEXT:    vredsum.vs v8, v10, v8
-; DOT64-NEXT:    vmv.x.s a0, v8
-; DOT64-NEXT:    ret
+; DOT-LABEL: reduce_of_zext:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetvli a0, zero, e8, m2, ta, ma
+; DOT-NEXT:    vmv.v.i v10, 1
+; DOT-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; DOT-NEXT:    vmv.v.i v12, 0
+; DOT-NEXT:    vqdotu.vv v12, v8, v10
+; DOT-NEXT:    vmv.s.x v8, zero
+; DOT-NEXT:    vredsum.vs v8, v12, v8
+; DOT-NEXT:    vmv.x.s a0, v8
+; DOT-NEXT:    ret
 entry:
   %a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
   %res = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %a.ext)

Copy link

github-actions bot commented May 30, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@@ -18455,61 +18430,62 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
}
}

// reduce (zext a) <--> reduce (mul zext a. zext 1)
// reduce (sext a) <--> reduce (mul sext a. sext 1)
// reduce (zext a) <--> partial_reduce_umla 0, a, 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's still an outer reduce after the partial_reduce right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, let me adjust the comment.

Copy link
Collaborator Author

@preames preames Jun 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, staring as it, not sure how to really make the comment more clear. The reduce here isn't the outermost reduce - we might be well into an interior add in the reduce tree. Any suggestions?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can drop the reduce. It's not in the comments for mul

Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@preames preames merged commit 2680afb into llvm:main Jun 10, 2025
7 checks passed
@preames preames deleted the pr-zvqdotq-reduce-of-ext-via-partial-reduce branch June 10, 2025 00:47
rorth pushed a commit to rorth/llvm-project that referenced this pull request Jun 11, 2025
…ructure (llvm#142212)

This involves a codegen regression at the moment due to the issue
described in 443cdd0, but this aligns the lowering paths for this case
and makes it less likely future bugs go undetected.
DhruvSrivastavaX pushed a commit to DhruvSrivastavaX/lldb-for-aix that referenced this pull request Jun 12, 2025
…ructure (llvm#142212)

This involves a codegen regression at the moment due to the issue
described in 443cdd0, but this aligns the lowering paths for this case
and makes it less likely future bugs go undetected.
tomtor pushed a commit to tomtor/llvm-project that referenced this pull request Jun 14, 2025
…ructure (llvm#142212)

This involves a codegen regression at the moment due to the issue
described in 443cdd0, but this aligns the lowering paths for this case
and makes it less likely future bugs go undetected.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants