Skip to content

[mlir][Vector] Support vector.extract(xfer_read) folding with dynamic indices #143269

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

dcaballe
Copy link
Contributor

@dcaballe dcaballe commented Jun 7, 2025

This PR is part of the last step to remove vector.extractelement and vector.insertelement ops (RFC: https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops).
It adds support for folding vector.transfer_read(vector.extract) -> memref.load with dynamic indices, which is currently supported by vector.extractelement.

@llvmbot
Copy link
Member

llvmbot commented Jun 7, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Diego Caballero (dcaballe)

Changes

This PR is part of the last step to remove vector.extractelement and vector.insertelement ops. It adds support for folding vector.transfer_read(vector.extract) -> memref.load with dynamic indices, which is currently supported by vector.extractelement.


Full diff: https://github.com/llvm/llvm-project/pull/143269.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (+19-5)
  • (modified) mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir (+30)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 7dbb7a334fe62..36197eb1caeb1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -886,12 +886,26 @@ class RewriteScalarExtractOfTransferRead
     SmallVector<Value> newIndices(xferOp.getIndices().begin(),
                                   xferOp.getIndices().end());
     for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
-      assert(isa<Attribute>(pos) && "Unexpected non-constant index");
-      int64_t offset = cast<IntegerAttr>(cast<Attribute>(pos)).getInt();
       int64_t idx = newIndices.size() - extractOp.getNumIndices() + i;
-      OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
-          rewriter, extractOp.getLoc(),
-          rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
+
+      // Compute affine expression `newIndices[idx] + pos` where `pos` can be
+      // either a constant or a value.
+      OpFoldResult ofr;
+      if (auto attr = dyn_cast<Attribute>(pos)) {
+        int64_t offset = cast<IntegerAttr>(attr).getInt();
+        ofr = affine::makeComposedFoldedAffineApply(
+            rewriter, extractOp.getLoc(),
+            rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
+      } else {
+        Value dynamicOffset = cast<Value>(pos);
+        AffineExpr sym0, sym1;
+        bindSymbols(rewriter.getContext(), sym0, sym1);
+        ofr = affine::makeComposedFoldedAffineApply(
+            rewriter, extractOp.getLoc(), sym0 + sym1,
+            {newIndices[idx], dynamicOffset});
+      }
+
+      // Update the corresponding index with the folded result.
       if (auto value = dyn_cast<Value>(ofr)) {
         newIndices[idx] = value;
       } else {
diff --git a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
index 52b0fdee184f6..9f10063a75092 100644
--- a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
+++ b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
@@ -148,3 +148,33 @@ func.func @subvector_extract(%m: memref<?x?xf32>, %idx: index) -> vector<16xf32>
   return %1 : vector<16xf32>
 }
 
+// -----
+
+//       CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-LABEL: func @transfer_read_1d_extract_dynamic(
+//  CHECK-SAME:     %[[MEMREF:.*]]: memref<?xf32>, %[[M_IDX:.*]]: index, %[[E_IDX:.*]]: index
+//       CHECK:   %[[APPLY:.*]] = affine.apply #[[$MAP]]()[%[[M_IDX]], %[[E_IDX]]]
+//       CHECK:   %[[RES:.*]] = memref.load %[[MEMREF]][%[[APPLY]]]
+func.func @transfer_read_1d_extract_dynamic(%m: memref<?xf32>, %idx: index,
+                                            %offset: index) -> f32 {
+  %cst = arith.constant 0.0 : f32
+  %vec = vector.transfer_read %m[%idx], %cst {in_bounds = [true]} : memref<?xf32>, vector<5xf32>
+  %elem = vector.extract %vec[%offset] : f32 from vector<5xf32>
+  return %elem : f32
+}
+
+// -----
+
+//       CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-LABEL: func @transfer_read_2d_extract_dynamic(
+//  CHECK-SAME:     %[[MEMREF:.*]]: memref<?x?xf32>, %[[M_IDX:.*]]: index, %[[ROW:.*]]: index, %[[COL:.*]]: index
+//       CHECK:   %[[ROW_APPLY:.*]] = affine.apply #[[$MAP]]()[%[[M_IDX]], %[[ROW]]]
+//       CHECK:   %[[COL_APPLY:.*]] = affine.apply #[[$MAP]]()[%[[M_IDX]], %[[COL]]]
+//       CHECK:   %[[RES:.*]] = memref.load %[[MEMREF]][%[[ROW_APPLY]], %[[COL_APPLY]]]
+func.func @transfer_read_2d_extract_dynamic(%m: memref<?x?xf32>, %idx: index,
+                                            %row_offset: index, %col_offset: index) -> f32 {
+  %cst = arith.constant 0.0 : f32
+  %vec = vector.transfer_read %m[%idx, %idx], %cst {in_bounds = [true, true]} : memref<?x?xf32>, vector<10x5xf32>
+  %elem = vector.extract %vec[%row_offset, %col_offset] : f32 from vector<10x5xf32>
+  return %elem : f32
+}

Copy link
Contributor

@newling newling left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. PR title typo? vector.extract(xfer_read)

// CHECK: %[[ROW_APPLY:.*]] = affine.apply #[[$MAP]]()[%[[M_IDX]], %[[ROW]]]
// CHECK: %[[COL_APPLY:.*]] = affine.apply #[[$MAP]]()[%[[M_IDX]], %[[COL]]]
// CHECK: %[[RES:.*]] = memref.load %[[MEMREF]][%[[ROW_APPLY]], %[[COL_APPLY]]]
func.func @transfer_read_2d_extract_dynamic(%m: memref<?x?xf32>, %idx: index,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use %row_idx and %col_idx instead of using the same %idx` for both indices, to ensure algorithm is iterating through indices as expected.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@newling , I guess your "nit" is non-blocking? If yes then this is ready to land (I see "nit"s as optional/nice-to-haves)


// Compute affine expression `newIndices[idx] + pos` where `pos` can be
// either a constant or a value.
OpFoldResult ofr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Would you mind using this opportunity to replace ofr with something more descriptive? ofr is a bit like int64_t i64; 😅

@newling
Copy link
Contributor

newling commented Jun 9, 2025

LGTM, thanks!

@newling , I guess your "nit" is non-blocking? If yes then this is ready to land (I see "nit"s as optional/nice-to-haves)

Yes, was definitely a non-blocking request

@dcaballe dcaballe changed the title [mlir][Vector] Support xfer_read(vector.extract)) folding with dynamic indices [mlir][Vector] Support vector.extract(xfer_read) folding with dynamic indices Jun 16, 2025
dcaballe added 2 commits June 16, 2025 18:36
…mic indices

This PR is part of the step to remove `vector.extractelement` and
`vector.insertelement` ops. It adds support for folding
`vector.transfer_read(vector.extract) -> memref.load` with dynamic
indices, which is currently supported by `vector.extractelement`.
@dcaballe dcaballe force-pushed the remove-vector-extractelement-insertelement-1 branch from ed7d5bd to f98e348 Compare June 16, 2025 18:59
@dcaballe dcaballe merged commit a00b736 into llvm:main Jun 16, 2025
5 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants