diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 7ac3b5e15db10..caf026e87a9b6 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -9024,9 +9024,9 @@ class AdjointGenerator #endif } } - gutils->invertedPointers.erase(found); - bb.SetInsertPoint(placeholder->getNextNode()); + if (&*bb.GetInsertPoint() == placeholder) + bb.SetInsertPoint(placeholder->getNextNode()); gutils->replaceAWithB(placeholder, anti); gutils->erase(placeholder); @@ -9058,7 +9058,8 @@ class AdjointGenerator forwardsShadow) || (Mode == DerivativeMode::ReverseModeGradient && backwardsShadow)) { - zeroKnownAllocation(bb, anti, args, *called, gutils->TLI); + if (!inLoop) + zeroKnownAllocation(bb, anti, args, *called, gutils->TLI); } } gutils->invertedPointers.insert( @@ -9515,15 +9516,6 @@ class AdjointGenerator return; } else { assert(Mode == DerivativeMode::ReverseModeCombined); - // If in a loop context, maintain the same free behavior. - if (auto inst = dyn_cast(rmat.first)) - if (rmat.second.LI && - rmat.second.LI->contains(inst->getParent())) { - return; - } - // In combined mode, if we don't need this allocation - // in the reverse, we can use the original deallocation - // behavior. std::map Seen; for (auto pair : gutils->knownRecomputeHeuristic) if (!pair.second) @@ -9531,10 +9523,24 @@ class AdjointGenerator bool primalNeededInReverse = is_value_needed_in_reverse( TR, gutils, rmat.first, Mode, Seen, oldUnreachable); + bool cacheWholeAllocation = false; if (gutils->knownRecomputeHeuristic.count(rmat.first)) { - if (!gutils->knownRecomputeHeuristic[rmat.first]) + if (!gutils->knownRecomputeHeuristic[rmat.first]) { + cacheWholeAllocation = true; primalNeededInReverse = true; + } } + // If in a loop context, maintain the same free behavior, unless + // caching whole allocation. + if (!cacheWholeAllocation) + if (auto inst = dyn_cast(rmat.first)) + if (rmat.second.LI && + rmat.second.LI->contains(inst->getParent())) { + return; + } + // In combined mode, if we don't need this allocation + // in the reverse, we can use the original deallocation + // behavior. if (!primalNeededInReverse) return; } diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 24d6ec3c2d9da..4e1cbe526dcea 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -616,10 +616,25 @@ void calculateUnusedValuesInFunction( bool primalNeededInReverse = is_value_needed_in_reverse( TR, gutils, pair.first, mode, CacheResults, oldUnreachable); + bool cacheWholeAllocation = false; if (gutils->knownRecomputeHeuristic.count(pair.first)) { if (!gutils->knownRecomputeHeuristic[pair.first]) { primalNeededInReverse = true; + cacheWholeAllocation = true; + } + } + // If rematerializing a loop-level allocation, the primal allocation + // is not needed in the reverse. + if (!cacheWholeAllocation && primalNeededInReverse) { + auto found = gutils->rematerializableAllocations.find( + const_cast(pair.first)); + if (found != gutils->rematerializableAllocations.end()) { + if (auto inst = dyn_cast(pair.first)) + if (found->second.LI && + found->second.LI->contains(inst->getParent())) { + primalNeededInReverse = false; + } } } diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index de8c3aec88b1d..8aea8ffd56e48 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -1398,7 +1398,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM, unwrap_cache.erase(blocks[j]); lookup_cache.erase(blocks[j]); SmallVector toErase; - for (auto &I : *blocks[j]) { + for (auto &I : llvm::reverse(*blocks[j])) { toErase.push_back(&I); } for (auto I : toErase) { @@ -2016,6 +2016,7 @@ BasicBlock *GradientUtils::getReverseOrLatchMerge(BasicBlock *BB, llvm::errs() << "branchingBlock: " << *branchingBlock << "\n"; } assert(reverseBlocks.find(BB) != reverseBlocks.end()); + assert(reverseBlocks.find(branchingBlock) != reverseBlocks.end()); LoopContext lc; bool inLoop = getContext(BB, lc); @@ -2099,7 +2100,7 @@ BasicBlock *GradientUtils::getReverseOrLatchMerge(BasicBlock *BB, } else { BasicBlock *enterB = BasicBlock::Create( BB->getContext(), "remat_enter", BB->getParent()); - BasicBlock *exitB = resumeblock; + rematerializedLoops_cache[L] = enterB; std::map origToNewForward; for (auto B : origLI->getBlocks()) { BasicBlock *newB = BasicBlock::Create( @@ -2116,12 +2117,6 @@ BasicBlock *GradientUtils::getReverseOrLatchMerge(BasicBlock *BB, IRBuilder<> NB(enterB); NB.CreateBr(origToNewForward[origLI->getHeader()]); } - { - llvm::SmallPtrSet origExitBlocks; - getExitBlocks(origLI, origExitBlocks); - for (auto EB : origExitBlocks) - origToNewForward[EB] = exitB; - } std::function handleLoop = [&](Loop *OL, bool subLoop) { @@ -2503,10 +2498,20 @@ BasicBlock *GradientUtils::getReverseOrLatchMerge(BasicBlock *BB, } } - // Remap a branch to the header to continue to the block. + llvm::SmallPtrSet origExitBlocks; + getExitBlocks(origLI, origExitBlocks); + // Remap a branch to the header to enter the incremented + // reverse of that block. auto remap = [&](BasicBlock *rB) { + // Remap of an exit branch is to go to the reverse + // exiting block. + if (origExitBlocks.count(rB)) { + return reverseBlocks[getNewFromOriginal(B)].front(); + } + // Reverse of an incrementing branch is go to the + // reverse of the branching block. if (rB == origLI->getHeader()) - return exitB; + return reverseBlocks[getNewFromOriginal(B)].front(); return origToNewForward[rB]; }; @@ -2553,7 +2558,7 @@ BasicBlock *GradientUtils::getReverseOrLatchMerge(BasicBlock *BB, } } } - rematerializedLoops_cache[L] = resumeblock = enterB; + resumeblock = enterB; } } @@ -5250,8 +5255,19 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, if (auto origInst = isOriginal(inst)) { auto found = rematerializableAllocations.find(origInst); if (found != rematerializableAllocations.end()) - if (found->second.LI) - scope = &newFunc->getEntryBlock(); + if (found->second.LI && found->second.LI->contains(origInst)) { + bool cacheWholeAllocation = false; + if (knownRecomputeHeuristic.count(origInst)) { + if (!knownRecomputeHeuristic[origInst]) { + cacheWholeAllocation = true; + } + } + // If not caching whole allocation and rematerializing the allocation + // within the loop, force an entry-level scope so there is no need + // to cache. + if (!cacheWholeAllocation) + scope = &newFunc->getEntryBlock(); + } } else { for (auto pair : backwardsOnlyShadows) { if (auto pinst = dyn_cast(pair.first)) diff --git a/enzyme/Enzyme/GradientUtils.h b/enzyme/Enzyme/GradientUtils.h index 2d1b3bb5f2b5b..e11938387c087 100644 --- a/enzyme/Enzyme/GradientUtils.h +++ b/enzyme/Enzyme/GradientUtils.h @@ -796,7 +796,7 @@ class GradientUtils : public CacheUtility { EmitWarning("NotPromotable", LI->getDebugLoc(), oldFunc, LI->getParent(), " Could not promote allocation ", *V, " due to load ", *LI, - " which does not postdominates store "); + " which does not postdominates store ", *res); return; } } diff --git a/enzyme/test/Enzyme/ReverseMode/infAlloc2.ll b/enzyme/test/Enzyme/ReverseMode/infAlloc2.ll index f69c7f4d5a5d9..897d168d34bfa 100644 --- a/enzyme/test/Enzyme/ReverseMode/infAlloc2.ll +++ b/enzyme/test/Enzyme/ReverseMode/infAlloc2.ll @@ -113,69 +113,67 @@ attributes #3 = { nounwind } ; CHECK-NEXT: call void @free(i8* %call) ; CHECK-NEXT: br label %for.cond, !llvm.loop !6 -; CHECK: invertentry: ; preds = %invertfor.body +; CHECK: invertentry: ; CHECK-NEXT: %0 = insertvalue { double } undef, double %"rho0'de.0", 0 ; CHECK-NEXT: ret { double } %0 -; CHECK: incinvertfor.cond: ; preds = %invertfor.body -; CHECK-NEXT: %1 = add nsw i64 %"iv'ac.0", -1 +; CHECK: invertfor.cond: ; preds = %invertfor.body, %remat_enter +; CHECK-NEXT: %"mul'de.0" = phi double [ %"mul'de.1", %invertfor.body ], [ %"mul'de.2", %remat_enter ] +; CHECK-NEXT: %"i10'de.0" = phi double [ %"i10'de.1", %invertfor.body ], [ %"i10'de.2", %remat_enter ] +; CHECK-NEXT: %"rho0'de.0" = phi double [ %"rho0'de.1", %invertfor.body ], [ %"rho0'de.2", %remat_enter ] +; CHECK-NEXT: %1 = icmp eq i64 %"iv'ac.0", 0 +; CHECK-NEXT: br i1 %1, label %invertentry, label %incinvertfor.cond + +; CHECK: incinvertfor.cond: +; CHECK-NEXT: %2 = add nsw i64 %"iv'ac.0", -1 ; CHECK-NEXT: br label %remat_enter ; CHECK: invertfor.body: ; preds = %invertfor.cond1 -; CHECK-NEXT: %"i4'ipc_unwrap" = bitcast i8* %"call'mi_cache.0" to double* +; CHECK-NEXT: %"i4'ipc_unwrap" = bitcast i8* %"call'mi" to double* ; CHECK-NEXT: store double 0.000000e+00, double* %"i4'ipc_unwrap", align 8 -; CHECK-NEXT: tail call void @free(i8* nonnull %"call'mi_cache.0") -; CHECK-NEXT: tail call void @free(i8* %call_cache.0) -; CHECK-NEXT: %2 = icmp eq i64 %"iv'ac.0", 0 -; CHECK-NEXT: br i1 %2, label %invertentry, label %incinvertfor.cond - -; CHECK: invertfor.cond1: ; preds = %invertfor.end, %incinvertfor.cond1 -; CHECK-NEXT: %"i10'de.0" = phi double [ %"i10'de.1", %invertfor.end ], [ 0.000000e+00, %incinvertfor.cond1 ] -; CHECK-NEXT: %"mul'de.0" = phi double [ %"mul'de.1", %invertfor.end ], [ 0.000000e+00, %incinvertfor.cond1 ] -; CHECK-NEXT: %"rho0'de.0" = phi double [ %"rho0'de.1", %invertfor.end ], [ %8, %incinvertfor.cond1 ] -; CHECK-NEXT: %"iv1'ac.0" = phi i64 [ 999999, %invertfor.end ], [ %4, %incinvertfor.cond1 ] +; CHECK-NEXT: tail call void @free(i8* nonnull %"call'mi") +; CHECK-NEXT: tail call void @free(i8* %remat_call) +; CHECK-NEXT: br label %invertfor.cond + +; CHECK: invertfor.cond1: ; preds = %remat_for.cond_for.cond1, %incinvertfor.cond1 +; CHECK-NEXT: %"mul'de.1" = phi double [ 0.000000e+00, %incinvertfor.cond1 ], [ %"mul'de.2", %remat_for.cond_for.cond1 ] +; CHECK-NEXT: %"i10'de.1" = phi double [ 0.000000e+00, %incinvertfor.cond1 ], [ %"i10'de.2", %remat_for.cond_for.cond1 ] +; CHECK-NEXT: %"rho0'de.1" = phi double [ %8, %incinvertfor.cond1 ], [ %"rho0'de.2", %remat_for.cond_for.cond1 ] +; CHECK-NEXT: %"iv1'ac.0" = phi i64 [ %4, %incinvertfor.cond1 ], [ 999999, %remat_for.cond_for.cond1 ] ; CHECK-NEXT: %3 = icmp eq i64 %"iv1'ac.0", 0 ; CHECK-NEXT: br i1 %3, label %invertfor.body, label %incinvertfor.cond1 ; CHECK: incinvertfor.cond1: ; preds = %invertfor.cond1 ; CHECK-NEXT: %4 = add nsw i64 %"iv1'ac.0", -1 -; CHECK-NEXT: %"i4'ipc_unwrap2" = bitcast i8* %"call'mi_cache.0" to double* +; CHECK-NEXT: %"i4'ipc_unwrap2" = bitcast i8* %"call'mi" to double* ; CHECK-NEXT: %"arrayidx5'ipg_unwrap" = getelementptr inbounds double, double* %"i4'ipc_unwrap2", i64 %"iv1'ac.0" ; CHECK-NEXT: %5 = load double, double* %"arrayidx5'ipg_unwrap", align 8 ; CHECK-NEXT: store double 0.000000e+00, double* %"arrayidx5'ipg_unwrap", align 8 -; CHECK-NEXT: %6 = fadd fast double %"mul'de.0", %5 +; CHECK-NEXT: %6 = fadd fast double %"mul'de.1", %5 ; CHECK-NEXT: %m0diffei10 = fmul fast double %6, %rho0 -; CHECK-NEXT: %i4_unwrap3 = bitcast i8* %call_cache.0 to double* ; CHECK-NEXT: %sub_unwrap4 = sub i64 %"iv1'ac.0", 1 -; CHECK-NEXT: %arrayidx4_unwrap5 = getelementptr inbounds double, double* %i4_unwrap3, i64 %sub_unwrap4 +; CHECK-NEXT: %arrayidx4_unwrap5 = getelementptr inbounds double, double* %i4_unwrap, i64 %sub_unwrap4 ; CHECK-NEXT: %i10_unwrap6 = load double, double* %arrayidx4_unwrap5, align 8, !invariant.group !7 ; CHECK-NEXT: %m1differho0 = fmul fast double %6, %i10_unwrap6 -; CHECK-NEXT: %7 = fadd fast double %"i10'de.0", %m0diffei10 -; CHECK-NEXT: %8 = fadd fast double %"rho0'de.0", %m1differho0 +; CHECK-NEXT: %7 = fadd fast double %"i10'de.1", %m0diffei10 +; CHECK-NEXT: %8 = fadd fast double %"rho0'de.1", %m1differho0 ; CHECK-NEXT: %"arrayidx4'ipg_unwrap" = getelementptr inbounds double, double* %"i4'ipc_unwrap2", i64 %sub_unwrap4 ; CHECK-NEXT: %9 = load double, double* %"arrayidx4'ipg_unwrap", align 8 ; CHECK-NEXT: %10 = fadd fast double %9, %7 ; CHECK-NEXT: store double %10, double* %"arrayidx4'ipg_unwrap", align 8 ; CHECK-NEXT: br label %invertfor.cond1 -; CHECK: invertfor.end: ; preds = %remat_for.cond_for.cond1, %remat_enter -; CHECK-NEXT: %"call'mi_cache.0" = phi i8* [ %"call'mi_cache.1", %remat_enter ], [ %"call'mi", %remat_for.cond_for.cond1 ] -; CHECK-NEXT: %call_cache.0 = phi i8* [ %call_cache.1, %remat_enter ], [ %remat_call, %remat_for.cond_for.cond1 ] -; CHECK-NEXT: br label %invertfor.cond1 - ; CHECK: remat_enter: ; preds = %for.cond, %incinvertfor.cond -; CHECK-NEXT: %"i10'de.1" = phi double [ %"i10'de.0", %incinvertfor.cond ], [ 0.000000e+00, %for.cond ] -; CHECK-NEXT: %"mul'de.1" = phi double [ %"mul'de.0", %incinvertfor.cond ], [ 0.000000e+00, %for.cond ] -; CHECK-NEXT: %"call'mi_cache.1" = phi i8* [ %"call'mi_cache.0", %incinvertfor.cond ], [ undef, %for.cond ] -; CHECK-NEXT: %call_cache.1 = phi i8* [ %call_cache.0, %incinvertfor.cond ], [ undef, %for.cond ] -; CHECK-NEXT: %"rho0'de.1" = phi double [ %"rho0'de.0", %incinvertfor.cond ], [ %differeturn, %for.cond ] -; CHECK-NEXT: %"iv'ac.0" = phi i64 [ %1, %incinvertfor.cond ], [ %numReg, %for.cond ] +; CHECK-NEXT: %"mul'de.2" = phi double [ %"mul'de.0", %incinvertfor.cond ], [ 0.000000e+00, %for.cond ] +; CHECK-NEXT: %"i10'de.2" = phi double [ %"i10'de.0", %incinvertfor.cond ], [ 0.000000e+00, %for.cond ] +; CHECK-NEXT: %"rho0'de.2" = phi double [ %"rho0'de.0", %incinvertfor.cond ], [ %differeturn, %for.cond ] +; CHECK-NEXT: %"iv'ac.0" = phi i64 [ %2, %incinvertfor.cond ], [ %numReg, %for.cond ] ; CHECK-NEXT: %cmp_unwrap = icmp ne i64 %"iv'ac.0", %numReg -; CHECK-NEXT: br i1 %cmp_unwrap, label %remat_for.cond_for.body, label %invertfor.end +; CHECK-NEXT: br i1 %cmp_unwrap, label %remat_for.cond_for.body, label %invertfor.cond ; CHECK: remat_for.cond_for.body: ; preds = %remat_enter -; CHECK-NEXT: %remat_call = call noalias align 16 i8* @calloc(i64 8, i64 1000000) -; CHECK-NEXT: %"call'mi" = call noalias nonnull align 16 i8* @calloc(i64 8, i64 1000000) +; CHECK-NEXT: %remat_call = call noalias align 16 i8* @calloc(i64 8, i64 1000000) +; CHECK-NEXT: %"call'mi" = call noalias nonnull align 16 i8* @calloc(i64 8, i64 1000000) ; CHECK-NEXT: %i4_unwrap = bitcast i8* %remat_call to double* ; CHECK-NEXT: store double 1.000000e+00, double* %i4_unwrap, align 8 ; CHECK-NEXT: br label %remat_for.cond_for.cond1 @@ -185,7 +183,7 @@ attributes #3 = { nounwind } ; CHECK-NEXT: %fiv = phi i64 [ %11, %remat_for.cond_for.body3 ], [ 0, %remat_for.cond_for.body ] ; CHECK-NEXT: %11 = add i64 %fiv, 1 ; CHECK-NEXT: %cmp2_unwrap = icmp ne i64 %11, 1000000 -; CHECK-NEXT: br i1 %cmp2_unwrap, label %remat_for.cond_for.body3, label %invertfor.end +; CHECK-NEXT: br i1 %cmp2_unwrap, label %remat_for.cond_for.body3, label %invertfor.cond1 ; CHECK: remat_for.cond_for.body3: ; preds = %remat_for.cond_for.cond1 ; CHECK-NEXT: %arrayidx5_unwrap = getelementptr inbounds double, double* %i4_unwrap, i64 %11 diff --git a/enzyme/test/Integration/ReverseMode/remat.c b/enzyme/test/Integration/ReverseMode/remat.c new file mode 100644 index 0000000000000..e0c625283116d --- /dev/null +++ b/enzyme/test/Integration/ReverseMode/remat.c @@ -0,0 +1,65 @@ +// RUN: %clang -std=c11 -O0 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli - +// RUN: %clang -std=c11 -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli - +// RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli - +// RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli - +// RUN: %clang -std=c11 -O0 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme --enzyme-inline=1 -S | %lli - +// RUN: %clang -std=c11 -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme --enzyme-inline=1 -S | %lli - +// RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme --enzyme-inline=1 -S | %lli - +// RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme --enzyme-inline=1 -S | %lli - + +// test.c +#include +#include + +#include "test_utils.h" + +extern void __enzyme_autodiff(void*, ...); +void square(double** p_delv, double** p_e, int ** idx, int numReg, int numElemReg) { + double* delv = *p_delv; + double* e = *p_e; + for (int r = 0; r < numReg; r++) { + double* tmp = (double*)malloc(numElemReg * sizeof(double)); + for (int i=0; i +#include + +#include "test_utils.h" + +extern void __enzyme_autodiff(void*, ...); +void square(double* __restrict__ delv, double* __restrict__ e, unsigned long long numReg) { + for (unsigned long long r = 0; r < 20; r++) { + double* tmp = (double*)malloc(numReg * sizeof(double)); + for (unsigned long long i=0; i