Skip to content

[AArch64][SelectionDAG] Add type legalization for partial reduce wide adds #141075

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

NickGuy-Arm
Copy link
Contributor

Based on work initially done by @JamesChesterman.

@llvmbot
Copy link
Member

llvmbot commented May 22, 2025

@llvm/pr-subscribers-backend-aarch64

Author: Nicholas Guy (NickGuy-Arm)

Changes

Based on work initially done by @JamesChesterman.


Patch is 36.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/141075.diff

5 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+8-4)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+36)
  • (modified) llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td (+13)
  • (modified) llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll (+132-74)
  • (modified) llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll (+238-84)
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index d6e288a59b2ee..0ac8f6f3a8171 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12644,6 +12644,8 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
           TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
     return SDValue();
 
+  EVT ResultVT = N->getValueType(0);
+
   bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
   unsigned NewOpcode =
       ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
@@ -12657,7 +12659,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
         (LHSOpcode != ISD::SIGN_EXTEND || CTrunc.sext(LHSBits) != C))
       return SDValue();
 
-    return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
+    return DAG.getNode(NewOpcode, DL, ResultVT, Acc, LHSExtOp,
                        DAG.getConstant(CTrunc, DL, LHSExtOpVT));
   }
 
@@ -12678,8 +12680,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
       Op1.getValueType().getVectorElementType() != AccElemVT)
     return SDValue();
 
-  return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
-                     RHSExtOp);
+  return DAG.getNode(NewOpcode, DL, ResultVT, Acc, LHSExtOp, RHSExtOp);
 }
 
 // partial.reduce.umla(acc, zext(op), splat(1))
@@ -12703,7 +12704,10 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
 
   SDValue UnextOp1 = Op1.getOperand(0);
   EVT UnextOp1VT = UnextOp1.getValueType();
-  if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), UnextOp1VT))
+  auto *Context = DAG.getContext();
+  if (!TLI.isPartialReduceMLALegalOrCustom(
+          TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
+          TLI.getTypeToTransformTo(*Context, UnextOp1VT)))
     return SDValue();
 
   bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 13fb6a32233fe..d602a62eaaf84 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1870,6 +1870,13 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
     setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal);
 
     setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv16i8, Custom);
+
+    // Wide add types
+    if (Subtarget->hasSVE2() || Subtarget->hasSME()) {
+      setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv4i32, Custom);
+      setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv8i16, Custom);
+      setPartialReduceMLAAction(MVT::nxv8i16, MVT::nxv16i8, Custom);
+    }
   }
 
   // Handle operations that are only available in non-streaming SVE mode.
@@ -29530,6 +29537,35 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
   SDValue LHS = Op.getOperand(1);
   SDValue RHS = Op.getOperand(2);
   EVT ResultVT = Op.getValueType();
