-
Notifications
You must be signed in to change notification settings - Fork 14k
[mlir][Vector] Support vector.extract(xfer_read)
folding with dynamic indices
#143269
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Vector] Support vector.extract(xfer_read)
folding with dynamic indices
#143269
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Diego Caballero (dcaballe) ChangesThis PR is part of the last step to remove Full diff: https://github.com/llvm/llvm-project/pull/143269.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 7dbb7a334fe62..36197eb1caeb1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -886,12 +886,26 @@ class RewriteScalarExtractOfTransferRead
SmallVector<Value> newIndices(xferOp.getIndices().begin(),
xferOp.getIndices().end());
for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
- assert(isa<Attribute>(pos) && "Unexpected non-constant index");
- int64_t offset = cast<IntegerAttr>(cast<Attribute>(pos)).getInt();
int64_t idx = newIndices.size() - extractOp.getNumIndices() + i;
- OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
- rewriter, extractOp.getLoc(),
- rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
+
+ // Compute affine expression `newIndices[idx] + pos` where `pos` can be
+ // either a constant or a value.
+ OpFoldResult ofr;
+ if (auto attr = dyn_cast<Attribute>(pos)) {
+ int64_t offset = cast<IntegerAttr>(attr).getInt();
+ ofr = affine::makeComposedFoldedAffineApply(
+ rewriter, extractOp.getLoc(),
+ rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
+ } else {
+ Value dynamicOffset = cast<Value>(pos);
+ AffineExpr sym0, sym1;
+ bindSymbols(rewriter.getContext(), sym0, sym1);
+ ofr = affine::makeComposedFoldedAffineApply(
+ rewriter, extractOp.getLoc(), sym0 + sym1,
+ {newIndices[idx], dynamicOffset});
+ }
+
+ // Update the corresponding index with the folded result.
if (auto value = dyn_cast<Value>(ofr)) {
newIndices[idx] = value;
} else {
diff --git a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
index 52b0fdee184f6..9f10063a75092 100644
--- a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
+++ b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
@@ -148,3 +148,33 @@ func.func @subvector_extract(%m: memref<?x?xf32>, %idx: index) -> vector<16xf32>
return %1 : vector<16xf32>
}
+// -----
+
+// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-LABEL: func @transfer_read_1d_extract_dynamic(
+// CHECK-SAME: %[[MEMREF:.*]]: memref<?xf32>, %[[M_IDX:.*]]: index, %[[E_IDX:.*]]: index
+// CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]()[%[[M_IDX]], %[[E_IDX]]]
+// CHECK: %[[RES:.*]] = memref.load %[[MEMREF]][%[[APPLY]]]
+func.func @transfer_read_1d_extract_dynamic(%m: memref<?xf32>, %idx: index,
+ %offset: index) -> f32 {
+ %cst = arith.constant 0.0 : f32
+ %vec = vector.transfer_read %m[%idx], %cst {in_bounds = [true]} : memref<?xf32>, vector<5xf32>
+ %elem = vector.extract %vec[%offset] : f32 from vector<5xf32>
+ return %elem : f32
+}
+
+// -----
+
+// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-LABEL: func @transfer_read_2d_extract_dynamic(
+// CHECK-SAME: %[[MEMREF:.*]]: memref<?x?xf32>, %[[M_IDX:.*]]: index, %[[ROW:.*]]: index, %[[COL:.*]]: index
+// CHECK: %[[ROW_APPLY:.*]] = affine.apply #[[$MAP]]()[%[[M_IDX]], %[[ROW]]]
+// CHECK: %[[COL_APPLY:.*]] = affine.apply #[[$MAP]]()[%[[M_IDX]], %[[COL]]]
+// CHECK: %[[RES:.*]] = memref.load %[[MEMREF]][%[[ROW_APPLY]], %[[COL_APPLY]]]
+func.func @transfer_read_2d_extract_dynamic(%m: memref<?x?xf32>, %idx: index,
+ %row_offset: index, %col_offset: index) -> f32 {
+ %cst = arith.constant 0.0 : f32
+ %vec = vector.transfer_read %m[%idx, %idx], %cst {in_bounds = [true, true]} : memref<?x?xf32>, vector<10x5xf32>
+ %elem = vector.extract %vec[%row_offset, %col_offset] : f32 from vector<10x5xf32>
+ return %elem : f32
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. PR title typo? vector.extract(xfer_read)
// CHECK: %[[ROW_APPLY:.*]] = affine.apply #[[$MAP]]()[%[[M_IDX]], %[[ROW]]] | ||
// CHECK: %[[COL_APPLY:.*]] = affine.apply #[[$MAP]]()[%[[M_IDX]], %[[COL]]] | ||
// CHECK: %[[RES:.*]] = memref.load %[[MEMREF]][%[[ROW_APPLY]], %[[COL_APPLY]]] | ||
func.func @transfer_read_2d_extract_dynamic(%m: memref<?x?xf32>, %idx: index, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: use %row_idx
and %col_idx
instead of using the same %idx` for both indices, to ensure algorithm is iterating through indices as expected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
@newling , I guess your "nit" is non-blocking? If yes then this is ready to land (I see "nit"s as optional/nice-to-haves)
|
||
// Compute affine expression `newIndices[idx] + pos` where `pos` can be | ||
// either a constant or a value. | ||
OpFoldResult ofr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] Would you mind using this opportunity to replace ofr
with something more descriptive? ofr
is a bit like int64_t i64;
😅
Yes, was definitely a non-blocking request |
xfer_read(vector.extract))
folding with dynamic indicesvector.extract(xfer_read)
folding with dynamic indices
…mic indices This PR is part of the step to remove `vector.extractelement` and `vector.insertelement` ops. It adds support for folding `vector.transfer_read(vector.extract) -> memref.load` with dynamic indices, which is currently supported by `vector.extractelement`.
ed7d5bd
to
f98e348
Compare
This PR is part of the last step to remove
vector.extractelement
andvector.insertelement
ops (RFC: https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops).It adds support for folding
vector.transfer_read(vector.extract) -> memref.load
with dynamic indices, which is currently supported byvector.extractelement
.