-
Notifications
You must be signed in to change notification settings - Fork 14.2k
[MLIR] Add bufferization state to getBufferType
and resolveConflicts
interface methods
#141466
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ts` interface methods
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir-mlprogram Author: Michele Scuttari (mscuttari) ChangesThe PR continues the work started in #141019 by adding the Patch is 76.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/141466.diff 22 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 43c97d57e1834..328d928c9ebdb 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -598,13 +598,14 @@ class BufferizationState {
FailureOr<Value>
allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue,
const BufferizationOptions &options,
- bool copy = true);
+ BufferizationState &state, bool copy = true);
/// Lookup the buffer for the given value. If the value was not bufferized
/// yet, wrap it in a ToBufferOp. Otherwise, it is the result of a ToTensorOp,
/// from which the memref operand is returned.
FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
- const BufferizationOptions &options);
+ const BufferizationOptions &options,
+ BufferizationState &state);
/// Return the buffer type for a given Value (tensor) after bufferization
/// without bufferizing any IR.
@@ -615,7 +616,8 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
///
/// This function is a wrapper around BufferizableOpInterface::getBufferType.
FailureOr<BaseMemRefType> getBufferType(Value value,
- const BufferizationOptions &options);
+ const BufferizationOptions &options,
+ BufferizationState &state);
/// Return the buffer type for a given Value (tensor) after bufferization
/// without bufferizing any IR. This function (and not the other overload
@@ -629,6 +631,7 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
/// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
FailureOr<BaseMemRefType> getBufferType(Value value,
const BufferizationOptions &options,
+ BufferizationState &state,
SmallVector<Value> &invocationStack);
/// Return "true" if the given op has tensor semantics and should be bufferized.
@@ -709,6 +712,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
/// places.
FailureOr<BaseMemRefType>
defaultGetBufferType(Value value, const BufferizationOptions &options,
+ BufferizationState &state,
SmallVector<Value> &invocationStack);
/// This is the default implementation of
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index b599a9f053215..80f9b72531660 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -381,13 +381,14 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*retType=*/"::llvm::LogicalResult",
/*methodName=*/"resolveConflicts",
/*args=*/(ins "::mlir::RewriterBase &":$rewriter,
- "const ::mlir::bufferization::AnalysisState &":$state),
+ "const ::mlir::bufferization::AnalysisState &":$analysisState,
+ "::mlir::bufferization::BufferizationState &":$bufferizationState),
/*methodBody=*/"",
/*defaultImplementation=*/[{
auto bufferizableOp =
::llvm::cast<BufferizableOpInterface>($_op.getOperation());
return bufferizableOp.resolveTensorOpOperandConflicts(
- rewriter, state);
+ rewriter, analysisState, bufferizationState);
}]
>,
InterfaceMethod<
@@ -523,6 +524,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*methodName=*/"getBufferType",
/*args=*/(ins "::mlir::Value":$value,
"const ::mlir::bufferization::BufferizationOptions &":$options,
+ "::mlir::bufferization::BufferizationState &":$state,
"::llvm::SmallVector<::mlir::Value> &":$invocationStack),
/*methodBody=*/"",
/*defaultImplementation=*/[{
@@ -531,7 +533,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
assert(invocationStack.back() == value &&
"inconsistant invocation stack");
return ::mlir::bufferization::detail::defaultGetBufferType(
- value, options, invocationStack);
+ value, options, state, invocationStack);
}]
>,
InterfaceMethod<
@@ -616,7 +618,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/// form of `bufferization.alloc_tensor` ops.
::llvm::LogicalResult resolveTensorOpOperandConflicts(
::mlir::RewriterBase &rewriter,
- const ::mlir::bufferization::AnalysisState &state);
+ const ::mlir::bufferization::AnalysisState &analysisState,
+ ::mlir::bufferization::BufferizationState &bufferizationState);
/// Return `true` if the given OpOperand creates an alias but does neither
/// read nor write. This implies that `bufferizesToMemoryRead` and
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index dafa4b9b183f2..0ee4f79144158 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -112,6 +112,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
FailureOr<BaseMemRefType> getBufferType(
Value value, const BufferizationOptions &options,
+ BufferizationState &state,
SmallVector<Value> &invocationStack);
RankedTensorType getType() {
@@ -471,7 +472,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
FailureOr<BaseMemRefType> getBufferType(
Value value, const BufferizationOptions &options,
- SmallVector<Value> &invocationStack) {
+ BufferizationState &state, SmallVector<Value> &invocationStack) {
return ::llvm::cast<BaseMemRefType>(getMemref().getType());
}
}];
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index cf86b9a23f59e..7c07f705c8435 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -34,12 +34,13 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+ BufferizationState &state,
SmallVector<Value> &invocationStack) const {
// Note: The user may want to override this function for OpResults in
// case the bufferized result type is different from the bufferized type of
// the aliasing OpOperand (if any).
if (isa<OpResult>(value))
- return bufferization::detail::defaultGetBufferType(value, options,
+ return bufferization::detail::defaultGetBufferType(value, options, state,
invocationStack);
// Compute the buffer type of the block argument by computing the bufferized
@@ -65,7 +66,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
callerType = memrefType;
} else {
FailureOr<BaseMemRefType> maybeCallerType =
- bufferization::getBufferType(opOperand->get(), options,
+ bufferization::getBufferType(opOperand->get(), options, state,
invocationStack);
if (failed(maybeCallerType))
return failure();
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
index 70e3defee0867..c1f5654abbf9b 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
@@ -62,7 +62,8 @@ LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
/// `BufferizableOpInterface`. The buffer types of tensor block arguments are
/// computed with `BufferizableOpIntercace::getBufferType`.
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
- const BufferizationOptions &options);
+ const BufferizationOptions &options,
+ BufferizationState &state);
} // namespace bufferization
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
index a4ee893ca5341..e587753ddebee 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
@@ -75,12 +75,15 @@ void hoistBuffersFromLoops(Operation *op);
/// additional buffer allocations.
LogicalResult insertTensorCopies(Operation *op,
const OneShotBufferizationOptions &options,
+ BufferizationState &bufferizationState,
BufferizationStatistics *statistics = nullptr);
/// Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
/// After applying this transform, the IR can be bufferized without inserting
/// additional buffer allocations.
-LogicalResult insertTensorCopies(Operation *op, const AnalysisState &state);
+LogicalResult insertTensorCopies(Operation *op,
+ const AnalysisState &analysisState,
+ BufferizationState &bufferizationState);
/// Populate patterns to lower tensor.empty ops to bufferization.alloc_tensor
/// ops.
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index f646326ffc58f..0389a984e169c 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -90,7 +90,8 @@ struct IndexCastOpInterface
auto castOp = cast<arith::IndexCastOp>(op);
auto resultTensorType = cast<TensorType>(castOp.getType());
- FailureOr<Value> source = getBuffer(rewriter, castOp.getIn(), options);
+ FailureOr<Value> source =
+ getBuffer(rewriter, castOp.getIn(), options, state);
if (failed(source))
return failure();
auto sourceType = cast<BaseMemRefType>(source->getType());
@@ -151,9 +152,9 @@ struct SelectOpInterface
// the moment (one for each tensor). When copying the op result, only one
// copy would be needed.
FailureOr<Value> maybeTrueBuffer =
- getBuffer(rewriter, selectOp.getTrueValue(), options);
+ getBuffer(rewriter, selectOp.getTrueValue(), options, state);
FailureOr<Value> maybeFalseBuffer =
- getBuffer(rewriter, selectOp.getFalseValue(), options);
+ getBuffer(rewriter, selectOp.getFalseValue(), options, state);
if (failed(maybeTrueBuffer) || failed(maybeFalseBuffer))
return failure();
Value trueBuffer = *maybeTrueBuffer;
@@ -164,7 +165,7 @@ struct SelectOpInterface
// both of them to the most dynamic MemRef type.
if (trueBuffer.getType() != falseBuffer.getType()) {
auto targetType =
- bufferization::getBufferType(selectOp.getResult(), options);
+ bufferization::getBufferType(selectOp.getResult(), options, state);
if (failed(targetType))
return failure();
if (trueBuffer.getType() != *targetType)
@@ -182,13 +183,14 @@ struct SelectOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+ BufferizationState &state,
SmallVector<Value> &invocationStack) const {
auto selectOp = cast<arith::SelectOp>(op);
assert(value == selectOp.getResult() && "invalid value");
- auto trueType = bufferization::getBufferType(selectOp.getTrueValue(),
- options, invocationStack);
- auto falseType = bufferization::getBufferType(selectOp.getFalseValue(),
- options, invocationStack);
+ auto trueType = bufferization::getBufferType(
+ selectOp.getTrueValue(), options, state, invocationStack);
+ auto falseType = bufferization::getBufferType(
+ selectOp.getFalseValue(), options, state, invocationStack);
if (failed(trueType) || failed(falseType))
return failure();
if (*trueType == *falseType)
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 14fa4c1ed8159..7d67d4a33ac32 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -165,7 +165,7 @@ Operation *bufferization::getOwnerOfValue(Value value) {
/// allocated.
FailureOr<Value> bufferization::allocateTensorForShapedValue(
OpBuilder &b, Location loc, Value shapedValue,
- const BufferizationOptions &options, bool copy) {
+ const BufferizationOptions &options, BufferizationState &state, bool copy) {
Value tensor;
if (llvm::isa<RankedTensorType>(shapedValue.getType())) {
tensor = shapedValue;
@@ -210,7 +210,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
// Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
if (copy)
return allocTensorOp.getResult();
- FailureOr<BaseMemRefType> copyBufferType = getBufferType(tensor, options);
+ FailureOr<BaseMemRefType> copyBufferType =
+ getBufferType(tensor, options, state);
if (failed(copyBufferType))
return failure();
std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace();
@@ -222,7 +223,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
}
LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
- RewriterBase &rewriter, const AnalysisState &state) {
+ RewriterBase &rewriter, const AnalysisState &analysisState,
+ BufferizationState &bufferizationState) {
OpBuilder::InsertionGuard g(rewriter);
Operation *op = getOperation();
SmallVector<OpOperand *> outOfPlaceOpOperands;
@@ -235,16 +237,18 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
Type operandType = opOperand.get().getType();
if (!llvm::isa<TensorType>(operandType))
continue;
- if (state.isInPlace(opOperand))
+ if (analysisState.isInPlace(opOperand))
continue;
if (llvm::isa<UnrankedTensorType>(operandType))
return op->emitError("copying of unranked tensors is not implemented");
- AliasingValueList aliasingValues = state.getAliasingValues(opOperand);
+ AliasingValueList aliasingValues =
+ analysisState.getAliasingValues(opOperand);
if (aliasingValues.getNumAliases() == 1 &&
isa<OpResult>(aliasingValues.getAliases()[0].value) &&
- !state.bufferizesToMemoryWrite(opOperand) &&
- state.getAliasingOpOperands(aliasingValues.getAliases()[0].value)
+ !analysisState.bufferizesToMemoryWrite(opOperand) &&
+ analysisState
+ .getAliasingOpOperands(aliasingValues.getAliases()[0].value)
.getNumAliases() == 1 &&
!isa<UnrankedTensorType>(
aliasingValues.getAliases()[0].value.getType())) {
@@ -256,12 +260,12 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
// cannot be copied at the moment).
Value value = aliasingValues.getAliases()[0].value;
outOfPlaceValues.push_back(value);
- if (!state.canOmitTensorCopy(opOperand))
+ if (!analysisState.canOmitTensorCopy(opOperand))
copiedOpValues.insert(value);
} else {
// In all other cases, make a copy of the OpOperand.
outOfPlaceOpOperands.push_back(&opOperand);
- if (!state.canOmitTensorCopy(opOperand))
+ if (!analysisState.canOmitTensorCopy(opOperand))
copiedOpOperands.insert(&opOperand);
}
}
@@ -270,8 +274,8 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
rewriter.setInsertionPoint(op);
for (OpOperand *opOperand : outOfPlaceOpOperands) {
FailureOr<Value> copy = allocateTensorForShapedValue(
- rewriter, op->getLoc(), opOperand->get(), state.getOptions(),
- copiedOpOperands.contains(opOperand));
+ rewriter, op->getLoc(), opOperand->get(), analysisState.getOptions(),
+ bufferizationState, copiedOpOperands.contains(opOperand));
if (failed(copy))
return failure();
rewriter.modifyOpInPlace(op, [&]() { opOperand->set(*copy); });
@@ -281,8 +285,8 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
rewriter.setInsertionPointAfter(op);
for (Value value : outOfPlaceValues) {
FailureOr<Value> copy = allocateTensorForShapedValue(
- rewriter, op->getLoc(), value, state.getOptions(),
- copiedOpValues.count(value));
+ rewriter, op->getLoc(), value, analysisState.getOptions(),
+ bufferizationState, copiedOpValues.count(value));
if (failed(copy))
return failure();
SmallVector<OpOperand *> uses = llvm::to_vector(
@@ -665,7 +669,8 @@ static void ensureToBufferOpIsValid(Value tensor, Type memrefType) {
}
FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
- const BufferizationOptions &options) {
+ const BufferizationOptions &options,
+ BufferizationState &state) {
#ifndef NDEBUG
auto tensorType = llvm::dyn_cast<TensorType>(value.getType());
assert(tensorType && "unexpected non-tensor type");
@@ -678,7 +683,7 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
// Insert to_buffer op.
OpBuilder::InsertionGuard g(rewriter);
setInsertionPointAfter(rewriter, value);
- FailureOr<BaseMemRefType> memrefType = getBufferType(value, options);
+ FailureOr<BaseMemRefType> memrefType = getBufferType(value, options, state);
if (failed(memrefType))
return failure();
ensureToBufferOpIsValid(value, *memrefType);
@@ -689,14 +694,16 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
/// Return the buffer type for a given Value (tensor) after bufferization.
FailureOr<BaseMemRefType>
-bufferization::getBufferType(Value value, const BufferizationOptions &options) {
+bufferization::getBufferType(Value value, const BufferizationOptions &options,
+ BufferizationState &state) {
SmallVector<Value> invocationStack;
- return getBufferType(value, options, invocationStack);
+ return getBufferType(value, options, state, invocationStack);
}
/// Return the buffer type for a given Value (tensor) after bufferization.
FailureOr<BaseMemRefType>
bufferization::getBufferType(Value value, const BufferizationOptions &options,
+ BufferizationState &state,
SmallVector<Value> &invocationStack) {
assert(llvm::isa<TensorType>(value.getType()) &&
"unexpected non-tensor type");
@@ -708,7 +715,7 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
Operation *op = getOwnerOfValue(value);
auto bufferizableOp = options.dynCastBufferizableOp(op);
if (bufferizableOp)
- return bufferizableOp.getBufferType(value, options, invocationStack);
+ return bufferizableOp.getBufferType(value, options, state, invocationStack);
// Op is not bufferizable.
auto memSpace =
@@ -944,6 +951,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
Failur...
[truncated]
|
@llvm/pr-subscribers-mlir-shape Author: Michele Scuttari (mscuttari) ChangesThe PR continues the work started in #141019 by adding the Patch is 76.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/141466.diff 22 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 43c97d57e1834..328d928c9ebdb 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -598,13 +598,14 @@ class BufferizationState {
FailureOr<Value>
allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue,
const BufferizationOptions &options,
- bool copy = true);
+ BufferizationState &state, bool copy = true);
/// Lookup the buffer for the given value. If the value was not bufferized
/// yet, wrap it in a ToBufferOp. Otherwise, it is the result of a ToTensorOp,
/// from which the memref operand is returned.
FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
- const BufferizationOptions &options);
+ const BufferizationOptions &options,
+ BufferizationState &state);
/// Return the buffer type for a given Value (tensor) after bufferization
/// without bufferizing any IR.
@@ -615,7 +616,8 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
///
/// This function is a wrapper around BufferizableOpInterface::getBufferType.
FailureOr<BaseMemRefType> getBufferType(Value value,
- const BufferizationOptions &options);
+ const BufferizationOptions &options,
+ BufferizationState &state);
/// Return the buffer type for a given Value (tensor) after bufferization
/// without bufferizing any IR. This function (and not the other overload
@@ -629,6 +631,7 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
/// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
FailureOr<BaseMemRefType> getBufferType(Value value,
const BufferizationOptions &options,
+ BufferizationState &state,
SmallVector<Value> &invocationStack);
/// Return "true" if the given op has tensor semantics and should be bufferized.
@@ -709,6 +712,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
/// places.
FailureOr<BaseMemRefType>
defaultGetBufferType(Value value, const BufferizationOptions &options,
+ BufferizationState &state,
SmallVector<Value> &invocationStack);
/// This is the default implementation of
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index b599a9f053215..80f9b72531660 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -381,13 +381,14 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*retType=*/"::llvm::LogicalResult",
/*methodName=*/"resolveConflicts",
/*args=*/(ins "::mlir::RewriterBase &":$rewriter,
- "const ::mlir::bufferization::AnalysisState &":$state),
+ "const ::mlir::bufferization::AnalysisState &":$analysisState,
+ "::mlir::bufferization::BufferizationState &":$bufferizationState),
/*methodBody=*/"",
/*defaultImplementation=*/[{
auto bufferizableOp =
::llvm::cast<BufferizableOpInterface>($_op.getOperation());
return bufferizableOp.resolveTensorOpOperandConflicts(
- rewriter, state);
+ rewriter, analysisState, bufferizationState);
}]
>,
InterfaceMethod<
@@ -523,6 +524,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*methodName=*/"getBufferType",
/*args=*/(ins "::mlir::Value":$value,
"const ::mlir::bufferization::BufferizationOptions &":$options,
+ "::mlir::bufferization::BufferizationState &":$state,
"::llvm::SmallVector<::mlir::Value> &":$invocationStack),
/*methodBody=*/"",
/*defaultImplementation=*/[{
@@ -531,7 +533,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
assert(invocationStack.back() == value &&
"inconsistant invocation stack");
return ::mlir::bufferization::detail::defaultGetBufferType(
- value, options, invocationStack);
+ value, options, state, invocationStack);
}]
>,
InterfaceMethod<
@@ -616,7 +618,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/// form of `bufferization.alloc_tensor` ops.
::llvm::LogicalResult resolveTensorOpOperandConflicts(
::mlir::RewriterBase &rewriter,
- const ::mlir::bufferization::AnalysisState &state);
+ const ::mlir::bufferization::AnalysisState &analysisState,
+ ::mlir::bufferization::BufferizationState &bufferizationState);
/// Return `true` if the given OpOperand creates an alias but does neither
/// read nor write. This implies that `bufferizesToMemoryRead` and
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index dafa4b9b183f2..0ee4f79144158 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -112,6 +112,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
FailureOr<BaseMemRefType> getBufferType(
Value value, const BufferizationOptions &options,
+ BufferizationState &state,
SmallVector<Value> &invocationStack);
RankedTensorType getType() {
@@ -471,7 +472,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
FailureOr<BaseMemRefType> getBufferType(
Value value, const BufferizationOptions &options,
- SmallVector<Value> &invocationStack) {
+ BufferizationState &state, SmallVector<Value> &invocationStack) {
return ::llvm::cast<BaseMemRefType>(getMemref().getType());
}
}];
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index cf86b9a23f59e..7c07f705c8435 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -34,12 +34,13 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+ BufferizationState &state,
SmallVector<Value> &invocationStack) const {
// Note: The user may want to override this function for OpResults in
// case the bufferized result type is different from the bufferized type of
// the aliasing OpOperand (if any).
if (isa<OpResult>(value))
- return bufferization::detail::defaultGetBufferType(value, options,
+ return bufferization::detail::defaultGetBufferType(value, options, state,
invocationStack);
// Compute the buffer type of the block argument by computing the bufferized
@@ -65,7 +66,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
callerType = memrefType;
} else {
FailureOr<BaseMemRefType> maybeCallerType =
- bufferization::getBufferType(opOperand->get(), options,
+ bufferization::getBufferType(opOperand->get(), options, state,
invocationStack);
if (failed(maybeCallerType))
return failure();
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
index 70e3defee0867..c1f5654abbf9b 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
@@ -62,7 +62,8 @@ LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
/// `BufferizableOpInterface`. The buffer types of tensor block arguments are
/// computed with `BufferizableOpIntercace::getBufferType`.
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
- const BufferizationOptions &options);
+ const BufferizationOptions &options,
+ BufferizationState &state);
} // namespace bufferization
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
index a4ee893ca5341..e587753ddebee 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
@@ -75,12 +75,15 @@ void hoistBuffersFromLoops(Operation *op);
/// additional buffer allocations.
LogicalResult insertTensorCopies(Operation *op,
const OneShotBufferizationOptions &options,
+ BufferizationState &bufferizationState,
BufferizationStatistics *statistics = nullptr);
/// Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
/// After applying this transform, the IR can be bufferized without inserting
/// additional buffer allocations.
-LogicalResult insertTensorCopies(Operation *op, const AnalysisState &state);
+LogicalResult insertTensorCopies(Operation *op,
+ const AnalysisState &analysisState,
+ BufferizationState &bufferizationState);
/// Populate patterns to lower tensor.empty ops to bufferization.alloc_tensor
/// ops.
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index f646326ffc58f..0389a984e169c 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -90,7 +90,8 @@ struct IndexCastOpInterface
auto castOp = cast<arith::IndexCastOp>(op);
auto resultTensorType = cast<TensorType>(castOp.getType());
- FailureOr<Value> source = getBuffer(rewriter, castOp.getIn(), options);
+ FailureOr<Value> source =
+ getBuffer(rewriter, castOp.getIn(), options, state);
if (failed(source))
return failure();
auto sourceType = cast<BaseMemRefType>(source->getType());
@@ -151,9 +152,9 @@ struct SelectOpInterface
// the moment (one for each tensor). When copying the op result, only one
// copy would be needed.
FailureOr<Value> maybeTrueBuffer =
- getBuffer(rewriter, selectOp.getTrueValue(), options);
+ getBuffer(rewriter, selectOp.getTrueValue(), options, state);
FailureOr<Value> maybeFalseBuffer =
- getBuffer(rewriter, selectOp.getFalseValue(), options);
+ getBuffer(rewriter, selectOp.getFalseValue(), options, state);
if (failed(maybeTrueBuffer) || failed(maybeFalseBuffer))
return failure();
Value trueBuffer = *maybeTrueBuffer;
@@ -164,7 +165,7 @@ struct SelectOpInterface
// both of them to the most dynamic MemRef type.
if (trueBuffer.getType() != falseBuffer.getType()) {
auto targetType =
- bufferization::getBufferType(selectOp.getResult(), options);
+ bufferization::getBufferType(selectOp.getResult(), options, state);
if (failed(targetType))
return failure();
if (trueBuffer.getType() != *targetType)
@@ -182,13 +183,14 @@ struct SelectOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+ BufferizationState &state,
SmallVector<Value> &invocationStack) const {
auto selectOp = cast<arith::SelectOp>(op);
assert(value == selectOp.getResult() && "invalid value");
- auto trueType = bufferization::getBufferType(selectOp.getTrueValue(),
- options, invocationStack);
- auto falseType = bufferization::getBufferType(selectOp.getFalseValue(),
- options, invocationStack);
+ auto trueType = bufferization::getBufferType(
+ selectOp.getTrueValue(), options, state, invocationStack);
+ auto falseType = bufferization::getBufferType(
+ selectOp.getFalseValue(), options, state, invocationStack);
if (failed(trueType) || failed(falseType))
return failure();
if (*trueType == *falseType)
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 14fa4c1ed8159..7d67d4a33ac32 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -165,7 +165,7 @@ Operation *bufferization::getOwnerOfValue(Value value) {
/// allocated.
FailureOr<Value> bufferization::allocateTensorForShapedValue(
OpBuilder &b, Location loc, Value shapedValue,
- const BufferizationOptions &options, bool copy) {
+ const BufferizationOptions &options, BufferizationState &state, bool copy) {
Value tensor;
if (llvm::isa<RankedTensorType>(shapedValue.getType())) {
tensor = shapedValue;
@@ -210,7 +210,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
// Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
if (copy)
return allocTensorOp.getResult();
- FailureOr<BaseMemRefType> copyBufferType = getBufferType(tensor, options);
+ FailureOr<BaseMemRefType> copyBufferType =
+ getBufferType(tensor, options, state);
if (failed(copyBufferType))
return failure();
std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace();
@@ -222,7 +223,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
}
LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
- RewriterBase &rewriter, const AnalysisState &state) {
+ RewriterBase &rewriter, const AnalysisState &analysisState,
+ BufferizationState &bufferizationState) {
OpBuilder::InsertionGuard g(rewriter);
Operation *op = getOperation();
SmallVector<OpOperand *> outOfPlaceOpOperands;
@@ -235,16 +237,18 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
Type operandType = opOperand.get().getType();
if (!llvm::isa<TensorType>(operandType))
continue;
- if (state.isInPlace(opOperand))
+ if (analysisState.isInPlace(opOperand))
continue;
if (llvm::isa<UnrankedTensorType>(operandType))
return op->emitError("copying of unranked tensors is not implemented");
- AliasingValueList aliasingValues = state.getAliasingValues(opOperand);
+ AliasingValueList aliasingValues =
+ analysisState.getAliasingValues(opOperand);
if (aliasingValues.getNumAliases() == 1 &&
isa<OpResult>(aliasingValues.getAliases()[0].value) &&
- !state.bufferizesToMemoryWrite(opOperand) &&
- state.getAliasingOpOperands(aliasingValues.getAliases()[0].value)
+ !analysisState.bufferizesToMemoryWrite(opOperand) &&
+ analysisState
+ .getAliasingOpOperands(aliasingValues.getAliases()[0].value)
.getNumAliases() == 1 &&
!isa<UnrankedTensorType>(
aliasingValues.getAliases()[0].value.getType())) {
@@ -256,12 +260,12 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
// cannot be copied at the moment).
Value value = aliasingValues.getAliases()[0].value;
outOfPlaceValues.push_back(value);
- if (!state.canOmitTensorCopy(opOperand))
+ if (!analysisState.canOmitTensorCopy(opOperand))
copiedOpValues.insert(value);
} else {
// In all other cases, make a copy of the OpOperand.
outOfPlaceOpOperands.push_back(&opOperand);
- if (!state.canOmitTensorCopy(opOperand))
+ if (!analysisState.canOmitTensorCopy(opOperand))
copiedOpOperands.insert(&opOperand);
}
}
@@ -270,8 +274,8 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
rewriter.setInsertionPoint(op);
for (OpOperand *opOperand : outOfPlaceOpOperands) {
FailureOr<Value> copy = allocateTensorForShapedValue(
- rewriter, op->getLoc(), opOperand->get(), state.getOptions(),
- copiedOpOperands.contains(opOperand));
+ rewriter, op->getLoc(), opOperand->get(), analysisState.getOptions(),
+ bufferizationState, copiedOpOperands.contains(opOperand));
if (failed(copy))
return failure();
rewriter.modifyOpInPlace(op, [&]() { opOperand->set(*copy); });
@@ -281,8 +285,8 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
rewriter.setInsertionPointAfter(op);
for (Value value : outOfPlaceValues) {
FailureOr<Value> copy = allocateTensorForShapedValue(
- rewriter, op->getLoc(), value, state.getOptions(),
- copiedOpValues.count(value));
+ rewriter, op->getLoc(), value, analysisState.getOptions(),
+ bufferizationState, copiedOpValues.count(value));
if (failed(copy))
return failure();
SmallVector<OpOperand *> uses = llvm::to_vector(
@@ -665,7 +669,8 @@ static void ensureToBufferOpIsValid(Value tensor, Type memrefType) {
}
FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
- const BufferizationOptions &options) {
+ const BufferizationOptions &options,
+ BufferizationState &state) {
#ifndef NDEBUG
auto tensorType = llvm::dyn_cast<TensorType>(value.getType());
assert(tensorType && "unexpected non-tensor type");
@@ -678,7 +683,7 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
// Insert to_buffer op.
OpBuilder::InsertionGuard g(rewriter);
setInsertionPointAfter(rewriter, value);
- FailureOr<BaseMemRefType> memrefType = getBufferType(value, options);
+ FailureOr<BaseMemRefType> memrefType = getBufferType(value, options, state);
if (failed(memrefType))
return failure();
ensureToBufferOpIsValid(value, *memrefType);
@@ -689,14 +694,16 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
/// Return the buffer type for a given Value (tensor) after bufferization.
FailureOr<BaseMemRefType>
-bufferization::getBufferType(Value value, const BufferizationOptions &options) {
+bufferization::getBufferType(Value value, const BufferizationOptions &options,
+ BufferizationState &state) {
SmallVector<Value> invocationStack;
- return getBufferType(value, options, invocationStack);
+ return getBufferType(value, options, state, invocationStack);
}
/// Return the buffer type for a given Value (tensor) after bufferization.
FailureOr<BaseMemRefType>
bufferization::getBufferType(Value value, const BufferizationOptions &options,
+ BufferizationState &state,
SmallVector<Value> &invocationStack) {
assert(llvm::isa<TensorType>(value.getType()) &&
"unexpected non-tensor type");
@@ -708,7 +715,7 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
Operation *op = getOwnerOfValue(value);
auto bufferizableOp = options.dynCastBufferizableOp(op);
if (bufferizableOp)
- return bufferizableOp.getBufferType(value, options, invocationStack);
+ return bufferizableOp.getBufferType(value, options, state, invocationStack);
// Op is not bufferizable.
auto memSpace =
@@ -944,6 +951,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
Failur...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
@@ -615,7 +616,8 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value, | |||
/// | |||
/// This function is a wrapper around BufferizableOpInterface::getBufferType. | |||
FailureOr<BaseMemRefType> getBufferType(Value value, | |||
const BufferizationOptions &options); | |||
const BufferizationOptions &options, | |||
BufferizationState &state); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be a const reference? Here and in all the other places that this PR touches?
@@ -381,13 +381,14 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> { | |||
/*retType=*/"::llvm::LogicalResult", | |||
/*methodName=*/"resolveConflicts", | |||
/*args=*/(ins "::mlir::RewriterBase &":$rewriter, | |||
"const ::mlir::bufferization::AnalysisState &":$state), | |||
"const ::mlir::bufferization::AnalysisState &":$analysisState, | |||
"::mlir::bufferization::BufferizationState &":$bufferizationState), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Const possible here?
…ts` interface methods (llvm#141466) The PR continues the work started in llvm#141019 by adding the `BufferizationState` class also to the `getBufferType` and `resolveConflicts` interface methods, together with the additional support functions that are used throughout the bufferization infrastructure.
* llvm/llvm-project#139340 ``` sed -i "s/isZeroIndex/isZeroInteger/g" **/*.h sed -i "s/isZeroIndex/isZeroInteger/g" **/*.cpp ``` * llvm/llvm-project#141466 & llvm/llvm-project#141019 * Add `BufferizationState &state` to `bufferize` and `getBuffer` * llvm/llvm-project#143159 & llvm/llvm-project#142683 & llvm/llvm-project#143779 * Updates to `transform.apply_registered_pass` and its Python-bindings * llvm/llvm-project#143217 * `tilingResult->mergeResult.replacements` -> `tilingResult->replacements` * llvm/llvm-project#140559 & llvm/llvm-project#143871 * Change CHECK lines which expected `amx.` ops to `llvm.call_intrinsic`s & fix which enables conversion again.
The PR continues the work started in #141019 by adding the
BufferizationState
class also to thegetBufferType
andresolveConflicts
interface methods, together with the additional support functions that are used throughout the bufferization infrastructure.