+
+  // Recognise Op as a wide add, if it is then we leave it as-is
+  // Base: nxv2i64, Subdivision: nxv4i32
+  auto IsEVTSubdivision = [](EVT Base, EVT Subdivision) -> bool {
+    assert(Base.isVector() && Subdivision.isVector());
+    assert(Base.isScalableVector() == Subdivision.isScalableVector());
+
+    ElementCount BaseCount = Base.getVectorElementCount();
+    ElementCount SubCount = Subdivision.getVectorElementCount();
+    if (BaseCount * 2 != SubCount)
+      return false;
+
+    uint64_t BaseScalarSize = Base.getScalarSizeInBits();
+    uint64_t SubScalarSize = Subdivision.getScalarSizeInBits();
+    if (BaseScalarSize != SubScalarSize * 2)
+      return false;
+
+    return true;
+  };
+  if (IsEVTSubdivision(ResultVT, LHS.getValueType())) {
+    // If it looks like a real wide add, we can leave it as-is and treat it as
+    // Legal
+    APInt C;
+    if (ISD::isConstantSplatVector(RHS.getNode(), C) && C.isOne())
+      return Op;
+    // If it doesn't, then we need to expand it.
+    return SDValue();
+  }
+
   assert(ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8);
 
   SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, MVT::nxv4i32,
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index d6bd59adef03b..b15caa25b604e 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -3787,6 +3787,19 @@ let Predicates = [HasSVE2_or_SME] in {
   defm USUBWB_ZZZ : sve2_wide_int_arith_wide<0b110, "usubwb", int_aarch64_sve_usubwb>;
   defm USUBWT_ZZZ : sve2_wide_int_arith_wide<0b111, "usubwt", int_aarch64_sve_usubwt>;
 
+  def : Pat<(nxv2i64 (partial_reduce_umla nxv2i64:$Acc, nxv4i32:$Input, (nxv4i32 (splat_vector (i32 1))))),
+            (UADDWT_ZZZ_D (UADDWB_ZZZ_D $Acc, $Input), $Input)>;
+  def : Pat<(nxv2i64 (partial_reduce_smla nxv2i64:$Acc, nxv4i32:$Input, (nxv4i32 (splat_vector (i32 1))))),
+            (SADDWT_ZZZ_D (SADDWB_ZZZ_D $Acc, $Input), $Input)>;
+  def : Pat<(nxv4i32 (partial_reduce_umla nxv4i32:$Acc, nxv8i16:$Input, (nxv8i16 (splat_vector (i32 1))))),
+            (UADDWT_ZZZ_D (UADDWB_ZZZ_D $Acc, $Input), $Input)>;
+  def : Pat<(nxv4i32 (partial_reduce_smla nxv4i32:$Acc, nxv8i16:$Input, (nxv8i16 (splat_vector (i32 1))))),
+            (SADDWT_ZZZ_D (SADDWB_ZZZ_D $Acc, $Input), $Input)>;
+  def : Pat<(nxv8i16 (partial_reduce_umla nxv8i16:$Acc, nxv16i8:$Input, (nxv16i8 (splat_vector (i32 1))))),
+            (UADDWT_ZZZ_D (UADDWB_ZZZ_D $Acc, $Input), $Input)>;
+  def : Pat<(nxv8i16 (partial_reduce_smla nxv8i16:$Acc, nxv16i8:$Input, (nxv16i8 (splat_vector (i32 1))))),
+            (SADDWT_ZZZ_D (SADDWB_ZZZ_D $Acc, $Input), $Input)>;
+
   // SVE2 integer multiply long
   defm SQDMULLB_ZZZ : sve2_wide_int_arith_long<0b11000, "sqdmullb", int_aarch64_sve_sqdmullb>;
   defm SQDMULLT_ZZZ : sve2_wide_int_arith_long<0b11001, "sqdmullt", int_aarch64_sve_sqdmullt>;
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 5bc9a101b1e44..baa63a4ca31a2 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -561,31 +561,34 @@ define <vscale x 4 x i64> @udot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale
 ; CHECK-NEXT:    add z1.d, z1.d, z3.d
 ; CHECK-NEXT:    ret
 ;
-; CHECK-NEWLOWERING-LABEL: udot_no_bin_op_8to64:
-; CHECK-NEWLOWERING:       // %bb.0:
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    uunpklo z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z5.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z6.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z7.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z24.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z25.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z24.d
-; CHECK-NEWLOWERING-NEXT:    add z5.d, z5.d, z25.d
-; CHECK-NEWLOWERING-NEXT:    add z1.d, z1.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    add z3.d, z7.d, z6.d
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z5.d
-; CHECK-NEWLOWERING-NEXT:    add z1.d, z1.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z2.d
-; CHECK-NEWLOWERING-NEXT:    add z1.d, z1.d, z4.d
-; CHECK-NEWLOWERING-NEXT:    ret
+; CHECK-NEWLOWERING-SVE-LABEL: udot_no_bin_op_8to64:
+; CHECK-NEWLOWERING-SVE:       // %bb.0:
+; CHECK-NEWLOWERING-SVE-NEXT:    movi v3.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE-NEXT:    mov z4.b, #1 // =0x1
+; CHECK-NEWLOWERING-SVE-NEXT:    udot z3.s, z2.b, z4.b
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z2.d, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z3.d, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT:    add z2.d, z3.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: udot_no_bin_op_8to64:
+; CHECK-NEWLOWERING-SVE2:       // %bb.0:
+; CHECK-NEWLOWERING-SVE2-NEXT:    movi v3.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE2-NEXT:    mov z4.b, #1 // =0x1
+; CHECK-NEWLOWERING-SVE2-NEXT:    udot z3.s, z2.b, z4.b
+; CHECK-NEWLOWERING-SVE2-NEXT:    uaddwb z0.d, z0.d, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    uaddwt z0.d, z0.d, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SME-LABEL: udot_no_bin_op_8to64:
+; CHECK-NEWLOWERING-SME:       // %bb.0:
+; CHECK-NEWLOWERING-SME-NEXT:    mov z3.b, #1 // =0x1
+; CHECK-NEWLOWERING-SME-NEXT:    mov z4.s, #0 // =0x0
+; CHECK-NEWLOWERING-SME-NEXT:    udot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-SME-NEXT:    uaddwb z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT:    uaddwt z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT:    ret
   %a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
   %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(<vscale x 4 x i64> %acc, <vscale x 16 x i64> %a.ext)
   ret <vscale x 4 x i64> %partial.reduce
@@ -603,31 +606,34 @@ define <vscale x 4 x i64> @sdot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale
 ; CHECK-NEXT:    add z1.d, z1.d, z3.d
 ; CHECK-NEXT:    ret
 ;
-; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op_8to64:
-; CHECK-NEWLOWERING:       // %bb.0:
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    sunpklo z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z5.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z6.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z7.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z24.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z25.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z24.d
-; CHECK-NEWLOWERING-NEXT:    add z5.d, z5.d, z25.d
-; CHECK-NEWLOWERING-NEXT:    add z1.d, z1.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    add z3.d, z7.d, z6.d
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z5.d
-; CHECK-NEWLOWERING-NEXT:    add z1.d, z1.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z2.d
-; CHECK-NEWLOWERING-NEXT:    add z1.d, z1.d, z4.d
-; CHECK-NEWLOWERING-NEXT:    ret
+; CHECK-NEWLOWERING-SVE-LABEL: sdot_no_bin_op_8to64:
+; CHECK-NEWLOWERING-SVE:       // %bb.0:
+; CHECK-NEWLOWERING-SVE-NEXT:    movi v3.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE-NEXT:    mov z4.b, #1 // =0x1
+; CHECK-NEWLOWERING-SVE-NEXT:    sdot z3.s, z2.b, z4.b
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpkhi z2.d, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpklo z3.d, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT:    add z2.d, z3.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: sdot_no_bin_op_8to64:
+; CHECK-NEWLOWERING-SVE2:       // %bb.0:
+; CHECK-NEWLOWERING-SVE2-NEXT:    movi v3.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE2-NEXT:    mov z4.b, #1 // =0x1
+; CHECK-NEWLOWERING-SVE2-NEXT:    sdot z3.s, z2.b, z4.b
+; CHECK-NEWLOWERING-SVE2-NEXT:    saddwb z0.d, z0.d, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    saddwt z0.d, z0.d, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SME-LABEL: sdot_no_bin_op_8to64:
+; CHECK-NEWLOWERING-SME:       // %bb.0:
+; CHECK-NEWLOWERING-SME-NEXT:    mov z3.b, #1 // =0x1
+; CHECK-NEWLOWERING-SME-NEXT:    mov z4.s, #0 // =0x0
+; CHECK-NEWLOWERING-SME-NEXT:    sdot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-SME-NEXT:    saddwb z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT:    saddwt z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT:    ret
   %a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
   %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(<vscale x 4 x i64> %acc, <vscale x 16 x i64> %a.ext)
   ret <vscale x 4 x i64> %partial.reduce
@@ -647,18 +653,44 @@ define <vscale x 4 x i32> @not_udot(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %
 ; CHECK-NEXT:    mla z0.s, p0/m, z1.s, z2.s
 ; CHECK-NEXT:    ret
 ;
-; CHECK-NEWLOWERING-LABEL: not_udot:
-; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    and z1.h, z1.h, #0xff
-; CHECK-NEWLOWERING-NEXT:    and z2.h, z2.h, #0xff
-; CHECK-NEWLOWERING-NEXT:    ptrue p0.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z4.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z3.s, z4.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z1.s, z2.s
-; CHECK-NEWLOWERING-NEXT:    ret
+; CHECK-NEWLOWERING-SVE-LABEL: not_udot:
+; CHECK-NEWLOWERING-SVE:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT:    and z1.h, z1.h, #0xff
+; CHECK-NEWLOWERING-SVE-NEXT:    and z2.h, z2.h, #0xff
+; CHECK-NEWLOWERING-SVE-NEXT:    ptrue p0.s
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z3.s, z1.h
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z4.s, z2.h
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-SVE-NEXT:    mla z0.s, p0/m, z3.s, z4.s
+; CHECK-NEWLOWERING-SVE-NEXT:    mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: not_udot:
+; CHECK-NEWLOWERING-SVE2:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT:    and z1.h, z1.h, #0xff
+; CHECK-NEWLOWERING-SVE2-NEXT:    and z2.h, z2.h, #0xff
+; CHECK-NEWLOWERING-SVE2-NEXT:    ptrue p0.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpklo z4.s, z1.h
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-SVE2-NEXT:    mla z0.s, p0/m, z4.s, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SME-LABEL: not_udot:
+; CHECK-NEWLOWERING-SME:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SME-NEXT:    and z1.h, z1.h, #0xff
+; CHECK-NEWLOWERING-SME-NEXT:    and z2.h, z2.h, #0xff
+; CHECK-NEWLOWERING-SME-NEXT:    ptrue p0.s
+; CHECK-NEWLOWERING-SME-NEXT:    uunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-SME-NEXT:    uunpklo z4.s, z1.h
+; CHECK-NEWLOWERING-SME-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-SME-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-SME-NEXT:    mla z0.s, p0/m, z4.s, z3.s
+; CHECK-NEWLOWERING-SME-NEXT:    mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-SME-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 8 x i8> %a to <vscale x 8 x i32>
   %b.wide = zext <vscale x 8 x i8> %b to <vscale x 8 x i32>
@@ -681,18 +713,44 @@ define <vscale x 2 x i64> @not_udot_wide(<vscale x 2 x i64> %acc, <vscale x 4 x
 ; CHECK-NEXT:    mla z0.d, p0/m, z1.d, z2.d
 ; CHECK-NEXT:    ret
 ;
-; CHECK-NEWLOWERING-LABEL: not_udot_wide:
-; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    and z1.s, z1.s, #0xffff
-; CHECK-NEWLOWERING-NEXT:    and z2.s, z2.s, #0xffff
-; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z4.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z3.d, z4.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z1.d, z2.d
-; CHECK-NEWLOWERING-NEXT:    ret
+; CHECK-NEWLOWERING-SVE-LABEL: not_udot_wide:
+; CHECK-NEWLOWERING-SVE:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT:    and z1.s, z1.s, #0xffff
+; CHECK-NEWLOWERING-SVE-NEXT:    and z2.s, z2.s, #0xffff
+; CHECK-NEWLOWERING-SVE-NEXT:    ptrue p0.d
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z3.d, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z4.d, z2.s
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-SVE-NEXT:    mla z0.d, p0/m, z3.d, z4.d
+; CHECK-NEWLOWERING-SVE-NEXT:    mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: not_udot_wide:
+; CHECK-NEWLOWERING-SVE2:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT:    and z1.s, z1.s, #0xffff
+; CHECK-NEWLOWERING-SVE2-NEXT:    and z2.s, z2.s, #0xffff
+; CHECK-NEWLOWERING-SVE2-NEXT:    ptrue p0.d
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpklo z3.d, z2.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpklo z4.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    mla z0.d, p0/m, z4.d, z3.d
+; CHECK-NEWLOWERING-SVE2-NEXT:    mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-SVE2-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SME-LABEL: not_udot_wide:
+; CHECK-NEWLOWERING-SME:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SME-NEXT:    and z1.s, z1.s, #0xffff
+; CHECK-NEWLOWERING-SME-NEXT:    and z2.s, z2.s, #0xffff
+; CHECK-NEWLOWERING-SME-NEXT:    ptrue p0.d
+; CHECK-NEWLOWERING-SME-NEXT:    uunpklo z3.d, z2.s
+; CHECK-NEWLOWERING-SME-NEXT:    uunpklo z4.d, z1.s
+; CHECK-NEWLOWERING-SME-NEXT:    uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-SME-NEXT:    uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-SME-NEXT:    mla z0.d, p0/m, z4.d, z3.d
+; CHECK-NEWLOWERING-SME-NEXT:    mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-SME-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 4 x i16> %a to <vscale x 4 x i64>
   %b.wide = zext <vscale x 4 x i16> %b to <vscale x 4 x i64>
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
index 5148d3da6c737..8f9f26a5d5b23 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
@@ -1,7 +1,8 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
-; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-SVE2
-; RUN: llc -mtriple=aarch64 -mattr=+sve %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-SVE
-; RUN: llc -mtriple=aarch64 -mattr=+sve2 -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NEWLOWERING
+; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK-SVE2
+; RUN: llc -mtriple=aarch64 -mattr=+sve %s -o - | FileCheck %s --check-prefixes=CHECK-SVE
+; RUN: llc -mtriple=aarch64 -mattr=+sve -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING-SVE
+; RUN: llc -mtriple=aarch64 -mattr=+sve2 -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING-SVE2
 
 define <vscale x 2 x i64> @signed_wide_add_nxv4i32(<vscale x 2 x i64> %acc, <vscale x 4 x i32> %input){
 ; CHECK-SVE2-LABEL: signed_wide_add_nxv4i32:
@@ -18,13 +19,19 @@ define <vscale x 2 x i64> @signed_wide_add_nxv4i32(<vscale x 2 x i64> %acc, <vsc
 ; CHECK-SVE-NEXT:    add z0.d, z0.d, z1.d
 ; CHECK-SVE-NEXT:    ret
 ;
-; CHECK-NEWLOWERING-LABEL: signed_wide_add_nxv4i32:
-; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    sunpklo z2.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z2.d
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z1.d
-; CHECK-NEWLOWERING-NEXT:    ret
+; CHECK-NEWLOWERING-SVE-LABEL: signed_wide_add_nxv4i32:
+; CHECK-NEWLOWERING-SVE:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpklo z2.d, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z0.d, z1.d
+; CHECK-NEWLOWERING-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: signed_wide_add_nxv4i32:
+; CHECK-NEWLOWERING-SVE2:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT:    saddwb z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    saddwt z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    ret
 entry:
     %input.wide = sext <vscale x 4 x i32> %inpu...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 22, 2025

@llvm/pr-subscribers-llvm-selectiondag

Author: Nicholas Guy (NickGuy-Arm)

Changes

Based on work initially done by @JamesChesterman.


Patch is 36.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/141075.diff

5 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+8-4)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+36)
  • (modified) llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td (+13)
  • (modified) llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll (+132-74)
  • (modified) llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll (+238-84)
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index d6e288a59b2ee..0ac8f6f3a8171 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12644,6 +12644,8 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
           TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
     return SDValue();
 
+  EVT ResultVT = N->getValueType(0);
+
   bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
   unsigned NewOpcode =
       ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
@@ -12657,7 +12659,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
         (LHSOpcode != ISD::SIGN_EXTEND || CTrunc.sext(LHSBits) != C))
       return SDValue();
 
-    return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
+    return DAG.getNode(NewOpcode, DL, ResultVT, Acc, LHSExtOp,
                        DAG.getConstant(CTrunc, DL, LHSExtOpVT));
   }
 
@@ -12678,8 +12680,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
       Op1.getValueType().getVectorElementType() != AccElemVT)
     return SDValue();
 
