-
Notifications
You must be signed in to change notification settings - Fork 14k
[AArch64] Enable fixed-length vector support for partial-reductions #142032
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AArch64] Enable fixed-length vector support for partial-reductions #142032
Conversation
@llvm/pr-subscribers-backend-aarch64 Author: Sander de Smalen (sdesmalen-arm) ChangesPatch is 34.30 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/142032.diff 2 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a07afea963e20..34ce8bb7dd2b2 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1930,6 +1930,14 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
Custom);
setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::nxv2i64,
Custom);
+
+ // Must be lowered to SVE instructions.
+ setPartialReduceMLAAction(MVT::v2i64, MVT::v4i32, Custom);
+ setPartialReduceMLAAction(MVT::v2i64, MVT::v8i16, Custom);
+ setPartialReduceMLAAction(MVT::v2i64, MVT::v16i8, Custom);
+ setPartialReduceMLAAction(MVT::v4i32, MVT::v8i16, Custom);
+ setPartialReduceMLAAction(MVT::v4i32, MVT::v16i8, Custom);
+ setPartialReduceMLAAction(MVT::v8i16, MVT::v16i8, Custom);
}
}
@@ -2225,6 +2233,26 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
bool PreferNEON = VT.is64BitVector() || VT.is128BitVector();
bool PreferSVE = !PreferNEON && Subtarget->isSVEAvailable();
+ if (EnablePartialReduceNodes) {
+ unsigned NumElts = VT.getVectorNumElements();
+ if (VT.getVectorElementType() == MVT::i64) {
+ setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i8, NumElts * 8),
+ Custom);
+ setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i16, NumElts * 4),
+ Custom);
+ setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i32, NumElts * 2),
+ Custom);
+ } else if (VT.getVectorElementType() == MVT::i32) {
+ setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i8, NumElts * 4),
+ Custom);
+ setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i16, NumElts * 2),
+ Custom);
+ } else if (VT.getVectorElementType() == MVT::i16) {
+ setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i8, NumElts * 2),
+ Custom);
+ }
+ }
+
// Lower fixed length vector operations to scalable equivalents.
setOperationAction(ISD::ABDS, VT, Default);
setOperationAction(ISD::ABDU, VT, Default);
@@ -29224,50 +29252,61 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
SDValue
AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
SelectionDAG &DAG) const {
- bool Scalable = Op.getValueType().isScalableVector();
-
- assert((!Scalable || Subtarget->isSVEorStreamingSVEAvailable()) &&
- "SVE or StreamingSVE must be available when using scalable vectors.");
- assert((Scalable || Subtarget->hasDotProd()) &&
- "Dotprod must be available when targeting NEON dot product "
- "instructions.");
-
SDLoc DL(Op);
SDValue Acc = Op.getOperand(0);
SDValue LHS = Op.getOperand(1);
SDValue RHS = Op.getOperand(2);
EVT ResultVT = Op.getValueType();
+ EVT OrigResultVT = ResultVT;
+ EVT OpVT = LHS.getValueType();
- assert((Scalable && ResultVT == MVT::nxv2i64 &&
- LHS.getValueType() == MVT::nxv16i8) ||
- (!Scalable && ResultVT == MVT::v2i64 &&
- LHS.getValueType() == MVT::v16i8));
+ bool ConvertToScalable =
+ ResultVT.isFixedLengthVector() &&
+ useSVEForFixedLengthVectorVT(ResultVT, /*OverrideNEON=*/true);
- EVT DotVT = Scalable ? MVT::nxv4i32 : MVT::v4i32;
+ if (ConvertToScalable) {
+ ResultVT = getContainerForFixedLengthVector(DAG, ResultVT);
+ OpVT = getContainerForFixedLengthVector(DAG, LHS.getValueType());
+ Acc = convertToScalableVector(DAG, ResultVT, Acc);
+ LHS = convertToScalableVector(DAG, OpVT, LHS);
+ RHS = convertToScalableVector(DAG, OpVT, RHS);
+ Op = DAG.getNode(Op.getOpcode(), DL, ResultVT, {Acc, LHS, RHS});
+ }
+
+ // Two-way and four-way partial reductions are supported by patterns.
+ // We only need to handle the 8-way partial reduction.
+ if (ResultVT.getScalarType() != MVT::i64 || OpVT.getScalarType() != MVT::i8)
+ return ConvertToScalable ? convertFromScalableVector(DAG, OrigResultVT, Op)
+ : Op;
+
+ EVT DotVT = ResultVT.isScalableVector() ? MVT::nxv4i32 : MVT::v4i32;
SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, DotVT,
DAG.getConstant(0, DL, DotVT), LHS, RHS);
+ SDValue Res;
bool IsUnsigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_UMLA;
- if (Scalable &&
- (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable())) {
+ if (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable()) {
unsigned LoOpcode = IsUnsigned ? AArch64ISD::UADDWB : AArch64ISD::SADDWB;
unsigned HiOpcode = IsUnsigned ? AArch64ISD::UADDWT : AArch64ISD::SADDWT;
SDValue Lo = DAG.getNode(LoOpcode, DL, ResultVT, Acc, DotNode);
- return DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
- }
-
- // Fold (nx)v4i32 into (nx)v2i64
- auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
- if (IsUnsigned) {
- DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT);
- DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT);
+ Res = DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
} else {
- DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT);
- DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT);
+ // Fold (nx)v4i32 into (nx)v2i64
+ auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
+ if (IsUnsigned) {
+ DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT);
+ DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT);
+ } else {
+ DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT);
+ DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT);
+ }
+ auto Lo = DAG.getNode(ISD::ADD, DL, ResultVT, Acc, DotNodeLo);
+ Res = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, DotNodeHi);
}
- auto Lo = DAG.getNode(ISD::ADD, DL, ResultVT, Acc, DotNodeLo);
- return DAG.getNode(ISD::ADD, DL, ResultVT, Lo, DotNodeHi);
+
+ return ConvertToScalable ? convertFromScalableVector(DAG, OrigResultVT, Res)
+ : Res;
}
SDValue
diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll
new file mode 100644
index 0000000000000..79d766d1b9908
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll
@@ -0,0 +1,791 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mattr=+dotprod -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,NEON
+; RUN: llc -mattr=+sve,+dotprod -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,SVE
+; RUN: llc -mattr=+sme -aarch64-enable-partial-reduce-nodes=true -force-streaming < %s | FileCheck %s --check-prefix=SME
+
+target triple = "aarch64"
+
+;
+; Two-way mla (i8 -> i16)
+;
+
+define <8 x i16> @two_way_i8_i16_vl128(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; COMMON-LABEL: two_way_i8_i16_vl128:
+; COMMON: // %bb.0:
+; COMMON-NEXT: ldr q0, [x0]
+; COMMON-NEXT: ldr q1, [x1]
+; COMMON-NEXT: ldr q2, [x2]
+; COMMON-NEXT: umlal v0.8h, v2.8b, v1.8b
+; COMMON-NEXT: umlal2 v0.8h, v2.16b, v1.16b
+; COMMON-NEXT: ret
+;
+; SME-LABEL: two_way_i8_i16_vl128:
+; SME: // %bb.0:
+; SME-NEXT: ldr q0, [x0]
+; SME-NEXT: ldr q1, [x1]
+; SME-NEXT: ldr q2, [x2]
+; SME-NEXT: umlalb z0.h, z2.b, z1.b
+; SME-NEXT: umlalt z0.h, z2.b, z1.b
+; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT: ret
+ %acc = load <8 x i16>, ptr %accptr
+ %u = load <16 x i8>, ptr %uptr
+ %s = load <16 x i8>, ptr %sptr
+ %u.wide = zext <16 x i8> %u to <16 x i16>
+ %s.wide = zext <16 x i8> %s to <16 x i16>
+ %mult = mul nuw nsw <16 x i16> %s.wide, %u.wide
+ %partial.reduce = tail call <8 x i16> @llvm.experimental.vector.partial.reduce.add(<8 x i16> %acc, <16 x i16> %mult)
+ ret <8 x i16> %partial.reduce
+}
+
+define <16 x i16> @two_way_i8_i16_vl128_double_width(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; COMMON-LABEL: two_way_i8_i16_vl128_double_width:
+; COMMON: // %bb.0:
+; COMMON-NEXT: ldp q0, q1, [x0]
+; COMMON-NEXT: ldp q2, q3, [x1]
+; COMMON-NEXT: ldp q4, q5, [x2]
+; COMMON-NEXT: umlal v0.8h, v4.8b, v2.8b
+; COMMON-NEXT: umlal v1.8h, v5.8b, v3.8b
+; COMMON-NEXT: umlal2 v0.8h, v4.16b, v2.16b
+; COMMON-NEXT: umlal2 v1.8h, v5.16b, v3.16b
+; COMMON-NEXT: ret
+;
+; SME-LABEL: two_way_i8_i16_vl128_double_width:
+; SME: // %bb.0:
+; SME-NEXT: ldp q0, q1, [x0]
+; SME-NEXT: ldp q3, q2, [x1]
+; SME-NEXT: ldp q5, q4, [x2]
+; SME-NEXT: umlalb z0.h, z5.b, z3.b
+; SME-NEXT: umlalb z1.h, z4.b, z2.b
+; SME-NEXT: umlalt z0.h, z5.b, z3.b
+; SME-NEXT: umlalt z1.h, z4.b, z2.b
+; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT: // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT: ret
+ %acc = load <16 x i16>, ptr %accptr
+ %u = load <32 x i8>, ptr %uptr
+ %s = load <32 x i8>, ptr %sptr
+ %u.wide = zext <32 x i8> %u to <32 x i16>
+ %s.wide = zext <32 x i8> %s to <32 x i16>
+ %mult = mul nuw nsw <32 x i16> %s.wide, %u.wide
+ %partial.reduce = tail call <16 x i16> @llvm.experimental.vector.partial.reduce.add(<16 x i16> %acc, <32 x i16> %mult)
+ ret <16 x i16> %partial.reduce
+}
+
+define <16 x i16> @two_way_i8_i16_vl256(ptr %accptr, ptr %uptr, ptr %sptr) vscale_range(2,2) {
+;
+;
+; NEON-LABEL: two_way_i8_i16_vl256:
+; NEON: // %bb.0:
+; NEON-NEXT: ldp q0, q1, [x0]
+; NEON-NEXT: ldp q2, q3, [x1]
+; NEON-NEXT: ldp q4, q5, [x2]
+; NEON-NEXT: umlal v0.8h, v4.8b, v2.8b
+; NEON-NEXT: umlal v1.8h, v5.8b, v3.8b
+; NEON-NEXT: umlal2 v0.8h, v4.16b, v2.16b
+; NEON-NEXT: umlal2 v1.8h, v5.16b, v3.16b
+; NEON-NEXT: ret
+;
+; SVE-LABEL: two_way_i8_i16_vl256:
+; SVE: // %bb.0:
+; SVE-NEXT: ldr z0, [x1]
+; SVE-NEXT: ldr z1, [x2]
+; SVE-NEXT: ptrue p0.h
+; SVE-NEXT: ldr z4, [x0]
+; SVE-NEXT: uunpklo z2.h, z0.b
+; SVE-NEXT: uunpklo z3.h, z1.b
+; SVE-NEXT: uunpkhi z0.h, z0.b
+; SVE-NEXT: uunpkhi z1.h, z1.b
+; SVE-NEXT: mad z2.h, p0/m, z3.h, z4.h
+; SVE-NEXT: mad z0.h, p0/m, z1.h, z2.h
+; SVE-NEXT: mov z1.d, z0.d
+; SVE-NEXT: ext z1.b, z1.b, z0.b, #16
+; SVE-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SVE-NEXT: // kill: def $q1 killed $q1 killed $z1
+; SVE-NEXT: ret
+;
+; SME-LABEL: two_way_i8_i16_vl256:
+; SME: // %bb.0:
+; SME-NEXT: ldr z0, [x0]
+; SME-NEXT: ldr z1, [x1]
+; SME-NEXT: ldr z2, [x2]
+; SME-NEXT: umlalb z0.h, z2.b, z1.b
+; SME-NEXT: umlalt z0.h, z2.b, z1.b
+; SME-NEXT: mov z1.d, z0.d
+; SME-NEXT: ext z1.b, z1.b, z0.b, #16
+; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT: // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT: ret
+ %acc = load <16 x i16>, ptr %accptr
+ %u = load <32 x i8>, ptr %uptr
+ %s = load <32 x i8>, ptr %sptr
+ %u.wide = zext <32 x i8> %u to <32 x i16>
+ %s.wide = zext <32 x i8> %s to <32 x i16>
+ %mult = mul nuw nsw <32 x i16> %s.wide, %u.wide
+ %partial.reduce = tail call <16 x i16> @llvm.experimental.vector.partial.reduce.add(<16 x i16> %acc, <32 x i16> %mult)
+ ret <16 x i16> %partial.reduce
+}
+
+;
+; Two-way mla (i16 -> i32)
+;
+
+define <4 x i32> @two_way_i16_i32_vl128(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; COMMON-LABEL: two_way_i16_i32_vl128:
+; COMMON: // %bb.0:
+; COMMON-NEXT: ldr q0, [x0]
+; COMMON-NEXT: ldr q1, [x1]
+; COMMON-NEXT: ldr q2, [x2]
+; COMMON-NEXT: umlal v0.4s, v2.4h, v1.4h
+; COMMON-NEXT: umlal2 v0.4s, v2.8h, v1.8h
+; COMMON-NEXT: ret
+;
+; SME-LABEL: two_way_i16_i32_vl128:
+; SME: // %bb.0:
+; SME-NEXT: ldr q0, [x0]
+; SME-NEXT: ldr q1, [x1]
+; SME-NEXT: ldr q2, [x2]
+; SME-NEXT: umlalb z0.s, z2.h, z1.h
+; SME-NEXT: umlalt z0.s, z2.h, z1.h
+; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT: ret
+ %acc = load <4 x i32>, ptr %accptr
+ %u = load <8 x i16>, ptr %uptr
+ %s = load <8 x i16>, ptr %sptr
+ %u.wide = zext <8 x i16> %u to <8 x i32>
+ %s.wide = zext <8 x i16> %s to <8 x i32>
+ %mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
+ %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add(<4 x i32> %acc, <8 x i32> %mult)
+ ret <4 x i32> %partial.reduce
+}
+
+define <8 x i32> @two_way_i16_i32_vl128_double_width(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; COMMON-LABEL: two_way_i16_i32_vl128_double_width:
+; COMMON: // %bb.0:
+; COMMON-NEXT: ldp q0, q1, [x0]
+; COMMON-NEXT: ldp q2, q3, [x1]
+; COMMON-NEXT: ldp q4, q5, [x2]
+; COMMON-NEXT: umlal v0.4s, v4.4h, v2.4h
+; COMMON-NEXT: umlal v1.4s, v5.4h, v3.4h
+; COMMON-NEXT: umlal2 v0.4s, v4.8h, v2.8h
+; COMMON-NEXT: umlal2 v1.4s, v5.8h, v3.8h
+; COMMON-NEXT: ret
+;
+; SME-LABEL: two_way_i16_i32_vl128_double_width:
+; SME: // %bb.0:
+; SME-NEXT: ldp q0, q1, [x0]
+; SME-NEXT: ldp q3, q2, [x1]
+; SME-NEXT: ldp q5, q4, [x2]
+; SME-NEXT: umlalb z0.s, z5.h, z3.h
+; SME-NEXT: umlalb z1.s, z4.h, z2.h
+; SME-NEXT: umlalt z0.s, z5.h, z3.h
+; SME-NEXT: umlalt z1.s, z4.h, z2.h
+; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT: // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT: ret
+ %acc = load <8 x i32>, ptr %accptr
+ %u = load <16 x i16>, ptr %uptr
+ %s = load <16 x i16>, ptr %sptr
+ %u.wide = zext <16 x i16> %u to <16 x i32>
+ %s.wide = zext <16 x i16> %s to <16 x i32>
+ %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
+ %partial.reduce = tail call <8 x i32> @llvm.experimental.vector.partial.reduce.add(<8 x i32> %acc, <16 x i32> %mult)
+ ret <8 x i32> %partial.reduce
+}
+
+define <8 x i32> @two_way_i16_i32_vl256(ptr %accptr, ptr %uptr, ptr %sptr) vscale_range(2,2) {
+;
+;
+; NEON-LABEL: two_way_i16_i32_vl256:
+; NEON: // %bb.0:
+; NEON-NEXT: ldp q0, q1, [x0]
+; NEON-NEXT: ldp q2, q3, [x1]
+; NEON-NEXT: ldp q4, q5, [x2]
+; NEON-NEXT: umlal v0.4s, v4.4h, v2.4h
+; NEON-NEXT: umlal v1.4s, v5.4h, v3.4h
+; NEON-NEXT: umlal2 v0.4s, v4.8h, v2.8h
+; NEON-NEXT: umlal2 v1.4s, v5.8h, v3.8h
+; NEON-NEXT: ret
+;
+; SVE-LABEL: two_way_i16_i32_vl256:
+; SVE: // %bb.0:
+; SVE-NEXT: ldr z0, [x1]
+; SVE-NEXT: ldr z1, [x2]
+; SVE-NEXT: ptrue p0.s
+; SVE-NEXT: ldr z4, [x0]
+; SVE-NEXT: uunpklo z2.s, z0.h
+; SVE-NEXT: uunpklo z3.s, z1.h
+; SVE-NEXT: uunpkhi z0.s, z0.h
+; SVE-NEXT: uunpkhi z1.s, z1.h
+; SVE-NEXT: mad z2.s, p0/m, z3.s, z4.s
+; SVE-NEXT: mad z0.s, p0/m, z1.s, z2.s
+; SVE-NEXT: mov z1.d, z0.d
+; SVE-NEXT: ext z1.b, z1.b, z0.b, #16
+; SVE-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SVE-NEXT: // kill: def $q1 killed $q1 killed $z1
+; SVE-NEXT: ret
+;
+; SME-LABEL: two_way_i16_i32_vl256:
+; SME: // %bb.0:
+; SME-NEXT: ldr z0, [x0]
+; SME-NEXT: ldr z1, [x1]
+; SME-NEXT: ldr z2, [x2]
+; SME-NEXT: umlalb z0.s, z2.h, z1.h
+; SME-NEXT: umlalt z0.s, z2.h, z1.h
+; SME-NEXT: mov z1.d, z0.d
+; SME-NEXT: ext z1.b, z1.b, z0.b, #16
+; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT: // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT: ret
+ %acc = load <8 x i32>, ptr %accptr
+ %u = load <16 x i16>, ptr %uptr
+ %s = load <16 x i16>, ptr %sptr
+ %u.wide = zext <16 x i16> %u to <16 x i32>
+ %s.wide = zext <16 x i16> %s to <16 x i32>
+ %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
+ %partial.reduce = tail call <8 x i32> @llvm.experimental.vector.partial.reduce.add(<8 x i32> %acc, <16 x i32> %mult)
+ ret <8 x i32> %partial.reduce
+}
+
+;
+; Two-way mla (i32 -> i64)
+;
+
+define <2 x i64> @two_way_i32_i64_vl128(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; COMMON-LABEL: two_way_i32_i64_vl128:
+; COMMON: // %bb.0:
+; COMMON-NEXT: ldr q0, [x0]
+; COMMON-NEXT: ldr q1, [x1]
+; COMMON-NEXT: ldr q2, [x2]
+; COMMON-NEXT: umlal v0.2d, v2.2s, v1.2s
+; COMMON-NEXT: umlal2 v0.2d, v2.4s, v1.4s
+; COMMON-NEXT: ret
+;
+; SME-LABEL: two_way_i32_i64_vl128:
+; SME: // %bb.0:
+; SME-NEXT: ldr q0, [x0]
+; SME-NEXT: ldr q1, [x1]
+; SME-NEXT: ldr q2, [x2]
+; SME-NEXT: umlalb z0.d, z2.s, z1.s
+; SME-NEXT: umlalt z0.d, z2.s, z1.s
+; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT: ret
+ %acc = load <2 x i64>, ptr %accptr
+ %u = load <4 x i32>, ptr %uptr
+ %s = load <4 x i32>, ptr %sptr
+ %u.wide = zext <4 x i32> %u to <4 x i64>
+ %s.wide = zext <4 x i32> %s to <4 x i64>
+ %mult = mul nuw nsw <4 x i64> %s.wide, %u.wide
+ %partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add(<2 x i64> %acc, <4 x i64> %mult)
+ ret <2 x i64> %partial.reduce
+}
+
+define <4 x i64> @two_way_i32_i64_vl128_double_width(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; COMMON-LABEL: two_way_i32_i64_vl128_double_width:
+; COMMON: // %bb.0:
+; COMMON-NEXT: ldp q0, q1, [x0]
+; COMMON-NEXT: ldp q2, q3, [x1]
+; COMMON-NEXT: ldp q4, q5, [x2]
+; COMMON-NEXT: umlal v0.2d, v4.2s, v2.2s
+; COMMON-NEXT: umlal v1.2d, v5.2s, v3.2s
+; COMMON-NEXT: umlal2 v0.2d, v4.4s, v2.4s
+; COMMON-NEXT: umlal2 v1.2d, v5.4s, v3.4s
+; COMMON-NEXT: ret
+;
+; SME-LABEL: two_way_i32_i64_vl128_double_width:
+; SME: // %bb.0:
+; SME-NEXT: ldp q0, q1, [x0]
+; SME-NEXT: ldp q3, q2, [x1]
+; SME-NEXT: ldp q5, q4, [x2]
+; SME-NEXT: umlalb z0.d, z5.s, z3.s
+; SME-NEXT: umlalb z1.d, z4.s, z2.s
+; SME-NEXT: umlalt z0.d, z5.s, z3.s
+; SME-NEXT: umlalt z1.d, z4.s, z2.s
+; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT: // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT: ret
+ %acc = load <4 x i64>, ptr %accptr
+ %u = load <8 x i32>, ptr %uptr
+ %s = load <8 x i32>, ptr %sptr
+ %u.wide = zext <8 x i32> %u to <8 x i64>
+ %s.wide = zext <8 x i32> %s to <8 x i64>
+ %mult = mul nuw nsw <8 x i64> %s.wide, %u.wide
+ %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add(<4 x i64> %acc, <8 x i64> %mult)
+ ret <4 x i64> %partial.reduce
+}
+
+define <4 x i64> @two_way_i32_i64_vl256(ptr %accptr, ptr %uptr, ptr %sptr) vscale_range(2,2) {
+;
+;
+; NEON-LABEL: two_way_i32_i64_vl256:
+; NEON: // %bb.0:
+; NEON-NEXT: ldp q0, q1, [x0]
+; NEON-NEXT: ldp q2, q3, [x1]
+; NEON-NEXT: ldp q4, q5, [x2]
+; NEON-NEXT: umlal v0.2d, v4.2s, v2.2s
+; NEON-NEXT: umlal v1.2d, v5.2s, v3.2s
+; NEON-NEXT: umlal2 v0.2d, v4.4s, v2.4s
+; NEON-NEXT: umlal2 v1.2d, v5.4s, v3.4s
+; NEON-NEXT: ret
+;
+; SVE-LABEL: two_way_i32_i64_vl256:
+; SVE: // %bb.0:
+; SVE-NEXT: ldr z0, [x1]
+; SVE-NEXT: ldr z1, [x2]
+; SVE-NEXT: ptrue p0.d
+; SVE-NEXT: ldr z4, [x0]
+; SVE-NEXT: uunpklo z2.d, z0.s
+; SVE-NEXT: uunpklo z3.d, z1.s
+; SVE-NEXT: uunpkhi z0.d, z0.s
+; SVE-NEXT: uunpkhi z1.d, z1.s
+; SVE-NEXT: mad z2.d, p0/m, z3.d, z4.d
+; SVE-NEXT: mad z0.d, p0/m, z1.d, z2.d
+; SVE-NEXT: mov z1.d, z0.d
+; SVE-NEXT: ext z1.b, z1.b, z0.b, #16
+; SVE-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SVE-NEXT: // kill: def $q1 killed $q1 killed $z1
+; SVE-NEXT: ret
+;
+; SME-LABEL: two_way_i32_i64_vl256:
+; SME: // %bb.0:
+; SME-NEXT: ldr z0, [x0]
+; SME-NEXT: ldr z1, [x1]
+; SME-NEXT: ldr z2, [x2]
+; SME-NEXT: umlalb z0.d, z2.s, z1.s
+; SME-NEXT: umlalt z0.d, z2.s, z1.s
+; SME-NEXT: mov z1.d, z0.d
+; SME-NEXT: ext z1.b, z1.b, z0.b, #16
+; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT: // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT: ret
+ %acc = load <4 x i64>, ptr %accptr
+ %u = load <8 x i32>, ptr %uptr
+ %s = load <8 x i32>, ptr %sptr
+ %u.wide = zext <8 x i32> %u to <8 x i64>
+ %s.wide = zext <8 x i32> %s to <8 x i64>
+ %mult = mul nuw nsw <8 x i64> %s.wide, %u.wide
+ %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add(<4 x i64> %acc, <8 x i64> %mult)
+ ret <4 x i64> %partial.reduce
+}
+
+
+;
+; Four-way dot (i8 -> i32)
+;
+
+define <4 x i32> @four_way_i8_i32_vl128(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; COMMON-LABEL: four_way_i8_i32_vl128:
+; COMMON: // %bb.0:
+; COMMON-NEXT: ldr q0, [x0]
+; COMMON-NEXT: ldr q1, [x1]
+; COMMON-NEXT: ldr q2, [x2]
+; COMMON-NEXT: udot v0.4s, v2.16b, v1...
[truncated]
|
if (EnablePartialReduceNodes) { | ||
unsigned NumElts = VT.getVectorNumElements(); | ||
if (VT.getVectorElementType() == MVT::i64) { | ||
setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i8, NumElts * 8), | ||
Custom); | ||
setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i16, NumElts * 4), | ||
Custom); | ||
setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i32, NumElts * 2), | ||
Custom); | ||
} else if (VT.getVectorElementType() == MVT::i32) { | ||
setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i8, NumElts * 4), | ||
Custom); | ||
setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i16, NumElts * 2), | ||
Custom); | ||
} else if (VT.getVectorElementType() == MVT::i16) { | ||
setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i8, NumElts * 2), | ||
Custom); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does this approach differ from the explicit assignments above? Do we need both, or can we get rid of one of them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The difference from the explicit assignments above is that we don't know the exact width of the vectors, because it applies to all legal vectors up to the target's vector length. For a VF=512, it could be e.g. partial.reduce(<64 x i8>) -> <8 x i64>
or partial.reduce(<32 x i8>) -> <4 x i64>
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aha, I'd missed that this function is only entered if the VT width is greater than 128.
%acc = load <8 x i16>, ptr %accptr | ||
%u = load <16 x i8>, ptr %uptr | ||
%s = load <16 x i8>, ptr %sptr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe nit: Is there a specific need for loading from a ptr? In the other partial reduce tests we have the operands come in as their types directly, rather than via ptr/load.
It would also clean up the checked asm by removing the ldr
s.
define <8 x i16> @two_way_i8_i16_vl128(<8 x i16> %acc, <16 x i8> %u, <16 x i8> %s) {
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For these (128-bit vector) types maybe not, but wider vectors (e.g. 256- or 512-bit) are passed in 128bit neon registers as per the ABI and they would need to be combined into a single SVE register, which makes the code look more messy than just doing an SVE vector load from a pointer.
88aec0f
to
fd66c5f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. Though you will need to rebase it on top of 1651aa2
fd66c5f
to
0fa3bd5
Compare
This enables the use of the [us]dot, [us]add[wt] and [us]mlal[bt] instructions in Streaming mode, and for wider vectors when the runtime vector length is known to be 256bits or larger.