Skip to content

[AArch64] Fix #94909: Optimize vector fmul(sitofp(x), 0.5) -> scvtf(x, 2) #141480

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

jph-13
Copy link
Contributor

@jph-13 jph-13 commented May 26, 2025

This commit reintroduces the optimization in InstCombine that was previously removed due to limited applicability. See: #91924

This update targets fmul(sitofp(x), C) where C is a constant reciprocal of a power of two. For both scalar and vector inputs, if we have sitofp(X) * C (where C is 1/2^N), this can be optimized to scvtf(X, 2^N). This eliminates the floating-point multiply by directly converting the integer to a scaled floating-point value.

@llvmbot
Copy link
Member

llvmbot commented May 26, 2025

@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-backend-aarch64

Author: JP Hafer (jph-13)

Changes

This commit reintroduces the optimization in InstCombine that was previously removed due to limited applicability. See: #91924

This update targets fmul(sitofp(x), C) where C is a constant reciprocal of a power of two. For both scalar and vector inputs, if we have sitofp(X) * C (where C is 1/2^N), this can be optimized to scvtf(X, 2^N). This eliminates the floating-point multiply by directly converting the integer to a scaled floating-point value.


Full diff: https://github.com/llvm/llvm-project/pull/141480.diff

2 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+152)
  • (added) llvm/test/Transforms/InstCombine/AArch64/scvtf-div-mul-combine.ll (+47)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index f2800145cc603..bb094d9772c47 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1148,6 +1148,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
   setTargetDAGCombine({ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::FP_TO_SINT_SAT,
                        ISD::FP_TO_UINT_SAT, ISD::FADD});
 
+  // Try to fmul -> scvtf for powers of 2
+  setTargetDAGCombine(ISD::FMUL);
+
   // Try and combine setcc with csel
   setTargetDAGCombine(ISD::SETCC);
 
@@ -19250,6 +19253,153 @@ static SDValue performFpToIntCombine(SDNode *N, SelectionDAG &DAG,
   return FixConv;
 }
 
