Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle Triangular Loops #22

Merged
merged 2 commits into from
Nov 7, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion bench/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

%-unopt.ll: %.cpp
#../build/bin/clang++ -fno-unroll-loops $^ -O3 -fno-vectorize -fno-slp-vectorize -S -emit-llvm -o $@ -ffast-math -Wno-error=non-pod-varargs -DEIGEN_UNROLLING_LIMIT=0 -I ../../adept-2.0.5/include -I../../tapenade/ADFirstAidKit
../build/bin/clang++ -fno-exceptions -fno-unroll-loops $^ -O3 -fno-vectorize -fno-slp-vectorize -S -emit-llvm -o $@ -ffast-math -Wno-error=non-pod-varargs -DEIGEN_UNROLLING_LIMIT=0 -I../../adept-2.0.5/include
#../build/bin/clang++ -fno-exceptions -fno-unroll-loops $^ -O3 -fno-vectorize -fno-slp-vectorize -S -emit-llvm -o $@ -ffast-math -Wno-error=non-pod-varargs -DEIGEN_UNROLLING_LIMIT=0 -I/home/wmoses/autodiff/adept-2.0.5/include
../build/bin/clang++ -fno-unroll-loops $^ -O3 -fno-vectorize -fno-slp-vectorize -S -emit-llvm -o $@ -ffast-math -Wno-error=non-pod-varargs -DEIGEN_UNROLLING_LIMIT=0 -I/home/wmoses/autodiff/adept-2.0.5/include

%-preopt.ll: %-unopt.ll
../build/bin/opt $^ -indvars -load=../enzyme/build/Enzyme/LLVMEnzyme-7.so -enzyme -mem2reg -sroa -early-cse-memssa -adce -bdce -simplifycfg -inline -adce -aggressive-instcombine -O2 -S -o $@
Expand Down
1 change: 1 addition & 0 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,7 @@ std::pair<Function*,StructType*> CreateAugmentedPrimal(Function* todiff, AAResul
}

if (called == nullptr) {
llvm::errs() << gutils->newFunc << "\n";
assert(op);
llvm::errs() << "cannot handle augment non constant function\n" << *op << "\n";
report_fatal_error("unknown augment non constant function");
Expand Down
1 change: 0 additions & 1 deletion enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,6 @@ void removeRedundantIVs(const Loop* L, BasicBlock* Header, BasicBlock* Preheader
assert(cmp->getOperand(0) == increment);

auto scv = SE.getSCEVAtScope(cmp->getOperand(1), L);
llvm::errs() << "coing to think about " << *cmp << "\n";
if (cmp->isUnsigned() || (scv != SE.getCouldNotCompute() && SE.isKnownNonNegative(scv)) ) {

// valid replacements (since unsigned comparison and i starts at 0 counting up)
Expand Down
49 changes: 19 additions & 30 deletions enzyme/Enzyme/GradientUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -906,7 +906,6 @@ class GradientUtils {
if (!getContext(blk, idx)) {
break;
}
llvm::errs() << " adding to contexts: " << idx.header->getName() << " starting ctx=" << ctx->getName() << "\n";
contexts.emplace_back(idx);
blk = idx.preheader;
}
Expand All @@ -925,29 +924,19 @@ class GradientUtils {
if (contexts[i].dynamic) {
limits[i] = ConstantInt::get(Type::getInt64Ty(ctx->getContext()), 1);
} else {
//while (limits[i] == nullptr) {
ValueToValueMapTy emptyMap;
IRBuilder <> allocationBuilder(&allocationPreheaders[i]->back());
Value* limitMinus1 = unwrapM(contexts[i].limit, allocationBuilder, emptyMap, /*lookupIfAble*/false);
if (limitMinus1 == nullptr) {
assert(allocationPreheaders[i]);
llvm::errs() << *oldFunc << "\n";
llvm::errs() << *newFunc << "\n";
llvm::errs() << "needed value " << *contexts[i].limit << " at " << allocationPreheaders[i]->getName() << "\n";
}
assert(limitMinus1 != nullptr);
limits[i] = allocationBuilder.CreateNUWAdd(limitMinus1, ConstantInt::get(limitMinus1->getType(), 1));
//TODO allow triangular arrays per above
/*
if (limits[i] == nullptr) {
int firstDifferent = j+1;
while (allocationPreheaders[firstDifferent] == allocationPreheaders[i]) {
firstDifferent++;
assert(firstDifferent < contexts.size());
}
allocationPreheaders[i] = allocationPreheaders[firstDifferent];
}
}*/
ValueToValueMapTy emptyMap;
IRBuilder <> allocationBuilder(&allocationPreheaders[i]->back());
Value* limitMinus1 = unwrapM(contexts[i].limit, allocationBuilder, emptyMap, /*lookupIfAble*/false);

// We have a loop with static bounds, but whose limit is not available to be computed at the current loop preheader (such as the innermost loop of triangular iteration domain)
// Handle this case like a dynamic loop
if (limitMinus1 == nullptr) {
allocationPreheaders[i] = contexts[i].preheader;
allocationBuilder.SetInsertPoint(&allocationPreheaders[i]->back());
limitMinus1 = unwrapM(contexts[i].limit, allocationBuilder, emptyMap, /*lookupIfAble*/false);
}
assert(limitMinus1 != nullptr);
limits[i] = allocationBuilder.CreateNUWAdd(limitMinus1, ConstantInt::get(limitMinus1->getType(), 1));
}
}

Expand All @@ -964,17 +953,15 @@ class GradientUtils {
size = allocationBuilder.CreateNUWMul(size, limits[i]);
}

llvm::errs() << "considering ctx " << ctx->getName() << " alph=" << allocationPreheaders[i]->getName() << " ctxheader=" << contexts[i].header->getName() << "\n";
if (contexts[i].dynamic) {
llvm::errs() << "starting outermost ph at " << allocationPreheaders[i]->getName() << "|ctx=" << ctx->getName() <<"\n";
// We are now starting a new allocation context
if ( (i+1 < contexts.size()) && (allocationPreheaders[i] != allocationPreheaders[i+1]) ) {
sublimits.push_back(std::make_pair(size, lims));
size = nullptr;
lims.clear();
}
}

if (size != nullptr) {
llvm::errs() << "starting final outermost ph at " << allocationPreheaders[contexts.size()-1]->getName()<<"|ctx=" << ctx->getName() << "\n";
sublimits.push_back(std::make_pair(size, lims));
lims.clear();
}
Expand Down Expand Up @@ -1073,8 +1060,9 @@ class GradientUtils {
if (tbuild.GetInsertBlock()->size()) {
tbuild.SetInsertPoint(tbuild.GetInsertBlock()->getFirstNonPHI());
}

auto ci = cast<CallInst>(CallInst::CreateFree(tbuild.CreatePointerCast(tbuild.CreateLoad(unwrapM(storeInto, tbuild, antimap, /*lookup*/false)), Type::getInt8PtrTy(ctx->getContext())), tbuild.GetInsertBlock()));
auto forfree = cast<LoadInst>(tbuild.CreateLoad(unwrapM(storeInto, tbuild, antimap, /*lookup*/false)));
forfree->setMetadata(LLVMContext::MD_invariant_load, MDNode::get(forfree->getContext(), {}));
auto ci = cast<CallInst>(CallInst::CreateFree(tbuild.CreatePointerCast(forfree, Type::getInt8PtrTy(ctx->getContext())), tbuild.GetInsertBlock()));
ci->addAttribute(AttributeList::FirstArgIndex, Attribute::NonNull);
if (ci->getParent()==nullptr) {
tbuild.Insert(ci);
Expand Down Expand Up @@ -1106,6 +1094,7 @@ class GradientUtils {
Value* next = cache;
for(int i=sublimits.size()-1; i>=0; i--) {
next = BuilderM.CreateLoad(next);
cast<LoadInst>(next)->setMetadata(LLVMContext::MD_invariant_load, MDNode::get(next->getContext(), {}));

const auto& containedloops = sublimits[i].second;

Expand Down
44 changes: 44 additions & 0 deletions enzyme/test/Enzyme/globalptr.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
; RUN: opt < %s %loadEnzyme -enzyme -enzyme_preopt=false -mem2reg -instsimplify -adce -correlated-propagation -simplifycfg -S | FileCheck %s

; XFAIL: *
; a function returning a pointer/float with no arguments is mistakenly marked as constant in spite of accessing a global

@global = external dso_local local_unnamed_addr global double*, align 8

; Function Attrs: noinline norecurse nounwind readonly uwtable
define dso_local double* @myglobal() local_unnamed_addr #0 {
entry:
%0 = load double*, double** @global, align 8
ret double* %0
}

; Function Attrs: noinline norecurse nounwind readonly uwtable
define dso_local double @mulglobal(double %x) #0 {
entry:
%call = tail call double* @myglobal()
%arrayidx = getelementptr inbounds double, double* %call, i64 2
%0 = load double, double* %arrayidx, align 8
%mul = fmul fast double %0, %x
ret double %mul
}

; Function Attrs: noinline nounwind uwtable
define dso_local double @derivative(double %x) local_unnamed_addr #1 {
entry:
%0 = tail call double (...) @__enzyme_autodiff.f64(double (double)* nonnull @mulglobal, double %x) #2
ret double %0
}

declare double @__enzyme_autodiff.f64(...) local_unnamed_addr

attributes #0 = { noinline norecurse nounwind readonly uwtable }
attributes #1 = { noinline nounwind uwtable }
attributes #2 = { nounwind }

; CHECK: define internal { double } @diffemulglobal(double %x, double %differeturn)
; CHECK-NEXT: entry:
; CHECK-NEXT: %call = tail call double* @myglobal()
; CHECK-NEXT: %arrayidx = getelementptr inbounds double, double* %call, i64 2
; CHECK-NEXT: %0 = load double, double* %arrayidx, align 8
; CHECK-NEXT: %[[tmul:.+]] = fmul fast double %0, %x
; CHECK-NEXT: %[[tcall.+]] = call {} @diffemyglobal(double %x)
116 changes: 116 additions & 0 deletions enzyme/test/Enzyme/triangular.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
; RUN: opt < %s %loadEnzyme -enzyme -enzyme_preopt=false -inline -mem2reg -instsimplify -adce -loop-deletion -correlated-propagation -simplifycfg -licm -early-cse -simplifycfg -instsimplify -S | FileCheck %s

; Function Attrs: noinline norecurse nounwind uwtable
define dso_local double @get(double* nocapture %x, i64 %i, i64 %j) local_unnamed_addr #0 {
entry:
%arrayidx = getelementptr inbounds double, double* %x, i64 %i
%0 = load double, double* %arrayidx, align 8
ret double %0
}

; Function Attrs: noinline norecurse nounwind uwtable
define dso_local double @f(double* nocapture %x, i64 %n) #0 {
entry:
br label %for.cond3.preheader

for.cond3.preheader: ; preds = %entry, %for.cond.cleanup6
%i = phi i64 [ %i_inc, %for.cond.cleanup6 ], [ 0, %entry ]
%outersum = phi double [ %sum.1.lcssa, %for.cond.cleanup6 ], [ 0.000000e+00, %entry ]
%i_inc = add nuw i64 %i, 1
; note this is now technically not exactly triangular
; %cmp423 = icmp eq i64 %i, 0
; br i1 %cmp423, label %for.cond.cleanup6, label %for.body7
br label %for.body7

for.body7: ; preds = %for.cond3.preheader, %for.body7
%j = phi i64 [ %j_inc, %for.body7 ], [ 0, %for.cond3.preheader ]
%innersum = phi double [ %add, %for.body7 ], [ %outersum, %for.cond3.preheader ]
%call = tail call fast double @get(double* %x, i64 undef, i64 %j)
%mul = fmul fast double %call, %call
%add = fadd fast double %mul, %innersum
%j_inc = add nuw i64 %j, 1
%exitcond = icmp eq i64 %j, %i
br i1 %exitcond, label %for.cond.cleanup6, label %for.body7

for.cond.cleanup6: ; preds = %for.body7
%sum.1.lcssa = phi double [ %add, %for.body7 ]
%cmp1 = icmp eq i64 %i, %n
br i1 %cmp1, label %return, label %for.cond3.preheader

return: ; preds = %for.cond.cleanup6
%retval.0 = phi double [ %sum.1.lcssa, %for.cond.cleanup6 ]
ret double %retval.0
}

; Function Attrs: noinline nounwind uwtable
define dso_local double @dsumsquare(double* %x, double* %xp, i64 %n) local_unnamed_addr #1 {
entry:
%call = tail call fast double @__enzyme_autodiff(i8* bitcast (double (double*, i64)* @f to i8*), double* %x, double* %xp, i64 %n)
ret double %call
}

declare dso_local double @__enzyme_autodiff(i8*, double*, double*, i64) local_unnamed_addr

attributes #0 = { noinline norecurse nounwind uwtable }
attributes #1 = { noinline nounwind uwtable }

; CHECK: define internal {} @diffef(double* nocapture %x, double* %"x'", i64 %n, double %differeturn) #0 {
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = add nuw i64 %n, 1
; CHECK-NEXT: %mallocsize = mul i64 %0, 8
; CHECK-NEXT: %malloccall = tail call noalias nonnull i8* @malloc(i64 %mallocsize)
; CHECK-NEXT: %call_malloccache = bitcast i8* %malloccall to double**
; CHECK-NEXT: br label %for.cond3.preheader

; CHECK: for.cond3.preheader: ; preds = %for.cond.cleanup6, %entry
; CHECK-NEXT: %iv = phi i64 [ %iv.next, %for.cond.cleanup6 ], [ 0, %entry ]
; CHECK-NEXT: %iv.next = add nuw i64 %iv, 1
; CHECK-NEXT: %[[mallocgep1:.+]] = getelementptr double*, double** %call_malloccache, i64 %iv
; CHECK-NEXT: %mallocsize3 = mul i64 %iv.next, 8
; CHECK-NEXT: %malloccall4 = tail call noalias nonnull i8* @malloc(i64 %mallocsize3)
; CHECK-NEXT: %call_malloccache5 = bitcast i8* %malloccall4 to double*
; CHECK-NEXT: store double* %call_malloccache5, double** %[[mallocgep1]]
; CHECK-NEXT: br label %for.body7

; CHECK: for.body7: ; preds = %for.body7, %for.cond3.preheader
; CHECK-NEXT: %iv1 = phi i64 [ %iv.next2, %for.body7 ], [ 0, %for.cond3.preheader ]
; CHECK-NEXT: %iv.next2 = add nuw i64 %iv1, 1
; CHECK-NEXT: %[[augmented:.+]] = call { {}, double } @augmented_get(double* %x, double* %"x'", i64 undef, i64 %iv1)
; CHECK-NEXT: %[[retval:.+]] = extractvalue { {}, double } %[[augmented]], 1
; CHECK-NEXT: %[[mallocgep2:.+]] = getelementptr double, double* %call_malloccache5, i64 %iv1
; CHECK-NEXT: store double %[[retval]], double* %[[mallocgep2]]
; CHECK-NEXT: %exitcond = icmp eq i64 %iv1, %iv
; CHECK-NEXT: br i1 %exitcond, label %for.cond.cleanup6, label %for.body7

; CHECK: for.cond.cleanup6: ; preds = %for.body7
; CHECK-NEXT: %cmp1 = icmp eq i64 %iv, %n
; CHECK-NEXT: br i1 %cmp1, label %invertfor.cond.cleanup6, label %for.cond3.preheader

; CHECK: invertentry: ; preds = %invertfor.cond3.preheader
; CHECK-NEXT: tail call void @free(i8* nonnull %malloccall)
; CHECK-NEXT: ret {} undef

; CHECK: invertfor.cond3.preheader: ; preds = %invertfor.body7
; CHECK-NEXT: %[[innerdatai8:.+]] = bitcast double* %[[innerdata:.+]] to i8*
; CHECK-NEXT: tail call void @free(i8* nonnull %[[innerdatai8]])
; CHECK-NEXT: %[[done1:.+]] = icmp eq i64 %"iv'phi", 0
; CHECK-NEXT: br i1 %[[done1]], label %invertentry, label %invertfor.cond.cleanup6

; CHECK: invertfor.body7: ; preds = %invertfor.cond.cleanup6, %invertfor.body7
; CHECK-NEXT: %"iv1'phi" = phi i64 [ %"iv'phi", %invertfor.cond.cleanup6 ], [ %[[subinner:.+]], %invertfor.body7 ]
; CHECK-NEXT: %[[subinner]] = sub i64 %"iv1'phi", 1
; CHECK-NEXT: %[[invertedgep2:.+]] = getelementptr double, double* %[[innerdata]], i64 %"iv1'phi"
; CHECK-NEXT: %[[cached:.+]] = load double, double* %[[invertedgep2]], !invariant.load !0
; CHECK-NEXT: %m0diffecall = fmul fast double %differeturn, %[[cached]]
; CHECK-NEXT: %[[innerdiffe:.+]] = fadd fast double %m0diffecall, %m0diffecall
; CHECK-NEXT: %[[dcall:.+]] = call {} @diffeget(double* %x, double* %"x'", i64 undef, i64 %"iv1'phi", double %[[innerdiffe]], {} undef)
; CHECK-NEXT: %[[done2:.+]] = icmp eq i64 %"iv1'phi", 0
; CHECK-NEXT: br i1 %[[done2]], label %invertfor.cond3.preheader, label %invertfor.body7

; CHECK: invertfor.cond.cleanup6: ; preds = %for.cond.cleanup6, %invertfor.cond3.preheader
; CHECK-NEXT: %"iv'phi" = phi i64 [ %[[subouter:.+]], %invertfor.cond3.preheader ], [ %n, %for.cond.cleanup6 ]
; CHECK-NEXT: %[[subouter]] = sub i64 %"iv'phi", 1
; CHECK-NEXT: %[[invertedgep1:.+]] = getelementptr double*, double** %call_malloccache, i64 %"iv'phi"
; CHECK-NEXT: %[[innerdata]] = load double*, double** %[[invertedgep1]]
; CHECK-NEXT: br label %invertfor.body7
; CHECK-NEXT: }
2 changes: 1 addition & 1 deletion tests/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ eigen.ll: eigen.cpp
../build/bin/opt -indvars -instcombine -functionattrs -instnamer eigen0.ll -S -o eigen.ll

eigen2.ll: eigen.ll
../build/bin/opt -always-inline -inline -lower-autodiff eigen.ll -S -o eigen1.ll
../build/bin/opt -always-inline -inline -load=../enzyme/build/Enzyme/LLVMEnzyme-7.so -enzyme eigen.ll -S -o eigen1.ll
#../build/bin/opt -always-inline -inline -lower-autodiff eigen.ll -S -o eigen1.ll -autodiff_inline
../build/bin/opt -mem2reg eigen1.ll -S -o eigena.ll
../build/bin/opt -sroa eigena.ll -S -o eigenb.ll
Expand Down
3 changes: 3 additions & 0 deletions tests/eigen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ static void matvec(const MatrixXd& __restrict W,
//printf("foo.rows()=%ld foo.cols()=%ld\n", foo.rows(), foo.cols());
//printf("b.rows()=%ld b.cols()=%ld\n", b.rows(), b.cols());
//printf("W.rows()=%ld W.cols()=%ld\n", W.rows(), W.cols());

/*
auto wr = W.rows();
__builtin_assume(wr > 0);
auto wr8 = wr << 3;
Expand All @@ -80,6 +82,7 @@ static void matvec(const MatrixXd& __restrict W,

__builtin_assume(fr == wr);
__builtin_assume(wc == br);
*/
foo = W * b;
//printf("r foo.rows()=%ld foo.cols()=%ld\n", foo.rows(), foo.cols());
//printf("r b.rows()=%ld b.cols()=%ld\n", b.rows(), b.cols());
Expand Down