Skip to content

Commit

Permalink
Handle Enzyme tape size of zero (#368)
Browse files Browse the repository at this point in the history
  • Loading branch information
LeilaGhaffari committed Oct 28, 2021
1 parent 36f3923 commit df488bd
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 0 deletions.
6 changes: 6 additions & 0 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,12 @@ class Enzyme : public ModulePass {
? aug->fn->getReturnType()
: cast<StructType>(aug->fn->getReturnType())
->getElementType(tapeIdx);
} else {
if (sizeOnly) {
CI->replaceAllUsesWith(ConstantInt::get(CI->getType(), 0, false));
CI->eraseFromParent();
return true;
}
}
if (sizeOnly) {
auto size = DL.getTypeSizeInBits(tapeType) / 8;
Expand Down
32 changes: 32 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/splitSize5.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s

; Function Attrs: noinline nounwind readnone uwtable
define double @tester(double* %x) {
entry:
%gep = getelementptr double, double* %x, i32 1
%y = load double, double* %x
%z = load double, double* %gep
%res = fadd fast double %y, %z
ret double %res
}

define void @test_derivative(double* %x, double* %dx) {
entry:
%size = call i64 (double (double*)*, ...) @__enzyme_augmentsize(double (double*)* nonnull @tester, metadata !"enzyme_dup")
%cache = alloca i8, i64 %size, align 1
call void (double (double*)*, ...) @__enzyme_augmentfwd(double (double*)* nonnull @tester, metadata !"enzyme_allocated", i64 %size, metadata !"enzyme_tape", i8* %cache, double* %x, double* %dx)
tail call void (double (double*)*, ...) @__enzyme_reverse(double (double*)* nonnull @tester, metadata !"enzyme_allocated", i64 %size, metadata !"enzyme_tape", i8* %cache, double* %x, double* %dx)
ret void
}

; Function Attrs: nounwind
declare void @__enzyme_augmentfwd(double (double*)*, ...)
declare i64 @__enzyme_augmentsize(double (double*)*, ...)
declare void @__enzyme_reverse(double (double*)*, ...)

; CHECK: define void @test_derivative(double* %x, double* %dx)
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = call fast double @augmented_tester(double* %x, double* %dx)
; CHECK-NEXT: call void @diffetester(double* %x, double* %dx, double 1.000000e+00)
; CHECK-NEXT: ret void
; CHECK-NEXT:}

0 comments on commit df488bd

Please sign in to comment.