Skip to content

Commit f15a6c9

Browse files
authored
[mlir] [DataFlow] Fix bug in int-range-analysis (llvm#126708)
When querying the lower bound and upper bound of loop to update the value range of a loop iteration variable, the program point to depend on should be the block corresponding to the iteration variable rather than the loop operation.
1 parent 9db0f91 commit f15a6c9

File tree

2 files changed

+36
-5
lines changed

2 files changed

+36
-5
lines changed

mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -152,15 +152,15 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
152152
/// on a LoopLikeInterface return the lower/upper bound for that result if
153153
/// possible.
154154
auto getLoopBoundFromFold = [&](std::optional<OpFoldResult> loopBound,
155-
Type boundType, bool getUpper) {
155+
Type boundType, Block *block, bool getUpper) {
156156
unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType);
157157
if (loopBound.has_value()) {
158158
if (auto attr = dyn_cast<Attribute>(*loopBound)) {
159159
if (auto bound = dyn_cast_or_null<IntegerAttr>(attr))
160160
return bound.getValue();
161161
} else if (auto value = llvm::dyn_cast_if_present<Value>(*loopBound)) {
162162
const IntegerValueRangeLattice *lattice =
163-
getLatticeElementFor(getProgramPointAfter(op), value);
163+
getLatticeElementFor(getProgramPointBefore(block), value);
164164
if (lattice != nullptr && !lattice->getValue().isUninitialized())
165165
return getUpper ? lattice->getValue().getValue().smax()
166166
: lattice->getValue().getValue().smin();
@@ -180,16 +180,17 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
180180
return SparseForwardDataFlowAnalysis ::visitNonControlFlowArguments(
181181
op, successor, argLattices, firstIndex);
182182
}
183+
Block *block = iv->getParentBlock();
183184
std::optional<OpFoldResult> lowerBound = loop.getSingleLowerBound();
184185
std::optional<OpFoldResult> upperBound = loop.getSingleUpperBound();
185186
std::optional<OpFoldResult> step = loop.getSingleStep();
186-
APInt min = getLoopBoundFromFold(lowerBound, iv->getType(),
187+
APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), block,
187188
/*getUpper=*/false);
188-
APInt max = getLoopBoundFromFold(upperBound, iv->getType(),
189+
APInt max = getLoopBoundFromFold(upperBound, iv->getType(), block,
189190
/*getUpper=*/true);
190191
// Assume positivity for uniscoverable steps by way of getUpper = true.
191192
APInt stepVal =
192-
getLoopBoundFromFold(step, iv->getType(), /*getUpper=*/true);
193+
getLoopBoundFromFold(step, iv->getType(), block, /*getUpper=*/true);
193194

194195
if (stepVal.isNegative()) {
195196
std::swap(min, max);

mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,33 @@ func.func @dont_propagate_across_infinite_loop() -> index {
154154
return %2 : index
155155
}
156156

157+
// CHECK-LABEL: @propagate_from_block_to_iterarg
158+
func.func @propagate_from_block_to_iterarg(%arg0: index, %arg1: i1) {
159+
%c0 = arith.constant 0 : index
160+
%c1 = arith.constant 1 : index
161+
%0 = scf.if %arg1 -> (index) {
162+
%1 = scf.if %arg1 -> (index) {
163+
scf.yield %arg0 : index
164+
} else {
165+
scf.yield %arg0 : index
166+
}
167+
scf.yield %1 : index
168+
} else {
169+
scf.yield %c1 : index
170+
}
171+
scf.for %arg2 = %c0 to %arg0 step %c1 {
172+
scf.if %arg1 {
173+
%1 = arith.subi %0, %c1 : index
174+
%2 = arith.muli %0, %1 : index
175+
%3 = arith.addi %2, %c1 : index
176+
scf.for %arg3 = %c0 to %3 step %c1 {
177+
%4 = arith.cmpi uge, %arg3, %c1 : index
178+
// CHECK-NOT: scf.if %false
179+
scf.if %4 {
180+
"test.foo"() : () -> ()
181+
}
182+
}
183+
}
184+
}
185+
return
186+
}

0 commit comments

Comments
 (0)