Skip to content

Commit

Permalink
Fixup diff malloc fb (#1872)
Browse files Browse the repository at this point in the history
* Diffuse malloc fb

* fixup

* fixup
  • Loading branch information
wsmoses committed May 10, 2024
1 parent 2250522 commit e96ccd2
Show file tree
Hide file tree
Showing 6 changed files with 319 additions and 135 deletions.
135 changes: 8 additions & 127 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -6288,134 +6288,15 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
return;

bool useConstantFallback =
gutils->isConstantInstruction(&call) &&
(gutils->isConstantValue(&call) || !shadowReturnUsed);
if (useConstantFallback && Mode != DerivativeMode::ForwardMode &&
Mode != DerivativeMode::ForwardModeError) {
// if there is an escaping allocation, which is deduced needed in
// reverse pass, we need to do the recursive procedure to perform the
// free.

// First test if the return is a potential pointer and needed for the
// reverse pass
bool escapingNeededAllocation = false;

if (!isNoEscapingAllocation(&call)) {
escapingNeededAllocation = EnzymeGlobalActivity;

std::map<UsageKey, bool> CacheResults;
for (auto pair : gutils->knownRecomputeHeuristic) {
if (!pair.second || gutils->unnecessaryIntermediates.count(
cast<Instruction>(pair.first))) {
CacheResults[UsageKey(pair.first, QueryType::Primal)] = false;
}
}

if (!escapingNeededAllocation &&
!(EnzymeJuliaAddrLoad && isSpecialPtr(call.getType()))) {
if (TR.query(&call)[{-1}].isPossiblePointer()) {
auto found = gutils->knownRecomputeHeuristic.find(&call);
if (found != gutils->knownRecomputeHeuristic.end()) {
if (!found->second) {
CacheResults.erase(UsageKey(&call, QueryType::Primal));
escapingNeededAllocation =
DifferentialUseAnalysis::is_value_needed_in_reverse<
QueryType::Primal>(gutils, &call,
DerivativeMode::ReverseModeGradient,
CacheResults, oldUnreachable);
}
} else {
escapingNeededAllocation =
DifferentialUseAnalysis::is_value_needed_in_reverse<
QueryType::Primal>(gutils, &call,
DerivativeMode::ReverseModeGradient,
CacheResults, oldUnreachable);
}
}
}

// Next test if any allocation could be stored into one of the
// arguments.
if (!escapingNeededAllocation)
#if LLVM_VERSION_MAJOR >= 14
for (unsigned i = 0; i < call.arg_size(); ++i)
#else
for (unsigned i = 0; i < call.getNumArgOperands(); ++i)
#endif
{
Value *a = call.getOperand(i);

if (EnzymeJuliaAddrLoad && isSpecialPtr(a->getType()))
continue;

auto vd = TR.query(a);
if (!vd[{-1}].isPossiblePointer())
continue;

if (!vd[{-1, -1}].isPossiblePointer())
continue;

if (isReadOnly(&call, i))
continue;

// An allocation could only be needed in the reverse pass if it
// escapes into an argument. However, is the parameter by which it
// escapes could capture the pointer, the rest of Enzyme's caching
// mechanisms cannot assume that the allocation itself is
// reloadable, since it may have been captured and overwritten
// elsewhere.
// TODO: this justification will need revisiting in the future as
// the caching algorithm becomes increasingly sophisticated.
if (!isNoCapture(&call, i))
continue;

escapingNeededAllocation = true;
}
}

// If desired this can become even more aggressive by looking through the
// called function for any allocations.
if (auto F = getFunctionFromCall(&call)) {
SmallVector<Function *, 1> todo = {F};
SmallPtrSet<Function *, 1> done;
bool seenAllocation = false;
while (todo.size() && !seenAllocation) {
auto cur = todo.pop_back_val();
if (done.count(cur))
continue;
done.insert(cur);
// assume empty functions allocate.
if (cur->empty()) {
// unless they are marked
if (isNoEscapingAllocation(cur))
continue;
seenAllocation = true;
break;
}
for (auto &BB : *cur)
for (auto &I : BB)
if (auto CB = dyn_cast<CallBase>(&I)) {
if (isNoEscapingAllocation(CB))
continue;
if (isAllocationCall(CB, gutils->TLI)) {
seenAllocation = true;
goto finish;
}
if (auto F = getFunctionFromCall(CB)) {
todo.push_back(F);
continue;
}
// Conservatively assume indirect functions allocate.
seenAllocation = true;
goto finish;
}
finish:;
}
if (!seenAllocation)
escapingNeededAllocation = false;
DifferentialUseAnalysis::callShouldNotUseDerivative(gutils, call);
if (!useConstantFallback) {
if (gutils->isConstantInstruction(&call) &&
gutils->isConstantValue(&call)) {
EmitWarning("ConstnatFallback", call,
"Call was deduced inactive but still doing differential "
"rewrite as it may escape an allocation",
call);
}
if (escapingNeededAllocation)
useConstantFallback = false;
}
if (useConstantFallback) {
if (!gutils->isConstantValue(&call)) {
Expand Down
156 changes: 154 additions & 2 deletions enzyme/Enzyme/DifferentialUseAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <set>

#include "DifferentialUseAnalysis.h"
#include "Utils.h"

#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Instruction.h"
Expand Down Expand Up @@ -721,8 +722,13 @@ bool DifferentialUseAnalysis::is_use_directly_needed_in_reverse(
return false;
}

bool neededFB = !gutils->isConstantInstruction(user) ||
!gutils->isConstantValue(const_cast<Instruction *>(user));
bool neededFB = false;
if (auto CB = dyn_cast<CallBase>(const_cast<Instruction *>(user))) {
neededFB = !callShouldNotUseDerivative(gutils, *CB);
} else {
neededFB = !gutils->isConstantInstruction(user) ||
!gutils->isConstantValue(const_cast<Instruction *>(user));
}
if (neededFB) {
if (EnzymePrintDiffUse)
llvm::errs() << " Need direct primal of " << *val
Expand Down Expand Up @@ -960,3 +966,149 @@ void DifferentialUseAnalysis::minCut(const DataLayout &DL, LoopInfo &OrigLI,
}
return;
}

bool DifferentialUseAnalysis::callShouldNotUseDerivative(
const GradientUtils *gutils, CallBase &call) {
bool shadowReturnUsed = false;
auto smode = gutils->mode;
if (smode == DerivativeMode::ReverseModeGradient)
smode = DerivativeMode::ReverseModePrimal;
(void)gutils->getReturnDiffeType(&call, nullptr, &shadowReturnUsed, smode);

bool useConstantFallback =
gutils->isConstantInstruction(&call) &&
(gutils->isConstantValue(&call) || !shadowReturnUsed);
if (useConstantFallback && gutils->mode != DerivativeMode::ForwardMode &&
gutils->mode != DerivativeMode::ForwardModeError) {
// if there is an escaping allocation, which is deduced needed in
// reverse pass, we need to do the recursive procedure to perform the
// free.

// First test if the return is a potential pointer and needed for the
// reverse pass
bool escapingNeededAllocation = false;

if (!isNoEscapingAllocation(&call)) {
escapingNeededAllocation = EnzymeGlobalActivity;

std::map<UsageKey, bool> CacheResults;
for (auto pair : gutils->knownRecomputeHeuristic) {
if (!pair.second || gutils->unnecessaryIntermediates.count(
cast<Instruction>(pair.first))) {
CacheResults[UsageKey(pair.first, QueryType::Primal)] = false;
}
}

if (!escapingNeededAllocation &&
!(EnzymeJuliaAddrLoad && isSpecialPtr(call.getType()))) {
if (gutils->TR.anyPointer(&call)) {
auto found = gutils->knownRecomputeHeuristic.find(&call);
if (found != gutils->knownRecomputeHeuristic.end()) {
if (!found->second) {
CacheResults.erase(UsageKey(&call, QueryType::Primal));
escapingNeededAllocation =
DifferentialUseAnalysis::is_value_needed_in_reverse<
QueryType::Primal>(gutils, &call,
DerivativeMode::ReverseModeGradient,
CacheResults, gutils->notForAnalysis);
}
} else {
escapingNeededAllocation =
DifferentialUseAnalysis::is_value_needed_in_reverse<
QueryType::Primal>(gutils, &call,
DerivativeMode::ReverseModeGradient,
CacheResults, gutils->notForAnalysis);
}
}
}

// Next test if any allocation could be stored into one of the
// arguments.
if (!escapingNeededAllocation)
#if LLVM_VERSION_MAJOR >= 14
for (unsigned i = 0; i < call.arg_size(); ++i)
#else
for (unsigned i = 0; i < call.getNumArgOperands(); ++i)
#endif
{
Value *a = call.getOperand(i);

if (EnzymeJuliaAddrLoad && isSpecialPtr(a->getType()))
continue;

if (!gutils->TR.anyPointer(a))
continue;

auto vd = gutils->TR.query(a);

if (!vd[{-1, -1}].isPossiblePointer())
continue;

if (isReadOnly(&call, i))
continue;

// An allocation could only be needed in the reverse pass if it
// escapes into an argument. However, is the parameter by which it
// escapes could capture the pointer, the rest of Enzyme's caching
// mechanisms cannot assume that the allocation itself is
// reloadable, since it may have been captured and overwritten
// elsewhere.
// TODO: this justification will need revisiting in the future as
// the caching algorithm becomes increasingly sophisticated.
if (!isNoCapture(&call, i))
continue;

escapingNeededAllocation = true;
}
}

// If desired this can become even more aggressive by looking through the
// called function for any allocations.
if (auto F = getFunctionFromCall(&call)) {
SmallVector<Function *, 1> todo = {F};
SmallPtrSet<Function *, 1> done;
bool seenAllocation = false;
while (todo.size() && !seenAllocation) {
auto cur = todo.pop_back_val();
if (done.count(cur))
continue;
done.insert(cur);
// assume empty functions allocate.
if (cur->empty()) {
// unless they are marked
if (isNoEscapingAllocation(cur))
continue;
seenAllocation = true;
break;
}
auto UR = getGuaranteedUnreachable(cur);
for (auto &BB : *cur) {
if (UR.count(&BB))
continue;
for (auto &I : BB)
if (auto CB = dyn_cast<CallBase>(&I)) {
if (isNoEscapingAllocation(CB))
continue;
if (isAllocationCall(CB, gutils->TLI)) {
seenAllocation = true;
goto finish;
}
if (auto F = getFunctionFromCall(CB)) {
todo.push_back(F);
continue;
}
// Conservatively assume indirect functions allocate.
seenAllocation = true;
goto finish;
}
}
finish:;
}
if (!seenAllocation)
escapingNeededAllocation = false;
}
if (escapingNeededAllocation)
useConstantFallback = false;
}
return useConstantFallback;
}
5 changes: 5 additions & 0 deletions enzyme/Enzyme/DifferentialUseAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,11 @@ forEachDifferentialUser(llvm::function_ref<void(llvm::Value *)> f,
}
}
}

//! Return whether or not this is a constant and should use reverse pass
bool callShouldNotUseDerivative(const GradientUtils *gutils,
llvm::CallBase &orig);

}; // namespace DifferentialUseAnalysis

#endif
4 changes: 2 additions & 2 deletions enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5922,7 +5922,7 @@ bool TypeResults::anyFloat(Value *val, bool anythingIsFloat) const {
if (dt != BaseType::Anything && dt != BaseType::Unknown)
return dt.isFloat();

if (val->getType()->isTokenTy())
if (val->getType()->isTokenTy() || val->getType()->isVoidTy())
return false;
auto &dl = analyzer->fntypeinfo.Function->getParent()->getDataLayout();
SmallSet<size_t, 8> offs;
Expand Down Expand Up @@ -5958,7 +5958,7 @@ bool TypeResults::anyPointer(Value *val) const {
auto dt = q[{-1}];
if (dt != BaseType::Anything && dt != BaseType::Unknown)
return dt == BaseType::Pointer;
if (val->getType()->isTokenTy())
if (val->getType()->isTokenTy() || val->getType()->isVoidTy())
return false;

auto &dl = analyzer->fntypeinfo.Function->getParent()->getDataLayout();
Expand Down
Loading

0 comments on commit e96ccd2

Please sign in to comment.