-
Notifications
You must be signed in to change notification settings - Fork 14k
[AArch64][SelectionDAG] Add type legalization for partial reduce wide adds #141075
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AArch64][SelectionDAG] Add type legalization for partial reduce wide adds #141075
Conversation
@llvm/pr-subscribers-backend-aarch64 Author: Nicholas Guy (NickGuy-Arm) ChangesBased on work initially done by @JamesChesterman. Patch is 36.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/141075.diff 5 Files Affected:
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index d6e288a59b2ee..0ac8f6f3a8171 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12644,6 +12644,8 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
return SDValue();
+ EVT ResultVT = N->getValueType(0);
+
bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
unsigned NewOpcode =
ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
@@ -12657,7 +12659,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
(LHSOpcode != ISD::SIGN_EXTEND || CTrunc.sext(LHSBits) != C))
return SDValue();
- return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
+ return DAG.getNode(NewOpcode, DL, ResultVT, Acc, LHSExtOp,
DAG.getConstant(CTrunc, DL, LHSExtOpVT));
}
@@ -12678,8 +12680,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
Op1.getValueType().getVectorElementType() != AccElemVT)
return SDValue();
- return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
- RHSExtOp);
+ return DAG.getNode(NewOpcode, DL, ResultVT, Acc, LHSExtOp, RHSExtOp);
}
// partial.reduce.umla(acc, zext(op), splat(1))
@@ -12703,7 +12704,10 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
SDValue UnextOp1 = Op1.getOperand(0);
EVT UnextOp1VT = UnextOp1.getValueType();
- if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), UnextOp1VT))
+ auto *Context = DAG.getContext();
+ if (!TLI.isPartialReduceMLALegalOrCustom(
+ TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
+ TLI.getTypeToTransformTo(*Context, UnextOp1VT)))
return SDValue();
bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 13fb6a32233fe..d602a62eaaf84 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1870,6 +1870,13 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal);
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv16i8, Custom);
+
+ // Wide add types
+ if (Subtarget->hasSVE2() || Subtarget->hasSME()) {
+ setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv4i32, Custom);
+ setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv8i16, Custom);
+ setPartialReduceMLAAction(MVT::nxv8i16, MVT::nxv16i8, Custom);
+ }
}
// Handle operations that are only available in non-streaming SVE mode.
@@ -29530,6 +29537,35 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
SDValue LHS = Op.getOperand(1);
SDValue RHS = Op.getOperand(2);
EVT ResultVT = Op.getValueType();
+
+ // Recognise Op as a wide add, if it is then we leave it as-is
+ // Base: nxv2i64, Subdivision: nxv4i32
+ auto IsEVTSubdivision = [](EVT Base, EVT Subdivision) -> bool {
+ assert(Base.isVector() && Subdivision.isVector());
+ assert(Base.isScalableVector() == Subdivision.isScalableVector());
+
+ ElementCount BaseCount = Base.getVectorElementCount();
+ ElementCount SubCount = Subdivision.getVectorElementCount();
+ if (BaseCount * 2 != SubCount)
+ return false;
+
+ uint64_t BaseScalarSize = Base.getScalarSizeInBits();
+ uint64_t SubScalarSize = Subdivision.getScalarSizeInBits();
+ if (BaseScalarSize != SubScalarSize * 2)
+ return false;
+
+ return true;
+ };
+ if (IsEVTSubdivision(ResultVT, LHS.getValueType())) {
+ // If it looks like a real wide add, we can leave it as-is and treat it as
+ // Legal
+ APInt C;
+ if (ISD::isConstantSplatVector(RHS.getNode(), C) && C.isOne())
+ return Op;
+ // If it doesn't, then we need to expand it.
+ return SDValue();
+ }
+
assert(ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8);
SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, MVT::nxv4i32,
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index d6bd59adef03b..b15caa25b604e 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -3787,6 +3787,19 @@ let Predicates = [HasSVE2_or_SME] in {
defm USUBWB_ZZZ : sve2_wide_int_arith_wide<0b110, "usubwb", int_aarch64_sve_usubwb>;
defm USUBWT_ZZZ : sve2_wide_int_arith_wide<0b111, "usubwt", int_aarch64_sve_usubwt>;
+ def : Pat<(nxv2i64 (partial_reduce_umla nxv2i64:$Acc, nxv4i32:$Input, (nxv4i32 (splat_vector (i32 1))))),
+ (UADDWT_ZZZ_D (UADDWB_ZZZ_D $Acc, $Input), $Input)>;
+ def : Pat<(nxv2i64 (partial_reduce_smla nxv2i64:$Acc, nxv4i32:$Input, (nxv4i32 (splat_vector (i32 1))))),
+ (SADDWT_ZZZ_D (SADDWB_ZZZ_D $Acc, $Input), $Input)>;
+ def : Pat<(nxv4i32 (partial_reduce_umla nxv4i32:$Acc, nxv8i16:$Input, (nxv8i16 (splat_vector (i32 1))))),
+ (UADDWT_ZZZ_D (UADDWB_ZZZ_D $Acc, $Input), $Input)>;
+ def : Pat<(nxv4i32 (partial_reduce_smla nxv4i32:$Acc, nxv8i16:$Input, (nxv8i16 (splat_vector (i32 1))))),
+ (SADDWT_ZZZ_D (SADDWB_ZZZ_D $Acc, $Input), $Input)>;
+ def : Pat<(nxv8i16 (partial_reduce_umla nxv8i16:$Acc, nxv16i8:$Input, (nxv16i8 (splat_vector (i32 1))))),
+ (UADDWT_ZZZ_D (UADDWB_ZZZ_D $Acc, $Input), $Input)>;
+ def : Pat<(nxv8i16 (partial_reduce_smla nxv8i16:$Acc, nxv16i8:$Input, (nxv16i8 (splat_vector (i32 1))))),
+ (SADDWT_ZZZ_D (SADDWB_ZZZ_D $Acc, $Input), $Input)>;
+
// SVE2 integer multiply long
defm SQDMULLB_ZZZ : sve2_wide_int_arith_long<0b11000, "sqdmullb", int_aarch64_sve_sqdmullb>;
defm SQDMULLT_ZZZ : sve2_wide_int_arith_long<0b11001, "sqdmullt", int_aarch64_sve_sqdmullt>;
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 5bc9a101b1e44..baa63a4ca31a2 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -561,31 +561,34 @@ define <vscale x 4 x i64> @udot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale
; CHECK-NEXT: add z1.d, z1.d, z3.d
; CHECK-NEXT: ret
;
-; CHECK-NEWLOWERING-LABEL: udot_no_bin_op_8to64:
-; CHECK-NEWLOWERING: // %bb.0:
-; CHECK-NEWLOWERING-NEXT: uunpkhi z3.h, z2.b
-; CHECK-NEWLOWERING-NEXT: uunpklo z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT: uunpkhi z4.s, z3.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z5.s, z2.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z6.d, z4.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z7.d, z3.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z24.d, z5.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z25.d, z2.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z24.d
-; CHECK-NEWLOWERING-NEXT: add z5.d, z5.d, z25.d
-; CHECK-NEWLOWERING-NEXT: add z1.d, z1.d, z3.d
-; CHECK-NEWLOWERING-NEXT: add z3.d, z7.d, z6.d
-; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z5.d
-; CHECK-NEWLOWERING-NEXT: add z1.d, z1.d, z3.d
-; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z2.d
-; CHECK-NEWLOWERING-NEXT: add z1.d, z1.d, z4.d
-; CHECK-NEWLOWERING-NEXT: ret
+; CHECK-NEWLOWERING-SVE-LABEL: udot_no_bin_op_8to64:
+; CHECK-NEWLOWERING-SVE: // %bb.0:
+; CHECK-NEWLOWERING-SVE-NEXT: movi v3.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE-NEXT: mov z4.b, #1 // =0x1
+; CHECK-NEWLOWERING-SVE-NEXT: udot z3.s, z2.b, z4.b
+; CHECK-NEWLOWERING-SVE-NEXT: uunpkhi z2.d, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT: uunpklo z3.d, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT: add z2.d, z3.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: udot_no_bin_op_8to64:
+; CHECK-NEWLOWERING-SVE2: // %bb.0:
+; CHECK-NEWLOWERING-SVE2-NEXT: movi v3.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE2-NEXT: mov z4.b, #1 // =0x1
+; CHECK-NEWLOWERING-SVE2-NEXT: udot z3.s, z2.b, z4.b
+; CHECK-NEWLOWERING-SVE2-NEXT: uaddwb z0.d, z0.d, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT: uaddwt z0.d, z0.d, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT: ret
+;
+; CHECK-NEWLOWERING-SME-LABEL: udot_no_bin_op_8to64:
+; CHECK-NEWLOWERING-SME: // %bb.0:
+; CHECK-NEWLOWERING-SME-NEXT: mov z3.b, #1 // =0x1
+; CHECK-NEWLOWERING-SME-NEXT: mov z4.s, #0 // =0x0
+; CHECK-NEWLOWERING-SME-NEXT: udot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-SME-NEXT: uaddwb z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT: uaddwt z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT: ret
%a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
%partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(<vscale x 4 x i64> %acc, <vscale x 16 x i64> %a.ext)
ret <vscale x 4 x i64> %partial.reduce
@@ -603,31 +606,34 @@ define <vscale x 4 x i64> @sdot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale
; CHECK-NEXT: add z1.d, z1.d, z3.d
; CHECK-NEXT: ret
;
-; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op_8to64:
-; CHECK-NEWLOWERING: // %bb.0:
-; CHECK-NEWLOWERING-NEXT: sunpkhi z3.h, z2.b
-; CHECK-NEWLOWERING-NEXT: sunpklo z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT: sunpkhi z4.s, z3.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z5.s, z2.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z6.d, z4.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z7.d, z3.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z24.d, z5.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z25.d, z2.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z24.d
-; CHECK-NEWLOWERING-NEXT: add z5.d, z5.d, z25.d
-; CHECK-NEWLOWERING-NEXT: add z1.d, z1.d, z3.d
-; CHECK-NEWLOWERING-NEXT: add z3.d, z7.d, z6.d
-; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z5.d
-; CHECK-NEWLOWERING-NEXT: add z1.d, z1.d, z3.d
-; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z2.d
-; CHECK-NEWLOWERING-NEXT: add z1.d, z1.d, z4.d
-; CHECK-NEWLOWERING-NEXT: ret
+; CHECK-NEWLOWERING-SVE-LABEL: sdot_no_bin_op_8to64:
+; CHECK-NEWLOWERING-SVE: // %bb.0:
+; CHECK-NEWLOWERING-SVE-NEXT: movi v3.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE-NEXT: mov z4.b, #1 // =0x1
+; CHECK-NEWLOWERING-SVE-NEXT: sdot z3.s, z2.b, z4.b
+; CHECK-NEWLOWERING-SVE-NEXT: sunpkhi z2.d, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT: sunpklo z3.d, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT: add z2.d, z3.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: sdot_no_bin_op_8to64:
+; CHECK-NEWLOWERING-SVE2: // %bb.0:
+; CHECK-NEWLOWERING-SVE2-NEXT: movi v3.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE2-NEXT: mov z4.b, #1 // =0x1
+; CHECK-NEWLOWERING-SVE2-NEXT: sdot z3.s, z2.b, z4.b
+; CHECK-NEWLOWERING-SVE2-NEXT: saddwb z0.d, z0.d, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT: saddwt z0.d, z0.d, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT: ret
+;
+; CHECK-NEWLOWERING-SME-LABEL: sdot_no_bin_op_8to64:
+; CHECK-NEWLOWERING-SME: // %bb.0:
+; CHECK-NEWLOWERING-SME-NEXT: mov z3.b, #1 // =0x1
+; CHECK-NEWLOWERING-SME-NEXT: mov z4.s, #0 // =0x0
+; CHECK-NEWLOWERING-SME-NEXT: sdot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-SME-NEXT: saddwb z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT: saddwt z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT: ret
%a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
%partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(<vscale x 4 x i64> %acc, <vscale x 16 x i64> %a.ext)
ret <vscale x 4 x i64> %partial.reduce
@@ -647,18 +653,44 @@ define <vscale x 4 x i32> @not_udot(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %
; CHECK-NEXT: mla z0.s, p0/m, z1.s, z2.s
; CHECK-NEXT: ret
;
-; CHECK-NEWLOWERING-LABEL: not_udot:
-; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: and z1.h, z1.h, #0xff
-; CHECK-NEWLOWERING-NEXT: and z2.h, z2.h, #0xff
-; CHECK-NEWLOWERING-NEXT: ptrue p0.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z3.s, z1.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z4.s, z2.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z3.s, z4.s
-; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z1.s, z2.s
-; CHECK-NEWLOWERING-NEXT: ret
+; CHECK-NEWLOWERING-SVE-LABEL: not_udot:
+; CHECK-NEWLOWERING-SVE: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT: and z1.h, z1.h, #0xff
+; CHECK-NEWLOWERING-SVE-NEXT: and z2.h, z2.h, #0xff
+; CHECK-NEWLOWERING-SVE-NEXT: ptrue p0.s
+; CHECK-NEWLOWERING-SVE-NEXT: uunpklo z3.s, z1.h
+; CHECK-NEWLOWERING-SVE-NEXT: uunpklo z4.s, z2.h
+; CHECK-NEWLOWERING-SVE-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-SVE-NEXT: uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-SVE-NEXT: mla z0.s, p0/m, z3.s, z4.s
+; CHECK-NEWLOWERING-SVE-NEXT: mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: not_udot:
+; CHECK-NEWLOWERING-SVE2: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT: and z1.h, z1.h, #0xff
+; CHECK-NEWLOWERING-SVE2-NEXT: and z2.h, z2.h, #0xff
+; CHECK-NEWLOWERING-SVE2-NEXT: ptrue p0.s
+; CHECK-NEWLOWERING-SVE2-NEXT: uunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-SVE2-NEXT: uunpklo z4.s, z1.h
+; CHECK-NEWLOWERING-SVE2-NEXT: uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-SVE2-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-SVE2-NEXT: mla z0.s, p0/m, z4.s, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT: mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-SVE2-NEXT: ret
+;
+; CHECK-NEWLOWERING-SME-LABEL: not_udot:
+; CHECK-NEWLOWERING-SME: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SME-NEXT: and z1.h, z1.h, #0xff
+; CHECK-NEWLOWERING-SME-NEXT: and z2.h, z2.h, #0xff
+; CHECK-NEWLOWERING-SME-NEXT: ptrue p0.s
+; CHECK-NEWLOWERING-SME-NEXT: uunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-SME-NEXT: uunpklo z4.s, z1.h
+; CHECK-NEWLOWERING-SME-NEXT: uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-SME-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-SME-NEXT: mla z0.s, p0/m, z4.s, z3.s
+; CHECK-NEWLOWERING-SME-NEXT: mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-SME-NEXT: ret
entry:
%a.wide = zext <vscale x 8 x i8> %a to <vscale x 8 x i32>
%b.wide = zext <vscale x 8 x i8> %b to <vscale x 8 x i32>
@@ -681,18 +713,44 @@ define <vscale x 2 x i64> @not_udot_wide(<vscale x 2 x i64> %acc, <vscale x 4 x
; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d
; CHECK-NEXT: ret
;
-; CHECK-NEWLOWERING-LABEL: not_udot_wide:
-; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: and z1.s, z1.s, #0xffff
-; CHECK-NEWLOWERING-NEXT: and z2.s, z2.s, #0xffff
-; CHECK-NEWLOWERING-NEXT: ptrue p0.d
-; CHECK-NEWLOWERING-NEXT: uunpklo z3.d, z1.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z4.d, z2.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z1.d, z1.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z3.d, z4.d
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z1.d, z2.d
-; CHECK-NEWLOWERING-NEXT: ret
+; CHECK-NEWLOWERING-SVE-LABEL: not_udot_wide:
+; CHECK-NEWLOWERING-SVE: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT: and z1.s, z1.s, #0xffff
+; CHECK-NEWLOWERING-SVE-NEXT: and z2.s, z2.s, #0xffff
+; CHECK-NEWLOWERING-SVE-NEXT: ptrue p0.d
+; CHECK-NEWLOWERING-SVE-NEXT: uunpklo z3.d, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT: uunpklo z4.d, z2.s
+; CHECK-NEWLOWERING-SVE-NEXT: uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT: uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-SVE-NEXT: mla z0.d, p0/m, z3.d, z4.d
+; CHECK-NEWLOWERING-SVE-NEXT: mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: not_udot_wide:
+; CHECK-NEWLOWERING-SVE2: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT: and z1.s, z1.s, #0xffff
+; CHECK-NEWLOWERING-SVE2-NEXT: and z2.s, z2.s, #0xffff
+; CHECK-NEWLOWERING-SVE2-NEXT: ptrue p0.d
+; CHECK-NEWLOWERING-SVE2-NEXT: uunpklo z3.d, z2.s
+; CHECK-NEWLOWERING-SVE2-NEXT: uunpklo z4.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT: uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-SVE2-NEXT: uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT: mla z0.d, p0/m, z4.d, z3.d
+; CHECK-NEWLOWERING-SVE2-NEXT: mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-SVE2-NEXT: ret
+;
+; CHECK-NEWLOWERING-SME-LABEL: not_udot_wide:
+; CHECK-NEWLOWERING-SME: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SME-NEXT: and z1.s, z1.s, #0xffff
+; CHECK-NEWLOWERING-SME-NEXT: and z2.s, z2.s, #0xffff
+; CHECK-NEWLOWERING-SME-NEXT: ptrue p0.d
+; CHECK-NEWLOWERING-SME-NEXT: uunpklo z3.d, z2.s
+; CHECK-NEWLOWERING-SME-NEXT: uunpklo z4.d, z1.s
+; CHECK-NEWLOWERING-SME-NEXT: uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-SME-NEXT: uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-SME-NEXT: mla z0.d, p0/m, z4.d, z3.d
+; CHECK-NEWLOWERING-SME-NEXT: mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-SME-NEXT: ret
entry:
%a.wide = zext <vscale x 4 x i16> %a to <vscale x 4 x i64>
%b.wide = zext <vscale x 4 x i16> %b to <vscale x 4 x i64>
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
index 5148d3da6c737..8f9f26a5d5b23 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
@@ -1,7 +1,8 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
-; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-SVE2
-; RUN: llc -mtriple=aarch64 -mattr=+sve %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-SVE
-; RUN: llc -mtriple=aarch64 -mattr=+sve2 -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NEWLOWERING
+; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK-SVE2
+; RUN: llc -mtriple=aarch64 -mattr=+sve %s -o - | FileCheck %s --check-prefixes=CHECK-SVE
+; RUN: llc -mtriple=aarch64 -mattr=+sve -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING-SVE
+; RUN: llc -mtriple=aarch64 -mattr=+sve2 -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING-SVE2
define <vscale x 2 x i64> @signed_wide_add_nxv4i32(<vscale x 2 x i64> %acc, <vscale x 4 x i32> %input){
; CHECK-SVE2-LABEL: signed_wide_add_nxv4i32:
@@ -18,13 +19,19 @@ define <vscale x 2 x i64> @signed_wide_add_nxv4i32(<vscale x 2 x i64> %acc, <vsc
; CHECK-SVE-NEXT: add z0.d, z0.d, z1.d
; CHECK-SVE-NEXT: ret
;
-; CHECK-NEWLOWERING-LABEL: signed_wide_add_nxv4i32:
-; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: sunpklo z2.d, z1.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z1.d, z1.s
-; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z2.d
-; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z1.d
-; CHECK-NEWLOWERING-NEXT: ret
+; CHECK-NEWLOWERING-SVE-LABEL: signed_wide_add_nxv4i32:
+; CHECK-NEWLOWERING-SVE: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT: sunpklo z2.d, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT: sunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z1.d
+; CHECK-NEWLOWERING-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: signed_wide_add_nxv4i32:
+; CHECK-NEWLOWERING-SVE2: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT: saddwb z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT: saddwt z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT: ret
entry:
%input.wide = sext <vscale x 4 x i32> %inpu...
[truncated]
|
@llvm/pr-subscribers-llvm-selectiondag Author: Nicholas Guy (NickGuy-Arm) ChangesBased on work initially done by @JamesChesterman. Patch is 36.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/141075.diff 5 Files Affected:
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index d6e288a59b2ee..0ac8f6f3a8171 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12644,6 +12644,8 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
return SDValue();
+ EVT ResultVT = N->getValueType(0);
+
bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
unsigned NewOpcode =
ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
@@ -12657,7 +12659,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
(LHSOpcode != ISD::SIGN_EXTEND || CTrunc.sext(LHSBits) != C))
return SDValue();
- return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
+ return DAG.getNode(NewOpcode, DL, ResultVT, Acc, LHSExtOp,
DAG.getConstant(CTrunc, DL, LHSExtOpVT));
}
@@ -12678,8 +12680,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
Op1.getValueType().getVectorElementType() != AccElemVT)
return SDValue();
- return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
- RHSExtOp);
+ return DAG.getNode(NewOpcode, DL, ResultVT, Acc, LHSExtOp, RHSExtOp);
}
// partial.reduce.umla(acc, zext(op), splat(1))
@@ -12703,7 +12704,10 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
SDValue UnextOp1 = Op1.getOperand(0);
EVT UnextOp1VT = UnextOp1.getValueType();
- if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), UnextOp1VT))
+ auto *Context = DAG.getContext();
+ if (!TLI.isPartialReduceMLALegalOrCustom(
+ TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
+ TLI.getTypeToTransformTo(*Context, UnextOp1VT)))
return SDValue();
bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 13fb6a32233fe..d602a62eaaf84 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1870,6 +1870,13 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal);
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv16i8, Custom);
+
+ // Wide add types
+ if (Subtarget->hasSVE2() || Subtarget->hasSME()) {
+ setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv4i32, Custom);
+ setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv8i16, Custom);
+ setPartialReduceMLAAction(MVT::nxv8i16, MVT::nxv16i8, Custom);
+ }
}
// Handle operations that are only available in non-streaming SVE mode.
@@ -29530,6 +29537,35 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
SDValue LHS = Op.getOperand(1);
SDValue RHS = Op.getOperand(2);
EVT ResultVT = Op.getValueType();
+
+ // Recognise Op as a wide add, if it is then we leave it as-is
+ // Base: nxv2i64, Subdivision: nxv4i32
+ auto IsEVTSubdivision = [](EVT Base, EVT Subdivision) -> bool {
+ assert(Base.isVector() && Subdivision.isVector());
+ assert(Base.isScalableVector() == Subdivision.isScalableVector());
+
+ ElementCount BaseCount = Base.getVectorElementCount();
+ ElementCount SubCount = Subdivision.getVectorElementCount();
+ if (BaseCount * 2 != SubCount)
+ return false;
+
+ uint64_t BaseScalarSize = Base.getScalarSizeInBits();
+ uint64_t SubScalarSize = Subdivision.getScalarSizeInBits();
+ if (BaseScalarSize != SubScalarSize * 2)
+ return false;
+
+ return true;
+ };
+ if (IsEVTSubdivision(ResultVT, LHS.getValueType())) {
+ // If it looks like a real wide add, we can leave it as-is and treat it as
+ // Legal
+ APInt C;
+ if (ISD::isConstantSplatVector(RHS.getNode(), C) && C.isOne())
+ return Op;
+ // If it doesn't, then we need to expand it.
+ return SDValue();
+ }
+
assert(ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8);
SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, MVT::nxv4i32,
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index d6bd59adef03b..b15caa25b604e 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -3787,6 +3787,19 @@ let Predicates = [HasSVE2_or_SME] in {
defm USUBWB_ZZZ : sve2_wide_int_arith_wide<0b110, "usubwb", int_aarch64_sve_usubwb>;
defm USUBWT_ZZZ : sve2_wide_int_arith_wide<0b111, "usubwt", int_aarch64_sve_usubwt>;
+ def : Pat<(nxv2i64 (partial_reduce_umla nxv2i64:$Acc, nxv4i32:$Input, (nxv4i32 (splat_vector (i32 1))))),
+ (UADDWT_ZZZ_D (UADDWB_ZZZ_D $Acc, $Input), $Input)>;
+ def : Pat<(nxv2i64 (partial_reduce_smla nxv2i64:$Acc, nxv4i32:$Input, (nxv4i32 (splat_vector (i32 1))))),
+ (SADDWT_ZZZ_D (SADDWB_ZZZ_D $Acc, $Input), $Input)>;
+ def : Pat<(nxv4i32 (partial_reduce_umla nxv4i32:$Acc, nxv8i16:$Input, (nxv8i16 (splat_vector (i32 1))))),
+ (UADDWT_ZZZ_D (UADDWB_ZZZ_D $Acc, $Input), $Input)>;
+ def : Pat<(nxv4i32 (partial_reduce_smla nxv4i32:$Acc, nxv8i16:$Input, (nxv8i16 (splat_vector (i32 1))))),
+ (SADDWT_ZZZ_D (SADDWB_ZZZ_D $Acc, $Input), $Input)>;
+ def : Pat<(nxv8i16 (partial_reduce_umla nxv8i16:$Acc, nxv16i8:$Input, (nxv16i8 (splat_vector (i32 1))))),
+ (UADDWT_ZZZ_D (UADDWB_ZZZ_D $Acc, $Input), $Input)>;
+ def : Pat<(nxv8i16 (partial_reduce_smla nxv8i16:$Acc, nxv16i8:$Input, (nxv16i8 (splat_vector (i32 1))))),
+ (SADDWT_ZZZ_D (SADDWB_ZZZ_D $Acc, $Input), $Input)>;
+
// SVE2 integer multiply long
defm SQDMULLB_ZZZ : sve2_wide_int_arith_long<0b11000, "sqdmullb", int_aarch64_sve_sqdmullb>;
defm SQDMULLT_ZZZ : sve2_wide_int_arith_long<0b11001, "sqdmullt", int_aarch64_sve_sqdmullt>;
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 5bc9a101b1e44..baa63a4ca31a2 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -561,31 +561,34 @@ define <vscale x 4 x i64> @udot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale
; CHECK-NEXT: add z1.d, z1.d, z3.d
; CHECK-NEXT: ret
;
-; CHECK-NEWLOWERING-LABEL: udot_no_bin_op_8to64:
-; CHECK-NEWLOWERING: // %bb.0:
-; CHECK-NEWLOWERING-NEXT: uunpkhi z3.h, z2.b
-; CHECK-NEWLOWERING-NEXT: uunpklo z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT: uunpkhi z4.s, z3.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z5.s, z2.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z6.d, z4.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z7.d, z3.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z24.d, z5.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z25.d, z2.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z24.d
-; CHECK-NEWLOWERING-NEXT: add z5.d, z5.d, z25.d
-; CHECK-NEWLOWERING-NEXT: add z1.d, z1.d, z3.d
-; CHECK-NEWLOWERING-NEXT: add z3.d, z7.d, z6.d
-; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z5.d
-; CHECK-NEWLOWERING-NEXT: add z1.d, z1.d, z3.d
-; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z2.d
-; CHECK-NEWLOWERING-NEXT: add z1.d, z1.d, z4.d
-; CHECK-NEWLOWERING-NEXT: ret
+; CHECK-NEWLOWERING-SVE-LABEL: udot_no_bin_op_8to64:
+; CHECK-NEWLOWERING-SVE: // %bb.0:
+; CHECK-NEWLOWERING-SVE-NEXT: movi v3.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE-NEXT: mov z4.b, #1 // =0x1
+; CHECK-NEWLOWERING-SVE-NEXT: udot z3.s, z2.b, z4.b
+; CHECK-NEWLOWERING-SVE-NEXT: uunpkhi z2.d, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT: uunpklo z3.d, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT: add z2.d, z3.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: udot_no_bin_op_8to64:
+; CHECK-NEWLOWERING-SVE2: // %bb.0:
+; CHECK-NEWLOWERING-SVE2-NEXT: movi v3.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE2-NEXT: mov z4.b, #1 // =0x1
+; CHECK-NEWLOWERING-SVE2-NEXT: udot z3.s, z2.b, z4.b
+; CHECK-NEWLOWERING-SVE2-NEXT: uaddwb z0.d, z0.d, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT: uaddwt z0.d, z0.d, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT: ret
+;
+; CHECK-NEWLOWERING-SME-LABEL: udot_no_bin_op_8to64:
+; CHECK-NEWLOWERING-SME: // %bb.0:
+; CHECK-NEWLOWERING-SME-NEXT: mov z3.b, #1 // =0x1
+; CHECK-NEWLOWERING-SME-NEXT: mov z4.s, #0 // =0x0
+; CHECK-NEWLOWERING-SME-NEXT: udot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-SME-NEXT: uaddwb z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT: uaddwt z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT: ret
%a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
%partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(<vscale x 4 x i64> %acc, <vscale x 16 x i64> %a.ext)
ret <vscale x 4 x i64> %partial.reduce
@@ -603,31 +606,34 @@ define <vscale x 4 x i64> @sdot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale
; CHECK-NEXT: add z1.d, z1.d, z3.d
; CHECK-NEXT: ret
;
-; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op_8to64:
-; CHECK-NEWLOWERING: // %bb.0:
-; CHECK-NEWLOWERING-NEXT: sunpkhi z3.h, z2.b
-; CHECK-NEWLOWERING-NEXT: sunpklo z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT: sunpkhi z4.s, z3.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z5.s, z2.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z6.d, z4.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z7.d, z3.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z24.d, z5.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z25.d, z2.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z24.d
-; CHECK-NEWLOWERING-NEXT: add z5.d, z5.d, z25.d
-; CHECK-NEWLOWERING-NEXT: add z1.d, z1.d, z3.d
-; CHECK-NEWLOWERING-NEXT: add z3.d, z7.d, z6.d
-; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z5.d
-; CHECK-NEWLOWERING-NEXT: add z1.d, z1.d, z3.d
-; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z2.d
-; CHECK-NEWLOWERING-NEXT: add z1.d, z1.d, z4.d
-; CHECK-NEWLOWERING-NEXT: ret
+; CHECK-NEWLOWERING-SVE-LABEL: sdot_no_bin_op_8to64:
+; CHECK-NEWLOWERING-SVE: // %bb.0:
+; CHECK-NEWLOWERING-SVE-NEXT: movi v3.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE-NEXT: mov z4.b, #1 // =0x1
+; CHECK-NEWLOWERING-SVE-NEXT: sdot z3.s, z2.b, z4.b
+; CHECK-NEWLOWERING-SVE-NEXT: sunpkhi z2.d, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT: sunpklo z3.d, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT: add z2.d, z3.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: sdot_no_bin_op_8to64:
+; CHECK-NEWLOWERING-SVE2: // %bb.0:
+; CHECK-NEWLOWERING-SVE2-NEXT: movi v3.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE2-NEXT: mov z4.b, #1 // =0x1
+; CHECK-NEWLOWERING-SVE2-NEXT: sdot z3.s, z2.b, z4.b
+; CHECK-NEWLOWERING-SVE2-NEXT: saddwb z0.d, z0.d, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT: saddwt z0.d, z0.d, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT: ret
+;
+; CHECK-NEWLOWERING-SME-LABEL: sdot_no_bin_op_8to64:
+; CHECK-NEWLOWERING-SME: // %bb.0:
+; CHECK-NEWLOWERING-SME-NEXT: mov z3.b, #1 // =0x1
+; CHECK-NEWLOWERING-SME-NEXT: mov z4.s, #0 // =0x0
+; CHECK-NEWLOWERING-SME-NEXT: sdot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-SME-NEXT: saddwb z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT: saddwt z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT: ret
%a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
%partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(<vscale x 4 x i64> %acc, <vscale x 16 x i64> %a.ext)
ret <vscale x 4 x i64> %partial.reduce
@@ -647,18 +653,44 @@ define <vscale x 4 x i32> @not_udot(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %
; CHECK-NEXT: mla z0.s, p0/m, z1.s, z2.s
; CHECK-NEXT: ret
;
-; CHECK-NEWLOWERING-LABEL: not_udot:
-; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: and z1.h, z1.h, #0xff
-; CHECK-NEWLOWERING-NEXT: and z2.h, z2.h, #0xff
-; CHECK-NEWLOWERING-NEXT: ptrue p0.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z3.s, z1.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z4.s, z2.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z3.s, z4.s
-; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z1.s, z2.s
-; CHECK-NEWLOWERING-NEXT: ret
+; CHECK-NEWLOWERING-SVE-LABEL: not_udot:
+; CHECK-NEWLOWERING-SVE: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT: and z1.h, z1.h, #0xff
+; CHECK-NEWLOWERING-SVE-NEXT: and z2.h, z2.h, #0xff
+; CHECK-NEWLOWERING-SVE-NEXT: ptrue p0.s
+; CHECK-NEWLOWERING-SVE-NEXT: uunpklo z3.s, z1.h
+; CHECK-NEWLOWERING-SVE-NEXT: uunpklo z4.s, z2.h
+; CHECK-NEWLOWERING-SVE-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-SVE-NEXT: uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-SVE-NEXT: mla z0.s, p0/m, z3.s, z4.s
+; CHECK-NEWLOWERING-SVE-NEXT: mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: not_udot:
+; CHECK-NEWLOWERING-SVE2: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT: and z1.h, z1.h, #0xff
+; CHECK-NEWLOWERING-SVE2-NEXT: and z2.h, z2.h, #0xff
+; CHECK-NEWLOWERING-SVE2-NEXT: ptrue p0.s
+; CHECK-NEWLOWERING-SVE2-NEXT: uunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-SVE2-NEXT: uunpklo z4.s, z1.h
+; CHECK-NEWLOWERING-SVE2-NEXT: uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-SVE2-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-SVE2-NEXT: mla z0.s, p0/m, z4.s, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT: mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-SVE2-NEXT: ret
+;
+; CHECK-NEWLOWERING-SME-LABEL: not_udot:
+; CHECK-NEWLOWERING-SME: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SME-NEXT: and z1.h, z1.h, #0xff
+; CHECK-NEWLOWERING-SME-NEXT: and z2.h, z2.h, #0xff
+; CHECK-NEWLOWERING-SME-NEXT: ptrue p0.s
+; CHECK-NEWLOWERING-SME-NEXT: uunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-SME-NEXT: uunpklo z4.s, z1.h
+; CHECK-NEWLOWERING-SME-NEXT: uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-SME-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-SME-NEXT: mla z0.s, p0/m, z4.s, z3.s
+; CHECK-NEWLOWERING-SME-NEXT: mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-SME-NEXT: ret
entry:
%a.wide = zext <vscale x 8 x i8> %a to <vscale x 8 x i32>
%b.wide = zext <vscale x 8 x i8> %b to <vscale x 8 x i32>
@@ -681,18 +713,44 @@ define <vscale x 2 x i64> @not_udot_wide(<vscale x 2 x i64> %acc, <vscale x 4 x
; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d
; CHECK-NEXT: ret
;
-; CHECK-NEWLOWERING-LABEL: not_udot_wide:
-; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: and z1.s, z1.s, #0xffff
-; CHECK-NEWLOWERING-NEXT: and z2.s, z2.s, #0xffff
-; CHECK-NEWLOWERING-NEXT: ptrue p0.d
-; CHECK-NEWLOWERING-NEXT: uunpklo z3.d, z1.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z4.d, z2.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z1.d, z1.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z3.d, z4.d
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z1.d, z2.d
-; CHECK-NEWLOWERING-NEXT: ret
+; CHECK-NEWLOWERING-SVE-LABEL: not_udot_wide:
+; CHECK-NEWLOWERING-SVE: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT: and z1.s, z1.s, #0xffff
+; CHECK-NEWLOWERING-SVE-NEXT: and z2.s, z2.s, #0xffff
+; CHECK-NEWLOWERING-SVE-NEXT: ptrue p0.d
+; CHECK-NEWLOWERING-SVE-NEXT: uunpklo z3.d, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT: uunpklo z4.d, z2.s
+; CHECK-NEWLOWERING-SVE-NEXT: uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT: uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-SVE-NEXT: mla z0.d, p0/m, z3.d, z4.d
+; CHECK-NEWLOWERING-SVE-NEXT: mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: not_udot_wide:
+; CHECK-NEWLOWERING-SVE2: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT: and z1.s, z1.s, #0xffff
+; CHECK-NEWLOWERING-SVE2-NEXT: and z2.s, z2.s, #0xffff
+; CHECK-NEWLOWERING-SVE2-NEXT: ptrue p0.d
+; CHECK-NEWLOWERING-SVE2-NEXT: uunpklo z3.d, z2.s
+; CHECK-NEWLOWERING-SVE2-NEXT: uunpklo z4.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT: uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-SVE2-NEXT: uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT: mla z0.d, p0/m, z4.d, z3.d
+; CHECK-NEWLOWERING-SVE2-NEXT: mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-SVE2-NEXT: ret
+;
+; CHECK-NEWLOWERING-SME-LABEL: not_udot_wide:
+; CHECK-NEWLOWERING-SME: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SME-NEXT: and z1.s, z1.s, #0xffff
+; CHECK-NEWLOWERING-SME-NEXT: and z2.s, z2.s, #0xffff
+; CHECK-NEWLOWERING-SME-NEXT: ptrue p0.d
+; CHECK-NEWLOWERING-SME-NEXT: uunpklo z3.d, z2.s
+; CHECK-NEWLOWERING-SME-NEXT: uunpklo z4.d, z1.s
+; CHECK-NEWLOWERING-SME-NEXT: uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-SME-NEXT: uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-SME-NEXT: mla z0.d, p0/m, z4.d, z3.d
+; CHECK-NEWLOWERING-SME-NEXT: mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-SME-NEXT: ret
entry:
%a.wide = zext <vscale x 4 x i16> %a to <vscale x 4 x i64>
%b.wide = zext <vscale x 4 x i16> %b to <vscale x 4 x i64>
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
index 5148d3da6c737..8f9f26a5d5b23 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
@@ -1,7 +1,8 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
-; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-SVE2
-; RUN: llc -mtriple=aarch64 -mattr=+sve %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-SVE
-; RUN: llc -mtriple=aarch64 -mattr=+sve2 -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NEWLOWERING
+; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK-SVE2
+; RUN: llc -mtriple=aarch64 -mattr=+sve %s -o - | FileCheck %s --check-prefixes=CHECK-SVE
+; RUN: llc -mtriple=aarch64 -mattr=+sve -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING-SVE
+; RUN: llc -mtriple=aarch64 -mattr=+sve2 -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING-SVE2
define <vscale x 2 x i64> @signed_wide_add_nxv4i32(<vscale x 2 x i64> %acc, <vscale x 4 x i32> %input){
; CHECK-SVE2-LABEL: signed_wide_add_nxv4i32:
@@ -18,13 +19,19 @@ define <vscale x 2 x i64> @signed_wide_add_nxv4i32(<vscale x 2 x i64> %acc, <vsc
; CHECK-SVE-NEXT: add z0.d, z0.d, z1.d
; CHECK-SVE-NEXT: ret
;
-; CHECK-NEWLOWERING-LABEL: signed_wide_add_nxv4i32:
-; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: sunpklo z2.d, z1.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z1.d, z1.s
-; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z2.d
-; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z1.d
-; CHECK-NEWLOWERING-NEXT: ret
+; CHECK-NEWLOWERING-SVE-LABEL: signed_wide_add_nxv4i32:
+; CHECK-NEWLOWERING-SVE: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT: sunpklo z2.d, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT: sunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z1.d
+; CHECK-NEWLOWERING-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: signed_wide_add_nxv4i32:
+; CHECK-NEWLOWERING-SVE2: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT: saddwb z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT: saddwt z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT: ret
entry:
%input.wide = sext <vscale x 4 x i32> %inpu...
[truncated]
|
@@ -12703,7 +12704,10 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) { | |||
|
|||
SDValue UnextOp1 = Op1.getOperand(0); | |||
EVT UnextOp1VT = UnextOp1.getValueType(); | |||
if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), UnextOp1VT)) | |||
auto *Context = DAG.getContext(); | |||
if (!TLI.isPartialReduceMLALegalOrCustom( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain the motivation behind this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My understanding is that this combine happens before type legalisation, so it needs to check if the partial reduce will be supported for the legal type.
@@ -12703,7 +12704,10 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) { | |||
|
|||
SDValue UnextOp1 = Op1.getOperand(0); | |||
EVT UnextOp1VT = UnextOp1.getValueType(); | |||
if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), UnextOp1VT)) | |||
auto *Context = DAG.getContext(); | |||
if (!TLI.isPartialReduceMLALegalOrCustom( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My understanding is that this combine happens before type legalisation, so it needs to check if the partial reduce will be supported for the legal type.
3010a2b
to
3936654
Compare
3936654
to
26c1098
Compare
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/116/builds/13487 Here is the relevant piece of the build log for the reference
|
… adds (#141075) Based on work initially done by @JamesChesterman.
… adds (llvm#141075) Based on work initially done by @JamesChesterman.
… adds (llvm#141075) Based on work initially done by @JamesChesterman.
Based on work initially done by @JamesChesterman.