Skip to content

Commit 1bc58a2

Browse files
authored
Extend getBackwardSlice to track values captured from above (llvm#113478)
This change modifies `getBackwardSlice` to track values captures by the regions of each operation that it traverses. Ignoring values captured from a parent region may lead to an incomplete program slice. However, there seems to be logic that depends on not traversing captured values, so this change preserves the default behavior by hiding this logic behind the `omitUsesFromAbove` flag.
1 parent 41448c1 commit 1bc58a2

File tree

4 files changed

+47
-6
lines changed

4 files changed

+47
-6
lines changed

mlir/include/mlir/Analysis/SliceAnalysis.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ struct BackwardSliceOptions : public SliceOptions {
4747
/// backward slice computation traverses block arguments and asserts that the
4848
/// parent op has a single region with a single block.
4949
bool omitBlockArguments = false;
50+
51+
/// When omitUsesFromAbove is true, the backward slice computation omits
52+
/// traversing values that are captured from above.
53+
/// TODO: this should default to `false` after users have been updated.
54+
bool omitUsesFromAbove = true;
5055
};
5156

5257
using ForwardSliceOptions = SliceOptions;

mlir/lib/Analysis/SliceAnalysis.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
#include "mlir/IR/Operation.h"
1717
#include "mlir/Interfaces/SideEffectInterfaces.h"
1818
#include "mlir/Support/LLVM.h"
19+
#include "mlir/Transforms/RegionUtils.h"
20+
#include "llvm/ADT/STLExtras.h"
1921
#include "llvm/ADT/SetVector.h"
2022
#include "llvm/ADT/SmallPtrSet.h"
2123

@@ -91,14 +93,13 @@ static void getBackwardSliceImpl(Operation *op,
9193
if (options.filter && !options.filter(op))
9294
return;
9395

94-
for (const auto &en : llvm::enumerate(op->getOperands())) {
95-
auto operand = en.value();
96-
if (auto *definingOp = operand.getDefiningOp()) {
96+
auto processValue = [&](Value value) {
97+
if (auto *definingOp = value.getDefiningOp()) {
9798
if (backwardSlice->count(definingOp) == 0)
9899
getBackwardSliceImpl(definingOp, backwardSlice, options);
99-
} else if (auto blockArg = dyn_cast<BlockArgument>(operand)) {
100+
} else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
100101
if (options.omitBlockArguments)
101-
continue;
102+
return;
102103

103104
Block *block = blockArg.getOwner();
104105
Operation *parentOp = block->getParentOp();
@@ -113,7 +114,14 @@ static void getBackwardSliceImpl(Operation *op,
113114
} else {
114115
llvm_unreachable("No definingOp and not a block argument.");
115116
}
117+
};
118+
119+
if (!options.omitUsesFromAbove) {
120+
visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand *operand) {
121+
processValue(operand->get());
122+
});
116123
}
124+
llvm::for_each(op->getOperands(), processValue);
117125

118126
backwardSlice->insert(op);
119127
}

mlir/test/IR/slice.mlir

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt -slice-analysis-test %s | FileCheck %s
1+
// RUN: mlir-opt -slice-analysis-test -split-input-file %s | FileCheck %s
22

33
func.func @slicing_linalg_op(%arg0 : index, %arg1 : index, %arg2 : index) {
44
%a = memref.alloc(%arg0, %arg2) : memref<?x?xf32>
@@ -33,3 +33,29 @@ func.func @slicing_linalg_op(%arg0 : index, %arg1 : index, %arg2 : index) {
3333
// CHECK-DAG: %[[B:.+]] = memref.alloc(%[[ARG2]], %[[ARG1]]) : memref<?x?xf32>
3434
// CHECK-DAG: %[[C:.+]] = memref.alloc(%[[ARG0]], %[[ARG1]]) : memref<?x?xf32>
3535
// CHECK: return
36+
37+
// -----
38+
39+
#map = affine_map<(d0, d1) -> (d0, d1)>
40+
func.func @slice_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) {
41+
%0 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
42+
^bb0(%in: f32, %out: f32):
43+
%2 = arith.addf %in, %in : f32
44+
linalg.yield %2 : f32
45+
} -> tensor<5x5xf32>
46+
%collapsed = tensor.collapse_shape %0 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32>
47+
%1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
48+
^bb0(%in: f32, %out: f32):
49+
%c2 = arith.constant 2 : index
50+
%extracted = tensor.extract %collapsed[%c2] : tensor<25xf32>
51+
%2 = arith.addf %extracted, %extracted : f32
52+
linalg.yield %2 : f32
53+
} -> tensor<5x5xf32>
54+
return
55+
}
56+
57+
// CHECK-LABEL: func @slice_use_from_above__backward_slice__0
58+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor
59+
// CHECK: %[[A:.+]] = linalg.generic {{.*}} ins(%[[ARG0]]
60+
// CHECK: %[[B:.+]] = tensor.collapse_shape %[[A]]
61+
// CHECK: return

mlir/test/lib/IR/TestSlicing.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ static LogicalResult createBackwardSliceFunction(Operation *op,
3939
SetVector<Operation *> slice;
4040
BackwardSliceOptions options;
4141
options.omitBlockArguments = omitBlockArguments;
42+
// TODO: Make this default.
43+
options.omitUsesFromAbove = false;
4244
getBackwardSlice(op, &slice, options);
4345
for (Operation *slicedOp : slice)
4446
builder.clone(*slicedOp, mapper);

0 commit comments

Comments
 (0)