Skip to content

Commit 665299e

Browse files
[mlir][Transforms] Add a utility method to move value definitions. (llvm#130874)
llvm@205c532 added a transform utility that moved all SSA dependences of an operation before an insertion point. Similar to that, this PR adds a transform utility function, `moveValueDefinitions` to move the slice of operations that define all values in a `ValueRange` before the insertion point. While very similar to `moveOperationDependencies`, this method differs in a few ways 1. When computing the backward slice since the start of the slice is value, the slice computed needs to be inclusive. 2. The combined backward slice needs to be sorted topologically before moving them to avoid SSA use-def violations while moving individual ops. The PR also adds a new transform op to test this new utility function. --------- Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
1 parent e11ede5 commit 665299e

File tree

5 files changed

+344
-1
lines changed

5 files changed

+344
-1
lines changed

mlir/include/mlir/Transforms/RegionUtils.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include "mlir/IR/Region.h"
1313
#include "mlir/IR/Value.h"
14+
#include "mlir/IR/ValueRange.h"
1415

1516
#include "llvm/ADT/SetVector.h"
1617

@@ -80,6 +81,16 @@ LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op,
8081
LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op,
8182
Operation *insertionPoint);
8283

84+
/// Move definitions of `values` before an insertion point. Current support is
85+
/// only for movement of definitions within the same basic block. Note that this
86+
/// is an all-or-nothing approach. Either definitions of all values are moved
87+
/// before insertion point, or none of them are.
88+
LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values,
89+
Operation *insertionPoint,
90+
DominanceInfo &dominance);
91+
LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values,
92+
Operation *insertionPoint);
93+
8394
/// Run a set of structural simplifications over the given regions. This
8495
/// includes transformations like unreachable block elimination, dead argument
8596
/// elimination, as well as some other DCE. This function returns success if any

mlir/lib/Transforms/Utils/RegionUtils.cpp

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1070,7 +1070,7 @@ LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter,
10701070
// in different basic blocks.
10711071
if (op->getBlock() != insertionPoint->getBlock()) {
10721072
return rewriter.notifyMatchFailure(
1073-
op, "unsupported caes where operation and insertion point are not in "
1073+
op, "unsupported case where operation and insertion point are not in "
10741074
"the same basic block");
10751075
}
10761076
// If `insertionPoint` does not dominate `op`, do nothing
@@ -1115,3 +1115,70 @@ LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter,
11151115
DominanceInfo dominance(op);
11161116
return moveOperationDependencies(rewriter, op, insertionPoint, dominance);
11171117
}
1118+
1119+
LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
1120+
ValueRange values,
1121+
Operation *insertionPoint,
1122+
DominanceInfo &dominance) {
1123+
// Remove the values that already dominate the insertion point.
1124+
SmallVector<Value> prunedValues;
1125+
for (auto value : values) {
1126+
if (dominance.properlyDominates(value, insertionPoint)) {
1127+
continue;
1128+
}
1129+
// Block arguments are not supported.
1130+
if (isa<BlockArgument>(value)) {
1131+
return rewriter.notifyMatchFailure(
1132+
insertionPoint,
1133+
"unsupported case of moving block argument before insertion point");
1134+
}
1135+
// Check for currently unsupported case if the insertion point is in a
1136+
// different block.
1137+
if (value.getDefiningOp()->getBlock() != insertionPoint->getBlock()) {
1138+
return rewriter.notifyMatchFailure(
1139+
insertionPoint,
1140+
"unsupported case of moving definition of value before an insertion "
1141+
"point in a different basic block");
1142+
}
1143+
prunedValues.push_back(value);
1144+
}
1145+
1146+
// Find the backward slice of operation for each `Value` the operation
1147+
// depends on. Prune the slice to only include operations not already
1148+
// dominated by the `insertionPoint`
1149+
BackwardSliceOptions options;
1150+
options.inclusive = true;
1151+
options.omitUsesFromAbove = false;
1152+
// Since current support is to only move within a same basic block,
1153+
// the slices dont need to look past block arguments.
1154+
options.omitBlockArguments = true;
1155+
options.filter = [&](Operation *sliceBoundaryOp) {
1156+
return !dominance.properlyDominates(sliceBoundaryOp, insertionPoint);
1157+
};
1158+
llvm::SetVector<Operation *> slice;
1159+
for (auto value : prunedValues) {
1160+
getBackwardSlice(value, &slice, options);
1161+
}
1162+
1163+
// If the slice contains `insertionPoint` cannot move the dependencies.
1164+
if (slice.contains(insertionPoint)) {
1165+
return rewriter.notifyMatchFailure(
1166+
insertionPoint,
1167+
"cannot move dependencies before operation in backward slice of op");
1168+
}
1169+
1170+
// Sort operations topologically before moving.
1171+
mlir::topologicalSort(slice);
1172+
1173+
for (Operation *op : slice) {
1174+
rewriter.moveOpBefore(op, insertionPoint);
1175+
}
1176+
return success();
1177+
}
1178+
1179+
LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
1180+
ValueRange values,
1181+
Operation *insertionPoint) {
1182+
DominanceInfo dominance(insertionPoint);
1183+
return moveValueDefinitions(rewriter, values, insertionPoint, dominance);
1184+
}

mlir/test/Transforms/move-operation-deps.mlir

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,3 +234,229 @@ module attributes {transform.with_named_sequence} {
234234
transform.yield
235235
}
236236
}
237+
238+
// -----
239+
240+
// Check simple move value definitions before insertion operation.
241+
func.func @simple_move_values() -> f32 {
242+
%0 = "before"() : () -> (f32)
243+
%1 = "moved_op_1"() : () -> (f32)
244+
%2 = "moved_op_2"() : () -> (f32)
245+
%3 = "foo"(%1, %2) : (f32, f32) -> (f32)
246+
return %3 : f32
247+
}
248+
// CHECK-LABEL: func @simple_move_values()
249+
// CHECK: %[[MOVED1:.+]] = "moved_op_1"
250+
// CHECK: %[[MOVED2:.+]] = "moved_op_2"
251+
// CHECK: %[[BEFORE:.+]] = "before"
252+
// CHECK: %[[FOO:.+]] = "foo"(%[[MOVED1]], %[[MOVED2]])
253+
// CHECK: return %[[FOO]]
254+
255+
module attributes {transform.with_named_sequence} {
256+
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
257+
%op1 = transform.structured.match ops{["moved_op_1"]} in %arg0
258+
: (!transform.any_op) -> !transform.any_op
259+
%op2 = transform.structured.match ops{["moved_op_2"]} in %arg0
260+
: (!transform.any_op) -> !transform.any_op
261+
%op3 = transform.structured.match ops{["before"]} in %arg0
262+
: (!transform.any_op) -> !transform.any_op
263+
%v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
264+
%v2 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
265+
transform.test.move_value_defns %v1, %v2 before %op3
266+
: (!transform.any_value, !transform.any_value), !transform.any_op
267+
transform.yield
268+
}
269+
}
270+
271+
// -----
272+
273+
// Compute slice including the implicitly captured values.
274+
func.func @move_region_dependencies_values() -> f32 {
275+
%0 = "before"() : () -> (f32)
276+
%1 = "moved_op_1"() : () -> (f32)
277+
%2 = "moved_op_2"() ({
278+
%3 = "inner_op"(%1) : (f32) -> (f32)
279+
"yield"(%3) : (f32) -> ()
280+
}) : () -> (f32)
281+
return %2 : f32
282+
}
283+
// CHECK-LABEL: func @move_region_dependencies_values()
284+
// CHECK: %[[MOVED1:.+]] = "moved_op_1"
285+
// CHECK: %[[MOVED2:.+]] = "moved_op_2"
286+
// CHECK: %[[BEFORE:.+]] = "before"
287+
288+
module attributes {transform.with_named_sequence} {
289+
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
290+
%op1 = transform.structured.match ops{["moved_op_2"]} in %arg0
291+
: (!transform.any_op) -> !transform.any_op
292+
%op2 = transform.structured.match ops{["before"]} in %arg0
293+
: (!transform.any_op) -> !transform.any_op
294+
%v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
295+
transform.test.move_value_defns %v1 before %op2
296+
: (!transform.any_value), !transform.any_op
297+
transform.yield
298+
}
299+
}
300+
301+
// -----
302+
303+
// Move operations in toplogical sort order
304+
func.func @move_values_in_topological_sort_order() -> f32 {
305+
%0 = "before"() : () -> (f32)
306+
%1 = "moved_op_1"() : () -> (f32)
307+
%2 = "moved_op_2"() : () -> (f32)
308+
%3 = "moved_op_3"(%1) : (f32) -> (f32)
309+
%4 = "moved_op_4"(%1, %3) : (f32, f32) -> (f32)
310+
%5 = "moved_op_5"(%2) : (f32) -> (f32)
311+
%6 = "foo"(%4, %5) : (f32, f32) -> (f32)
312+
return %6 : f32
313+
}
314+
// CHECK-LABEL: func @move_values_in_topological_sort_order()
315+
// CHECK: %[[MOVED_1:.+]] = "moved_op_1"
316+
// CHECK-DAG: %[[MOVED_2:.+]] = "moved_op_3"(%[[MOVED_1]])
317+
// CHECK-DAG: %[[MOVED_3:.+]] = "moved_op_4"(%[[MOVED_1]], %[[MOVED_2]])
318+
// CHECK-DAG: %[[MOVED_4:.+]] = "moved_op_2"
319+
// CHECK-DAG: %[[MOVED_5:.+]] = "moved_op_5"(%[[MOVED_4]])
320+
// CHECK: %[[BEFORE:.+]] = "before"
321+
// CHECK: %[[FOO:.+]] = "foo"(%[[MOVED_3]], %[[MOVED_5]])
322+
// CHECK: return %[[FOO]]
323+
324+
module attributes {transform.with_named_sequence} {
325+
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
326+
%op1 = transform.structured.match ops{["moved_op_4"]} in %arg0
327+
: (!transform.any_op) -> !transform.any_op
328+
%op2 = transform.structured.match ops{["moved_op_5"]} in %arg0
329+
: (!transform.any_op) -> !transform.any_op
330+
%op3 = transform.structured.match ops{["before"]} in %arg0
331+
: (!transform.any_op) -> !transform.any_op
332+
%v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
333+
%v2 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
334+
transform.test.move_value_defns %v1, %v2 before %op3
335+
: (!transform.any_value, !transform.any_value), !transform.any_op
336+
transform.yield
337+
}
338+
}
339+
340+
// -----
341+
342+
// Move only those value definitions that are not dominated by insertion point
343+
344+
func.func @move_only_required_defns() -> (f32, f32, f32, f32) {
345+
%0 = "unmoved_op"() : () -> (f32)
346+
%1 = "dummy_op"() : () -> (f32)
347+
%2 = "before"() : () -> (f32)
348+
%3 = "moved_op"() : () -> (f32)
349+
return %0, %1, %2, %3 : f32, f32, f32, f32
350+
}
351+
// CHECK-LABEL: func @move_only_required_defns()
352+
// CHECK: %[[UNMOVED:.+]] = "unmoved_op"
353+
// CHECK: %[[DUMMY:.+]] = "dummy_op"
354+
// CHECK: %[[MOVED:.+]] = "moved_op"
355+
// CHECK: %[[BEFORE:.+]] = "before"
356+
357+
module attributes {transform.with_named_sequence} {
358+
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
359+
%op1 = transform.structured.match ops{["unmoved_op"]} in %arg0
360+
: (!transform.any_op) -> !transform.any_op
361+
%op2 = transform.structured.match ops{["dummy_op"]} in %arg0
362+
: (!transform.any_op) -> !transform.any_op
363+
%op3 = transform.structured.match ops{["before"]} in %arg0
364+
: (!transform.any_op) -> !transform.any_op
365+
%op4 = transform.structured.match ops{["moved_op"]} in %arg0
366+
: (!transform.any_op) -> !transform.any_op
367+
%v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
368+
%v2 = transform.get_result %op4[0] : (!transform.any_op) -> !transform.any_value
369+
transform.test.move_value_defns %v1, %v2 before %op3
370+
: (!transform.any_value, !transform.any_value), !transform.any_op
371+
transform.yield
372+
}
373+
}
374+
375+
// -----
376+
377+
// Move only those value definitions that are not dominated by insertion point
378+
379+
func.func @move_only_required_defns() -> (f32, f32, f32, f32) {
380+
%0 = "unmoved_op"() : () -> (f32)
381+
%1 = "dummy_op"() : () -> (f32)
382+
%2 = "before"() : () -> (f32)
383+
%3 = "moved_op"() : () -> (f32)
384+
return %0, %1, %2, %3 : f32, f32, f32, f32
385+
}
386+
// CHECK-LABEL: func @move_only_required_defns()
387+
// CHECK: %[[UNMOVED:.+]] = "unmoved_op"
388+
// CHECK: %[[DUMMY:.+]] = "dummy_op"
389+
// CHECK: %[[MOVED:.+]] = "moved_op"
390+
// CHECK: %[[BEFORE:.+]] = "before"
391+
392+
module attributes {transform.with_named_sequence} {
393+
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
394+
%op1 = transform.structured.match ops{["unmoved_op"]} in %arg0
395+
: (!transform.any_op) -> !transform.any_op
396+
%op2 = transform.structured.match ops{["dummy_op"]} in %arg0
397+
: (!transform.any_op) -> !transform.any_op
398+
%op3 = transform.structured.match ops{["before"]} in %arg0
399+
: (!transform.any_op) -> !transform.any_op
400+
%op4 = transform.structured.match ops{["moved_op"]} in %arg0
401+
: (!transform.any_op) -> !transform.any_op
402+
%v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
403+
%v2 = transform.get_result %op4[0] : (!transform.any_op) -> !transform.any_value
404+
transform.test.move_value_defns %v1, %v2 before %op3
405+
: (!transform.any_value, !transform.any_value), !transform.any_op
406+
transform.yield
407+
}
408+
}
409+
410+
// -----
411+
412+
// Check handling of block arguments
413+
func.func @move_only_required_defns() -> (f32, f32) {
414+
%0 = "unmoved_op"() : () -> (f32)
415+
cf.br ^bb0(%0 : f32)
416+
^bb0(%arg0 : f32) :
417+
%1 = "before"() : () -> (f32)
418+
%2 = "moved_op"(%arg0) : (f32) -> (f32)
419+
return %1, %2 : f32, f32
420+
}
421+
// CHECK-LABEL: func @move_only_required_defns()
422+
// CHECK: %[[MOVED:.+]] = "moved_op"
423+
// CHECK: %[[BEFORE:.+]] = "before"
424+
425+
module attributes {transform.with_named_sequence} {
426+
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
427+
%op1 = transform.structured.match ops{["before"]} in %arg0
428+
: (!transform.any_op) -> !transform.any_op
429+
%op2 = transform.structured.match ops{["moved_op"]} in %arg0
430+
: (!transform.any_op) -> !transform.any_op
431+
%v1 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
432+
transform.test.move_value_defns %v1 before %op1
433+
: (!transform.any_value), !transform.any_op
434+
transform.yield
435+
}
436+
}
437+
438+
// -----
439+
440+
// Do not move across basic blocks
441+
func.func @no_move_across_basic_blocks() -> (f32, f32) {
442+
%0 = "unmoved_op"() : () -> (f32)
443+
%1 = "before"() : () -> (f32)
444+
cf.br ^bb0(%0 : f32)
445+
^bb0(%arg0 : f32) :
446+
%2 = "moved_op"(%arg0) : (f32) -> (f32)
447+
return %1, %2 : f32, f32
448+
}
449+
450+
module attributes {transform.with_named_sequence} {
451+
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
452+
%op1 = transform.structured.match ops{["before"]} in %arg0
453+
: (!transform.any_op) -> !transform.any_op
454+
%op2 = transform.structured.match ops{["moved_op"]} in %arg0
455+
: (!transform.any_op) -> !transform.any_op
456+
%v1 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
457+
// expected-remark@+1{{unsupported case of moving definition of value before an insertion point in a different basic block}}
458+
transform.test.move_value_defns %v1 before %op1
459+
: (!transform.any_value), !transform.any_op
460+
transform.yield
461+
}
462+
}

mlir/test/lib/Transforms/TestTransformsOps.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,23 @@ transform::TestMoveOperandDeps::apply(TransformRewriter &rewriter,
3939
return DiagnosedSilenceableFailure::success();
4040
}
4141

42+
DiagnosedSilenceableFailure
43+
transform::TestMoveValueDefns::apply(TransformRewriter &rewriter,
44+
TransformResults &TransformResults,
45+
TransformState &state) {
46+
SmallVector<Value> values;
47+
for (auto tdValue : getValues()) {
48+
values.push_back(*state.getPayloadValues(tdValue).begin());
49+
}
50+
Operation *moveBefore = *state.getPayloadOps(getInsertionPoint()).begin();
51+
if (failed(moveValueDefinitions(rewriter, values, moveBefore))) {
52+
auto listener = cast<ErrorCheckingTrackingListener>(rewriter.getListener());
53+
std::string errorMsg = listener->getLatestMatchFailureMessage();
54+
(void)emitRemark(errorMsg);
55+
}
56+
return DiagnosedSilenceableFailure::success();
57+
}
58+
4259
namespace {
4360

4461
class TestTransformsDialectExtension

mlir/test/lib/Transforms/TestTransformsOps.td

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,26 @@ def TestMoveOperandDeps :
3838
}];
3939
}
4040

41+
def TestMoveValueDefns :
42+
Op<Transform_Dialect, "test.move_value_defns",
43+
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
44+
DeclareOpInterfaceMethods<TransformOpInterface>,
45+
ReportTrackingListenerFailuresOpTrait]> {
46+
let description = [{
47+
Moves all dependencies of on operation before another operation.
48+
}];
49+
50+
let arguments =
51+
(ins Variadic<TransformValueHandleTypeInterface>:$values,
52+
TransformHandleTypeInterface:$insertion_point);
53+
54+
let results = (outs);
55+
56+
let assemblyFormat = [{
57+
$values `before` $insertion_point attr-dict
58+
`:` `(` type($values) `)` `` `,` type($insertion_point)
59+
}];
60+
}
61+
62+
4163
#endif // TEST_TRANSFORM_OPS

0 commit comments

Comments
 (0)