Skip to content

[MLIR] Add bufferization state to getBufferType and resolveConflicts interface methods #141466

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 28, 2025

Conversation

mscuttari
Copy link
Member

The PR continues the work started in #141019 by adding the BufferizationState class also to the getBufferType and resolveConflicts interface methods, together with the additional support functions that are used throughout the bufferization infrastructure.

@llvmbot
Copy link
Member

llvmbot commented May 26, 2025

@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir-tensor
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-mlprogram

Author: Michele Scuttari (mscuttari)

Changes

The PR continues the work started in #141019 by adding the BufferizationState class also to the getBufferType and resolveConflicts interface methods, together with the additional support functions that are used throughout the bufferization infrastructure.


Patch is 76.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/141466.diff

22 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h (+7-3)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td (+7-4)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td (+2-1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h (+3-2)
  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h (+2-1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h (+4-1)
  • (modified) mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp (+10-8)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+30-21)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+10-7)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (+6-5)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+8-5)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp (+3-2)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp (+12-9)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp (+13-13)
  • (modified) mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp (+2-1)
  • (modified) mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp (+61-44)
  • (modified) mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp (+1-1)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp (+4-2)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+46-37)
  • (modified) mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp (+11-6)
  • (modified) mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp (+5-1)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 43c97d57e1834..328d928c9ebdb 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -598,13 +598,14 @@ class BufferizationState {
 FailureOr<Value>
 allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue,
                              const BufferizationOptions &options,
-                             bool copy = true);
+                             BufferizationState &state, bool copy = true);
 
 /// Lookup the buffer for the given value. If the value was not bufferized
 /// yet, wrap it in a ToBufferOp. Otherwise, it is the result of a ToTensorOp,
 /// from which the memref operand is returned.
 FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
-                           const BufferizationOptions &options);
+                           const BufferizationOptions &options,
+                           BufferizationState &state);
 
 /// Return the buffer type for a given Value (tensor) after bufferization
 /// without bufferizing any IR.
@@ -615,7 +616,8 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
 ///
 /// This function is a wrapper around BufferizableOpInterface::getBufferType.
 FailureOr<BaseMemRefType> getBufferType(Value value,
-                                        const BufferizationOptions &options);
+                                        const BufferizationOptions &options,
+                                        BufferizationState &state);
 
 /// Return the buffer type for a given Value (tensor) after bufferization
 /// without bufferizing any IR. This function (and not the other overload
@@ -629,6 +631,7 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
 /// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
 FailureOr<BaseMemRefType> getBufferType(Value value,
                                         const BufferizationOptions &options,
+                                        BufferizationState &state,
                                         SmallVector<Value> &invocationStack);
 
 /// Return "true" if the given op has tensor semantics and should be bufferized.
@@ -709,6 +712,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
 /// places.
 FailureOr<BaseMemRefType>
 defaultGetBufferType(Value value, const BufferizationOptions &options,
+                     BufferizationState &state,
                      SmallVector<Value> &invocationStack);
 
 /// This is the default implementation of
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index b599a9f053215..80f9b72531660 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -381,13 +381,14 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
         /*retType=*/"::llvm::LogicalResult",
         /*methodName=*/"resolveConflicts",
         /*args=*/(ins "::mlir::RewriterBase &":$rewriter,
-                      "const ::mlir::bufferization::AnalysisState &":$state),
+                      "const ::mlir::bufferization::AnalysisState &":$analysisState,
+                      "::mlir::bufferization::BufferizationState &":$bufferizationState),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           auto bufferizableOp =
               ::llvm::cast<BufferizableOpInterface>($_op.getOperation());
           return bufferizableOp.resolveTensorOpOperandConflicts(
-              rewriter, state);
+              rewriter, analysisState, bufferizationState);
         }]
       >,
       InterfaceMethod<
@@ -523,6 +524,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
         /*methodName=*/"getBufferType",
         /*args=*/(ins "::mlir::Value":$value,
                       "const ::mlir::bufferization::BufferizationOptions &":$options,
+                      "::mlir::bufferization::BufferizationState &":$state,
                       "::llvm::SmallVector<::mlir::Value> &":$invocationStack),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
@@ -531,7 +533,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           assert(invocationStack.back() == value &&
                  "inconsistant invocation stack");
           return ::mlir::bufferization::detail::defaultGetBufferType(
-              value, options, invocationStack);
+              value, options, state, invocationStack);
         }]
       >,
       InterfaceMethod<
@@ -616,7 +618,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
     /// form of `bufferization.alloc_tensor` ops.
     ::llvm::LogicalResult resolveTensorOpOperandConflicts(
         ::mlir::RewriterBase &rewriter,
-        const ::mlir::bufferization::AnalysisState &state);
+        const ::mlir::bufferization::AnalysisState &analysisState,
+        ::mlir::bufferization::BufferizationState &bufferizationState);
 
     /// Return `true` if the given OpOperand creates an alias but does neither
     /// read nor write. This implies that `bufferizesToMemoryRead` and
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index dafa4b9b183f2..0ee4f79144158 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -112,6 +112,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
 
     FailureOr<BaseMemRefType> getBufferType(
         Value value, const BufferizationOptions &options,
+        BufferizationState &state,
         SmallVector<Value> &invocationStack);
 
     RankedTensorType getType() {
@@ -471,7 +472,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
 
     FailureOr<BaseMemRefType> getBufferType(
         Value value, const BufferizationOptions &options,
-        SmallVector<Value> &invocationStack) {
+        BufferizationState &state, SmallVector<Value> &invocationStack) {
       return ::llvm::cast<BaseMemRefType>(getMemref().getType());
     }
   }];
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index cf86b9a23f59e..7c07f705c8435 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -34,12 +34,13 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+    BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
     // Note: The user may want to override this function for OpResults in
     // case the bufferized result type is different from the bufferized type of
     // the aliasing OpOperand (if any).
     if (isa<OpResult>(value))
-      return bufferization::detail::defaultGetBufferType(value, options,
+      return bufferization::detail::defaultGetBufferType(value, options, state,
                                                          invocationStack);
 
     // Compute the buffer type of the block argument by computing the bufferized
@@ -65,7 +66,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
         callerType = memrefType;
       } else {
         FailureOr<BaseMemRefType> maybeCallerType =
-            bufferization::getBufferType(opOperand->get(), options,
+            bufferization::getBufferType(opOperand->get(), options, state,
                                          invocationStack);
         if (failed(maybeCallerType))
           return failure();
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
index 70e3defee0867..c1f5654abbf9b 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
@@ -62,7 +62,8 @@ LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
 /// `BufferizableOpInterface`. The buffer types of tensor block arguments are
 /// computed with `BufferizableOpIntercace::getBufferType`.
 LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
-                                      const BufferizationOptions &options);
+                                      const BufferizationOptions &options,
+                                      BufferizationState &state);
 
 } // namespace bufferization
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
index a4ee893ca5341..e587753ddebee 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
@@ -75,12 +75,15 @@ void hoistBuffersFromLoops(Operation *op);
 /// additional buffer allocations.
 LogicalResult insertTensorCopies(Operation *op,
                                  const OneShotBufferizationOptions &options,
+                                 BufferizationState &bufferizationState,
                                  BufferizationStatistics *statistics = nullptr);
 
 /// Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
 /// After applying this transform, the IR can be bufferized without inserting
 /// additional buffer allocations.
-LogicalResult insertTensorCopies(Operation *op, const AnalysisState &state);
+LogicalResult insertTensorCopies(Operation *op,
+                                 const AnalysisState &analysisState,
+                                 BufferizationState &bufferizationState);
 
 /// Populate patterns to lower tensor.empty ops to bufferization.alloc_tensor
 /// ops.
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index f646326ffc58f..0389a984e169c 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -90,7 +90,8 @@ struct IndexCastOpInterface
     auto castOp = cast<arith::IndexCastOp>(op);
     auto resultTensorType = cast<TensorType>(castOp.getType());
 
-    FailureOr<Value> source = getBuffer(rewriter, castOp.getIn(), options);
+    FailureOr<Value> source =
+        getBuffer(rewriter, castOp.getIn(), options, state);
     if (failed(source))
       return failure();
     auto sourceType = cast<BaseMemRefType>(source->getType());
@@ -151,9 +152,9 @@ struct SelectOpInterface
     // the moment (one for each tensor). When copying the op result, only one
     // copy would be needed.
     FailureOr<Value> maybeTrueBuffer =
-        getBuffer(rewriter, selectOp.getTrueValue(), options);
+        getBuffer(rewriter, selectOp.getTrueValue(), options, state);
     FailureOr<Value> maybeFalseBuffer =
