Skip to content

Commit be354cf

Browse files
authored
[mlir][transforms] Process RegionBranchOp with empty region (llvm#123895)
This PR adds process for RegionBranchOp with empty region, such as 'else' region of `scf.if`. Fixes llvm#123246.
1 parent 0d8d354 commit be354cf

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

mlir/lib/Transforms/RemoveDeadValues.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
375375
// Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
376376
auto markLiveArgs = [&](DenseMap<Region *, BitVector> &liveArgs) {
377377
for (Region &region : regionBranchOp->getRegions()) {
378+
if (region.empty())
379+
continue;
378380
SmallVector<Value> arguments(region.front().getArguments());
379381
BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la);
380382
liveArgs[&region] = regionLiveArgs;
@@ -420,6 +422,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
420422
auto markNonForwardedReturnValues =
421423
[&](DenseMap<Operation *, BitVector> &nonForwardedRets) {
422424
for (Region &region : regionBranchOp->getRegions()) {
425+
if (region.empty())
426+
continue;
423427
Operation *terminator = region.front().getTerminator();
424428
nonForwardedRets[terminator] =
425429
BitVector(terminator->getNumOperands(), true);
@@ -499,6 +503,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
499503
// Recompute `resultsToKeep` and `argsToKeep` based on
500504
// `terminatorOperandsToKeep`.
501505
for (Region &region : regionBranchOp->getRegions()) {
506+
if (region.empty())
507+
continue;
502508
Operation *terminator = region.front().getTerminator();
503509
for (const RegionSuccessor &successor : getSuccessors(&region)) {
504510
Region *successorRegion = successor.getSuccessor();
@@ -547,6 +553,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
547553

548554
// Update the terminator operands that need to be kept.
549555
for (Region &region : regionBranchOp->getRegions()) {
556+
if (region.empty())
557+
continue;
550558
updateOperandsOrTerminatorOperandsToKeep(
551559
terminatorOperandsToKeep[region.back().getTerminator()],
552560
resultsToKeep, argsToKeep, &region);
@@ -611,8 +619,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
611619

612620
// Do (2.a) and (2.b).
613621
for (Region &region : regionBranchOp->getRegions()) {
614-
assert(!region.empty() && "expected a non-empty region in an op "
615-
"implementing `RegionBranchOpInterface`");
622+
if (region.empty())
623+
continue;
616624
BitVector argsToRemove = argsToKeep[&region].flip();
617625
cl.blocks.push_back({&region.front(), argsToRemove});
618626
collectNonLiveValues(nonLiveSet, region.front().getArguments(),
@@ -621,6 +629,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
621629

622630
// Do (2.c).
623631
for (Region &region : regionBranchOp->getRegions()) {
632+
if (region.empty())
633+
continue;
624634
Operation *terminator = region.front().getTerminator();
625635
cl.operands.push_back(
626636
{terminator, terminatorOperandsToKeep[terminator].flip()});

mlir/test/Transforms/remove-dead-values.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,22 @@ func.func @main(%arg3 : i32, %arg4 : i1) {
408408

409409
// -----
410410

411+
// The scf.if operation represents an if-then-else construct for conditionally
412+
// executing two regions of code. The 'the' region has exactly 1 block, and
413+
// the 'else' region may have 0 or 1 block. This case is to ensure 'else' region
414+
// with 0 block not crash.
415+
416+
// CHECK-LABEL: func.func @clean_region_branch_op_with_empty_region
417+
func.func @clean_region_branch_op_with_empty_region(%arg0: i1, %arg1: memref<f32>) {
418+
%cst = arith.constant 1.000000e+00 : f32
419+
scf.if %arg0 {
420+
memref.store %cst, %arg1[] : memref<f32>
421+
}
422+
return
423+
}
424+
425+
// -----
426+
411427
#map = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
412428
func.func @kernel(%arg0: memref<18xf32>) {
413429
%c1 = arith.constant 1 : index

0 commit comments

Comments
 (0)