Skip to content

Commit

Permalink
Improving block argument folding to handle more cases. (iree-org#13631)
Browse files Browse the repository at this point in the history
The existing code would give up on particular args if multiple branch
sites had non-identical duplicate arg sets for that arg.

Fixes iree-org#13543.
  • Loading branch information
benvanik authored and NatashaKnk committed Jul 6, 2023
1 parent cfa2c99 commit ebc8cf4
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 27 deletions.
46 changes: 27 additions & 19 deletions compiler/src/iree/compiler/Dialect/Util/Transforms/Patterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/EquivalenceClasses.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dominance.h"
Expand All @@ -32,7 +33,7 @@ static void eraseOperands(MutableOperandRange &operands,
}

// Folds block arguments that are always known to have the same value at all
// branch source sites. This is like CSE applied to blocks.
// branch source sites. This is like CSE applied to block arguments.
//
// Example:
// br ^bb1(%0, %0 : index, index)
Expand All @@ -56,12 +57,11 @@ struct FoldBlockArgumentsPattern
mutable BranchOpInterface branchOp;
// Which successor this source represents.
unsigned successorIndex;
// Successor operand index -> index mapping for duplicates.
// Base/non-duplicated values will be identity.
// Example: (%a, %b, %a, %b) -> (0, 1, 0, 1)
SmallVector<int> dupeIndexMap;
// Equivalence classes for all arguments indicating which have the same
// value at the source. Base/non-duplicated values will be identity.
// Example: (%a, %b, %a, %b, %c) -> (0, 2), (1, 3), (4)
llvm::EquivalenceClasses<unsigned> duplicates;
};
static const int kUnassigned = -1;
DenseMap<Block *, SmallVector<BlockSource>> blockSourceMap;
bool hasAnyDupes = false;
for (auto branchOp : region.getOps<BranchOpInterface>()) {
Expand All @@ -72,18 +72,17 @@ struct FoldBlockArgumentsPattern
BlockSource blockSource;
blockSource.branchOp = branchOp;
blockSource.successorIndex = successorIndex;
blockSource.dupeIndexMap.resize(operands.size(), kUnassigned);
for (int i = 0; i < operands.size(); ++i) {
blockSource.dupeIndexMap[i] = i;
blockSource.duplicates.insert(i);
for (int j = 0; j < i; ++j) {
if (operands[j] == operands[i]) {
blockSource.dupeIndexMap[i] = j;
blockSource.duplicates.unionSets(i, j);
hasAnyDupes |= true;
break;
}
}
}
blockSourceMap[block].push_back(blockSource);
blockSourceMap[block].push_back(std::move(blockSource));
}
}
if (!hasAnyDupes) {
Expand All @@ -109,9 +108,10 @@ struct FoldBlockArgumentsPattern
llvm::BitVector elidedArgs(numArgs);

// See if each block argument is foldable across all block sources.
// In order to fold we need each source to have the same index in its
// duplication map refering back to the given block argument.
llvm::BitVector sameValues(numArgs);
// In order to fold we need each source to share some duplicates but note
// that the sources may not have identical sets.
llvm::BitVector sameValues(numArgs); // reused
llvm::BitVector sourceValues(numArgs); // reused
for (unsigned argIndex = 0; argIndex < numArgs; ++argIndex) {
// Each bit represents an argument that duplicates the arg at argIndex.
// We walk all the sources and AND their masks together to get the safe
Expand All @@ -120,17 +120,25 @@ struct FoldBlockArgumentsPattern
// Example for %1: (%a, %b, %a) -> b000
sameValues.set(); // note reused
for (auto &blockSource : blockSources) {
for (unsigned i = 0; i < numArgs; ++i) {
if (i == argIndex || blockSource.dupeIndexMap[i] != argIndex) {
sameValues.reset(i);
}
sourceValues.reset();
for (auto mit = blockSource.duplicates.findLeader(argIndex);
mit != blockSource.duplicates.member_end(); ++mit) {
sourceValues.set(*mit);
}
sameValues &= sourceValues;
}
if (sameValues.none()) continue;
if (sameValues.none()) {
continue; // arg unused/not duplicated
}

// Remove the base argument from the set so we don't erase it and can
// point all duplicate args at it.
int baseArgIndex = sameValues.find_first();
sameValues.reset(baseArgIndex);
elidedArgs |= sameValues;

// Replace all of the subsequent duplicate arguments with the first.
auto baseArg = block.getArgument(argIndex);
auto baseArg = block.getArgument(baseArgIndex);
for (unsigned dupeIndex : sameValues.set_bits()) {
rewriter.replaceAllUsesWith(block.getArgument(dupeIndex), baseArg);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@ func.func @foldBrArguments(%cond: i1, %arg1: index) -> index {
^bb1:
// CHECK: %[[OP1:.+]] = "some.op1"
%0 = "some.op1"() : () -> index
// CHECK: cf.br ^bb3(%[[OP1]], %[[ARG1]] : index, index)
cf.br ^bb3(%0, %arg1, %0 : index, index, index)
// CHECK: cf.br ^bb3(%[[OP1]], %[[OP1]], %[[ARG1]] : index, index, index)
cf.br ^bb3(%0, %0, %arg1, %0 : index, index, index, index)
^bb2:
// CHECK: %[[OP2:.+]] = "some.op2"
%1 = "some.op2"() : () -> index
// CHECK: cf.br ^bb3(%[[OP2]], %[[OP2]] : index, index)
cf.br ^bb3(%1, %1, %1 : index, index, index)
// CHECK: ^bb3(%[[BB3_ARG0:.+]]: index, %[[BB3_ARG1:.+]]: index):
^bb3(%bb3_0: index, %bb3_1: index, %bb3_2: index):
// CHECK: %[[OP3:.+]] = "some.op3"(%[[BB3_ARG0]], %[[BB3_ARG1]], %[[BB3_ARG0]])
%2 = "some.op3"(%bb3_0, %bb3_1, %bb3_2) : (index, index, index) -> index
// CHECK: cf.br ^bb3(%[[ARG1]], %[[OP2]], %[[OP2]] : index, index, index)
cf.br ^bb3(%arg1, %1, %1, %1 : index, index, index, index)
// CHECK: ^bb3(%[[BB3_ARG0:.+]]: index, %[[BB3_ARG1:.+]]: index, %[[BB3_ARG2:.+]]: index):
^bb3(%bb3_0: index, %bb3_1: index, %bb3_2: index, %bb3_3: index):
// CHECK: %[[OP3:.+]] = "some.op3"(%[[BB3_ARG0]], %[[BB3_ARG1]], %[[BB3_ARG2]], %[[BB3_ARG1]])
%2 = "some.op3"(%bb3_0, %bb3_1, %bb3_2, %bb3_3) : (index, index, index, index) -> index
// CHECK: return %[[OP3]]
return %2 : index
}
Expand Down

0 comments on commit ebc8cf4

Please sign in to comment.