-
Notifications
You must be signed in to change notification settings - Fork 13.9k
[mlir][linalg] Simplify createWriteOrMaskedWrite
(NFC)
#141567
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][linalg] Simplify createWriteOrMaskedWrite
(NFC)
#141567
Conversation
@llvm/pr-subscribers-mlir-llvm Author: Andrzej Warzyński (banach-space) Changes
Full diff: https://github.com/llvm/llvm-project/pull/141567.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0113ba86a5ae3..2abb2f0ea467c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1590,61 +1590,46 @@ static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes,
/// Creates an optionally masked TransferWriteOp
///
/// Generates the following operation:
-/// %res = vector.transfer_write %vectorToStore into %dest
+/// %res = vector.transfer_write %vecToStore into %dest
///
-/// If the leading N dimensions of the vector to store do not match
-/// `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
-/// masking is applied to ensure correctness:
+/// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
///
-/// %mask = vector.create_mask(%destShape) : %vectorToStoreShape
+/// %mask = vector.create_mask(%destShape) : %vecToStoreShape
/// %res = vector.mask %mask {
-/// vector.transfer_write %vectorToStore into %dest
+/// vector.transfer_write %vecToStore into %dest
/// }
///
-/// The mask shape is identical to `vectorToStore` (with the element type ==
+/// The mask shape is identical to `vecToStore` (with the element type ==
/// i1), and the mask values are based on the shape of the `dest` tensor.
///
/// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
/// is used instead of masking:
///
-/// %write = vector.transfer_write %vectorToStore into %dest
+/// %write = vector.transfer_write %vecToStore into %dest
/// in_bounds_flags = (...)
/// %res = vector.transfer_write %input into %dest
/// {in_bounds = in_bounds_flags}
///
-/// `writeIndices` specifies the offsets to use. If empty, all indices are set
-/// to 0.
-///
-/// NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
-/// `valueToStore`.
-/// TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
-/// already provided in `vectorToStore`.
+/// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
+/// are set to 0.
static Operation *
-createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
- Value dest,
- ArrayRef<int64_t> inputVecSizesForLeadingDims,
- SmallVector<Value> writeIndices = {},
+createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
+ Value dest, SmallVector<Value> writeIndices = {},
bool useInBoundsInsteadOfMasking = false) {
ShapedType destType = cast<ShapedType>(dest.getType());
int64_t destRank = destType.getRank();
auto destShape = destType.getShape();
- VectorType vecToStoreType = cast<VectorType>(vectorToStore.getType());
+ VectorType vecToStoreType = cast<VectorType>(vecToStore.getType());
int64_t vecToStoreRank = vecToStoreType.getRank();
auto vecToStoreShape = vecToStoreType.getShape();
// Compute the in_bounds attribute
SmallVector<bool> inBoundsVal(vecToStoreRank, true);
if (useInBoundsInsteadOfMasking) {
- // In this case, assume that all the required vector sizes have been
- // provided.
- assert(inputVecSizesForLeadingDims.size() ==
- static_cast<size_t>(vecToStoreType.getRank()) &&
- "Insufficient number of input vector sizes!");
- // Update the inBounds attribute.
for (unsigned i = 0; i < destRank; i++)
- inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
+ inBoundsVal[i] = (destShape[i] == vecToStoreShape[i]) &&
!ShapedType::isDynamic(destShape[i]);
}
@@ -1660,7 +1645,7 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
// Generate the xfer_write Op
Operation *write =
builder.create<vector::TransferWriteOp>(loc,
- /*vector=*/vectorToStore,
+ /*vector=*/vecToStore,
/*source=*/dest,
/*indices=*/writeIndices,
/*inBounds=*/inBoundsVal);
@@ -1669,46 +1654,25 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
if (useInBoundsInsteadOfMasking)
return write;
- assert(llvm::none_of(
- destShape.drop_front(inputVecSizesForLeadingDims.size()),
- [](int64_t size) { return size == ShapedType::kDynamic; }) &&
- "Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
-
- // Check if masking is needed.
- bool needMaskForWrite =
- !llvm::equal(inputVecSizesForLeadingDims,
- destShape.take_front(destRank - vecToStoreRank +
- inputVecSizesForLeadingDims.size()));
-
- // If masking is needed, generate the mask and mask the operation.
- if (needMaskForWrite) {
- // Get the mask shape + type. Missing mask dimensions are taken from
- // `vectorToStore`.
- SmallVector<int64_t> writeMaskShape;
- writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
- inputVecSizesForLeadingDims.end());
- if (vecToStoreRank >
- static_cast<int64_t>(inputVecSizesForLeadingDims.size()))
- writeMaskShape.append(vecToStoreShape.begin() +
- inputVecSizesForLeadingDims.size(),
- vecToStoreShape.end());
- auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
-
- SmallVector<OpFoldResult> destSizes =
- tensor::getMixedSizes(builder, loc, dest);
- SmallVector<OpFoldResult> maskSizes(destSizes.end() - writeMaskShape.size(),
- destSizes.end());
-
- if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
- writeMaskShape))
- return write;
-
- Value maskForWrite = builder.createOrFold<vector::CreateMaskOp>(
- loc, writeMaskType, maskSizes);
- write = mlir::vector::maskOperation(builder, write, maskForWrite);
- }
+ // Check if masking is needed. If not, exit.
+ if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
+ return write;
+
+ // Compute the mask and mask the write Op.
+ auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
+
+ SmallVector<OpFoldResult> destSizes =
+ tensor::getMixedSizes(builder, loc, dest);
+ SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
+ destSizes.end());
+
+ if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
+ vecToStoreShape))
+ return write;
- return write;
+ Value maskForWrite =
+ builder.createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
+ return mlir::vector::maskOperation(builder, write, maskForWrite);
}
/// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
@@ -1808,10 +1772,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
Value dest = rewriter.create<tensor::EmptyOp>(
loc, reifiedReturnShapes[0],
transposeOp.getResult().getType().getElementType());
- Operation *write = createWriteOrMaskedWrite(
- rewriter, loc, transposeOp.getResult(), dest,
- /*inputVecSizesForLeadingDims=*/inputVectorSizes, /*writeIndices=*/{},
- /*useInBoundsInsteadOfMasking=*/false);
+ Operation *write =
+ createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest,
+ /*writeIndices=*/{},
+ /*useInBoundsInsteadOfMasking=*/false);
newResults.push_back(write->getResult(0));
return success();
}
@@ -1949,7 +1913,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
shapeCastOp.getResult().getType().getElementType());
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, shapeCastOp.getResult(), dest,
- /*inputVecSizesForLeadingDims=*/writeVectorSizes,
/*writeIndices=*/{}, useInBoundsInsteadOfMasking);
newResults.push_back(write->getResult(0));
return success();
@@ -1982,10 +1945,9 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
// Create Xfer write Op
Value dest = rewriter.create<tensor::EmptyOp>(
loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
- Operation *write = createWriteOrMaskedWrite(
- rewriter, loc, maskedRead, dest,
- /*inputVecSizesForLeadingDims=*/inputVectorSizes, {},
- /*useInBoundsInsteadOfMasking=*/false);
+ Operation *write =
+ createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest, {},
+ /*useInBoundsInsteadOfMasking=*/false);
newResults.push_back(write->getResult(0));
return success();
}
@@ -3041,8 +3003,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
// Create write
auto writeIndices =
getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
- Operation *write = createWriteOrMaskedWrite(
- rewriter, loc, read, sliceOp.getDest(), vecType.getShape(), writeIndices);
+ Operation *write = createWriteOrMaskedWrite(rewriter, loc, read,
+ sliceOp.getDest(), writeIndices);
// 4. Finalize
newResults.push_back(write->getResult(0));
|
@llvm/pr-subscribers-mlir-linalg Author: Andrzej Warzyński (banach-space) Changes
Full diff: https://github.com/llvm/llvm-project/pull/141567.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0113ba86a5ae3..2abb2f0ea467c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1590,61 +1590,46 @@ static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes,
/// Creates an optionally masked TransferWriteOp
///
/// Generates the following operation:
-/// %res = vector.transfer_write %vectorToStore into %dest
+/// %res = vector.transfer_write %vecToStore into %dest
///
-/// If the leading N dimensions of the vector to store do not match
-/// `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
-/// masking is applied to ensure correctness:
+/// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
///
-/// %mask = vector.create_mask(%destShape) : %vectorToStoreShape
+/// %mask = vector.create_mask(%destShape) : %vecToStoreShape
/// %res = vector.mask %mask {
-/// vector.transfer_write %vectorToStore into %dest
+/// vector.transfer_write %vecToStore into %dest
/// }
///
-/// The mask shape is identical to `vectorToStore` (with the element type ==
+/// The mask shape is identical to `vecToStore` (with the element type ==
/// i1), and the mask values are based on the shape of the `dest` tensor.
///
/// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
/// is used instead of masking:
///
-/// %write = vector.transfer_write %vectorToStore into %dest
+/// %write = vector.transfer_write %vecToStore into %dest
/// in_bounds_flags = (...)
/// %res = vector.transfer_write %input into %dest
/// {in_bounds = in_bounds_flags}
///
-/// `writeIndices` specifies the offsets to use. If empty, all indices are set
-/// to 0.
-///
-/// NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
-/// `valueToStore`.
-/// TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
-/// already provided in `vectorToStore`.
+/// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
+/// are set to 0.
static Operation *
-createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
- Value dest,
- ArrayRef<int64_t> inputVecSizesForLeadingDims,
- SmallVector<Value> writeIndices = {},
+createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
+ Value dest, SmallVector<Value> writeIndices = {},
bool useInBoundsInsteadOfMasking = false) {
ShapedType destType = cast<ShapedType>(dest.getType());
int64_t destRank = destType.getRank();
auto destShape = destType.getShape();
- VectorType vecToStoreType = cast<VectorType>(vectorToStore.getType());
+ VectorType vecToStoreType = cast<VectorType>(vecToStore.getType());
int64_t vecToStoreRank = vecToStoreType.getRank();
auto vecToStoreShape = vecToStoreType.getShape();
// Compute the in_bounds attribute
SmallVector<bool> inBoundsVal(vecToStoreRank, true);
if (useInBoundsInsteadOfMasking) {
- // In this case, assume that all the required vector sizes have been
- // provided.
- assert(inputVecSizesForLeadingDims.size() ==
- static_cast<size_t>(vecToStoreType.getRank()) &&
- "Insufficient number of input vector sizes!");
- // Update the inBounds attribute.
for (unsigned i = 0; i < destRank; i++)
- inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
+ inBoundsVal[i] = (destShape[i] == vecToStoreShape[i]) &&
!ShapedType::isDynamic(destShape[i]);
}
@@ -1660,7 +1645,7 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
// Generate the xfer_write Op
Operation *write =
builder.create<vector::TransferWriteOp>(loc,
- /*vector=*/vectorToStore,
+ /*vector=*/vecToStore,
/*source=*/dest,
/*indices=*/writeIndices,
/*inBounds=*/inBoundsVal);
@@ -1669,46 +1654,25 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
if (useInBoundsInsteadOfMasking)
return write;
- assert(llvm::none_of(
- destShape.drop_front(inputVecSizesForLeadingDims.size()),
- [](int64_t size) { return size == ShapedType::kDynamic; }) &&
- "Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
-
- // Check if masking is needed.
- bool needMaskForWrite =
- !llvm::equal(inputVecSizesForLeadingDims,
- destShape.take_front(destRank - vecToStoreRank +
- inputVecSizesForLeadingDims.size()));
-
- // If masking is needed, generate the mask and mask the operation.
- if (needMaskForWrite) {
- // Get the mask shape + type. Missing mask dimensions are taken from
- // `vectorToStore`.
- SmallVector<int64_t> writeMaskShape;
- writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
- inputVecSizesForLeadingDims.end());
- if (vecToStoreRank >
- static_cast<int64_t>(inputVecSizesForLeadingDims.size()))
- writeMaskShape.append(vecToStoreShape.begin() +
- inputVecSizesForLeadingDims.size(),
- vecToStoreShape.end());
- auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
-
- SmallVector<OpFoldResult> destSizes =
- tensor::getMixedSizes(builder, loc, dest);
- SmallVector<OpFoldResult> maskSizes(destSizes.end() - writeMaskShape.size(),
- destSizes.end());
-
- if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
- writeMaskShape))
- return write;
-
- Value maskForWrite = builder.createOrFold<vector::CreateMaskOp>(
- loc, writeMaskType, maskSizes);
- write = mlir::vector::maskOperation(builder, write, maskForWrite);
- }
+ // Check if masking is needed. If not, exit.
+ if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
+ return write;
+
+ // Compute the mask and mask the write Op.
+ auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
+
+ SmallVector<OpFoldResult> destSizes =
+ tensor::getMixedSizes(builder, loc, dest);
+ SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
+ destSizes.end());
+
+ if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
+ vecToStoreShape))
+ return write;
- return write;
+ Value maskForWrite =
+ builder.createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
+ return mlir::vector::maskOperation(builder, write, maskForWrite);
}
/// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
@@ -1808,10 +1772,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
Value dest = rewriter.create<tensor::EmptyOp>(
loc, reifiedReturnShapes[0],
transposeOp.getResult().getType().getElementType());
- Operation *write = createWriteOrMaskedWrite(
- rewriter, loc, transposeOp.getResult(), dest,
- /*inputVecSizesForLeadingDims=*/inputVectorSizes, /*writeIndices=*/{},
- /*useInBoundsInsteadOfMasking=*/false);
+ Operation *write =
+ createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest,
+ /*writeIndices=*/{},
+ /*useInBoundsInsteadOfMasking=*/false);
newResults.push_back(write->getResult(0));
return success();
}
@@ -1949,7 +1913,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
shapeCastOp.getResult().getType().getElementType());
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, shapeCastOp.getResult(), dest,
- /*inputVecSizesForLeadingDims=*/writeVectorSizes,
/*writeIndices=*/{}, useInBoundsInsteadOfMasking);
newResults.push_back(write->getResult(0));
return success();
@@ -1982,10 +1945,9 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
// Create Xfer write Op
Value dest = rewriter.create<tensor::EmptyOp>(
loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
- Operation *write = createWriteOrMaskedWrite(
- rewriter, loc, maskedRead, dest,
- /*inputVecSizesForLeadingDims=*/inputVectorSizes, {},
- /*useInBoundsInsteadOfMasking=*/false);
+ Operation *write =
+ createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest, {},
+ /*useInBoundsInsteadOfMasking=*/false);
newResults.push_back(write->getResult(0));
return success();
}
@@ -3041,8 +3003,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
// Create write
auto writeIndices =
getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
- Operation *write = createWriteOrMaskedWrite(
- rewriter, loc, read, sliceOp.getDest(), vecType.getShape(), writeIndices);
+ Operation *write = createWriteOrMaskedWrite(rewriter, loc, read,
+ sliceOp.getDest(), writeIndices);
// 4. Finalize
newResults.push_back(write->getResult(0));
|
@llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) Changes
Full diff: https://github.com/llvm/llvm-project/pull/141567.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0113ba86a5ae3..2abb2f0ea467c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1590,61 +1590,46 @@ static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes,
/// Creates an optionally masked TransferWriteOp
///
/// Generates the following operation:
-/// %res = vector.transfer_write %vectorToStore into %dest
+/// %res = vector.transfer_write %vecToStore into %dest
///
-/// If the leading N dimensions of the vector to store do not match
-/// `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
-/// masking is applied to ensure correctness:
+/// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
///
-/// %mask = vector.create_mask(%destShape) : %vectorToStoreShape
+/// %mask = vector.create_mask(%destShape) : %vecToStoreShape
/// %res = vector.mask %mask {
-/// vector.transfer_write %vectorToStore into %dest
+/// vector.transfer_write %vecToStore into %dest
/// }
///
-/// The mask shape is identical to `vectorToStore` (with the element type ==
+/// The mask shape is identical to `vecToStore` (with the element type ==
/// i1), and the mask values are based on the shape of the `dest` tensor.
///
/// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
/// is used instead of masking:
///
-/// %write = vector.transfer_write %vectorToStore into %dest
+/// %write = vector.transfer_write %vecToStore into %dest
/// in_bounds_flags = (...)
/// %res = vector.transfer_write %input into %dest
/// {in_bounds = in_bounds_flags}
///
-/// `writeIndices` specifies the offsets to use. If empty, all indices are set
-/// to 0.
-///
-/// NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
-/// `valueToStore`.
-/// TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
-/// already provided in `vectorToStore`.
+/// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
+/// are set to 0.
static Operation *
-createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
- Value dest,
- ArrayRef<int64_t> inputVecSizesForLeadingDims,
- SmallVector<Value> writeIndices = {},
+createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
+ Value dest, SmallVector<Value> writeIndices = {},
bool useInBoundsInsteadOfMasking = false) {
ShapedType destType = cast<ShapedType>(dest.getType());
int64_t destRank = destType.getRank();
auto destShape = destType.getShape();
- VectorType vecToStoreType = cast<VectorType>(vectorToStore.getType());
+ VectorType vecToStoreType = cast<VectorType>(vecToStore.getType());
int64_t vecToStoreRank = vecToStoreType.getRank();
auto vecToStoreShape = vecToStoreType.getShape();
// Compute the in_bounds attribute
SmallVector<bool> inBoundsVal(vecToStoreRank, true);
if (useInBoundsInsteadOfMasking) {
- // In this case, assume that all the required vector sizes have been
- // provided.
- assert(inputVecSizesForLeadingDims.size() ==
- static_cast<size_t>(vecToStoreType.getRank()) &&
- "Insufficient number of input vector sizes!");
- // Update the inBounds attribute.
for (unsigned i = 0; i < destRank; i++)
- inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
+ inBoundsVal[i] = (destShape[i] == vecToStoreShape[i]) &&
!ShapedType::isDynamic(destShape[i]);
}
@@ -1660,7 +1645,7 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
// Generate the xfer_write Op
Operation *write =
builder.create<vector::TransferWriteOp>(loc,
- /*vector=*/vectorToStore,
+ /*vector=*/vecToStore,
/*source=*/dest,
/*indices=*/writeIndices,
/*inBounds=*/inBoundsVal);
@@ -1669,46 +1654,25 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
if (useInBoundsInsteadOfMasking)
return write;
- assert(llvm::none_of(
- destShape.drop_front(inputVecSizesForLeadingDims.size()),
- [](int64_t size) { return size == ShapedType::kDynamic; }) &&
- "Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
-
- // Check if masking is needed.
- bool needMaskForWrite =
- !llvm::equal(inputVecSizesForLeadingDims,
- destShape.take_front(destRank - vecToStoreRank +
- inputVecSizesForLeadingDims.size()));
-
- // If masking is needed, generate the mask and mask the operation.
- if (needMaskForWrite) {
- // Get the mask shape + type. Missing mask dimensions are taken from
- // `vectorToStore`.
- SmallVector<int64_t> writeMaskShape;
- writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
- inputVecSizesForLeadingDims.end());
- if (vecToStoreRank >
- static_cast<int64_t>(inputVecSizesForLeadingDims.size()))
- writeMaskShape.append(vecToStoreShape.begin() +
- inputVecSizesForLeadingDims.size(),
- vecToStoreShape.end());
- auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
-
- SmallVector<OpFoldResult> destSizes =
- tensor::getMixedSizes(builder, loc, dest);
- SmallVector<OpFoldResult> maskSizes(destSizes.end() - writeMaskShape.size(),
- destSizes.end());
-
- if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
- writeMaskShape))
- return write;
-
- Value maskForWrite = builder.createOrFold<vector::CreateMaskOp>(
- loc, writeMaskType, maskSizes);
- write = mlir::vector::maskOperation(builder, write, maskForWrite);
- }
+ // Check if masking is needed. If not, exit.
+ if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
+ return write;
+
+ // Compute the mask and mask the write Op.
+ auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
+
+ SmallVector<OpFoldResult> destSizes =
+ tensor::getMixedSizes(builder, loc, dest);
+ SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
+ destSizes.end());
+
+ if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
+ vecToStoreShape))
+ return write;
- return write;
+ Value maskForWrite =
+ builder.createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
+ return mlir::vector::maskOperation(builder, write, maskForWrite);
}
/// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
@@ -1808,10 +1772,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
Value dest = rewriter.create<tensor::EmptyOp>(
loc, reifiedReturnShapes[0],
transposeOp.getResult().getType().getElementType());
- Operation *write = createWriteOrMaskedWrite(
- rewriter, loc, transposeOp.getResult(), dest,
- /*inputVecSizesForLeadingDims=*/inputVectorSizes, /*writeIndices=*/{},
- /*useInBoundsInsteadOfMasking=*/false);
+ Operation *write =
+ createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest,
+ /*writeIndices=*/{},
+ /*useInBoundsInsteadOfMasking=*/false);
newResults.push_back(write->getResult(0));
return success();
}
@@ -1949,7 +1913,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
shapeCastOp.getResult().getType().getElementType());
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, shapeCastOp.getResult(), dest,
- /*inputVecSizesForLeadingDims=*/writeVectorSizes,
/*writeIndices=*/{}, useInBoundsInsteadOfMasking);
newResults.push_back(write->getResult(0));
return success();
@@ -1982,10 +1945,9 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
// Create Xfer write Op
Value dest = rewriter.create<tensor::EmptyOp>(
loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
- Operation *write = createWriteOrMaskedWrite(
- rewriter, loc, maskedRead, dest,
- /*inputVecSizesForLeadingDims=*/inputVectorSizes, {},
- /*useInBoundsInsteadOfMasking=*/false);
+ Operation *write =
+ createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest, {},
+ /*useInBoundsInsteadOfMasking=*/false);
newResults.push_back(write->getResult(0));
return success();
}
@@ -3041,8 +3003,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
// Create write
auto writeIndices =
getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
- Operation *write = createWriteOrMaskedWrite(
- rewriter, loc, read, sliceOp.getDest(), vecType.getShape(), writeIndices);
+ Operation *write = createWriteOrMaskedWrite(rewriter, loc, read,
+ sliceOp.getDest(), writeIndices);
// 4. Finalize
newResults.push_back(write->getResult(0));
|
@llvm/pr-subscribers-mlir-vector Author: Andrzej Warzyński (banach-space) Changes
Full diff: https://github.com/llvm/llvm-project/pull/141567.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0113ba86a5ae3..2abb2f0ea467c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1590,61 +1590,46 @@ static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes,
/// Creates an optionally masked TransferWriteOp
///
/// Generates the following operation:
-/// %res = vector.transfer_write %vectorToStore into %dest
+/// %res = vector.transfer_write %vecToStore into %dest
///
-/// If the leading N dimensions of the vector to store do not match
-/// `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
-/// masking is applied to ensure correctness:
+/// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
///
-/// %mask = vector.create_mask(%destShape) : %vectorToStoreShape
+/// %mask = vector.create_mask(%destShape) : %vecToStoreShape
/// %res = vector.mask %mask {
-/// vector.transfer_write %vectorToStore into %dest
+/// vector.transfer_write %vecToStore into %dest
/// }
///
-/// The mask shape is identical to `vectorToStore` (with the element type ==
+/// The mask shape is identical to `vecToStore` (with the element type ==
/// i1), and the mask values are based on the shape of the `dest` tensor.
///
/// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
/// is used instead of masking:
///
-/// %write = vector.transfer_write %vectorToStore into %dest
+/// %write = vector.transfer_write %vecToStore into %dest
/// in_bounds_flags = (...)
/// %res = vector.transfer_write %input into %dest
/// {in_bounds = in_bounds_flags}
///
-/// `writeIndices` specifies the offsets to use. If empty, all indices are set
-/// to 0.
-///
-/// NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
-/// `valueToStore`.
-/// TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
-/// already provided in `vectorToStore`.
+/// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
+/// are set to 0.
static Operation *
-createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
- Value dest,
- ArrayRef<int64_t> inputVecSizesForLeadingDims,
- SmallVector<Value> writeIndices = {},
+createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
+ Value dest, SmallVector<Value> writeIndices = {},
bool useInBoundsInsteadOfMasking = false) {
ShapedType destType = cast<ShapedType>(dest.getType());
int64_t destRank = destType.getRank();
auto destShape = destType.getShape();
- VectorType vecToStoreType = cast<VectorType>(vectorToStore.getType());
+ VectorType vecToStoreType = cast<VectorType>(vecToStore.getType());
int64_t vecToStoreRank = vecToStoreType.getRank();
auto vecToStoreShape = vecToStoreType.getShape();
// Compute the in_bounds attribute
SmallVector<bool> inBoundsVal(vecToStoreRank, true);
if (useInBoundsInsteadOfMasking) {
- // In this case, assume that all the required vector sizes have been
- // provided.
- assert(inputVecSizesForLeadingDims.size() ==
- static_cast<size_t>(vecToStoreType.getRank()) &&
- "Insufficient number of input vector sizes!");
- // Update the inBounds attribute.
for (unsigned i = 0; i < destRank; i++)
- inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
+ inBoundsVal[i] = (destShape[i] == vecToStoreShape[i]) &&
!ShapedType::isDynamic(destShape[i]);
}
@@ -1660,7 +1645,7 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
// Generate the xfer_write Op
Operation *write =
builder.create<vector::TransferWriteOp>(loc,
- /*vector=*/vectorToStore,
+ /*vector=*/vecToStore,
/*source=*/dest,
/*indices=*/writeIndices,
/*inBounds=*/inBoundsVal);
@@ -1669,46 +1654,25 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
if (useInBoundsInsteadOfMasking)
return write;
- assert(llvm::none_of(
- destShape.drop_front(inputVecSizesForLeadingDims.size()),
- [](int64_t size) { return size == ShapedType::kDynamic; }) &&
- "Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
-
- // Check if masking is needed.
- bool needMaskForWrite =
- !llvm::equal(inputVecSizesForLeadingDims,
- destShape.take_front(destRank - vecToStoreRank +
- inputVecSizesForLeadingDims.size()));
-
- // If masking is needed, generate the mask and mask the operation.
- if (needMaskForWrite) {
- // Get the mask shape + type. Missing mask dimensions are taken from
- // `vectorToStore`.
- SmallVector<int64_t> writeMaskShape;
- writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
- inputVecSizesForLeadingDims.end());
- if (vecToStoreRank >
- static_cast<int64_t>(inputVecSizesForLeadingDims.size()))
- writeMaskShape.append(vecToStoreShape.begin() +
- inputVecSizesForLeadingDims.size(),
- vecToStoreShape.end());
- auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
-
- SmallVector<OpFoldResult> destSizes =
- tensor::getMixedSizes(builder, loc, dest);
- SmallVector<OpFoldResult> maskSizes(destSizes.end() - writeMaskShape.size(),
- destSizes.end());
-
- if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
- writeMaskShape))
- return write;
-
- Value maskForWrite = builder.createOrFold<vector::CreateMaskOp>(
- loc, writeMaskType, maskSizes);
- write = mlir::vector::maskOperation(builder, write, maskForWrite);
- }
+ // Check if masking is needed. If not, exit.
+ if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
+ return write;
+
+ // Compute the mask and mask the write Op.
+ auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
+
+ SmallVector<OpFoldResult> destSizes =
+ tensor::getMixedSizes(builder, loc, dest);
+ SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
+ destSizes.end());
+
+ if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
+ vecToStoreShape))
+ return write;
- return write;
+ Value maskForWrite =
+ builder.createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
+ return mlir::vector::maskOperation(builder, write, maskForWrite);
}
/// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
@@ -1808,10 +1772,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
Value dest = rewriter.create<tensor::EmptyOp>(
loc, reifiedReturnShapes[0],
transposeOp.getResult().getType().getElementType());
- Operation *write = createWriteOrMaskedWrite(
- rewriter, loc, transposeOp.getResult(), dest,
- /*inputVecSizesForLeadingDims=*/inputVectorSizes, /*writeIndices=*/{},
- /*useInBoundsInsteadOfMasking=*/false);
+ Operation *write =
+ createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest,
+ /*writeIndices=*/{},
+ /*useInBoundsInsteadOfMasking=*/false);
newResults.push_back(write->getResult(0));
return success();
}
@@ -1949,7 +1913,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
shapeCastOp.getResult().getType().getElementType());
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, shapeCastOp.getResult(), dest,
- /*inputVecSizesForLeadingDims=*/writeVectorSizes,
/*writeIndices=*/{}, useInBoundsInsteadOfMasking);
newResults.push_back(write->getResult(0));
return success();
@@ -1982,10 +1945,9 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
// Create Xfer write Op
Value dest = rewriter.create<tensor::EmptyOp>(
loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
- Operation *write = createWriteOrMaskedWrite(
- rewriter, loc, maskedRead, dest,
- /*inputVecSizesForLeadingDims=*/inputVectorSizes, {},
- /*useInBoundsInsteadOfMasking=*/false);
+ Operation *write =
+ createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest, {},
+ /*useInBoundsInsteadOfMasking=*/false);
newResults.push_back(write->getResult(0));
return success();
}
@@ -3041,8 +3003,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
// Create write
auto writeIndices =
getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
- Operation *write = createWriteOrMaskedWrite(
- rewriter, loc, read, sliceOp.getDest(), vecType.getShape(), writeIndices);
+ Operation *write = createWriteOrMaskedWrite(rewriter, loc, read,
+ sliceOp.getDest(), writeIndices);
// 4. Finalize
newResults.push_back(write->getResult(0));
|
createWriteOrMaskedWrite
(NFC)
53d435d
to
ca24a26
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice clean up, thanks!
42b1783
to
373036e
Compare
ca24a26
to
edcc604
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
This patch removes `inputVecSizesForLeadingDims` from the parameter list of `createWriteOrMaskedWrite`. That argument is unnecessary — vector sizes can be obtained from the `vecToStore` parameter. Since this doesn't change behavior or test results, it's marked as NFC. Additional cleanups: * Renamed `vectorToStore` to `vecToStore` for consistency and brevity. * Rewrote a conditional at the end of the function to use early exit, improving readability: ```cpp // BEFORE: if (maskingRequried) { Value maskForWrite = ...; write = maskOperation(write, maskForWrite); } return write; // AFTER if (!maskingRequried) return write; Value maskFroWrite = ...; return vector::maskOperation(builder, write, maskForWrite); ``` This change addresses a TODO from #141244.
edcc604
to
cef5067
Compare
This patch removes `inputVecSizesForLeadingDims` from the parameter list of `createWriteOrMaskedWrite`. That argument is unnecessary - vector sizes can be obtained from the `vecToStore` parameter. Since this doesn't change behavior or test results, it's marked as NFC. Additional cleanups: * Renamed `vectorToStore` to `vecToStore` for consistency and brevity. * Rewrote a conditional at the end of the function to use early exit, improving readability: ```cpp // BEFORE: if (maskingRequried) { Value maskForWrite = ...; write = maskOperation(write, maskForWrite); } return write; // AFTER if (!maskingRequried) return write; Value maskFroWrite = ...; return vector::maskOperation(builder, write, maskForWrite); ```
This patch removes `inputVecSizesForLeadingDims` from the parameter list of `createWriteOrMaskedWrite`. That argument is unnecessary - vector sizes can be obtained from the `vecToStore` parameter. Since this doesn't change behavior or test results, it's marked as NFC. Additional cleanups: * Renamed `vectorToStore` to `vecToStore` for consistency and brevity. * Rewrote a conditional at the end of the function to use early exit, improving readability: ```cpp // BEFORE: if (maskingRequried) { Value maskForWrite = ...; write = maskOperation(write, maskForWrite); } return write; // AFTER if (!maskingRequried) return write; Value maskFroWrite = ...; return vector::maskOperation(builder, write, maskForWrite); ```
This patch removes `inputVecSizesForLeadingDims` from the parameter list of `createWriteOrMaskedWrite`. That argument is unnecessary - vector sizes can be obtained from the `vecToStore` parameter. Since this doesn't change behavior or test results, it's marked as NFC. Additional cleanups: * Renamed `vectorToStore` to `vecToStore` for consistency and brevity. * Rewrote a conditional at the end of the function to use early exit, improving readability: ```cpp // BEFORE: if (maskingRequried) { Value maskForWrite = ...; write = maskOperation(write, maskForWrite); } return write; // AFTER if (!maskingRequried) return write; Value maskFroWrite = ...; return vector::maskOperation(builder, write, maskForWrite); ```
This patch removes `inputVecSizesForLeadingDims` from the parameter list of `createWriteOrMaskedWrite`. That argument is unnecessary - vector sizes can be obtained from the `vecToStore` parameter. Since this doesn't change behavior or test results, it's marked as NFC. Additional cleanups: * Renamed `vectorToStore` to `vecToStore` for consistency and brevity. * Rewrote a conditional at the end of the function to use early exit, improving readability: ```cpp // BEFORE: if (maskingRequried) { Value maskForWrite = ...; write = maskOperation(write, maskForWrite); } return write; // AFTER if (!maskingRequried) return write; Value maskFroWrite = ...; return vector::maskOperation(builder, write, maskForWrite); ```
This patch removes
inputVecSizesForLeadingDims
from the parameter listof
createWriteOrMaskedWrite
. That argument is unnecessary — vector sizescan be obtained from the
vecToStore
parameter. Since this doesn't changebehavior or test results, it's marked as NFC.
Additional cleanups:
vectorToStore
tovecToStore
for consistency and brevity.improving readability: