Skip to content

Commit

Permalink
Handle triangular and related loops
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Nov 7, 2019
1 parent 3102ace commit c17c043
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 10 deletions.
3 changes: 2 additions & 1 deletion bench/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

%-unopt.ll: %.cpp
#../build/bin/clang++ -fno-unroll-loops $^ -O3 -fno-vectorize -fno-slp-vectorize -S -emit-llvm -o $@ -ffast-math -Wno-error=non-pod-varargs -DEIGEN_UNROLLING_LIMIT=0 -I ../../adept-2.0.5/include -I../../tapenade/ADFirstAidKit
../build/bin/clang++ -fno-exceptions -fno-unroll-loops $^ -O3 -fno-vectorize -fno-slp-vectorize -S -emit-llvm -o $@ -ffast-math -Wno-error=non-pod-varargs -DEIGEN_UNROLLING_LIMIT=0 -I../../adept-2.0.5/include
#../build/bin/clang++ -fno-exceptions -fno-unroll-loops $^ -O3 -fno-vectorize -fno-slp-vectorize -S -emit-llvm -o $@ -ffast-math -Wno-error=non-pod-varargs -DEIGEN_UNROLLING_LIMIT=0 -I/home/wmoses/autodiff/adept-2.0.5/include
../build/bin/clang++ -fno-unroll-loops $^ -O3 -fno-vectorize -fno-slp-vectorize -S -emit-llvm -o $@ -ffast-math -Wno-error=non-pod-varargs -DEIGEN_UNROLLING_LIMIT=0 -I/home/wmoses/autodiff/adept-2.0.5/include

%-preopt.ll: %-unopt.ll
../build/bin/opt $^ -indvars -load=../enzyme/build/Enzyme/LLVMEnzyme-7.so -enzyme -mem2reg -sroa -early-cse-memssa -adce -bdce -simplifycfg -inline -adce -aggressive-instcombine -O2 -S -o $@
Expand Down
1 change: 1 addition & 0 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ std::pair<Function*,StructType*> CreateAugmentedPrimal(Function* todiff, AAResul
}

if (called == nullptr) {
llvm::errs() << gutils->newFunc << "\n";
assert(op);
llvm::errs() << "cannot handle augment non constant function\n" << *op << "\n";
report_fatal_error("unknown augment non constant function");
Expand Down
1 change: 0 additions & 1 deletion enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,6 @@ void removeRedundantIVs(const Loop* L, BasicBlock* Header, BasicBlock* Preheader
assert(cmp->getOperand(0) == increment);

auto scv = SE.getSCEVAtScope(cmp->getOperand(1), L);
llvm::errs() << "coing to think about " << *cmp << "\n";
if (cmp->isUnsigned() || (scv != SE.getCouldNotCompute() && SE.isKnownNonNegative(scv)) ) {

// valid replacements (since unsigned comparison and i starts at 0 counting up)
Expand Down
22 changes: 15 additions & 7 deletions enzyme/Enzyme/GradientUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -903,7 +903,7 @@ class GradientUtils {
if (!getContext(blk, idx)) {
break;
}
llvm::errs() << " adding to contexts: " << idx.header->getName() << " starting ctx=" << ctx->getName() << "\n";
//llvm::errs() << " adding to contexts: " << idx.header->getName() << " starting ctx=" << ctx->getName() << "\n";
contexts.emplace_back(idx);
blk = idx.preheader;
}
Expand All @@ -926,11 +926,17 @@ class GradientUtils {
ValueToValueMapTy emptyMap;
IRBuilder <> allocationBuilder(&allocationPreheaders[i]->back());
Value* limitMinus1 = unwrapM(contexts[i].limit, allocationBuilder, emptyMap, /*lookupIfAble*/false);

if (limitMinus1 == nullptr) {
allocationPreheaders[i] = contexts[i].preheader;
allocationBuilder.SetInsertPoint(&allocationPreheaders[i]->back());
limitMinus1 = unwrapM(contexts[i].limit, allocationBuilder, emptyMap, /*lookupIfAble*/false);
/*
assert(allocationPreheaders[i]);
llvm::errs() << *oldFunc << "\n";
llvm::errs() << *newFunc << "\n";
llvm::errs() << "needed value " << *contexts[i].limit << " at " << allocationPreheaders[i]->getName() << "\n";
*/
}
assert(limitMinus1 != nullptr);
limits[i] = allocationBuilder.CreateNUWAdd(limitMinus1, ConstantInt::get(limitMinus1->getType(), 1));
Expand Down Expand Up @@ -961,17 +967,17 @@ class GradientUtils {
size = allocationBuilder.CreateNUWMul(size, limits[i]);
}

llvm::errs() << "considering ctx " << ctx->getName() << " alph=" << allocationPreheaders[i]->getName() << " ctxheader=" << contexts[i].header->getName() << "\n";
if (contexts[i].dynamic) {
llvm::errs() << "starting outermost ph at " << allocationPreheaders[i]->getName() << "|ctx=" << ctx->getName() <<"\n";
//llvm::errs() << "considering ctx " << ctx->getName() << " alph=" << allocationPreheaders[i]->getName() << " ctxheader=" << contexts[i].header->getName() << "\n";
if ( (i+1 < contexts.size()) && (allocationPreheaders[i] != allocationPreheaders[i+1]) ) {
//llvm::errs() << "starting outermost ph at " << allocationPreheaders[i]->getName() << "|ctx=" << ctx->getName() <<"\n";
sublimits.push_back(std::make_pair(size, lims));
size = nullptr;
lims.clear();
}
}

if (size != nullptr) {
llvm::errs() << "starting final outermost ph at " << allocationPreheaders[contexts.size()-1]->getName()<<"|ctx=" << ctx->getName() << "\n";
//llvm::errs() << "starting final outermost ph at " << allocationPreheaders[contexts.size()-1]->getName()<<"|ctx=" << ctx->getName() << "\n";
sublimits.push_back(std::make_pair(size, lims));
lims.clear();
}
Expand Down Expand Up @@ -1070,8 +1076,9 @@ class GradientUtils {
if (tbuild.GetInsertBlock()->size()) {
tbuild.SetInsertPoint(tbuild.GetInsertBlock()->getFirstNonPHI());
}

auto ci = cast<CallInst>(CallInst::CreateFree(tbuild.CreatePointerCast(tbuild.CreateLoad(unwrapM(storeInto, tbuild, antimap, /*lookup*/false)), Type::getInt8PtrTy(ctx->getContext())), tbuild.GetInsertBlock()));
auto forfree = cast<LoadInst>(tbuild.CreateLoad(unwrapM(storeInto, tbuild, antimap, /*lookup*/false)));
forfree->setMetadata(LLVMContext::MD_invariant_load, MDNode::get(forfree->getContext(), {}));
auto ci = cast<CallInst>(CallInst::CreateFree(tbuild.CreatePointerCast(forfree, Type::getInt8PtrTy(ctx->getContext())), tbuild.GetInsertBlock()));
ci->addAttribute(AttributeList::FirstArgIndex, Attribute::NonNull);
if (ci->getParent()==nullptr) {
tbuild.Insert(ci);
Expand Down Expand Up @@ -1103,6 +1110,7 @@ class GradientUtils {
Value* next = cache;
for(int i=sublimits.size()-1; i>=0; i--) {
next = BuilderM.CreateLoad(next);
cast<LoadInst>(next)->setMetadata(LLVMContext::MD_invariant_load, MDNode::get(next->getContext(), {}));

const auto& containedloops = sublimits[i].second;

Expand Down
117 changes: 117 additions & 0 deletions enzyme/test/Enzyme/triangular.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
; RUN: opt < %s %loadEnzyme -enzyme -enzyme_preopt=false -inline -mem2reg -instsimplify -adce -loop-deletion -correlated-propagation -simplifycfg -licm -early-cse -simplifycfg -instsimplify -S | FileCheck %s

; Function Attrs: noinline norecurse nounwind uwtable
define dso_local double @get(double* nocapture %x, i64 %i, i64 %j) local_unnamed_addr #0 {
entry:
%arrayidx = getelementptr inbounds double, double* %x, i64 %i
%0 = load double, double* %arrayidx, align 8
store double 0.000000e+00, double* %arrayidx, align 8
ret double %0
}

; Function Attrs: noinline norecurse nounwind uwtable
define dso_local double @f(double* nocapture %x, i64 %n) #0 {
entry:
br label %for.cond3.preheader

for.cond3.preheader: ; preds = %entry, %for.cond.cleanup6
%i = phi i64 [ %i_inc, %for.cond.cleanup6 ], [ 0, %entry ]
%outersum = phi double [ %sum.1.lcssa, %for.cond.cleanup6 ], [ 0.000000e+00, %entry ]
%i_inc = add nuw i64 %i, 1
; note this is now technically not exactly triangular
; %cmp423 = icmp eq i64 %i, 0
; br i1 %cmp423, label %for.cond.cleanup6, label %for.body7
br label %for.body7

for.body7: ; preds = %for.cond3.preheader, %for.body7
%j = phi i64 [ %j_inc, %for.body7 ], [ 0, %for.cond3.preheader ]
%innersum = phi double [ %add, %for.body7 ], [ %outersum, %for.cond3.preheader ]
%call = tail call fast double @get(double* %x, i64 undef, i64 %j)
%mul = fmul fast double %call, %call
%add = fadd fast double %mul, %innersum
%j_inc = add nuw i64 %j, 1
%exitcond = icmp eq i64 %j, %i
br i1 %exitcond, label %for.cond.cleanup6, label %for.body7

for.cond.cleanup6: ; preds = %for.body7
%sum.1.lcssa = phi double [ %add, %for.body7 ]
%cmp1 = icmp eq i64 %i, %n
br i1 %cmp1, label %return, label %for.cond3.preheader

return: ; preds = %for.cond.cleanup6
%retval.0 = phi double [ %sum.1.lcssa, %for.cond.cleanup6 ]
ret double %retval.0
}

; Function Attrs: noinline nounwind uwtable
define dso_local double @dsumsquare(double* %x, double* %xp, i64 %n) local_unnamed_addr #1 {
entry:
%call = tail call fast double @__enzyme_autodiff(i8* bitcast (double (double*, i64)* @f to i8*), double* %x, double* %xp, i64 %n)
ret double %call
}

declare dso_local double @__enzyme_autodiff(i8*, double*, double*, i64) local_unnamed_addr

attributes #0 = { noinline norecurse nounwind uwtable }
attributes #1 = { noinline nounwind uwtable }

; CHECK: define internal {} @diffef(double* nocapture %x, double* %"x'", i64 %n, double %differeturn) #0 {
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = add nuw i64 %n, 1
; CHECK-NEXT: %mallocsize = mul i64 %0, 8
; CHECK-NEXT: %malloccall = tail call noalias nonnull i8* @malloc(i64 %mallocsize)
; CHECK-NEXT: %call_malloccache = bitcast i8* %malloccall to double**
; CHECK-NEXT: br label %for.cond3.preheader

; CHECK: for.cond3.preheader: ; preds = %for.cond.cleanup6, %entry
; CHECK-NEXT: %iv = phi i64 [ %iv.next, %for.cond.cleanup6 ], [ 0, %entry ]
; CHECK-NEXT: %iv.next = add nuw i64 %iv, 1
; CHECK-NEXT: %[[mallocgep1:.+]] = getelementptr double*, double** %call_malloccache, i64 %iv
; CHECK-NEXT: %mallocsize3 = mul i64 %iv.next, 8
; CHECK-NEXT: %malloccall4 = tail call noalias nonnull i8* @malloc(i64 %mallocsize3)
; CHECK-NEXT: %call_malloccache5 = bitcast i8* %malloccall4 to double*
; CHECK-NEXT: store double* %call_malloccache5, double** %[[mallocgep1]]
; CHECK-NEXT: br label %for.body7

; CHECK: for.body7: ; preds = %for.body7, %for.cond3.preheader
; CHECK-NEXT: %iv1 = phi i64 [ %iv.next2, %for.body7 ], [ 0, %for.cond3.preheader ]
; CHECK-NEXT: %iv.next2 = add nuw i64 %iv1, 1
; CHECK-NEXT: %[[augmented:.+]] = call { {}, double } @augmented_get(double* %x, double* %"x'", i64 undef, i64 %iv1)
; CHECK-NEXT: %[[retval:.+]] = extractvalue { {}, double } %[[augmented]], 1
; CHECK-NEXT: %[[mallocgep2:.+]] = getelementptr double, double* %call_malloccache5, i64 %iv1
; CHECK-NEXT: store double %[[retval]], double* %[[mallocgep2]]
; CHECK-NEXT: %exitcond = icmp eq i64 %iv1, %iv
; CHECK-NEXT: br i1 %exitcond, label %for.cond.cleanup6, label %for.body7

; CHECK: for.cond.cleanup6: ; preds = %for.body7
; CHECK-NEXT: %cmp1 = icmp eq i64 %iv, %n
; CHECK-NEXT: br i1 %cmp1, label %invertfor.cond.cleanup6, label %for.cond3.preheader

; CHECK: invertentry: ; preds = %invertfor.cond3.preheader
; CHECK-NEXT: tail call void @free(i8* nonnull %malloccall)
; CHECK-NEXT: ret {} undef

; CHECK: invertfor.cond3.preheader: ; preds = %invertfor.body7
; CHECK-NEXT: %[[innerdatai8:.+]] = bitcast double* %[[innerdata:.+]] to i8*
; CHECK-NEXT: tail call void @free(i8* nonnull %[[innerdatai8]])
; CHECK-NEXT: %[[done1:.+]] = icmp eq i64 %"iv'phi", 0
; CHECK-NEXT: br i1 %[[done1]], label %invertentry, label %invertfor.cond.cleanup6

; CHECK: invertfor.body7: ; preds = %invertfor.cond.cleanup6, %invertfor.body7
; CHECK-NEXT: %"iv1'phi" = phi i64 [ %"iv'phi", %invertfor.cond.cleanup6 ], [ %[[subinner:.+]], %invertfor.body7 ]
; CHECK-NEXT: %[[subinner]] = sub i64 %"iv1'phi", 1
; CHECK-NEXT: %[[invertedgep2:.+]] = getelementptr double, double* %[[innerdata]], i64 %"iv1'phi"
; CHECK-NEXT: %[[cached:.+]] = load double, double* %[[invertedgep2]], !invariant.load !0
; CHECK-NEXT: %m0diffecall = fmul fast double %differeturn, %[[cached]]
; CHECK-NEXT: %[[innerdiffe:.+]] = fadd fast double %m0diffecall, %m0diffecall
; CHECK-NEXT: %[[dcall:.+]] = call {} @diffeget(double* %x, double* %"x'", i64 undef, i64 %"iv1'phi", double %[[innerdiffe]], {} undef)
; CHECK-NEXT: %[[done2:.+]] = icmp eq i64 %"iv1'phi", 0
; CHECK-NEXT: br i1 %[[done2]], label %invertfor.cond3.preheader, label %invertfor.body7

; CHECK: invertfor.cond.cleanup6: ; preds = %for.cond.cleanup6, %invertfor.cond3.preheader
; CHECK-NEXT: %"iv'phi" = phi i64 [ %[[subouter:.+]], %invertfor.cond3.preheader ], [ %n, %for.cond.cleanup6 ]
; CHECK-NEXT: %[[subouter]] = sub i64 %"iv'phi", 1
; CHECK-NEXT: %[[invertedgep1:.+]] = getelementptr double*, double** %call_malloccache, i64 %"iv'phi"
; CHECK-NEXT: %[[innerdata]] = load double*, double** %[[invertedgep1]]
; CHECK-NEXT: br label %invertfor.body7
; CHECK-NEXT: }
2 changes: 1 addition & 1 deletion tests/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ eigen.ll: eigen.cpp
../build/bin/opt -indvars -instcombine -functionattrs -instnamer eigen0.ll -S -o eigen.ll

eigen2.ll: eigen.ll
../build/bin/opt -always-inline -inline -lower-autodiff eigen.ll -S -o eigen1.ll
../build/bin/opt -always-inline -inline -load=../enzyme/build/Enzyme/LLVMEnzyme-7.so -enzyme eigen.ll -S -o eigen1.ll
#../build/bin/opt -always-inline -inline -lower-autodiff eigen.ll -S -o eigen1.ll -autodiff_inline
../build/bin/opt -mem2reg eigen1.ll -S -o eigena.ll
../build/bin/opt -sroa eigena.ll -S -o eigenb.ll
Expand Down
3 changes: 3 additions & 0 deletions tests/eigen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ static void matvec(const MatrixXd& __restrict W,
//printf("foo.rows()=%ld foo.cols()=%ld\n", foo.rows(), foo.cols());
//printf("b.rows()=%ld b.cols()=%ld\n", b.rows(), b.cols());
//printf("W.rows()=%ld W.cols()=%ld\n", W.rows(), W.cols());

/*
auto wr = W.rows();
__builtin_assume(wr > 0);
auto wr8 = wr << 3;
Expand All @@ -80,6 +82,7 @@ static void matvec(const MatrixXd& __restrict W,
__builtin_assume(fr == wr);
__builtin_assume(wc == br);
*/
foo = W * b;
//printf("r foo.rows()=%ld foo.cols()=%ld\n", foo.rows(), foo.cols());
//printf("r b.rows()=%ld b.cols()=%ld\n", b.rows(), b.cols());
Expand Down

0 comments on commit c17c043

Please sign in to comment.