Skip to content

Commit

Permalink
Fixed shadow return mixed (#1936)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Jun 23, 2024
1 parent be1d460 commit 59e4408
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 51 deletions.
20 changes: 11 additions & 9 deletions enzyme/Enzyme/DiffeGradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,13 @@ DiffeGradientUtils::DiffeGradientUtils(
ValueToValueMapTy &invertedPointers_,
const SmallPtrSetImpl<Value *> &constantvalues_,
const SmallPtrSetImpl<Value *> &returnvals_, DIFFE_TYPE ActiveReturn,
ArrayRef<DIFFE_TYPE> constant_values,
bool shadowReturnUsed, ArrayRef<DIFFE_TYPE> constant_values,
llvm::ValueMap<const llvm::Value *, AssertingReplacingVH> &origToNew_,
DerivativeMode mode, unsigned width, bool omp)
: GradientUtils(Logic, newFunc_, oldFunc_, TLI, TA, TR, invertedPointers_,
constantvalues_, returnvals_, ActiveReturn, constant_values,
origToNew_, mode, width, omp) {
constantvalues_, returnvals_, ActiveReturn,
shadowReturnUsed, constant_values, origToNew_, mode, width,
omp) {
if (oldFunc_->empty())
return;
assert(reverseBlocks.size() == 0);
Expand All @@ -85,8 +86,9 @@ DiffeGradientUtils::DiffeGradientUtils(
DiffeGradientUtils *DiffeGradientUtils::CreateFromClone(
EnzymeLogic &Logic, DerivativeMode mode, unsigned width, Function *todiff,
TargetLibraryInfo &TLI, TypeAnalysis &TA, FnTypeInfo &oldTypeInfo,
DIFFE_TYPE retType, bool diffeReturnArg, ArrayRef<DIFFE_TYPE> constant_args,
ReturnType returnValue, Type *additionalArg, bool omp) {
DIFFE_TYPE retType, bool shadowReturn, bool diffeReturnArg,
ArrayRef<DIFFE_TYPE> constant_args, ReturnType returnValue,
Type *additionalArg, bool omp) {
Function *oldFunc = todiff;
assert(mode == DerivativeMode::ReverseModeGradient ||
mode == DerivativeMode::ReverseModeCombined ||
Expand Down Expand Up @@ -157,10 +159,10 @@ DiffeGradientUtils *DiffeGradientUtils::CreateFromClone(
if (!oldFunc->empty())
assert(TR.getFunction() == oldFunc);

auto res = new DiffeGradientUtils(Logic, newFunc, oldFunc, TLI, TA, TR,
invertedPointers, constant_values,
nonconstant_values, retType, constant_args,
originalToNew, mode, width, omp);
auto res = new DiffeGradientUtils(
Logic, newFunc, oldFunc, TLI, TA, TR, invertedPointers, constant_values,
nonconstant_values, retType, shadowReturn, constant_args, originalToNew,
mode, width, omp);

return res;
}
Expand Down
6 changes: 4 additions & 2 deletions enzyme/Enzyme/DiffeGradientUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ class DiffeGradientUtils final : public GradientUtils {
llvm::ValueToValueMapTy &invertedPointers_,
const llvm::SmallPtrSetImpl<llvm::Value *> &constantvalues_,
const llvm::SmallPtrSetImpl<llvm::Value *> &returnvals_,
DIFFE_TYPE ActiveReturn, llvm::ArrayRef<DIFFE_TYPE> constant_values,
DIFFE_TYPE ActiveReturn, bool shadowReturnUsed,
llvm::ArrayRef<DIFFE_TYPE> constant_values,
llvm::ValueMap<const llvm::Value *, AssertingReplacingVH> &origToNew_,
DerivativeMode mode, unsigned width, bool omp);

Expand All @@ -79,7 +80,8 @@ class DiffeGradientUtils final : public GradientUtils {
CreateFromClone(EnzymeLogic &Logic, DerivativeMode mode, unsigned width,
llvm::Function *todiff, llvm::TargetLibraryInfo &TLI,
TypeAnalysis &TA, FnTypeInfo &oldTypeInfo, DIFFE_TYPE retType,
bool diffeReturnArg, llvm::ArrayRef<DIFFE_TYPE> constant_args,
bool shadowReturnArg, bool diffeReturnArg,
llvm::ArrayRef<DIFFE_TYPE> constant_args,
ReturnType returnValue, llvm::Type *additionalArg, bool omp);

llvm::AllocaInst *getDifferential(llvm::Value *val);
Expand Down
4 changes: 2 additions & 2 deletions enzyme/Enzyme/DifferentialUseAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -675,8 +675,8 @@ bool DifferentialUseAnalysis::is_use_directly_needed_in_reverse(

if (shadow) {
if (isa<ReturnInst>(user)) {
if (gutils->ATA->ActiveReturns == DIFFE_TYPE::DUP_ARG ||
gutils->ATA->ActiveReturns == DIFFE_TYPE::DUP_NONEED) {
bool notrev = mode != DerivativeMode::ReverseModeGradient;
if (gutils->shadowReturnUsed && notrev) {

bool inst_cv = gutils->isConstantValue(const_cast<Value *>(val));

Expand Down
8 changes: 5 additions & 3 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4135,7 +4135,8 @@ Function *EnzymeLogic::CreatePrimalAndGradient(

DiffeGradientUtils *gutils = DiffeGradientUtils::CreateFromClone(
*this, key.mode, key.width, key.todiff, TLI, TA, oldTypeInfo, key.retType,
diffeReturnArg, key.constant_args, retVal, key.additionalType, omp);
key.shadowReturnUsed, diffeReturnArg, key.constant_args, retVal,
key.additionalType, omp);

gutils->AtomicAdd = key.AtomicAdd;
gutils->FreeMemory = key.freeMemory;
Expand Down Expand Up @@ -4787,8 +4788,9 @@ Function *EnzymeLogic::CreateForwardDiff(
bool diffeReturnArg = false;

DiffeGradientUtils *gutils = DiffeGradientUtils::CreateFromClone(
*this, mode, width, todiff, TLI, TA, oldTypeInfo, retType, diffeReturnArg,
constant_args, retVal, additionalArg, omp);
*this, mode, width, todiff, TLI, TA, oldTypeInfo, retType,
/*shadowReturn*/ retActive, diffeReturnArg, constant_args, retVal,
additionalArg, omp);

insert_or_assign2<ForwardCacheKey, Function *>(ForwardCachedFunctions, tup,
gutils->newFunc);
Expand Down
9 changes: 5 additions & 4 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ GradientUtils::GradientUtils(
ValueToValueMapTy &invertedPointers_,
const SmallPtrSetImpl<Value *> &constantvalues_,
const SmallPtrSetImpl<Value *> &activevals_, DIFFE_TYPE ReturnActivity,
ArrayRef<DIFFE_TYPE> ArgDiffeTypes_,
bool shadowReturnUsed_, ArrayRef<DIFFE_TYPE> ArgDiffeTypes_,
llvm::ValueMap<const llvm::Value *, AssertingReplacingVH> &originalToNewFn_,
DerivativeMode mode, unsigned width, bool omp)
: CacheUtility(TLI_, newFunc_), Logic(Logic), mode(mode), oldFunc(oldFunc_),
Expand Down Expand Up @@ -194,7 +194,8 @@ GradientUtils::GradientUtils(
tid(nullptr), numThreads(nullptr),
OrigAA(oldFunc_->empty() ? ((AAResults *)nullptr)
: &Logic.PPC.getAAResultsFromFunction(oldFunc_)),
TA(TA_), TR(TR_), omp(omp), width(width), ArgDiffeTypes(ArgDiffeTypes_) {
TA(TA_), TR(TR_), omp(omp), width(width),
shadowReturnUsed(shadowReturnUsed_), ArgDiffeTypes(ArgDiffeTypes_) {
if (oldFunc_->empty())
return;
if (oldFunc_->getSubprogram()) {
Expand Down Expand Up @@ -4342,8 +4343,8 @@ GradientUtils *GradientUtils::CreateFromClone(

auto res = new GradientUtils(
Logic, newFunc, oldFunc, TLI, TA, TR, invertedPointers, constant_values,
nonconstant_values, retType, constant_args, originalToNew,
DerivativeMode::ReverseModePrimal, width, omp);
nonconstant_values, retType, shadowReturnUsed, constant_args,
originalToNew, DerivativeMode::ReverseModePrimal, width, omp);
return res;
}

Expand Down
4 changes: 3 additions & 1 deletion enzyme/Enzyme/GradientUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,8 @@ class GradientUtils : public CacheUtility {
public:
unsigned getWidth() { return width; }

bool shadowReturnUsed;

llvm::ArrayRef<DIFFE_TYPE> ArgDiffeTypes;

public:
Expand All @@ -379,7 +381,7 @@ class GradientUtils : public CacheUtility {
llvm::ValueToValueMapTy &invertedPointers_,
const llvm::SmallPtrSetImpl<llvm::Value *> &constantvalues_,
const llvm::SmallPtrSetImpl<llvm::Value *> &activevals_,
DIFFE_TYPE ReturnActivity,
DIFFE_TYPE ReturnActivity, bool shadowReturnUsed,
llvm::ArrayRef<DIFFE_TYPE> ArgDiffeTypes_,
llvm::ValueMap<const llvm::Value *, AssertingReplacingVH>
&originalToNewFn_,
Expand Down
24 changes: 8 additions & 16 deletions enzyme/test/Enzyme/ReverseMode/duplicatemallocptrloop3.ll
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ attributes #9 = { nounwind }
; CHECK-NEXT: %r_malloccache = bitcast i8* %malloccall to double*
; CHECK-NEXT: %[[malloccall4:.+]] = tail call noalias nonnull dereferenceable(80) dereferenceable_or_null(80) i8* bitcast (i8* (i32)* @malloc to i8* (i64)*)(i64 80)
; CHECK-NEXT: %"a4'ip_phi_malloccache" = bitcast i8* %[[malloccall4]] to double**
; CHECK-NEXT: %[[malloccall8:.+]] = tail call noalias nonnull dereferenceable(80) dereferenceable_or_null(80) i8* bitcast (i8* (i32)* @malloc to i8* (i64)*)(i64 80)
; CHECK-NEXT: %subcache_malloccache = bitcast i8* %[[malloccall8]] to double**
; CHECK-NEXT: br label %loop

; CHECK: loop: ; preds = %loop, %entry
Expand All @@ -78,20 +76,17 @@ attributes #9 = { nounwind }
; CHECK-NEXT: %a10 = getelementptr inbounds double, double* %a0, i32 %0
; CHECK-NEXT: store double* %"a10'ipg", double** %"p3'ipc", align 8
; CHECK-NEXT: store double* %a10, double** %p3, align 8, !alias.scope !{{[0-9]+}}, !noalias !{{[0-9]+}}
; CHECK-NEXT: %a4_augmented = call { double*, double*, double* } @augmented_f(double** %p3, double** %"p3'ipc")
; CHECK-NEXT: %subcache = extractvalue { double*, double*, double* } %a4_augmented, 0
; CHECK-NEXT: %a4 = extractvalue { double*, double*, double* } %a4_augmented, 1
; CHECK-NEXT: %"a4'ac" = extractvalue { double*, double*, double* } %a4_augmented, 2
; CHECK-NEXT: %a4_augmented = call { double*, double* } @augmented_f(double** %p3, double** %"p3'ipc")
; CHECK-NEXT: %a4 = extractvalue { double*, double* } %a4_augmented, 0
; CHECK-NEXT: %"a4'ac" = extractvalue { double*, double* } %a4_augmented, 1
; CHECK-NEXT: %r = load double, double* %a4
; CHECK-NEXT: %m2 = fmul double %r, %r
; CHECK-NEXT: %a13 = getelementptr inbounds double, double* %out, i32 %0
; CHECK-NEXT: store double %m2, double* %a13, align 8
; CHECK-NEXT: %1 = getelementptr inbounds double*, double** %subcache_malloccache, i64 %iv
; CHECK-NEXT: store double* %subcache, double** %1, align 8
; CHECK-NEXT: %2 = getelementptr inbounds double*, double** %"a4'ip_phi_malloccache", i64 %iv
; CHECK-NEXT: store double* %"a4'ac", double** %2, align 8
; CHECK-NEXT: %3 = getelementptr inbounds double, double* %r_malloccache, i64 %iv
; CHECK-NEXT: store double %r, double* %3, align 8
; CHECK-NEXT: %[[i2:.+]] = getelementptr inbounds double*, double** %"a4'ip_phi_malloccache", i64 %iv
; CHECK-NEXT: store double* %"a4'ac", double** %[[i2]], align 8
; CHECK-NEXT: %[[i3:.+]] = getelementptr inbounds double, double* %r_malloccache, i64 %iv
; CHECK-NEXT: store double %r, double* %[[i3]], align 8
; CHECK-NEXT: %a14 = add nuw nsw i32 %0, 1
; CHECK-NEXT: %a15 = icmp eq i32 %a14, 10
; CHECK-NEXT: br i1 %a15, label %invertloop, label %loop
Expand All @@ -101,7 +96,6 @@ attributes #9 = { nounwind }
; CHECK-NEXT: call void @free(i8* nonnull %p2)
; CHECK-NEXT: call void @free(i8* nonnull %malloccall)
; CHECK-NEXT: call void @free(i8* nonnull %[[malloccall4]])
; CHECK-NEXT: call void @free(i8* nonnull %[[malloccall8]])
; CHECK-NEXT: ret void

; CHECK: invertloop: ; preds = %loop, %incinvertloop
Expand All @@ -122,9 +116,7 @@ attributes #9 = { nounwind }
; CHECK-NEXT: store double %[[i11]], double* %[[i9]]
; CHECK-NEXT: %p3_unwrap = bitcast i8* %p2 to double**
; CHECK-NEXT: %"p3'ipc_unwrap" = bitcast i8* %"p2'mi" to double**
; CHECK-NEXT: %[[i12:.+]] = getelementptr inbounds double*, double** %subcache_malloccache, i64 %"iv'ac.0"
; CHECK-NEXT: %[[i13:.+]] = load double*, double** %[[i12]], align 8
; CHECK-NEXT: call void @diffef(double** %p3_unwrap, double** %"p3'ipc_unwrap", double* %[[i13]])
; CHECK-NEXT: call void @diffef(double** %p3_unwrap, double** %"p3'ipc_unwrap")
; CHECK-NEXT: %[[i14:.+]] = icmp eq i64 %"iv'ac.0", 0
; CHECK-NEXT: br i1 %[[i14]], label %invertentry, label %incinvertloop

Expand Down
26 changes: 12 additions & 14 deletions enzyme/test/Enzyme/ReverseMode/globalptr.ll
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ attributes #2 = { nounwind }

; CHECK: define internal { double } @diffemulglobal(double %x, double %differeturn)
; CHECK-NEXT: entry:
; CHECK-NEXT: %call_augmented = call { double*, double*, double* } @augmented_myglobal()
; CHECK: %call = extractvalue { double*, double*, double* } %call_augmented, 1
; CHECK: %"call'ac" = extractvalue { double*, double*, double* } %call_augmented, 2
; CHECK-NEXT: %call_augmented = call { double*, double* } @augmented_myglobal()
; CHECK: %call = extractvalue { double*, double* } %call_augmented, 0
; CHECK: %"call'ac" = extractvalue { double*, double* } %call_augmented, 1
; CHECK-NEXT: %"arrayidx'ipg" = getelementptr inbounds double, double* %"call'ac", i64 2
; CHECK-NEXT: %arrayidx = getelementptr inbounds double, double* %call, i64 2
; CHECK-NEXT: %0 = load double, double* %arrayidx, align 8
Expand All @@ -53,22 +53,20 @@ attributes #2 = { nounwind }
; CHECK-NEXT: ret { double } %[[i3]]
; CHECK-NEXT: }

; CHECK: define internal { double*, double*, double* } @augmented_myglobal()
; CHECK: define internal { double*, double* } @augmented_myglobal()
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = alloca { double*, double*, double* }
; CHECK-NEXT: %1 = getelementptr inbounds { double*, double*, double* }, { double*, double*, double* }* %0, i32 0, i32 0
; CHECK-NEXT: %0 = alloca { double*, double* }
; CHECK-NEXT: %"ptr'ipl" = load double*, double** @dglobal, align 8
; CHECK-NEXT: store double* %"ptr'ipl", double** %1
; CHECK-NEXT: %ptr = load double*, double** @global, align 8
; CHECK-NEXT: %2 = getelementptr inbounds { double*, double*, double* }, { double*, double*, double* }* %0, i32 0, i32 1
; CHECK-NEXT: store double* %ptr, double** %2
; CHECK-NEXT: %3 = getelementptr inbounds { double*, double*, double* }, { double*, double*, double* }* %0, i32 0, i32 2
; CHECK-NEXT: store double* %"ptr'ipl", double** %3
; CHECK-NEXT: %4 = load { double*, double*, double* }, { double*, double*, double* }* %0
; CHECK-NEXT: ret { double*, double*, double* } %4
; CHECK-NEXT: %[[a2:.+]] = getelementptr inbounds { double*, double* }, { double*, double* }* %0, i32 0, i32 0
; CHECK-NEXT: store double* %ptr, double** %[[a2]]
; CHECK-NEXT: %[[a3:.+]] = getelementptr inbounds { double*, double* }, { double*, double* }* %0, i32 0, i32 1
; CHECK-NEXT: store double* %"ptr'ipl", double** %[[a3]]
; CHECK-NEXT: %[[a4:.+]] = load { double*, double* }, { double*, double* }* %0
; CHECK-NEXT: ret { double*, double* } %[[a4]]
; CHECK-NEXT: }

; CHECK: define internal void @diffemyglobal(double* %"ptr'il_phi")
; CHECK: define internal void @diffemyglobal()
; CHECK-NEXT: entry:
; CHECK-NEXT: ret void
; CHECK-NEXT: }

0 comments on commit 59e4408

Please sign in to comment.