Skip to content

Commit

Permalink
Fix loop materialization branching (rust-lang#522)
Browse files Browse the repository at this point in the history
* Better print message

* Fix cache whole allocation

* Fix loop rematerialization

* Fix flakey alloc order

* Reverse erasure
  • Loading branch information
wsmoses committed Feb 14, 2022
1 parent 8827a70 commit 377f2c7
Show file tree
Hide file tree
Showing 7 changed files with 219 additions and 62 deletions.
32 changes: 19 additions & 13 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -9024,9 +9024,9 @@ class AdjointGenerator
#endif
}
}

gutils->invertedPointers.erase(found);
bb.SetInsertPoint(placeholder->getNextNode());
if (&*bb.GetInsertPoint() == placeholder)
bb.SetInsertPoint(placeholder->getNextNode());
gutils->replaceAWithB(placeholder, anti);
gutils->erase(placeholder);

Expand Down Expand Up @@ -9058,7 +9058,8 @@ class AdjointGenerator
forwardsShadow) ||
(Mode == DerivativeMode::ReverseModeGradient &&
backwardsShadow)) {
zeroKnownAllocation(bb, anti, args, *called, gutils->TLI);
if (!inLoop)
zeroKnownAllocation(bb, anti, args, *called, gutils->TLI);
}
}
gutils->invertedPointers.insert(
Expand Down Expand Up @@ -9515,26 +9516,31 @@ class AdjointGenerator
return;
} else {
assert(Mode == DerivativeMode::ReverseModeCombined);
// If in a loop context, maintain the same free behavior.
if (auto inst = dyn_cast<Instruction>(rmat.first))
if (rmat.second.LI &&
rmat.second.LI->contains(inst->getParent())) {
return;
}
// In combined mode, if we don't need this allocation
// in the reverse, we can use the original deallocation
// behavior.
std::map<UsageKey, bool> Seen;
for (auto pair : gutils->knownRecomputeHeuristic)
if (!pair.second)
Seen[UsageKey(pair.first, ValueType::Primal)] = false;
bool primalNeededInReverse =
is_value_needed_in_reverse<ValueType::Primal>(
TR, gutils, rmat.first, Mode, Seen, oldUnreachable);
bool cacheWholeAllocation = false;
if (gutils->knownRecomputeHeuristic.count(rmat.first)) {
if (!gutils->knownRecomputeHeuristic[rmat.first])
if (!gutils->knownRecomputeHeuristic[rmat.first]) {
cacheWholeAllocation = true;
primalNeededInReverse = true;
}
}
// If in a loop context, maintain the same free behavior, unless
// caching whole allocation.
if (!cacheWholeAllocation)
if (auto inst = dyn_cast<Instruction>(rmat.first))
if (rmat.second.LI &&
rmat.second.LI->contains(inst->getParent())) {
return;
}
// In combined mode, if we don't need this allocation
// in the reverse, we can use the original deallocation
// behavior.
if (!primalNeededInReverse)
return;
}
Expand Down
15 changes: 15 additions & 0 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -616,10 +616,25 @@ void calculateUnusedValuesInFunction(

bool primalNeededInReverse = is_value_needed_in_reverse<ValueType::Primal>(
TR, gutils, pair.first, mode, CacheResults, oldUnreachable);
bool cacheWholeAllocation = false;

if (gutils->knownRecomputeHeuristic.count(pair.first)) {
if (!gutils->knownRecomputeHeuristic[pair.first]) {
primalNeededInReverse = true;
cacheWholeAllocation = true;
}
}
// If rematerializing a loop-level allocation, the primal allocation
// is not needed in the reverse.
if (!cacheWholeAllocation && primalNeededInReverse) {
auto found = gutils->rematerializableAllocations.find(
const_cast<CallInst *>(pair.first));
if (found != gutils->rematerializableAllocations.end()) {
if (auto inst = dyn_cast<Instruction>(pair.first))
if (found->second.LI &&
found->second.LI->contains(inst->getParent())) {
primalNeededInReverse = false;
}
}
}

Expand Down
42 changes: 29 additions & 13 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1398,7 +1398,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
unwrap_cache.erase(blocks[j]);
lookup_cache.erase(blocks[j]);
SmallVector<Instruction *, 4> toErase;
for (auto &I : *blocks[j]) {
for (auto &I : llvm::reverse(*blocks[j])) {
toErase.push_back(&I);
}
for (auto I : toErase) {
Expand Down Expand Up @@ -2016,6 +2016,7 @@ BasicBlock *GradientUtils::getReverseOrLatchMerge(BasicBlock *BB,
llvm::errs() << "branchingBlock: " << *branchingBlock << "\n";
}
assert(reverseBlocks.find(BB) != reverseBlocks.end());
assert(reverseBlocks.find(branchingBlock) != reverseBlocks.end());
LoopContext lc;
bool inLoop = getContext(BB, lc);

Expand Down Expand Up @@ -2099,7 +2100,7 @@ BasicBlock *GradientUtils::getReverseOrLatchMerge(BasicBlock *BB,
} else {
BasicBlock *enterB = BasicBlock::Create(
BB->getContext(), "remat_enter", BB->getParent());
BasicBlock *exitB = resumeblock;
rematerializedLoops_cache[L] = enterB;
std::map<BasicBlock *, BasicBlock *> origToNewForward;
for (auto B : origLI->getBlocks()) {
BasicBlock *newB = BasicBlock::Create(
Expand All @@ -2116,12 +2117,6 @@ BasicBlock *GradientUtils::getReverseOrLatchMerge(BasicBlock *BB,
IRBuilder<> NB(enterB);
NB.CreateBr(origToNewForward[origLI->getHeader()]);
}
{
llvm::SmallPtrSet<llvm::BasicBlock *, 8> origExitBlocks;
getExitBlocks(origLI, origExitBlocks);
for (auto EB : origExitBlocks)
origToNewForward[EB] = exitB;
}

std::function<void(Loop *, bool)> handleLoop = [&](Loop *OL,
bool subLoop) {
Expand Down Expand Up @@ -2503,10 +2498,20 @@ BasicBlock *GradientUtils::getReverseOrLatchMerge(BasicBlock *BB,
}
}

// Remap a branch to the header to continue to the block.
llvm::SmallPtrSet<llvm::BasicBlock *, 8> origExitBlocks;
getExitBlocks(origLI, origExitBlocks);
// Remap a branch to the header to enter the incremented
// reverse of that block.
auto remap = [&](BasicBlock *rB) {
// Remap of an exit branch is to go to the reverse
// exiting block.
if (origExitBlocks.count(rB)) {
return reverseBlocks[getNewFromOriginal(B)].front();
}
// Reverse of an incrementing branch is go to the
// reverse of the branching block.
if (rB == origLI->getHeader())
return exitB;
return reverseBlocks[getNewFromOriginal(B)].front();
return origToNewForward[rB];
};

Expand Down Expand Up @@ -2553,7 +2558,7 @@ BasicBlock *GradientUtils::getReverseOrLatchMerge(BasicBlock *BB,
}
}
}
rematerializedLoops_cache[L] = resumeblock = enterB;
resumeblock = enterB;
}
}

Expand Down Expand Up @@ -5250,8 +5255,19 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
if (auto origInst = isOriginal(inst)) {
auto found = rematerializableAllocations.find(origInst);
if (found != rematerializableAllocations.end())
if (found->second.LI)
scope = &newFunc->getEntryBlock();
if (found->second.LI && found->second.LI->contains(origInst)) {
bool cacheWholeAllocation = false;
if (knownRecomputeHeuristic.count(origInst)) {
if (!knownRecomputeHeuristic[origInst]) {
cacheWholeAllocation = true;
}
}
// If not caching whole allocation and rematerializing the allocation
// within the loop, force an entry-level scope so there is no need
// to cache.
if (!cacheWholeAllocation)
scope = &newFunc->getEntryBlock();
}
} else {
for (auto pair : backwardsOnlyShadows) {
if (auto pinst = dyn_cast<Instruction>(pair.first))
Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/GradientUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,7 @@ class GradientUtils : public CacheUtility {
EmitWarning("NotPromotable", LI->getDebugLoc(), oldFunc,
LI->getParent(), " Could not promote allocation ", *V,
" due to load ", *LI,
" which does not postdominates store ");
" which does not postdominates store ", *res);
return;
}
}
Expand Down
68 changes: 33 additions & 35 deletions enzyme/test/Enzyme/ReverseMode/infAlloc2.ll
Original file line number Diff line number Diff line change
Expand Up @@ -113,69 +113,67 @@ attributes #3 = { nounwind }
; CHECK-NEXT: call void @free(i8* %call)
; CHECK-NEXT: br label %for.cond, !llvm.loop !6

; CHECK: invertentry: ; preds = %invertfor.body
; CHECK: invertentry:
; CHECK-NEXT: %0 = insertvalue { double } undef, double %"rho0'de.0", 0
; CHECK-NEXT: ret { double } %0

; CHECK: incinvertfor.cond: ; preds = %invertfor.body
; CHECK-NEXT: %1 = add nsw i64 %"iv'ac.0", -1
; CHECK: invertfor.cond: ; preds = %invertfor.body, %remat_enter
; CHECK-NEXT: %"mul'de.0" = phi double [ %"mul'de.1", %invertfor.body ], [ %"mul'de.2", %remat_enter ]
; CHECK-NEXT: %"i10'de.0" = phi double [ %"i10'de.1", %invertfor.body ], [ %"i10'de.2", %remat_enter ]
; CHECK-NEXT: %"rho0'de.0" = phi double [ %"rho0'de.1", %invertfor.body ], [ %"rho0'de.2", %remat_enter ]
; CHECK-NEXT: %1 = icmp eq i64 %"iv'ac.0", 0
; CHECK-NEXT: br i1 %1, label %invertentry, label %incinvertfor.cond

; CHECK: incinvertfor.cond:
; CHECK-NEXT: %2 = add nsw i64 %"iv'ac.0", -1
; CHECK-NEXT: br label %remat_enter

; CHECK: invertfor.body: ; preds = %invertfor.cond1
; CHECK-NEXT: %"i4'ipc_unwrap" = bitcast i8* %"call'mi_cache.0" to double*
; CHECK-NEXT: %"i4'ipc_unwrap" = bitcast i8* %"call'mi" to double*
; CHECK-NEXT: store double 0.000000e+00, double* %"i4'ipc_unwrap", align 8
; CHECK-NEXT: tail call void @free(i8* nonnull %"call'mi_cache.0")
; CHECK-NEXT: tail call void @free(i8* %call_cache.0)
; CHECK-NEXT: %2 = icmp eq i64 %"iv'ac.0", 0
; CHECK-NEXT: br i1 %2, label %invertentry, label %incinvertfor.cond

; CHECK: invertfor.cond1: ; preds = %invertfor.end, %incinvertfor.cond1
; CHECK-NEXT: %"i10'de.0" = phi double [ %"i10'de.1", %invertfor.end ], [ 0.000000e+00, %incinvertfor.cond1 ]
; CHECK-NEXT: %"mul'de.0" = phi double [ %"mul'de.1", %invertfor.end ], [ 0.000000e+00, %incinvertfor.cond1 ]
; CHECK-NEXT: %"rho0'de.0" = phi double [ %"rho0'de.1", %invertfor.end ], [ %8, %incinvertfor.cond1 ]
; CHECK-NEXT: %"iv1'ac.0" = phi i64 [ 999999, %invertfor.end ], [ %4, %incinvertfor.cond1 ]
; CHECK-NEXT: tail call void @free(i8* nonnull %"call'mi")
; CHECK-NEXT: tail call void @free(i8* %remat_call)
; CHECK-NEXT: br label %invertfor.cond

; CHECK: invertfor.cond1: ; preds = %remat_for.cond_for.cond1, %incinvertfor.cond1
; CHECK-NEXT: %"mul'de.1" = phi double [ 0.000000e+00, %incinvertfor.cond1 ], [ %"mul'de.2", %remat_for.cond_for.cond1 ]
; CHECK-NEXT: %"i10'de.1" = phi double [ 0.000000e+00, %incinvertfor.cond1 ], [ %"i10'de.2", %remat_for.cond_for.cond1 ]
; CHECK-NEXT: %"rho0'de.1" = phi double [ %8, %incinvertfor.cond1 ], [ %"rho0'de.2", %remat_for.cond_for.cond1 ]
; CHECK-NEXT: %"iv1'ac.0" = phi i64 [ %4, %incinvertfor.cond1 ], [ 999999, %remat_for.cond_for.cond1 ]
; CHECK-NEXT: %3 = icmp eq i64 %"iv1'ac.0", 0
; CHECK-NEXT: br i1 %3, label %invertfor.body, label %incinvertfor.cond1

; CHECK: incinvertfor.cond1: ; preds = %invertfor.cond1
; CHECK-NEXT: %4 = add nsw i64 %"iv1'ac.0", -1
; CHECK-NEXT: %"i4'ipc_unwrap2" = bitcast i8* %"call'mi_cache.0" to double*
; CHECK-NEXT: %"i4'ipc_unwrap2" = bitcast i8* %"call'mi" to double*
; CHECK-NEXT: %"arrayidx5'ipg_unwrap" = getelementptr inbounds double, double* %"i4'ipc_unwrap2", i64 %"iv1'ac.0"
; CHECK-NEXT: %5 = load double, double* %"arrayidx5'ipg_unwrap", align 8
; CHECK-NEXT: store double 0.000000e+00, double* %"arrayidx5'ipg_unwrap", align 8
; CHECK-NEXT: %6 = fadd fast double %"mul'de.0", %5
; CHECK-NEXT: %6 = fadd fast double %"mul'de.1", %5
; CHECK-NEXT: %m0diffei10 = fmul fast double %6, %rho0
; CHECK-NEXT: %i4_unwrap3 = bitcast i8* %call_cache.0 to double*
; CHECK-NEXT: %sub_unwrap4 = sub i64 %"iv1'ac.0", 1
; CHECK-NEXT: %arrayidx4_unwrap5 = getelementptr inbounds double, double* %i4_unwrap3, i64 %sub_unwrap4
; CHECK-NEXT: %arrayidx4_unwrap5 = getelementptr inbounds double, double* %i4_unwrap, i64 %sub_unwrap4
; CHECK-NEXT: %i10_unwrap6 = load double, double* %arrayidx4_unwrap5, align 8, !invariant.group !7
; CHECK-NEXT: %m1differho0 = fmul fast double %6, %i10_unwrap6
; CHECK-NEXT: %7 = fadd fast double %"i10'de.0", %m0diffei10
; CHECK-NEXT: %8 = fadd fast double %"rho0'de.0", %m1differho0
; CHECK-NEXT: %7 = fadd fast double %"i10'de.1", %m0diffei10
; CHECK-NEXT: %8 = fadd fast double %"rho0'de.1", %m1differho0
; CHECK-NEXT: %"arrayidx4'ipg_unwrap" = getelementptr inbounds double, double* %"i4'ipc_unwrap2", i64 %sub_unwrap4
; CHECK-NEXT: %9 = load double, double* %"arrayidx4'ipg_unwrap", align 8
; CHECK-NEXT: %10 = fadd fast double %9, %7
; CHECK-NEXT: store double %10, double* %"arrayidx4'ipg_unwrap", align 8
; CHECK-NEXT: br label %invertfor.cond1

; CHECK: invertfor.end: ; preds = %remat_for.cond_for.cond1, %remat_enter
; CHECK-NEXT: %"call'mi_cache.0" = phi i8* [ %"call'mi_cache.1", %remat_enter ], [ %"call'mi", %remat_for.cond_for.cond1 ]
; CHECK-NEXT: %call_cache.0 = phi i8* [ %call_cache.1, %remat_enter ], [ %remat_call, %remat_for.cond_for.cond1 ]
; CHECK-NEXT: br label %invertfor.cond1

; CHECK: remat_enter: ; preds = %for.cond, %incinvertfor.cond
; CHECK-NEXT: %"i10'de.1" = phi double [ %"i10'de.0", %incinvertfor.cond ], [ 0.000000e+00, %for.cond ]
; CHECK-NEXT: %"mul'de.1" = phi double [ %"mul'de.0", %incinvertfor.cond ], [ 0.000000e+00, %for.cond ]
; CHECK-NEXT: %"call'mi_cache.1" = phi i8* [ %"call'mi_cache.0", %incinvertfor.cond ], [ undef, %for.cond ]
; CHECK-NEXT: %call_cache.1 = phi i8* [ %call_cache.0, %incinvertfor.cond ], [ undef, %for.cond ]
; CHECK-NEXT: %"rho0'de.1" = phi double [ %"rho0'de.0", %incinvertfor.cond ], [ %differeturn, %for.cond ]
; CHECK-NEXT: %"iv'ac.0" = phi i64 [ %1, %incinvertfor.cond ], [ %numReg, %for.cond ]
; CHECK-NEXT: %"mul'de.2" = phi double [ %"mul'de.0", %incinvertfor.cond ], [ 0.000000e+00, %for.cond ]
; CHECK-NEXT: %"i10'de.2" = phi double [ %"i10'de.0", %incinvertfor.cond ], [ 0.000000e+00, %for.cond ]
; CHECK-NEXT: %"rho0'de.2" = phi double [ %"rho0'de.0", %incinvertfor.cond ], [ %differeturn, %for.cond ]
; CHECK-NEXT: %"iv'ac.0" = phi i64 [ %2, %incinvertfor.cond ], [ %numReg, %for.cond ]
; CHECK-NEXT: %cmp_unwrap = icmp ne i64 %"iv'ac.0", %numReg
; CHECK-NEXT: br i1 %cmp_unwrap, label %remat_for.cond_for.body, label %invertfor.end
; CHECK-NEXT: br i1 %cmp_unwrap, label %remat_for.cond_for.body, label %invertfor.cond

; CHECK: remat_for.cond_for.body: ; preds = %remat_enter
; CHECK-NEXT: %remat_call = call noalias align 16 i8* @calloc(i64 8, i64 1000000)
; CHECK-NEXT: %"call'mi" = call noalias nonnull align 16 i8* @calloc(i64 8, i64 1000000)
; CHECK-NEXT: %remat_call = call noalias align 16 i8* @calloc(i64 8, i64 1000000)
; CHECK-NEXT: %"call'mi" = call noalias nonnull align 16 i8* @calloc(i64 8, i64 1000000)
; CHECK-NEXT: %i4_unwrap = bitcast i8* %remat_call to double*
; CHECK-NEXT: store double 1.000000e+00, double* %i4_unwrap, align 8
; CHECK-NEXT: br label %remat_for.cond_for.cond1
Expand All @@ -185,7 +183,7 @@ attributes #3 = { nounwind }
; CHECK-NEXT: %fiv = phi i64 [ %11, %remat_for.cond_for.body3 ], [ 0, %remat_for.cond_for.body ]
; CHECK-NEXT: %11 = add i64 %fiv, 1
; CHECK-NEXT: %cmp2_unwrap = icmp ne i64 %11, 1000000
; CHECK-NEXT: br i1 %cmp2_unwrap, label %remat_for.cond_for.body3, label %invertfor.end
; CHECK-NEXT: br i1 %cmp2_unwrap, label %remat_for.cond_for.body3, label %invertfor.cond1

; CHECK: remat_for.cond_for.body3: ; preds = %remat_for.cond_for.cond1
; CHECK-NEXT: %arrayidx5_unwrap = getelementptr inbounds double, double* %i4_unwrap, i64 %11
Expand Down
65 changes: 65 additions & 0 deletions enzyme/test/Integration/ReverseMode/remat.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// RUN: %clang -std=c11 -O0 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
// RUN: %clang -std=c11 -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
// RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
// RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
// RUN: %clang -std=c11 -O0 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme --enzyme-inline=1 -S | %lli -
// RUN: %clang -std=c11 -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme --enzyme-inline=1 -S | %lli -
// RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme --enzyme-inline=1 -S | %lli -
// RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme --enzyme-inline=1 -S | %lli -

// test.c
#include <stdio.h>
#include <stdlib.h>

#include "test_utils.h"

extern void __enzyme_autodiff(void*, ...);
void square(double** p_delv, double** p_e, int ** idx, int numReg, int numElemReg) {
double* delv = *p_delv;
double* e = *p_e;
for (int r = 0; r < numReg; r++) {
double* tmp = (double*)malloc(numElemReg * sizeof(double));
for (int i=0; i<numElemReg; i++) {
int off = idx[r][i];
tmp[i] = delv[off];
}
for (int i=0; i<numElemReg; i++) {
int off = idx[r][i];
e[off] = tmp[i] * tmp[i];
}
free(tmp);
}
}
int main() {
int numReg = 100;
double *delv = (double*)malloc(sizeof(double)*numReg);
double *e = (double*)malloc(sizeof(double)*numReg);
double *d_delv = (double*)malloc(sizeof(double)*numReg);
double *d_e = (double*)malloc(sizeof(double)*numReg);
int* idxs[numReg];
int numRegElem = 200;
for (int i=0; i<numReg; i++) {
int* data = (int*)malloc(sizeof(int)*numRegElem);
for (int j=0; j<numRegElem; j++) {
data[j] = j % numReg;
}
idxs[i] = data;
delv[i] = i;
d_delv[i] = 0;
e[i] = 0;
d_e[i] = 1;
}

square(&delv, &e, idxs, numReg, numRegElem);
for (int i=0; i<numReg; i++) {
printf("e=%f delv=%f\n", e[i], delv[i]);
}

__enzyme_autodiff((void*)square, &delv, &d_delv, &e, &d_e, idxs, numReg, numRegElem);
for (int i=0; i<numReg; i++) {
printf("d_e=%f d_delv=%f\n", d_e[i], d_delv[i]);
APPROX_EQ(d_e[i], 0.0, 1e-10);
APPROX_EQ(d_delv[i], 2.0 * i, 1e-10);
}
}

Loading

0 comments on commit 377f2c7

Please sign in to comment.