-
Notifications
You must be signed in to change notification settings - Fork 13.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Vector] Handle 0-rank case in fold instead of RewriterPattern #130168
Conversation
@llvm/pr-subscribers-mlir Author: Kunwar Grover (Groverkss) ChangesFor vector.extract, the folder always canonicalizes to a vector.extract operation, while the rewrite pattern canonicalizes to a vector.broadcast except in the case of 0-rank vectors. Remove this special casing, and instead handle the 0-rank vector case in the folder. Full diff: https://github.com/llvm/llvm-project/pull/130168.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8e0e723cf4ed3..31f1e82ff1174 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1678,7 +1678,7 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
return source;
unsigned extractResultRank = getRank(extractOp.getType());
- if (extractResultRank >= broadcastSrcRank)
+ if (extractResultRank > broadcastSrcRank)
return Value();
// Check that the dimension of the result haven't been broadcasted.
auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
@@ -2159,13 +2159,11 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
// folding patterns.
if (extractResultRank < broadcastSrcRank)
return failure();
+ // For scalar result, the input can only be a zero-dim vector, which will
+ // be handled by the folder.
+ if (extractResultRank == 0)
+ return failure();
- // Special case if broadcast src is a 0D vector.
- if (extractResultRank == 0) {
- assert(broadcastSrcRank == 0 && llvm::isa<VectorType>(source.getType()));
- rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(extractOp, source);
- return success();
- }
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
extractOp, extractOp.getType(), source);
return success();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index bf755b466c7eb..8a9204f042ff6 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -736,7 +736,7 @@ func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
// CHECK-LABEL: fold_extract_broadcast_0dvec_input_scalar_output
// CHECK-SAME: %[[A:.*]]: vector<f32>
-// CHECK: %[[B:.+]] = vector.extractelement %[[A]][] : vector<f32>
+// CHECK: %[[B:.+]] = vector.extract %[[A]][] : f32 from vector<f32>
// CHECK: return %[[B]] : f32
func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
%idx0 : index, %idx1 : index, %idx2: index) -> f32 {
@@ -2834,7 +2834,7 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>,
%3 = vector.extract %2[] : f32 from vector<f32>
// Broadcast 0D to 3D and extract scalar.
- // CHECK: %[[extract1:.*]] = vector.extractelement %[[b]][] : vector<f32>
+ // CHECK: %[[extract1:.*]] = vector.extract %[[b]][] : f32 from vector<f32>
%4 = vector.broadcast %b : vector<f32> to vector<1x2x4xf32>
%5 = vector.extract %4[0, 0, 1] : f32 from vector<1x2x4xf32>
|
@llvm/pr-subscribers-mlir-vector Author: Kunwar Grover (Groverkss) ChangesFor vector.extract, the folder always canonicalizes to a vector.extract operation, while the rewrite pattern canonicalizes to a vector.broadcast except in the case of 0-rank vectors. Remove this special casing, and instead handle the 0-rank vector case in the folder. Full diff: https://github.com/llvm/llvm-project/pull/130168.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8e0e723cf4ed3..31f1e82ff1174 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1678,7 +1678,7 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
return source;
unsigned extractResultRank = getRank(extractOp.getType());
- if (extractResultRank >= broadcastSrcRank)
+ if (extractResultRank > broadcastSrcRank)
return Value();
// Check that the dimension of the result haven't been broadcasted.
auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
@@ -2159,13 +2159,11 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
// folding patterns.
if (extractResultRank < broadcastSrcRank)
return failure();
+ // For scalar result, the input can only be a zero-dim vector, which will
+ // be handled by the folder.
+ if (extractResultRank == 0)
+ return failure();
- // Special case if broadcast src is a 0D vector.
- if (extractResultRank == 0) {
- assert(broadcastSrcRank == 0 && llvm::isa<VectorType>(source.getType()));
- rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(extractOp, source);
- return success();
- }
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
extractOp, extractOp.getType(), source);
return success();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index bf755b466c7eb..8a9204f042ff6 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -736,7 +736,7 @@ func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
// CHECK-LABEL: fold_extract_broadcast_0dvec_input_scalar_output
// CHECK-SAME: %[[A:.*]]: vector<f32>
-// CHECK: %[[B:.+]] = vector.extractelement %[[A]][] : vector<f32>
+// CHECK: %[[B:.+]] = vector.extract %[[A]][] : f32 from vector<f32>
// CHECK: return %[[B]] : f32
func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
%idx0 : index, %idx1 : index, %idx2: index) -> f32 {
@@ -2834,7 +2834,7 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>,
%3 = vector.extract %2[] : f32 from vector<f32>
// Broadcast 0D to 3D and extract scalar.
- // CHECK: %[[extract1:.*]] = vector.extractelement %[[b]][] : vector<f32>
+ // CHECK: %[[extract1:.*]] = vector.extract %[[b]][] : f32 from vector<f32>
%4 = vector.broadcast %b : vector<f32> to vector<1x2x4xf32>
%5 = vector.extract %4[0, 0, 1] : f32 from vector<1x2x4xf32>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, thank you! LGTM % nit
No regressions in IREE: iree-org/iree#20268 |
18f90d0
to
412fb21
Compare
For vector.extract, the folder always canonicalizes to a vector.extract operation, while the rewrite pattern canonicalizes to a vector.broadcast except in the case of 0-rank vectors.
Remove this special casing, and instead handle the 0-rank vector case in the folder.