-
Notifications
You must be signed in to change notification settings - Fork 13.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][gpu] Allow WarpOpDeadResult, WarpOpForwardOperand patterns to be used in isolation. #132860
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir Author: Charitha Saumya (charithaintc) ChangesThis PR move Full diff: https://github.com/llvm/llvm-project/pull/132860.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
index dda45219b2acc..082d990cee8a4 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
@@ -98,6 +98,10 @@ void populatePropagateWarpVectorDistributionPatterns(
const WarpShuffleFromIdxFn &warpShuffleFromIdxFn,
PatternBenefit benefit = 1, PatternBenefit readBenefit = 0);
+/// Patterns for simplification of WarpExecuteOnLane0Op during distribution.
+void populateWarpSimplificationPatterns(RewritePatternSet &pattern,
+ PatternBenefit benefit = 1);
+
/// Lambda signature to compute a reduction of a distributed value for the given
/// reduction kind and size.
using DistributedReductionFn =
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index e214257de2cdf..f0d771142e307 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1761,17 +1761,23 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
PatternBenefit readBenefit) {
patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
- patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
- WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
- WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
- WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
- patterns.getContext(), benefit);
+ patterns
+ .add<WarpOpElementwise, WarpOpBroadcast, WarpOpShapeCast, WarpOpExtract,
+ WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
+ WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
+ patterns.getContext(), benefit);
patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
benefit);
patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
benefit);
}
+void mlir::vector::populateWarpSimplificationPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<WarpOpDeadResult, WarpOpForwardOperand>(patterns.getContext(),
+ benefit);
+}
+
void mlir::vector::populateDistributeReduction(
RewritePatternSet &patterns,
const DistributedReductionFn &distributedReductionFn,
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index a54ae816570a8..feec10e6492f7 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -660,6 +660,7 @@ struct TestVectorDistribution
vector::populatePropagateWarpVectorDistributionPatterns(
patterns, distributionFn, shuffleFn, /*benefit=*/1,
/*readBenefit=*/0);
+ vector::populateWarpSimplificationPatterns(patterns);
vector::populateDistributeReduction(patterns, warpReduction, 1);
populateDistributeTransferWriteOpPatterns(patterns, distributionFn, 2);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
@@ -672,6 +673,7 @@ struct TestVectorDistribution
RewritePatternSet patterns(ctx);
vector::populatePropagateWarpVectorDistributionPatterns(
patterns, distributionFn, shuffleFn);
+ vector::populateWarpSimplificationPatterns(patterns);
vector::populateDistributeReduction(patterns, warpReduction);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
|
@llvm/pr-subscribers-mlir-vector Author: Charitha Saumya (charithaintc) ChangesThis PR move Full diff: https://github.com/llvm/llvm-project/pull/132860.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
index dda45219b2acc..082d990cee8a4 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
@@ -98,6 +98,10 @@ void populatePropagateWarpVectorDistributionPatterns(
const WarpShuffleFromIdxFn &warpShuffleFromIdxFn,
PatternBenefit benefit = 1, PatternBenefit readBenefit = 0);
+/// Patterns for simplification of WarpExecuteOnLane0Op during distribution.
+void populateWarpSimplificationPatterns(RewritePatternSet &pattern,
+ PatternBenefit benefit = 1);
+
/// Lambda signature to compute a reduction of a distributed value for the given
/// reduction kind and size.
using DistributedReductionFn =
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index e214257de2cdf..f0d771142e307 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1761,17 +1761,23 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
PatternBenefit readBenefit) {
patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
- patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
- WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
- WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
- WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
- patterns.getContext(), benefit);
+ patterns
+ .add<WarpOpElementwise, WarpOpBroadcast, WarpOpShapeCast, WarpOpExtract,
+ WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
+ WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
+ patterns.getContext(), benefit);
patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
benefit);
patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
benefit);
}
+void mlir::vector::populateWarpSimplificationPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<WarpOpDeadResult, WarpOpForwardOperand>(patterns.getContext(),
+ benefit);
+}
+
void mlir::vector::populateDistributeReduction(
RewritePatternSet &patterns,
const DistributedReductionFn &distributedReductionFn,
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index a54ae816570a8..feec10e6492f7 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -660,6 +660,7 @@ struct TestVectorDistribution
vector::populatePropagateWarpVectorDistributionPatterns(
patterns, distributionFn, shuffleFn, /*benefit=*/1,
/*readBenefit=*/0);
+ vector::populateWarpSimplificationPatterns(patterns);
vector::populateDistributeReduction(patterns, warpReduction, 1);
populateDistributeTransferWriteOpPatterns(patterns, distributionFn, 2);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
@@ -672,6 +673,7 @@ struct TestVectorDistribution
RewritePatternSet patterns(ctx);
vector::populatePropagateWarpVectorDistributionPatterns(
patterns, distributionFn, shuffleFn);
+ vector::populateWarpSimplificationPatterns(patterns);
vector::populateDistributeReduction(patterns, warpReduction);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you could move the pattern to the canonicalization. There's nothing specific to vector there iirc.
Sure. I will try that approach and get back to you. |
Hi @Groverkss, do you have any suggestions regarding this change? |
I considered this approach now, but the issue is then at op level we need to rely on I think moving these simplification patterns to a separate function is minimally invasive and gets the job done. |
Hi @banach-space @dcaballe, sorry to bug you on this. Could you please comment on this change or suggest any alternatives? |
Hi there! I'm not too familiar with vector distribution. Probably @Groverkss and @kuhar can help. |
This PR move
WarpOpDeadResult
,WarpOpForwardOperand
patterns intopopulateWarpSimplificationPatterns
so that they can be reused without having to reuse all other vector distribution patterns insidepopulatePropagateWarpVectorDistributionPatterns
.