Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][vector] Propagate vector.extract through elementwise ops #131462

Merged
merged 11 commits into from
Mar 25, 2025

Conversation

Hardcode84
Copy link
Contributor

@Hardcode84 Hardcode84 commented Mar 15, 2025

Propagate Extract(Elementwise(...)) -> Elemetwise(Extract...).

Currenly limited to the case when extract is the single use of elementwise to avoid introducing additional elementwise ops.

@llvmbot
Copy link
Member

llvmbot commented Mar 16, 2025

@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Ivan Butygin (Hardcode84)

Changes

Propagate Extract(Elementwise(...)) -> Elemetwise(Extract...).

Currenly limited to the case when extract is the single use of elementwise to avoid introducing additional elementwise ops.


Full diff: https://github.com/llvm/llvm-project/pull/131462.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td (+11)
  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (+3)
  • (modified) mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp (+5)
  • (modified) mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Vector/Transforms/VectorPropagateExtract.cpp (+66)
  • (added) mlir/test/Dialect/Vector/propagate-extracts.mlir (+47)
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index c973eca0132a9..7be39519c1037 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -453,4 +453,15 @@ def ApplyVectorReductionToContractPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyVectorPropagateExtractPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.propagate_extract",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Collect a set of patterns for propagating `vector.extract` through the
+    vector ops.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 #endif // VECTOR_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 7de4a6a315750..16c66e078821d 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -409,6 +409,9 @@ void populateVectorLinearizeShuffleLikeOpsPatterns(
     const TypeConverter &typeConverter, RewritePatternSet &patterns,
     ConversionTarget &target, unsigned targetBitWidth);
 
+/// Populates patterns for propagating `vector.extract` through the vector ops.
+void populateVectorPropagateExtractsPatterns(RewritePatternSet &patterns);
+
 } // namespace vector
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 20c577273d786..616e563fcdc77 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -204,6 +204,11 @@ void transform::ApplyTransferToScfPatternsOp::populatePatterns(
   populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions);
 }
 
+void transform::ApplyVectorPropagateExtractPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  vector::populateVectorPropagateExtractsPatterns(patterns);
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 8ca5cb6c6dfab..8830375f88104 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -24,6 +24,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   VectorTransforms.cpp
   VectorUnroll.cpp
   VectorMaskElimination.cpp
+  VectorPropagateExtract.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/Transforms
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorPropagateExtract.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorPropagateExtract.cpp
new file mode 100644
index 0000000000000..10f578179bc94
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorPropagateExtract.cpp
@@ -0,0 +1,66 @@
+//===- VectorPropagateExtract.cpp - vector.extract propagation - ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns for vector.extract propagation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Pattern to rewrite a ExtractOp(Elementwise) -> Elementwise(ExtractOp).
+class ExtractOpFromElementwise final
+    : public OpRewritePattern<vector::ExtractOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ExtractOp op,
+                                PatternRewriter &rewriter) const override {
+    Operation *eltwise = op.getVector().getDefiningOp();
+
+    // Elementwise op with single result and `extract` is single user.
+    if (!eltwise || !OpTrait::hasElementwiseMappableTraits(eltwise) ||
+        eltwise->getNumResults() != 1 || !eltwise->hasOneUse())
+      return failure();
+
+    // Arguments and result types must match.
+    if (!llvm::all_equal(llvm::concat<Type>(eltwise->getOperandTypes(),
+                                            eltwise->getResultTypes())))
+      return failure();
+
+    Type dstType = op.getType();
+
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPoint(eltwise);
+
+    IRMapping mapping;
+    Location loc = eltwise->getLoc();
+    for (auto &&[i, arg] : llvm::enumerate(eltwise->getOperands())) {
+      Value newArg =
+          rewriter.create<vector::ExtractOp>(loc, arg, op.getMixedPosition());
+      mapping.map(arg, newArg);
+    }
+
+    Operation *newEltwise = rewriter.clone(*eltwise, mapping);
+    newEltwise->getResult(0).setType(dstType);
+
+    rewriter.replaceOp(op, newEltwise);
+    rewriter.eraseOp(eltwise);
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::vector::populateVectorPropagateExtractsPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<ExtractOpFromElementwise>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Vector/propagate-extracts.mlir b/mlir/test/Dialect/Vector/propagate-extracts.mlir
new file mode 100644
index 0000000000000..6c6f812c8f6d2
--- /dev/null
+++ b/mlir/test/Dialect/Vector/propagate-extracts.mlir
@@ -0,0 +1,47 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
+
+// CHECK-LABEL: @extract_elementwise
+//  CHECK-SAME:   (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
+func.func @extract_elementwise(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
+// CHECK:   %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : f32 from vector<4xf32>
+// CHECK:   %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32>
+// CHECK:   %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32
+// CHECK:   return %[[RES]] : f32
+  %0 = arith.addf %arg0, %arg1 : vector<4xf32>
+  %1 = vector.extract %0[1] : f32 from vector<4xf32>
+  return %1 : f32
+}
+
+// CHECK-LABEL: @extract_vec_elementwise
+//  CHECK-SAME:   (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<2x4xf32>)
+func.func @extract_vec_elementwise(%arg0: vector<2x4xf32>, %arg1: vector<2x4xf32>) -> vector<4xf32> {
+// CHECK:   %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : vector<4xf32> from vector<2x4xf32>
+// CHECK:   %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf32> from vector<2x4xf32>
+// CHECK:   %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : vector<4xf32>
+// CHECK:   return %[[RES]] : vector<4xf32>
+  %0 = arith.addf %arg0, %arg1 : vector<2x4xf32>
+  %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
+  return %1 : vector<4xf32>
+}
+
+// CHECK-LABEL: @extract_elementwise_use
+//  CHECK-SAME:   (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
+func.func @extract_elementwise_use(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> (f32, vector<4xf32>) {
+// Do not propagate extract, as elementwise has other uses.
+// CHECK:   %[[ELT:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<4xf32>
+// CHECK:   %[[EXT:.*]] = vector.extract %[[ELT]][1] : f32 from vector<4xf32>
+// CHECK:   return %[[EXT]], %[[ELT]] : f32, vector<4xf32>
+  %0 = arith.addf %arg0, %arg1 : vector<4xf32>
+  %1 = vector.extract %0[1] : f32 from vector<4xf32>
+  return %1, %0 : f32, vector<4xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.propagate_extract
+    } : !transform.any_op
+    transform.yield
+  }
+}

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thank you for working on this!

[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Collect a set of patterns for propagating `vector.extract` through the
vector ops.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we classify Arith ops as vector ops? Also, are you planning to add more patterns like this? Otherwise this TD Op is over-promising a bit.

On a related note, we already have similar patterns for elementwise here:

/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
/// ```
/// %a = vector.broadcast %arg1 : index to vector<1x4xindex>
/// %b = vector.broadcast %arg2 : index to vector<1x4xindex>
/// %r = arith.addi %a, %b : vector<1x4xindex>
/// ```
/// Gets converted to:
/// ```
/// %r = arith.addi %arg0, %arg1 : index
/// %b = vector.broadcast %r : index to vector<1x4xindex>
/// ```
///
/// Both `vector.broadcast` and `vector.splat` are supported as broadcasting
/// ops.
struct ReorderElementwiseOpsOnBroadcast final
: public OpTraitRewritePattern<OpTrait::Elementwise> {
using OpTraitRewritePattern::OpTraitRewritePattern;
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (op->getNumResults() != 1)
return failure();
if (!llvm::isa<ShapedType>(op->getResults()[0].getType()))
return failure();
if (!OpTrait::hasElementwiseMappableTraits(op))
return rewriter.notifyMatchFailure(
op, "Op doesn't have ElementwiseMappableTraits");
if (op->getNumOperands() == 0)
return failure();
if (op->getResults()[0].getType() != op->getOperand(0).getType())
return rewriter.notifyMatchFailure(op,
"result and operand type mismatch");
if (isa<vector::FMAOp>(op)) {
return rewriter.notifyMatchFailure(
op,
"Op only accepts vector types - not supported as broadcast source "
"might be a scalar");
}
// Get the type of the lhs operand
auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp();
if (!lhsBcastOrSplat ||
!isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
return failure();
auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
// Make sure that all operands are broadcast from identical types:
// * scalar (`vector.broadcast` + `vector.splat`), or
// * vector (`vector.broadcast`).
// Otherwise the re-ordering wouldn't be safe.
if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) {
auto bcast = val.getDefiningOp<vector::BroadcastOp>();
if (bcast)
return (bcast.getOperand().getType() == lhsBcastOrSplatType);
auto splat = val.getDefiningOp<vector::SplatOp>();
if (splat)
return (splat.getOperand().getType() == lhsBcastOrSplatType);
return false;
})) {
return failure();
}
// Collect the source values before broadcasting
SmallVector<Value> srcValues;
srcValues.reserve(op->getNumOperands());
for (Value operand : op->getOperands()) {
srcValues.push_back(operand.getDefiningOp()->getOperand(0));
}
// Create the "elementwise" Op
Operation *elementwiseOp =
rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
lhsBcastOrSplatType, op->getAttrs());
// Replace the original Op with the elementwise Op
auto vectorType = op->getResultTypes()[0];
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
op, vectorType, elementwiseOp->getResults());
return success();
}
};

We should move this pattern there unless there's expectation for this to grow?

Copy link
Contributor Author

@Hardcode84 Hardcode84 Mar 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most arith/math ops already have elementwise and vectorizable traits (which is checked by hasElementwiseMappableTraits) so this pattern will cover them. I missed populateSinkVectorOpsPatterns, so I will move pattern there (and add transform, exposing them). Those should cover most of my cases, the other pattern I may add later is vector.extract(vector.load(data)) -> memref.load(data).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved to populateSinkVectorOpsPatterns

