Skip to content

[VectorCombine] Support nary operands and intrinsics in scalarizeOpOrCmp #138406

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 28, 2025

Conversation

lukel97
Copy link
Contributor

@lukel97 lukel97 commented May 3, 2025

This adds support for unary operands, and unary + ternary intrinsics in scalarizeOpOrCmp (FKA scalarizeBinOpOrCmp).

The motivation behind this is to scalarize more intrinsics in VectorCombine rather than in DAGCombine, so we can sink splats across basic blocks: see #137786

The main change required is to generalize the existing VecC0/VecC1 rules across n-ary ops:

  • An operand can either be a constant vector or an insert of a scalar into a constant vector
  • If it's an insert, the index needs to be static and in bounds
  • If it's an insert, all indices need to be the same across all operands
  • If all the operands are constant vectors, bail as it will get constant folded anyway

Copy link

github-actions bot commented May 3, 2025

✅ With the latest revision this PR passed the undef deprecator.

@lukel97 lukel97 force-pushed the vector-combine/scalarize-nary branch from 2698f8a to 8e12a1f Compare May 23, 2025 11:58
@lukel97 lukel97 changed the title [WIP][VectorCombine] Support nary intrinsics in scalarizeBinOpOrCmp [VectorCombine] Support nary operands and intrinsics in scalarizeOpOrCmp May 23, 2025
@lukel97 lukel97 marked this pull request as ready for review May 23, 2025 12:08
@llvmbot
Copy link
Member

llvmbot commented May 23, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Luke Lau (lukel97)

Changes

This adds support for unary operands, and unary + ternary intrinsics in scalarizeOpOrCmp (FKA scalarizeBinOpOrCmp).

The motivation behind this is to scalarize more intrinsics in VectorCombine rather than in DAGCombine, so we can sink splats across basic blocks: see #137786

The main change required is to generalize the existing VecC0/VecC1 rules across n-ary ops:

  • An operand can either be a constant vector or an insert of a scalar into a constant vector
  • If it's an insert, the index needs to be static and in bounds
  • If it's an insert, all indices need to be the same across all operands
  • If all the operands are constant vectors, bail as it will get constant folded anyway

Stacked on #137823


Patch is 36.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/138406.diff

9 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/VectorCombine.cpp (+108-92)
  • (added) llvm/test/Transforms/VectorCombine/RISCV/binop-scalarize.ll (+16)
  • (modified) llvm/test/Transforms/VectorCombine/X86/insert-binop-with-constant-inseltpoison.ll (+4-4)
  • (modified) llvm/test/Transforms/VectorCombine/X86/insert-binop-with-constant.ll (+66-36)
  • (modified) llvm/test/Transforms/VectorCombine/X86/insert-binop.ll (+13-6)
  • (modified) llvm/test/Transforms/VectorCombine/X86/scalarize-cmp-inseltpoison.ll (+13-9)
  • (modified) llvm/test/Transforms/VectorCombine/X86/scalarize-cmp.ll (+22-13)
  • (modified) llvm/test/Transforms/VectorCombine/intrinsic-scalarize.ll (+56)
  • (added) llvm/test/Transforms/VectorCombine/unary-op-scalarize.ll (+26)
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index fe1d930f295ce..bf33292544497 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -19,6 +19,7 @@
 #include "llvm/ADT/Statistic.h"
 #include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/BasicAliasAnalysis.h"
+#include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/Analysis/GlobalsModRef.h"
 #include "llvm/Analysis/Loads.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
@@ -46,7 +47,7 @@ STATISTIC(NumVecCmp, "Number of vector compares formed");
 STATISTIC(NumVecBO, "Number of vector binops formed");
 STATISTIC(NumVecCmpBO, "Number of vector compare + binop formed");
 STATISTIC(NumShufOfBitcast, "Number of shuffles moved after bitcast");
-STATISTIC(NumScalarBO, "Number of scalar binops formed");
+STATISTIC(NumScalarOps, "Number of scalar unary + binary ops formed");
 STATISTIC(NumScalarCmp, "Number of scalar compares formed");
 STATISTIC(NumScalarIntrinsic, "Number of scalar intrinsic calls formed");
 
@@ -113,7 +114,7 @@ class VectorCombine {
   bool foldInsExtBinop(Instruction &I);
   bool foldInsExtVectorToShuffle(Instruction &I);
   bool foldBitcastShuffle(Instruction &I);
-  bool scalarizeBinopOrCmp(Instruction &I);
+  bool scalarizeOpOrCmp(Instruction &I);
   bool scalarizeVPIntrinsic(Instruction &I);
   bool foldExtractedCmps(Instruction &I);
   bool foldBinopOfReductions(Instruction &I);
@@ -1017,28 +1018,20 @@ bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
   return true;
 }
 
-/// Match a vector binop, compare or binop-like intrinsic with at least one
-/// inserted scalar operand and convert to scalar binop/cmp/intrinsic followed
+/// Match a vector op/compare/intrinsic with at least one
+/// inserted scalar operand and convert to scalar op/cmp/intrinsic followed
 /// by insertelement.
-bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
-  CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE;
-  Value *Ins0, *Ins1;
-  if (!match(&I, m_BinOp(m_Value(Ins0), m_Value(Ins1))) &&
-      !match(&I, m_Cmp(Pred, m_Value(Ins0), m_Value(Ins1)))) {
-    // TODO: Allow unary and ternary intrinsics
-    // TODO: Allow intrinsics with different argument types
-    // TODO: Allow intrinsics with scalar arguments
-    if (auto *II = dyn_cast<IntrinsicInst>(&I);
-        II && II->arg_size() == 2 &&
-        isTriviallyVectorizable(II->getIntrinsicID()) &&
-        all_of(II->args(),
-               [&II](Value *Arg) { return Arg->getType() == II->getType(); })) {
-      Ins0 = II->getArgOperand(0);
-      Ins1 = II->getArgOperand(1);
-    } else {
+bool VectorCombine::scalarizeOpOrCmp(Instruction &I) {
+  if (!isa<UnaryOperator, BinaryOperator, CmpInst, IntrinsicInst>(I))
+    return false;
+
+  // TODO: Allow intrinsics with different argument types
+  // TODO: Allow intrinsics with scalar arguments
+  if (auto *II = dyn_cast<IntrinsicInst>(&I))
+    if (!isTriviallyVectorizable(II->getIntrinsicID()) ||
+        !all_of(II->args(),
+                [&II](Value *Arg) { return Arg->getType() == II->getType(); }))
       return false;
-    }
-  }
 
   // Do not convert the vector condition of a vector select into a scalar
   // condition. That may cause problems for codegen because of differences in
@@ -1049,50 +1042,47 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
       if (match(U, m_Select(m_Specific(&I), m_Value(), m_Value())))
         return false;
 
-  // Match against one or both scalar values being inserted into constant
-  // vectors:
-  // vec_op VecC0, (inselt VecC1, V1, Index)
-  // vec_op (inselt VecC0, V0, Index), VecC1
-  // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index)
-  // TODO: Deal with mismatched index constants and variable indexes?
-  Constant *VecC0 = nullptr, *VecC1 = nullptr;
-  Value *V0 = nullptr, *V1 = nullptr;
-  uint64_t Index0 = 0, Index1 = 0;
-  if (!match(Ins0, m_InsertElt(m_Constant(VecC0), m_Value(V0),
-                               m_ConstantInt(Index0))) &&
-      !match(Ins0, m_Constant(VecC0)))
-    return false;
-  if (!match(Ins1, m_InsertElt(m_Constant(VecC1), m_Value(V1),
-                               m_ConstantInt(Index1))) &&
-      !match(Ins1, m_Constant(VecC1)))
-    return false;
-
-  bool IsConst0 = !V0;
-  bool IsConst1 = !V1;
-  if (IsConst0 && IsConst1)
-    return false;
-  if (!IsConst0 && !IsConst1 && Index0 != Index1)
-    return false;
-
-  auto *VecTy0 = cast<VectorType>(Ins0->getType());
-  auto *VecTy1 = cast<VectorType>(Ins1->getType());
-  if (VecTy0->getElementCount().getKnownMinValue() <= Index0 ||
-      VecTy1->getElementCount().getKnownMinValue() <= Index1)
-    return false;
+  // Match constant vectors or scalars being inserted into constant vectors:
+  // vec_op [VecC0 | (inselt VecC0, V0, Index)], ...
+  SmallVector<Constant *> VecCs;
+  SmallVector<Value *> ScalarOps;
+  std::optional<uint64_t> Index;
+
+  auto Ops = isa<IntrinsicInst>(I) ? cast<IntrinsicInst>(I).args()
+                                   : I.operand_values();
+  for (Value *Op : Ops) {
+    Constant *VecC;
+    Value *V;
+    uint64_t InsIdx = 0;
+    VectorType *OpTy = cast<VectorType>(Op->getType());
+    if (match(Op, m_InsertElt(m_Constant(VecC), m_Value(V),
+                              m_ConstantInt(InsIdx)))) {
+      // Bail if any inserts are out of bounds.
+      if (OpTy->getElementCount().getKnownMinValue() <= InsIdx)
+        return false;
+      // All inserts must have the same index.
+      // TODO: Deal with mismatched index constants and variable indexes?
+      if (!Index)
+        Index = InsIdx;
+      else if (InsIdx != *Index)
+        return false;
+      VecCs.push_back(VecC);
+      ScalarOps.push_back(V);
+    } else if (match(Op, m_Constant(VecC))) {
+      VecCs.push_back(VecC);
+      ScalarOps.push_back(nullptr);
+    } else {
+      return false;
+    }
+  }
 
-  // Bail for single insertion if it is a load.
-  // TODO: Handle this once getVectorInstrCost can cost for load/stores.
-  auto *I0 = dyn_cast_or_null<Instruction>(V0);
-  auto *I1 = dyn_cast_or_null<Instruction>(V1);
-  if ((IsConst0 && I1 && I1->mayReadFromMemory()) ||
-      (IsConst1 && I0 && I0->mayReadFromMemory()))
+  // Bail if all operands are constant.
+  if (!Index.has_value())
     return false;
 
-  uint64_t Index = IsConst0 ? Index1 : Index0;
-  Type *ScalarTy = IsConst0 ? V1->getType() : V0->getType();
-  Type *VecTy = I.getType();
+  VectorType *VecTy = cast<VectorType>(I.getType());
+  Type *ScalarTy = VecTy->getScalarType();
   assert(VecTy->isVectorTy() &&
-         (IsConst0 || IsConst1 || V0->getType() == V1->getType()) &&
          (ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy() ||
           ScalarTy->isPointerTy()) &&
          "Unexpected types for insert element into binop or cmp");
@@ -1105,7 +1095,7 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
         Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred, CostKind);
     VectorOpCost = TTI.getCmpSelInstrCost(
         Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred, CostKind);
-  } else if (isa<BinaryOperator>(I)) {
+  } else if (isa<UnaryOperator, BinaryOperator>(I)) {
     ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy, CostKind);
     VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy, CostKind);
   } else {
@@ -1120,15 +1110,37 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
     VectorOpCost = TTI.getIntrinsicInstrCost(VectorICA, CostKind);
   }
 
+  // Fold the vector constants in the original vectors into a new base vector to
+  // get more accurate cost modelling.
+  Value *NewVecC = nullptr;
+  if (auto *CI = dyn_cast<CmpInst>(&I))
+    NewVecC = ConstantFoldCompareInstOperands(CI->getPredicate(), VecCs[0],
+                                              VecCs[1], *DL);
+  else if (isa<UnaryOperator>(I))
+    NewVecC = ConstantFoldUnaryOpOperand((Instruction::UnaryOps)Opcode,
+                                         VecCs[0], *DL);
+  else if (isa<BinaryOperator>(I))
+    NewVecC = ConstantFoldBinaryOpOperands((Instruction::BinaryOps)Opcode,
+                                           VecCs[0], VecCs[1], *DL);
+  else if (isa<IntrinsicInst>(I) && cast<IntrinsicInst>(I).arg_size() == 2)
+    NewVecC =
+        ConstantFoldBinaryIntrinsic(cast<IntrinsicInst>(I).getIntrinsicID(),
+                                    VecCs[0], VecCs[1], I.getType(), &I);
+
   // Get cost estimate for the insert element. This cost will factor into
   // both sequences.
-  InstructionCost InsertCost = TTI.getVectorInstrCost(
-      Instruction::InsertElement, VecTy, CostKind, Index);
-  InstructionCost OldCost =
-      (IsConst0 ? 0 : InsertCost) + (IsConst1 ? 0 : InsertCost) + VectorOpCost;
-  InstructionCost NewCost = ScalarOpCost + InsertCost +
-                            (IsConst0 ? 0 : !Ins0->hasOneUse() * InsertCost) +
-                            (IsConst1 ? 0 : !Ins1->hasOneUse() * InsertCost);
+  InstructionCost OldCost = VectorOpCost;
+  InstructionCost NewCost =
+      ScalarOpCost + TTI.getVectorInstrCost(Instruction::InsertElement, VecTy,
+                                            CostKind, *Index, NewVecC);
+  for (auto [Op, VecC, Scalar] : zip(Ops, VecCs, ScalarOps)) {
+    if (!Scalar)
+      continue;
+    InstructionCost InsertCost = TTI.getVectorInstrCost(
+        Instruction::InsertElement, VecTy, CostKind, *Index, VecC, Scalar);
+    OldCost += InsertCost;
+    NewCost += !Op->hasOneUse() * InsertCost;
+  }
 
   // We want to scalarize unless the vector variant actually has lower cost.
   if (OldCost < NewCost || !NewCost.isValid())
@@ -1138,25 +1150,25 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
   // inselt NewVecC, (scalar_op V0, V1), Index
   if (isa<CmpInst>(I))
     ++NumScalarCmp;
-  else if (isa<BinaryOperator>(I))
-    ++NumScalarBO;
+  else if (isa<UnaryOperator, BinaryOperator>(I))
+    ++NumScalarOps;
   else if (isa<IntrinsicInst>(I))
     ++NumScalarIntrinsic;
 
   // For constant cases, extract the scalar element, this should constant fold.
-  if (IsConst0)
-    V0 = ConstantExpr::getExtractElement(VecC0, Builder.getInt64(Index));
-  if (IsConst1)
-    V1 = ConstantExpr::getExtractElement(VecC1, Builder.getInt64(Index));
+  for (auto [OpIdx, Scalar, VecC] : enumerate(ScalarOps, VecCs))
+    if (!Scalar)
+      ScalarOps[OpIdx] = ConstantExpr::getExtractElement(
+          cast<Constant>(VecC), Builder.getInt64(*Index));
 
   Value *Scalar;
-  if (isa<CmpInst>(I))
-    Scalar = Builder.CreateCmp(Pred, V0, V1);
-  else if (isa<BinaryOperator>(I))
-    Scalar = Builder.CreateBinOp((Instruction::BinaryOps)Opcode, V0, V1);
+  if (auto *CI = dyn_cast<CmpInst>(&I))
+    Scalar = Builder.CreateCmp(CI->getPredicate(), ScalarOps[0], ScalarOps[1]);
+  else if (isa<UnaryOperator, BinaryOperator>(I))
+    Scalar = Builder.CreateNAryOp(Opcode, ScalarOps);
   else
     Scalar = Builder.CreateIntrinsic(
-        ScalarTy, cast<IntrinsicInst>(I).getIntrinsicID(), {V0, V1});
+        ScalarTy, cast<IntrinsicInst>(I).getIntrinsicID(), ScalarOps);
 
   Scalar->setName(I.getName() + ".scalar");
 
@@ -1165,16 +1177,20 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
   if (auto *ScalarInst = dyn_cast<Instruction>(Scalar))
     ScalarInst->copyIRFlags(&I);
 
-  // Fold the vector constants in the original vectors into a new base vector.
-  Value *NewVecC;
-  if (isa<CmpInst>(I))
-    NewVecC = Builder.CreateCmp(Pred, VecC0, VecC1);
-  else if (isa<BinaryOperator>(I))
-    NewVecC = Builder.CreateBinOp((Instruction::BinaryOps)Opcode, VecC0, VecC1);
-  else
-    NewVecC = Builder.CreateIntrinsic(
-        VecTy, cast<IntrinsicInst>(I).getIntrinsicID(), {VecC0, VecC1});
-  Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, Index);
+  // Create a new base vector if the constant folding failed.
+  if (!NewVecC) {
+    SmallVector<Value *> VecCValues;
+    VecCValues.reserve(VecCs.size());
+    append_range(VecCValues, VecCs);
+    if (auto *CI = dyn_cast<CmpInst>(&I))
+      NewVecC = Builder.CreateCmp(CI->getPredicate(), VecCs[0], VecCs[1]);
+    else if (isa<UnaryOperator, BinaryOperator>(I))
+      NewVecC = Builder.CreateNAryOp(Opcode, VecCValues);
+    else
+      NewVecC = Builder.CreateIntrinsic(
+          VecTy, cast<IntrinsicInst>(I).getIntrinsicID(), VecCValues);
+  }
+  Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, *Index);
   replaceValue(I, *Insert);
   return true;
 }
@@ -3560,7 +3576,7 @@ bool VectorCombine::run() {
     // This transform works with scalable and fixed vectors
     // TODO: Identify and allow other scalable transforms
     if (IsVectorType) {
-      MadeChange |= scalarizeBinopOrCmp(I);
+      MadeChange |= scalarizeOpOrCmp(I);
       MadeChange |= scalarizeLoadExtract(I);
       MadeChange |= scalarizeVPIntrinsic(I);
       MadeChange |= foldInterleaveIntrinsics(I);
diff --git a/llvm/test/Transforms/VectorCombine/RISCV/binop-scalarize.ll b/llvm/test/Transforms/VectorCombine/RISCV/binop-scalarize.ll
new file mode 100644
index 0000000000000..ec4f6cc7520d1
--- /dev/null
+++ b/llvm/test/Transforms/VectorCombine/RISCV/binop-scalarize.ll
@@ -0,0 +1,16 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -S -p vector-combine -mtriple=riscv64 -mattr=+v | FileCheck %s
+
+define <4 x i32> @add_constant_load(ptr %p) {
+; CHECK-LABEL: define <4 x i32> @add_constant_load(
+; CHECK-SAME: ptr [[P:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT:    [[X:%.*]] = load i32, ptr [[P]], align 4
+; CHECK-NEXT:    [[V_SCALAR:%.*]] = add i32 [[X]], 42
+; CHECK-NEXT:    [[V:%.*]] = insertelement <4 x i32> poison, i32 [[V_SCALAR]], i64 0
+; CHECK-NEXT:    ret <4 x i32> [[V]]
+;
+  %x = load i32, ptr %p
+  %ins = insertelement <4 x i32> poison, i32 %x, i32 0
+  %v = add <4 x i32> %ins, splat (i32 42)
+  ret <4 x i32> %v
+}
diff --git a/llvm/test/Transforms/VectorCombine/X86/insert-binop-with-constant-inseltpoison.ll b/llvm/test/Transforms/VectorCombine/X86/insert-binop-with-constant-inseltpoison.ll
index d45d5f4d44ff3..564c9a795a794 100644
--- a/llvm/test/Transforms/VectorCombine/X86/insert-binop-with-constant-inseltpoison.ll
+++ b/llvm/test/Transforms/VectorCombine/X86/insert-binop-with-constant-inseltpoison.ll
@@ -153,8 +153,8 @@ define <2 x i64> @shl_constant_op0_not_undef_lane(i64 %x) {
 define <2 x i64> @shl_constant_op0_load(ptr %p) {
 ; CHECK-LABEL: @shl_constant_op0_load(
 ; CHECK-NEXT:    [[LD:%.*]] = load i64, ptr [[P:%.*]], align 8
-; CHECK-NEXT:    [[INS:%.*]] = insertelement <2 x i64> poison, i64 [[LD]], i32 1
-; CHECK-NEXT:    [[BO:%.*]] = shl <2 x i64> <i64 undef, i64 2>, [[INS]]
+; CHECK-NEXT:    [[BO_SCALAR:%.*]] = shl i64 2, [[LD]]
+; CHECK-NEXT:    [[BO:%.*]] = insertelement <2 x i64> poison, i64 [[BO_SCALAR]], i64 1
 ; CHECK-NEXT:    ret <2 x i64> [[BO]]
 ;
   %ld = load i64, ptr %p
@@ -204,8 +204,8 @@ define <2 x i64> @shl_constant_op1_not_undef_lane(i64 %x) {
 define <2 x i64> @shl_constant_op1_load(ptr %p) {
 ; CHECK-LABEL: @shl_constant_op1_load(
 ; CHECK-NEXT:    [[LD:%.*]] = load i64, ptr [[P:%.*]], align 8
-; CHECK-NEXT:    [[INS:%.*]] = insertelement <2 x i64> poison, i64 [[LD]], i32 0
-; CHECK-NEXT:    [[BO:%.*]] = shl nuw <2 x i64> [[INS]], <i64 5, i64 2>
+; CHECK-NEXT:    [[BO_SCALAR:%.*]] = shl nuw i64 [[LD]], 5
+; CHECK-NEXT:    [[BO:%.*]] = insertelement <2 x i64> poison, i64 [[BO_SCALAR]], i64 0
 ; CHECK-NEXT:    ret <2 x i64> [[BO]]
 ;
   %ld = load i64, ptr %p
diff --git a/llvm/test/Transforms/VectorCombine/X86/insert-binop-with-constant.ll b/llvm/test/Transforms/VectorCombine/X86/insert-binop-with-constant.ll
index 2b5a58ea44de4..cf3bd00527f81 100644
--- a/llvm/test/Transforms/VectorCombine/X86/insert-binop-with-constant.ll
+++ b/llvm/test/Transforms/VectorCombine/X86/insert-binop-with-constant.ll
@@ -153,8 +153,8 @@ define <2 x i64> @shl_constant_op0_not_undef_lane(i64 %x) {
 define <2 x i64> @shl_constant_op0_load(ptr %p) {
 ; CHECK-LABEL: @shl_constant_op0_load(
 ; CHECK-NEXT:    [[LD:%.*]] = load i64, ptr [[P:%.*]], align 8
-; CHECK-NEXT:    [[INS:%.*]] = insertelement <2 x i64> undef, i64 [[LD]], i32 1
-; CHECK-NEXT:    [[BO:%.*]] = shl <2 x i64> <i64 undef, i64 2>, [[INS]]
+; CHECK-NEXT:    [[BO_SCALAR:%.*]] = shl i64 2, [[LD]]
+; CHECK-NEXT:    [[BO:%.*]] = insertelement <2 x i64> poison, i64 [[BO_SCALAR]], i64 1
 ; CHECK-NEXT:    ret <2 x i64> [[BO]]
 ;
   %ld = load i64, ptr %p
@@ -204,8 +204,8 @@ define <2 x i64> @shl_constant_op1_not_undef_lane(i64 %x) {
 define <2 x i64> @shl_constant_op1_load(ptr %p) {
 ; CHECK-LABEL: @shl_constant_op1_load(
 ; CHECK-NEXT:    [[LD:%.*]] = load i64, ptr [[P:%.*]], align 8
-; CHECK-NEXT:    [[INS:%.*]] = insertelement <2 x i64> undef, i64 [[LD]], i32 0
-; CHECK-NEXT:    [[BO:%.*]] = shl nuw <2 x i64> [[INS]], <i64 5, i64 2>
+; CHECK-NEXT:    [[BO_SCALAR:%.*]] = shl nuw i64 [[LD]], 5
+; CHECK-NEXT:    [[BO:%.*]] = insertelement <2 x i64> zeroinitializer, i64 [[BO_SCALAR]], i64 0
 ; CHECK-NEXT:    ret <2 x i64> [[BO]]
 ;
   %ld = load i64, ptr %p
@@ -479,10 +479,15 @@ define <2 x i64> @sdiv_constant_op1_not_undef_lane(i64 %x) {
 }
 
 define <2 x i64> @and_constant(i64 %x) {
-; CHECK-LABEL: @and_constant(
-; CHECK-NEXT:    [[BO_SCALAR:%.*]] = and i64 [[X:%.*]], 42
-; CHECK-NEXT:    [[BO:%.*]] = insertelement <2 x i64> <i64 0, i64 undef>, i64 [[BO_SCALAR]], i64 0
-; CHECK-NEXT:    ret <2 x i64> [[BO]]
+; SSE-LABEL: @and_constant(
+; SSE-NEXT:    [[INS:%.*]] = insertelement <2 x i64> undef, i64 [[X:%.*]], i32 0
+; SSE-NEXT:    [[BO:%.*]] = and <2 x i64> [[INS]], <i64 42, i64 undef>
+; SSE-NEXT:    ret <2 x i64> [[BO]]
+;
+; AVX-LABEL: @and_constant(
+; AVX-NEXT:    [[BO_SCALAR:%.*]] = and i64 [[X:%.*]], 42
+; AVX-NEXT:    [[BO:%.*]] = insertelement <2 x i64> <i64 0, i64 undef>, i64 [[BO_SCALAR]], i64 0
+; AVX-NEXT:    ret <2 x i64> [[BO]]
 ;
   %ins = insertelement <2 x i64> undef, i64 %x, i32 0
   %bo = and <2 x i64> %ins, <i64 42, i64 undef>
@@ -490,10 +495,15 @@ define <2 x i64> @and_constant(i64 %x) {
 }
 
 define <2 x i64> @and_constant_not_undef_lane(i64 %x) {
-; CHECK-LABEL: @and_constant_not_undef_lane(
-; CHECK-NEXT:    [[BO_SCALAR:%.*]] = and i64 [[X:%.*]], 42
-; CHECK-NEXT:    [[BO:%.*]] = insertelement <2 x i64> zeroinitializer, i64 [[BO_SCALAR]], i64 0
-; CHECK-NEXT:    ret <2 x i64> [[BO]]
+; SSE-LABEL: @and_constant_not_undef_lane(
+; SSE-NEXT:    [[INS:%.*]] = insertelement <2 x i64> undef, i64 [[X:%.*]], i32 0
+; SSE-NEXT:    [[BO:%.*]] = and <2 x i64> [[INS]], <i64 42, i64 -42>
+; SSE-NEXT:    ret <2 x i64> [[BO]]
+;
+; AVX-LABEL: @and_constant_not_undef_lane(
+; AVX-NEXT:    [[BO_SCALAR:%.*]] = and i64 [[X:%.*]], 42
+; AVX-NEXT:    [[BO:%.*]] = insertelement <2 x i64> zeroinitializer, i64 [[BO_SCALAR]], i64 0
+; AVX-NEXT:    ret <2 x i64> [[BO]]
 ;
   %ins = insertelement <2 x i64> undef, i64 %x, i32 0
   %bo = and <2 x i64> %ins, <i64 42, i64 -42>
@@ -523,10 +533,15 @@ define <2 x i64> @or_constant_not_undef_lane(i64 %x) {
 }
 
 define <2 x i64> @xor_constant(i64 %x) {
-; CHECK-LABEL: @xor_constant(
-; CHECK-NEXT:    [[BO_SCALAR:%.*]] = xor i64 [[X:%.*]], 42
-; CHECK-NEXT:    [[BO:%.*]] = insertelement <2 x i64> <i64 undef, i64 0>, i64 [[BO_SCALAR]], i64 0
-; CHECK-NEXT:    ret <2 x i64> [[BO]]
+; SSE-LABEL: @xor_constant(
+; SSE-NEXT:    [[INS:%.*]] = insertelement <2 x i64> undef, i64 [[X:%.*]], i32 0
+; SSE-NEXT:    [[BO:%.*]] = xor <2 x i64> [[INS]], <i64 42, i64 undef>
+; SSE-NEXT:    ret <2 x i64> [[BO]]
+;
+; AVX-LABEL: @xor_constant(
+; AVX-NEXT:    [[BO_SCALAR:%.*]] = xor i64 [[X:%.*]], 42
+; AVX-NEXT:    [[BO:%.*]] = insertelement <2 x i64> <i64 undef, i64 0>, i64 [[BO_SCALAR]], i64 0
+; AVX-NEXT:    ret <2 x i64> [[BO]]
 ;
   %ins = insertelement <2 x i64> undef, i64 %x, i32 0
   %bo = xor <2 x i64> %ins, <i64 42, i64 undef>
@@ -546,8 +561,8 @@ define <2 x i64> @xor_constant_not_undef_lane(i64 %x) {
 
 define <2 x double> @fadd_constant(double %x) {
 ; CHECK-LABEL: @fadd_constant(
-; CHECK-NEXT:    [[BO_SCALAR:%.*]] = fadd double [[X:%.*]], 4.200000e+01
-; CHECK-NEXT:    [[BO:%.*]] = insertelement <2 x double> <double 0x7FF8000000000000, double undef>, double [[BO_SCALAR]], i64 0
+; CHECK-NEXT:    [[INS:%.*]] = insertelement <2 x double> undef, double [[X:%.*]], i32 0
+; CHECK-NEXT:    [[BO:%.*]] = fadd <2 x double> [[INS]], <double 4.200000e+01, double u...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 23, 2025

@llvm/pr-subscribers-vectorizers

Author: Luke Lau (lukel97)

Changes

This adds support for unary operands, and unary + ternary intrinsics in scalarizeOpOrCmp (FKA scalarizeBinOpOrCmp).

The motivation behind this is to scalarize more intrinsics in VectorCombine rather than in DAGCombine, so we can sink splats across basic blocks: see #137786

The main change required is to generalize the existing VecC0/VecC1 rules across n-ary ops:

  • An operand can either be a constant vector or an insert of a scalar into a constant vector
  • If it's an insert, the index needs to be static and in bounds
  • If it's an insert, all indices need to be the same across all operands
  • If all the operands are constant vectors, bail as it will get constant folded anyway

Stacked on #137823


Patch is 36.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/138406.diff

9 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/VectorCombine.cpp (+108-92)
  • (added) llvm/test/Transforms/VectorCombine/RISCV/binop-scalarize.ll (+16)
  • (modified) llvm/test/Transforms/VectorCombine/X86/insert-binop-with-constant-inseltpoison.ll (+4-4)
  • (modified) llvm/test/Transforms/VectorCombine/X86/insert-binop-with-constant.ll (+66-36)
  • (modified) llvm/test/Transforms/VectorCombine/X86/insert-binop.ll (+13-6)
  • (modified) llvm/test/Transforms/VectorCombine/X86/scalarize-cmp-inseltpoison.ll (+13-9)
  • (modified) llvm/test/Transforms/VectorCombine/X86/scalarize-cmp.ll (+22-13)
  • (modified) llvm/test/Transforms/VectorCombine/intrinsic-scalarize.ll (+56)
  • (added) llvm/test/Transforms/VectorCombine/unary-op-scalarize.ll (+26)
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index fe1d930f295ce..bf33292544497 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -19,6 +19,7 @@
 #include "llvm/ADT/Statistic.h"
 #include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/BasicAliasAnalysis.h"
+#include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/Analysis/GlobalsModRef.h"
 #include "llvm/Analysis/Loads.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
@@ -46,7 +47,7 @@ STATISTIC(NumVecCmp, "Number of vector compares formed");
 STATISTIC(NumVecBO, "Number of vector binops formed");
 STATISTIC(NumVecCmpBO, "Number of vector compare + binop formed");
 STATISTIC(NumShufOfBitcast, "Number of shuffles moved after bitcast");
-STATISTIC(NumScalarBO, "Number of scalar binops formed");
+STATISTIC(NumScalarOps, "Number of scalar unary + binary ops formed");
 STATISTIC(NumScalarCmp, "Number of scalar compares formed");
 STATISTIC(NumScalarIntrinsic, "Number of scalar intrinsic calls formed");
 
@@ -113,7 +114,7 @@ class VectorCombine {
   bool foldInsExtBinop(Instruction &I);
   bool foldInsExtVectorToShuffle(Instruction &I);
   bool foldBitcastShuffle(Instruction &I);
-  bool scalarizeBinopOrCmp(Instruction &I);
+  bool scalarizeOpOrCmp(Instruction &I);
   bool scalarizeVPIntrinsic(Instruction &I);
   bool foldExtractedCmps(Instruction &I);
   bool foldBinopOfReductions(Instruction &I);
@@ -1017,28 +1018,20 @@ bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
   return true;
 }
 
-/// Match a vector binop, compare or binop-like intrinsic with at least one
-/// inserted scalar operand and convert to scalar binop/cmp/intrinsic followed
+/// Match a vector op/compare/intrinsic with at least one
+/// inserted scalar operand and convert to scalar op/cmp/intrinsic followed
 /// by insertelement.
-bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
-  CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE;
-  Value *Ins0, *Ins1;
-  if (!match(&I, m_BinOp(m_Value(Ins0), m_Value(Ins1))) &&
-      !match(&I, m_Cmp(Pred, m_Value(Ins0), m_Value(Ins1)))) {
-    // TODO: Allow unary and ternary intrinsics
-    // TODO: Allow intrinsics with different argument types
-    // TODO: Allow intrinsics with scalar arguments
-    if (auto *II = dyn_cast<IntrinsicInst>(&I);
-        II && II->arg_size() == 2 &&
-        isTriviallyVectorizable(II->getIntrinsicID()) &&
-        all_of(II->args(),
-               [&II](Value *Arg) { return Arg->getType() == II->getType(); })) {
-      Ins0 = II->getArgOperand(0);
-      Ins1 = II->getArgOperand(1);
-    } else {
+bool VectorCombine::scalarizeOpOrCmp(Instruction &I) {
+  if (!isa<UnaryOperator, BinaryOperator, CmpInst, IntrinsicInst>(I))
+    return false;
+
+  // TODO: Allow intrinsics with different argument types
+  // TODO: Allow intrinsics with scalar arguments
+  if (auto *II = dyn_cast<IntrinsicInst>(&I))
+    if (!isTriviallyVectorizable(II->getIntrinsicID()) ||
+        !all_of(II->args(),
+                [&II](Value *Arg) { return Arg->getType() == II->getType(); }))
       return false;
-    }
-  }
 
   // Do not convert the vector condition of a vector select into a scalar
   // condition. That may cause problems for codegen because of differences in
@@ -1049,50 +1042,47 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
       if (match(U, m_Select(m_Specific(&I), m_Value(), m_Value())))
         return false;
 
-  // Match against one or both scalar values being inserted into constant
-  // vectors:
-  // vec_op VecC0, (inselt VecC1, V1, Index)
-  // vec_op (inselt VecC0, V0, Index), VecC1
-  // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index)
-  // TODO: Deal with mismatched index constants and variable indexes?
-  Constant *VecC0 = nullptr, *VecC1 = nullptr;
-  Value *V0 = nullptr, *V1 = nullptr;
-  uint64_t Index0 = 0, Index1 = 0;
-  if (!match(Ins0, m_InsertElt(m_Constant(VecC0), m_Value(V0),
-                               m_ConstantInt(Index0))) &&
-      !match(Ins0, m_Constant(VecC0)))
-    return false;
-  if (!match(Ins1, m_InsertElt(m_Constant(VecC1), m_Value(V1),
-                               m_ConstantInt(Index1))) &&
-      !match(Ins1, m_Constant(VecC1)))
-    return false;
-
-  bool IsConst0 = !V0;
-  bool IsConst1 = !V1;
-  if (IsConst0 && IsConst1)
-    return false;
-  if (!IsConst0 && !IsConst1 && Index0 != Index1)
-    return false;
-
-  auto *VecTy0 = cast<VectorType>(Ins0->getType());
-  auto *VecTy1 = cast<VectorType>(Ins1->getType());
-  if (VecTy0->getElementCount().getKnownMinValue() <= Index0 ||
-      VecTy1->getElementCount().getKnownMinValue() <= Index1)
-    return false;
+  // Match constant vectors or scalars being inserted into constant vectors:
+  // vec_op [VecC0 | (inselt VecC0, V0, Index)], ...
+  SmallVector<Constant *> VecCs;
+  SmallVector<Value *> ScalarOps;
+  std::optional<uint64_t> Index;
+
+  auto Ops = isa<IntrinsicInst>(I) ? cast<IntrinsicInst>(I).args()
+                                   : I.operand_values();
+  for (Value *Op : Ops) {
+    Constant *VecC;
+    Value *V;
+    uint64_t InsIdx = 0;
+    VectorType *OpTy = cast<VectorType>(Op->getType());
+    if (match(Op, m_InsertElt(m_Constant(VecC), m_Value(V),
+                              m_ConstantInt(InsIdx)))) {
+      // Bail if any inserts are out of bounds.
+      if (OpTy->getElementCount().getKnownMinValue() <= InsIdx)
+        return false;
+      // All inserts must have the same index.
+      // TODO: Deal with mismatched index constants and variable indexes?
+      if (!Index)
+        Index = InsIdx;
+      else if (InsIdx != *Index)
+        return false;
+      VecCs.push_back(VecC);
+      ScalarOps.push_back(V);
+    } else if (match(Op, m_Constant(VecC))) {
+      VecCs.push_back(VecC);
+      ScalarOps.push_back(nullptr);
+    } else {
+      return false;
+    }
+  }
 
-  // Bail for single insertion if it is a load.
-  // TODO: Handle this once getVectorInstrCost can cost for load/stores.
-  auto *I0 = dyn_cast_or_null<Instruction>(V0);
-  auto *I1 = dyn_cast_or_null<Instruction>(V1);
-  if ((IsConst0 && I1 && I1->mayReadFromMemory()) ||
-      (IsConst1 && I0 && I0->mayReadFromMemory()))
+  // Bail if all operands are constant.
+  if (!Index.has_value())
     return false;
 
-  uint64_t Index = IsConst0 ? Index1 : Index0;
-  Type *ScalarTy = IsConst0 ? V1->getType() : V0->getType();
-  Type *VecTy = I.getType();
+  VectorType *VecTy = cast<VectorType>(I.getType());
+  Type *ScalarTy = VecTy->getScalarType();
   assert(VecTy->isVectorTy() &&
-         (IsConst0 || IsConst1 || V0->getType() == V1->getType()) &&
          (ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy() ||
           ScalarTy->isPointerTy()) &&
          "Unexpected types for insert element into binop or cmp");
@@ -1105,7 +1095,7 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
         Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred, CostKind);
     VectorOpCost = TTI.getCmpSelInstrCost(
         Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred, CostKind);
-  } else if (isa<BinaryOperator>(I)) {
+  } else if (isa<UnaryOperator, BinaryOperator>(I)) {
     ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy, CostKind);
     VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy, CostKind);
   } else {
@@ -1120,15 +1110,37 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
     VectorOpCost = TTI.getIntrinsicInstrCost(VectorICA, CostKind);
   }
 
+  // Fold the vector constants in the original vectors into a new base vector to
+  // get more accurate cost modelling.
+  Value *NewVecC = nullptr;
+  if (auto *CI = dyn_cast<CmpInst>(&I))
+    NewVecC = ConstantFoldCompareInstOperands(CI->getPredicate(), VecCs[0],
+                                              VecCs[1], *DL);
+  else if (isa<UnaryOperator>(I))
+    NewVecC = ConstantFoldUnaryOpOperand((Instruction::UnaryOps)Opcode,
+                                         VecCs[0], *DL);
+  else if (isa<BinaryOperator>(I))
+    NewVecC = ConstantFoldBinaryOpOperands((Instruction::BinaryOps)Opcode,
+                                           VecCs[0], VecCs[1], *DL);
+  else if (isa<IntrinsicInst>(I) && cast<IntrinsicInst>(I).arg_size() == 2)
+    NewVecC =
+        ConstantFoldBinaryIntrinsic(cast<IntrinsicInst>(I).getIntrinsicID(),
+                                    VecCs[0], VecCs[1], I.getType(), &I);
+
   // Get cost estimate for the insert element. This cost will factor into
   // both sequences.
-  InstructionCost InsertCost = TTI.getVectorInstrCost(
-      Instruction::InsertElement, VecTy, CostKind, Index);
-  InstructionCost OldCost =
-      (IsConst0 ? 0 : InsertCost) + (IsConst1 ? 0 : InsertCost) + VectorOpCost;
-  InstructionCost NewCost = ScalarOpCost + InsertCost +
-                            (IsConst0 ? 0 : !Ins0->hasOneUse() * InsertCost) +
-                            (IsConst1 ? 0 : !Ins1->hasOneUse() * InsertCost);
+  InstructionCost OldCost = VectorOpCost;
+  InstructionCost NewCost =
+      ScalarOpCost + TTI.getVectorInstrCost(Instruction::InsertElement, VecTy,
+                                            CostKind, *Index, NewVecC);
+  for (auto [Op, VecC, Scalar] : zip(Ops, VecCs, ScalarOps)) {
+    if (!Scalar)
+      continue;
+    InstructionCost InsertCost = TTI.getVectorInstrCost(
+        Instruction::InsertElement, VecTy, CostKind, *Index, VecC, Scalar);
+    OldCost += InsertCost;
+    NewCost += !Op->hasOneUse() * InsertCost;
+  }
 
   // We want to scalarize unless the vector variant actually has lower cost.
   if (OldCost < NewCost || !NewCost.isValid())
@@ -1138,25 +1150,25 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
   // inselt NewVecC, (scalar_op V0, V1), Index
   if (isa<CmpInst>(I))
     ++NumScalarCmp;
-  else if (isa<BinaryOperator>(I))
-    ++NumScalarBO;
+  else if (isa<UnaryOperator, BinaryOperator>(I))
+    ++NumScalarOps;
   else if (isa<IntrinsicInst>(I))
     ++NumScalarIntrinsic;
 
   // For constant cases, extract the scalar element, this should constant fold.
-  if (IsConst0)
-    V0 = ConstantExpr::getExtractElement(VecC0, Builder.getInt64(Index));
-  if (IsConst1)
-    V1 = ConstantExpr::getExtractElement(VecC1, Builder.getInt64(Index));
+  for (auto [OpIdx, Scalar, VecC] : enumerate(ScalarOps, VecCs))
+    if (!Scalar)
+      ScalarOps[OpIdx] = ConstantExpr::getExtractElement(
+          cast<Constant>(VecC), Builder.getInt64(*Index));
 
   Value *Scalar;
-  if (isa<CmpInst>(I))
-    Scalar = Builder.CreateCmp(Pred, V0, V1);
-  else if (isa<BinaryOperator>(I))
-    Scalar = Builder.CreateBinOp((Instruction::BinaryOps)Opcode, V0, V1);
+  if (auto *CI = dyn_cast<CmpInst>(&I))
+    Scalar = Builder.CreateCmp(CI->getPredicate(), ScalarOps[0], ScalarOps[1]);
+  else if (isa<UnaryOperator, BinaryOperator>(I))
+    Scalar = Builder.CreateNAryOp(Opcode, ScalarOps);
   else
     Scalar = Builder.CreateIntrinsic(
-        ScalarTy, cast<IntrinsicInst>(I).getIntrinsicID(), {V0, V1});
+        ScalarTy, cast<IntrinsicInst>(I).getIntrinsicID(), ScalarOps);
 
   Scalar->setName(I.getName() + ".scalar");
 
@@ -1165,16 +1177,20 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
   if (auto *ScalarInst = dyn_cast<Instruction>(Scalar))
     ScalarInst->copyIRFlags(&I);
 
-  // Fold the vector constants in the original vectors into a new base vector.
-  Value *NewVecC;
-  if (isa<CmpInst>(I))
-    NewVecC = Builder.CreateCmp(Pred, VecC0, VecC1);
-  else if (isa<BinaryOperator>(I))
-    NewVecC = Builder.CreateBinOp((Instruction::BinaryOps)Opcode, VecC0, VecC1);
-  else
-    NewVecC = Builder.CreateIntrinsic(
-        VecTy, cast<IntrinsicInst>(I).getIntrinsicID(), {VecC0, VecC1});
-  Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, Index);
+  // Create a new base vector if the constant folding failed.
+  if (!NewVecC) {
+    SmallVector<Value *> VecCValues;
+    VecCValues.reserve(VecCs.size());
+    append_range(VecCValues, VecCs);
+    if (auto *CI = dyn_cast<CmpInst>(&I))
+      NewVecC = Builder.CreateCmp(CI->getPredicate(), VecCs[0], VecCs[1]);
+    else if (isa<UnaryOperator, BinaryOperator>(I))
+      NewVecC = Builder.CreateNAryOp(Opcode, VecCValues);
+    else
+      NewVecC = Builder.CreateIntrinsic(
+          VecTy, cast<IntrinsicInst>(I).getIntrinsicID(), VecCValues);
+  }
+  Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, *Index);
   replaceValue(I, *Insert);
   return true;
 }
@@ -3560,7 +3576,7 @@ bool VectorCombine::run() {
     // This transform works with scalable and fixed vectors
     // TODO: Identify and allow other scalable transforms
     if (IsVectorType) {
-      MadeChange |= scalarizeBinopOrCmp(I);
+      MadeChange |= scalarizeOpOrCmp(I);
       MadeChange |= scalarizeLoadExtract(I);
       MadeChange |= scalarizeVPIntrinsic(I);
       MadeChange |= foldInterleaveIntrinsics(I);
diff --git a/llvm/test/Transforms/VectorCombine/RISCV/binop-scalarize.ll b/llvm/test/Transforms/VectorCombine/RISCV/binop-scalarize.ll
new file mode 100644
index 0000000000000..ec4f6cc7520d1
--- /dev/null
+++ b/llvm/test/Transforms/VectorCombine/RISCV/binop-scalarize.ll
@@ -0,0 +1,16 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -S -p vector-combine -mtriple=riscv64 -mattr=+v | FileCheck %s
+
+define <4 x i32> @add_constant_load(ptr %p) {
+; CHECK-LABEL: define <4 x i32> @add_constant_load(
+; CHECK-SAME: ptr [[P:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT:    [[X:%.*]] = load i32, ptr [[P]], align 4
+; CHECK-NEXT:    [[V_SCALAR:%.*]] = add i32 [[X]], 42
+; CHECK-NEXT:    [[V:%.*]] = insertelement <4 x i32> poison, i32 [[V_SCALAR]], i64 0
+; CHECK-NEXT:    ret <4 x i32> [[V]]
+;
+  %x = load i32, ptr %p
+  %ins = insertelement <4 x i32> poison, i32 %x, i32 0
+  %v = add <4 x i32> %ins, splat (i32 42)
+  ret <4 x i32> %v
+}
diff --git a/llvm/test/Transforms/VectorCombine/X86/insert-binop-with-constant-inseltpoison.ll b/llvm/test/Transforms/VectorCombine/X86/insert-binop-with-constant-inseltpoison.ll
index d45d5f4d44ff3..564c9a795a794 100644
--- a/llvm/test/Transforms/VectorCombine/X86/insert-binop-with-constant-inseltpoison.ll
+++ b/llvm/test/Transforms/VectorCombine/X86/insert-binop-with-constant-inseltpoison.ll
@@ -153,8 +153,8 @@ define <2 x i64> @shl_constant_op0_not_undef_lane(i64 %x) {
 define <2 x i64> @shl_constant_op0_load(ptr %p) {
 ; CHECK-LABEL: @shl_constant_op0_load(
 ; CHECK-NEXT:    [[LD:%.*]] = load i64, ptr [[P:%.*]], align 8
-; CHECK-NEXT:    [[INS:%.*]] = insertelement <2 x i64> poison, i64 [[LD]], i32 1
-; CHECK-NEXT:    [[BO:%.*]] = shl <2 x i64> <i64 undef, i64 2>, [[INS]]
+; CHECK-NEXT:    [[BO_SCALAR:%.*]] = shl i64 2, [[LD]]
+; CHECK-NEXT:    [[BO:%.*]] = insertelement <2 x i64> poison, i64 [[BO_SCALAR]], i64 1
 ; CHECK-NEXT:    ret <2 x i64> [[BO]]
 ;
   %ld = load i64, ptr %p
@@ -204,8 +204,8 @@ define <2 x i64> @shl_constant_op1_not_undef_lane(i64 %x) {
 define <2 x i64> @shl_constant_op1_load(ptr %p) {
 ; CHECK-LABEL: @shl_constant_op1_load(
 ; CHECK-NEXT:    [[LD:%.*]] = load i64, ptr [[P:%.*]], align 8
-; CHECK-NEXT:    [[INS:%.*]] = insertelement <2 x i64> poison, i64 [[LD]], i32 0
-; CHECK-NEXT:    [[BO:%.*]] = shl nuw <2 x i64> [[INS]], <i64 5, i64 2>
+; CHECK-NEXT:    [[BO_SCALAR:%.*]] = shl nuw i64 [[LD]], 5
+; CHECK-NEXT:    [[BO:%.*]] = insertelement <2 x i64> poison, i64 [[BO_SCALAR]], i64 0
 ; CHECK-NEXT:    ret <2 x i64> [[BO]]
 ;
   %ld = load i64, ptr %p
diff --git a/llvm/test/Transforms/VectorCombine/X86/insert-binop-with-constant.ll b/llvm/test/Transforms/VectorCombine/X86/insert-binop-with-constant.ll
index 2b5a58ea44de4..cf3bd00527f81 100644
--- a/llvm/test/Transforms/VectorCombine/X86/insert-binop-with-constant.ll
+++ b/llvm/test/Transforms/VectorCombine/X86/insert-binop-with-constant.ll
@@ -153,8 +153,8 @@ define <2 x i64> @shl_constant_op0_not_undef_lane(i64 %x) {
 define <2 x i64> @shl_constant_op0_load(ptr %p) {
 ; CHECK-LABEL: @shl_constant_op0_load(
 ; CHECK-NEXT:    [[LD:%.*]] = load i64, ptr [[P:%.*]], align 8
-; CHECK-NEXT:    [[INS:%.*]] = insertelement <2 x i64> undef, i64 [[LD]], i32 1
-; CHECK-NEXT:    [[BO:%.*]] = shl <2 x i64> <i64 undef, i64 2>, [[INS]]
+; CHECK-NEXT:    [[BO_SCALAR:%.*]] = shl i64 2, [[LD]]
+; CHECK-NEXT:    [[BO:%.*]] = insertelement <2 x i64> poison, i64 [[BO_SCALAR]], i64 1
 ; CHECK-NEXT:    ret <2 x i64> [[BO]]
 ;
   %ld = load i64, ptr %p
@@ -204,8 +204,8 @@ define <2 x i64> @shl_constant_op1_not_undef_lane(i64 %x) {
 define <2 x i64> @shl_constant_op1_load(ptr %p) {
 ; CHECK-LABEL: @shl_constant_op1_load(
 ; CHECK-NEXT:    [[LD:%.*]] = load i64, ptr [[P:%.*]], align 8
-; CHECK-NEXT:    [[INS:%.*]] = insertelement <2 x i64> undef, i64 [[LD]], i32 0
-; CHECK-NEXT:    [[BO:%.*]] = shl nuw <2 x i64> [[INS]], <i64 5, i64 2>
+; CHECK-NEXT:    [[BO_SCALAR:%.*]] = shl nuw i64 [[LD]], 5
+; CHECK-NEXT:    [[BO:%.*]] = insertelement <2 x i64> zeroinitializer, i64 [[BO_SCALAR]], i64 0
 ; CHECK-NEXT:    ret <2 x i64> [[BO]]
 ;
   %ld = load i64, ptr %p
@@ -479,10 +479,15 @@ define <2 x i64> @sdiv_constant_op1_not_undef_lane(i64 %x) {
 }
 
 define <2 x i64> @and_constant(i64 %x) {
-; CHECK-LABEL: @and_constant(
-; CHECK-NEXT:    [[BO_SCALAR:%.*]] = and i64 [[X:%.*]], 42
-; CHECK-NEXT:    [[BO:%.*]] = insertelement <2 x i64> <i64 0, i64 undef>, i64 [[BO_SCALAR]], i64 0
-; CHECK-NEXT:    ret <2 x i64> [[BO]]
+; SSE-LABEL: @and_constant(
+; SSE-NEXT:    [[INS:%.*]] = insertelement <2 x i64> undef, i64 [[X:%.*]], i32 0
+; SSE-NEXT:    [[BO:%.*]] = and <2 x i64> [[INS]], <i64 42, i64 undef>
+; SSE-NEXT:    ret <2 x i64> [[BO]]
+;
+; AVX-LABEL: @and_constant(
+; AVX-NEXT:    [[BO_SCALAR:%.*]] = and i64 [[X:%.*]], 42
+; AVX-NEXT:    [[BO:%.*]] = insertelement <2 x i64> <i64 0, i64 undef>, i64 [[BO_SCALAR]], i64 0
+; AVX-NEXT:    ret <2 x i64> [[BO]]
 ;
   %ins = insertelement <2 x i64> undef, i64 %x, i32 0
   %bo = and <2 x i64> %ins, <i64 42, i64 undef>
@@ -490,10 +495,15 @@ define <2 x i64> @and_constant(i64 %x) {
 }
 
 define <2 x i64> @and_constant_not_undef_lane(i64 %x) {
-; CHECK-LABEL: @and_constant_not_undef_lane(
-; CHECK-NEXT:    [[BO_SCALAR:%.*]] = and i64 [[X:%.*]], 42
-; CHECK-NEXT:    [[BO:%.*]] = insertelement <2 x i64> zeroinitializer, i64 [[BO_SCALAR]], i64 0
-; CHECK-NEXT:    ret <2 x i64> [[BO]]
+; SSE-LABEL: @and_constant_not_undef_lane(
+; SSE-NEXT:    [[INS:%.*]] = insertelement <2 x i64> undef, i64 [[X:%.*]], i32 0
+; SSE-NEXT:    [[BO:%.*]] = and <2 x i64> [[INS]], <i64 42, i64 -42>
+; SSE-NEXT:    ret <2 x i64> [[BO]]
+;
+; AVX-LABEL: @and_constant_not_undef_lane(
+; AVX-NEXT:    [[BO_SCALAR:%.*]] = and i64 [[X:%.*]], 42
+; AVX-NEXT:    [[BO:%.*]] = insertelement <2 x i64> zeroinitializer, i64 [[BO_SCALAR]], i64 0
+; AVX-NEXT:    ret <2 x i64> [[BO]]
 ;
   %ins = insertelement <2 x i64> undef, i64 %x, i32 0
   %bo = and <2 x i64> %ins, <i64 42, i64 -42>
@@ -523,10 +533,15 @@ define <2 x i64> @or_constant_not_undef_lane(i64 %x) {
 }
 
 define <2 x i64> @xor_constant(i64 %x) {
-; CHECK-LABEL: @xor_constant(
-; CHECK-NEXT:    [[BO_SCALAR:%.*]] = xor i64 [[X:%.*]], 42
-; CHECK-NEXT:    [[BO:%.*]] = insertelement <2 x i64> <i64 undef, i64 0>, i64 [[BO_SCALAR]], i64 0
-; CHECK-NEXT:    ret <2 x i64> [[BO]]
+; SSE-LABEL: @xor_constant(
+; SSE-NEXT:    [[INS:%.*]] = insertelement <2 x i64> undef, i64 [[X:%.*]], i32 0
+; SSE-NEXT:    [[BO:%.*]] = xor <2 x i64> [[INS]], <i64 42, i64 undef>
+; SSE-NEXT:    ret <2 x i64> [[BO]]
+;
+; AVX-LABEL: @xor_constant(
+; AVX-NEXT:    [[BO_SCALAR:%.*]] = xor i64 [[X:%.*]], 42
+; AVX-NEXT:    [[BO:%.*]] = insertelement <2 x i64> <i64 undef, i64 0>, i64 [[BO_SCALAR]], i64 0
+; AVX-NEXT:    ret <2 x i64> [[BO]]
 ;
   %ins = insertelement <2 x i64> undef, i64 %x, i32 0
   %bo = xor <2 x i64> %ins, <i64 42, i64 undef>
@@ -546,8 +561,8 @@ define <2 x i64> @xor_constant_not_undef_lane(i64 %x) {
 
 define <2 x double> @fadd_constant(double %x) {
 ; CHECK-LABEL: @fadd_constant(
-; CHECK-NEXT:    [[BO_SCALAR:%.*]] = fadd double [[X:%.*]], 4.200000e+01
-; CHECK-NEXT:    [[BO:%.*]] = insertelement <2 x double> <double 0x7FF8000000000000, double undef>, double [[BO_SCALAR]], i64 0
+; CHECK-NEXT:    [[INS:%.*]] = insertelement <2 x double> undef, double [[X:%.*]], i32 0
+; CHECK-NEXT:    [[BO:%.*]] = fadd <2 x double> [[INS]], <double 4.200000e+01, double u...
[truncated]

VecCs[0], VecCs[1], *DL);
else if (isa<IntrinsicInst>(I) && cast<IntrinsicInst>(I).arg_size() == 2)
NewVecC =
ConstantFoldBinaryIntrinsic(cast<IntrinsicInst>(I).getIntrinsicID(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there nothing more general we can use to fold other op counts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not that I could find, there's ConstantFoldCall but it looks like it doesn't handle any scalable intrinsics, apart from aarch64_sve_convert_from_svbool. For fixed vectors it actually calls ConstantFoldScalarCall on each element, which in turn calls ConstantFoldIntrinsicCall2 for binary intrinsics.

else if (isa<BinaryOperator>(I))
NewVecC = ConstantFoldBinaryOpOperands((Instruction::BinaryOps)Opcode,
VecCs[0], VecCs[1], *DL);
else if (isa<IntrinsicInst>(I) && cast<IntrinsicInst>(I).arg_size() == 2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
  if (II->arg_size() == 2)
    NewVecC = ConstantFoldBinaryIntrinsic(II->getIntrinsicID(), VecCs[0], VecCs[1], I.getType(), &I);
}


// TODO: Allow intrinsics with different argument types
// TODO: Allow intrinsics with scalar arguments
if (auto *II = dyn_cast<IntrinsicInst>(&I))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we hoist the various dyn_cast we have for UnaryOperator/BinaryOperator/CmpInst/IntrinsicInst so we just call them all once and then check for non-null instead of a mix of dyn_cast/isa/cast everywhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, done in e99fed6

@lukel97 lukel97 force-pushed the vector-combine/scalarize-nary branch from b8529e5 to e99fed6 Compare May 27, 2025 17:30
Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - cheers

@lukel97 lukel97 merged commit 2b9ded6 into llvm:main May 28, 2025
11 checks passed
@llvm-ci
Copy link
Collaborator

llvm-ci commented May 28, 2025

LLVM Buildbot has detected a new failure on builder openmp-offload-sles-build-only running on rocm-worker-hw-04-sles while building llvm at step 5 "compile-openmp".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/140/builds/23952

Here is the relevant piece of the build log for the reference
Step 5 (compile-openmp) failure: build (failure)
...
/home/botworker/bbot/builds/openmp-offload-sles-build/llvm.src/clang/include/clang/Basic/PointerAuthOptions.h:75:39: warning: ‘clang::PointerAuthSchema::DiscriminationKind’ is too small to hold all values of ‘enum class clang::PointerAuthSchema::Discrimination’
   Discrimination DiscriminationKind : 2;
                                       ^
cc1plus: warning: unrecognized command line option ‘-Wno-unnecessary-virtual-specifier’
7.599 [3468/32/3705] Building CXX object tools/clang/lib/Basic/CMakeFiles/obj.clangBasic.dir/XRayLists.cpp.o
7.605 [3467/32/3706] Building LinalgStructuredOps.h.inc...
7.608 [3466/32/3707] Building LinalgStructuredOps.cpp.inc...
7.609 [3465/32/3708] Building CXX object tools/clang/lib/Lex/CMakeFiles/obj.clangLex.dir/HeaderMap.cpp.o
7.611 [3464/32/3709] Building CXX object tools/clang/lib/Lex/CMakeFiles/obj.clangLex.dir/DependencyDirectivesScanner.cpp.o
7.616 [3463/32/3710] Building CXX object lib/Transforms/Vectorize/CMakeFiles/LLVMVectorize.dir/VectorCombine.cpp.o
FAILED: lib/Transforms/Vectorize/CMakeFiles/LLVMVectorize.dir/VectorCombine.cpp.o 
ccache /usr/bin/c++ -DGTEST_HAS_RTTI=0 -D_DEBUG -D_GLIBCXX_ASSERTIONS -D_GNU_SOURCE -D__STDC_CONSTANT_MACROS -D__STDC_FORMAT_MACROS -D__STDC_LIMIT_MACROS -Ilib/Transforms/Vectorize -I/home/botworker/bbot/builds/openmp-offload-sles-build/llvm.src/llvm/lib/Transforms/Vectorize -Iinclude -I/home/botworker/bbot/builds/openmp-offload-sles-build/llvm.src/llvm/include -fPIC -fno-semantic-interposition -fvisibility-inlines-hidden -Werror=date-time -Wall -Wextra -Wno-unused-parameter -Wwrite-strings -Wcast-qual -Wno-missing-field-initializers -pedantic -Wno-long-long -Wimplicit-fallthrough -Wno-uninitialized -Wno-nonnull -Wno-noexcept-type -Wno-unnecessary-virtual-specifier -Wdelete-non-virtual-dtor -Wno-comment -Wno-misleading-indentation -fdiagnostics-color -ffunction-sections -fdata-sections -O3 -DNDEBUG  -fno-exceptions -funwind-tables -fno-rtti -UNDEBUG -std=c++1z -MD -MT lib/Transforms/Vectorize/CMakeFiles/LLVMVectorize.dir/VectorCombine.cpp.o -MF lib/Transforms/Vectorize/CMakeFiles/LLVMVectorize.dir/VectorCombine.cpp.o.d -o lib/Transforms/Vectorize/CMakeFiles/LLVMVectorize.dir/VectorCombine.cpp.o -c /home/botworker/bbot/builds/openmp-offload-sles-build/llvm.src/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
/home/botworker/bbot/builds/openmp-offload-sles-build/llvm.src/llvm/lib/Transforms/Vectorize/VectorCombine.cpp: In member function ‘bool {anonymous}::VectorCombine::scalarizeOpOrCmp(llvm::Instruction&)’:
/home/botworker/bbot/builds/openmp-offload-sles-build/llvm.src/llvm/lib/Transforms/Vectorize/VectorCombine.cpp:1055:17: error: operands to ?: have different types ‘llvm::iterator_range<llvm::Use*>’ and ‘llvm::iterator_range<llvm::User::value_op_iterator>’
   auto Ops = II ? II->args() : I.operand_values();
              ~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/botworker/bbot/builds/openmp-offload-sles-build/llvm.src/llvm/lib/Transforms/Vectorize/VectorCombine.cpp:1055:17: note:   and each type can be converted to the other
/home/botworker/bbot/builds/openmp-offload-sles-build/llvm.src/llvm/lib/Transforms/Vectorize/VectorCombine.cpp:1056:20: error: unable to deduce ‘auto&&’ from ‘Ops’
   for (Value *Op : Ops) {
                    ^~~
At global scope:
cc1plus: warning: unrecognized command line option ‘-Wno-unnecessary-virtual-specifier’
7.617 [3463/31/3711] Building CXX object tools/clang/lib/Lex/CMakeFiles/obj.clangLex.dir/LexHLSLRootSignature.cpp.o
7.618 [3463/30/3712] Building CXX object tools/clang/lib/Lex/CMakeFiles/obj.clangLex.dir/InitHeaderSearch.cpp.o
7.621 [3463/29/3713] Building LinalgRelayoutOps.cpp.inc...
7.623 [3463/28/3714] Building LinalgRelayoutOps.h.inc...
7.670 [3463/27/3715] Building CXX object tools/clang/lib/Tooling/DependencyScanning/CMakeFiles/obj.clangDependencyScanning.dir/ModuleDepCollector.cpp.o
In file included from /home/botworker/bbot/builds/openmp-offload-sles-build/llvm.src/clang/include/clang/Basic/CodeGenOptions.h:17:0,
                 from /home/botworker/bbot/builds/openmp-offload-sles-build/llvm.src/clang/include/clang/Frontend/CompilerInvocation.h:13,
                 from /home/botworker/bbot/builds/openmp-offload-sles-build/llvm.src/clang/include/clang/Tooling/DependencyScanning/ModuleDepCollector.h:15,
                 from /home/botworker/bbot/builds/openmp-offload-sles-build/llvm.src/clang/lib/Tooling/DependencyScanning/ModuleDepCollector.cpp:9:
/home/botworker/bbot/builds/openmp-offload-sles-build/llvm.src/clang/include/clang/Basic/PointerAuthOptions.h:70:18: warning: ‘clang::PointerAuthSchema::TheKind’ is too small to hold all values of ‘enum class clang::PointerAuthSchema::Kind’
   Kind TheKind : 2;
                  ^
/home/botworker/bbot/builds/openmp-offload-sles-build/llvm.src/clang/include/clang/Basic/PointerAuthOptions.h:74:58: warning: ‘clang::PointerAuthSchema::SelectedAuthenticationMode’ is too small to hold all values of ‘enum class clang::PointerAuthenticationMode’
   PointerAuthenticationMode SelectedAuthenticationMode : 2;
                                                          ^
/home/botworker/bbot/builds/openmp-offload-sles-build/llvm.src/clang/include/clang/Basic/PointerAuthOptions.h:75:39: warning: ‘clang::PointerAuthSchema::DiscriminationKind’ is too small to hold all values of ‘enum class clang::PointerAuthSchema::Discrimination’
   Discrimination DiscriminationKind : 2;
                                       ^
In file included from /home/botworker/bbot/builds/openmp-offload-sles-build/llvm.src/clang/include/clang/Tooling/DependencyScanning/ModuleDepCollector.h:19:0,
                 from /home/botworker/bbot/builds/openmp-offload-sles-build/llvm.src/clang/lib/Tooling/DependencyScanning/ModuleDepCollector.cpp:9:
/home/botworker/bbot/builds/openmp-offload-sles-build/llvm.src/clang/include/clang/Serialization/ASTReader.h:245:16: warning: ‘virtual bool clang::ASTReaderListener::visitInputFile(llvm::StringRef, llvm::StringRef, bool, bool, bool)’ was hidden [-Woverloaded-virtual]
   virtual bool visitInputFile(StringRef FilenameAsRequested, StringRef Filename,
                ^~~~~~~~~~~~~~
/home/botworker/bbot/builds/openmp-offload-sles-build/llvm.src/clang/include/clang/Serialization/ASTReader.h:306:8: warning:   by ‘virtual bool clang::ChainedASTReaderListener::visitInputFile(llvm::StringRef, bool, bool, bool)’ [-Woverloaded-virtual]
   bool visitInputFile(StringRef Filename, bool isSystem,
        ^~~~~~~~~~~~~~
cc1plus: warning: unrecognized command line option ‘-Wno-unnecessary-virtual-specifier’

@llvm llvm locked and limited conversation to collaborators May 29, 2025
@llvm llvm unlocked this conversation May 29, 2025
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Jun 3, 2025
…Cmp (llvm#138406)

This adds support for unary operands, and unary + ternary intrinsics in
scalarizeOpOrCmp (FKA scalarizeBinOpOrCmp).

The motivation behind this is to scalarize more intrinsics in
VectorCombine rather than in DAGCombine, so we can sink splats across
basic blocks: see llvm#137786

The main change required is to generalize the existing VecC0/VecC1 rules
across n-ary ops:

- An operand can either be a constant vector or an insert of a scalar
into a constant vector
- If it's an insert, the index needs to be static and in bounds
- If it's an insert, all indices need to be the same across all operands
- If all the operands are constant vectors, bail as it will get constant
folded anyway
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants