-
Notifications
You must be signed in to change notification settings - Fork 13.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Propagate vector.extract
through elementwise ops
#131462
Conversation
c8b7a74
to
ac16b30
Compare
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Ivan Butygin (Hardcode84) ChangesPropagate Currenly limited to the case when extract is the single use of elementwise to avoid introducing additional elementwise ops. Full diff: https://github.com/llvm/llvm-project/pull/131462.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index c973eca0132a9..7be39519c1037 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -453,4 +453,15 @@ def ApplyVectorReductionToContractPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplyVectorPropagateExtractPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.vector.propagate_extract",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Collect a set of patterns for propagating `vector.extract` through the
+ vector ops.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
#endif // VECTOR_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 7de4a6a315750..16c66e078821d 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -409,6 +409,9 @@ void populateVectorLinearizeShuffleLikeOpsPatterns(
const TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, unsigned targetBitWidth);
+/// Populates patterns for propagating `vector.extract` through the vector ops.
+void populateVectorPropagateExtractsPatterns(RewritePatternSet &patterns);
+
} // namespace vector
} // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 20c577273d786..616e563fcdc77 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -204,6 +204,11 @@ void transform::ApplyTransferToScfPatternsOp::populatePatterns(
populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions);
}
+void transform::ApplyVectorPropagateExtractPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ vector::populateVectorPropagateExtractsPatterns(patterns);
+}
+
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 8ca5cb6c6dfab..8830375f88104 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -24,6 +24,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
VectorTransforms.cpp
VectorUnroll.cpp
VectorMaskElimination.cpp
+ VectorPropagateExtract.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/Transforms
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorPropagateExtract.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorPropagateExtract.cpp
new file mode 100644
index 0000000000000..10f578179bc94
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorPropagateExtract.cpp
@@ -0,0 +1,66 @@
+//===- VectorPropagateExtract.cpp - vector.extract propagation - ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns for vector.extract propagation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Pattern to rewrite a ExtractOp(Elementwise) -> Elementwise(ExtractOp).
+class ExtractOpFromElementwise final
+ : public OpRewritePattern<vector::ExtractOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ExtractOp op,
+ PatternRewriter &rewriter) const override {
+ Operation *eltwise = op.getVector().getDefiningOp();
+
+ // Elementwise op with single result and `extract` is single user.
+ if (!eltwise || !OpTrait::hasElementwiseMappableTraits(eltwise) ||
+ eltwise->getNumResults() != 1 || !eltwise->hasOneUse())
+ return failure();
+
+ // Arguments and result types must match.
+ if (!llvm::all_equal(llvm::concat<Type>(eltwise->getOperandTypes(),
+ eltwise->getResultTypes())))
+ return failure();
+
+ Type dstType = op.getType();
+
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(eltwise);
+
+ IRMapping mapping;
+ Location loc = eltwise->getLoc();
+ for (auto &&[i, arg] : llvm::enumerate(eltwise->getOperands())) {
+ Value newArg =
+ rewriter.create<vector::ExtractOp>(loc, arg, op.getMixedPosition());
+ mapping.map(arg, newArg);
+ }
+
+ Operation *newEltwise = rewriter.clone(*eltwise, mapping);
+ newEltwise->getResult(0).setType(dstType);
+
+ rewriter.replaceOp(op, newEltwise);
+ rewriter.eraseOp(eltwise);
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::vector::populateVectorPropagateExtractsPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<ExtractOpFromElementwise>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Vector/propagate-extracts.mlir b/mlir/test/Dialect/Vector/propagate-extracts.mlir
new file mode 100644
index 0000000000000..6c6f812c8f6d2
--- /dev/null
+++ b/mlir/test/Dialect/Vector/propagate-extracts.mlir
@@ -0,0 +1,47 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
+
+// CHECK-LABEL: @extract_elementwise
+// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
+func.func @extract_elementwise(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
+// CHECK: %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : f32 from vector<4xf32>
+// CHECK: %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32>
+// CHECK: %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32
+// CHECK: return %[[RES]] : f32
+ %0 = arith.addf %arg0, %arg1 : vector<4xf32>
+ %1 = vector.extract %0[1] : f32 from vector<4xf32>
+ return %1 : f32
+}
+
+// CHECK-LABEL: @extract_vec_elementwise
+// CHECK-SAME: (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<2x4xf32>)
+func.func @extract_vec_elementwise(%arg0: vector<2x4xf32>, %arg1: vector<2x4xf32>) -> vector<4xf32> {
+// CHECK: %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : vector<4xf32> from vector<2x4xf32>
+// CHECK: %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf32> from vector<2x4xf32>
+// CHECK: %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : vector<4xf32>
+// CHECK: return %[[RES]] : vector<4xf32>
+ %0 = arith.addf %arg0, %arg1 : vector<2x4xf32>
+ %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
+ return %1 : vector<4xf32>
+}
+
+// CHECK-LABEL: @extract_elementwise_use
+// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
+func.func @extract_elementwise_use(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> (f32, vector<4xf32>) {
+// Do not propagate extract, as elementwise has other uses.
+// CHECK: %[[ELT:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<4xf32>
+// CHECK: %[[EXT:.*]] = vector.extract %[[ELT]][1] : f32 from vector<4xf32>
+// CHECK: return %[[EXT]], %[[ELT]] : f32, vector<4xf32>
+ %0 = arith.addf %arg0, %arg1 : vector<4xf32>
+ %1 = vector.extract %0[1] : f32 from vector<4xf32>
+ return %1, %0 : f32, vector<4xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.vector.propagate_extract
+ } : !transform.any_op
+ transform.yield
+ }
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thank you for working on this!
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> { | ||
let description = [{ | ||
Collect a set of patterns for propagating `vector.extract` through the | ||
vector ops. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we classify Arith ops as vector ops? Also, are you planning to add more patterns like this? Otherwise this TD Op is over-promising a bit.
On a related note, we already have similar patterns for elementwise here:
llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
Lines 964 to 1044 in d928a67
/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex: | |
/// ``` | |
/// %a = vector.broadcast %arg1 : index to vector<1x4xindex> | |
/// %b = vector.broadcast %arg2 : index to vector<1x4xindex> | |
/// %r = arith.addi %a, %b : vector<1x4xindex> | |
/// ``` | |
/// Gets converted to: | |
/// ``` | |
/// %r = arith.addi %arg0, %arg1 : index | |
/// %b = vector.broadcast %r : index to vector<1x4xindex> | |
/// ``` | |
/// | |
/// Both `vector.broadcast` and `vector.splat` are supported as broadcasting | |
/// ops. | |
struct ReorderElementwiseOpsOnBroadcast final | |
: public OpTraitRewritePattern<OpTrait::Elementwise> { | |
using OpTraitRewritePattern::OpTraitRewritePattern; | |
LogicalResult matchAndRewrite(Operation *op, | |
PatternRewriter &rewriter) const override { | |
if (op->getNumResults() != 1) | |
return failure(); | |
if (!llvm::isa<ShapedType>(op->getResults()[0].getType())) | |
return failure(); | |
if (!OpTrait::hasElementwiseMappableTraits(op)) | |
return rewriter.notifyMatchFailure( | |
op, "Op doesn't have ElementwiseMappableTraits"); | |
if (op->getNumOperands() == 0) | |
return failure(); | |
if (op->getResults()[0].getType() != op->getOperand(0).getType()) | |
return rewriter.notifyMatchFailure(op, | |
"result and operand type mismatch"); | |
if (isa<vector::FMAOp>(op)) { | |
return rewriter.notifyMatchFailure( | |
op, | |
"Op only accepts vector types - not supported as broadcast source " | |
"might be a scalar"); | |
} | |
// Get the type of the lhs operand | |
auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp(); | |
if (!lhsBcastOrSplat || | |
!isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat)) | |
return failure(); | |
auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType(); | |
// Make sure that all operands are broadcast from identical types: | |
// * scalar (`vector.broadcast` + `vector.splat`), or | |
// * vector (`vector.broadcast`). | |
// Otherwise the re-ordering wouldn't be safe. | |
if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) { | |
auto bcast = val.getDefiningOp<vector::BroadcastOp>(); | |
if (bcast) | |
return (bcast.getOperand().getType() == lhsBcastOrSplatType); | |
auto splat = val.getDefiningOp<vector::SplatOp>(); | |
if (splat) | |
return (splat.getOperand().getType() == lhsBcastOrSplatType); | |
return false; | |
})) { | |
return failure(); | |
} | |
// Collect the source values before broadcasting | |
SmallVector<Value> srcValues; | |
srcValues.reserve(op->getNumOperands()); | |
for (Value operand : op->getOperands()) { | |
srcValues.push_back(operand.getDefiningOp()->getOperand(0)); | |
} | |
// Create the "elementwise" Op | |
Operation *elementwiseOp = | |
rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues, | |
lhsBcastOrSplatType, op->getAttrs()); | |
// Replace the original Op with the elementwise Op | |
auto vectorType = op->getResultTypes()[0]; | |
rewriter.replaceOpWithNewOp<vector::BroadcastOp>( | |
op, vectorType, elementwiseOp->getResults()); | |
return success(); | |
} | |
}; |
We should move this pattern there unless there's expectation for this to grow?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most arith/math ops already have elementwise and vectorizable traits (which is checked by hasElementwiseMappableTraits
) so this pattern will cover them. I missed populateSinkVectorOpsPatterns
, so I will move pattern there (and add transform, exposing them). Those should cover most of my cases, the other pattern I may add later is vector.extract(vector.load(data))
-> memref.load(data)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
moved to populateSinkVectorOpsPatterns
@banach-space I've updated the PR, some changes:
|
Cool. With a dedicated TD op, we should split
Could you add a TODO for me? Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates and sorry for delay, Ivan!
This is already nicely polished. I've left some minor comments to improve consistency + re-use, nothing major 🙏🏻
Thanks for contributing this!
|
||
// ----- | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a block comment documenting which pattern is being tested.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index | ||
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 | ||
// CHECK-DAG: %[[C79:.*]] = arith.constant 79 : index | ||
// CHECK: %[[VAL6:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : index | ||
// CHECK: %[[VAL7:.*]] = arith.addi %[[VAL_3]], %[[VAL_4]] : index | ||
|
||
// CHECK: %[[VAL_20:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL6]], %[[C79]], %[[VAL7]]], %[[CST]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32> | ||
// CHECK: %[[VAL_21:.*]] = vector.transfer_write %[[VAL_20]], %[[VAL_5]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32> | ||
// CHECK: return %[[VAL_21]] : tensor<1x4xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since you have already started the trend, would you mind replacing VAL_N
with meaningful LIT variables names? Thanks!
/// This may result in more efficient code when we extracting a single value | ||
/// from multi-element vector and also to help canonicalize 1-element vectors to | ||
/// scalars. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am wondering about this ... The tests for the vectorizer became "cleaner" (nice!), but if a value is already in a vector register, and then it's extracted ... I would actually expect the original code to be faster. However, the overall performance will depend on the surrounding code 🤔
I am mostly waving my hands here, but perhaps refrain from saying that this "may result in more efficient code" and instead say "replaces elementwise computation on N vectors with an elementwise on M elements (M < N) and an additional vector.extract
". Basically, something a bit more neutral.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed to "cleaner" )
if (!eltwise || !OpTrait::hasElementwiseMappableTraits(eltwise) || | ||
eltwise->getNumResults() != 1 || !eltwise->hasOneUse()) | ||
return rewriter.notifyMatchFailure(op, "not a suitable op"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] Any objections to splitting this into multiple conditions and adding more descriptive failure msg?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of adding a new file, could you try adding a new RUN
line here: https://github.com/llvm/llvm-project/blob/main/mlir/test/Dialect/Vector/vector-sink.mlir. You will probably have to put the TD sequence into a separate file. Here's an example:
// RUN: mlir-opt -transform-preload-library='transform-library-paths=%p/transpose-matmul-a.mlir' -transform-interpreter -split-input-file %s | FileCheck %s --check-prefixes=CHECK,TRANSPOSE-A
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
// CHECK-LABEL: @extract_elementwise_no_single_use | ||
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>) | ||
func.func @extract_elementwise_no_single_use(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> (f32, vector<4xf32>) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you also add some "prefix" to mark negative tests? https://mlir.llvm.org/getting_started/TestingGuide/#step-3-add-the-newly-identified-missing-case
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Propagate `Extract(Elementwise(...))` -> `Elemetwise(Extract...)`. Currenly limited to the case when extract is the single use of elementwise to avoid introducing additional computations.
a52b8f3
to
64ea88d
Compare
@banach-space addressed review comments |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, thank you for addressing my comments!
LGTM
@banach-space I found another issue while testing it:
But doesn't actually support scalar inputs. |
|
Propagate
Extract(Elementwise(...))
->Elemetwise(Extract...)
.Currenly limited to the case when extract is the single use of elementwise to avoid introducing additional elementwise ops.