-
Notifications
You must be signed in to change notification settings - Fork 14.1k
[AArch64] Fix #94909: Optimize vector fmul(sitofp(x), 0.5) -> scvtf(x, 2) #141480
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-llvm-transforms @llvm/pr-subscribers-backend-aarch64 Author: JP Hafer (jph-13) ChangesThis commit reintroduces the optimization in InstCombine that was previously removed due to limited applicability. See: #91924 This update targets Full diff: https://github.com/llvm/llvm-project/pull/141480.diff 2 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index f2800145cc603..bb094d9772c47 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1148,6 +1148,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setTargetDAGCombine({ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::FP_TO_SINT_SAT,
ISD::FP_TO_UINT_SAT, ISD::FADD});
+ // Try to fmul -> scvtf for powers of 2
+ setTargetDAGCombine(ISD::FMUL);
+
// Try and combine setcc with csel
setTargetDAGCombine(ISD::SETCC);
@@ -19250,6 +19253,153 @@ static SDValue performFpToIntCombine(SDNode *N, SelectionDAG &DAG,
return FixConv;
}
+/// Try to extract a log2 exponent from a uniform constant FP splat.
+/// Returns -1 if the value is not a power-of-two float.
+static int getUniformFPSplatLog2(const BuildVectorSDNode *BV, unsigned MaxExponent) {
+ SDValue FirstElt = BV->getOperand(0);
+ if (!isa<ConstantFPSDNode>(FirstElt))
+ return -1;
+
+ const ConstantFPSDNode *FirstConst = cast<ConstantFPSDNode>(FirstElt);
+ const APFloat &FirstVal = FirstConst->getValueAPF();
+ const fltSemantics &Sem = FirstVal.getSemantics();
+
+ // Check all elements are the same
+ for (unsigned i = 1, e = BV->getNumOperands(); i != e; ++i) {
+ SDValue Elt = BV->getOperand(i);
+ if (!isa<ConstantFPSDNode>(Elt))
+ return -1;
+ const APFloat &Val = cast<ConstantFPSDNode>(Elt)->getValueAPF();
+ if (!Val.bitwiseIsEqual(FirstVal))
+ return -1;
+ }
+
+ // Reject zero, NaN, or negative values
+ if (FirstVal.isZero() || FirstVal.isNaN() || FirstVal.isNegative())
+ return -1;
+
+ // Get raw bits
+ APInt Bits = FirstVal.bitcastToAPInt();
+
+ int ExponentBias = 0;
+ unsigned ExponentBits = 0;
+ unsigned MantissaBits = 0;
+
+ if (&Sem == &APFloat::IEEEsingle()) {
+ ExponentBias = 127;
+ ExponentBits = 8;
+ MantissaBits = 23;
+ } else if (&Sem == &APFloat::IEEEdouble()) {
+ ExponentBias = 1023;
+ ExponentBits = 11;
+ MantissaBits = 52;
+ } else {
+ // Unsupported type
+ return -1;
+ }
+
+ // Mask out mantissa and check it's zero (i.e., power of two)
+ APInt MantissaMask = APInt::getLowBitsSet(Bits.getBitWidth(), MantissaBits);
+ if ((Bits & MantissaMask) != 0)
+ return -1;
+
+ // Extract exponent
+ unsigned ExponentShift = MantissaBits;
+ APInt ExponentMask = APInt::getBitsSet(Bits.getBitWidth(),
+ ExponentShift,
+ ExponentShift + ExponentBits);
+ int Exponent = (Bits & ExponentMask).lshr(ExponentShift).getZExtValue();
+ int Log2 = ExponentBias - Exponent;
+
+ if (static_cast<unsigned>(Log2) > MaxExponent)
+ return -1;
+
+ return Log2;
+}
+
+/// Fold a floating-point multiply by power of two into fixed-point to
+/// floating-point conversion.
+static SDValue performFMulCombine(SDNode *N, SelectionDAG &DAG,
+ TargetLowering::DAGCombinerInfo &DCI,
+ const AArch64Subtarget *Subtarget) {
+
+ if (!Subtarget->hasNEON())
+ return SDValue();
+
+ // N is the FMUL node.
+ if (N->getOpcode() != ISD::FMUL)
+ return SDValue();
+
+ // SINT_TO_FP or UINT_TO_FP
+ SDValue Op = N->getOperand(0);
+ unsigned Opc = Op->getOpcode();
+ if (!Op.getValueType().isVector() || !Op.getValueType().isSimple() ||
+ !Op.getOperand(0).getValueType().isSimple() ||
+ (Opc != ISD::SINT_TO_FP && Opc != ISD::UINT_TO_FP))
+ return SDValue();
+
+ SDValue ConstVec = N->getOperand(1);
+ if (!isa<BuildVectorSDNode>(ConstVec))
+ return SDValue();
+
+ MVT IntTy = Op.getOperand(0).getSimpleValueType().getVectorElementType();
+ int32_t IntBits = IntTy.getSizeInBits();
+ if (IntBits != 16 && IntBits != 32 && IntBits != 64)
+ return SDValue();
+
+ MVT FloatTy = N->getSimpleValueType(0).getVectorElementType();
+ int32_t FloatBits = FloatTy.getSizeInBits();
+ if (FloatBits != 32 && FloatBits != 64)
+ return SDValue();
+
+ if (IntBits > FloatBits)
+ return SDValue();
+
+ BitVector UndefElements;
+ BuildVectorSDNode *BV = cast<BuildVectorSDNode>(ConstVec);
+ int32_t IntrinsicC = getUniformFPSplatLog2(BV, FloatBits + 1);
+
+ // Handle cases where it's not a power of two, or is 2^0.
+ if (IntrinsicC == -1 || IntrinsicC == 0)
+ return SDValue();
+
+ // Check if IntrinsicC is within the valid range [1, FloatBits].
+ // The 's' value must be in [1, FloatBits].
+ if (IntrinsicC <= 0 || IntrinsicC > FloatBits)
+ return SDValue();
+
+ MVT ResTy;
+ unsigned NumLanes = Op.getValueType().getVectorNumElements();
+ switch (NumLanes) {
+ default:
+ return SDValue();
+ case 2:
+ ResTy = FloatBits == 32 ? MVT::v2i32 : MVT::v2i64;
+ break;
+ case 4:
+ ResTy = FloatBits == 32 ? MVT::v4i32 : MVT::v4i64;
+ break;
+ }
+
+ if (ResTy == MVT::v4i64 && DCI.isBeforeLegalizeOps())
+ return SDValue();
+
+ SDLoc DL(N);
+ SDValue ConvInput = Op.getOperand(0);
+ bool IsSigned = Opc == ISD::SINT_TO_FP;
+
+ if (IntBits < FloatBits)
+ ConvInput = DAG.getNode(IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, DL,
+ ResTy, ConvInput);
+
+ unsigned IntrinsicOpcode = IsSigned ? Intrinsic::aarch64_neon_vcvtfxs2fp
+ : Intrinsic::aarch64_neon_vcvtfxu2fp;
+
+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(),
+ DAG.getConstant(IntrinsicOpcode, DL, MVT::i32), ConvInput,
+ DAG.getConstant(IntrinsicC, DL, MVT::i32));
+}
+
static SDValue tryCombineToBSL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
const AArch64TargetLowering &TLI) {
EVT VT = N->getValueType(0);
@@ -26693,6 +26843,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
case ISD::FP_TO_SINT_SAT:
case ISD::FP_TO_UINT_SAT:
return performFpToIntCombine(N, DAG, DCI, Subtarget);
+ case ISD::FMUL:
+ return performFMulCombine(N, DAG, DCI, Subtarget);
case ISD::OR:
return performORCombine(N, DCI, Subtarget, *this);
case ISD::AND:
diff --git a/llvm/test/Transforms/InstCombine/AArch64/scvtf-div-mul-combine.ll b/llvm/test/Transforms/InstCombine/AArch64/scvtf-div-mul-combine.ll
new file mode 100644
index 0000000000000..befddb165fcce
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/AArch64/scvtf-div-mul-combine.ll
@@ -0,0 +1,47 @@
+; RUN: llc -mtriple=aarch64-linux-gnu -aarch64-neon-syntax=apple -verify-machineinstrs -o - %s | FileCheck %s
+
+; Test case 1: Scalar fdiv by 16.0
+define float @tests(i32 %in) {
+; CHECK-LABEL: tests:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: scvtf s0, w0, #4
+; CHECK-NEXT: ret
+entry:
+ %vcvt.i = sitofp i32 %in to float
+ %div.i = fdiv float %vcvt.i, 16.0
+ ret float %div.i
+}
+
+; Test case 2: Scalar fmul by (2^-4)
+define float @testsmul(i32 %in) local_unnamed_addr #0 {
+; CHECK-LABEL: testsmul:
+; CHECK: // %bb.0:
+; CHECK-NEXT: scvtf s0, w0, #4
+; CHECK-NEXT: ret
+ %vcvt.i = sitofp i32 %in to float
+ %div.i = fmul float %vcvt.i, 6.250000e-02 ; 0.0625 is 2^-4
+ ret float %div.i
+}
+
+; Test case 3: Vector fdiv by 16.0
+define <2 x float> @testv(<2 x i32> %in) {
+; CHECK-LABEL: testv:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: scvtf.2s v0, v0, #4
+; CHECK-NEXT: ret
+entry:
+ %vcvt.i = sitofp <2 x i32> %in to <2 x float>
+ %div.i = fdiv <2 x float> %vcvt.i, <float 16.0, float 16.0>
+ ret <2 x float> %div.i
+}
+
+; Test case 4: Vector fmul by 2^-4
+define <2 x float> @testvmul(<2 x i32> %in) local_unnamed_addr #0 {
+; CHECK-LABEL: testvmul:
+; CHECK: // %bb.0:
+; CHECK-NEXT: scvtf.2s v0, v0, #4
+; CHECK-NEXT: ret
+ %vcvt.i = sitofp <2 x i32> %in to <2 x float>
+ %div.i = fmul <2 x float> %vcvt.i, splat (float 6.250000e-02) ; 0.0625 is 2^-4
+ ret <2 x float> %div.i
+}
\ No newline at end of file
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
It looks like the existing scalar code uses tablegen patterns, via SelectCVTFixedPosRecipOperand. Would it be possible to do the same for vector operations? It would need to detect the splat constant. |
I will take a look at the tablegen impls. Maybe that will help me understand why I am having issues with half too. |
I just resolved all the original flags since the new implementation is very different. I did try to get f16 working but I became very confused. As of now it doesn't appear to have a match in TD. I started creating one but I am not sure if I shold replace all the round tripping or not. So I figured I would see if we could get this in, then maybe try another pass at half later. |
fbb93fd
to
67a4484
Compare
Please ignore for now, not sure what I broke when I squashed. Sorry. |
This is incomplete, but I could really use some help. I have never touched tablegen before this and feel I am making it way too difficult. I am having specific issues with the smaller registers and sizes (are my matchers too restrictive?). I left in commented out vNi16 code for I can't get any of it to not conflict. The v1i32 and v1i64 are also escaping me. As for the f16 implementations, it seems I need to handle a sext, but I figure the rest should be working first. I could use any comments or guidance folks have. Thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like this got quite far. I've left a few comments inline, tablegen can be finicky at times.
llvm/test/Transforms/InstCombine/AArch64/scvtf-div-mul-combine.ll
Outdated
Show resolved
Hide resolved
…tf(x, 2) This commit reintroduces the optimization in InstCombine that was previously removed due to limited applicability. See: llvm#91924 This update targets `fmul(sitofp(x), C)` where `C` is a constant reciprocal of a power of two. For both scalar and vector inputs, if we have `sitofp(X) * C` (where `C` is `1/2^N`), this can be optimized to `scvtf(X, 2^N)`. This eliminates the floating-point multiply by directly converting the integer to a scaled floating-point value.
This commit reintroduces the optimization in InstCombine that was previously removed due to limited applicability. See: #91924
This update targets
fmul(sitofp(x), C)
whereC
is a constant reciprocal of a power of two. For both scalar and vector inputs, if we havesitofp(X) * C
(whereC
is1/2^N
), this can be optimized toscvtf(X, 2^N)
. This eliminates the floating-point multiply by directly converting the integer to a scaled floating-point value.