+/// Try to extract a log2 exponent from a uniform constant FP splat.
+/// Returns -1 if the value is not a power-of-two float.
+static int getUniformFPSplatLog2(const BuildVectorSDNode *BV, unsigned MaxExponent) {
+  SDValue FirstElt = BV->getOperand(0);
+  if (!isa<ConstantFPSDNode>(FirstElt))
+    return -1;
+
+  const ConstantFPSDNode *FirstConst = cast<ConstantFPSDNode>(FirstElt);
+  const APFloat &FirstVal = FirstConst->getValueAPF();
+  const fltSemantics &Sem = FirstVal.getSemantics();
+
+  // Check all elements are the same
+  for (unsigned i = 1, e = BV->getNumOperands(); i != e; ++i) {
+    SDValue Elt = BV->getOperand(i);
+    if (!isa<ConstantFPSDNode>(Elt))
+      return -1;
+    const APFloat &Val = cast<ConstantFPSDNode>(Elt)->getValueAPF();
+    if (!Val.bitwiseIsEqual(FirstVal))
+      return -1;
+  }
+
+  // Reject zero, NaN, or negative values
+  if (FirstVal.isZero() || FirstVal.isNaN() || FirstVal.isNegative())
+    return -1;
+
+  // Get raw bits
+  APInt Bits = FirstVal.bitcastToAPInt();
+
+  int ExponentBias = 0;
+  unsigned ExponentBits = 0;
+  unsigned MantissaBits = 0;
+
+  if (&Sem == &APFloat::IEEEsingle()) {
+    ExponentBias = 127;
+    ExponentBits = 8;
+    MantissaBits = 23;
+  } else if (&Sem == &APFloat::IEEEdouble()) {
+    ExponentBias = 1023;
+    ExponentBits = 11;
+    MantissaBits = 52;
+  } else {
+    // Unsupported type
+    return -1;
+  }
+
+  // Mask out mantissa and check it's zero (i.e., power of two)
+  APInt MantissaMask = APInt::getLowBitsSet(Bits.getBitWidth(), MantissaBits);
+  if ((Bits & MantissaMask) != 0)
+    return -1;
+
+  // Extract exponent
+  unsigned ExponentShift = MantissaBits;
+  APInt ExponentMask = APInt::getBitsSet(Bits.getBitWidth(),
+                                         ExponentShift,
+                                         ExponentShift + ExponentBits);
+  int Exponent = (Bits & ExponentMask).lshr(ExponentShift).getZExtValue();
+  int Log2 = ExponentBias - Exponent;
+
+  if (static_cast<unsigned>(Log2) > MaxExponent)
+    return -1;
+
+  return Log2;
+}
+
+/// Fold a floating-point multiply by power of two into fixed-point to
+/// floating-point conversion.
+static SDValue performFMulCombine(SDNode *N, SelectionDAG &DAG,
+                                  TargetLowering::DAGCombinerInfo &DCI,
+                                  const AArch64Subtarget *Subtarget) {
+                                    
+  if (!Subtarget->hasNEON())
+    return SDValue();
+
+  // N is the FMUL node.
+  if (N->getOpcode() != ISD::FMUL)
+      return SDValue();
+
+  // SINT_TO_FP or UINT_TO_FP
+  SDValue Op = N->getOperand(0);
+  unsigned Opc = Op->getOpcode();
+  if (!Op.getValueType().isVector() || !Op.getValueType().isSimple() ||
+      !Op.getOperand(0).getValueType().isSimple() ||
+      (Opc != ISD::SINT_TO_FP && Opc != ISD::UINT_TO_FP))
+    return SDValue();
+
+  SDValue ConstVec = N->getOperand(1);
+  if (!isa<BuildVectorSDNode>(ConstVec))
+    return SDValue();
+
+  MVT IntTy = Op.getOperand(0).getSimpleValueType().getVectorElementType();
+  int32_t IntBits = IntTy.getSizeInBits();
+  if (IntBits != 16 && IntBits != 32 && IntBits != 64)
+    return SDValue();
+
+  MVT FloatTy = N->getSimpleValueType(0).getVectorElementType();
+  int32_t FloatBits = FloatTy.getSizeInBits();
+  if (FloatBits != 32 && FloatBits != 64)
+    return SDValue();
+
+  if (IntBits > FloatBits)
+    return SDValue();
+
+  BitVector UndefElements;
+  BuildVectorSDNode *BV = cast<BuildVectorSDNode>(ConstVec);
+  int32_t IntrinsicC = getUniformFPSplatLog2(BV, FloatBits + 1);
+
+  // Handle cases where it's not a power of two, or is 2^0.
+  if (IntrinsicC == -1 || IntrinsicC == 0)
+    return SDValue();
+
+  // Check if IntrinsicC is within the valid range [1, FloatBits].
+  // The 's' value must be in [1, FloatBits].
+  if (IntrinsicC <= 0 || IntrinsicC > FloatBits)
+      return SDValue();
+
+  MVT ResTy;
+  unsigned NumLanes = Op.getValueType().getVectorNumElements();
+  switch (NumLanes) {
+  default:
+    return SDValue();
+  case 2:
+    ResTy = FloatBits == 32 ? MVT::v2i32 : MVT::v2i64;
+    break;
+  case 4:
+    ResTy = FloatBits == 32 ? MVT::v4i32 : MVT::v4i64;
+    break;
+  }
+
+  if (ResTy == MVT::v4i64 && DCI.isBeforeLegalizeOps())
+    return SDValue();
+
+  SDLoc DL(N);
+  SDValue ConvInput = Op.getOperand(0);
+  bool IsSigned = Opc == ISD::SINT_TO_FP;
+
+  if (IntBits < FloatBits)
+    ConvInput = DAG.getNode(IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, DL,
+                            ResTy, ConvInput);
+
+  unsigned IntrinsicOpcode = IsSigned ? Intrinsic::aarch64_neon_vcvtfxs2fp
+                                      : Intrinsic::aarch64_neon_vcvtfxu2fp;
+
+  return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(),
+                     DAG.getConstant(IntrinsicOpcode, DL, MVT::i32), ConvInput,
+                     DAG.getConstant(IntrinsicC, DL, MVT::i32));
+}
+
 static SDValue tryCombineToBSL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
                                const AArch64TargetLowering &TLI) {
   EVT VT = N->getValueType(0);
@@ -26693,6 +26843,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
   case ISD::FP_TO_SINT_SAT:
   case ISD::FP_TO_UINT_SAT:
     return performFpToIntCombine(N, DAG, DCI, Subtarget);
+  case ISD::FMUL:
+    return performFMulCombine(N, DAG, DCI, Subtarget);
   case ISD::OR:
     return performORCombine(N, DCI, Subtarget, *this);
   case ISD::AND:
diff --git a/llvm/test/Transforms/InstCombine/AArch64/scvtf-div-mul-combine.ll b/llvm/test/Transforms/InstCombine/AArch64/scvtf-div-mul-combine.ll
new file mode 100644
index 0000000000000..befddb165fcce
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/AArch64/scvtf-div-mul-combine.ll
@@ -0,0 +1,47 @@
+; RUN: llc -mtriple=aarch64-linux-gnu -aarch64-neon-syntax=apple -verify-machineinstrs -o - %s | FileCheck %s
+
+; Test case 1: Scalar fdiv by 16.0
+define float @tests(i32 %in) {
+; CHECK-LABEL: tests:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    scvtf   s0, w0, #4
+; CHECK-NEXT:    ret
+entry:
+  %vcvt.i = sitofp i32 %in to float
+  %div.i = fdiv float %vcvt.i, 16.0
+  ret float %div.i
+}
+
+; Test case 2: Scalar fmul by (2^-4)
+define float @testsmul(i32 %in) local_unnamed_addr #0 {
+; CHECK-LABEL: testsmul:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    scvtf   s0, w0, #4
+; CHECK-NEXT:    ret
+  %vcvt.i = sitofp i32 %in to float
+  %div.i = fmul float %vcvt.i, 6.250000e-02 ; 0.0625 is 2^-4
+  ret float %div.i
+}
+
+; Test case 3: Vector fdiv by 16.0
+define <2 x float> @testv(<2 x i32> %in) {
+; CHECK-LABEL: testv:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    scvtf.2s        v0, v0, #4
+; CHECK-NEXT:    ret
+entry:
+  %vcvt.i = sitofp <2 x i32> %in to <2 x float>
+  %div.i = fdiv <2 x float> %vcvt.i, <float 16.0, float 16.0>
+  ret <2 x float> %div.i
+}
+
+; Test case 4: Vector fmul by 2^-4 
+define <2 x float> @testvmul(<2 x i32> %in) local_unnamed_addr #0 {
+; CHECK-LABEL: testvmul:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    scvtf.2s        v0, v0, #4
+; CHECK-NEXT:    ret
+  %vcvt.i = sitofp <2 x i32> %in to <2 x float>
+  %div.i = fmul <2 x float> %vcvt.i, splat (float 6.250000e-02) ; 0.0625 is 2^-4
+  ret <2 x float> %div.i
+}
\ No newline at end of file

Copy link

github-actions bot commented May 26, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@davemgreen
Copy link
Collaborator

It looks like the existing scalar code uses tablegen patterns, via SelectCVTFixedPosRecipOperand. Would it be possible to do the same for vector operations? It would need to detect the splat constant.

@jph-13
Copy link
Contributor Author

jph-13 commented May 27, 2025

I will take a look at the tablegen impls. Maybe that will help me understand why I am having issues with half too.

@jph-13
Copy link
Contributor Author

jph-13 commented May 30, 2025

I just resolved all the original flags since the new implementation is very different. I did try to get f16 working but I became very confused. As of now it doesn't appear to have a match in TD. I started creating one but I am not sure if I shold replace all the round tripping or not. So I figured I would see if we could get this in, then maybe try another pass at half later.

@jph-13 jph-13 force-pushed the issue_94909 branch 2 times, most recently from fbb93fd to 67a4484 Compare May 30, 2025 18:54
@jph-13
Copy link
Contributor Author

jph-13 commented May 30, 2025

Please ignore for now, not sure what I broke when I squashed. Sorry.

@jph-13
Copy link
Contributor Author

jph-13 commented Jun 10, 2025

This is incomplete, but I could really use some help. I have never touched tablegen before this and feel I am making it way too difficult.

I am having specific issues with the smaller registers and sizes (are my matchers too restrictive?). I left in commented out vNi16 code for I can't get any of it to not conflict. The v1i32 and v1i64 are also escaping me. As for the f16 implementations, it seems I need to handle a sext, but I figure the rest should be working first.

I could use any comments or guidance folks have.

Thanks.

Copy link
Collaborator

@davemgreen davemgreen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this got quite far. I've left a few comments inline, tablegen can be finicky at times.

@jph-13 jph-13 marked this pull request as draft June 16, 2025 13:52
…tf(x, 2)

This commit reintroduces the optimization in InstCombine that was previously removed due to limited applicability.
See: llvm#91924

This update targets `fmul(sitofp(x), C)` where `C` is a constant reciprocal of a power of two. For both scalar and vector inputs, if we have `sitofp(X) * C` (where `C` is `1/2^N`), this can be optimized to `scvtf(X, 2^N)`. This eliminates the floating-point multiply by directly converting the integer to a scaled floating-point value.
@jph-13 jph-13 marked this pull request as ready for review June 18, 2025 15:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants