Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SESE canonicalization: unroll loop to eliminate undefs. #19811

Merged
merged 11 commits into from Oct 12, 2018
181 changes: 171 additions & 10 deletions lib/SILOptimizer/Mandatory/TFCanonicalizeCFG.cpp
Expand Up @@ -335,8 +335,8 @@ class BasicBlockCloner : public SILClonerWithScopes<BasicBlockCloner> {

bool hasCloned() const { return cloned; }

/// Return a cloned block.
SILBasicBlock *cloneBlock(SILBasicBlock *bb) {
/// Create a block and clone everything except the instructions.
bgogul marked this conversation as resolved.
Show resolved Hide resolved
SILBasicBlock *initBlock(SILBasicBlock *bb) {
auto bbIt = BBMap.find(bb);
if (bbIt != BBMap.end())
return bbIt->second;
Expand All @@ -354,13 +354,26 @@ class BasicBlockCloner : public SILClonerWithScopes<BasicBlockCloner> {
ValueMap[arg] = newBB->createPhiArgument(
arg->getType(), arg->getOwnershipKind(), arg->getDecl());
}
// Clone all the instructions.
return newBB;
}

// Clone all the instructions and return the cloned block.
SILBasicBlock *cloneBlock(SILBasicBlock * bb) {
bgogul marked this conversation as resolved.
Show resolved Hide resolved
auto bbIt = BBMap.find(bb);
assert (bbIt != BBMap.end() && "Block is not initialied before cloning.");
SILBasicBlock *newBB = bbIt->second;
getBuilder().setInsertionPoint(newBB);
for (auto &inst : *bb) {
visit(&inst);
}
return newBB;
}

SILBasicBlock *initAndCloneBlock(SILBasicBlock * bb) {
bgogul marked this conversation as resolved.
Show resolved Hide resolved
initBlock(bb);
return cloneBlock(bb);
}

/// Handle references to basic blocks when cloning.
SILBasicBlock *remapBasicBlock(SILBasicBlock *bb) {
// If the block was not cloned by this cloner, directly reference it.
Expand All @@ -370,6 +383,18 @@ class BasicBlockCloner : public SILClonerWithScopes<BasicBlockCloner> {
return bbIt->second;
return bb;
}

SILValue remapValue(SILValue Value) {
auto VI = ValueMap.find(Value);
if (VI != ValueMap.end())
return VI->second;
return Value;
}

void updateValueMap(SILValue oldValue, SILValue newValue) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we find better names for the param names? like oldValue -> key, newValue -> value

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think oldValue and newValue are better. It just says that when we are cloning replace occurrences of oldValue with new Value. Updated the function comment.

auto emplaceResult = ValueMap.try_emplace(oldValue, newValue);
assert(emplaceResult.second && "Remapping value multiple times during SESE cloning.");
bgogul marked this conversation as resolved.
Show resolved Hide resolved
}
};

} // namespace
Expand All @@ -382,7 +407,8 @@ class SingleExitLoopTransformer {
PostDominanceInfo *PDI)
: deviceInfo(deviceInfo), DI(DI), PDI(PDI), LI(LI), loop(loop),
header(loop->getHeader()), preheader(loop->getLoopPreheader()),
latch(loop->getLoopLatch()), currentFn(header->getParent()) {
latch(loop->getLoopLatch()), currentFn(header->getParent()),
oldHeaderNumArgs(header->getNumArguments()), hasUndefsAtPreheader(false) {
assert(preheader && "Canonicalization should have given us one preheader");
assert(latch && "Canonicalization should have given us one latch block");
initialize();
Expand Down Expand Up @@ -421,6 +447,9 @@ class SingleExitLoopTransformer {
/// we will get a single exit block.
void ensureSingleExitBlock();

/// Unroll the body of the loop once.
void unrollLoopBody();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider renaming to unrollLoopBodyOnce(), and remove the comment.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


/// Compute escaping values and what values to use as arguments at preheader.
llvm::DenseMap<SILValue, SILValue> computeEscapingValuesSubstMap() const;

Expand Down Expand Up @@ -471,6 +500,10 @@ class SingleExitLoopTransformer {
SILBasicBlock *preheader;
SILBasicBlock *latch;
SILFunction *currentFn;
unsigned oldHeaderNumArgs;
/// Flag to track if we have undefs at preheader corresponding to escaping
/// values and exit args.
bool hasUndefsAtPreheader;
/// Equivalence classes induced by argument passing.
llvm::EquivalenceClasses<SILValue> equivalentValues;
/// exit blocks of the loop.
Expand Down Expand Up @@ -606,7 +639,7 @@ void SingleExitLoopTransformer::ensureSingleExitBlock() {
if (DI->properlyDominates(succ, header)) continue;

// Clone the block and rewire the edge.
SILBasicBlock *clonedSucc = cloner.cloneBlock(succ);
SILBasicBlock *clonedSucc = cloner.initAndCloneBlock(succ);
changeBranchTarget(current->getTerminator(), edgeIdx, clonedSucc,
/*preserveArgs*/ true);
worklist.insert(clonedSucc);
Expand Down Expand Up @@ -839,10 +872,12 @@ void SingleExitLoopTransformer::patchPreheader(SILBasicBlock *newHeader) {
// Simply pass in an undef. This will never be accessed at runtime.
SmallVector<SILValue, 8> newArgs;
for (const auto &kv : escapingValueSubstMap) {
hasUndefsAtPreheader |= isa<SILUndef>(kv.second);
newArgs.push_back(kv.second);
}
if (TFNoUndefsInSESE) {
for (const auto &kv : exitArgSubstMap) {
hasUndefsAtPreheader |= isa<SILUndef>(kv.second);
newArgs.push_back(kv.second);
}
}
Expand All @@ -862,11 +897,6 @@ SingleExitLoopTransformer::patchEdges(SILBasicBlock *newHeader,

llvm::DenseMap<SILBasicBlock *, intmax_t> exitIndices;

unsigned oldHeaderNumArgs =
newHeader->getNumArguments() -
(escapingValueSubstMap.size() + exitArgSubstMap.size() +
/* exitIndex, stayInLoop*/ 2);

// Identify the exit from the header (if any) and assign '0' as its index.
SILBasicBlock *headerExit = nullptr;
for (SILBasicBlock *succ : header->getSuccessorBlocks()) {
Expand Down Expand Up @@ -1103,6 +1133,15 @@ bool SingleExitLoopTransformer::transform() {

// Update the loop header to newHeader.
loop->moveToHeader(newHeader);

if (TFNoUndefsInSESE) {
// If we still have undefs at preheader, simply clone the loop body once
// before the actual loop.
if (hasUndefsAtPreheader) {
unrollLoopBody();
}
}

return true;
}

Expand Down Expand Up @@ -1150,6 +1189,128 @@ void SESERegionBuilder::ensureSingleExitFromLoops() {
}
}

void SingleExitLoopTransformer::unrollLoopBody() {
BasicBlockCloner cloner(*currentFn);
// Setup cloner so that newHeader's argument's are replaced with values in
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: argument's -> arguments?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

// preheader.
SILBasicBlock *newHeader = loop->getHeader();
auto preheaderTermInst = dyn_cast<BranchInst>(preheader->getTerminator());
assert(preheaderTermInst && "Preheader of a loop has a non-branch terminator");
for (unsigned argIndex = 0; argIndex < oldHeaderNumArgs; ++argIndex) {
auto preHeaderArg = preheaderTermInst->getArg(argIndex);
auto newHeaderArg = newHeader->getArgument(argIndex);
cloner.updateValueMap(newHeaderArg, preHeaderArg);
}
// Clone everything except the new header. We should traverse the
// blocks in depth first order to ensure values are cloned before they are used.
SmallPtrSet<SILBasicBlock *, 32> worklist;
SmallVector<SILBasicBlock *, 32> initializedBlocks;
worklist.insert(header);
while (!worklist.empty()) {
SILBasicBlock *current = *worklist.begin();
worklist.erase(current);
cloner.initBlock(current);
initializedBlocks.push_back(current);
for (SILBasicBlock *succ : current->getSuccessorBlocks()) {
// Skip if succ is not a part of the loop, is already cloned, or
// is the new preheader.
if (!loop->contains(succ) || cloner.remapBasicBlock(succ) != succ ||
succ == newHeader) {
continue;
}
worklist.insert(succ);
}
}
for (SILBasicBlock *bb : initializedBlocks) {
cloner.cloneBlock(bb);
}

// Get the clone for the original and new header.
SILBasicBlock *clonedHeader = cloner.remapBasicBlock(header);
bgogul marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given "newHeader", it'd be less confusing if we change this to "oldHeader", and same for clonedHeader.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to clonedOldHeader. header is a class field.

replaceBranchTarget(preheader->getTerminator(), newHeader, clonedHeader,
/*preserveArgs*/ false);

// Along a path in the loop body where an escaping value or an exit argument
// is not defined, the SESE loop canonicalization would have propagated the
// corresponding loop carried state that was added to the new header. However,
// these are not remapped when the loop body is unrolled (as we won't know
// what value to use in the unrolled body as it is undefined along that path).
// This following code patches these arguments by picking a value that
// dominates `pred` and is equivalent to the corresponding argument in the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mentioning pred here makes the comment block hard to read and confusing, as that name does not appear in the large block of CFG example below (I was initially wondering if there's a typo), but seems to instead refer to a var in the code that's many lines down below.

consider first explaining the rationale/benefit/mechanics (the why and what) of unrolling in terms of the example below. we can then add another comment block right above the code below, to describe how that is in general implemented in terms of variables in the code, so that the code-related comment would echo / reiterate on what the example has illustrated for the readers.

// cloned block. e.g.,
//
// do {
// if (...) break;
// i += 1
// } while(...)
// return i
//
// --CFG--
// preheader: i0 = 0; br header(i0)
//
// header(i0): cond ??, break, body
//
// break: br exit(i0)
//
// body: i1 = i0 + 1; cond ??, header(i1), exit(i1)
//
// exit(i2): return i2
//
// --Canonicalized CFG (not everything is shown)--
// preheader: i0 = 0; br newHeader(i0, undef)
//
// newHeader(i0, i3): cond stayInLoop, header, exit(i3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm not able to follow this example overall. the entire textual representation is hard to read. for example, where is stayInLoop updated?

it might help if:

  1. we give some high level textual description on why undef is present in the Canonicalized CFG (is it along the lines of: "the second bb arg of newHeader is the updated value i that we want to return; when we enter newHeader for the first time from preheader, we don't know that value, but we won't ever return it, so setting it to undef is safe"), and why the undef can be removed after unrolling.

Also, to make things simpler, is it possible to avoid generating undefs in the first place, vs first generating it, and then try eliminating it? specifically, can we achieve that by moving loop rotation earlier?

  1. we add some comments on the semantics of the bb args. e.g. what do i4 and i5 represent in newLatch(i4, i5)

//
// header: cond ??, break, body
//
// break: br newLatch(i0, i3)
//
// body: i1 = i0 + 1; cond ??, newHeader(i1, i1), newLatch(i1, i1)
//
// newLatch(i4, i5): br newHeader(i4, i5)
//
// exit(i2): return i2
//
// In the unrolled body of the loop, break will be cloned as follows:
// (prime refers to the cloned version):
// break': br newLatch'(i0', i3)
//
// Note that i3 is not cloned, which is patched here as follows:
// break': br newLatch'(i0', ii')
bgogul marked this conversation as resolved.
Show resolved Hide resolved
// `i1` is equivalent to `i3` as they both flow into the argument `i5` of
// `newLatch`.
SILBasicBlock *newLatch = loop->getLoopLatch();
SILBasicBlock *clonedNewLatch = cloner.remapBasicBlock(newLatch);
for (SILBasicBlock *pred : newLatch->getPredecessorBlocks()) {
auto predTermInst = dyn_cast<BranchInst>(pred->getTerminator());
assert(predTermInst && "Preheader of a loop has a non-branch terminator");
for (unsigned argIndex = 0; argIndex < predTermInst->getNumArgs(); ++argIndex) {
auto arg = predTermInst->getArg(argIndex);
// Skip if this is not a uncloned argument as illustrated above.
bgogul marked this conversation as resolved.
Show resolved Hide resolved
if (!isa<SILArgument>(arg) ||
cast<SILArgument>(arg)->getParent() != newHeader) {
continue;
}
// Iterate over the incoming values of the corresponding argument in the
// latch block and pick one that is suitable to be used here.
auto destBBArg = newLatch->getArgument(argIndex);
SmallVector<SILValue, 8> incomingValues;
destBBArg->getIncomingPhiValues(incomingValues);
for (auto value : incomingValues) {
if (value != arg && DI->properlyDominates(value, predTermInst)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if we cannot find such a value? should we assert this won't ever happen?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is guaranteed to find a value. SIL verification will fail otherwise. No need to add another check that will essentially replicate SIL verification.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debugging SIL verifier failure usually takes more work because the context is non-local. It'd usually be preferrable to have a local check so that we can fail fast. Also, the check serves as a documentation for this important invariant.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Added an assert to check that we patch such arguments.

// A suitable value is found. Update the edge value in the unrolled
// loop with the corresponding cloned value.
SILBasicBlock *clonedPred = cloner.remapBasicBlock(pred);
changeEdgeValue(clonedPred->getTerminator(), clonedNewLatch, argIndex,
cloner.remapValue(value));
break;
}
}
}
}

}

/// Process the specified loop, collapsing it into a SESE region node. This
/// forms a WhileLoopSESERegion node and puts it into the loopPreheaders data
/// structure, allowing the outer level's acyclic region handling to pick it
Expand Down
31 changes: 18 additions & 13 deletions test/TensorFlow/sese_loop_canonicalization.sil
Expand Up @@ -231,6 +231,11 @@ public func nestedLoopWithBreak(breakCount:Int32) -> Int32 {

// CHECK-LABEL: --- XLA CFG Canonicalize: $doWhileLoop
// CHECK: [sequence
// CHECK: {condition Header: {{bb[0-9]+}}
// CHECK: {condition Header: {{bb[0-9]+}}
// CHECK: block {{bb[0-9]+}}
// CHECK: block {{bb[0-9]+}}}
// CHECK: block {{bb[0-9]+}}}
// CHECK: <while Preheader: [[PHDR:bb[0-9]+]], Header: [[HDR:bb[0-9]+]], exit: [[EXIT:bb[0-9]+]]
// CHECK: [sequence
// CHECK: {condition Header: {{bb[0-9]+}}
Expand All @@ -243,8 +248,8 @@ public func nestedLoopWithBreak(breakCount:Int32) -> Int32 {

// Make sure undef is still left in this case for now.
// CHECK: sil @$doWhileLoop : {{.*}} (Builtin.Int32) -> Builtin.Int32 {
// CHECK: [[PHDR]]({{.*}} : $Builtin.Int32):
// CHECK: br [[HDR]]({{.*}} : $Builtin.Int32, undef : $Builtin.Int32, {{.*}} : $TensorHandle<Builtin.Int32>, {{.*}} : $TensorHandle<Builtin.Int1>)
// CHECK: [[PHDR]]({{.*}} : $TensorHandle<Builtin.Int1>):
// CHECK: br [[HDR]]({{.*}} : $Builtin.Int32, {{.*}} : $Builtin.Int32, {{.*}} : $TensorHandle<Builtin.Int32>, {{.*}} : $TensorHandle<Builtin.Int1>)

sil @$doWhileLoop : $@convention(thin) (Builtin.Int32) -> Builtin.Int32 {
bb0(%0 : $Builtin.Int32):
Expand Down Expand Up @@ -299,19 +304,19 @@ bb3 (%9 : $Builtin.Int32):
//
// CHECK-LABEL: --- XLA CFG Canonicalize: $loopThatRequiresNodeCloning
// CHECK: [sequence
// CHECK: {condition Header: bb0
// CHECK: block bb1
// CHECK: {condition Header: {{bb[0-9]+}}
// CHECK: block {{bb[0-9]+}}
// CHECK: [sequence
// CHECK: <while Preheader: bb2, Header: bb9, exit: bb11
// CHECK: <while Preheader: {{bb[0-9]+}}, Header: {{bb[0-9]+}}, exit: {{bb[0-9]+}}
// CHECK: [sequence
// CHECK: {condition Header: bb3
// CHECK: block bb4
// CHECK: {condition Header: bb5
// CHECK: block bb7
// CHECK: block bb6}}
// CHECK: block bb10]>
// CHECK: block bb11]}
// CHECK: block bb8]
// CHECK: {condition Header: {{bb[0-9]+}}
// CHECK: block {{bb[0-9]+}}
// CHECK: {condition Header: {{bb[0-9]+}}
// CHECK: block {{bb[0-9]+}}
// CHECK: block {{bb[0-9]+}}}}
// CHECK: block {{bb[0-9]+}}]>
// CHECK: block {{bb[0-9]+}}]}
// CHECK: block {{bb[0-9]+}}]
// CHECK: --- XLA CFG Canonicalize end
sil @$loopThatRequiresNodeCloning : $@convention(thin) (Builtin.Int32, Builtin.Int32) -> Builtin.Int32 {
bb0(%0 : $Builtin.Int32, %1 : $Builtin.Int32):
Expand Down
24 changes: 24 additions & 0 deletions test/TensorFlowRuntime/sese_loop_canonicalization.swift
Expand Up @@ -89,6 +89,30 @@ ControlFlowTests.testAllBackends("sumOfProductsWithBound") {
// Effectively no bound as natSum(3) * natSum(3) is 36.
expectNearlyEqualWithScalarTensor(36, sumOfProductsWithBound(3, 3, 100))
}


func doWhileLoopWithBreak(_ breakIndex:Int32) -> Tensor<Int32> {
var i: Int32 = 1
var sum = Tensor<Int32>(0)
let maxCount: Int32 = 100
repeat {
bgogul marked this conversation as resolved.
Show resolved Hide resolved
sum += i
if (i == breakIndex) {
bgogul marked this conversation as resolved.
Show resolved Hide resolved
break
}
i += 1
} while i <= maxCount
return sum
}

ControlFlowTests.testAllBackends("doWhileLoopWithBreak") {
expectEqualWithScalarTensor(3, natSumWithBreak(2))
expectEqualWithScalarTensor(55, natSumWithBreak(10))
expectEqualWithScalarTensor(5050, natSumWithBreak(-300))
expectEqualWithScalarTensor(5050, natSumWithBreak(100))
expectEqualWithScalarTensor(5050, natSumWithBreak(200))
}

#endif // CUDA

runAllTests()