-  return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
-                     RHSExtOp);
+  return DAG.getNode(NewOpcode, DL, ResultVT, Acc, LHSExtOp, RHSExtOp);
 }
 
 // partial.reduce.umla(acc, zext(op), splat(1))
@@ -12703,7 +12704,10 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
 
   SDValue UnextOp1 = Op1.getOperand(0);
   EVT UnextOp1VT = UnextOp1.getValueType();
-  if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), UnextOp1VT))
+  auto *Context = DAG.getContext();
+  if (!TLI.isPartialReduceMLALegalOrCustom(
+          TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
+          TLI.getTypeToTransformTo(*Context, UnextOp1VT)))
     return SDValue();
 
   bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 13fb6a32233fe..d602a62eaaf84 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1870,6 +1870,13 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
     setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal);
 
     setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv16i8, Custom);
+
+    // Wide add types
+    if (Subtarget->hasSVE2() || Subtarget->hasSME()) {
+      setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv4i32, Custom);
+      setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv8i16, Custom);
+      setPartialReduceMLAAction(MVT::nxv8i16, MVT::nxv16i8, Custom);
+    }
   }
 
   // Handle operations that are only available in non-streaming SVE mode.
@@ -29530,6 +29537,35 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
   SDValue LHS = Op.getOperand(1);
   SDValue RHS = Op.getOperand(2);
   EVT ResultVT = Op.getValueType();
