Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Vector] Handle 0-rank case in fold instead of RewriterPattern #130168

Merged
merged 1 commit into from
Mar 24, 2025

Conversation

Groverkss
Copy link
Member

For vector.extract, the folder always canonicalizes to a vector.extract operation, while the rewrite pattern canonicalizes to a vector.broadcast except in the case of 0-rank vectors.

Remove this special casing, and instead handle the 0-rank vector case in the folder.

@llvmbot
Copy link
Member

llvmbot commented Mar 6, 2025

@llvm/pr-subscribers-mlir

Author: Kunwar Grover (Groverkss)

Changes

For vector.extract, the folder always canonicalizes to a vector.extract operation, while the rewrite pattern canonicalizes to a vector.broadcast except in the case of 0-rank vectors.

Remove this special casing, and instead handle the 0-rank vector case in the folder.


Full diff: https://github.com/llvm/llvm-project/pull/130168.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+5-7)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+2-2)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8e0e723cf4ed3..31f1e82ff1174 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1678,7 +1678,7 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
     return source;
 
   unsigned extractResultRank = getRank(extractOp.getType());
-  if (extractResultRank >= broadcastSrcRank)
+  if (extractResultRank > broadcastSrcRank)
     return Value();
   // Check that the dimension of the result haven't been broadcasted.
   auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
@@ -2159,13 +2159,11 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
     // folding patterns.
     if (extractResultRank < broadcastSrcRank)
       return failure();
+    // For scalar result, the input can only be a zero-dim vector, which will
+    // be handled by the folder.
+    if (extractResultRank == 0)
+      return failure();
 
-    // Special case if broadcast src is a 0D vector.
-    if (extractResultRank == 0) {
-      assert(broadcastSrcRank == 0 && llvm::isa<VectorType>(source.getType()));
-      rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(extractOp, source);
-      return success();
-    }
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
         extractOp, extractOp.getType(), source);
     return success();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index bf755b466c7eb..8a9204f042ff6 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -736,7 +736,7 @@ func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
 
 // CHECK-LABEL: fold_extract_broadcast_0dvec_input_scalar_output
 //  CHECK-SAME:   %[[A:.*]]: vector<f32>
-//       CHECK:   %[[B:.+]] = vector.extractelement %[[A]][] : vector<f32>
+//       CHECK:   %[[B:.+]] = vector.extract %[[A]][] : f32 from vector<f32>
 //       CHECK:   return %[[B]] : f32
 func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>, 
   %idx0 : index, %idx1 : index, %idx2: index) -> f32 {
@@ -2834,7 +2834,7 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>,
   %3 = vector.extract %2[] : f32 from vector<f32>
 
   // Broadcast 0D to 3D and extract scalar.
-  // CHECK: %[[extract1:.*]] = vector.extractelement %[[b]][] : vector<f32>
+  // CHECK: %[[extract1:.*]] = vector.extract %[[b]][] : f32 from vector<f32>
   %4 = vector.broadcast %b : vector<f32> to vector<1x2x4xf32>
   %5 = vector.extract %4[0, 0, 1] : f32 from vector<1x2x4xf32>
 

@llvmbot
Copy link
Member

llvmbot commented Mar 6, 2025

@llvm/pr-subscribers-mlir-vector

Author: Kunwar Grover (Groverkss)

Changes

For vector.extract, the folder always canonicalizes to a vector.extract operation, while the rewrite pattern canonicalizes to a vector.broadcast except in the case of 0-rank vectors.

Remove this special casing, and instead handle the 0-rank vector case in the folder.


Full diff: https://github.com/llvm/llvm-project/pull/130168.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+5-7)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+2-2)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8e0e723cf4ed3..31f1e82ff1174 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1678,7 +1678,7 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
     return source;
 
   unsigned extractResultRank = getRank(extractOp.getType());
-  if (extractResultRank >= broadcastSrcRank)
+  if (extractResultRank > broadcastSrcRank)
     return Value();
   // Check that the dimension of the result haven't been broadcasted.
   auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
@@ -2159,13 +2159,11 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
     // folding patterns.
     if (extractResultRank < broadcastSrcRank)
       return failure();
+    // For scalar result, the input can only be a zero-dim vector, which will
+    // be handled by the folder.
+    if (extractResultRank == 0)
+      return failure();
 
-    // Special case if broadcast src is a 0D vector.
-    if (extractResultRank == 0) {
-      assert(broadcastSrcRank == 0 && llvm::isa<VectorType>(source.getType()));
-      rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(extractOp, source);
-      return success();
-    }
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
         extractOp, extractOp.getType(), source);
     return success();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index bf755b466c7eb..8a9204f042ff6 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -736,7 +736,7 @@ func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
 
 // CHECK-LABEL: fold_extract_broadcast_0dvec_input_scalar_output
 //  CHECK-SAME:   %[[A:.*]]: vector<f32>
-//       CHECK:   %[[B:.+]] = vector.extractelement %[[A]][] : vector<f32>
+//       CHECK:   %[[B:.+]] = vector.extract %[[A]][] : f32 from vector<f32>
 //       CHECK:   return %[[B]] : f32
 func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>, 
   %idx0 : index, %idx1 : index, %idx2: index) -> f32 {
@@ -2834,7 +2834,7 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>,
   %3 = vector.extract %2[] : f32 from vector<f32>
 
   // Broadcast 0D to 3D and extract scalar.
-  // CHECK: %[[extract1:.*]] = vector.extractelement %[[b]][] : vector<f32>
+  // CHECK: %[[extract1:.*]] = vector.extract %[[b]][] : f32 from vector<f32>
   %4 = vector.broadcast %b : vector<f32> to vector<1x2x4xf32>
   %5 = vector.extract %4[0, 0, 1] : f32 from vector<1x2x4xf32>
 

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thank you! LGTM % nit

@Groverkss
Copy link
Member Author

No regressions in IREE: iree-org/iree#20268

@Groverkss Groverkss force-pushed the handle-0rank-case-in-fold branch from 18f90d0 to 412fb21 Compare March 24, 2025 13:02
@Groverkss Groverkss merged commit f3fa54a into llvm:main Mar 24, 2025
8 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants