Skip to content

Commit

Permalink
Optimize away unused direct recursive tapes (#1874)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed May 11, 2024
1 parent e96ccd2 commit 53a31b2
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 3 deletions.
32 changes: 29 additions & 3 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2678,9 +2678,23 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(

SmallVector<Type *, 4> MallocTypes;

bool nonRecursiveUse = false;

for (auto a : gutils->getTapeValues()) {
MallocTypes.push_back(a->getType());
if (auto ei = dyn_cast<ExtractValueInst>(a)) {
auto tidx = returnMapping.find(AugmentedStruct::Tape)->second;
if (ei->getIndices().size() == 1 && ei->getIndices()[0] == (unsigned)tidx)
if (auto cb = dyn_cast<CallBase>(ei->getOperand(0)))
if (gutils->newFunc == cb->getCalledFunction())
continue;
}
nonRecursiveUse = true;
}
if (MallocTypes.size() == 0)
nonRecursiveUse = true;
if (!nonRecursiveUse)
MallocTypes.clear();

Type *tapeType = StructType::get(nf->getContext(), MallocTypes);

Expand Down Expand Up @@ -2930,6 +2944,17 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
}
++i;
}
} else if (!nonRecursiveUse) {
for (auto v : gutils->getTapeValues()) {
if (isa<UndefValue>(v))
continue;
auto EV = cast<ExtractValueInst>(v);
auto EV2 = cast<ExtractValueInst>(VMap[v]);
assert(EV->use_empty());
EV->eraseFromParent();
assert(EV2->use_empty());
EV2->eraseFromParent();
}
}

for (BasicBlock &BB : *nf) {
Expand Down Expand Up @@ -3047,11 +3072,12 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
}
}
for (auto user : fnusers) {
if (removeStruct) {
if (removeStruct || !nonRecursiveUse) {
IRBuilder<> B(user);
SmallVector<Value *, 4> args(user->arg_begin(), user->arg_end());
auto rep = B.CreateCall(NewF, args);
rep->takeName(user);
if (!rep->getType()->isVoidTy())
rep->takeName(user);
rep->copyIRFlags(user);
rep->setAttributes(user->getAttributes());
rep->setCallingConv(user->getCallingConv());
Expand Down Expand Up @@ -3083,7 +3109,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
PPC.ReplaceReallocs(NewF, /*mem2reg*/ true);

AugmentedCachedFunctions.find(tup)->second.fn = NewF;
if (recursive || (omp && !noTape))
if ((recursive && nonRecursiveUse) || (omp && !noTape))
AugmentedCachedFunctions.find(tup)->second.tapeType = tapeType;
AugmentedCachedFunctions.find(tup)->second.isComplete = true;

Expand Down
76 changes: 76 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/dacsum.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -loop-deletion -correlated-propagation -adce -simplifycfg -S | FileCheck %s; fi
; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,loop(loop-deletion),correlated-propagation,adce,%simplifycfg)" -S | FileCheck %s

declare void @__enzyme_autodiff(...)

define void @dsquare(double* %arg, double* %arg1) {
bb:
tail call void (...) @__enzyme_autodiff(double (double*, i64, i64)* nonnull @sum, metadata !"enzyme_dup", double* %arg, double* %arg1, i64 1, i64 10)
ret void
}

; Function Attrs: nofree noinline
define internal fastcc double @sum(double* nocapture readonly %arg, i64 "enzyme_inactive" "enzyme_type"="{[-1]:Integer}" %arg1, i64 "enzyme_inactive" "enzyme_type"="{[-1]:Integer}" %arg2) {
bb:
%i8 = icmp eq i64 %arg2, %arg1
br i1 %i8, label %bb11, label %bb17

bb11: ; preds = %bb
%i12 = add i64 %arg2, -1
%i15 = getelementptr inbounds double, double* %arg, i64 %i12
%i16 = load double, double* %i15, align 8
br label %bb9

bb17: ; preds = %bb
%i18 = sub i64 %arg2, %arg1
%i45 = ashr i64 %i18, 1
%i46 = add i64 %i45, %arg1
%i47 = call fastcc double @sum(double* %arg, i64 signext %arg1, i64 signext %i46)
%i48 = add i64 %i46, 1
%i49 = call fastcc double @sum(double* %arg, i64 signext %i48, i64 signext %arg2)
%i50 = fadd double %i47, %i49
br label %bb9

bb9: ; preds = %bb44, %bb35, %bb20, %bb11
%i10 = phi double [ %i16, %bb11 ], [ %i50, %bb17 ]
ret double %i10
}


; CHECK: define internal fastcc void @diffesum(double* nocapture readonly %arg, double* nocapture %"arg'", i64 "enzyme_inactive" "enzyme_type"="{[-1]:Integer}" %arg1, i64 "enzyme_inactive" "enzyme_type"="{[-1]:Integer}" %arg2, double %differeturn)
; CHECK-NEXT: bb:
; CHECK-NEXT: %i8 = icmp eq i64 %arg2, %arg1
; CHECK-NEXT: br i1 %i8, label %invertbb9, label %bb17

; CHECK: bb17: ; preds = %bb
; CHECK-NEXT: %i18 = sub i64 %arg2, %arg1
; CHECK-NEXT: %i45 = ashr i64 %i18, 1
; CHECK-NEXT: %i46 = add i64 %i45, %arg1
; CHECK-NEXT: call fastcc void @augmented_sum(double* %arg, double* %"arg'", i64 signext %arg1, i64 signext %i46)
; CHECK-NEXT: br label %invertbb9

; CHECK: invertbb: ; preds = %invertbb17, %invertbb11
; CHECK-NEXT: ret void

; CHECK: invertbb11: ; preds = %invertbb9
; CHECK-NEXT: %i12_unwrap = add i64 %arg2, -1
; CHECK-NEXT: %"i15'ipg_unwrap" = getelementptr inbounds double, double* %"arg'", i64 %i12_unwrap
; CHECK-NEXT: %0 = load double, double* %"i15'ipg_unwrap", align 8
; CHECK-NEXT: %1 = fadd fast double %0, %3
; CHECK-NEXT: store double %1, double* %"i15'ipg_unwrap", align 8
; CHECK-NEXT: br label %invertbb

; CHECK: invertbb17: ; preds = %invertbb9
; CHECK-NEXT: %i18_unwrap = sub i64 %arg2, %arg1
; CHECK-NEXT: %i45_unwrap = ashr i64 %i18_unwrap, 1
; CHECK-NEXT: %i46_unwrap = add i64 %i45_unwrap, %arg1
; CHECK-NEXT: %i48_unwrap = add i64 %i46_unwrap, 1
; CHECK-NEXT: call fastcc void @diffesum(double* %arg, double* %"arg'", i64 signext %i48_unwrap, i64 signext %arg2, double %2)
; CHECK-NEXT: call fastcc void @diffesum.2(double* %arg, double* %"arg'", i64 signext %arg1, i64 signext %i46_unwrap, double %2)
; CHECK-NEXT: br label %invertbb

; CHECK: invertbb9: ; preds = %bb17, %bb
; CHECK-NEXT: %2 = select fast i1 %i8, double 0.000000e+00, double %differeturn
; CHECK-NEXT: %3 = select fast i1 %i8, double %differeturn, double 0.000000e+00
; CHECK-NEXT: br i1 %i8, label %invertbb11, label %invertbb17
; CHECK-NEXT: }

0 comments on commit 53a31b2

Please sign in to comment.