Skip to content

Commit efc31ec

Browse files
authored
[mlir][LICM] Restrict LICM to pure tensor semantics (llvm#129673)
This PR fixes a bug where LICM incorrectly allowed buffer semantics, which could lead to a crash. Fixes llvm#129416.
1 parent 5728813 commit efc31ec

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -280,11 +280,11 @@ MatchingSubsets::populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
280280
if (auto insertionOp =
281281
dyn_cast<SubsetInsertionOpInterface>(use.getOwner())) {
282282
// Current implementation expects that the insertionOp implement
283-
// the destinationStyleOpInterface as well. Abort if that tha is not
284-
// the case
285-
if (!isa<DestinationStyleOpInterface>(use.getOwner())) {
283+
// the DestinationStyleOpInterface and with pure tensor semantics
284+
// as well. Abort if that is not the case.
285+
auto dstOp = dyn_cast<DestinationStyleOpInterface>(use.getOwner());
286+
if (!dstOp || !dstOp.hasPureTensorSemantics())
286287
return failure();
287-
}
288288

289289
// The value must be used as a destination. (In case of a source, the
290290
// entire tensor would be read, which would prevent any hoisting.)

mlir/test/Transforms/loop-invariant-subset-hoisting.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,3 +595,21 @@ func.func @hoist_vector_transfer_write_pairs_disjoint_tensor(
595595
}
596596
return %1 : tensor<?x?xf32>
597597
}
598+
599+
// -----
600+
601+
// Ensure that cases with buffer semantics exit gracefully.
602+
603+
// CHECK-LABEL: @hoist_buffer
604+
func.func @hoist_buffer(%arg0: memref<7x7xf16>) {
605+
%c0 = arith.constant 0 : index
606+
%c1 = arith.constant 1 : index
607+
%alloc = memref.alloc() : memref<7x7xf16>
608+
// CHECK: scf.for
609+
// CHECK: linalg.copy
610+
%0 = scf.for %arg1 = %c0 to %c1 step %c1 iter_args(%arg2 = %alloc) -> (memref<7x7xf16>) {
611+
linalg.copy ins(%arg0 : memref<7x7xf16>) outs(%arg2 : memref<7x7xf16>)
612+
scf.yield %alloc : memref<7x7xf16>
613+
}
614+
return
615+
}

0 commit comments

Comments
 (0)