Skip to content

Commit

Permalink
Fix shadow return usage (#1939)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jun 23, 2024
1 parent 2a5a879 commit 7fee772
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 15 deletions.
27 changes: 14 additions & 13 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2017,7 +2017,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
return insert_or_assign<AugmentedCacheKey, AugmentedReturn>(
AugmentedCachedFunctions, tup,
AugmentedReturn(newFunc, nullptr, {}, returnMapping, {}, {},
constant_args))
constant_args, shadowReturnUsed))
->second;
}
if (context.req) {
Expand All @@ -2028,7 +2028,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
return insert_or_assign<AugmentedCacheKey, AugmentedReturn>(
AugmentedCachedFunctions, tup,
AugmentedReturn(newFunc, nullptr, {}, returnMapping, {}, {},
constant_args))
constant_args, shadowReturnUsed))
->second;
}
llvm::errs() << "mod: " << *todiff->getParent() << "\n";
Expand Down Expand Up @@ -2149,7 +2149,8 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
AugmentedCachedFunctions, tup,
AugmentedReturn(NewF, aug.tapeType, aug.tapeIndices,
aug.returns, aug.overwritten_args_map,
aug.can_modref_map, next_constant_args))
aug.can_modref_map, next_constant_args,
shadowReturnUsed))
->second;
}

Expand Down Expand Up @@ -2213,7 +2214,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
return insert_or_assign<AugmentedCacheKey, AugmentedReturn>(
AugmentedCachedFunctions, tup,
AugmentedReturn(foundcalled, nullptr, {}, returnMapping, {},
{}, constant_args))
{}, constant_args, shadowReturnUsed))
->second;
}

Expand Down Expand Up @@ -2271,7 +2272,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
return insert_or_assign<AugmentedCacheKey, AugmentedReturn>(
AugmentedCachedFunctions, tup,
AugmentedReturn(foundcalled, nullptr, {}, returnMapping, {},
{}, constant_args))
{}, constant_args, shadowReturnUsed))
->second;
}
if (ST->getNumElements() == 2 &&
Expand All @@ -2282,7 +2283,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
return insert_or_assign<AugmentedCacheKey, AugmentedReturn>(
AugmentedCachedFunctions, tup,
AugmentedReturn(foundcalled, nullptr, {}, returnMapping, {},
{}, constant_args))
{}, constant_args, shadowReturnUsed))
->second;
}
if (ST->getNumElements() == 2) {
Expand Down Expand Up @@ -2336,7 +2337,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
return insert_or_assign<AugmentedCacheKey, AugmentedReturn>(
AugmentedCachedFunctions, tup,
AugmentedReturn(foundcalled, nullptr, {}, returnMapping, {},
{}, constant_args))
{}, constant_args, shadowReturnUsed))
->second;
}
}
Expand All @@ -2348,7 +2349,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
return insert_or_assign<AugmentedCacheKey, AugmentedReturn>(
AugmentedCachedFunctions, tup,
AugmentedReturn(foundcalled, nullptr, {}, returnMapping, {}, {},
constant_args))
constant_args, shadowReturnUsed))
->second; // dyn_cast<StructType>(st->getElementType(0)));
}

Expand Down Expand Up @@ -2393,7 +2394,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
return insert_or_assign<AugmentedCacheKey, AugmentedReturn>(
AugmentedCachedFunctions, tup,
AugmentedReturn(newFunc, nullptr, {}, returnMapping, {}, {},
constant_args))
constant_args, shadowReturnUsed))
->second;
}
if (context.req) {
Expand All @@ -2404,7 +2405,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
return insert_or_assign<AugmentedCacheKey, AugmentedReturn>(
AugmentedCachedFunctions, tup,
AugmentedReturn(newFunc, nullptr, {}, returnMapping, {}, {},
constant_args))
constant_args, shadowReturnUsed))
->second;
}
llvm::errs() << "mod: " << *todiff->getParent() << "\n";
Expand Down Expand Up @@ -2455,7 +2456,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
insert_or_assign(AugmentedCachedFunctions, tup,
AugmentedReturn(gutils->newFunc, nullptr, {}, returnMapping,
overwritten_args_map, can_modref_map,
constant_args));
constant_args, shadowReturnUsed));

auto getIndex = [&](Instruction *I, CacheType u, IRBuilder<> &B) -> unsigned {
return gutils->getIndex(
Expand Down Expand Up @@ -4135,8 +4136,8 @@ Function *EnzymeLogic::CreatePrimalAndGradient(

DiffeGradientUtils *gutils = DiffeGradientUtils::CreateFromClone(
*this, key.mode, key.width, key.todiff, TLI, TA, oldTypeInfo, key.retType,
key.shadowReturnUsed, diffeReturnArg, key.constant_args, retVal,
key.additionalType, omp);
augmenteddata ? augmenteddata->shadowReturnUsed : key.shadowReturnUsed,
diffeReturnArg, key.constant_args, retVal, key.additionalType, omp);

gutils->AtomicAdd = key.AtomicAdd;
gutils->FreeMemory = key.freeMemory;
Expand Down
6 changes: 4 additions & 2 deletions enzyme/Enzyme/EnzymeLogic.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ class AugmentedReturn {

const std::vector<DIFFE_TYPE> constant_args;

bool shadowReturnUsed;

bool isComplete;

AugmentedReturn(
Expand All @@ -129,11 +131,11 @@ class AugmentedReturn {
std::map<AugmentedStruct, int> returns,
std::map<llvm::CallInst *, const std::vector<bool>> overwritten_args_map,
std::map<llvm::Instruction *, bool> can_modref_map,
const std::vector<DIFFE_TYPE> &constant_args)
const std::vector<DIFFE_TYPE> &constant_args, bool shadowReturnUsed)
: fn(fn), tapeType(tapeType), tapeIndices(tapeIndices), returns(returns),
overwritten_args_map(overwritten_args_map),
can_modref_map(can_modref_map), constant_args(constant_args),
isComplete(false) {}
shadowReturnUsed(shadowReturnUsed), isComplete(false) {}
};

/// \p todiff is the function to differentiate
Expand Down
80 changes: 80 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/shadowret.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -early-cse -instsimplify -simplifycfg -S | FileCheck %s; fi
; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,early-cse,instsimplify,%simplifycfg)" -S | FileCheck %s

define internal fastcc nonnull double* @julia_mygetindex_2341(double* %a12) {
top:
%a13 = load double, double* %a12, align 8
%a22 = call noalias nonnull double* @malloc(i64 24)
br label %L8

L8: ; preds = %idxend2, %idxend
%a21 = phi i64 [ 0, %top ], [ %a24, %L8 ]
%a23 = getelementptr inbounds double, double* %a22, i64 %a21
store double %a13, double* %a23, align 8
%a24 = add nuw nsw i64 %a21, 1
%.not5 = icmp ne i64 1, %a24
br i1 %.not5, label %L8, label %L28

L28: ; preds = %idxend2
ret double* %a22
}

define internal fastcc nonnull double* @mydiag(double* %a12) {
top:
%a4 = call fastcc nonnull double* @julia_mygetindex_2341(double* %a12)
ret double* %a4
}

declare void @__enzyme_autodiff(...)

define void @caller(double* %x, double* %dx) {
entry:
call void (...) @__enzyme_autodiff(double (double*)* @f, metadata !"enzyme_dup", double* %x, double* %dx)
ret void
}

define double @f(double* %a13) {
top:
%b13 = call fastcc nonnull double* @mydiag(double* %a13)
%a14 = load double, double* %b13, align 8
ret double %a14
}

declare double* @malloc(i64) local_unnamed_addr

; CHECK: define internal fastcc void @diffejulia_mygetindex_2341(double* %a12, double* %"a12'", { double*, double* } %tapeArg)
; CHECK-NEXT: top:
; CHECK-NEXT: %"a22'mi" = extractvalue { double*, double* } %tapeArg, 0
; CHECK-NEXT: %a22 = extractvalue { double*, double* } %tapeArg, 1
; CHECK-NEXT: br label %L8

; CHECK: L8: ; preds = %L8, %top
; CHECK-NEXT: %iv = phi i64 [ %iv.next, %L8 ], [ 0, %top ]
; CHECK-NEXT: %iv.next = add nuw nsw i64 %iv, 1
; CHECK-NEXT: %.not5 = icmp ne i64 %iv.next, 1
; CHECK-NEXT: br i1 %.not5, label %L8, label %invertL8

; CHECK: inverttop: ; preds = %invertL8
; CHECK-NEXT: %[[a0:.+]] = bitcast double* %"a22'mi" to i8*
; CHECK-NEXT: call void @free(i8* nonnull %[[a0]])
; CHECK-NEXT: %[[a1:.+]] = bitcast double* %a22 to i8*
; CHECK-NEXT: call void @free(i8* %[[a1]])
; CHECK-NEXT: %[[a2:.+]] = load double, double* %"a12'", align 8, !alias.scope !5, !noalias !8
; CHECK-NEXT: %[[a3:.+]] = fadd fast double %[[a2]], %[[a5:.+]]
; CHECK-NEXT: store double %[[a3]], double* %"a12'", align 8, !alias.scope !5, !noalias !8
; CHECK-NEXT: ret void

; CHECK: invertL8: ; preds = %L8, %incinvertL8
; CHECK-NEXT: %"a13'de.0" = phi double [ %[[a5]], %incinvertL8 ], [ 0.000000e+00, %L8 ]
; CHECK-NEXT: %"iv'ac.0" = phi i64 [ %[[a7:.+]], %incinvertL8 ], [ 0, %L8 ]
; CHECK-NEXT: %"a23'ipg_unwrap" = getelementptr inbounds double, double* %"a22'mi", i64 %"iv'ac.0"
; CHECK-NEXT: %[[a4:.+]] = load double, double* %"a23'ipg_unwrap", align 8
; CHECK-NEXT: store double 0.000000e+00, double* %"a23'ipg_unwrap", align 8
; CHECK-NEXT: %[[a5]] = fadd fast double %"a13'de.0", %[[a4]]
; CHECK-NEXT: %[[a6:.+]] = icmp eq i64 %"iv'ac.0", 0
; CHECK-NEXT: br i1 %[[a6]], label %inverttop, label %incinvertL8

; CHECK: incinvertL8: ; preds = %invertL8
; CHECK-NEXT: %[[a7]] = add nsw i64 %"iv'ac.0", -1
; CHECK-NEXT: br label %invertL8
; CHECK-NEXT: }

0 comments on commit 7fee772

Please sign in to comment.