@Hardcode84
Copy link
Contributor Author

Hardcode84 commented Mar 18, 2025

@banach-space I've updated the PR, some changes:

  • Moved to populateSinkVectorOpsPatterns, populateSinkVectorOpsPatterns is used in linalg vectorizer so there are changes in related tests, but I think they are beneficial.
  • Exposed populateSinkVectorOpsPatterns through the transform interface and added a smoke test
  • Relaxed type comparison between args and result, to allow conversion ops like arith.index_cast

@banach-space
Copy link
Contributor

  • Exposed populateSinkVectorOpsPatterns through the transform interface and added a smoke test

Cool. With a dedicated TD op, we should split

void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::populateVectorReductionToContractPatterns(patterns);
vector::populateSinkVectorOpsPatterns(patterns);
}

Could you add a TODO for me? Thanks!

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates and sorry for delay, Ivan!

This is already nicely polished. I've left some minor comments to improve consistency + re-use, nothing major 🙏🏻

Thanks for contributing this!

Comment on lines +426 to +429

// -----

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a block comment documenting which pattern is being tested.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 66 to 74
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[C79:.*]] = arith.constant 79 : index
// CHECK: %[[VAL6:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : index
// CHECK: %[[VAL7:.*]] = arith.addi %[[VAL_3]], %[[VAL_4]] : index

// CHECK: %[[VAL_20:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL6]], %[[C79]], %[[VAL7]]], %[[CST]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32>
// CHECK: %[[VAL_21:.*]] = vector.transfer_write %[[VAL_20]], %[[VAL_5]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
// CHECK: return %[[VAL_21]] : tensor<1x4xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you have already started the trend, would you mind replacing VAL_N with meaningful LIT variables names? Thanks!

Comment on lines 1047 to 1049
/// This may result in more efficient code when we extracting a single value
/// from multi-element vector and also to help canonicalize 1-element vectors to
/// scalars.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering about this ... The tests for the vectorizer became "cleaner" (nice!), but if a value is already in a vector register, and then it's extracted ... I would actually expect the original code to be faster. However, the overall performance will depend on the surrounding code 🤔

I am mostly waving my hands here, but perhaps refrain from saying that this "may result in more efficient code" and instead say "replaces elementwise computation on N vectors with an elementwise on M elements (M < N) and an additional vector.extract". Basically, something a bit more neutral.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to "cleaner" )

Comment on lines 1070 to 1072
if (!eltwise || !OpTrait::hasElementwiseMappableTraits(eltwise) ||
eltwise->getNumResults() != 1 || !eltwise->hasOneUse())
return rewriter.notifyMatchFailure(op, "not a suitable op");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Any objections to splitting this into multiple conditions and adding more descriptive failure msg?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of adding a new file, could you try adding a new RUN line here: https://github.com/llvm/llvm-project/blob/main/mlir/test/Dialect/Vector/vector-sink.mlir. You will probably have to put the TD sequence into a separate file. Here's an example:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


// CHECK-LABEL: @extract_elementwise_no_single_use
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
func.func @extract_elementwise_no_single_use(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> (f32, vector<4xf32>) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add some "prefix" to mark negative tests? https://mlir.llvm.org/getting_started/TestingGuide/#step-3-add-the-newly-identified-missing-case

Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Propagate `Extract(Elementwise(...))` -> `Elemetwise(Extract...)`.

Currenly limited to the case when extract is the single use of elementwise to avoid introducing additional computations.
@Hardcode84 Hardcode84 force-pushed the vec-extract-propagate branch from a52b8f3 to 64ea88d Compare March 22, 2025 23:14
@Hardcode84
Copy link
Contributor Author

@banach-space addressed review comments

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thank you for addressing my comments!

LGTM

@Hardcode84
Copy link
Contributor Author

@banach-space I found another issue while testing it: vector.fma has the ElementwiseMappable trait, which includes

// Elementwise op can be applied to scalars instead tensor/vector operands.
def Scalarizable : NativeOpTrait<"Scalarizable", [Elementwise]>;

But doesn't actually support scalar inputs.

@Hardcode84
Copy link
Contributor Author

@banach-space I found another issue while testing it: vector.fma has the ElementwiseMappable trait, which includes

// Elementwise op can be applied to scalars instead tensor/vector operands.
def Scalarizable : NativeOpTrait<"Scalarizable", [Elementwise]>;

But doesn't actually support scalar inputs.

#132611

@Hardcode84 Hardcode84 merged commit 9b02222 into llvm:main Mar 25, 2025
11 checks passed
@Hardcode84 Hardcode84 deleted the vec-extract-propagate branch March 25, 2025 11:07
basioli-k added a commit that referenced this pull request Mar 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants