Skip to content

Commit

Permalink
Add atomic fadd for reverse mode (#849)
Browse files Browse the repository at this point in the history
* Add atomic fadd for reverse mode

* Fix lower version

* Fix for version

* add maybealign
  • Loading branch information
wsmoses committed Sep 19, 2022
1 parent 0665cc8 commit ad0af85
Show file tree
Hide file tree
Showing 4 changed files with 227 additions and 26 deletions.
4 changes: 4 additions & 0 deletions enzyme/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,14 @@ if (NOT DEFINED LLVM_EXTERNAL_LIT)
message("found llvm match ${CMAKE_MATCH_1} dir ${LLVM_DIR}")
if (EXISTS ${LLVM_DIR}/../../../bin/llvm-lit)
set(LLVM_EXTERNAL_LIT ${LLVM_DIR}/../../../bin/llvm-lit)
else()
set(LLVM_EXTERNAL_LIT lit)
endif()
else()
if (EXISTS ${LLVM_DIR}/bin/llvm-lit)
set(LLVM_EXTERNAL_LIT ${LLVM_DIR}/bin/llvm-lit)
else()
set(LLVM_EXTERNAL_LIT lit)
endif()
endif()
endif()
Expand Down
117 changes: 91 additions & 26 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -935,12 +935,25 @@ class AdjointGenerator
}

void visitAtomicRMWInst(llvm::AtomicRMWInst &I) {
if (Mode == DerivativeMode::ForwardMode) {
IRBuilder<> BuilderZ(&I);
getForwardBuilder(BuilderZ);
switch (I.getOperation()) {
case AtomicRMWInst::FAdd:
case AtomicRMWInst::FSub: {

if (gutils->isConstantInstruction(&I) && gutils->isConstantValue(&I)) {
if (Mode == DerivativeMode::ReverseModeGradient ||
Mode == DerivativeMode::ForwardModeSplit) {
eraseIfUnused(I, /*erase*/ true, /*check*/ false);
} else {
eraseIfUnused(I);
}
return;
}

switch (I.getOperation()) {
case AtomicRMWInst::FAdd:
case AtomicRMWInst::FSub: {

if (Mode == DerivativeMode::ForwardMode ||
Mode == DerivativeMode::ForwardModeSplit) {
IRBuilder<> BuilderZ(&I);
getForwardBuilder(BuilderZ);
auto rule = [&](Value *ptr, Value *dif) -> Value * {
if (!gutils->isConstantInstruction(&I)) {
assert(ptr);
Expand Down Expand Up @@ -981,32 +994,84 @@ class AdjointGenerator
setDiffe(&I, diff, BuilderZ);
return;
}
default:
break;
if (Mode == DerivativeMode::ReverseModePrimal) {
eraseIfUnused(I);
return;
}
}
if (!gutils->isConstantInstruction(&I) || !gutils->isConstantValue(&I)) {
if (looseTypeAnalysis) {
auto &DL = gutils->newFunc->getParent()->getDataLayout();
auto valType = I.getValOperand()->getType();
auto storeSize = DL.getTypeSizeInBits(valType) / 8;
auto fp = TR.firstPointer(storeSize, I.getPointerOperand(),
/*errifnotfound*/ false,
/*pointerIntSame*/ true);
if (!fp.isKnown() && valType->isIntOrIntVectorTy()) {
goto noerror;
if ((Mode == DerivativeMode::ReverseModeCombined ||
Mode == DerivativeMode::ReverseModeGradient) &&
gutils->isConstantValue(&I)) {
if (!gutils->isConstantValue(I.getValOperand())) {
assert(!gutils->isConstantValue(I.getPointerOperand()));
IRBuilder<> Builder2(&I);
getReverseBuilder(Builder2);
Value *ip = gutils->invertPointerM(I.getPointerOperand(), Builder2);
auto order = I.getOrdering();
if (order == AtomicOrdering::Release)
order = AtomicOrdering::Monotonic;
else if (order == AtomicOrdering::AcquireRelease)
order = AtomicOrdering::Acquire;

auto rule = [&](Value *ip) -> Value * {
#if LLVM_VERSION_MAJOR > 7
LoadInst *dif1 =
Builder2.CreateLoad(I.getType(), ip, I.isVolatile());
#else
LoadInst *dif1 = Builder2.CreateLoad(ip, I.isVolatile());
#endif

#if LLVM_VERSION_MAJOR >= 11
dif1->setAlignment(I.getAlign());
#else
const DataLayout &DL = I.getModule()->getDataLayout();
auto tmpAlign = DL.getTypeStoreSize(I.getValOperand()->getType());
#if LLVM_VERSION_MAJOR >= 10
dif1->setAlignment(MaybeAlign(tmpAlign.getFixedSize()));
#else
dif1->setAlignment(tmpAlign);
#endif
#endif
dif1->setOrdering(order);
dif1->setSyncScopeID(I.getSyncScopeID());
return dif1;
};
Value *diff = applyChainRule(I.getType(), Builder2, rule, ip);

addToDiffe(I.getValOperand(), diff, Builder2,
I.getValOperand()->getType()->getScalarType());
}
if (Mode == DerivativeMode::ReverseModeGradient) {
eraseIfUnused(I, /*erase*/ true, /*check*/ false);
} else
eraseIfUnused(I);
return;
}
TR.dump();
llvm::errs() << "oldFunc: " << *gutils->newFunc << "\n";
llvm::errs() << "I: " << I << "\n";
assert(0 && "Active atomic inst not handled");
break;
}
default:
break;
}
noerror:;

if (Mode == DerivativeMode::ReverseModeGradient) {
eraseIfUnused(I, /*erase*/ true, /*check*/ false);
if (looseTypeAnalysis) {
auto &DL = gutils->newFunc->getParent()->getDataLayout();
auto valType = I.getValOperand()->getType();
auto storeSize = DL.getTypeSizeInBits(valType) / 8;
auto fp = TR.firstPointer(storeSize, I.getPointerOperand(),
/*errifnotfound*/ false,
/*pointerIntSame*/ true);
if (!fp.isKnown() && valType->isIntOrIntVectorTy()) {
if (Mode == DerivativeMode::ReverseModeGradient ||
Mode == DerivativeMode::ReverseModeGradient) {
eraseIfUnused(I, /*erase*/ true, /*check*/ false);
} else
eraseIfUnused(I);
return;
}
}
TR.dump();
llvm::errs() << "oldFunc: " << *gutils->newFunc << "\n";
llvm::errs() << "I: " << I << "\n";
llvm_unreachable("Active atomic inst not yet handled");
}

void visitStoreInst(llvm::StoreInst &SI) {
Expand Down
90 changes: 90 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/atomicfadd.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
; RUN: if [ %llvmver -ge 9 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -S | FileCheck %s; fi

; ModuleID = '<source>'
source_filename = "<source>"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"

define dso_local void @foo1(double* %p, double %v) {
%a10 = atomicrmw volatile fadd double* %p, double %v monotonic
ret void
}
define dso_local void @foo2(double* %p, double %v) {
%a10 = atomicrmw volatile fadd double* %p, double %v acquire
ret void
}
define dso_local void @foo3(double* %p, double %v) {
%a10 = atomicrmw volatile fadd double* %p, double %v release
ret void
}
define dso_local void @foo4(double* %p, double %v) {
%a10 = atomicrmw volatile fadd double* %p, double %v acq_rel
ret void
}
define dso_local void @foo5(double* %p, double %v) {
%a10 = atomicrmw volatile fadd double* %p, double %v seq_cst
ret void
}
define dso_local void @foo6(double* %p, double %v) {
%a10 = atomicrmw volatile fadd double* %p, double 1.000000e+00 seq_cst
ret void
}

define void @caller(double* %a, double* %b, double %v) {
%r1 = call double @_Z17__enzyme_autodiffPviRdS0_(i8* bitcast (void (double*, double)* @foo1 to i8*), double* %a, double* %b, double %v)
%r2 = call double @_Z17__enzyme_autodiffPviRdS0_(i8* bitcast (void (double*, double)* @foo2 to i8*), double* %a, double* %b, double %v)
%r3 = call double @_Z17__enzyme_autodiffPviRdS0_(i8* bitcast (void (double*, double)* @foo3 to i8*), double* %a, double* %b, double %v)
%r4 = call double @_Z17__enzyme_autodiffPviRdS0_(i8* bitcast (void (double*, double)* @foo4 to i8*), double* %a, double* %b, double %v)
%r5 = call double @_Z17__enzyme_autodiffPviRdS0_(i8* bitcast (void (double*, double)* @foo5 to i8*), double* %a, double* %b, double %v)
%r6 = call double @_Z17__enzyme_autodiffPviRdS0_(i8* bitcast (void (double*, double)* @foo6 to i8*), double* %a, double* %b, double %v)
ret void
}

declare double @_Z17__enzyme_autodiffPviRdS0_(i8*, double*, double*, double)


; CHECK: define internal { double } @diffefoo1(double* %p, double* %"p'", double %v)
; CHECK-NEXT: invert:
; CHECK-NEXT: %a10 = atomicrmw volatile fadd double* %p, double %v monotonic
; CHECK-NEXT: %0 = load atomic volatile double, double* %"p'" monotonic, align 8
; CHECK-NEXT: %1 = insertvalue { double } {{(undef|poison)}}, double %0, 0
; CHECK-NEXT: ret { double } %1
; CHECK-NEXT: }

; CHECK: define internal { double } @diffefoo2(double* %p, double* %"p'", double %v)
; CHECK-NEXT: invert:
; CHECK-NEXT: %a10 = atomicrmw volatile fadd double* %p, double %v acquire
; CHECK-NEXT: %0 = load atomic volatile double, double* %"p'" acquire, align 8
; CHECK-NEXT: %1 = insertvalue { double } {{(undef|poison)}}, double %0, 0
; CHECK-NEXT: ret { double } %1
; CHECK-NEXT: }

; CHECK: define internal { double } @diffefoo3(double* %p, double* %"p'", double %v)
; CHECK-NEXT: invert:
; CHECK-NEXT: %a10 = atomicrmw volatile fadd double* %p, double %v release
; CHECK-NEXT: %0 = load atomic volatile double, double* %"p'" monotonic, align 8
; CHECK-NEXT: %1 = insertvalue { double } {{(undef|poison)}}, double %0, 0
; CHECK-NEXT: ret { double } %1
; CHECK-NEXT: }

; CHECK: define internal { double } @diffefoo4(double* %p, double* %"p'", double %v)
; CHECK-NEXT: invert:
; CHECK-NEXT: %a10 = atomicrmw volatile fadd double* %p, double %v acq_rel
; CHECK-NEXT: %0 = load atomic volatile double, double* %"p'" acquire, align 8
; CHECK-NEXT: %1 = insertvalue { double } {{(undef|poison)}}, double %0, 0
; CHECK-NEXT: ret { double } %1
; CHECK-NEXT: }

; CHECK: define internal { double } @diffefoo5(double* %p, double* %"p'", double %v)
; CHECK-NEXT: invert:
; CHECK-NEXT: %a10 = atomicrmw volatile fadd double* %p, double %v seq_cst
; CHECK-NEXT: %0 = load atomic volatile double, double* %"p'" seq_cst, align 8
; CHECK-NEXT: %1 = insertvalue { double } {{(undef|poison)}}, double %0, 0
; CHECK-NEXT: ret { double } %1
; CHECK-NEXT: }

; CHECK: define internal { double } @diffefoo6(double* %p, double* %"p'", double %v)
; CHECK-NEXT: invert:
; CHECK-NEXT: %a10 = atomicrmw volatile fadd double* %p, double 1.000000e+00 seq_cst
; CHECK-NEXT: ret { double } zeroinitializer
; CHECK-NEXT: }
42 changes: 42 additions & 0 deletions enzyme/test/Enzyme/ReverseModeVector/atomicadd.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
; RUN: if [ %llvmver -ge 9 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -S | FileCheck %s; fi

; ModuleID = '<source>'
source_filename = "<source>"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"

define dso_local void @foo1(double* %p, double %v) {
%a10 = atomicrmw volatile fadd double* %p, double %v monotonic
ret void
}
define dso_local void @foo6(double* %p, double %v) {
%a10 = atomicrmw volatile fadd double* %p, double 1.000000e+00 seq_cst
ret void
}

define void @caller(double* %a, double* %b, double %v) {
%r1 = call [2 x double] (...) @_Z17__enzyme_autodiffPviRdS0_(i8* bitcast (void (double*, double)* @foo1 to i8*), metadata !"enzyme_width", i64 2, double* %a, double* %b, double* %b, double %v)
%r6 = call [2 x double] (...) @_Z17__enzyme_autodiffPviRdS0_(i8* bitcast (void (double*, double)* @foo6 to i8*), metadata !"enzyme_width", i64 2, double* %a, double* %b, double* %b, double %v)
ret void
}

declare [2 x double] @_Z17__enzyme_autodiffPviRdS0_(...)

; CHECK: define internal { [2 x double] } @diffe2foo1(double* %p, [2 x double*] %"p'", double %v)
; CHECK-NEXT: invert:
; CHECK-NEXT: %a10 = atomicrmw volatile fadd double* %p, double %v monotonic
; CHECK-NEXT: %0 = extractvalue [2 x double*] %"p'", 0
; CHECK-NEXT: %1 = load atomic volatile double, double* %0 monotonic, align 8
; CHECK-NEXT: %2 = extractvalue [2 x double*] %"p'", 1
; CHECK-NEXT: %3 = load atomic volatile double, double* %2 monotonic, align 8
; CHECK-NEXT: %.fca.0.insert5 = insertvalue [2 x double] {{(undef|poison)}}, double %1, 0
; CHECK-NEXT: %.fca.1.insert8 = insertvalue [2 x double] %.fca.0.insert5, double %3, 1
; CHECK-NEXT: %4 = insertvalue { [2 x double] } undef, [2 x double] %.fca.1.insert8, 0
; CHECK-NEXT: ret { [2 x double] } %4
; CHECK-NEXT: }

; CHECK: define internal { [2 x double] } @diffe2foo6(double* %p, [2 x double*] %"p'", double %v)
; CHECK-NEXT: invert:
; CHECK-NEXT: %a10 = atomicrmw volatile fadd double* %p, double 1.000000e+00 seq_cst
; CHECK-NEXT: ret { [2 x double] } zeroinitializer
; CHECK-NEXT: }

0 comments on commit ad0af85

Please sign in to comment.