+
+  // Recognise Op as a wide add, if it is then we leave it as-is
+  // Base: nxv2i64, Subdivision: nxv4i32
+  auto IsEVTSubdivision = [](EVT Base, EVT Subdivision) -> bool {
+    assert(Base.isVector() && Subdivision.isVector());
+    assert(Base.isScalableVector() == Subdivision.isScalableVector());
+
+    ElementCount BaseCount = Base.getVectorElementCount();
+    ElementCount SubCount = Subdivision.getVectorElementCount();
+    if (BaseCount * 2 != SubCount)
+      return false;
+
+    uint64_t BaseScalarSize = Base.getScalarSizeInBits();
+    uint64_t SubScalarSize = Subdivision.getScalarSizeInBits();
+    if (BaseScalarSize != SubScalarSize * 2)
+      return false;
+
+    return true;
+  };
+  if (IsEVTSubdivision(ResultVT, LHS.getValueType())) {
+    // If it looks like a real wide add, we can leave it as-is and treat it as
+    // Legal
+    APInt C;
+    if (ISD::isConstantSplatVector(RHS.getNode(), C) && C.isOne())
+      return Op;
+    // If it doesn't, then we need to expand it.
+    return SDValue();
+  }
+
   assert(ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8);
 
   SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, MVT::nxv4i32,
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index d6bd59adef03b..b15caa25b604e 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -3787,6 +3787,19 @@ let Predicates = [HasSVE2_or_SME] in {
   defm USUBWB_ZZZ : sve2_wide_int_arith_wide<0b110, "usubwb", int_aarch64_sve_usubwb>;
   defm USUBWT_ZZZ : sve2_wide_int_arith_wide<0b111, "usubwt", int_aarch64_sve_usubwt>;
 
+  def : Pat<(nxv2i64 (partial_reduce_umla nxv2i64:$Acc, nxv4i32:$Input, (nxv4i32 (splat_vector (i32 1))))),
+            (UADDWT_ZZZ_D (UADDWB_ZZZ_D $Acc, $Input), $Input)>;
+  def : Pat<(nxv2i64 (partial_reduce_smla nxv2i64:$Acc, nxv4i32:$Input, (nxv4i32 (splat_vector (i32 1))))),
+            (SADDWT_ZZZ_D (SADDWB_ZZZ_D $Acc, $Input), $Input)>;
+  def : Pat<(nxv4i32 (partial_reduce_umla nxv4i32:$Acc, nxv8i16:$Input, (nxv8i16 (splat_vector (i32 1))))),
+            (UADDWT_ZZZ_D (UADDWB_ZZZ_D $Acc, $Input), $Input)>;
+  def : Pat<(nxv4i32 (partial_reduce_smla nxv4i32:$Acc, nxv8i16:$Input, (nxv8i16 (splat_vector (i32 1))))),
+            (SADDWT_ZZZ_D (SADDWB_ZZZ_D $Acc, $Input), $Input)>;
+  def : Pat<(nxv8i16 (partial_reduce_umla nxv8i16:$Acc, nxv16i8:$Input, (nxv16i8 (splat_vector (i32 1))))),
+            (UADDWT_ZZZ_D (UADDWB_ZZZ_D $Acc, $Input), $Input)>;
+  def : Pat<(nxv8i16 (partial_reduce_smla nxv8i16:$Acc, nxv16i8:$Input, (nxv16i8 (splat_vector (i32 1))))),
+            (SADDWT_ZZZ_D (SADDWB_ZZZ_D $Acc, $Input), $Input)>;
+
   // SVE2 integer multiply long
   defm SQDMULLB_ZZZ : sve2_wide_int_arith_long<0b11000, "sqdmullb", int_aarch64_sve_sqdmullb>;
   defm SQDMULLT_ZZZ : sve2_wide_int_arith_long<0b11001, "sqdmullt", int_aarch64_sve_sqdmullt>;
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 5bc9a101b1e44..baa63a4ca31a2 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -561,31 +561,34 @@ define <vscale x 4 x i64> @udot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale
 ; CHECK-NEXT:    add z1.d, z1.d, z3.d
 ; CHECK-NEXT:    ret
 ;
-; CHECK-NEWLOWERING-LABEL: udot_no_bin_op_8to64:
-; CHECK-NEWLOWERING:       // %bb.0:
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z3.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    uunpklo z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z5.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z6.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z7.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z24.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z25.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z24.d
-; CHECK-NEWLOWERING-NEXT:    add z5.d, z5.d, z25.d
-; CHECK-NEWLOWERING-NEXT:    add z1.d, z1.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    add z3.d, z7.d, z6.d
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z5.d
-; CHECK-NEWLOWERING-NEXT:    add z1.d, z1.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z2.d
-; CHECK-NEWLOWERING-NEXT:    add z1.d, z1.d, z4.d
-; CHECK-NEWLOWERING-NEXT:    ret
+; CHECK-NEWLOWERING-SVE-LABEL: udot_no_bin_op_8to64:
+; CHECK-NEWLOWERING-SVE:       // %bb.0:
+; CHECK-NEWLOWERING-SVE-NEXT:    movi v3.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE-NEXT:    mov z4.b, #1 // =0x1
+; CHECK-NEWLOWERING-SVE-NEXT:    udot z3.s, z2.b, z4.b
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z2.d, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z3.d, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT:    add z2.d, z3.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: udot_no_bin_op_8to64:
+; CHECK-NEWLOWERING-SVE2:       // %bb.0:
+; CHECK-NEWLOWERING-SVE2-NEXT:    movi v3.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE2-NEXT:    mov z4.b, #1 // =0x1
+; CHECK-NEWLOWERING-SVE2-NEXT:    udot z3.s, z2.b, z4.b
+; CHECK-NEWLOWERING-SVE2-NEXT:    uaddwb z0.d, z0.d, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    uaddwt z0.d, z0.d, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SME-LABEL: udot_no_bin_op_8to64:
+; CHECK-NEWLOWERING-SME:       // %bb.0:
+; CHECK-NEWLOWERING-SME-NEXT:    mov z3.b, #1 // =0x1
+; CHECK-NEWLOWERING-SME-NEXT:    mov z4.s, #0 // =0x0
+; CHECK-NEWLOWERING-SME-NEXT:    udot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-SME-NEXT:    uaddwb z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT:    uaddwt z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT:    ret
   %a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
   %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(<vscale x 4 x i64> %acc, <vscale x 16 x i64> %a.ext)
   ret <vscale x 4 x i64> %partial.reduce
@@ -603,31 +606,34 @@ define <vscale x 4 x i64> @sdot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale
 ; CHECK-NEXT:    add z1.d, z1.d, z3.d
 ; CHECK-NEXT:    ret
 ;
-; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op_8to64:
-; CHECK-NEWLOWERING:       // %bb.0:
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    sunpklo z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z5.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    sunpklo z6.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z7.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z24.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT:    sunpklo z25.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z24.d
-; CHECK-NEWLOWERING-NEXT:    add z5.d, z5.d, z25.d
-; CHECK-NEWLOWERING-NEXT:    add z1.d, z1.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    add z3.d, z7.d, z6.d
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z5.d
-; CHECK-NEWLOWERING-NEXT:    add z1.d, z1.d, z3.d
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z2.d
-; CHECK-NEWLOWERING-NEXT:    add z1.d, z1.d, z4.d
-; CHECK-NEWLOWERING-NEXT:    ret
+; CHECK-NEWLOWERING-SVE-LABEL: sdot_no_bin_op_8to64:
+; CHECK-NEWLOWERING-SVE:       // %bb.0:
+; CHECK-NEWLOWERING-SVE-NEXT:    movi v3.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE-NEXT:    mov z4.b, #1 // =0x1
+; CHECK-NEWLOWERING-SVE-NEXT:    sdot z3.s, z2.b, z4.b
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpkhi z2.d, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpklo z3.d, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT:    add z2.d, z3.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: sdot_no_bin_op_8to64:
+; CHECK-NEWLOWERING-SVE2:       // %bb.0:
+; CHECK-NEWLOWERING-SVE2-NEXT:    movi v3.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE2-NEXT:    mov z4.b, #1 // =0x1
+; CHECK-NEWLOWERING-SVE2-NEXT:    sdot z3.s, z2.b, z4.b
+; CHECK-NEWLOWERING-SVE2-NEXT:    saddwb z0.d, z0.d, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    saddwt z0.d, z0.d, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SME-LABEL: sdot_no_bin_op_8to64:
+; CHECK-NEWLOWERING-SME:       // %bb.0:
+; CHECK-NEWLOWERING-SME-NEXT:    mov z3.b, #1 // =0x1
+; CHECK-NEWLOWERING-SME-NEXT:    mov z4.s, #0 // =0x0
+; CHECK-NEWLOWERING-SME-NEXT:    sdot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-SME-NEXT:    saddwb z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT:    saddwt z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT:    ret
   %a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
   %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(<vscale x 4 x i64> %acc, <vscale x 16 x i64> %a.ext)
   ret <vscale x 4 x i64> %partial.reduce
@@ -647,18 +653,44 @@ define <vscale x 4 x i32> @not_udot(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %
 ; CHECK-NEXT:    mla z0.s, p0/m, z1.s, z2.s
 ; CHECK-NEXT:    ret
 ;
-; CHECK-NEWLOWERING-LABEL: not_udot:
-; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    and z1.h, z1.h, #0xff
-; CHECK-NEWLOWERING-NEXT:    and z2.h, z2.h, #0xff
-; CHECK-NEWLOWERING-NEXT:    ptrue p0.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpklo z4.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z3.s, z4.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.s, p0/m, z1.s, z2.s
-; CHECK-NEWLOWERING-NEXT:    ret
+; CHECK-NEWLOWERING-SVE-LABEL: not_udot:
+; CHECK-NEWLOWERING-SVE:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT:    and z1.h, z1.h, #0xff
+; CHECK-NEWLOWERING-SVE-NEXT:    and z2.h, z2.h, #0xff
+; CHECK-NEWLOWERING-SVE-NEXT:    ptrue p0.s
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z3.s, z1.h
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z4.s, z2.h
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-SVE-NEXT:    mla z0.s, p0/m, z3.s, z4.s
+; CHECK-NEWLOWERING-SVE-NEXT:    mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: not_udot:
+; CHECK-NEWLOWERING-SVE2:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT:    and z1.h, z1.h, #0xff
+; CHECK-NEWLOWERING-SVE2-NEXT:    and z2.h, z2.h, #0xff
+; CHECK-NEWLOWERING-SVE2-NEXT:    ptrue p0.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpklo z4.s, z1.h
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-SVE2-NEXT:    mla z0.s, p0/m, z4.s, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SME-LABEL: not_udot:
+; CHECK-NEWLOWERING-SME:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SME-NEXT:    and z1.h, z1.h, #0xff
+; CHECK-NEWLOWERING-SME-NEXT:    and z2.h, z2.h, #0xff
+; CHECK-NEWLOWERING-SME-NEXT:    ptrue p0.s
+; CHECK-NEWLOWERING-SME-NEXT:    uunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-SME-NEXT:    uunpklo z4.s, z1.h
+; CHECK-NEWLOWERING-SME-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-SME-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-SME-NEXT:    mla z0.s, p0/m, z4.s, z3.s
+; CHECK-NEWLOWERING-SME-NEXT:    mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-SME-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 8 x i8> %a to <vscale x 8 x i32>
   %b.wide = zext <vscale x 8 x i8> %b to <vscale x 8 x i32>
@@ -681,18 +713,44 @@ define <vscale x 2 x i64> @not_udot_wide(<vscale x 2 x i64> %acc, <vscale x 4 x
 ; CHECK-NEXT:    mla z0.d, p0/m, z1.d, z2.d
 ; CHECK-NEXT:    ret
 ;
-; CHECK-NEWLOWERING-LABEL: not_udot_wide:
-; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    and z1.s, z1.s, #0xffff
-; CHECK-NEWLOWERING-NEXT:    and z2.s, z2.s, #0xffff
-; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
-; CHECK-NEWLOWERING-NEXT:    uunpklo z3.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    uunpklo z4.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z1.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    uunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z3.d, z4.d
-; CHECK-NEWLOWERING-NEXT:    mla z0.d, p0/m, z1.d, z2.d
-; CHECK-NEWLOWERING-NEXT:    ret
+; CHECK-NEWLOWERING-SVE-LABEL: not_udot_wide:
+; CHECK-NEWLOWERING-SVE:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT:    and z1.s, z1.s, #0xffff
+; CHECK-NEWLOWERING-SVE-NEXT:    and z2.s, z2.s, #0xffff
+; CHECK-NEWLOWERING-SVE-NEXT:    ptrue p0.d
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z3.d, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpklo z4.d, z2.s
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT:    uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-SVE-NEXT:    mla z0.d, p0/m, z3.d, z4.d
+; CHECK-NEWLOWERING-SVE-NEXT:    mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: not_udot_wide:
+; CHECK-NEWLOWERING-SVE2:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT:    and z1.s, z1.s, #0xffff
+; CHECK-NEWLOWERING-SVE2-NEXT:    and z2.s, z2.s, #0xffff
+; CHECK-NEWLOWERING-SVE2-NEXT:    ptrue p0.d
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpklo z3.d, z2.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpklo z4.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    mla z0.d, p0/m, z4.d, z3.d
+; CHECK-NEWLOWERING-SVE2-NEXT:    mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-SVE2-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SME-LABEL: not_udot_wide:
+; CHECK-NEWLOWERING-SME:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SME-NEXT:    and z1.s, z1.s, #0xffff
+; CHECK-NEWLOWERING-SME-NEXT:    and z2.s, z2.s, #0xffff
+; CHECK-NEWLOWERING-SME-NEXT:    ptrue p0.d
+; CHECK-NEWLOWERING-SME-NEXT:    uunpklo z3.d, z2.s
+; CHECK-NEWLOWERING-SME-NEXT:    uunpklo z4.d, z1.s
+; CHECK-NEWLOWERING-SME-NEXT:    uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-SME-NEXT:    uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-SME-NEXT:    mla z0.d, p0/m, z4.d, z3.d
+; CHECK-NEWLOWERING-SME-NEXT:    mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-SME-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 4 x i16> %a to <vscale x 4 x i64>
   %b.wide = zext <vscale x 4 x i16> %b to <vscale x 4 x i64>
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
index 5148d3da6c737..8f9f26a5d5b23 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
@@ -1,7 +1,8 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
-; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-SVE2
-; RUN: llc -mtriple=aarch64 -mattr=+sve %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-SVE
-; RUN: llc -mtriple=aarch64 -mattr=+sve2 -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NEWLOWERING
+; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK-SVE2
+; RUN: llc -mtriple=aarch64 -mattr=+sve %s -o - | FileCheck %s --check-prefixes=CHECK-SVE
+; RUN: llc -mtriple=aarch64 -mattr=+sve -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING-SVE
+; RUN: llc -mtriple=aarch64 -mattr=+sve2 -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING-SVE2
 
 define <vscale x 2 x i64> @signed_wide_add_nxv4i32(<vscale x 2 x i64> %acc, <vscale x 4 x i32> %input){
 ; CHECK-SVE2-LABEL: signed_wide_add_nxv4i32:
@@ -18,13 +19,19 @@ define <vscale x 2 x i64> @signed_wide_add_nxv4i32(<vscale x 2 x i64> %acc, <vsc
 ; CHECK-SVE-NEXT:    add z0.d, z0.d, z1.d
 ; CHECK-SVE-NEXT:    ret
 ;
-; CHECK-NEWLOWERING-LABEL: signed_wide_add_nxv4i32:
-; CHECK-NEWLOWERING:       // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT:    sunpklo z2.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    sunpkhi z1.d, z1.s
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z2.d
-; CHECK-NEWLOWERING-NEXT:    add z0.d, z0.d, z1.d
-; CHECK-NEWLOWERING-NEXT:    ret
+; CHECK-NEWLOWERING-SVE-LABEL: signed_wide_add_nxv4i32:
+; CHECK-NEWLOWERING-SVE:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpklo z2.d, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT:    sunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT:    add z0.d, z0.d, z1.d
+; CHECK-NEWLOWERING-SVE-NEXT:    ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: signed_wide_add_nxv4i32:
+; CHECK-NEWLOWERING-SVE2:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT:    saddwb z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    saddwt z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT:    ret
 entry:
     %input.wide = sext <vscale x 4 x i32> %inpu...
[truncated]

@@ -12703,7 +12704,10 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {

SDValue UnextOp1 = Op1.getOperand(0);
EVT UnextOp1VT = UnextOp1.getValueType();
if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), UnextOp1VT))
auto *Context = DAG.getContext();
if (!TLI.isPartialReduceMLALegalOrCustom(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain the motivation behind this change?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that this combine happens before type legalisation, so it needs to check if the partial reduce will be supported for the legal type.

@@ -12703,7 +12704,10 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {

SDValue UnextOp1 = Op1.getOperand(0);
EVT UnextOp1VT = UnextOp1.getValueType();
if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), UnextOp1VT))
auto *Context = DAG.getContext();
if (!TLI.isPartialReduceMLALegalOrCustom(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that this combine happens before type legalisation, so it needs to check if the partial reduce will be supported for the legal type.

@NickGuy-Arm NickGuy-Arm force-pushed the JamesChesterman/legal-partial-reduction/wideadd branch from 3010a2b to 3936654 Compare May 28, 2025 14:53
@NickGuy-Arm NickGuy-Arm force-pushed the JamesChesterman/legal-partial-reduction/wideadd branch from 3936654 to 26c1098 Compare May 28, 2025 15:56
@NickGuy-Arm NickGuy-Arm merged commit a5d97eb into llvm:main May 29, 2025
11 checks passed
@llvm-ci
Copy link
Collaborator

llvm-ci commented May 29, 2025

LLVM Buildbot has detected a new failure on builder mlir-nvidia-gcc7 running on mlir-nvidia while building llvm at step 7 "test-build-check-mlir-build-only-check-mlir".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/116/builds/13487

Here is the relevant piece of the build log for the reference
Step 7 (test-build-check-mlir-build-only-check-mlir) failure: test (failure)
******************** TEST 'MLIR :: Integration/GPU/CUDA/async.mlir' FAILED ********************
Exit Code: 1

Command Output (stdout):
--
# RUN: at line 1
/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir  | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -gpu-kernel-outlining  | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm),nvvm-attach-target)'  | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -gpu-async-region -gpu-to-llvm -reconcile-unrealized-casts -gpu-module-to-binary="format=fatbin"  | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -async-to-async-runtime -async-runtime-ref-counting  | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -convert-async-to-llvm -convert-func-to-llvm -convert-arith-to-llvm -convert-cf-to-llvm -reconcile-unrealized-casts  | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-runner    --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/lib/libmlir_cuda_runtime.so    --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/lib/libmlir_async_runtime.so    --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/lib/libmlir_runner_utils.so    --entry-point-result=void -O0  | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/FileCheck /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -gpu-kernel-outlining
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt '-pass-pipeline=builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm),nvvm-attach-target)'
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -gpu-async-region -gpu-to-llvm -reconcile-unrealized-casts -gpu-module-to-binary=format=fatbin
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -async-to-async-runtime -async-runtime-ref-counting
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -convert-async-to-llvm -convert-func-to-llvm -convert-arith-to-llvm -convert-cf-to-llvm -reconcile-unrealized-casts
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-runner --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/lib/libmlir_cuda_runtime.so --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/lib/libmlir_async_runtime.so --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/lib/libmlir_runner_utils.so --entry-point-result=void -O0
# .---command stderr------------
# | 'cuStreamWaitEvent(stream, event, 0)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuEventDestroy(event)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuStreamWaitEvent(stream, event, 0)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuEventDestroy(event)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuStreamWaitEvent(stream, event, 0)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuStreamWaitEvent(stream, event, 0)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuEventDestroy(event)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuEventDestroy(event)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuEventSynchronize(event)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuEventDestroy(event)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# `-----------------------------
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/FileCheck /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir
# .---command stderr------------
# | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir:68:12: error: CHECK: expected string not found in input
# |  // CHECK: [84, 84]
# |            ^
# | <stdin>:1:1: note: scanning from here
# | Unranked Memref base@ = 0x5b62d2839480 rank = 1 offset = 0 sizes = [2] strides = [1] data = 
# | ^
# | <stdin>:2:1: note: possible intended match here
# | [42, 84]
# | ^
# | 
# | Input file: <stdin>
# | Check file: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir
# | 
# | -dump-input=help explains the following input dump.
# | 
# | Input was:
# | <<<<<<
# |             1: Unranked Memref base@ = 0x5b62d2839480 rank = 1 offset = 0 sizes = [2] strides = [1] data =  
# | check:68'0     X~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ error: no match found
# |             2: [42, 84] 
# | check:68'0     ~~~~~~~~~
# | check:68'1     ?         possible intended match
...

svkeerthy pushed a commit that referenced this pull request May 29, 2025
google-yfyang pushed a commit to google-yfyang/llvm-project that referenced this pull request May 29, 2025
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Jun 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AArch64 llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants