Skip to content

[VPlan] Separate out logic to manage IR flags to VPIRFlags (NFC). #140621

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 25, 2025

Conversation

fhahn
Copy link
Contributor

@fhahn fhahn commented May 19, 2025

This patch moves the logic to manage IR flags to a separate VPIRFlags class. For now, VPRecipeWithIRFlags is the only class that inherits VPIRFlags. The new class allows for simpler passing of flags when constructing recipes, simplifying the constructors for various recipes (VPInstruction in particular, which now just has 2 constructors, one taking an extra VPIRFlags argument.

This mirrors the approach taken for VPIRMetadata and makes it easier to extend in the future. The patch also adds a unified flagsValidForOpcode to check if the flags in a VPIRFlags match the provided opcode.

This patch moves the logic to manage IR flags to a separate VPIRFlags
class. For now, VPRecipeWithIRFlags is the only class that inherits
VPIRFlags. The new class allows for simpler passing of flags when
constructing recipes, simplifying the constructors for various recipes
(VPInstruction in particular, which now just has 2 constructors, one
taking an extra VPIRFlags argument.

This mirrors the approach taken for VPIRMetadata and makes it easier to
extend in the future. The patch also adds a unified flagsValidForOpcode
to check if the flags in a VPIRFlags match the provided opcode.
@llvmbot
Copy link
Member

llvmbot commented May 19, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Florian Hahn (fhahn)

Changes

This patch moves the logic to manage IR flags to a separate VPIRFlags class. For now, VPRecipeWithIRFlags is the only class that inherits VPIRFlags. The new class allows for simpler passing of flags when constructing recipes, simplifying the constructors for various recipes (VPInstruction in particular, which now just has 2 constructors, one taking an extra VPIRFlags argument.

This mirrors the approach taken for VPIRMetadata and makes it easier to extend in the future. The patch also adds a unified flagsValidForOpcode to check if the flags in a VPIRFlags match the provided opcode.


Patch is 27.07 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140621.diff

5 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h (+14-18)
  • (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+81-122)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+42-36)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp (+18-15)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp (+4-4)
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
index bae53c600c18c..c751f053cb65a 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
@@ -164,25 +164,19 @@ class VPBuilder {
                               DebugLoc DL, const Twine &Name = "") {
     return createInstruction(Opcode, Operands, DL, Name);
   }
-  VPInstruction *createNaryOp(unsigned Opcode,
-                              std::initializer_list<VPValue *> Operands,
-                              std::optional<FastMathFlags> FMFs = {},
-                              DebugLoc DL = {}, const Twine &Name = "") {
-    if (FMFs)
-      return tryInsertInstruction(
-          new VPInstruction(Opcode, Operands, *FMFs, DL, Name));
-    return createInstruction(Opcode, Operands, DL, Name);
+  VPInstruction *createNaryOp(unsigned Opcode, ArrayRef<VPValue *> Operands,
+                              const VPIRFlags &Flags, DebugLoc DL = {},
+                              const Twine &Name = "") {
+    return tryInsertInstruction(
+        new VPInstruction(Opcode, Operands, Flags, DL, Name));
   }
+
   VPInstruction *createNaryOp(unsigned Opcode,
                               std::initializer_list<VPValue *> Operands,
-                              Type *ResultTy,
-                              std::optional<FastMathFlags> FMFs = {},
+                              Type *ResultTy, const VPIRFlags &Flags = {},
                               DebugLoc DL = {}, const Twine &Name = "") {
-    if (FMFs)
-      return tryInsertInstruction(new VPInstructionWithType(
-          Opcode, Operands, ResultTy, *FMFs, DL, Name));
     return tryInsertInstruction(
-        new VPInstructionWithType(Opcode, Operands, ResultTy, DL, Name));
+        new VPInstructionWithType(Opcode, Operands, ResultTy, Flags, DL, Name));
   }
 
   VPInstruction *createOverflowingOp(unsigned Opcode,
@@ -236,18 +230,20 @@ class VPBuilder {
     assert(Pred >= CmpInst::FIRST_ICMP_PREDICATE &&
            Pred <= CmpInst::LAST_ICMP_PREDICATE && "invalid predicate");
     return tryInsertInstruction(
-        new VPInstruction(Instruction::ICmp, Pred, A, B, DL, Name));
+        new VPInstruction(Instruction::ICmp, {A, B}, Pred, DL, Name));
   }
 
   VPInstruction *createPtrAdd(VPValue *Ptr, VPValue *Offset, DebugLoc DL = {},
                               const Twine &Name = "") {
     return tryInsertInstruction(
-        new VPInstruction(Ptr, Offset, GEPNoWrapFlags::none(), DL, Name));
+        new VPInstruction(VPInstruction::PtrAdd, {Ptr, Offset},
+                          GEPNoWrapFlags::none(), DL, Name));
   }
   VPValue *createInBoundsPtrAdd(VPValue *Ptr, VPValue *Offset, DebugLoc DL = {},
                                 const Twine &Name = "") {
     return tryInsertInstruction(
-        new VPInstruction(Ptr, Offset, GEPNoWrapFlags::inBounds(), DL, Name));
+        new VPInstruction(VPInstruction::PtrAdd, {Ptr, Offset},
+                          GEPNoWrapFlags::inBounds(), DL, Name));
   }
 
   VPInstruction *createScalarPhi(ArrayRef<VPValue *> IncomingValues,
@@ -269,7 +265,7 @@ class VPBuilder {
   VPInstruction *createScalarCast(Instruction::CastOps Opcode, VPValue *Op,
                                   Type *ResultTy, DebugLoc DL) {
     return tryInsertInstruction(
-        new VPInstructionWithType(Opcode, Op, ResultTy, DL));
+        new VPInstructionWithType(Opcode, Op, ResultTy, {}, DL));
   }
 
   VPWidenCastRecipe *createWidenCast(Instruction::CastOps Opcode, VPValue *Op,
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index e634de1e17c69..b38fb7e9b1adb 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -577,8 +577,8 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue {
 #endif
 };
 
-/// Class to record LLVM IR flag for a recipe along with it.
-class VPRecipeWithIRFlags : public VPSingleDefRecipe {
+/// Class to record LLVM IR flags.
+class VPIRFlags {
   enum class OperationType : unsigned char {
     Cmp,
     OverflowingBinOp,
@@ -637,23 +637,10 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
     unsigned AllFlags;
   };
 
-protected:
-  void transferFlags(VPRecipeWithIRFlags &Other) {
-    OpType = Other.OpType;
-    AllFlags = Other.AllFlags;
-  }
-
 public:
-  VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
-                      DebugLoc DL = {})
-      : VPSingleDefRecipe(SC, Operands, DL) {
-    OpType = OperationType::Other;
-    AllFlags = 0;
-  }
+  VPIRFlags() : OpType(OperationType::Other), AllFlags(0) {}
 
-  VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
-                      Instruction &I)
-      : VPSingleDefRecipe(SC, Operands, &I, I.getDebugLoc()) {
+  VPIRFlags(Instruction &I) {
     if (auto *Op = dyn_cast<CmpInst>(&I)) {
       OpType = OperationType::Cmp;
       CmpPredicate = Op->getPredicate();
@@ -681,63 +668,27 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
     }
   }
 
-  VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
-                      CmpInst::Predicate Pred, DebugLoc DL = {})
-      : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::Cmp),
-        CmpPredicate(Pred) {}
+  VPIRFlags(CmpInst::Predicate Pred)
+      : OpType(OperationType::Cmp), CmpPredicate(Pred) {}
 
-  VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
-                      WrapFlagsTy WrapFlags, DebugLoc DL = {})
-      : VPSingleDefRecipe(SC, Operands, DL),
-        OpType(OperationType::OverflowingBinOp), WrapFlags(WrapFlags) {}
+  VPIRFlags(WrapFlagsTy WrapFlags)
+      : OpType(OperationType::OverflowingBinOp), WrapFlags(WrapFlags) {}
 
-  VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
-                      FastMathFlags FMFs, DebugLoc DL = {})
-      : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::FPMathOp),
-        FMFs(FMFs) {}
+  VPIRFlags(FastMathFlags FMFs) : OpType(OperationType::FPMathOp), FMFs(FMFs) {}
 
-  VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
-                      DisjointFlagsTy DisjointFlags, DebugLoc DL = {})
-      : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::DisjointOp),
-        DisjointFlags(DisjointFlags) {}
+  VPIRFlags(DisjointFlagsTy DisjointFlags)
+      : OpType(OperationType::DisjointOp), DisjointFlags(DisjointFlags) {}
 
-  template <typename IterT>
-  VPRecipeWithIRFlags(const unsigned char SC, IterT Operands,
-                      NonNegFlagsTy NonNegFlags, DebugLoc DL = {})
-      : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::NonNegOp),
-        NonNegFlags(NonNegFlags) {}
+  VPIRFlags(NonNegFlagsTy NonNegFlags)
+      : OpType(OperationType::NonNegOp), NonNegFlags(NonNegFlags) {}
 
-protected:
-  VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
-                      GEPNoWrapFlags GEPFlags, DebugLoc DL = {})
-      : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::GEPOp),
-        GEPFlags(GEPFlags) {}
+  VPIRFlags(GEPNoWrapFlags GEPFlags)
+      : OpType(OperationType::GEPOp), GEPFlags(GEPFlags) {}
 
 public:
-  static inline bool classof(const VPRecipeBase *R) {
-    return R->getVPDefID() == VPRecipeBase::VPInstructionSC ||
-           R->getVPDefID() == VPRecipeBase::VPWidenSC ||
-           R->getVPDefID() == VPRecipeBase::VPWidenGEPSC ||
-           R->getVPDefID() == VPRecipeBase::VPWidenCallSC ||
-           R->getVPDefID() == VPRecipeBase::VPWidenCastSC ||
-           R->getVPDefID() == VPRecipeBase::VPWidenIntrinsicSC ||
-           R->getVPDefID() == VPRecipeBase::VPReductionSC ||
-           R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
-           R->getVPDefID() == VPRecipeBase::VPReplicateSC ||
-           R->getVPDefID() == VPRecipeBase::VPVectorEndPointerSC ||
-           R->getVPDefID() == VPRecipeBase::VPVectorPointerSC ||
-           R->getVPDefID() == VPRecipeBase::VPExtendedReductionSC ||
-           R->getVPDefID() == VPRecipeBase::VPMulAccumulateReductionSC;
-  }
-
-  static inline bool classof(const VPUser *U) {
-    auto *R = dyn_cast<VPRecipeBase>(U);
-    return R && classof(R);
-  }
-
-  static inline bool classof(const VPValue *V) {
-    auto *R = dyn_cast_or_null<VPRecipeBase>(V->getDefiningRecipe());
-    return R && classof(R);
+  void transferFlags(VPIRFlags &Other) {
+    OpType = Other.OpType;
+    AllFlags = Other.AllFlags;
   }
 
   /// Drop all poison-generating flags.
@@ -851,11 +802,58 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
     return DisjointFlags.IsDisjoint;
   }
 
+#if !defined(NDEBUG)
+  /// Returns true if the set flags are valid for \p Opcode.
+  bool flagsValidForOpcode(unsigned Opcode) const;
+#endif
+
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
   void printFlags(raw_ostream &O) const;
 #endif
 };
 
+class VPRecipeWithIRFlags : public VPSingleDefRecipe, public VPIRFlags {
+public:
+  VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
+                      DebugLoc DL = {})
+      : VPSingleDefRecipe(SC, Operands, DL), VPIRFlags() {}
+
+  VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
+                      Instruction &I)
+      : VPSingleDefRecipe(SC, Operands, &I, I.getDebugLoc()), VPIRFlags(I) {}
+
+  VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
+                      const VPIRFlags &Flags, DebugLoc DL = {})
+      : VPSingleDefRecipe(SC, Operands, DL), VPIRFlags(Flags) {}
+
+public:
+  static inline bool classof(const VPRecipeBase *R) {
+    return R->getVPDefID() == VPRecipeBase::VPInstructionSC ||
+           R->getVPDefID() == VPRecipeBase::VPWidenSC ||
+           R->getVPDefID() == VPRecipeBase::VPWidenGEPSC ||
+           R->getVPDefID() == VPRecipeBase::VPWidenCallSC ||
+           R->getVPDefID() == VPRecipeBase::VPWidenCastSC ||
+           R->getVPDefID() == VPRecipeBase::VPWidenIntrinsicSC ||
+           R->getVPDefID() == VPRecipeBase::VPReductionSC ||
+           R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
+           R->getVPDefID() == VPRecipeBase::VPReplicateSC ||
+           R->getVPDefID() == VPRecipeBase::VPVectorEndPointerSC ||
+           R->getVPDefID() == VPRecipeBase::VPVectorPointerSC ||
+           R->getVPDefID() == VPRecipeBase::VPExtendedReductionSC ||
+           R->getVPDefID() == VPRecipeBase::VPMulAccumulateReductionSC;
+  }
+
+  static inline bool classof(const VPUser *U) {
+    auto *R = dyn_cast<VPRecipeBase>(U);
+    return R && classof(R);
+  }
+
+  static inline bool classof(const VPValue *V) {
+    auto *R = dyn_cast_or_null<VPRecipeBase>(V->getDefiningRecipe());
+    return R && classof(R);
+  }
+};
+
 /// Helper to access the operand that contains the unroll part for this recipe
 /// after unrolling.
 template <unsigned PartOpIdx> class VPUnrollPartAccessor {
@@ -958,54 +956,21 @@ class VPInstruction : public VPRecipeWithIRFlags,
   /// value for lane \p Lane.
   Value *generatePerLane(VPTransformState &State, const VPLane &Lane);
 
-#if !defined(NDEBUG)
-  /// Return true if the VPInstruction is a floating point math operation, i.e.
-  /// has fast-math flags.
-  bool isFPMathOp() const;
-#endif
-
 public:
-  VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands, DebugLoc DL,
+  VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands, DebugLoc DL = {},
                 const Twine &Name = "")
       : VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, DL),
         Opcode(Opcode), Name(Name.str()) {}
 
-  VPInstruction(unsigned Opcode, std::initializer_list<VPValue *> Operands,
-                DebugLoc DL = {}, const Twine &Name = "")
-      : VPInstruction(Opcode, ArrayRef<VPValue *>(Operands), DL, Name) {}
-
-  VPInstruction(unsigned Opcode, CmpInst::Predicate Pred, VPValue *A,
-                VPValue *B, DebugLoc DL = {}, const Twine &Name = "");
-
-  VPInstruction(unsigned Opcode, std::initializer_list<VPValue *> Operands,
-                WrapFlagsTy WrapFlags, DebugLoc DL = {}, const Twine &Name = "")
-      : VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, WrapFlags, DL),
-        Opcode(Opcode), Name(Name.str()) {}
-
-  VPInstruction(unsigned Opcode, std::initializer_list<VPValue *> Operands,
-                DisjointFlagsTy DisjointFlag, DebugLoc DL = {},
-                const Twine &Name = "")
-      : VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, DisjointFlag, DL),
-        Opcode(Opcode), Name(Name.str()) {
-    assert(Opcode == Instruction::Or && "only OR opcodes can be disjoint");
-  }
-
-  VPInstruction(VPValue *Ptr, VPValue *Offset, GEPNoWrapFlags Flags,
-                DebugLoc DL = {}, const Twine &Name = "")
-      : VPRecipeWithIRFlags(VPDef::VPInstructionSC,
-                            ArrayRef<VPValue *>({Ptr, Offset}), Flags, DL),
-        Opcode(VPInstruction::PtrAdd), Name(Name.str()) {}
-
-  VPInstruction(unsigned Opcode, std::initializer_list<VPValue *> Operands,
-                FastMathFlags FMFs, DebugLoc DL = {}, const Twine &Name = "");
+  VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands,
+                const VPIRFlags &Flags, DebugLoc DL = {},
+                const Twine &Name = "");
 
   VP_CLASSOF_IMPL(VPDef::VPInstructionSC)
 
   VPInstruction *clone() override {
     SmallVector<VPValue *, 2> Operands(operands());
-    auto *New = new VPInstruction(Opcode, Operands, getDebugLoc(), Name);
-    New->transferFlags(*this);
-    return New;
+    return new VPInstruction(Opcode, Operands, *this, getDebugLoc(), Name);
   }
 
   unsigned getOpcode() const { return Opcode; }
@@ -1082,13 +1047,9 @@ class VPInstructionWithType : public VPInstruction {
 
 public:
   VPInstructionWithType(unsigned Opcode, ArrayRef<VPValue *> Operands,
-                        Type *ResultTy, DebugLoc DL, const Twine &Name = "")
-      : VPInstruction(Opcode, Operands, DL, Name), ResultTy(ResultTy) {}
-  VPInstructionWithType(unsigned Opcode,
-                        std::initializer_list<VPValue *> Operands,
-                        Type *ResultTy, FastMathFlags FMFs, DebugLoc DL = {},
+                        Type *ResultTy, const VPIRFlags &Flags, DebugLoc DL,
                         const Twine &Name = "")
-      : VPInstruction(Opcode, Operands, FMFs, DL, Name), ResultTy(ResultTy) {}
+      : VPInstruction(Opcode, Operands, Flags, DL, Name), ResultTy(ResultTy) {}
 
   static inline bool classof(const VPRecipeBase *R) {
     // VPInstructionWithType are VPInstructions with specific opcodes requiring
@@ -1113,8 +1074,9 @@ class VPInstructionWithType : public VPInstruction {
 
   VPInstruction *clone() override {
     SmallVector<VPValue *, 2> Operands(operands());
-    auto *New = new VPInstructionWithType(
-        getOpcode(), Operands, getResultType(), getDebugLoc(), getName());
+    auto *New =
+        new VPInstructionWithType(getOpcode(), Operands, getResultType(), *this,
+                                  getDebugLoc(), getName());
     New->setUnderlyingValue(getUnderlyingValue());
     return New;
   }
@@ -1373,15 +1335,12 @@ class VPWidenCastRecipe : public VPRecipeWithIRFlags, public VPIRMetadata {
   }
 
   VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy,
-                    DebugLoc DL = {})
-      : VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, DL), VPIRMetadata(),
-        Opcode(Opcode), ResultTy(ResultTy) {}
-
-  VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy,
-                    bool IsNonNeg, DebugLoc DL = {})
-      : VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, NonNegFlagsTy(IsNonNeg),
-                            DL),
-        Opcode(Opcode), ResultTy(ResultTy) {}
+                    const VPIRFlags &Flags = {}, DebugLoc DL = {})
+      : VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, Flags, DL),
+        VPIRMetadata(), Opcode(Opcode), ResultTy(ResultTy) {
+    assert(flagsValidForOpcode(Opcode) &&
+           "Set flags not supported for the provided opcode");
+  }
 
   ~VPWidenCastRecipe() override = default;
 
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 14ed40f16683a..3a57ce4c8af6e 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -368,7 +368,7 @@ void VPPartialReductionRecipe::print(raw_ostream &O, const Twine &Indent,
 }
 #endif
 
-FastMathFlags VPRecipeWithIRFlags::getFastMathFlags() const {
+FastMathFlags VPIRFlags::getFastMathFlags() const {
   assert(OpType == OperationType::FPMathOp &&
          "recipe doesn't have fast math flags");
   FastMathFlags Res;
@@ -406,23 +406,13 @@ template class VPUnrollPartAccessor<2>;
 template class VPUnrollPartAccessor<3>;
 }
 
-VPInstruction::VPInstruction(unsigned Opcode, CmpInst::Predicate Pred,
-                             VPValue *A, VPValue *B, DebugLoc DL,
+VPInstruction::VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands,
+                             const VPIRFlags &Flags, DebugLoc DL,
                              const Twine &Name)
-    : VPRecipeWithIRFlags(VPDef::VPInstructionSC, ArrayRef<VPValue *>({A, B}),
-                          Pred, DL),
+    : VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, Flags, DL),
       Opcode(Opcode), Name(Name.str()) {
-  assert(Opcode == Instruction::ICmp &&
-         "only ICmp predicates supported at the moment");
-}
-
-VPInstruction::VPInstruction(unsigned Opcode,
-                             std::initializer_list<VPValue *> Operands,
-                             FastMathFlags FMFs, DebugLoc DL, const Twine &Name)
-    : VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, FMFs, DL),
-      Opcode(Opcode), Name(Name.str()) {
-  // Make sure the VPInstruction is a floating-point operation.
-  assert(isFPMathOp() && "this op can't take fast-math flags");
+  assert(flagsValidForOpcode(getOpcode()) &&
+         "Set flags not supported for the provided opcode");
 }
 
 bool VPInstruction::doesGeneratePerAllLanes() const {
@@ -864,24 +854,11 @@ bool VPInstruction::isSingleScalar() const {
          getOpcode() == Instruction::PHI;
 }
 
-#if !defined(NDEBUG)
-bool VPInstruction::isFPMathOp() const {
-  // Inspired by FPMathOperator::classof. Notable differences are that we don't
-  // support Call, PHI and Select opcodes here yet.
-  return Opcode == Instruction::FAdd || Opcode == Instruction::FMul ||
-         Opcode == Instruction::FNeg || Opcode == Instruction::FSub ||
-         Opcode == Instruction::FDiv || Opcode == Instruction::FRem ||
-         Opcode == Instruction::FCmp || Opcode == Instruction::Select ||
-         Opcode == VPInstruction::WideIVStep;
-}
-#endif
-
 void VPInstruction::execute(VPTransformState &State) {
   assert(!State.Lane && "VPInstruction executing an Lane");
   IRBuilderBase::FastMathFlagGuard FMFGuard(State.Builder);
-  assert((hasFastMathFlags() == isFPMathOp() ||
-          getOpcode() == Instruction::Select) &&
-         "Recipe not a FPMathOp but has fast-math flags?");
+  assert(flagsValidForOpcode(getOpcode()) &&
+         "Set flags not supported for the provided opcode");
   if (hasFastMathFlags())
     State.Builder.setFastMathFlags(getFastMathFlags());
   bool GeneratesPerFirstLaneOnly = canGenerateScalarForFirstLane() &&
@@ -1606,8 +1583,7 @@ InstructionCost VPWidenSelectRecipe::computeCost(ElementCount VF,
       {TTI::OK_AnyValue, TTI::OP_None}, {TTI::OK_AnyValue, TTI::OP_None}, SI);
 }
 
-VPRecipeWithIRFlags::FastMathFlagsTy::FastMathFlagsTy(
-    const FastMathFlags &FMF) {
+VPIRFlags::FastMathFlagsTy::FastMathFlagsTy(const FastMathFlags &FMF) {
   AllowReassoc = FMF.allowReassoc();
   NoNaNs = FMF.noNaNs();
   NoInfs = FMF.noInfs();
@@ -1617,8 +1593,39 @@ VPRecipeWithIRFlags::FastMathFlagsTy::FastMathFlagsTy(
   ApproxFunc = FMF.approxFunc();
 }
 
+#if !defined(NDEBUG)
+bool VPIRFlags::flagsValidForOpcode(unsigned Opcode) const {
+  switch (OpType) {
+  case OperationType::OverflowingBinOp:
+    return Opcode == Instruction::Add || Opcode == Instruction::Sub ||
+           Opcode == Instruction::Mul ||
+           Opcode == VPInstruction::VPInstruction::CanonicalIVIncrementForPart;
+  case OperationType::DisjointOp:
+    return O...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 19, 2025

@llvm/pr-subscribers-vectorizers

Author: Florian Hahn (fhahn)

Changes

This patch moves the logic to manage IR flags to a separate VPIRFlags class. For now, VPRecipeWithIRFlags is the only class that inherits VPIRFlags. The new class allows for simpler passing of flags when constructing recipes, simplifying the constructors for various recipes (VPInstruction in particular, which now just has 2 constructors, one taking an extra VPIRFlags argument.

This mirrors the approach taken for VPIRMetadata and makes it easier to extend in the future. The patch also adds a unified flagsValidForOpcode to check if the flags in a VPIRFlags match the provided opcode.


Patch is 27.07 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140621.diff

5 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h (+14-18)
  • (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+81-122)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+42-36)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp (+18-15)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp (+4-4)
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
index bae53c600c18c..c751f053cb65a 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
@@ -164,25 +164,19 @@ class VPBuilder {
                               DebugLoc DL, const Twine &Name = "") {
     return createInstruction(Opcode, Operands, DL, Name);
   }
-  VPInstruction *createNaryOp(unsigned Opcode,
-                              std::initializer_list<VPValue *> Operands,
-                              std::optional<FastMathFlags> FMFs = {},
-                              DebugLoc DL = {}, const Twine &Name = "") {
-    if (FMFs)
-      return tryInsertInstruction(
-          new VPInstruction(Opcode, Operands, *FMFs, DL, Name));
-    return createInstruction(Opcode, Operands, DL, Name);
+  VPInstruction *createNaryOp(unsigned Opcode, ArrayRef<VPValue *> Operands,
+                              const VPIRFlags &Flags, DebugLoc DL = {},
+                              const Twine &Name = "") {
+    return tryInsertInstruction(
+        new VPInstruction(Opcode, Operands, Flags, DL, Name));
   }
+
   VPInstruction *createNaryOp(unsigned Opcode,
                               std::initializer_list<VPValue *> Operands,
-                              Type *ResultTy,
-                              std::optional<FastMathFlags> FMFs = {},
+                              Type *ResultTy, const VPIRFlags &Flags = {},
                               DebugLoc DL = {}, const Twine &Name = "") {
-    if (FMFs)
-      return tryInsertInstruction(new VPInstructionWithType(
-          Opcode, Operands, ResultTy, *FMFs, DL, Name));
     return tryInsertInstruction(
-        new VPInstructionWithType(Opcode, Operands, ResultTy, DL, Name));
+        new VPInstructionWithType(Opcode, Operands, ResultTy, Flags, DL, Name));
   }
 
   VPInstruction *createOverflowingOp(unsigned Opcode,
@@ -236,18 +230,20 @@ class VPBuilder {
     assert(Pred >= CmpInst::FIRST_ICMP_PREDICATE &&
            Pred <= CmpInst::LAST_ICMP_PREDICATE && "invalid predicate");
     return tryInsertInstruction(
-        new VPInstruction(Instruction::ICmp, Pred, A, B, DL, Name));
+        new VPInstruction(Instruction::ICmp, {A, B}, Pred, DL, Name));
   }
 
   VPInstruction *createPtrAdd(VPValue *Ptr, VPValue *Offset, DebugLoc DL = {},
                               const Twine &Name = "") {
     return tryInsertInstruction(
-        new VPInstruction(Ptr, Offset, GEPNoWrapFlags::none(), DL, Name));
+        new VPInstruction(VPInstruction::PtrAdd, {Ptr, Offset},
+                          GEPNoWrapFlags::none(), DL, Name));
   }
   VPValue *createInBoundsPtrAdd(VPValue *Ptr, VPValue *Offset, DebugLoc DL = {},
                                 const Twine &Name = "") {
     return tryInsertInstruction(
-        new VPInstruction(Ptr, Offset, GEPNoWrapFlags::inBounds(), DL, Name));
+        new VPInstruction(VPInstruction::PtrAdd, {Ptr, Offset},
+                          GEPNoWrapFlags::inBounds(), DL, Name));
   }
 
   VPInstruction *createScalarPhi(ArrayRef<VPValue *> IncomingValues,
@@ -269,7 +265,7 @@ class VPBuilder {
   VPInstruction *createScalarCast(Instruction::CastOps Opcode, VPValue *Op,
                                   Type *ResultTy, DebugLoc DL) {
     return tryInsertInstruction(
-        new VPInstructionWithType(Opcode, Op, ResultTy, DL));
+        new VPInstructionWithType(Opcode, Op, ResultTy, {}, DL));
   }
 
   VPWidenCastRecipe *createWidenCast(Instruction::CastOps Opcode, VPValue *Op,
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index e634de1e17c69..b38fb7e9b1adb 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -577,8 +577,8 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue {
 #endif
 };
 
-/// Class to record LLVM IR flag for a recipe along with it.
-class VPRecipeWithIRFlags : public VPSingleDefRecipe {
+/// Class to record LLVM IR flags.
+class VPIRFlags {
   enum class OperationType : unsigned char {
     Cmp,
     OverflowingBinOp,
@@ -637,23 +637,10 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
     unsigned AllFlags;
   };
 
-protected:
-  void transferFlags(VPRecipeWithIRFlags &Other) {
-    OpType = Other.OpType;
-    AllFlags = Other.AllFlags;
-  }
-
 public:
-  VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
-                      DebugLoc DL = {})
-      : VPSingleDefRecipe(SC, Operands, DL) {
-    OpType = OperationType::Other;
-    AllFlags = 0;
-  }
+  VPIRFlags() : OpType(OperationType::Other), AllFlags(0) {}
 
-  VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
-                      Instruction &I)
-      : VPSingleDefRecipe(SC, Operands, &I, I.getDebugLoc()) {
+  VPIRFlags(Instruction &I) {
     if (auto *Op = dyn_cast<CmpInst>(&I)) {
       OpType = OperationType::Cmp;
       CmpPredicate = Op->getPredicate();
@@ -681,63 +668,27 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
     }
   }
 
-  VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
-                      CmpInst::Predicate Pred, DebugLoc DL = {})
-      : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::Cmp),
-        CmpPredicate(Pred) {}
+  VPIRFlags(CmpInst::Predicate Pred)
+      : OpType(OperationType::Cmp), CmpPredicate(Pred) {}
 
-  VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
-                      WrapFlagsTy WrapFlags, DebugLoc DL = {})
-      : VPSingleDefRecipe(SC, Operands, DL),
-        OpType(OperationType::OverflowingBinOp), WrapFlags(WrapFlags) {}
+  VPIRFlags(WrapFlagsTy WrapFlags)
+      : OpType(OperationType::OverflowingBinOp), WrapFlags(WrapFlags) {}
 
-  VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
-                      FastMathFlags FMFs, DebugLoc DL = {})
-      : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::FPMathOp),
-        FMFs(FMFs) {}
+  VPIRFlags(FastMathFlags FMFs) : OpType(OperationType::FPMathOp), FMFs(FMFs) {}
 
-  VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
-                      DisjointFlagsTy DisjointFlags, DebugLoc DL = {})
-      : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::DisjointOp),
-        DisjointFlags(DisjointFlags) {}
+  VPIRFlags(DisjointFlagsTy DisjointFlags)
+      : OpType(OperationType::DisjointOp), DisjointFlags(DisjointFlags) {}
 
-  template <typename IterT>
-  VPRecipeWithIRFlags(const unsigned char SC, IterT Operands,
-                      NonNegFlagsTy NonNegFlags, DebugLoc DL = {})
-      : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::NonNegOp),
-        NonNegFlags(NonNegFlags) {}
+  VPIRFlags(NonNegFlagsTy NonNegFlags)
+      : OpType(OperationType::NonNegOp), NonNegFlags(NonNegFlags) {}
 
-protected:
-  VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
-                      GEPNoWrapFlags GEPFlags, DebugLoc DL = {})
-      : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::GEPOp),
-        GEPFlags(GEPFlags) {}
+  VPIRFlags(GEPNoWrapFlags GEPFlags)
+      : OpType(OperationType::GEPOp), GEPFlags(GEPFlags) {}
 
 public:
-  static inline bool classof(const VPRecipeBase *R) {
-    return R->getVPDefID() == VPRecipeBase::VPInstructionSC ||
-           R->getVPDefID() == VPRecipeBase::VPWidenSC ||
-           R->getVPDefID() == VPRecipeBase::VPWidenGEPSC ||
-           R->getVPDefID() == VPRecipeBase::VPWidenCallSC ||
-           R->getVPDefID() == VPRecipeBase::VPWidenCastSC ||
-           R->getVPDefID() == VPRecipeBase::VPWidenIntrinsicSC ||
-           R->getVPDefID() == VPRecipeBase::VPReductionSC ||
-           R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
-           R->getVPDefID() == VPRecipeBase::VPReplicateSC ||
-           R->getVPDefID() == VPRecipeBase::VPVectorEndPointerSC ||
-           R->getVPDefID() == VPRecipeBase::VPVectorPointerSC ||
-           R->getVPDefID() == VPRecipeBase::VPExtendedReductionSC ||
-           R->getVPDefID() == VPRecipeBase::VPMulAccumulateReductionSC;
-  }
-
-  static inline bool classof(const VPUser *U) {
-    auto *R = dyn_cast<VPRecipeBase>(U);
-    return R && classof(R);
-  }
-
-  static inline bool classof(const VPValue *V) {
-    auto *R = dyn_cast_or_null<VPRecipeBase>(V->getDefiningRecipe());
-    return R && classof(R);
+  void transferFlags(VPIRFlags &Other) {
+    OpType = Other.OpType;
+    AllFlags = Other.AllFlags;
   }
 
   /// Drop all poison-generating flags.
@@ -851,11 +802,58 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
     return DisjointFlags.IsDisjoint;
   }
 
+#if !defined(NDEBUG)
+  /// Returns true if the set flags are valid for \p Opcode.
+  bool flagsValidForOpcode(unsigned Opcode) const;
+#endif
+
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
   void printFlags(raw_ostream &O) const;
 #endif
 };
 
+class VPRecipeWithIRFlags : public VPSingleDefRecipe, public VPIRFlags {
+public:
+  VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
+                      DebugLoc DL = {})
+      : VPSingleDefRecipe(SC, Operands, DL), VPIRFlags() {}
+
+  VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
+                      Instruction &I)
+      : VPSingleDefRecipe(SC, Operands, &I, I.getDebugLoc()), VPIRFlags(I) {}
+
+  VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
+                      const VPIRFlags &Flags, DebugLoc DL = {})
+      : VPSingleDefRecipe(SC, Operands, DL), VPIRFlags(Flags) {}
+
+public:
+  static inline bool classof(const VPRecipeBase *R) {
+    return R->getVPDefID() == VPRecipeBase::VPInstructionSC ||
+           R->getVPDefID() == VPRecipeBase::VPWidenSC ||
+           R->getVPDefID() == VPRecipeBase::VPWidenGEPSC ||
+           R->getVPDefID() == VPRecipeBase::VPWidenCallSC ||
+           R->getVPDefID() == VPRecipeBase::VPWidenCastSC ||
+           R->getVPDefID() == VPRecipeBase::VPWidenIntrinsicSC ||
+           R->getVPDefID() == VPRecipeBase::VPReductionSC ||
+           R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
+           R->getVPDefID() == VPRecipeBase::VPReplicateSC ||
+           R->getVPDefID() == VPRecipeBase::VPVectorEndPointerSC ||
+           R->getVPDefID() == VPRecipeBase::VPVectorPointerSC ||
+           R->getVPDefID() == VPRecipeBase::VPExtendedReductionSC ||
+           R->getVPDefID() == VPRecipeBase::VPMulAccumulateReductionSC;
+  }
+
+  static inline bool classof(const VPUser *U) {
+    auto *R = dyn_cast<VPRecipeBase>(U);
+    return R && classof(R);
+  }
+
+  static inline bool classof(const VPValue *V) {
+    auto *R = dyn_cast_or_null<VPRecipeBase>(V->getDefiningRecipe());
+    return R && classof(R);
+  }
+};
+
 /// Helper to access the operand that contains the unroll part for this recipe
 /// after unrolling.
 template <unsigned PartOpIdx> class VPUnrollPartAccessor {
@@ -958,54 +956,21 @@ class VPInstruction : public VPRecipeWithIRFlags,
   /// value for lane \p Lane.
   Value *generatePerLane(VPTransformState &State, const VPLane &Lane);
 
-#if !defined(NDEBUG)
-  /// Return true if the VPInstruction is a floating point math operation, i.e.
-  /// has fast-math flags.
-  bool isFPMathOp() const;
-#endif
-
 public:
-  VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands, DebugLoc DL,
+  VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands, DebugLoc DL = {},
                 const Twine &Name = "")
       : VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, DL),
         Opcode(Opcode), Name(Name.str()) {}
 
-  VPInstruction(unsigned Opcode, std::initializer_list<VPValue *> Operands,
-                DebugLoc DL = {}, const Twine &Name = "")
-      : VPInstruction(Opcode, ArrayRef<VPValue *>(Operands), DL, Name) {}
-
-  VPInstruction(unsigned Opcode, CmpInst::Predicate Pred, VPValue *A,
-                VPValue *B, DebugLoc DL = {}, const Twine &Name = "");
-
-  VPInstruction(unsigned Opcode, std::initializer_list<VPValue *> Operands,
-                WrapFlagsTy WrapFlags, DebugLoc DL = {}, const Twine &Name = "")
-      : VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, WrapFlags, DL),
-        Opcode(Opcode), Name(Name.str()) {}
-
-  VPInstruction(unsigned Opcode, std::initializer_list<VPValue *> Operands,
-                DisjointFlagsTy DisjointFlag, DebugLoc DL = {},
-                const Twine &Name = "")
-      : VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, DisjointFlag, DL),
-        Opcode(Opcode), Name(Name.str()) {
-    assert(Opcode == Instruction::Or && "only OR opcodes can be disjoint");
-  }
-
-  VPInstruction(VPValue *Ptr, VPValue *Offset, GEPNoWrapFlags Flags,
-                DebugLoc DL = {}, const Twine &Name = "")
-      : VPRecipeWithIRFlags(VPDef::VPInstructionSC,
-                            ArrayRef<VPValue *>({Ptr, Offset}), Flags, DL),
-        Opcode(VPInstruction::PtrAdd), Name(Name.str()) {}
-
-  VPInstruction(unsigned Opcode, std::initializer_list<VPValue *> Operands,
-                FastMathFlags FMFs, DebugLoc DL = {}, const Twine &Name = "");
+  VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands,
+                const VPIRFlags &Flags, DebugLoc DL = {},
+                const Twine &Name = "");
 
   VP_CLASSOF_IMPL(VPDef::VPInstructionSC)
 
   VPInstruction *clone() override {
     SmallVector<VPValue *, 2> Operands(operands());
-    auto *New = new VPInstruction(Opcode, Operands, getDebugLoc(), Name);
-    New->transferFlags(*this);
-    return New;
+    return new VPInstruction(Opcode, Operands, *this, getDebugLoc(), Name);
   }
 
   unsigned getOpcode() const { return Opcode; }
@@ -1082,13 +1047,9 @@ class VPInstructionWithType : public VPInstruction {
 
 public:
   VPInstructionWithType(unsigned Opcode, ArrayRef<VPValue *> Operands,
-                        Type *ResultTy, DebugLoc DL, const Twine &Name = "")
-      : VPInstruction(Opcode, Operands, DL, Name), ResultTy(ResultTy) {}
-  VPInstructionWithType(unsigned Opcode,
-                        std::initializer_list<VPValue *> Operands,
-                        Type *ResultTy, FastMathFlags FMFs, DebugLoc DL = {},
+                        Type *ResultTy, const VPIRFlags &Flags, DebugLoc DL,
                         const Twine &Name = "")
-      : VPInstruction(Opcode, Operands, FMFs, DL, Name), ResultTy(ResultTy) {}
+      : VPInstruction(Opcode, Operands, Flags, DL, Name), ResultTy(ResultTy) {}
 
   static inline bool classof(const VPRecipeBase *R) {
     // VPInstructionWithType are VPInstructions with specific opcodes requiring
@@ -1113,8 +1074,9 @@ class VPInstructionWithType : public VPInstruction {
 
   VPInstruction *clone() override {
     SmallVector<VPValue *, 2> Operands(operands());
-    auto *New = new VPInstructionWithType(
-        getOpcode(), Operands, getResultType(), getDebugLoc(), getName());
+    auto *New =
+        new VPInstructionWithType(getOpcode(), Operands, getResultType(), *this,
+                                  getDebugLoc(), getName());
     New->setUnderlyingValue(getUnderlyingValue());
     return New;
   }
@@ -1373,15 +1335,12 @@ class VPWidenCastRecipe : public VPRecipeWithIRFlags, public VPIRMetadata {
   }
 
   VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy,
-                    DebugLoc DL = {})
-      : VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, DL), VPIRMetadata(),
-        Opcode(Opcode), ResultTy(ResultTy) {}
-
-  VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy,
-                    bool IsNonNeg, DebugLoc DL = {})
-      : VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, NonNegFlagsTy(IsNonNeg),
-                            DL),
-        Opcode(Opcode), ResultTy(ResultTy) {}
+                    const VPIRFlags &Flags = {}, DebugLoc DL = {})
+      : VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, Flags, DL),
+        VPIRMetadata(), Opcode(Opcode), ResultTy(ResultTy) {
+    assert(flagsValidForOpcode(Opcode) &&
+           "Set flags not supported for the provided opcode");
+  }
 
   ~VPWidenCastRecipe() override = default;
 
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 14ed40f16683a..3a57ce4c8af6e 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -368,7 +368,7 @@ void VPPartialReductionRecipe::print(raw_ostream &O, const Twine &Indent,
 }
 #endif
 
-FastMathFlags VPRecipeWithIRFlags::getFastMathFlags() const {
+FastMathFlags VPIRFlags::getFastMathFlags() const {
   assert(OpType == OperationType::FPMathOp &&
          "recipe doesn't have fast math flags");
   FastMathFlags Res;
@@ -406,23 +406,13 @@ template class VPUnrollPartAccessor<2>;
 template class VPUnrollPartAccessor<3>;
 }
 
-VPInstruction::VPInstruction(unsigned Opcode, CmpInst::Predicate Pred,
-                             VPValue *A, VPValue *B, DebugLoc DL,
+VPInstruction::VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands,
+                             const VPIRFlags &Flags, DebugLoc DL,
                              const Twine &Name)
-    : VPRecipeWithIRFlags(VPDef::VPInstructionSC, ArrayRef<VPValue *>({A, B}),
-                          Pred, DL),
+    : VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, Flags, DL),
       Opcode(Opcode), Name(Name.str()) {
-  assert(Opcode == Instruction::ICmp &&
-         "only ICmp predicates supported at the moment");
-}
-
-VPInstruction::VPInstruction(unsigned Opcode,
-                             std::initializer_list<VPValue *> Operands,
-                             FastMathFlags FMFs, DebugLoc DL, const Twine &Name)
-    : VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, FMFs, DL),
-      Opcode(Opcode), Name(Name.str()) {
-  // Make sure the VPInstruction is a floating-point operation.
-  assert(isFPMathOp() && "this op can't take fast-math flags");
+  assert(flagsValidForOpcode(getOpcode()) &&
+         "Set flags not supported for the provided opcode");
 }
 
 bool VPInstruction::doesGeneratePerAllLanes() const {
@@ -864,24 +854,11 @@ bool VPInstruction::isSingleScalar() const {
          getOpcode() == Instruction::PHI;
 }
 
-#if !defined(NDEBUG)
-bool VPInstruction::isFPMathOp() const {
-  // Inspired by FPMathOperator::classof. Notable differences are that we don't
-  // support Call, PHI and Select opcodes here yet.
-  return Opcode == Instruction::FAdd || Opcode == Instruction::FMul ||
-         Opcode == Instruction::FNeg || Opcode == Instruction::FSub ||
-         Opcode == Instruction::FDiv || Opcode == Instruction::FRem ||
-         Opcode == Instruction::FCmp || Opcode == Instruction::Select ||
-         Opcode == VPInstruction::WideIVStep;
-}
-#endif
-
 void VPInstruction::execute(VPTransformState &State) {
   assert(!State.Lane && "VPInstruction executing an Lane");
   IRBuilderBase::FastMathFlagGuard FMFGuard(State.Builder);
-  assert((hasFastMathFlags() == isFPMathOp() ||
-          getOpcode() == Instruction::Select) &&
-         "Recipe not a FPMathOp but has fast-math flags?");
+  assert(flagsValidForOpcode(getOpcode()) &&
+         "Set flags not supported for the provided opcode");
   if (hasFastMathFlags())
     State.Builder.setFastMathFlags(getFastMathFlags());
   bool GeneratesPerFirstLaneOnly = canGenerateScalarForFirstLane() &&
@@ -1606,8 +1583,7 @@ InstructionCost VPWidenSelectRecipe::computeCost(ElementCount VF,
       {TTI::OK_AnyValue, TTI::OP_None}, {TTI::OK_AnyValue, TTI::OP_None}, SI);
 }
 
-VPRecipeWithIRFlags::FastMathFlagsTy::FastMathFlagsTy(
-    const FastMathFlags &FMF) {
+VPIRFlags::FastMathFlagsTy::FastMathFlagsTy(const FastMathFlags &FMF) {
   AllowReassoc = FMF.allowReassoc();
   NoNaNs = FMF.noNaNs();
   NoInfs = FMF.noInfs();
@@ -1617,8 +1593,39 @@ VPRecipeWithIRFlags::FastMathFlagsTy::FastMathFlagsTy(
   ApproxFunc = FMF.approxFunc();
 }
 
+#if !defined(NDEBUG)
+bool VPIRFlags::flagsValidForOpcode(unsigned Opcode) const {
+  switch (OpType) {
+  case OperationType::OverflowingBinOp:
+    return Opcode == Instruction::Add || Opcode == Instruction::Sub ||
+           Opcode == Instruction::Mul ||
+           Opcode == VPInstruction::VPInstruction::CanonicalIVIncrementForPart;
+  case OperationType::DisjointOp:
+    return O...
[truncated]

#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
void printFlags(raw_ostream &O) const;
#endif
};

class VPRecipeWithIRFlags : public VPSingleDefRecipe, public VPIRFlags {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment for this class

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added thanks.

: VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, NonNegFlagsTy(IsNonNeg),
DL),
Opcode(Opcode), ResultTy(ResultTy) {}
const VPIRFlags &Flags = {}, DebugLoc DL = {})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I presume changing the bool to VPIRFlags is to bring it inline with VPInstructionWithType, so we can replace it in #129712?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The immediate case is #140623 (using VPInstructionWithType for uniform recipes), but it will also be useful for #129712, once I can get around to updating it

@fhahn fhahn merged commit c0506a1 into llvm:main May 25, 2025
10 of 11 checks passed
@fhahn fhahn deleted the vplan-separate-vpirflags branch May 25, 2025 10:13
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request May 25, 2025
… (NFC). (#140621)

This patch moves the logic to manage IR flags to a separate VPIRFlags
class. For now, VPRecipeWithIRFlags is the only class that inherits
VPIRFlags. The new class allows for simpler passing of flags when
constructing recipes, simplifying the constructors for various recipes
(VPInstruction in particular, which now just has 2 constructors, one
taking an extra VPIRFlags argument.

This mirrors the approach taken for VPIRMetadata and makes it easier to
extend in the future. The patch also adds a unified flagsValidForOpcode
to check if the flags in a VPIRFlags match the provided opcode.

PR: llvm/llvm-project#140621
Copy link
Collaborator

@ayalz ayalz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Post-commit consistency nits.

@@ -164,25 +164,19 @@ class VPBuilder {
DebugLoc DL, const Twine &Name = "") {
return createInstruction(Opcode, Operands, DL, Name);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, also above between lines 162,163

Suggested change
}
}

OpType = OperationType::Other;
AllFlags = 0;
}
VPIRFlags() : OpType(OperationType::Other), AllFlags(0) {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to document the interface here and below.

static inline bool classof(const VPValue *V) {
auto *R = dyn_cast_or_null<VPRecipeBase>(V->getDefiningRecipe());
return R && classof(R);
void transferFlags(VPIRFlags &Other) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now transferFlags() needs to be public rather than protected?

Used only by recipe clone()'s - can it be replaced by copy constructor (or calls to flags-supporting constructors), as done in VPIRMetadata.

#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
void printFlags(raw_ostream &O) const;
#endif
};

/// A pure-virtual common base class for recipes defining a single VPValue and
/// using IR flags.
struct VPRecipeWithIRFlags : public VPSingleDefRecipe, public VPIRFlags {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This retains existing name, but perhaps VPSingleDefWithIRFlags would be more accurate, following its documentation above. Or VPSingleDefWithFlags, considering IR Flags are wrapped in VPIRFlags (along with condition predicates and potentially other constants).

auto *New = new VPInstruction(Opcode, Operands, getDebugLoc(), Name);
New->transferFlags(*this);
return New;
return new VPInstruction(Opcode, Operands, *this, getDebugLoc(), Name);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would this

Suggested change
return new VPInstruction(Opcode, Operands, *this, getDebugLoc(), Name);
return new VPInstruction(Opcode, operands(), *this, getDebugLoc(), Name);

also work? Additional instances below.

@@ -2524,11 +2525,12 @@ static void expandVPExtendedReduction(VPExtendedReductionRecipe *ExtRed) {
// Only ZExt contains non-neg flags.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consistency nit: rather than calling two distinct constructors - either w/ or w/o flags, could Flags be preset and used consistently, here and below, as in

Suggested change
// Only ZExt contains non-neg flags.
// Only ZExt contains non-neg flags.
VPIRFlags Flags;
if (ExtRed->isZExt())
Flags = VPIRFlags(*ExtRed); // or even set to VPIRFlags::NonNegFlagsTy(true/ExtRed->isNonNeg()) ?
Ext = new VPWidenCastRecipe(ExtRed->getExtOpcode(), ExtRed->getVecOp(),
ExtRed->getResultType(), Flags,
ExtRed->getDebugLoc());

?

if (isa_and_present<FPMathOperator>(ID.getInductionBinOp()))
FMFs = ID.getInductionBinOp()->getFastMathFlags();
Flags = ID.getInductionBinOp()->getFastMathFlags();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consistency nit: {} are used above in Flags = {VPI->getFastMathFlags()};?

VPInstruction *createNaryOp(unsigned Opcode, ArrayRef<VPValue *> Operands,
const VPIRFlags &Flags, DebugLoc DL = {},
const Twine &Name = "") {
return tryInsertInstruction(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consistency nit: worth overloading createInstruction() with a Flags-supporting version, to be used here and below?

Comment on lines +851 to +852
assert(flagsValidForOpcode(getOpcode()) &&
"Set flags not supported for the provided opcode");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consistency: if desired both at construction and at execution (considering flags as mutable), should VPWidenCastRecipe::execute() also assert?

sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Jun 3, 2025
…vm#140621)

This patch moves the logic to manage IR flags to a separate VPIRFlags
class. For now, VPRecipeWithIRFlags is the only class that inherits
VPIRFlags. The new class allows for simpler passing of flags when
constructing recipes, simplifying the constructors for various recipes
(VPInstruction in particular, which now just has 2 constructors, one
taking an extra VPIRFlags argument.

This mirrors the approach taken for VPIRMetadata and makes it easier to
extend in the future. The patch also adds a unified flagsValidForOpcode
to check if the flags in a VPIRFlags match the provided opcode.

PR: llvm#140621
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants