Skip to content

Commit

Permalink
Fix forward fmuladd (rust-lang#668)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed May 28, 2022
1 parent e290aab commit ca37b19
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 3 deletions.
6 changes: 3 additions & 3 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -3792,8 +3792,8 @@ class AdjointGenerator
if (gutils->isConstantInstruction(&I))
return;

Value *op0 = gutils->getNewFromOriginal(orig_ops[0]);
Value *op1 = gutils->getNewFromOriginal(orig_ops[1]);
Value *op2 = gutils->getNewFromOriginal(orig_ops[2]);

Type *opType0 = gutils->getShadowType(orig_ops[0]->getType());
Type *opType1 = gutils->getShadowType(orig_ops[1]->getType());
Expand All @@ -3810,8 +3810,8 @@ class AdjointGenerator
: diffe(orig_ops[2], Builder2);

auto rule = [&](Value *dif0, Value *dif1, Value *dif2) {
Value *dif = Builder2.CreateFAdd(Builder2.CreateFMul(op1, dif2),
Builder2.CreateFMul(dif1, op2));
Value *dif = Builder2.CreateFAdd(Builder2.CreateFMul(op0, dif1),
Builder2.CreateFMul(op1, dif0));
return Builder2.CreateFAdd(dif, dif0);
};

Expand Down
28 changes: 28 additions & 0 deletions enzyme/test/Enzyme/ForwardMode/fmuladd.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -gvn -simplifycfg -instcombine -S | FileCheck %s

; Function Attrs: nounwind readnone uwtable
define double @tester(double %x, double %y, double %z) {
entry:
%0 = tail call fast double @llvm.fmuladd.f64(double %x, double %y, double %z)
ret double %0
}

define double @test_derivative(double %x, double %y, double %z) {
entry:
%0 = tail call double (double (double, double, double)*, ...) @__enzyme_fwddiff(double (double, double, double)* nonnull @tester, double %x, double %x, double %y, double %y, double %z, double %z)
ret double %0
}

declare double @llvm.fmuladd.f64(double %a, double %b, double %c)

; Function Attrs: nounwind
declare double @__enzyme_fwddiff(double (double, double, double)*, ...)

; CHECK: define internal double @fwddiffetester(double %x, double %"x'", double %y, double %"y'", double %z, double %"z'")
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = fmul fast double %y, %"x'"
; CHECK-NEXT: %1 = fmul fast double %x, %"y'"
; CHECK-NEXT: %2 = fadd fast double %1, %0
; CHECK-NEXT: %3 = fadd fast double %2, %"x'"
; CHECK-NEXT: ret double %3
; CHECK-NEXT: }
29 changes: 29 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/fmuladd.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -gvn -simplifycfg -instcombine -S | FileCheck %s

; Function Attrs: nounwind readnone uwtable
define double @tester(double %x, double %y, double %z) {
entry:
%0 = tail call fast double @llvm.fmuladd.f64(double %x, double %y, double %z)
ret double %0
}

define double @test_derivative(double %x, double %y, double %z) {
entry:
%0 = tail call double (double (double, double, double)*, ...) @__enzyme_autodiff(double (double, double, double)* nonnull @tester, double %x, double %y, double %z)
ret double %0
}

declare double @llvm.fmuladd.f64(double %a, double %b, double %c)

; Function Attrs: nounwind
declare double @__enzyme_autodiff(double (double, double, double)*, ...)

; CHECK: define internal { double, double, double } @diffetester(double %x, double %y, double %z, double %differeturn)
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = fmul fast double %differeturn, %y
; CHECK-NEXT: %1 = fmul fast double %differeturn, %x
; CHECK-NEXT: %2 = insertvalue { double, double, double } undef, double %0, 0
; CHECK-NEXT: %3 = insertvalue { double, double, double } %2, double %1, 1
; CHECK-NEXT: %4 = insertvalue { double, double, double } %3, double %differeturn, 2
; CHECK-NEXT: ret { double, double, double } %4
; CHECK-NEXT: }

0 comments on commit ca37b19

Please sign in to comment.