-        getBuffer(rewriter, selectOp.getFalseValue(), options);
+        getBuffer(rewriter, selectOp.getFalseValue(), options, state);
     if (failed(maybeTrueBuffer) || failed(maybeFalseBuffer))
       return failure();
     Value trueBuffer = *maybeTrueBuffer;
@@ -164,7 +165,7 @@ struct SelectOpInterface
     // both of them to the most dynamic MemRef type.
     if (trueBuffer.getType() != falseBuffer.getType()) {
       auto targetType =
-          bufferization::getBufferType(selectOp.getResult(), options);
+          bufferization::getBufferType(selectOp.getResult(), options, state);
       if (failed(targetType))
         return failure();
       if (trueBuffer.getType() != *targetType)
@@ -182,13 +183,14 @@ struct SelectOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+                BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
     auto selectOp = cast<arith::SelectOp>(op);
     assert(value == selectOp.getResult() && "invalid value");
-    auto trueType = bufferization::getBufferType(selectOp.getTrueValue(),
-                                                 options, invocationStack);
-    auto falseType = bufferization::getBufferType(selectOp.getFalseValue(),
-                                                  options, invocationStack);
+    auto trueType = bufferization::getBufferType(
+        selectOp.getTrueValue(), options, state, invocationStack);
+    auto falseType = bufferization::getBufferType(
+        selectOp.getFalseValue(), options, state, invocationStack);
     if (failed(trueType) || failed(falseType))
       return failure();
     if (*trueType == *falseType)
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 14fa4c1ed8159..7d67d4a33ac32 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -165,7 +165,7 @@ Operation *bufferization::getOwnerOfValue(Value value) {
 /// allocated.
 FailureOr<Value> bufferization::allocateTensorForShapedValue(
     OpBuilder &b, Location loc, Value shapedValue,
-    const BufferizationOptions &options, bool copy) {
+    const BufferizationOptions &options, BufferizationState &state, bool copy) {
   Value tensor;
   if (llvm::isa<RankedTensorType>(shapedValue.getType())) {
     tensor = shapedValue;
@@ -210,7 +210,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
   // Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
   if (copy)
     return allocTensorOp.getResult();
-  FailureOr<BaseMemRefType> copyBufferType = getBufferType(tensor, options);
+  FailureOr<BaseMemRefType> copyBufferType =
+      getBufferType(tensor, options, state);
   if (failed(copyBufferType))
     return failure();
   std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace();
@@ -222,7 +223,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
 }
 
 LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
-    RewriterBase &rewriter, const AnalysisState &state) {
+    RewriterBase &rewriter, const AnalysisState &analysisState,
+    BufferizationState &bufferizationState) {
   OpBuilder::InsertionGuard g(rewriter);
   Operation *op = getOperation();
   SmallVector<OpOperand *> outOfPlaceOpOperands;
@@ -235,16 +237,18 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
     Type operandType = opOperand.get().getType();
     if (!llvm::isa<TensorType>(operandType))
       continue;
-    if (state.isInPlace(opOperand))
+    if (analysisState.isInPlace(opOperand))
       continue;
     if (llvm::isa<UnrankedTensorType>(operandType))
       return op->emitError("copying of unranked tensors is not implemented");
 
-    AliasingValueList aliasingValues = state.getAliasingValues(opOperand);
+    AliasingValueList aliasingValues =
+        analysisState.getAliasingValues(opOperand);
     if (aliasingValues.getNumAliases() == 1 &&
         isa<OpResult>(aliasingValues.getAliases()[0].value) &&
-        !state.bufferizesToMemoryWrite(opOperand) &&
-        state.getAliasingOpOperands(aliasingValues.getAliases()[0].value)
+        !analysisState.bufferizesToMemoryWrite(opOperand) &&
+        analysisState
+                .getAliasingOpOperands(aliasingValues.getAliases()[0].value)
                 .getNumAliases() == 1 &&
         !isa<UnrankedTensorType>(
             aliasingValues.getAliases()[0].value.getType())) {
@@ -256,12 +260,12 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
       // cannot be copied at the moment).
       Value value = aliasingValues.getAliases()[0].value;
       outOfPlaceValues.push_back(value);
-      if (!state.canOmitTensorCopy(opOperand))
+      if (!analysisState.canOmitTensorCopy(opOperand))
         copiedOpValues.insert(value);
     } else {
       // In all other cases, make a copy of the OpOperand.
       outOfPlaceOpOperands.push_back(&opOperand);
-      if (!state.canOmitTensorCopy(opOperand))
+      if (!analysisState.canOmitTensorCopy(opOperand))
         copiedOpOperands.insert(&opOperand);
     }
   }
@@ -270,8 +274,8 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
   rewriter.setInsertionPoint(op);
   for (OpOperand *opOperand : outOfPlaceOpOperands) {
     FailureOr<Value> copy = allocateTensorForShapedValue(
-        rewriter, op->getLoc(), opOperand->get(), state.getOptions(),
-        copiedOpOperands.contains(opOperand));
+        rewriter, op->getLoc(), opOperand->get(), analysisState.getOptions(),
+        bufferizationState, copiedOpOperands.contains(opOperand));
     if (failed(copy))
       return failure();
     rewriter.modifyOpInPlace(op, [&]() { opOperand->set(*copy); });
@@ -281,8 +285,8 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
   rewriter.setInsertionPointAfter(op);
   for (Value value : outOfPlaceValues) {
     FailureOr<Value> copy = allocateTensorForShapedValue(
-        rewriter, op->getLoc(), value, state.getOptions(),
-        copiedOpValues.count(value));
+        rewriter, op->getLoc(), value, analysisState.getOptions(),
+        bufferizationState, copiedOpValues.count(value));
     if (failed(copy))
       return failure();
     SmallVector<OpOperand *> uses = llvm::to_vector(
@@ -665,7 +669,8 @@ static void ensureToBufferOpIsValid(Value tensor, Type memrefType) {
 }
 
 FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
-                                          const BufferizationOptions &options) {
+                                          const BufferizationOptions &options,
+                                          BufferizationState &state) {
 #ifndef NDEBUG
   auto tensorType = llvm::dyn_cast<TensorType>(value.getType());
   assert(tensorType && "unexpected non-tensor type");
@@ -678,7 +683,7 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
   // Insert to_buffer op.
   OpBuilder::InsertionGuard g(rewriter);
   setInsertionPointAfter(rewriter, value);
-  FailureOr<BaseMemRefType> memrefType = getBufferType(value, options);
+  FailureOr<BaseMemRefType> memrefType = getBufferType(value, options, state);
   if (failed(memrefType))
     return failure();
   ensureToBufferOpIsValid(value, *memrefType);
@@ -689,14 +694,16 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
 
 /// Return the buffer type for a given Value (tensor) after bufferization.
 FailureOr<BaseMemRefType>
-bufferization::getBufferType(Value value, const BufferizationOptions &options) {
+bufferization::getBufferType(Value value, const BufferizationOptions &options,
+                             BufferizationState &state) {
   SmallVector<Value> invocationStack;
-  return getBufferType(value, options, invocationStack);
+  return getBufferType(value, options, state, invocationStack);
 }
 
 /// Return the buffer type for a given Value (tensor) after bufferization.
 FailureOr<BaseMemRefType>
 bufferization::getBufferType(Value value, const BufferizationOptions &options,
+                             BufferizationState &state,
                              SmallVector<Value> &invocationStack) {
   assert(llvm::isa<TensorType>(value.getType()) &&
          "unexpected non-tensor type");
@@ -708,7 +715,7 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
   Operation *op = getOwnerOfValue(value);
   auto bufferizableOp = options.dynCastBufferizableOp(op);
   if (bufferizableOp)
-    return bufferizableOp.getBufferType(value, options, invocationStack);
+    return bufferizableOp.getBufferType(value, options, state, invocationStack);
 
   // Op is not bufferizable.
   auto memSpace =
@@ -944,6 +951,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
 
 Failur...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 26, 2025

@llvm/pr-subscribers-mlir-shape

Author: Michele Scuttari (mscuttari)

Changes

The PR continues the work started in #141019 by adding the BufferizationState class also to the getBufferType and resolveConflicts interface methods, together with the additional support functions that are used throughout the bufferization infrastructure.


Patch is 76.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/141466.diff

22 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h (+7-3)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td (+7-4)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td (+2-1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h (+3-2)
  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h (+2-1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h (+4-1)
  • (modified) mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp (+10-8)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+30-21)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+10-7)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (+6-5)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+8-5)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp (+3-2)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp (+12-9)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp (+13-13)
  • (modified) mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp (+2-1)
  • (modified) mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp (+61-44)
  • (modified) mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp (+1-1)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp (+4-2)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+46-37)
  • (modified) mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp (+11-6)
  • (modified) mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp (+5-1)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 43c97d57e1834..328d928c9ebdb 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -598,13 +598,14 @@ class BufferizationState {
 FailureOr<Value>
 allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue,
                              const BufferizationOptions &options,
-                             bool copy = true);
+                             BufferizationState &state, bool copy = true);
 
 /// Lookup the buffer for the given value. If the value was not bufferized
 /// yet, wrap it in a ToBufferOp. Otherwise, it is the result of a ToTensorOp,
 /// from which the memref operand is returned.
 FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
-                           const BufferizationOptions &options);
+                           const BufferizationOptions &options,
+                           BufferizationState &state);
 
 /// Return the buffer type for a given Value (tensor) after bufferization
 /// without bufferizing any IR.
