Skip to content

Commit

Permalink
Improve adjoint of swap
Browse files Browse the repository at this point in the history
  • Loading branch information
tansongchen committed May 26, 2022
1 parent 17b7f8d commit 407999b
Show file tree
Hide file tree
Showing 13 changed files with 336 additions and 804 deletions.
424 changes: 82 additions & 342 deletions enzyme/Enzyme/AdjointGenerator.h

Large diffs are not rendered by default.

82 changes: 82 additions & 0 deletions enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,88 @@ Function *getOrInsertMemcpyStrided(Module &M, PointerType *T, Type *IT,
return F;
}

Function *getOrInsertMemsetStrided(Module &M, PointerType *T, Type *IT,
unsigned align) {
Type *elementType = T->getPointerElementType();
assert(elementType->isFloatingPointTy());
std::string name = "__enzyme_memset_" + tofltstr(elementType) + "_" +
std::to_string(cast<IntegerType>(IT)->getBitWidth()) +
"_align" + std::to_string(align) + "stride";
FunctionType *FT =
FunctionType::get(Type::getVoidTy(M.getContext()), {T, elementType, IT, IT}, false);

#if LLVM_VERSION_MAJOR >= 9
Function *F = cast<Function>(M.getOrInsertFunction(name, FT).getCallee());
#else
Function *F = cast<Function>(M.getOrInsertFunction(name, FT));
#endif

if (!F->empty())
return F;

F->setLinkage(Function::LinkageTypes::InternalLinkage);
F->addFnAttr(Attribute::ArgMemOnly);
F->addFnAttr(Attribute::NoUnwind);
F->addFnAttr(Attribute::AlwaysInline);
F->addParamAttr(0, Attribute::NoCapture);
F->addParamAttr(0, Attribute::WriteOnly);

BasicBlock *entry = BasicBlock::Create(M.getContext(), "entry", F);
BasicBlock *body = BasicBlock::Create(M.getContext(), "for.body", F);
BasicBlock *end = BasicBlock::Create(M.getContext(), "for.end", F);

auto dst = F->arg_begin();
dst->setName("dst");
auto val = dst + 1;
val->setName("val");
auto num = val + 1;
num->setName("num");
auto stride = num + 1;
stride->setName("stride");

{
IRBuilder<> B(entry);
B.CreateCondBr(B.CreateICmpEQ(num, ConstantInt::get(num->getType(), 0)),
end, body);
}

{
IRBuilder<> B(body);
B.setFastMathFlags(getFast());
PHINode *idx = B.CreatePHI(num->getType(), 2, "idx");
idx->addIncoming(ConstantInt::get(num->getType(), 0), entry);

#if LLVM_VERSION_MAJOR > 7
Value *dsti = B.CreateInBoundsGEP(dst->getType()->getPointerElementType(),
dst, idx, "dst.i");
#else
Value *dsti = B.CreateInBoundsGEP(dst, idx, "dst.i");
#endif

StoreInst *dsts = B.CreateStore(val, dsti);

if (align) {
#if LLVM_VERSION_MAJOR >= 10
dsts->setAlignment(Align(align));
#else
dsts->setAlignment(align);
#endif
}

Value *next =
B.CreateNUWAdd(idx, ConstantInt::get(num->getType(), 1), "idx.next");
idx->addIncoming(next, body);
B.CreateCondBr(B.CreateICmpEQ(num, next), end, body);
}

{
IRBuilder<> B(end);
B.CreateRetVoid();
}

return F;
}

// TODO implement differential memmove
Function *getOrInsertDifferentialFloatMemmove(Module &M, Type *T,
unsigned dstalign,
Expand Down
4 changes: 4 additions & 0 deletions enzyme/Enzyme/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,10 @@ llvm::Function *getOrInsertMemcpyStrided(llvm::Module &M, llvm::PointerType *T,
llvm::Type *IT, unsigned dstalign,
unsigned srcalign);

/// Create function for type that performs memset with a stride
llvm::Function *getOrInsertMemsetStrided(llvm::Module &M, llvm::PointerType *T,
llvm::Type *IT, unsigned align);

/// Create function for type that performs the derivative memmove on floating
/// point memory
llvm::Function *
Expand Down
85 changes: 0 additions & 85 deletions enzyme/test/Enzyme/ReverseMode/blas/cblas_daxpy.ll

This file was deleted.

36 changes: 36 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/blas/cblas_dcopy.ll
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
;RUN: %opt < %s %loadEnzyme -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s

@enzyme_const = common global i32 0, align 4

define void @wrapper(i32 %n, double* %x, i32 %incx, double* %y, i32 %incy) {
entry:
tail call void @cblas_dcopy(i32 %n, double* %x, i32 %incx, double* %y, i32 %incy)
Expand All @@ -24,6 +26,20 @@ entry:

declare void @__enzyme_autodiff(i8*, ...)

define void @inactiveX(i32 %n, double* %x, double* nocapture readnone %_x, i32 %incx, double* %y, double* %_y, i32 %incy) {
entry:
%0 = load i32, i32* @enzyme_const, align 4
tail call void (i8*, ...) @__enzyme_autodiff(i8* bitcast (void (i32, double*, i32, double*, i32)* @wrapper to i8*), i32 %n, i32 %0, double* %x, i32 %incx, double* %y, double* %_y, i32 %incy)
ret void
}

define void @inactiveY(i32 %n, double* %x, double* %_x, i32 %incx, double* %y, double* nocapture readnone %_y, i32 %incy) {
entry:
%0 = load i32, i32* @enzyme_const, align 4
tail call void (i8*, ...) @__enzyme_autodiff(i8* bitcast (void (i32, double*, i32, double*, i32)* @wrapper to i8*), i32 %n, double* %x, double* %_x, i32 %incx, i32 %0, double* %y, i32 %incy)
ret void
}

define void @activeMod(i32 %n, double* %x, double* %_x, i32 %incx, double* %y, double* %_y, i32 %incy) {
entry:
tail call void (i8*, ...) @__enzyme_autodiff(i8* bitcast (void (i32, double*, i32, double*, i32)* @wrapperMod to i8*), i32 %n, double* %x, double* %_x, i32 %incx, double* %y, double* %_y, i32 %incy)
Expand All @@ -34,6 +50,14 @@ entry:
;CHECK-NEXT: entry
;CHECK-NEXT: call void @[[active:.+]](

;CHECK: define void @inactiveX
;CHECK-NEXT: entry
;CHECK-NEXT: call void @[[inactiveX:.+]](

;CHECK: define void @inactiveY
;CHECK-NEXT: entry
;CHECK-NEXT: call void @[[inactiveY:.+]](

;CHECK: define void @activeMod
;CHECK-NEXT: entry
;CHECK-NEXT: call void @[[activeMod:.+]](
Expand All @@ -45,6 +69,18 @@ entry:
;CHECK-NEXT: ret void
;CHECK-NEXT:}

;CHECK:define internal void @[[inactiveX]](i32 %n, double* %x, i32 %incx, double* %y, double* %"y'", i32 %incy)
;CHECK-NEXT:entry:
;CHECK-NEXT: tail call void @cblas_dcopy(i32 %n, double* %x, i32 %incx, double* %y, i32 %incy)
;CHECK-NEXT: ret void
;CHECK-NEXT:}

;CHECK:define internal void @[[inactiveY]](i32 %n, double* %x, double* %"x'", i32 %incx, double* %y, i32 %incy)
;CHECK-NEXT:entry:
;CHECK-NEXT: tail call void @cblas_dcopy(i32 %n, double* %x, i32 %incx, double* %y, i32 %incy)
;CHECK-NEXT: ret void
;CHECK-NEXT:}

;CHECK:define internal void @[[activeMod]](i32 %n, double* %x, double* %"x'", i32 %incx, double* %y, double* %"y'", i32 %incy)
;CHECK-NEXT:entry:
;CHECK-NEXT: tail call void @cblas_dcopy(i32 %n, double* %x, i32 %incx, double* %y, i32 %incy)
Expand Down
61 changes: 0 additions & 61 deletions enzyme/test/Enzyme/ReverseMode/blas/cblas_dnrm2.ll

This file was deleted.

85 changes: 0 additions & 85 deletions enzyme/test/Enzyme/ReverseMode/blas/cblas_dscal.ll

This file was deleted.

Loading

0 comments on commit 407999b

Please sign in to comment.