Skip to content

Commit 4d339ec

Browse files
committed
[mlir][Vector] Add pattern to reorder elementwise and broadcast ops
The new pattern will replace elementwise(broadcast) with broadcast(elementwise) when safe. This change affects tests for vectorising nD-extract. In one case ("vectorize_nd_tensor_extract_with_tensor_extract") I just trimmed the test and only preserved the key parts (scalar and contiguous load from the original Op). We could do the same with some other tests if that helps maintainability. Differential Revision: https://reviews.llvm.org/D152812
1 parent e9d77cd commit 4d339ec

File tree

6 files changed

+200
-48
lines changed

6 files changed

+200
-48
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,10 @@ void populateVectorTransferFullPartialPatterns(
137137
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
138138
RewritePatternSet &patterns, PatternBenefit benefit = 1);
139139

140+
/// Patterns that remove redundant vector broadcasts.
141+
void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
142+
PatternBenefit benefit = 1);
143+
140144
/// Populate `patterns` with the following patterns.
141145
///
142146
/// [DecomposeDifferentRankInsertStridedSlice]

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3066,6 +3066,8 @@ transform::VectorizeOp::applyToOne(Operation *target,
30663066
if (!getDisableMultiReductionToContractPatterns())
30673067
vector::populateVectorReductionToContractPatterns(patterns);
30683068

3069+
vector::populateSinkVectorBroadcastPatterns(patterns);
3070+
30693071
patterns.add<linalg::LinalgCopyVTRForwardingPattern,
30703072
linalg::LinalgCopyVTWForwardingPattern>(ctx,
30713073
/*benefit=*/2);

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -885,6 +885,66 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
885885
std::function<bool(BitCastOp)> controlFn;
886886
};
887887

888+
/// Reorders elementwise(broadcast) to broadcast(elementwise). Ex:
889+
/// ```
890+
/// %a = vector.broadcast %arg1 : index to vector<1x4xindex>
891+
/// %b = vector.broadcast %arg2 : index to vector<1x4xindex>
892+
/// %r = arith.addi %a, %b : vector<1x4xindex>
893+
/// ```
894+
/// Gets converted to:
895+
/// ```
896+
/// %r = arith.addi %arg0, %arg1 : index
897+
/// %b = vector.broadcast %r : index to vector<1x4xindex>
898+
/// ```
899+
struct ReorderElementwiseOpsOnBroadcast final
900+
: public OpTraitRewritePattern<OpTrait::Elementwise> {
901+
using OpTraitRewritePattern::OpTraitRewritePattern;
902+
LogicalResult matchAndRewrite(Operation *op,
903+
PatternRewriter &rewriter) const override {
904+
if (op->getNumResults() != 1)
905+
return failure();
906+
if (!llvm::isa<ShapedType>(op->getResults()[0].getType()))
907+
return failure();
908+
if (!OpTrait::hasElementwiseMappableTraits(op))
909+
return failure();
910+
911+
// Get the type of the first operand
912+
auto firstBcast = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
913+
if (!firstBcast)
914+
return failure();
915+
auto firstOpType = firstBcast.getOperand().getType();
916+
917+
// Make sure that operands are "broadcast"ed from identical (scalar or
918+
// vector) types. That indicates that it's safe to skip the broadcasting of
919+
// operands.
920+
if (!llvm::all_of(op->getOperands(), [&firstOpType](Value val) {
921+
auto bcast = val.getDefiningOp<vector::BroadcastOp>();
922+
return (bcast && (bcast.getOperand().getType() == firstOpType));
923+
})) {
924+
return failure();
925+
}
926+
927+
// Collect the source values
928+
SmallVector<Value> srcValues;
929+
srcValues.reserve(op->getNumOperands());
930+
931+
for (Value operand : op->getOperands()) {
932+
srcValues.push_back(
933+
operand.getDefiningOp<vector::BroadcastOp>().getOperand());
934+
}
935+
936+
Operation *elementwiseOp =
937+
rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
938+
firstOpType, op->getAttrs());
939+
940+
auto vectorType = op->getResultTypes()[0];
941+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
942+
op, vectorType, elementwiseOp->getResults());
943+
944+
return success();
945+
}
946+
};
947+
888948
// Helper that returns a vector comparison that constructs a mask:
889949
// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
890950
//
@@ -1311,6 +1371,12 @@ void mlir::vector::
13111371
patterns.add<DropInnerMostUnitDims>(patterns.getContext(), benefit);
13121372
}
13131373

1374+
void mlir::vector::populateSinkVectorBroadcastPatterns(
1375+
RewritePatternSet &patterns, PatternBenefit benefit) {
1376+
patterns.add<ReorderElementwiseOpsOnBroadcast>(patterns.getContext(),
1377+
benefit);
1378+
}
1379+
13141380
//===----------------------------------------------------------------------===//
13151381
// TableGen'd enum attribute definitions
13161382
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir

Lines changed: 23 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -130,27 +130,29 @@ func.func @vectorize_nd_tensor_extract_transfer_read_complex(%6: tensor<45x80x16
130130
return %25 : tensor<1x4xf32>
131131
}
132132

133-
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_complex
133+
134+
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_complex(
134135
// CHECK-SAME: %[[VAL_0:.*]]: tensor<45x80x16xf32>,
135-
// CHECK-SAME: {{.*}}: index,
136+
// CHECK-SAME: %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index, %[[VAL_4:.*]]: index,
136137
// CHECK-SAME: %[[VAL_5:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
137138
// CHECK: %[[VAL_6:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
138139
// CHECK: %[[VAL_7:.*]] = arith.constant 0 : i32
139140
// CHECK: %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32
140141
// CHECK: %[[VAL_9:.*]] = arith.constant 0 : index
141142
// CHECK: %[[VAL_10:.*]] = arith.constant 79 : index
142-
// CHECK: %[[VAL_11:.*]] = vector.broadcast %{{.*}} : index to vector<1x4xindex>
143-
// CHECK: %[[VAL_12:.*]] = vector.broadcast %{{.*}} : index to vector<1x4xindex>
144-
// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_12]] : vector<1x4xindex>
145-
// CHECK: %[[VAL_14:.*]] = vector.broadcast %{{.*}} : index to vector<4xindex>
146-
// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_6]] : vector<4xindex>
147-
// CHECK: %[[VAL_16:.*]] = vector.broadcast %{{.*}} : index to vector<4xindex>
148-
// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_15]], %[[VAL_16]] : vector<4xindex>
149-
// CHECK: %[[VAL_18:.*]] = vector.shape_cast %[[VAL_13]] : vector<1x4xindex> to vector<4xindex>
150-
// CHECK: %[[VAL_19:.*]] = vector.extractelement %[[VAL_18]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex>
151-
// CHECK: %[[VAL_20:.*]] = vector.extractelement %[[VAL_17]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex>
152-
// CHECK: %[[VAL_21:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_19]], %[[VAL_10]], %[[VAL_20]]], %[[VAL_8]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32>
153-
// CHECK: %[[VAL_22:.*]] = vector.transfer_write %[[VAL_21]], %[[VAL_5]]{{\[}}%[[VAL_9]], %[[VAL_9]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
143+
// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : index
144+
// CHECK: %[[VAL_12:.*]] = vector.broadcast %[[VAL_11]] : index to vector<1x4xindex>
145+
// CHECK: %[[VAL_13:.*]] = vector.broadcast %[[VAL_3]] : index to vector<4xindex>
146+
// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_6]] : vector<4xindex>
147+
// CHECK: %[[VAL_15:.*]] = vector.broadcast %[[VAL_4]] : index to vector<4xindex>
148+
// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_14]], %[[VAL_15]] : vector<4xindex>
149+
// CHECK: %[[VAL_17:.*]] = vector.shape_cast %[[VAL_12]] : vector<1x4xindex> to vector<4xindex>
150+
// CHECK: %[[VAL_18:.*]] = vector.extractelement %[[VAL_17]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex>
151+
// CHECK: %[[VAL_19:.*]] = vector.extractelement %[[VAL_16]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex>
152+
// CHECK: %[[VAL_20:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_18]], %[[VAL_10]], %[[VAL_19]]], %[[VAL_8]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32>
153+
// CHECK: %[[VAL_21:.*]] = vector.transfer_write %[[VAL_20]], %[[VAL_5]]{{\[}}%[[VAL_9]], %[[VAL_9]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
154+
// CHECK: return %[[VAL_21]] : tensor<1x4xf32>
155+
// CHECK: }
154156

155157
transform.sequence failures(propagate) {
156158
^bb1(%arg1: !transform.any_op):
@@ -317,43 +319,16 @@ func.func @vectorize_nd_tensor_extract_with_tensor_extract(%input_1: tensor<1x20
317319
}
318320

319321
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_with_tensor_extract(
320-
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x20xi32>,
321-
// CHECK-SAME: %[[VAL_1:.*]]: tensor<257x24xf32>,
322-
// CHECK-SAME: %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index, %[[VAL_4:.*]]: index, %[[VAL_5:.*]]: index) -> tensor<1x1x4xf32> {
323-
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant dense<0> : vector<1x1x4xindex>
324-
// CHECK-DAG: %[[VAL_7:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
325-
// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 0 : i32
326-
// CHECK-DAG: %[[VAL_9:.*]] = arith.constant dense<256> : vector<1x1x4xindex>
327-
// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
328-
// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 0 : index
329-
// CHECK: %[[VAL_12:.*]] = tensor.empty() : tensor<1x1x4xf32>
330-
// CHECK: %[[VAL_13:.*]] = vector.broadcast %[[VAL_2]] : index to vector<1x1x4xindex>
331-
// CHECK: %[[VAL_14:.*]] = vector.broadcast %[[VAL_4]] : index to vector<1x1x4xindex>
332-
// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_13]], %[[VAL_14]] : vector<1x1x4xindex>
333-
// CHECK: %[[VAL_16:.*]] = vector.broadcast %[[VAL_3]] : index to vector<1x1x4xindex>
334-
// CHECK: %[[VAL_17:.*]] = vector.broadcast %[[VAL_7]] : vector<4xindex> to vector<1x1x4xindex>
335-
// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_16]], %[[VAL_17]] : vector<1x1x4xindex>
336-
// CHECK: %[[VAL_19:.*]] = vector.broadcast %[[VAL_5]] : index to vector<1x1x4xindex>
337-
// CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_18]], %[[VAL_19]] : vector<1x1x4xindex>
338-
// CHECK: %[[VAL_21:.*]] = vector.shape_cast %[[VAL_15]] : vector<1x1x4xindex> to vector<4xindex>
339-
// CHECK: %[[VAL_22:.*]] = vector.extractelement %[[VAL_21]][%[[VAL_8]] : i32] : vector<4xindex>
322+
// CHECK-SAME: %[[INPUT_1:.*]]: tensor<1x20xi32>,
323+
// CHECK-SAME: %[[INPUT_2:.*]]: tensor<257x24xf32>,
324+
// CHECK: %[[EXTRACTED_0_IDX_0:.*]] = arith.constant 0 : index
325+
// CHECK: %[[EXTRACTED_0_IDX_1:.*]] = vector.extractelement %{{.*}}[%{{.*}} : i32] : vector<4xindex>
340326
// First `tensor.extract` from the generic Op - loop invariant scalar load.
341-
// CHECK: %[[VAL_23:.*]] = tensor.extract %[[VAL_0]][%[[VAL_11]], %[[VAL_22]]] : tensor<1x20xi32>
342-
// CHECK: %[[VAL_24:.*]] = arith.index_cast %[[VAL_23]] : i32 to index
343-
// CHECK: %[[VAL_25:.*]] = vector.broadcast %[[VAL_24]] : index to vector<1x1x4xindex>
344-
// CHECK: %[[VAL_26:.*]] = arith.maxsi %[[VAL_25]], %[[VAL_6]] : vector<1x1x4xindex>
345-
// CHECK: %[[VAL_27:.*]] = arith.minsi %[[VAL_26]], %[[VAL_9]] : vector<1x1x4xindex>
346-
// CHECK: %[[VAL_28:.*]] = vector.shape_cast %[[VAL_27]] : vector<1x1x4xindex> to vector<4xindex>
347-
// CHECK: %[[VAL_29:.*]] = vector.extractelement %[[VAL_28]][%[[VAL_8]] : i32] : vector<4xindex>
348-
// CHECK: %[[VAL_30:.*]] = vector.shape_cast %[[VAL_20]] : vector<1x1x4xindex> to vector<4xindex>
349-
// CHECK: %[[VAL_31:.*]] = vector.extractelement %[[VAL_30]][%[[VAL_8]] : i32] : vector<4xindex>
327+
// CHECK: tensor.extract %[[INPUT_1]][%[[EXTRACTED_0_IDX_0]], %[[EXTRACTED_0_IDX_1]]] : tensor<1x20xi32>
350328
// The following `tensor.extract` from the generic Op s a contiguous load (all Ops used
351329
// for address calculation also satisfy the required conditions).
352-
// CHECK: %[[VAL_32:.*]] = vector.transfer_read %[[VAL_1]][%[[VAL_29]], %[[VAL_31]]], %[[VAL_10]] {in_bounds = [true, true]} : tensor<257x24xf32>, vector<1x4xf32>
353-
// CHECK: %[[VAL_33:.*]] = vector.broadcast %[[VAL_32]] : vector<1x4xf32> to vector<1x1x4xf32>
354-
// CHECK: %[[VAL_34:.*]] = vector.transfer_write %[[VAL_33]], %[[VAL_12]][%[[VAL_11]], %[[VAL_11]], %[[VAL_11]]] {in_bounds = [true, true, true]} : vector<1x1x4xf32>, tensor<1x1x4xf32>
355-
// CHECK: return %[[VAL_34]] : tensor<1x1x4xf32>
356-
// CHECK: }
330+
// CHECK: vector.transfer_read %[[INPUT_2]][%{{.*}}, %{{.*}}, %{{.*}} {in_bounds = [true, true]} : tensor<257x24xf32>, vector<1x4xf32>
331+
357332

358333
transform.sequence failures(propagate) {
359334
^bb1(%arg1: !transform.any_op):
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
// RUN: mlir-opt %s -test-sink-vector-broadcast -split-input-file | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @broadcast_scalar(
4+
// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -> vector<1x4xindex> {
5+
// CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[ARG_1]] : index
6+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
7+
// CHECK: return %[[BCAST]] : vector<1x4xindex>
8+
// CHECK: }
9+
10+
func.func @broadcast_scalar( %arg1: index, %arg2: index) -> vector<1x4xindex> {
11+
%0 = vector.broadcast %arg1 : index to vector<1x4xindex>
12+
%1 = vector.broadcast %arg2 : index to vector<1x4xindex>
13+
%2 = arith.addi %0, %1 : vector<1x4xindex>
14+
return %2 : vector<1x4xindex>
15+
}
16+
17+
// -----
18+
19+
// CHECK-LABEL: func.func @broadcast_vector(
20+
// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>,
21+
// CHECK-SAME: %[[ARG_1:.*]]: vector<4xf32>) -> vector<3x4xf32> {
22+
// CHECK: %[[ADDF:.*]] = arith.addf %[[ARG_0]], %[[ARG_1]] : vector<4xf32>
23+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADDF]] : vector<4xf32> to vector<3x4xf32>
24+
// CHECK: return %[[BCAST]] : vector<3x4xf32>
25+
// CHECK: }
26+
27+
func.func @broadcast_vector( %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vector<3x4xf32> {
28+
%arg1_bcast = vector.broadcast %arg1 : vector<4xf32> to vector<3x4xf32>
29+
%arg2_bcast = vector.broadcast %arg2 : vector<4xf32> to vector<3x4xf32>
30+
%2 = arith.addf %arg1_bcast, %arg2_bcast : vector<3x4xf32>
31+
return %2 : vector<3x4xf32>
32+
}
33+
// -----
34+
35+
// CHECK-LABEL: func.func @broadcast_vector_and_scalar(
36+
// CHECK-SAME: %[[ARG_0:.*]]: i32,
37+
// CHECK-SAME: %[[ARG_1:.*]]: vector<4xi32>) -> vector<4xi32> {
38+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : i32 to vector<4xi32>
39+
// CHECK: %[[ADD:.*]] = arith.addi %[[BCAST]], %[[ARG_1]] : vector<4xi32>
40+
// CHECK: return %[[ADD]] : vector<4xi32>
41+
// CHECK: }
42+
43+
func.func @broadcast_vector_and_scalar( %arg1: i32, %arg2: vector<4xi32>) -> vector<4xi32> {
44+
%arg1_bcast = vector.broadcast %arg1 : i32 to vector<4xi32>
45+
%2 = arith.addi %arg1_bcast, %arg2 : vector<4xi32>
46+
return %2 : vector<4xi32>
47+
}
48+
49+
// -----
50+
51+
#matmat_accesses = [
52+
affine_map<(i, j, k) -> (i, k)>,
53+
affine_map<(i, j, k) -> (k, j)>,
54+
affine_map<(i, j, k) -> (i, j)>
55+
]
56+
#matmat_trait = {
57+
indexing_maps = #matmat_accesses,
58+
iterator_types = ["parallel", "parallel", "reduction"]
59+
}
60+
61+
// CHECK-LABEL: func.func @broadcast_not_elementwise() -> vector<2x2xf32> {
62+
// CHECK-DAG: %[[VAL_0:.*]] = arith.constant dense<1.000000e+00> : vector<2x2xf32>
63+
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant dense<2.000000e+00> : vector<2x2xf32>
64+
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<3.000000e+00> : vector<2x2xf32>
65+
// CHECK: %[[VAL_3:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
66+
func.func @broadcast_not_elementwise() -> vector<2x2xf32> {
67+
%f1 = arith.constant 1.0: f32
68+
%f2 = arith.constant 2.0: f32
69+
%f3 = arith.constant 3.0: f32
70+
71+
%A = vector.broadcast %f1 : f32 to vector<2x2xf32>
72+
%B = vector.broadcast %f2 : f32 to vector<2x2xf32>
73+
%C = vector.broadcast %f3 : f32 to vector<2x2xf32>
74+
%mm1 = vector.contract #matmat_trait %A, %B, %C
75+
: vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
76+
77+
return %mm1 : vector<2x2xf32>
78+
}

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,31 @@ struct TestVectorTransferCollapseInnerMostContiguousDims
374374
}
375375
};
376376

377+
struct TestSinkVectorBroadcast
378+
: public PassWrapper<TestSinkVectorBroadcast, OperationPass<func::FuncOp>> {
379+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSinkVectorBroadcast)
380+
381+
TestSinkVectorBroadcast() = default;
382+
TestSinkVectorBroadcast(const TestSinkVectorBroadcast &pass) = default;
383+
384+
void getDependentDialects(DialectRegistry &registry) const override {
385+
registry.insert<memref::MemRefDialect, affine::AffineDialect>();
386+
}
387+
388+
StringRef getArgument() const final { return "test-sink-vector-broadcast"; }
389+
390+
StringRef getDescription() const final {
391+
return "Test lowering patterns that eliminate redundant brodacast "
392+
"operations.";
393+
}
394+
395+
void runOnOperation() override {
396+
RewritePatternSet patterns(&getContext());
397+
populateSinkVectorBroadcastPatterns(patterns);
398+
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
399+
}
400+
};
401+
377402
struct TestVectorReduceToContractPatternsPatterns
378403
: public PassWrapper<TestVectorReduceToContractPatternsPatterns,
379404
OperationPass<func::FuncOp>> {
@@ -735,6 +760,8 @@ void registerTestVectorLowerings() {
735760

736761
PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
737762

763+
PassRegistration<TestSinkVectorBroadcast>();
764+
738765
PassRegistration<TestVectorReduceToContractPatternsPatterns>();
739766

740767
PassRegistration<TestFlattenVectorTransferPatterns>();

0 commit comments

Comments
 (0)