@@ -615,7 +616,8 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
 ///
 /// This function is a wrapper around BufferizableOpInterface::getBufferType.
 FailureOr<BaseMemRefType> getBufferType(Value value,
-                                        const BufferizationOptions &options);
+                                        const BufferizationOptions &options,
+                                        BufferizationState &state);
 
 /// Return the buffer type for a given Value (tensor) after bufferization
 /// without bufferizing any IR. This function (and not the other overload
@@ -629,6 +631,7 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
 /// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
 FailureOr<BaseMemRefType> getBufferType(Value value,
                                         const BufferizationOptions &options,
+                                        BufferizationState &state,
                                         SmallVector<Value> &invocationStack);
 
 /// Return "true" if the given op has tensor semantics and should be bufferized.
@@ -709,6 +712,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
 /// places.
 FailureOr<BaseMemRefType>
 defaultGetBufferType(Value value, const BufferizationOptions &options,
+                     BufferizationState &state,
                      SmallVector<Value> &invocationStack);
 
 /// This is the default implementation of
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index b599a9f053215..80f9b72531660 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -381,13 +381,14 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
         /*retType=*/"::llvm::LogicalResult",
         /*methodName=*/"resolveConflicts",
         /*args=*/(ins "::mlir::RewriterBase &":$rewriter,
-                      "const ::mlir::bufferization::AnalysisState &":$state),
+                      "const ::mlir::bufferization::AnalysisState &":$analysisState,
+                      "::mlir::bufferization::BufferizationState &":$bufferizationState),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           auto bufferizableOp =
               ::llvm::cast<BufferizableOpInterface>($_op.getOperation());
           return bufferizableOp.resolveTensorOpOperandConflicts(
-              rewriter, state);
+              rewriter, analysisState, bufferizationState);
         }]
       >,
       InterfaceMethod<
@@ -523,6 +524,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
         /*methodName=*/"getBufferType",
         /*args=*/(ins "::mlir::Value":$value,
                       "const ::mlir::bufferization::BufferizationOptions &":$options,
+                      "::mlir::bufferization::BufferizationState &":$state,
                       "::llvm::SmallVector<::mlir::Value> &":$invocationStack),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
@@ -531,7 +533,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           assert(invocationStack.back() == value &&
                  "inconsistant invocation stack");
           return ::mlir::bufferization::detail::defaultGetBufferType(
-              value, options, invocationStack);
+              value, options, state, invocationStack);
         }]
       >,
       InterfaceMethod<
@@ -616,7 +618,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
     /// form of `bufferization.alloc_tensor` ops.
     ::llvm::LogicalResult resolveTensorOpOperandConflicts(
         ::mlir::RewriterBase &rewriter,
-        const ::mlir::bufferization::AnalysisState &state);
+        const ::mlir::bufferization::AnalysisState &analysisState,
+        ::mlir::bufferization::BufferizationState &bufferizationState);
 
     /// Return `true` if the given OpOperand creates an alias but does neither
     /// read nor write. This implies that `bufferizesToMemoryRead` and
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index dafa4b9b183f2..0ee4f79144158 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -112,6 +112,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
 
     FailureOr<BaseMemRefType> getBufferType(
         Value value, const BufferizationOptions &options,
+        BufferizationState &state,
         SmallVector<Value> &invocationStack);
 
     RankedTensorType getType() {
@@ -471,7 +472,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
 
     FailureOr<BaseMemRefType> getBufferType(
         Value value, const BufferizationOptions &options,
-        SmallVector<Value> &invocationStack) {
+        BufferizationState &state, SmallVector<Value> &invocationStack) {
       return ::llvm::cast<BaseMemRefType>(getMemref().getType());
     }
   }];
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index cf86b9a23f59e..7c07f705c8435 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -34,12 +34,13 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+    BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
     // Note: The user may want to override this function for OpResults in
     // case the bufferized result type is different from the bufferized type of
     // the aliasing OpOperand (if any).
     if (isa<OpResult>(value))
-      return bufferization::detail::defaultGetBufferType(value, options,
+      return bufferization::detail::defaultGetBufferType(value, options, state,
                                                          invocationStack);
 
     // Compute the buffer type of the block argument by computing the bufferized
@@ -65,7 +66,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
         callerType = memrefType;
       } else {
         FailureOr<BaseMemRefType> maybeCallerType =
-            bufferization::getBufferType(opOperand->get(), options,
+            bufferization::getBufferType(opOperand->get(), options, state,
                                          invocationStack);
         if (failed(maybeCallerType))
           return failure();
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
index 70e3defee0867..c1f5654abbf9b 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
@@ -62,7 +62,8 @@ LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
 /// `BufferizableOpInterface`. The buffer types of tensor block arguments are
 /// computed with `BufferizableOpIntercace::getBufferType`.
 LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
-                                      const BufferizationOptions &options);
+                                      const BufferizationOptions &options,
+                                      BufferizationState &state);
 
 } // namespace bufferization
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
index a4ee893ca5341..e587753ddebee 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
@@ -75,12 +75,15 @@ void hoistBuffersFromLoops(Operation *op);
 /// additional buffer allocations.
 LogicalResult insertTensorCopies(Operation *op,
                                  const OneShotBufferizationOptions &options,
+                                 BufferizationState &bufferizationState,
                                  BufferizationStatistics *statistics = nullptr);
 
 /// Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
 /// After applying this transform, the IR can be bufferized without inserting
 /// additional buffer allocations.
-LogicalResult insertTensorCopies(Operation *op, const AnalysisState &state);
+LogicalResult insertTensorCopies(Operation *op,
+                                 const AnalysisState &analysisState,
+                                 BufferizationState &bufferizationState);
 
 /// Populate patterns to lower tensor.empty ops to bufferization.alloc_tensor
 /// ops.
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index f646326ffc58f..0389a984e169c 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -90,7 +90,8 @@ struct IndexCastOpInterface
     auto castOp = cast<arith::IndexCastOp>(op);
     auto resultTensorType = cast<TensorType>(castOp.getType());
 
-    FailureOr<Value> source = getBuffer(rewriter, castOp.getIn(), options);
+    FailureOr<Value> source =
+        getBuffer(rewriter, castOp.getIn(), options, state);
     if (failed(source))
       return failure();
     auto sourceType = cast<BaseMemRefType>(source->getType());
@@ -151,9 +152,9 @@ struct SelectOpInterface
     // the moment (one for each tensor). When copying the op result, only one
     // copy would be needed.
     FailureOr<Value> maybeTrueBuffer =
-        getBuffer(rewriter, selectOp.getTrueValue(), options);
+        getBuffer(rewriter, selectOp.getTrueValue(), options, state);
     FailureOr<Value> maybeFalseBuffer =
-        getBuffer(rewriter, selectOp.getFalseValue(), options);
+        getBuffer(rewriter, selectOp.getFalseValue(), options, state);
     if (failed(maybeTrueBuffer) || failed(maybeFalseBuffer))
       return failure();
     Value trueBuffer = *maybeTrueBuffer;
@@ -164,7 +165,7 @@ struct SelectOpInterface
     // both of them to the most dynamic MemRef type.
     if (trueBuffer.getType() != falseBuffer.getType()) {
       auto targetType =
-          bufferization::getBufferType(selectOp.getResult(), options);
+          bufferization::getBufferType(selectOp.getResult(), options, state);
       if (failed(targetType))
         return failure();
       if (trueBuffer.getType() != *targetType)
@@ -182,13 +183,14 @@ struct SelectOpInterface
 
   FailureOr<BaseMemRefType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+                BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
     auto selectOp = cast<arith::SelectOp>(op);
     assert(value == selectOp.getResult() && "invalid value");
-    auto trueType = bufferization::getBufferType(selectOp.getTrueValue(),
-                                                 options, invocationStack);
-    auto falseType = bufferization::getBufferType(selectOp.getFalseValue(),
-                                                  options, invocationStack);
+    auto trueType = bufferization::getBufferType(
+        selectOp.getTrueValue(), options, state, invocationStack);
+    auto falseType = bufferization::getBufferType(
+        selectOp.getFalseValue(), options, state, invocationStack);
     if (failed(trueType) || failed(falseType))
       return failure();
     if (*trueType == *falseType)
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 14fa4c1ed8159..7d67d4a33ac32 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -165,7 +165,7 @@ Operation *bufferization::getOwnerOfValue(Value value) {
 /// allocated.
 FailureOr<Value> bufferization::allocateTensorForShapedValue(
     OpBuilder &b, Location loc, Value shapedValue,
-    const BufferizationOptions &options, bool copy) {
+    const BufferizationOptions &options, BufferizationState &state, bool copy) {
   Value tensor;
   if (llvm::isa<RankedTensorType>(shapedValue.getType())) {
     tensor = shapedValue;
@@ -210,7 +210,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
   // Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
   if (copy)
     return allocTensorOp.getResult();
-  FailureOr<BaseMemRefType> copyBufferType = getBufferType(tensor, options);
+  FailureOr<BaseMemRefType> copyBufferType =
+      getBufferType(tensor, options, state);
   if (failed(copyBufferType))
     return failure();
   std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace();
@@ -222,7 +223,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
 }
 
 LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
-    RewriterBase &rewriter, const AnalysisState &state) {
+    RewriterBase &rewriter, const AnalysisState &analysisState,
+    BufferizationState &bufferizationState) {
   OpBuilder::InsertionGuard g(rewriter);
   Operation *op = getOperation();
   SmallVector<OpOperand *> outOfPlaceOpOperands;
@@ -235,16 +237,18 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
     Type operandType = opOperand.get().getType();
     if (!llvm::isa<TensorType>(operandType))
       continue;
-    if (state.isInPlace(opOperand))
+    if (analysisState.isInPlace(opOperand))
       continue;
     if (llvm::isa<UnrankedTensorType>(operandType))
       return op->emitError("copying of unranked tensors is not implemented");
 
-    AliasingValueList aliasingValues = state.getAliasingValues(opOperand);
+    AliasingValueList aliasingValues =
+        analysisState.getAliasingValues(opOperand);
     if (aliasingValues.getNumAliases() == 1 &&
         isa<OpResult>(aliasingValues.getAliases()[0].value) &&
-        !state.bufferizesToMemoryWrite(opOperand) &&
-        state.getAliasingOpOperands(aliasingValues.getAliases()[0].value)
+        !analysisState.bufferizesToMemoryWrite(opOperand) &&
+        analysisState
+                .getAliasingOpOperands(aliasingValues.getAliases()[0].value)
                 .getNumAliases() == 1 &&
         !isa<UnrankedTensorType>(
             aliasingValues.getAliases()[0].value.getType())) {
@@ -256,12 +260,12 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
       // cannot be copied at the moment).
       Value value = aliasingValues.getAliases()[0].value;
       outOfPlaceValues.push_back(value);
-      if (!state.canOmitTensorCopy(opOperand))
+      if (!analysisState.canOmitTensorCopy(opOperand))
         copiedOpValues.insert(value);
     } else {
       // In all other cases, make a copy of the OpOperand.
       outOfPlaceOpOperands.push_back(&opOperand);
-      if (!state.canOmitTensorCopy(opOperand))
+      if (!analysisState.canOmitTensorCopy(opOperand))
         copiedOpOperands.insert(&opOperand);
     }
   }
@@ -270,8 +274,8 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
   rewriter.setInsertionPoint(op);
   for (OpOperand *opOperand : outOfPlaceOpOperands) {
     FailureOr<Value> copy = allocateTensorForShapedValue(
-        rewriter, op->getLoc(), opOperand->get(), state.getOptions(),
-        copiedOpOperands.contains(opOperand));
+        rewriter, op->getLoc(), opOperand->get(), analysisState.getOptions(),
+        bufferizationState, copiedOpOperands.contains(opOperand));
     if (failed(copy))
       return failure();
     rewriter.modifyOpInPlace(op, [&]() { opOperand->set(*copy); });
@@ -281,8 +285,8 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
   rewriter.setInsertionPointAfter(op);
   for (Value value : outOfPlaceValues) {
     FailureOr<Value> copy = allocateTensorForShapedValue(
-        rewriter, op->getLoc(), value, state.getOptions(),
-        copiedOpValues.count(value));
+        rewriter, op->getLoc(), value, analysisState.getOptions(),
+        bufferizationState, copiedOpValues.count(value));
     if (failed(copy))
       return failure();
     SmallVector<OpOperand *> uses = llvm::to_vector(
@@ -665,7 +669,8 @@ static void ensureToBufferOpIsValid(Value tensor, Type memrefType) {
 }
 
 FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
-                                          const BufferizationOptions &options) {
+                                          const BufferizationOptions &options,
+                                          BufferizationState &state) {
 #ifndef NDEBUG
   auto tensorType = llvm::dyn_cast<TensorType>(value.getType());
   assert(tensorType && "unexpected non-tensor type");
@@ -678,7 +683,7 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
   // Insert to_buffer op.
   OpBuilder::InsertionGuard g(rewriter);
   setInsertionPointAfter(rewriter, value);
-  FailureOr<BaseMemRefType> memrefType = getBufferType(value, options);
+  FailureOr<BaseMemRefType> memrefType = getBufferType(value, options, state);
   if (failed(memrefType))
     return failure();
   ensureToBufferOpIsValid(value, *memrefType);
@@ -689,14 +694,16 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
 
 /// Return the buffer type for a given Value (tensor) after bufferization.
 FailureOr<BaseMemRefType>
-bufferization::getBufferType(Value value, const BufferizationOptions &options) {
+bufferization::getBufferType(Value value, const BufferizationOptions &options,
+                             BufferizationState &state) {
   SmallVector<Value> invocationStack;
-  return getBufferType(value, options, invocationStack);
+  return getBufferType(value, options, state, invocationStack);
 }
 
 /// Return the buffer type for a given Value (tensor) after bufferization.
 FailureOr<BaseMemRefType>
 bufferization::getBufferType(Value value, const BufferizationOptions &options,
+                             BufferizationState &state,
                              SmallVector<Value> &invocationStack) {
   assert(llvm::isa<TensorType>(value.getType()) &&
          "unexpected non-tensor type");
@@ -708,7 +715,7 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
   Operation *op = getOwnerOfValue(value);
   auto bufferizableOp = options.dynCastBufferizableOp(op);
   if (bufferizableOp)
-    return bufferizableOp.getBufferType(value, options, invocationStack);
+    return bufferizableOp.getBufferType(value, options, state, invocationStack);
 
   // Op is not bufferizable.
   auto memSpace =
@@ -944,6 +951,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
 
 Failur...
[truncated]

Copy link

github-actions bot commented May 26, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@@ -615,7 +616,8 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
///
/// This function is a wrapper around BufferizableOpInterface::getBufferType.
FailureOr<BaseMemRefType> getBufferType(Value value,
const BufferizationOptions &options);
const BufferizationOptions &options,
BufferizationState &state);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be a const reference? Here and in all the other places that this PR touches?

@@ -381,13 +381,14 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*retType=*/"::llvm::LogicalResult",
/*methodName=*/"resolveConflicts",
/*args=*/(ins "::mlir::RewriterBase &":$rewriter,
"const ::mlir::bufferization::AnalysisState &":$state),
"const ::mlir::bufferization::AnalysisState &":$analysisState,
"::mlir::bufferization::BufferizationState &":$bufferizationState),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Const possible here?

@mscuttari mscuttari merged commit 63cb6af into llvm:main May 28, 2025
9 of 11 checks passed
@mscuttari mscuttari deleted the bufferization-state branch May 29, 2025 15:29
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Jun 3, 2025
…ts` interface methods (llvm#141466)

The PR continues the work started in llvm#141019 by adding the `BufferizationState` class also to the `getBufferType` and `resolveConflicts` interface methods, together with the additional support functions that are used throughout the bufferization infrastructure.
rolfmorel added a commit to libxsmm/tpp-mlir that referenced this pull request Jun 12, 2025
* llvm/llvm-project#139340
```
sed -i "s/isZeroIndex/isZeroInteger/g" **/*.h
sed -i "s/isZeroIndex/isZeroInteger/g" **/*.cpp
```

* llvm/llvm-project#141466 &
llvm/llvm-project#141019
  * Add `BufferizationState &state` to `bufferize` and `getBuffer` 

* llvm/llvm-project#143159 &
llvm/llvm-project#142683 &
llvm/llvm-project#143779
  * Updates to `transform.apply_registered_pass` and its Python-bindings

* llvm/llvm-project#143217
* `tilingResult->mergeResult.replacements` ->
`tilingResult->replacements`

* llvm/llvm-project#140559 &
llvm/llvm-project#143871
* Change CHECK lines which expected `amx.` ops to `llvm.call_intrinsic`s
& fix which enables conversion again.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants