Skip to content

[AArch64] Enable fixed-length vector support for partial-reductions #142032

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

sdesmalen-arm
Copy link
Collaborator

@sdesmalen-arm sdesmalen-arm commented May 29, 2025

This enables the use of the [us]dot, [us]add[wt] and [us]mlal[bt] instructions in Streaming mode, and for wider vectors when the runtime vector length is known to be 256bits or larger.

@llvmbot
Copy link
Member

llvmbot commented May 29, 2025

@llvm/pr-subscribers-backend-aarch64

Author: Sander de Smalen (sdesmalen-arm)

Changes

Patch is 34.30 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/142032.diff

2 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+66-27)
  • (added) llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll (+791)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a07afea963e20..34ce8bb7dd2b2 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1930,6 +1930,14 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
                          Custom);
       setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::nxv2i64,
                          Custom);
+
+      // Must be lowered to SVE instructions.
+      setPartialReduceMLAAction(MVT::v2i64, MVT::v4i32, Custom);
+      setPartialReduceMLAAction(MVT::v2i64, MVT::v8i16, Custom);
+      setPartialReduceMLAAction(MVT::v2i64, MVT::v16i8, Custom);
+      setPartialReduceMLAAction(MVT::v4i32, MVT::v8i16, Custom);
+      setPartialReduceMLAAction(MVT::v4i32, MVT::v16i8, Custom);
+      setPartialReduceMLAAction(MVT::v8i16, MVT::v16i8, Custom);
     }
   }
 
@@ -2225,6 +2233,26 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
   bool PreferNEON = VT.is64BitVector() || VT.is128BitVector();
   bool PreferSVE = !PreferNEON && Subtarget->isSVEAvailable();
 
+  if (EnablePartialReduceNodes) {
+    unsigned NumElts = VT.getVectorNumElements();
+    if (VT.getVectorElementType() == MVT::i64) {
+      setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i8, NumElts * 8),
+                                Custom);
+      setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i16, NumElts * 4),
+                                Custom);
+      setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i32, NumElts * 2),
+                                Custom);
+    } else if (VT.getVectorElementType() == MVT::i32) {
+      setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i8, NumElts * 4),
+                                Custom);
+      setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i16, NumElts * 2),
+                                Custom);
+    } else if (VT.getVectorElementType() == MVT::i16) {
+      setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i8, NumElts * 2),
+                                Custom);
+    }
+  }
+
   // Lower fixed length vector operations to scalable equivalents.
   setOperationAction(ISD::ABDS, VT, Default);
   setOperationAction(ISD::ABDU, VT, Default);
@@ -29224,50 +29252,61 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
 SDValue
 AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
                                                SelectionDAG &DAG) const {
-  bool Scalable = Op.getValueType().isScalableVector();
-
-  assert((!Scalable || Subtarget->isSVEorStreamingSVEAvailable()) &&
-         "SVE or StreamingSVE must be available when using scalable vectors.");
-  assert((Scalable || Subtarget->hasDotProd()) &&
-         "Dotprod must be available when targeting NEON dot product "
-         "instructions.");
-
   SDLoc DL(Op);
 
   SDValue Acc = Op.getOperand(0);
   SDValue LHS = Op.getOperand(1);
   SDValue RHS = Op.getOperand(2);
   EVT ResultVT = Op.getValueType();
+  EVT OrigResultVT = ResultVT;
+  EVT OpVT = LHS.getValueType();
 
-  assert((Scalable && ResultVT == MVT::nxv2i64 &&
-          LHS.getValueType() == MVT::nxv16i8) ||
-         (!Scalable && ResultVT == MVT::v2i64 &&
-          LHS.getValueType() == MVT::v16i8));
+  bool ConvertToScalable =
+      ResultVT.isFixedLengthVector() &&
+      useSVEForFixedLengthVectorVT(ResultVT, /*OverrideNEON=*/true);
 
-  EVT DotVT = Scalable ? MVT::nxv4i32 : MVT::v4i32;
+  if (ConvertToScalable) {
+    ResultVT = getContainerForFixedLengthVector(DAG, ResultVT);
+    OpVT = getContainerForFixedLengthVector(DAG, LHS.getValueType());
+    Acc = convertToScalableVector(DAG, ResultVT, Acc);
+    LHS = convertToScalableVector(DAG, OpVT, LHS);
+    RHS = convertToScalableVector(DAG, OpVT, RHS);
+    Op = DAG.getNode(Op.getOpcode(), DL, ResultVT, {Acc, LHS, RHS});
+  }
+
+  // Two-way and four-way partial reductions are supported by patterns.
+  // We only need to handle the 8-way partial reduction.
+  if (ResultVT.getScalarType() != MVT::i64 || OpVT.getScalarType() != MVT::i8)
+    return ConvertToScalable ? convertFromScalableVector(DAG, OrigResultVT, Op)
+                             : Op;
+
+  EVT DotVT = ResultVT.isScalableVector() ? MVT::nxv4i32 : MVT::v4i32;
   SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, DotVT,
                                 DAG.getConstant(0, DL, DotVT), LHS, RHS);
 
+  SDValue Res;
   bool IsUnsigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_UMLA;
-  if (Scalable &&
-      (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable())) {
+  if (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable()) {
     unsigned LoOpcode = IsUnsigned ? AArch64ISD::UADDWB : AArch64ISD::SADDWB;
     unsigned HiOpcode = IsUnsigned ? AArch64ISD::UADDWT : AArch64ISD::SADDWT;
     SDValue Lo = DAG.getNode(LoOpcode, DL, ResultVT, Acc, DotNode);
-    return DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
-  }
-
-  // Fold (nx)v4i32 into (nx)v2i64
-  auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
-  if (IsUnsigned) {
-    DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT);
-    DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT);
+    Res = DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
   } else {
-    DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT);
-    DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT);
+    // Fold (nx)v4i32 into (nx)v2i64
+    auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
+    if (IsUnsigned) {
+      DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT);
+      DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT);
+    } else {
+      DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT);
+      DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT);
+    }
+    auto Lo = DAG.getNode(ISD::ADD, DL, ResultVT, Acc, DotNodeLo);
+    Res = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, DotNodeHi);
   }
-  auto Lo = DAG.getNode(ISD::ADD, DL, ResultVT, Acc, DotNodeLo);
-  return DAG.getNode(ISD::ADD, DL, ResultVT, Lo, DotNodeHi);
+
+  return ConvertToScalable ? convertFromScalableVector(DAG, OrigResultVT, Res)
+                           : Res;
 }
 
 SDValue
diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll
new file mode 100644
index 0000000000000..79d766d1b9908
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll
@@ -0,0 +1,791 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mattr=+dotprod -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,NEON
+; RUN: llc -mattr=+sve,+dotprod -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,SVE
+; RUN: llc -mattr=+sme -aarch64-enable-partial-reduce-nodes=true -force-streaming < %s | FileCheck %s --check-prefix=SME
+
+target triple = "aarch64"
+
+;
+; Two-way mla (i8 -> i16)
+;
+
+define <8 x i16> @two_way_i8_i16_vl128(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; COMMON-LABEL: two_way_i8_i16_vl128:
+; COMMON:       // %bb.0:
+; COMMON-NEXT:    ldr q0, [x0]
+; COMMON-NEXT:    ldr q1, [x1]
+; COMMON-NEXT:    ldr q2, [x2]
+; COMMON-NEXT:    umlal v0.8h, v2.8b, v1.8b
+; COMMON-NEXT:    umlal2 v0.8h, v2.16b, v1.16b
+; COMMON-NEXT:    ret
+;
+; SME-LABEL: two_way_i8_i16_vl128:
+; SME:       // %bb.0:
+; SME-NEXT:    ldr q0, [x0]
+; SME-NEXT:    ldr q1, [x1]
+; SME-NEXT:    ldr q2, [x2]
+; SME-NEXT:    umlalb z0.h, z2.b, z1.b
+; SME-NEXT:    umlalt z0.h, z2.b, z1.b
+; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT:    ret
+  %acc = load <8 x i16>, ptr %accptr
+  %u = load <16 x i8>, ptr %uptr
+  %s = load <16 x i8>, ptr %sptr
+  %u.wide = zext <16 x i8> %u to <16 x i16>
+  %s.wide = zext <16 x i8> %s to <16 x i16>
+  %mult = mul nuw nsw <16 x i16> %s.wide, %u.wide
+  %partial.reduce = tail call <8 x i16> @llvm.experimental.vector.partial.reduce.add(<8 x i16> %acc, <16 x i16> %mult)
+  ret <8 x i16> %partial.reduce
+}
+
+define <16 x i16> @two_way_i8_i16_vl128_double_width(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; COMMON-LABEL: two_way_i8_i16_vl128_double_width:
+; COMMON:       // %bb.0:
+; COMMON-NEXT:    ldp q0, q1, [x0]
+; COMMON-NEXT:    ldp q2, q3, [x1]
+; COMMON-NEXT:    ldp q4, q5, [x2]
+; COMMON-NEXT:    umlal v0.8h, v4.8b, v2.8b
+; COMMON-NEXT:    umlal v1.8h, v5.8b, v3.8b
+; COMMON-NEXT:    umlal2 v0.8h, v4.16b, v2.16b
+; COMMON-NEXT:    umlal2 v1.8h, v5.16b, v3.16b
+; COMMON-NEXT:    ret
+;
+; SME-LABEL: two_way_i8_i16_vl128_double_width:
+; SME:       // %bb.0:
+; SME-NEXT:    ldp q0, q1, [x0]
+; SME-NEXT:    ldp q3, q2, [x1]
+; SME-NEXT:    ldp q5, q4, [x2]
+; SME-NEXT:    umlalb z0.h, z5.b, z3.b
+; SME-NEXT:    umlalb z1.h, z4.b, z2.b
+; SME-NEXT:    umlalt z0.h, z5.b, z3.b
+; SME-NEXT:    umlalt z1.h, z4.b, z2.b
+; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT:    // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT:    ret
+  %acc = load <16 x i16>, ptr %accptr
+  %u = load <32 x i8>, ptr %uptr
+  %s = load <32 x i8>, ptr %sptr
+  %u.wide = zext <32 x i8> %u to <32 x i16>
+  %s.wide = zext <32 x i8> %s to <32 x i16>
+  %mult = mul nuw nsw <32 x i16> %s.wide, %u.wide
+  %partial.reduce = tail call <16 x i16> @llvm.experimental.vector.partial.reduce.add(<16 x i16> %acc, <32 x i16> %mult)
+  ret <16 x i16> %partial.reduce
+}
+
+define <16 x i16> @two_way_i8_i16_vl256(ptr %accptr, ptr %uptr, ptr %sptr) vscale_range(2,2) {
+;
+;
+; NEON-LABEL: two_way_i8_i16_vl256:
+; NEON:       // %bb.0:
+; NEON-NEXT:    ldp q0, q1, [x0]
+; NEON-NEXT:    ldp q2, q3, [x1]
+; NEON-NEXT:    ldp q4, q5, [x2]
+; NEON-NEXT:    umlal v0.8h, v4.8b, v2.8b
+; NEON-NEXT:    umlal v1.8h, v5.8b, v3.8b
+; NEON-NEXT:    umlal2 v0.8h, v4.16b, v2.16b
+; NEON-NEXT:    umlal2 v1.8h, v5.16b, v3.16b
+; NEON-NEXT:    ret
+;
+; SVE-LABEL: two_way_i8_i16_vl256:
+; SVE:       // %bb.0:
+; SVE-NEXT:    ldr z0, [x1]
+; SVE-NEXT:    ldr z1, [x2]
+; SVE-NEXT:    ptrue p0.h
+; SVE-NEXT:    ldr z4, [x0]
+; SVE-NEXT:    uunpklo z2.h, z0.b
+; SVE-NEXT:    uunpklo z3.h, z1.b
+; SVE-NEXT:    uunpkhi z0.h, z0.b
+; SVE-NEXT:    uunpkhi z1.h, z1.b
+; SVE-NEXT:    mad z2.h, p0/m, z3.h, z4.h
+; SVE-NEXT:    mad z0.h, p0/m, z1.h, z2.h
+; SVE-NEXT:    mov z1.d, z0.d
+; SVE-NEXT:    ext z1.b, z1.b, z0.b, #16
+; SVE-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SVE-NEXT:    // kill: def $q1 killed $q1 killed $z1
+; SVE-NEXT:    ret
+;
+; SME-LABEL: two_way_i8_i16_vl256:
+; SME:       // %bb.0:
+; SME-NEXT:    ldr z0, [x0]
+; SME-NEXT:    ldr z1, [x1]
+; SME-NEXT:    ldr z2, [x2]
+; SME-NEXT:    umlalb z0.h, z2.b, z1.b
+; SME-NEXT:    umlalt z0.h, z2.b, z1.b
+; SME-NEXT:    mov z1.d, z0.d
+; SME-NEXT:    ext z1.b, z1.b, z0.b, #16
+; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT:    // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT:    ret
+  %acc = load <16 x i16>, ptr %accptr
+  %u = load <32 x i8>, ptr %uptr
+  %s = load <32 x i8>, ptr %sptr
+  %u.wide = zext <32 x i8> %u to <32 x i16>
+  %s.wide = zext <32 x i8> %s to <32 x i16>
+  %mult = mul nuw nsw <32 x i16> %s.wide, %u.wide
+  %partial.reduce = tail call <16 x i16> @llvm.experimental.vector.partial.reduce.add(<16 x i16> %acc, <32 x i16> %mult)
+  ret <16 x i16> %partial.reduce
+}
+
+;
+; Two-way mla (i16 -> i32)
+;
+
+define <4 x i32> @two_way_i16_i32_vl128(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; COMMON-LABEL: two_way_i16_i32_vl128:
+; COMMON:       // %bb.0:
+; COMMON-NEXT:    ldr q0, [x0]
+; COMMON-NEXT:    ldr q1, [x1]
+; COMMON-NEXT:    ldr q2, [x2]
+; COMMON-NEXT:    umlal v0.4s, v2.4h, v1.4h
+; COMMON-NEXT:    umlal2 v0.4s, v2.8h, v1.8h
+; COMMON-NEXT:    ret
+;
+; SME-LABEL: two_way_i16_i32_vl128:
+; SME:       // %bb.0:
+; SME-NEXT:    ldr q0, [x0]
+; SME-NEXT:    ldr q1, [x1]
+; SME-NEXT:    ldr q2, [x2]
+; SME-NEXT:    umlalb z0.s, z2.h, z1.h
+; SME-NEXT:    umlalt z0.s, z2.h, z1.h
+; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT:    ret
+  %acc = load <4 x i32>, ptr %accptr
+  %u = load <8 x i16>, ptr %uptr
+  %s = load <8 x i16>, ptr %sptr
+  %u.wide = zext <8 x i16> %u to <8 x i32>
+  %s.wide = zext <8 x i16> %s to <8 x i32>
+  %mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
+  %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add(<4 x i32> %acc, <8 x i32> %mult)
+  ret <4 x i32> %partial.reduce
+}
+
+define <8 x i32> @two_way_i16_i32_vl128_double_width(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; COMMON-LABEL: two_way_i16_i32_vl128_double_width:
+; COMMON:       // %bb.0:
+; COMMON-NEXT:    ldp q0, q1, [x0]
+; COMMON-NEXT:    ldp q2, q3, [x1]
+; COMMON-NEXT:    ldp q4, q5, [x2]
+; COMMON-NEXT:    umlal v0.4s, v4.4h, v2.4h
+; COMMON-NEXT:    umlal v1.4s, v5.4h, v3.4h
+; COMMON-NEXT:    umlal2 v0.4s, v4.8h, v2.8h
+; COMMON-NEXT:    umlal2 v1.4s, v5.8h, v3.8h
+; COMMON-NEXT:    ret
+;
+; SME-LABEL: two_way_i16_i32_vl128_double_width:
+; SME:       // %bb.0:
+; SME-NEXT:    ldp q0, q1, [x0]
+; SME-NEXT:    ldp q3, q2, [x1]
+; SME-NEXT:    ldp q5, q4, [x2]
+; SME-NEXT:    umlalb z0.s, z5.h, z3.h
+; SME-NEXT:    umlalb z1.s, z4.h, z2.h
+; SME-NEXT:    umlalt z0.s, z5.h, z3.h
+; SME-NEXT:    umlalt z1.s, z4.h, z2.h
+; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT:    // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT:    ret
+  %acc = load <8 x i32>, ptr %accptr
+  %u = load <16 x i16>, ptr %uptr
+  %s = load <16 x i16>, ptr %sptr
+  %u.wide = zext <16 x i16> %u to <16 x i32>
+  %s.wide = zext <16 x i16> %s to <16 x i32>
+  %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
+  %partial.reduce = tail call <8 x i32> @llvm.experimental.vector.partial.reduce.add(<8 x i32> %acc, <16 x i32> %mult)
+  ret <8 x i32> %partial.reduce
+}
+
+define <8 x i32> @two_way_i16_i32_vl256(ptr %accptr, ptr %uptr, ptr %sptr) vscale_range(2,2) {
+;
+;
+; NEON-LABEL: two_way_i16_i32_vl256:
+; NEON:       // %bb.0:
+; NEON-NEXT:    ldp q0, q1, [x0]
+; NEON-NEXT:    ldp q2, q3, [x1]
+; NEON-NEXT:    ldp q4, q5, [x2]
+; NEON-NEXT:    umlal v0.4s, v4.4h, v2.4h
+; NEON-NEXT:    umlal v1.4s, v5.4h, v3.4h
+; NEON-NEXT:    umlal2 v0.4s, v4.8h, v2.8h
+; NEON-NEXT:    umlal2 v1.4s, v5.8h, v3.8h
+; NEON-NEXT:    ret
+;
+; SVE-LABEL: two_way_i16_i32_vl256:
+; SVE:       // %bb.0:
+; SVE-NEXT:    ldr z0, [x1]
+; SVE-NEXT:    ldr z1, [x2]
+; SVE-NEXT:    ptrue p0.s
+; SVE-NEXT:    ldr z4, [x0]
+; SVE-NEXT:    uunpklo z2.s, z0.h
+; SVE-NEXT:    uunpklo z3.s, z1.h
+; SVE-NEXT:    uunpkhi z0.s, z0.h
+; SVE-NEXT:    uunpkhi z1.s, z1.h
+; SVE-NEXT:    mad z2.s, p0/m, z3.s, z4.s
+; SVE-NEXT:    mad z0.s, p0/m, z1.s, z2.s
+; SVE-NEXT:    mov z1.d, z0.d
+; SVE-NEXT:    ext z1.b, z1.b, z0.b, #16
+; SVE-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SVE-NEXT:    // kill: def $q1 killed $q1 killed $z1
+; SVE-NEXT:    ret
+;
+; SME-LABEL: two_way_i16_i32_vl256:
+; SME:       // %bb.0:
+; SME-NEXT:    ldr z0, [x0]
+; SME-NEXT:    ldr z1, [x1]
+; SME-NEXT:    ldr z2, [x2]
+; SME-NEXT:    umlalb z0.s, z2.h, z1.h
+; SME-NEXT:    umlalt z0.s, z2.h, z1.h
+; SME-NEXT:    mov z1.d, z0.d
+; SME-NEXT:    ext z1.b, z1.b, z0.b, #16
+; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT:    // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT:    ret
+  %acc = load <8 x i32>, ptr %accptr
+  %u = load <16 x i16>, ptr %uptr
+  %s = load <16 x i16>, ptr %sptr
+  %u.wide = zext <16 x i16> %u to <16 x i32>
+  %s.wide = zext <16 x i16> %s to <16 x i32>
+  %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
+  %partial.reduce = tail call <8 x i32> @llvm.experimental.vector.partial.reduce.add(<8 x i32> %acc, <16 x i32> %mult)
+  ret <8 x i32> %partial.reduce
+}
+
+;
+; Two-way mla (i32 -> i64)
+;
+
+define <2 x i64> @two_way_i32_i64_vl128(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; COMMON-LABEL: two_way_i32_i64_vl128:
+; COMMON:       // %bb.0:
+; COMMON-NEXT:    ldr q0, [x0]
+; COMMON-NEXT:    ldr q1, [x1]
+; COMMON-NEXT:    ldr q2, [x2]
+; COMMON-NEXT:    umlal v0.2d, v2.2s, v1.2s
+; COMMON-NEXT:    umlal2 v0.2d, v2.4s, v1.4s
+; COMMON-NEXT:    ret
+;
+; SME-LABEL: two_way_i32_i64_vl128:
+; SME:       // %bb.0:
+; SME-NEXT:    ldr q0, [x0]
+; SME-NEXT:    ldr q1, [x1]
+; SME-NEXT:    ldr q2, [x2]
+; SME-NEXT:    umlalb z0.d, z2.s, z1.s
+; SME-NEXT:    umlalt z0.d, z2.s, z1.s
+; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT:    ret
+  %acc = load <2 x i64>, ptr %accptr
+  %u = load <4 x i32>, ptr %uptr
+  %s = load <4 x i32>, ptr %sptr
+  %u.wide = zext <4 x i32> %u to <4 x i64>
+  %s.wide = zext <4 x i32> %s to <4 x i64>
+  %mult = mul nuw nsw <4 x i64> %s.wide, %u.wide
+  %partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add(<2 x i64> %acc, <4 x i64> %mult)
+  ret <2 x i64> %partial.reduce
+}
+
+define <4 x i64> @two_way_i32_i64_vl128_double_width(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; COMMON-LABEL: two_way_i32_i64_vl128_double_width:
+; COMMON:       // %bb.0:
+; COMMON-NEXT:    ldp q0, q1, [x0]
+; COMMON-NEXT:    ldp q2, q3, [x1]
+; COMMON-NEXT:    ldp q4, q5, [x2]
+; COMMON-NEXT:    umlal v0.2d, v4.2s, v2.2s
+; COMMON-NEXT:    umlal v1.2d, v5.2s, v3.2s
+; COMMON-NEXT:    umlal2 v0.2d, v4.4s, v2.4s
+; COMMON-NEXT:    umlal2 v1.2d, v5.4s, v3.4s
+; COMMON-NEXT:    ret
+;
+; SME-LABEL: two_way_i32_i64_vl128_double_width:
+; SME:       // %bb.0:
+; SME-NEXT:    ldp q0, q1, [x0]
+; SME-NEXT:    ldp q3, q2, [x1]
+; SME-NEXT:    ldp q5, q4, [x2]
+; SME-NEXT:    umlalb z0.d, z5.s, z3.s
+; SME-NEXT:    umlalb z1.d, z4.s, z2.s
+; SME-NEXT:    umlalt z0.d, z5.s, z3.s
+; SME-NEXT:    umlalt z1.d, z4.s, z2.s
+; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT:    // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT:    ret
+  %acc = load <4 x i64>, ptr %accptr
+  %u = load <8 x i32>, ptr %uptr
+  %s = load <8 x i32>, ptr %sptr
+  %u.wide = zext <8 x i32> %u to <8 x i64>
+  %s.wide = zext <8 x i32> %s to <8 x i64>
+  %mult = mul nuw nsw <8 x i64> %s.wide, %u.wide
+  %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add(<4 x i64> %acc, <8 x i64> %mult)
+  ret <4 x i64> %partial.reduce
+}
+
+define <4 x i64> @two_way_i32_i64_vl256(ptr %accptr, ptr %uptr, ptr %sptr) vscale_range(2,2) {
+;
+;
+; NEON-LABEL: two_way_i32_i64_vl256:
+; NEON:       // %bb.0:
+; NEON-NEXT:    ldp q0, q1, [x0]
+; NEON-NEXT:    ldp q2, q3, [x1]
+; NEON-NEXT:    ldp q4, q5, [x2]
+; NEON-NEXT:    umlal v0.2d, v4.2s, v2.2s
+; NEON-NEXT:    umlal v1.2d, v5.2s, v3.2s
+; NEON-NEXT:    umlal2 v0.2d, v4.4s, v2.4s
+; NEON-NEXT:    umlal2 v1.2d, v5.4s, v3.4s
+; NEON-NEXT:    ret
+;
+; SVE-LABEL: two_way_i32_i64_vl256:
+; SVE:       // %bb.0:
+; SVE-NEXT:    ldr z0, [x1]
+; SVE-NEXT:    ldr z1, [x2]
+; SVE-NEXT:    ptrue p0.d
+; SVE-NEXT:    ldr z4, [x0]
+; SVE-NEXT:    uunpklo z2.d, z0.s
+; SVE-NEXT:    uunpklo z3.d, z1.s
+; SVE-NEXT:    uunpkhi z0.d, z0.s
+; SVE-NEXT:    uunpkhi z1.d, z1.s
+; SVE-NEXT:    mad z2.d, p0/m, z3.d, z4.d
+; SVE-NEXT:    mad z0.d, p0/m, z1.d, z2.d
+; SVE-NEXT:    mov z1.d, z0.d
+; SVE-NEXT:    ext z1.b, z1.b, z0.b, #16
+; SVE-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SVE-NEXT:    // kill: def $q1 killed $q1 killed $z1
+; SVE-NEXT:    ret
+;
+; SME-LABEL: two_way_i32_i64_vl256:
+; SME:       // %bb.0:
+; SME-NEXT:    ldr z0, [x0]
+; SME-NEXT:    ldr z1, [x1]
+; SME-NEXT:    ldr z2, [x2]
+; SME-NEXT:    umlalb z0.d, z2.s, z1.s
+; SME-NEXT:    umlalt z0.d, z2.s, z1.s
+; SME-NEXT:    mov z1.d, z0.d
+; SME-NEXT:    ext z1.b, z1.b, z0.b, #16
+; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT:    // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT:    ret
+  %acc = load <4 x i64>, ptr %accptr
+  %u = load <8 x i32>, ptr %uptr
+  %s = load <8 x i32>, ptr %sptr
+  %u.wide = zext <8 x i32> %u to <8 x i64>
+  %s.wide = zext <8 x i32> %s to <8 x i64>
+  %mult = mul nuw nsw <8 x i64> %s.wide, %u.wide
+  %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add(<4 x i64> %acc, <8 x i64> %mult)
+  ret <4 x i64> %partial.reduce
+}
+
+
+;
+; Four-way dot (i8 -> i32)
+;
+
+define <4 x i32> @four_way_i8_i32_vl128(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; COMMON-LABEL: four_way_i8_i32_vl128:
+; COMMON:       // %bb.0:
+; COMMON-NEXT:    ldr q0, [x0]
+; COMMON-NEXT:    ldr q1, [x1]
+; COMMON-NEXT:    ldr q2, [x2]
+; COMMON-NEXT:    udot v0.4s, v2.16b, v1...
[truncated]

Comment on lines 2236 to 2254
if (EnablePartialReduceNodes) {
unsigned NumElts = VT.getVectorNumElements();
if (VT.getVectorElementType() == MVT::i64) {
setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i8, NumElts * 8),
Custom);
setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i16, NumElts * 4),
Custom);
setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i32, NumElts * 2),
Custom);
} else if (VT.getVectorElementType() == MVT::i32) {
setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i8, NumElts * 4),
Custom);
setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i16, NumElts * 2),
Custom);
} else if (VT.getVectorElementType() == MVT::i16) {
setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i8, NumElts * 2),
Custom);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this approach differ from the explicit assignments above? Do we need both, or can we get rid of one of them?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The difference from the explicit assignments above is that we don't know the exact width of the vectors, because it applies to all legal vectors up to the target's vector length. For a VF=512, it could be e.g. partial.reduce(<64 x i8>) -> <8 x i64> or partial.reduce(<32 x i8>) -> <4 x i64>.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha, I'd missed that this function is only entered if the VT width is greater than 128.

Comment on lines +32 to +34
%acc = load <8 x i16>, ptr %accptr
%u = load <16 x i8>, ptr %uptr
%s = load <16 x i8>, ptr %sptr
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe nit: Is there a specific need for loading from a ptr? In the other partial reduce tests we have the operands come in as their types directly, rather than via ptr/load.
It would also clean up the checked asm by removing the ldrs.

define <8 x i16> @two_way_i8_i16_vl128(<8 x i16> %acc, <16 x i8> %u, <16 x i8> %s) {

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For these (128-bit vector) types maybe not, but wider vectors (e.g. 256- or 512-bit) are passed in 128bit neon registers as per the ABI and they would need to be combined into a single SVE register, which makes the code look more messy than just doing an SVE vector load from a pointer.

@sdesmalen-arm sdesmalen-arm force-pushed the fixed-length-partial-reductions branch from 88aec0f to fd66c5f Compare May 30, 2025 12:11
Copy link
Contributor

@NickGuy-Arm NickGuy-Arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Though you will need to rebase it on top of 1651aa2

@sdesmalen-arm sdesmalen-arm force-pushed the fixed-length-partial-reductions branch from fd66c5f to 0fa3bd5 Compare May 30, 2025 15:20
@sdesmalen-arm sdesmalen-arm merged commit 12bd049 into llvm:main May 30, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants