Skip to content

Commit

Permalink
Ensure custom derivative functions aren't deleted
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Feb 12, 2024
1 parent 6b8960d commit 0a9f18d
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 16 deletions.
22 changes: 22 additions & 0 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4473,7 +4473,29 @@ Function *EnzymeLogic::CreateForwardDiff(
"unknown derivative for function -- metadata incorrect");
}
auto md2 = cast<MDTuple>(md);
assert(md2);
assert(md2->getNumOperands() == 1);
if (!md2->getOperand(0)) {
std::string s;
llvm::raw_string_ostream ss(s);
ss << "Failed to use custom forward mode derivative for "
<< todiff->getName() << "\n";
ss << " found metadata (but null op0) " << *md2 << "\n";
EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req,
ss.str());
return ForwardCachedFunctions[tup] = nullptr;
}
if (!isa<ConstantAsMetadata>(md2->getOperand(0))) {
std::string s;
llvm::raw_string_ostream ss(s);
ss << "Failed to use custom forward mode derivative for "
<< todiff->getName() << "\n";
ss << " found metadata (but not constantasmetadata) "
<< *md2->getOperand(0) << "\n";
EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req,
ss.str());
return ForwardCachedFunctions[tup] = nullptr;
}
auto gvemd = cast<ConstantAsMetadata>(md2->getOperand(0));
auto foundcalled = cast<Function>(gvemd->getValue());

Expand Down
40 changes: 24 additions & 16 deletions enzyme/Enzyme/PreserveNVVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,25 @@ using namespace llvm;
#define addAttribute addAttributeAtIndex
#endif

//! Returns whether changed.
bool preserveLinkage(bool Begin, Function &F, bool Inlining = true) {
if (Begin && !F.hasFnAttribute("prev_fixup")) {
F.addFnAttr("prev_fixup");
if (Inlining) {
if (F.hasFnAttribute(Attribute::AlwaysInline))
F.addFnAttr("prev_always_inline");
F.removeFnAttr(Attribute::AlwaysInline);
if (F.hasFnAttribute(Attribute::NoInline))
F.addFnAttr("prev_no_inline");
F.addFnAttr(Attribute::NoInline);
}
F.addFnAttr("prev_linkage", std::to_string(F.getLinkage()));
F.setLinkage(Function::LinkageTypes::ExternalLinkage);
return true;
}
return false;
}

template <const char *handlername, DerivativeMode Mode, int numargs>
static void
handleCustomDerivative(llvm::Module &M, llvm::GlobalVariable &g,
Expand Down Expand Up @@ -237,26 +256,31 @@ handleCustomDerivative(llvm::Module &M, llvm::GlobalVariable &g,
Fs[fn] = NewF;
}

preserveLinkage(true, *Fs[1], false);
Fs[0]->setMetadata(
"enzyme_augment",
llvm::MDTuple::get(Fs[0]->getContext(),
{llvm::ValueAsMetadata::get(Fs[1])}));
preserveLinkage(true, *Fs[2], false);
Fs[0]->setMetadata(
"enzyme_gradient",
llvm::MDTuple::get(Fs[0]->getContext(),
{llvm::ValueAsMetadata::get(Fs[2])}));
} else if (Mode == DerivativeMode::ForwardMode) {
assert(numargs == 2);
preserveLinkage(true, *Fs[1], false);
Fs[0]->setMetadata(
"enzyme_derivative",
llvm::MDTuple::get(Fs[0]->getContext(),
{llvm::ValueAsMetadata::get(Fs[1])}));
} else if (Mode == DerivativeMode::ForwardModeSplit) {
assert(numargs == 3);
preserveLinkage(true, *Fs[1], false);
Fs[0]->setMetadata(
"enzyme_augment",
llvm::MDTuple::get(Fs[0]->getContext(),
{llvm::ValueAsMetadata::get(Fs[1])}));
preserveLinkage(true, *Fs[2], false);
Fs[0]->setMetadata(
"enzyme_splitderivative",
llvm::MDTuple::get(Fs[0]->getContext(),
Expand All @@ -282,22 +306,6 @@ handleCustomDerivative(llvm::Module &M, llvm::GlobalVariable &g,
}
globalsToErase.push_back(&g);
}
//! Returns whether changed.
bool preserveLinkage(bool Begin, Function &F) {
if (Begin && !F.hasFnAttribute("prev_fixup")) {
F.addFnAttr("prev_fixup");
if (F.hasFnAttribute(Attribute::AlwaysInline))
F.addFnAttr("prev_always_inline");
if (F.hasFnAttribute(Attribute::NoInline))
F.addFnAttr("prev_no_inline");
F.addFnAttr("prev_linkage", std::to_string(F.getLinkage()));
F.setLinkage(Function::LinkageTypes::ExternalLinkage);
F.addFnAttr(Attribute::NoInline);
F.removeFnAttr(Attribute::AlwaysInline);
return true;
}
return false;
}

bool preserveNVVM(bool Begin, Function &F) {
bool changed = false;
Expand Down

0 comments on commit 0a9f18d

Please sign in to comment.