Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SLP]Fix perfect diamond match with extractelements in scalars #132466

Conversation

alexey-bataev
Copy link
Member

@alexey-bataev alexey-bataev commented Mar 21, 2025

Need to drop all previous estimations/vectorizations, when found
a perfect diamond match. This improves cost estimation and improves code
emission.
Also, need to adjust getScalarizationOverhead cost for non-poison input
vector. Currently, it does not allow to estimate it correctly, so
instead use conservative element-by-element insertelement cost for each
unique scalar.

Created using spr 1.3.5
@llvmbot
Copy link
Member

llvmbot commented Mar 21, 2025

@llvm/pr-subscribers-vectorizers

Author: Alexey Bataev (alexey-bataev)

Changes

Need to drop all previous estimations/vectorizations, when found
a perfect diamond match. This improves cost estimation and improves code
emission.
Also, need to adjust getScalarizationOverhead cost for non-poison input
vector. Currently, it does not allow to estimate it correctly, so
instead add a cost of the insertion of the first vector element into
non-poison vector value and then remaining elements.


Full diff: https://github.com/llvm/llvm-project/pull/132466.diff

3 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+45-14)
  • (modified) llvm/test/Transforms/SLPVectorizer/X86/buildvector-with-reuses.ll (+9-12)
  • (modified) llvm/test/Transforms/SLPVectorizer/X86/reduction-transpose.ll (+8-36)
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 0201955b8b559..7050549d61d74 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -5310,12 +5310,11 @@ getShuffleCost(const TargetTransformInfo &TTI, TTI::ShuffleKind Kind,
 /// This is similar to TargetTransformInfo::getScalarizationOverhead, but if
 /// ScalarTy is a FixedVectorType, a vector will be inserted or extracted
 /// instead of a scalar.
-static InstructionCost getScalarizationOverhead(const TargetTransformInfo &TTI,
-                                                Type *ScalarTy, VectorType *Ty,
-                                                const APInt &DemandedElts,
-                                                bool Insert, bool Extract,
-                                                TTI::TargetCostKind CostKind,
-                                                ArrayRef<Value *> VL = {}) {
+static InstructionCost
+getScalarizationOverhead(const TargetTransformInfo &TTI, Type *ScalarTy,
+                         VectorType *Ty, const APInt &DemandedElts, bool Insert,
+                         bool Extract, TTI::TargetCostKind CostKind,
+                         bool ForPoisonSrc = true, ArrayRef<Value *> VL = {}) {
   assert(!isa<ScalableVectorType>(Ty) &&
          "ScalableVectorType is not supported.");
   assert(getNumElements(ScalarTy) * DemandedElts.getBitWidth() ==
@@ -5339,8 +5338,19 @@ static InstructionCost getScalarizationOverhead(const TargetTransformInfo &TTI,
     }
     return Cost;
   }
-  return TTI.getScalarizationOverhead(Ty, DemandedElts, Insert, Extract,
-                                      CostKind, VL);
+  APInt NewDemandedElts = DemandedElts;
+  InstructionCost Cost = 0;
+  if (!ForPoisonSrc && Insert) {
+    // Handle insert into non-poison vector.
+    unsigned LeftMostBit = NewDemandedElts.countr_zero();
+    NewDemandedElts.clearBit(LeftMostBit);
+    Cost += TTI.getVectorInstrCost(Instruction::InsertElement, Ty, CostKind,
+                                   LeftMostBit, Constant::getNullValue(Ty));
+  }
+  return Cost + (NewDemandedElts.isZero()
+                     ? 0
+                     : TTI.getScalarizationOverhead(Ty, NewDemandedElts, Insert,
+                                                    Extract, CostKind, VL));
 }
 
 /// Correctly creates insert_subvector, checking that the index is multiple of
@@ -11684,6 +11694,15 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
     // No need to delay the cost estimation during analysis.
     return std::nullopt;
   }
+  /// Reset the builder to handle perfect diamond match.
+  void resetForSameNode() {
+    IsFinalized = false;
+    CommonMask.clear();
+    InVectors.clear();
+    Cost = 0;
+    VectorizedVals.clear();
+    SameNodesEstimated = true;
+  }
   void add(const TreeEntry &E1, const TreeEntry &E2, ArrayRef<int> Mask) {
     if (&E1 == &E2) {
       assert(all_of(Mask,
@@ -14890,15 +14909,18 @@ InstructionCost BoUpSLP::getGatherCost(ArrayRef<Value *> VL, bool ForPoisonSrc,
     ShuffledElements.setBit(I);
     ShuffleMask[I] = Res.first->second;
   }
-  if (!DemandedElements.isZero())
-    Cost += getScalarizationOverhead(*TTI, ScalarTy, VecTy, DemandedElements,
-                                     /*Insert=*/true,
-                                     /*Extract=*/false, CostKind, VL);
-  if (ForPoisonSrc)
+  if (ForPoisonSrc) {
     Cost = getScalarizationOverhead(*TTI, ScalarTy, VecTy,
                                     /*DemandedElts*/ ~ShuffledElements,
                                     /*Insert*/ true,
-                                    /*Extract*/ false, CostKind, VL);
+                                    /*Extract*/ false, CostKind,
+                                    /*ForPoisonSrc=*/true, VL);
+  } else if (!DemandedElements.isZero()) {
+    Cost += getScalarizationOverhead(*TTI, ScalarTy, VecTy, DemandedElements,
+                                     /*Insert=*/true,
+                                     /*Extract=*/false, CostKind,
+                                     /*ForPoisonSrc=*/false, VL);
+  }
   if (DuplicateNonConst)
     Cost += ::getShuffleCost(*TTI, TargetTransformInfo::SK_PermuteSingleSrc,
                              VecTy, ShuffleMask);
@@ -15556,6 +15578,12 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
         PoisonValue::get(PointerType::getUnqual(ScalarTy->getContext())),
         MaybeAlign());
   }
+  /// Reset the builder to handle perfect diamond match.
+  void resetForSameNode() {
+    IsFinalized = false;
+    CommonMask.clear();
+    InVectors.clear();
+  }
   /// Adds 2 input vectors (in form of tree entries) and the mask for their
   /// shuffling.
   void add(const TreeEntry &E1, const TreeEntry &E2, ArrayRef<int> Mask) {
@@ -16111,6 +16139,9 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Type *ScalarTy,
             Mask[I] = FrontTE->findLaneForValue(V);
           }
         }
+        // Reset the builder(s) to correctly handle perfect diamond matched
+        // nodes.
+        ShuffleBuilder.resetForSameNode();
         ShuffleBuilder.add(*FrontTE, Mask);
         // Full matched entry found, no need to insert subvectors.
         Res = ShuffleBuilder.finalize(E->getCommonMask(), {}, {});
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/buildvector-with-reuses.ll b/llvm/test/Transforms/SLPVectorizer/X86/buildvector-with-reuses.ll
index 75a413ffc1fb1..579239bc659bd 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/buildvector-with-reuses.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/buildvector-with-reuses.ll
@@ -10,18 +10,15 @@ define <4 x double> @test(ptr %ia, ptr %ib, ptr %ic, ptr %id, ptr %ie, ptr %x) {
 ; CHECK-NEXT:    [[I4275:%.*]] = load double, ptr [[ID]], align 8
 ; CHECK-NEXT:    [[I4277:%.*]] = load double, ptr [[IE]], align 8
 ; CHECK-NEXT:    [[I4326:%.*]] = load <4 x double>, ptr [[X]], align 8
-; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x double> [[I4326]], <4 x double> poison, <2 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP2:%.*]] = insertelement <2 x double> poison, double [[I4238]], i32 0
-; CHECK-NEXT:    [[TMP3:%.*]] = insertelement <2 x double> [[TMP2]], double [[I4252]], i32 1
-; CHECK-NEXT:    [[TMP4:%.*]] = fmul fast <2 x double> [[TMP1]], [[TMP3]]
-; CHECK-NEXT:    [[TMP5:%.*]] = insertelement <2 x double> [[TMP1]], double [[I4275]], i32 1
-; CHECK-NEXT:    [[TMP6:%.*]] = insertelement <2 x double> poison, double [[I4264]], i32 0
-; CHECK-NEXT:    [[TMP7:%.*]] = insertelement <2 x double> [[TMP6]], double [[I4277]], i32 1
-; CHECK-NEXT:    [[TMP8:%.*]] = fmul fast <2 x double> [[TMP5]], [[TMP7]]
-; CHECK-NEXT:    [[TMP9:%.*]] = shufflevector <2 x double> [[TMP4]], <2 x double> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
-; CHECK-NEXT:    [[TMP10:%.*]] = shufflevector <2 x double> [[TMP8]], <2 x double> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
-; CHECK-NEXT:    [[I44281:%.*]] = shufflevector <4 x double> [[TMP9]], <4 x double> [[TMP10]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
-; CHECK-NEXT:    ret <4 x double> [[I44281]]
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x double> [[I4326]], <4 x double> poison, <2 x i32> <i32 0, i32 poison>
+; CHECK-NEXT:    [[TMP2:%.*]] = insertelement <2 x double> [[TMP1]], double [[I4275]], i32 1
+; CHECK-NEXT:    [[TMP3:%.*]] = shufflevector <2 x double> [[TMP2]], <2 x double> poison, <4 x i32> <i32 0, i32 0, i32 0, i32 1>
+; CHECK-NEXT:    [[TMP4:%.*]] = insertelement <4 x double> poison, double [[I4238]], i32 0
+; CHECK-NEXT:    [[TMP5:%.*]] = insertelement <4 x double> [[TMP4]], double [[I4252]], i32 1
+; CHECK-NEXT:    [[TMP6:%.*]] = insertelement <4 x double> [[TMP5]], double [[I4264]], i32 2
+; CHECK-NEXT:    [[TMP7:%.*]] = insertelement <4 x double> [[TMP6]], double [[I4277]], i32 3
+; CHECK-NEXT:    [[TMP8:%.*]] = fmul fast <4 x double> [[TMP3]], [[TMP7]]
+; CHECK-NEXT:    ret <4 x double> [[TMP8]]
 ;
   %i4238 = load double, ptr %ia, align 8
   %i4252 = load double, ptr %ib, align 8
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/reduction-transpose.ll b/llvm/test/Transforms/SLPVectorizer/X86/reduction-transpose.ll
index cb4783010965e..32dccd353da17 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/reduction-transpose.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/reduction-transpose.ll
@@ -49,24 +49,10 @@ define i32 @reduce_and4(i32 %acc, <4 x i32> %v1, <4 x i32> %v2, <4 x i32> %v3, <
 ;
 ; AVX512-LABEL: @reduce_and4(
 ; AVX512-NEXT:  entry:
-; AVX512-NEXT:    [[VECEXT:%.*]] = extractelement <4 x i32> [[V1:%.*]], i64 0
-; AVX512-NEXT:    [[VECEXT1:%.*]] = extractelement <4 x i32> [[V1]], i64 1
-; AVX512-NEXT:    [[VECEXT2:%.*]] = extractelement <4 x i32> [[V1]], i64 2
-; AVX512-NEXT:    [[VECEXT4:%.*]] = extractelement <4 x i32> [[V1]], i64 3
-; AVX512-NEXT:    [[VECEXT7:%.*]] = extractelement <4 x i32> [[V2:%.*]], i64 0
-; AVX512-NEXT:    [[VECEXT8:%.*]] = extractelement <4 x i32> [[V2]], i64 1
-; AVX512-NEXT:    [[VECEXT10:%.*]] = extractelement <4 x i32> [[V2]], i64 2
-; AVX512-NEXT:    [[VECEXT12:%.*]] = extractelement <4 x i32> [[V2]], i64 3
-; AVX512-NEXT:    [[TMP0:%.*]] = shufflevector <4 x i32> [[V4:%.*]], <4 x i32> [[V3:%.*]], <16 x i32> <i32 1, i32 0, i32 2, i32 3, i32 5, i32 4, i32 6, i32 7, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
-; AVX512-NEXT:    [[TMP1:%.*]] = insertelement <16 x i32> [[TMP0]], i32 [[VECEXT8]], i32 8
-; AVX512-NEXT:    [[TMP2:%.*]] = insertelement <16 x i32> [[TMP1]], i32 [[VECEXT7]], i32 9
-; AVX512-NEXT:    [[TMP3:%.*]] = insertelement <16 x i32> [[TMP2]], i32 [[VECEXT10]], i32 10
-; AVX512-NEXT:    [[TMP4:%.*]] = insertelement <16 x i32> [[TMP3]], i32 [[VECEXT12]], i32 11
-; AVX512-NEXT:    [[TMP5:%.*]] = insertelement <16 x i32> [[TMP4]], i32 [[VECEXT1]], i32 12
-; AVX512-NEXT:    [[TMP6:%.*]] = insertelement <16 x i32> [[TMP5]], i32 [[VECEXT]], i32 13
-; AVX512-NEXT:    [[TMP7:%.*]] = insertelement <16 x i32> [[TMP6]], i32 [[VECEXT2]], i32 14
-; AVX512-NEXT:    [[TMP8:%.*]] = insertelement <16 x i32> [[TMP7]], i32 [[VECEXT4]], i32 15
-; AVX512-NEXT:    [[OP_RDX:%.*]] = call i32 @llvm.vector.reduce.and.v16i32(<16 x i32> [[TMP8]])
+; AVX512-NEXT:    [[TMP0:%.*]] = shufflevector <4 x i32> [[V4:%.*]], <4 x i32> [[V3:%.*]], <8 x i32> <i32 1, i32 0, i32 2, i32 3, i32 5, i32 4, i32 6, i32 7>
+; AVX512-NEXT:    [[TMP1:%.*]] = shufflevector <4 x i32> [[V2:%.*]], <4 x i32> [[V1:%.*]], <8 x i32> <i32 1, i32 0, i32 2, i32 3, i32 5, i32 4, i32 6, i32 7>
+; AVX512-NEXT:    [[RDX_OP:%.*]] = and <8 x i32> [[TMP0]], [[TMP1]]
+; AVX512-NEXT:    [[OP_RDX:%.*]] = call i32 @llvm.vector.reduce.and.v8i32(<8 x i32> [[RDX_OP]])
 ; AVX512-NEXT:    [[OP_RDX1:%.*]] = and i32 [[OP_RDX]], [[ACC:%.*]]
 ; AVX512-NEXT:    ret i32 [[OP_RDX1]]
 ;
@@ -144,24 +130,10 @@ define i32 @reduce_and4_transpose(i32 %acc, <4 x i32> %v1, <4 x i32> %v2, <4 x i
 ; AVX2-NEXT:    ret i32 [[OP_RDX]]
 ;
 ; AVX512-LABEL: @reduce_and4_transpose(
-; AVX512-NEXT:    [[VECEXT:%.*]] = extractelement <4 x i32> [[V1:%.*]], i64 0
-; AVX512-NEXT:    [[VECEXT1:%.*]] = extractelement <4 x i32> [[V2:%.*]], i64 0
-; AVX512-NEXT:    [[VECEXT7:%.*]] = extractelement <4 x i32> [[V1]], i64 1
-; AVX512-NEXT:    [[VECEXT8:%.*]] = extractelement <4 x i32> [[V2]], i64 1
-; AVX512-NEXT:    [[VECEXT15:%.*]] = extractelement <4 x i32> [[V1]], i64 2
-; AVX512-NEXT:    [[VECEXT16:%.*]] = extractelement <4 x i32> [[V2]], i64 2
-; AVX512-NEXT:    [[VECEXT23:%.*]] = extractelement <4 x i32> [[V1]], i64 3
-; AVX512-NEXT:    [[VECEXT24:%.*]] = extractelement <4 x i32> [[V2]], i64 3
-; AVX512-NEXT:    [[TMP1:%.*]] = shufflevector <4 x i32> [[V4:%.*]], <4 x i32> [[V3:%.*]], <16 x i32> <i32 3, i32 2, i32 1, i32 0, i32 7, i32 6, i32 5, i32 4, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
-; AVX512-NEXT:    [[TMP2:%.*]] = insertelement <16 x i32> [[TMP1]], i32 [[VECEXT24]], i32 8
-; AVX512-NEXT:    [[TMP3:%.*]] = insertelement <16 x i32> [[TMP2]], i32 [[VECEXT16]], i32 9
-; AVX512-NEXT:    [[TMP4:%.*]] = insertelement <16 x i32> [[TMP3]], i32 [[VECEXT8]], i32 10
-; AVX512-NEXT:    [[TMP5:%.*]] = insertelement <16 x i32> [[TMP4]], i32 [[VECEXT1]], i32 11
-; AVX512-NEXT:    [[TMP6:%.*]] = insertelement <16 x i32> [[TMP5]], i32 [[VECEXT23]], i32 12
-; AVX512-NEXT:    [[TMP7:%.*]] = insertelement <16 x i32> [[TMP6]], i32 [[VECEXT15]], i32 13
-; AVX512-NEXT:    [[TMP8:%.*]] = insertelement <16 x i32> [[TMP7]], i32 [[VECEXT7]], i32 14
-; AVX512-NEXT:    [[TMP9:%.*]] = insertelement <16 x i32> [[TMP8]], i32 [[VECEXT]], i32 15
-; AVX512-NEXT:    [[OP_RDX:%.*]] = call i32 @llvm.vector.reduce.and.v16i32(<16 x i32> [[TMP9]])
+; AVX512-NEXT:    [[TMP1:%.*]] = shufflevector <4 x i32> [[V4:%.*]], <4 x i32> [[V3:%.*]], <8 x i32> <i32 3, i32 2, i32 1, i32 0, i32 7, i32 6, i32 5, i32 4>
+; AVX512-NEXT:    [[TMP2:%.*]] = shufflevector <4 x i32> [[V2:%.*]], <4 x i32> [[V1:%.*]], <8 x i32> <i32 3, i32 2, i32 1, i32 0, i32 7, i32 6, i32 5, i32 4>
+; AVX512-NEXT:    [[RDX_OP:%.*]] = and <8 x i32> [[TMP1]], [[TMP2]]
+; AVX512-NEXT:    [[OP_RDX:%.*]] = call i32 @llvm.vector.reduce.and.v8i32(<8 x i32> [[RDX_OP]])
 ; AVX512-NEXT:    [[OP_RDX1:%.*]] = and i32 [[OP_RDX]], [[ACC:%.*]]
 ; AVX512-NEXT:    ret i32 [[OP_RDX1]]
 ;

@llvmbot
Copy link
Member

llvmbot commented Mar 21, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Alexey Bataev (alexey-bataev)

Changes

Need to drop all previous estimations/vectorizations, when found
a perfect diamond match. This improves cost estimation and improves code
emission.
Also, need to adjust getScalarizationOverhead cost for non-poison input
vector. Currently, it does not allow to estimate it correctly, so
instead add a cost of the insertion of the first vector element into
non-poison vector value and then remaining elements.


Full diff: https://github.com/llvm/llvm-project/pull/132466.diff

3 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+45-14)
  • (modified) llvm/test/Transforms/SLPVectorizer/X86/buildvector-with-reuses.ll (+9-12)
  • (modified) llvm/test/Transforms/SLPVectorizer/X86/reduction-transpose.ll (+8-36)
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 0201955b8b559..7050549d61d74 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -5310,12 +5310,11 @@ getShuffleCost(const TargetTransformInfo &TTI, TTI::ShuffleKind Kind,
 /// This is similar to TargetTransformInfo::getScalarizationOverhead, but if
 /// ScalarTy is a FixedVectorType, a vector will be inserted or extracted
 /// instead of a scalar.
-static InstructionCost getScalarizationOverhead(const TargetTransformInfo &TTI,
-                                                Type *ScalarTy, VectorType *Ty,
-                                                const APInt &DemandedElts,
-                                                bool Insert, bool Extract,
-                                                TTI::TargetCostKind CostKind,
-                                                ArrayRef<Value *> VL = {}) {
+static InstructionCost
+getScalarizationOverhead(const TargetTransformInfo &TTI, Type *ScalarTy,
+                         VectorType *Ty, const APInt &DemandedElts, bool Insert,
+                         bool Extract, TTI::TargetCostKind CostKind,
+                         bool ForPoisonSrc = true, ArrayRef<Value *> VL = {}) {
   assert(!isa<ScalableVectorType>(Ty) &&
          "ScalableVectorType is not supported.");
   assert(getNumElements(ScalarTy) * DemandedElts.getBitWidth() ==
@@ -5339,8 +5338,19 @@ static InstructionCost getScalarizationOverhead(const TargetTransformInfo &TTI,
     }
     return Cost;
   }
-  return TTI.getScalarizationOverhead(Ty, DemandedElts, Insert, Extract,
-                                      CostKind, VL);
+  APInt NewDemandedElts = DemandedElts;
+  InstructionCost Cost = 0;
+  if (!ForPoisonSrc && Insert) {
+    // Handle insert into non-poison vector.
+    unsigned LeftMostBit = NewDemandedElts.countr_zero();
+    NewDemandedElts.clearBit(LeftMostBit);
+    Cost += TTI.getVectorInstrCost(Instruction::InsertElement, Ty, CostKind,
+                                   LeftMostBit, Constant::getNullValue(Ty));
+  }
+  return Cost + (NewDemandedElts.isZero()
+                     ? 0
+                     : TTI.getScalarizationOverhead(Ty, NewDemandedElts, Insert,
+                                                    Extract, CostKind, VL));
 }
 
 /// Correctly creates insert_subvector, checking that the index is multiple of
@@ -11684,6 +11694,15 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
     // No need to delay the cost estimation during analysis.
     return std::nullopt;
   }
+  /// Reset the builder to handle perfect diamond match.
+  void resetForSameNode() {
+    IsFinalized = false;
+    CommonMask.clear();
+    InVectors.clear();
+    Cost = 0;
+    VectorizedVals.clear();
+    SameNodesEstimated = true;
+  }
   void add(const TreeEntry &E1, const TreeEntry &E2, ArrayRef<int> Mask) {
     if (&E1 == &E2) {
       assert(all_of(Mask,
@@ -14890,15 +14909,18 @@ InstructionCost BoUpSLP::getGatherCost(ArrayRef<Value *> VL, bool ForPoisonSrc,
     ShuffledElements.setBit(I);
     ShuffleMask[I] = Res.first->second;
   }
-  if (!DemandedElements.isZero())
-    Cost += getScalarizationOverhead(*TTI, ScalarTy, VecTy, DemandedElements,
-                                     /*Insert=*/true,
-                                     /*Extract=*/false, CostKind, VL);
-  if (ForPoisonSrc)
+  if (ForPoisonSrc) {
     Cost = getScalarizationOverhead(*TTI, ScalarTy, VecTy,
                                     /*DemandedElts*/ ~ShuffledElements,
                                     /*Insert*/ true,
-                                    /*Extract*/ false, CostKind, VL);
+                                    /*Extract*/ false, CostKind,
+                                    /*ForPoisonSrc=*/true, VL);
+  } else if (!DemandedElements.isZero()) {
+    Cost += getScalarizationOverhead(*TTI, ScalarTy, VecTy, DemandedElements,
+                                     /*Insert=*/true,
+                                     /*Extract=*/false, CostKind,
+                                     /*ForPoisonSrc=*/false, VL);
+  }
   if (DuplicateNonConst)
     Cost += ::getShuffleCost(*TTI, TargetTransformInfo::SK_PermuteSingleSrc,
                              VecTy, ShuffleMask);
@@ -15556,6 +15578,12 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
         PoisonValue::get(PointerType::getUnqual(ScalarTy->getContext())),
         MaybeAlign());
   }
+  /// Reset the builder to handle perfect diamond match.
+  void resetForSameNode() {
+    IsFinalized = false;
+    CommonMask.clear();
+    InVectors.clear();
+  }
   /// Adds 2 input vectors (in form of tree entries) and the mask for their
   /// shuffling.
   void add(const TreeEntry &E1, const TreeEntry &E2, ArrayRef<int> Mask) {
@@ -16111,6 +16139,9 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Type *ScalarTy,
             Mask[I] = FrontTE->findLaneForValue(V);
           }
         }
+        // Reset the builder(s) to correctly handle perfect diamond matched
+        // nodes.
+        ShuffleBuilder.resetForSameNode();
         ShuffleBuilder.add(*FrontTE, Mask);
         // Full matched entry found, no need to insert subvectors.
         Res = ShuffleBuilder.finalize(E->getCommonMask(), {}, {});
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/buildvector-with-reuses.ll b/llvm/test/Transforms/SLPVectorizer/X86/buildvector-with-reuses.ll
index 75a413ffc1fb1..579239bc659bd 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/buildvector-with-reuses.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/buildvector-with-reuses.ll
@@ -10,18 +10,15 @@ define <4 x double> @test(ptr %ia, ptr %ib, ptr %ic, ptr %id, ptr %ie, ptr %x) {
 ; CHECK-NEXT:    [[I4275:%.*]] = load double, ptr [[ID]], align 8
 ; CHECK-NEXT:    [[I4277:%.*]] = load double, ptr [[IE]], align 8
 ; CHECK-NEXT:    [[I4326:%.*]] = load <4 x double>, ptr [[X]], align 8
-; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x double> [[I4326]], <4 x double> poison, <2 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP2:%.*]] = insertelement <2 x double> poison, double [[I4238]], i32 0
-; CHECK-NEXT:    [[TMP3:%.*]] = insertelement <2 x double> [[TMP2]], double [[I4252]], i32 1
-; CHECK-NEXT:    [[TMP4:%.*]] = fmul fast <2 x double> [[TMP1]], [[TMP3]]
-; CHECK-NEXT:    [[TMP5:%.*]] = insertelement <2 x double> [[TMP1]], double [[I4275]], i32 1
-; CHECK-NEXT:    [[TMP6:%.*]] = insertelement <2 x double> poison, double [[I4264]], i32 0
-; CHECK-NEXT:    [[TMP7:%.*]] = insertelement <2 x double> [[TMP6]], double [[I4277]], i32 1
-; CHECK-NEXT:    [[TMP8:%.*]] = fmul fast <2 x double> [[TMP5]], [[TMP7]]
-; CHECK-NEXT:    [[TMP9:%.*]] = shufflevector <2 x double> [[TMP4]], <2 x double> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
-; CHECK-NEXT:    [[TMP10:%.*]] = shufflevector <2 x double> [[TMP8]], <2 x double> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
-; CHECK-NEXT:    [[I44281:%.*]] = shufflevector <4 x double> [[TMP9]], <4 x double> [[TMP10]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
-; CHECK-NEXT:    ret <4 x double> [[I44281]]
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x double> [[I4326]], <4 x double> poison, <2 x i32> <i32 0, i32 poison>
+; CHECK-NEXT:    [[TMP2:%.*]] = insertelement <2 x double> [[TMP1]], double [[I4275]], i32 1
+; CHECK-NEXT:    [[TMP3:%.*]] = shufflevector <2 x double> [[TMP2]], <2 x double> poison, <4 x i32> <i32 0, i32 0, i32 0, i32 1>
+; CHECK-NEXT:    [[TMP4:%.*]] = insertelement <4 x double> poison, double [[I4238]], i32 0
+; CHECK-NEXT:    [[TMP5:%.*]] = insertelement <4 x double> [[TMP4]], double [[I4252]], i32 1
+; CHECK-NEXT:    [[TMP6:%.*]] = insertelement <4 x double> [[TMP5]], double [[I4264]], i32 2
+; CHECK-NEXT:    [[TMP7:%.*]] = insertelement <4 x double> [[TMP6]], double [[I4277]], i32 3
+; CHECK-NEXT:    [[TMP8:%.*]] = fmul fast <4 x double> [[TMP3]], [[TMP7]]
+; CHECK-NEXT:    ret <4 x double> [[TMP8]]
 ;
   %i4238 = load double, ptr %ia, align 8
   %i4252 = load double, ptr %ib, align 8
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/reduction-transpose.ll b/llvm/test/Transforms/SLPVectorizer/X86/reduction-transpose.ll
index cb4783010965e..32dccd353da17 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/reduction-transpose.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/reduction-transpose.ll
@@ -49,24 +49,10 @@ define i32 @reduce_and4(i32 %acc, <4 x i32> %v1, <4 x i32> %v2, <4 x i32> %v3, <
 ;
 ; AVX512-LABEL: @reduce_and4(
 ; AVX512-NEXT:  entry:
-; AVX512-NEXT:    [[VECEXT:%.*]] = extractelement <4 x i32> [[V1:%.*]], i64 0
-; AVX512-NEXT:    [[VECEXT1:%.*]] = extractelement <4 x i32> [[V1]], i64 1
-; AVX512-NEXT:    [[VECEXT2:%.*]] = extractelement <4 x i32> [[V1]], i64 2
-; AVX512-NEXT:    [[VECEXT4:%.*]] = extractelement <4 x i32> [[V1]], i64 3
-; AVX512-NEXT:    [[VECEXT7:%.*]] = extractelement <4 x i32> [[V2:%.*]], i64 0
-; AVX512-NEXT:    [[VECEXT8:%.*]] = extractelement <4 x i32> [[V2]], i64 1
-; AVX512-NEXT:    [[VECEXT10:%.*]] = extractelement <4 x i32> [[V2]], i64 2
-; AVX512-NEXT:    [[VECEXT12:%.*]] = extractelement <4 x i32> [[V2]], i64 3
-; AVX512-NEXT:    [[TMP0:%.*]] = shufflevector <4 x i32> [[V4:%.*]], <4 x i32> [[V3:%.*]], <16 x i32> <i32 1, i32 0, i32 2, i32 3, i32 5, i32 4, i32 6, i32 7, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
-; AVX512-NEXT:    [[TMP1:%.*]] = insertelement <16 x i32> [[TMP0]], i32 [[VECEXT8]], i32 8
-; AVX512-NEXT:    [[TMP2:%.*]] = insertelement <16 x i32> [[TMP1]], i32 [[VECEXT7]], i32 9
-; AVX512-NEXT:    [[TMP3:%.*]] = insertelement <16 x i32> [[TMP2]], i32 [[VECEXT10]], i32 10
-; AVX512-NEXT:    [[TMP4:%.*]] = insertelement <16 x i32> [[TMP3]], i32 [[VECEXT12]], i32 11
-; AVX512-NEXT:    [[TMP5:%.*]] = insertelement <16 x i32> [[TMP4]], i32 [[VECEXT1]], i32 12
-; AVX512-NEXT:    [[TMP6:%.*]] = insertelement <16 x i32> [[TMP5]], i32 [[VECEXT]], i32 13
-; AVX512-NEXT:    [[TMP7:%.*]] = insertelement <16 x i32> [[TMP6]], i32 [[VECEXT2]], i32 14
-; AVX512-NEXT:    [[TMP8:%.*]] = insertelement <16 x i32> [[TMP7]], i32 [[VECEXT4]], i32 15
-; AVX512-NEXT:    [[OP_RDX:%.*]] = call i32 @llvm.vector.reduce.and.v16i32(<16 x i32> [[TMP8]])
+; AVX512-NEXT:    [[TMP0:%.*]] = shufflevector <4 x i32> [[V4:%.*]], <4 x i32> [[V3:%.*]], <8 x i32> <i32 1, i32 0, i32 2, i32 3, i32 5, i32 4, i32 6, i32 7>
+; AVX512-NEXT:    [[TMP1:%.*]] = shufflevector <4 x i32> [[V2:%.*]], <4 x i32> [[V1:%.*]], <8 x i32> <i32 1, i32 0, i32 2, i32 3, i32 5, i32 4, i32 6, i32 7>
+; AVX512-NEXT:    [[RDX_OP:%.*]] = and <8 x i32> [[TMP0]], [[TMP1]]
+; AVX512-NEXT:    [[OP_RDX:%.*]] = call i32 @llvm.vector.reduce.and.v8i32(<8 x i32> [[RDX_OP]])
 ; AVX512-NEXT:    [[OP_RDX1:%.*]] = and i32 [[OP_RDX]], [[ACC:%.*]]
 ; AVX512-NEXT:    ret i32 [[OP_RDX1]]
 ;
@@ -144,24 +130,10 @@ define i32 @reduce_and4_transpose(i32 %acc, <4 x i32> %v1, <4 x i32> %v2, <4 x i
 ; AVX2-NEXT:    ret i32 [[OP_RDX]]
 ;
 ; AVX512-LABEL: @reduce_and4_transpose(
-; AVX512-NEXT:    [[VECEXT:%.*]] = extractelement <4 x i32> [[V1:%.*]], i64 0
-; AVX512-NEXT:    [[VECEXT1:%.*]] = extractelement <4 x i32> [[V2:%.*]], i64 0
-; AVX512-NEXT:    [[VECEXT7:%.*]] = extractelement <4 x i32> [[V1]], i64 1
-; AVX512-NEXT:    [[VECEXT8:%.*]] = extractelement <4 x i32> [[V2]], i64 1
-; AVX512-NEXT:    [[VECEXT15:%.*]] = extractelement <4 x i32> [[V1]], i64 2
-; AVX512-NEXT:    [[VECEXT16:%.*]] = extractelement <4 x i32> [[V2]], i64 2
-; AVX512-NEXT:    [[VECEXT23:%.*]] = extractelement <4 x i32> [[V1]], i64 3
-; AVX512-NEXT:    [[VECEXT24:%.*]] = extractelement <4 x i32> [[V2]], i64 3
-; AVX512-NEXT:    [[TMP1:%.*]] = shufflevector <4 x i32> [[V4:%.*]], <4 x i32> [[V3:%.*]], <16 x i32> <i32 3, i32 2, i32 1, i32 0, i32 7, i32 6, i32 5, i32 4, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
-; AVX512-NEXT:    [[TMP2:%.*]] = insertelement <16 x i32> [[TMP1]], i32 [[VECEXT24]], i32 8
-; AVX512-NEXT:    [[TMP3:%.*]] = insertelement <16 x i32> [[TMP2]], i32 [[VECEXT16]], i32 9
-; AVX512-NEXT:    [[TMP4:%.*]] = insertelement <16 x i32> [[TMP3]], i32 [[VECEXT8]], i32 10
-; AVX512-NEXT:    [[TMP5:%.*]] = insertelement <16 x i32> [[TMP4]], i32 [[VECEXT1]], i32 11
-; AVX512-NEXT:    [[TMP6:%.*]] = insertelement <16 x i32> [[TMP5]], i32 [[VECEXT23]], i32 12
-; AVX512-NEXT:    [[TMP7:%.*]] = insertelement <16 x i32> [[TMP6]], i32 [[VECEXT15]], i32 13
-; AVX512-NEXT:    [[TMP8:%.*]] = insertelement <16 x i32> [[TMP7]], i32 [[VECEXT7]], i32 14
-; AVX512-NEXT:    [[TMP9:%.*]] = insertelement <16 x i32> [[TMP8]], i32 [[VECEXT]], i32 15
-; AVX512-NEXT:    [[OP_RDX:%.*]] = call i32 @llvm.vector.reduce.and.v16i32(<16 x i32> [[TMP9]])
+; AVX512-NEXT:    [[TMP1:%.*]] = shufflevector <4 x i32> [[V4:%.*]], <4 x i32> [[V3:%.*]], <8 x i32> <i32 3, i32 2, i32 1, i32 0, i32 7, i32 6, i32 5, i32 4>
+; AVX512-NEXT:    [[TMP2:%.*]] = shufflevector <4 x i32> [[V2:%.*]], <4 x i32> [[V1:%.*]], <8 x i32> <i32 3, i32 2, i32 1, i32 0, i32 7, i32 6, i32 5, i32 4>
+; AVX512-NEXT:    [[RDX_OP:%.*]] = and <8 x i32> [[TMP1]], [[TMP2]]
+; AVX512-NEXT:    [[OP_RDX:%.*]] = call i32 @llvm.vector.reduce.and.v8i32(<8 x i32> [[RDX_OP]])
 ; AVX512-NEXT:    [[OP_RDX1:%.*]] = and i32 [[OP_RDX]], [[ACC:%.*]]
 ; AVX512-NEXT:    ret i32 [[OP_RDX1]]
 ;

@hiraditya
Copy link
Collaborator

lgtm unless @RKSimon has any feedback on this.

APInt NewDemandedElts = DemandedElts;
InstructionCost Cost = 0;
if (!ForPoisonSrc && Insert) {
// Handle insert into non-poison vector.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not certain this is entirely correct, afaict I can tell you are trying to cost inserting a number of elements into an existing (non-poison) vector, while the TTI::getScalarizationOverhead assumes Insert is for a ISD::BUILD_VECTOR style pattern. Would a SK_Select shuffle not be closer to what you're after? The isZero checks below suggest you've had to write this mainly for the single element diamond case, so maybe that need to be handled by a DemandedElts.isOneBitSet() special case instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use a conservative solution for now. Currently, codegen just generates insertelement for each unique scalar here, so I just replicated this for non-poison input value.
Need to teach ScalarizationOverhead of the non-poison input vector. Also, not sure the current generic implementation is fully correct. It currently supposes that all insertvector instructions are inserted into poison vector, but it is true only for the first instruction.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK - add a TODO comment for now please

Created using spr 1.3.5
Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with one minor

Created using spr 1.3.5
@alexey-bataev alexey-bataev merged commit ad9909d into main Mar 24, 2025
6 of 10 checks passed
@alexey-bataev alexey-bataev deleted the users/alexey-bataev/spr/slpfix-perfect-diamond-match-with-extractelements-in-scalars branch March 24, 2025 13:29
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Mar 24, 2025
…lars

Need to drop all previous estimations/vectorizations, when found
a perfect diamond match. This improves cost estimation and improves code
emission.
Also, need to adjust getScalarizationOverhead cost for non-poison input
vector. Currently, it does not allow to estimate it correctly, so
instead use conservative element-by-element insertelement cost for each
unique scalar.

Reviewers: RKSimon, hiraditya

Reviewed By: RKSimon

Pull Request: llvm/llvm-project#132466
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants