Skip to content

Commit 9fddcea

Browse files
authored
[MLIR][Affine] Fix getSliceBounds for missing handling of no lower/upper bound in certain cases (llvm#127192)
Fix `FlatLinearValueConstraints::getSliceBounds` for missing checks on no lower/upper bound bound. Obvious bug. Fixes: llvm#119525 Fixes: llvm#108374
1 parent 5866032 commit 9fddcea

File tree

2 files changed

+72
-2
lines changed

2 files changed

+72
-2
lines changed

mlir/lib/Analysis/FlatLinearValueConstraints.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -679,7 +679,7 @@ void FlatLinearConstraints::getSliceBounds(unsigned offset, unsigned num,
679679
// TODO: being conservative for the moment in cases that
680680
// lead to multiple bounds - until getConstDifference in LoopFusion.cpp is
681681
// fixed (b/126426796).
682-
if (!lbMap || lbMap.getNumResults() > 1) {
682+
if (!lbMap || lbMap.getNumResults() != 1) {
683683
LLVM_DEBUG(llvm::dbgs()
684684
<< "WARNING: Potentially over-approximating slice lb\n");
685685
auto lbConst = getConstantBound64(BoundType::LB, pos + offset);
@@ -688,7 +688,7 @@ void FlatLinearConstraints::getSliceBounds(unsigned offset, unsigned num,
688688
getAffineConstantExpr(*lbConst, context));
689689
}
690690
}
691-
if (!ubMap || ubMap.getNumResults() > 1) {
691+
if (!ubMap || ubMap.getNumResults() != 1) {
692692
LLVM_DEBUG(llvm::dbgs()
693693
<< "WARNING: Potentially over-approximating slice ub\n");
694694
auto ubConst = getConstantBound64(BoundType::UB, pos + offset);

mlir/test/Dialect/Affine/loop-fusion-4.mlir

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,3 +425,73 @@ func.func @non_int_memory_space() {
425425
// PRODUCER-CONSUMER-NEXT: affine.for %{{.*}} = 0 to 8
426426
return
427427
}
428+
429+
// -----
430+
431+
#map = affine_map<(d0) -> (d0)>
432+
#map1 = affine_map<(d0) -> (d0 + 1)>
433+
434+
// Exercises fix for crash reported at https://github.com/llvm/llvm-project/issues/119525
435+
436+
// No fusion of producer into consumer happens here as the slice is determined
437+
// to be invalid. This is a limitation and it is possible to compute a slice
438+
// (reduction along %arg4) and fuse.
439+
440+
// PRODUCER-CONSUMER-LABEL: func @slice_compute_check
441+
func.func @slice_compute_check(%arg0: memref<1x8x26xi32, strided<[?, ?, ?], offset: ?>>, %arg1: memref<1x8x26xi32, strided<[?, ?, ?], offset: ?>>, %arg2: memref<1x8x26xi32, strided<[?, ?, ?], offset: ?>>) {
442+
%alloc_14 = memref.alloc() : memref<1x8x26xi32>
443+
%alloc_15 = memref.alloc() : memref<1x26xi32>
444+
affine.for %arg3 = 0 to 1 {
445+
affine.for %arg4 = 0 to 8 {
446+
affine.for %arg5 = 0 to 26 {
447+
affine.for %arg6 = #map(%arg3) to #map1(%arg3) {
448+
affine.for %arg7 = #map(%arg4) to #map1(%arg4) {
449+
affine.for %arg8 = #map(%arg5) to #map1(%arg5) {
450+
%61 = affine.load %alloc_14[%arg6, %arg7, %arg8] : memref<1x8x26xi32>
451+
%62 = affine.load %alloc_15[%arg6, %arg8] : memref<1x26xi32>
452+
%63 = llvm.intr.smin(%61, %62) : (i32, i32) -> i32
453+
affine.store %63, %alloc_15[%arg6, %arg8] : memref<1x26xi32>
454+
}
455+
}
456+
}
457+
}
458+
}
459+
}
460+
affine.for %arg3 = 0 to 26 {
461+
%61 = affine.load %alloc_15[0, %arg3] : memref<1x26xi32>
462+
}
463+
memref.dealloc %alloc_15 : memref<1x26xi32>
464+
memref.dealloc %alloc_14 : memref<1x8x26xi32>
465+
return
466+
}
467+
468+
// -----
469+
470+
// Exercises fix for crash reported at https://github.com/llvm/llvm-project/issues/108374
471+
472+
// No fusion of producer into consumer happens here. The slice will not be
473+
// valid as the producer doesn't supply to all of the consumer.
474+
475+
#map = affine_map<(d0) -> (d0)>
476+
#map1 = affine_map<(d0) -> (d0 + 1)>
477+
// PRODUCER-CONSUMER-LABEL: func @test_add_slice_bounds
478+
func.func @test_add_slice_bounds() {
479+
%alloc = memref.alloc() : memref<10xf32>
480+
%cst = arith.constant 0.619152 : f32
481+
affine.for %arg0 = 0 to 10 {
482+
affine.for %arg1 = #map(%arg0) to #map1(%arg0) {
483+
affine.store %cst, %alloc[%arg1] : memref<10xf32>
484+
}
485+
}
486+
affine.for %arg0 = 0 to 3 {
487+
affine.for %arg1 = 0 to 10 {
488+
affine.for %arg2 = #map(%arg0) to #map1(%arg0) {
489+
affine.for %arg3 = #map(%arg1) to #map1(%arg1) {
490+
%0 = affine.apply #map1(%arg3)
491+
%1 = affine.load %alloc[%0] : memref<10xf32>
492+
}
493+
}
494+
}
495+
}
496+
return
497+
}

0 commit comments

Comments
 (0)