Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][gpu] Allow WarpOpDeadResult, WarpOpForwardOperand patterns to be used in isolation. #132860

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

charithaintc
Copy link
Contributor

This PR move WarpOpDeadResult, WarpOpForwardOperand patterns into populateWarpSimplificationPatterns so that they can be reused without having to reuse all other vector distribution patterns inside populatePropagateWarpVectorDistributionPatterns.

@llvmbot
Copy link
Member

llvmbot commented Mar 25, 2025

@llvm/pr-subscribers-mlir

Author: Charitha Saumya (charithaintc)

Changes

This PR move WarpOpDeadResult, WarpOpForwardOperand patterns into populateWarpSimplificationPatterns so that they can be reused without having to reuse all other vector distribution patterns inside populatePropagateWarpVectorDistributionPatterns.


Full diff: https://github.com/llvm/llvm-project/pull/132860.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h (+4)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+11-5)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+2)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
index dda45219b2acc..082d990cee8a4 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
@@ -98,6 +98,10 @@ void populatePropagateWarpVectorDistributionPatterns(
     const WarpShuffleFromIdxFn &warpShuffleFromIdxFn,
     PatternBenefit benefit = 1, PatternBenefit readBenefit = 0);
 
+/// Patterns for simplification of WarpExecuteOnLane0Op during distribution.
+void populateWarpSimplificationPatterns(RewritePatternSet &pattern,
+                                        PatternBenefit benefit = 1);
+
 /// Lambda signature to compute a reduction of a distributed value for the given
 /// reduction kind and size.
 using DistributedReductionFn =
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index e214257de2cdf..f0d771142e307 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1761,17 +1761,23 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
     const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
     PatternBenefit readBenefit) {
   patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
-  patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
-               WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
-               WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
-               WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
-      patterns.getContext(), benefit);
+  patterns
+      .add<WarpOpElementwise, WarpOpBroadcast, WarpOpShapeCast, WarpOpExtract,
+           WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
+           WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
+          patterns.getContext(), benefit);
   patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
                                     benefit);
   patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
                                benefit);
 }
 
+void mlir::vector::populateWarpSimplificationPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<WarpOpDeadResult, WarpOpForwardOperand>(patterns.getContext(),
+                                                       benefit);
+}
+
 void mlir::vector::populateDistributeReduction(
     RewritePatternSet &patterns,
     const DistributedReductionFn &distributedReductionFn,
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index a54ae816570a8..feec10e6492f7 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -660,6 +660,7 @@ struct TestVectorDistribution
       vector::populatePropagateWarpVectorDistributionPatterns(
           patterns, distributionFn, shuffleFn, /*benefit=*/1,
           /*readBenefit=*/0);
+      vector::populateWarpSimplificationPatterns(patterns);
       vector::populateDistributeReduction(patterns, warpReduction, 1);
       populateDistributeTransferWriteOpPatterns(patterns, distributionFn, 2);
       (void)applyPatternsGreedily(getOperation(), std::move(patterns));
@@ -672,6 +673,7 @@ struct TestVectorDistribution
       RewritePatternSet patterns(ctx);
       vector::populatePropagateWarpVectorDistributionPatterns(
           patterns, distributionFn, shuffleFn);
+      vector::populateWarpSimplificationPatterns(patterns);
       vector::populateDistributeReduction(patterns, warpReduction);
       (void)applyPatternsGreedily(getOperation(), std::move(patterns));
     }

@llvmbot
Copy link
Member

llvmbot commented Mar 25, 2025

@llvm/pr-subscribers-mlir-vector

Author: Charitha Saumya (charithaintc)

Changes

This PR move WarpOpDeadResult, WarpOpForwardOperand patterns into populateWarpSimplificationPatterns so that they can be reused without having to reuse all other vector distribution patterns inside populatePropagateWarpVectorDistributionPatterns.


Full diff: https://github.com/llvm/llvm-project/pull/132860.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h (+4)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+11-5)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+2)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
index dda45219b2acc..082d990cee8a4 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
@@ -98,6 +98,10 @@ void populatePropagateWarpVectorDistributionPatterns(
     const WarpShuffleFromIdxFn &warpShuffleFromIdxFn,
     PatternBenefit benefit = 1, PatternBenefit readBenefit = 0);
 
+/// Patterns for simplification of WarpExecuteOnLane0Op during distribution.
+void populateWarpSimplificationPatterns(RewritePatternSet &pattern,
+                                        PatternBenefit benefit = 1);
+
 /// Lambda signature to compute a reduction of a distributed value for the given
 /// reduction kind and size.
 using DistributedReductionFn =
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index e214257de2cdf..f0d771142e307 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1761,17 +1761,23 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
     const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
     PatternBenefit readBenefit) {
   patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
-  patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
-               WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
-               WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
-               WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
-      patterns.getContext(), benefit);
+  patterns
+      .add<WarpOpElementwise, WarpOpBroadcast, WarpOpShapeCast, WarpOpExtract,
+           WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
+           WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
+          patterns.getContext(), benefit);
   patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
                                     benefit);
   patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
                                benefit);
 }
 
+void mlir::vector::populateWarpSimplificationPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<WarpOpDeadResult, WarpOpForwardOperand>(patterns.getContext(),
+                                                       benefit);
+}
+
 void mlir::vector::populateDistributeReduction(
     RewritePatternSet &patterns,
     const DistributedReductionFn &distributedReductionFn,
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index a54ae816570a8..feec10e6492f7 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -660,6 +660,7 @@ struct TestVectorDistribution
       vector::populatePropagateWarpVectorDistributionPatterns(
           patterns, distributionFn, shuffleFn, /*benefit=*/1,
           /*readBenefit=*/0);
+      vector::populateWarpSimplificationPatterns(patterns);
       vector::populateDistributeReduction(patterns, warpReduction, 1);
       populateDistributeTransferWriteOpPatterns(patterns, distributionFn, 2);
       (void)applyPatternsGreedily(getOperation(), std::move(patterns));
@@ -672,6 +673,7 @@ struct TestVectorDistribution
       RewritePatternSet patterns(ctx);
       vector::populatePropagateWarpVectorDistributionPatterns(
           patterns, distributionFn, shuffleFn);
+      vector::populateWarpSimplificationPatterns(patterns);
       vector::populateDistributeReduction(patterns, warpReduction);
       (void)applyPatternsGreedily(getOperation(), std::move(patterns));
     }

Copy link
Contributor

@kurapov-peter kurapov-peter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you could move the pattern to the canonicalization. There's nothing specific to vector there iirc.

@charithaintc
Copy link
Contributor Author

I think you could move the pattern to the canonicalization. There's nothing specific to vector there iirc.

Sure. I will try that approach and get back to you.

@charithaintc
Copy link
Contributor Author

Hi @Groverkss, do you have any suggestions regarding this change?

@charithaintc
Copy link
Contributor Author

I think you could move the pattern to the canonicalization. There's nothing specific to vector there iirc.

I considered this approach now, but the issue is then at op level we need to rely on WarpDistributionPattern or otherwise reimplement moveRegionToNewWarpOpAndReplaceReturns. Does that sound ok to you?

I think moving these simplification patterns to a separate function is minimally invasive and gets the job done.

@charithaintc
Copy link
Contributor Author

Hi @banach-space @dcaballe, sorry to bug you on this.

Could you please comment on this change or suggest any alternatives?

@dcaballe
Copy link
Contributor

Hi there! I'm not too familiar with vector distribution. Probably @Groverkss and @kuhar can help.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants