Skip to content

Commit

Permalink
Add jl simplify pass (#1880)
Browse files Browse the repository at this point in the history
* Add jl simplify pass

* Freeze fix
  • Loading branch information
wsmoses committed May 14, 2024
1 parent de30014 commit 759d272
Show file tree
Hide file tree
Showing 10 changed files with 420 additions and 19 deletions.
5 changes: 5 additions & 0 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3274,6 +3274,7 @@ class EnzymeNewPM final : public EnzymeBase,
AnalysisKey EnzymeNewPM::Key;

#include "ActivityAnalysisPrinter.h"
#include "JLInstSimplify.h"
#include "PreserveNVVM.h"
#include "TypeAnalysis/TypeAnalysisPrinter.h"
#include "llvm/Passes/PassBuilder.h"
Expand Down Expand Up @@ -3833,6 +3834,10 @@ void registerEnzyme(llvm::PassBuilder &PB) {
FPM.addPass(ActivityAnalysisPrinterNewPM());
return true;
}
if (Name == "jl-inst-simplify") {
FPM.addPass(JLInstSimplifyNewPM());
return true;
}
return false;
});
}
Expand Down
201 changes: 201 additions & 0 deletions enzyme/Enzyme/JLInstSimplify.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
//=- JLInstSimplify.h - Additional instsimplifyrules for julia programs =//
//
// Enzyme Project
//
// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// If using this code in an academic setting, please cite the following:
// @incollection{enzymeNeurips,
// title = {Instead of Rewriting Foreign Code for Machine Learning,
// Automatically Synthesize Fast Gradients},
// author = {Moses, William S. and Churavy, Valentin},
// booktitle = {Advances in Neural Information Processing Systems 33},
// year = {2020},
// note = {To appear in},
// }
//
//===----------------------------------------------------------------------===//
//
// This file contains a utility LLVM pass for printing derived Activity Analysis
// results of a given function.
//
//===----------------------------------------------------------------------===//
#include <llvm/Config/llvm-config.h>

#include "llvm/ADT/SmallVector.h"

#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Metadata.h"

#include "llvm/IR/LegacyPassManager.h"

#include "llvm/Support/Debug.h"

#include "llvm/Analysis/TargetLibraryInfo.h"

#include "llvm-c/Core.h"
#include "llvm-c/DataTypes.h"

#include "llvm-c/ExternC.h"
#include "llvm-c/Types.h"

#include "JLInstSimplify.h"
#include "Utils.h"

using namespace llvm;
#ifdef DEBUG_TYPE
#undef DEBUG_TYPE
#endif
#define DEBUG_TYPE "jl-inst-simplify"
namespace {

bool jlInstSimplify(llvm::Function &F, TargetLibraryInfo &TLI,
llvm::AAResults &AA, llvm::LoopInfo &LI) {
bool changed = false;

for (auto &BB : F)
for (auto &I : BB) {
if (auto FI = dyn_cast<FreezeInst>(&I)) {
if (FI->hasOneUse()) {
bool allBranch = true;
for (auto user : FI->users()) {
if (!isa<BranchInst>(user)) {
allBranch = false;
break;
}
}
if (allBranch) {
FI->replaceAllUsesWith(FI->getOperand(0));
changed = true;
continue;
}
}
}
if (auto cmp = dyn_cast<ICmpInst>(&I)) {
if (cmp->use_empty())
continue;
auto lhs = getBaseObject(cmp->getOperand(0), /*offsetAllowed*/ false);
auto rhs = getBaseObject(cmp->getOperand(1), /*offsetAllowed*/ false);
if (lhs == rhs) {
cmp->replaceAllUsesWith(cmp->isTrueWhenEqual()
? ConstantInt::getTrue(F.getContext())
: ConstantInt::getFalse(F.getContext()));
changed = true;
continue;
}
if ((isNoAlias(lhs) && (isNoAlias(rhs) || isa<Argument>(rhs))) ||
(isNoAlias(rhs) && isa<Argument>(lhs))) {
cmp->replaceAllUsesWith(cmp->isTrueWhenEqual()
? ConstantInt::getFalse(F.getContext())
: ConstantInt::getTrue(F.getContext()));
changed = true;
continue;
}
auto llhs = dyn_cast<LoadInst>(lhs);
auto lrhs = dyn_cast<LoadInst>(rhs);
if (llhs && lrhs) {
auto lhsv =
getBaseObject(llhs->getOperand(0), /*offsetAllowed*/ false);
auto rhsv =
getBaseObject(lrhs->getOperand(0), /*offsetAllowed*/ false);
if ((isNoAlias(lhsv) && (isNoAlias(rhsv) || isa<Argument>(rhsv))) ||
(isNoAlias(rhsv) && isa<Argument>(lhsv))) {
bool legal = false;
for (int i = 0; i < 2; i++) {
Value *start = (i == 0) ? lhsv : rhsv;
Instruction *starti = dyn_cast<Instruction>(start);
if (!starti) {
assert(isa<Argument>(starti));
starti = &cast<Argument>(starti)
->getParent()
->getEntryBlock()
.front();
}

bool overwritten = false;
allInstructionsBetween(
LI, starti, cmp, [&](Instruction *I) -> bool {
if (!I->mayWriteToMemory())
return /*earlyBreak*/ false;

for (auto LI : {llhs, lrhs})
if (writesToMemoryReadBy(AA, TLI,
/*maybeReader*/ LI,
/*maybeWriter*/ I)) {
overwritten = true;
return /*earlyBreak*/ true;
}
return /*earlyBreak*/ false;
});
if (!overwritten) {
legal = true;
break;
}
}

if (legal && lhsv != rhsv) {
cmp->replaceAllUsesWith(
cmp->isTrueWhenEqual()
? ConstantInt::getFalse(F.getContext())
: ConstantInt::getTrue(F.getContext()));
changed = true;
continue;
}
}
}
}
}
return changed;
}

class JLInstSimplify final : public FunctionPass {
public:
static char ID;
JLInstSimplify() : FunctionPass(ID) {}

void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<TargetLibraryInfoWrapperPass>();
AU.addRequired<AAResultsWrapperPass>();
AU.addRequired<LoopInfoWrapperPass>();
}

bool runOnFunction(Function &F) override {
auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
return jlInstSimplify(F, TLI, AA, LI);
}
};

} // namespace

FunctionPass *createJLInstSimplifyPass() { return new JLInstSimplify(); }

extern "C" void LLVMAddJLInstSimplifyPass(LLVMPassManagerRef PM) {
unwrap(PM)->add(createJLInstSimplifyPass());
}

char JLInstSimplify::ID = 0;

static RegisterPass<JLInstSimplify> X("jl-inst-simplify",
"JL instruction simplification");

JLInstSimplifyNewPM::Result
JLInstSimplifyNewPM::run(llvm::Function &F,
llvm::FunctionAnalysisManager &FAM) {
bool changed = false;
changed = jlInstSimplify(F, FAM.getResult<TargetLibraryAnalysis>(F),
FAM.getResult<AAManager>(F),
FAM.getResult<LoopAnalysis>(F));
return changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
}
llvm::AnalysisKey JLInstSimplifyNewPM::Key;
43 changes: 43 additions & 0 deletions enzyme/Enzyme/JLInstSimplify.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
//=- JLInstSimplify.h - Additional instsimplifyrules for julia programs =//
//
// Enzyme Project
//
// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// If using this code in an academic setting, please cite the following:
// @incollection{enzymeNeurips,
// title = {Instead of Rewriting Foreign Code for Machine Learning,
// Automatically Synthesize Fast Gradients},
// author = {Moses, William S. and Churavy, Valentin},
// booktitle = {Advances in Neural Information Processing Systems 33},
// year = {2020},
// note = {To appear in},
// }
//
//===----------------------------------------------------------------------===//
#include <llvm/Config/llvm-config.h>

#include "llvm/IR/PassManager.h"
#include "llvm/Passes/PassPlugin.h"

namespace llvm {
class FunctionPass;
}

class JLInstSimplifyNewPM final
: public llvm::AnalysisInfoMixin<JLInstSimplifyNewPM> {
friend struct llvm::AnalysisInfoMixin<JLInstSimplifyNewPM>;

private:
static llvm::AnalysisKey Key;

public:
using Result = llvm::PreservedAnalyses;
JLInstSimplifyNewPM() {}

Result run(llvm::Function &M, llvm::FunctionAnalysisManager &MAM);

static bool isRequired() { return true; }
};
6 changes: 6 additions & 0 deletions enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2043,6 +2043,12 @@ bool writesToMemoryReadBy(llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI,
if (funcName == "jl_array_copy" || funcName == "ijl_array_copy")
return false;

if (funcName == "jl_new_array" || funcName == "ijl_new_array")
return false;

if (funcName == "julia.safepoint")
return false;

if (funcName == "jl_idtable_rehash" || funcName == "ijl_idtable_rehash")
return false;

Expand Down
66 changes: 47 additions & 19 deletions enzyme/Enzyme/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1335,18 +1335,23 @@ static inline bool isPointerArithmeticInst(const llvm::Value *V,
return false;
}

static inline llvm::Value *getBaseObject(llvm::Value *V) {
static inline llvm::Value *getBaseObject(llvm::Value *V,
bool offsetAllowed = true) {
while (true) {
if (auto CI = llvm::dyn_cast<llvm::CastInst>(V)) {
V = CI->getOperand(0);
continue;
} else if (auto CI = llvm::dyn_cast<llvm::GetElementPtrInst>(V)) {
V = CI->getOperand(0);
continue;
if (offsetAllowed || CI->hasAllZeroIndices()) {
V = CI->getOperand(0);
continue;
}
} else if (auto II = llvm::dyn_cast<llvm::IntrinsicInst>(V);
II && isIntelSubscriptIntrinsic(*II)) {
V = II->getOperand(3);
continue;
if (offsetAllowed) {
V = II->getOperand(3);
continue;
}
} else if (auto CI = llvm::dyn_cast<llvm::PHINode>(V)) {
if (CI->getNumIncomingValues() == 1) {
V = CI->getOperand(0);
Expand All @@ -1366,7 +1371,7 @@ static inline llvm::Value *getBaseObject(llvm::Value *V) {
auto funcName = getFuncNameFromCall(Call);
auto AttrList = Call->getAttributes().getAttributes(
llvm::AttributeList::FunctionIndex);
if (AttrList.hasAttribute("enzyme_pointermath")) {
if (AttrList.hasAttribute("enzyme_pointermath") && offsetAllowed) {
size_t res = 0;
bool failed = AttrList.getAttribute("enzyme_pointermath")
.getValueAsString()
Expand Down Expand Up @@ -1398,7 +1403,7 @@ static inline llvm::Value *getBaseObject(llvm::Value *V) {
if (auto fn = getFunctionFromCall(Call)) {
auto AttrList = fn->getAttributes().getAttributes(
llvm::AttributeList::FunctionIndex);
if (AttrList.hasAttribute("enzyme_pointermath")) {
if (AttrList.hasAttribute("enzyme_pointermath") && offsetAllowed) {
size_t res = 0;
bool failed = AttrList.getAttribute("enzyme_pointermath")
.getValueAsString()
Expand Down Expand Up @@ -1428,24 +1433,27 @@ static inline llvm::Value *getBaseObject(llvm::Value *V) {
// because it should be in sync with CaptureTracking. Not using it may
// cause weird miscompilations where 2 aliasing pointers are assumed to
// noalias.
if (auto *RP = llvm::getArgumentAliasingToReturnedPointer(Call, false)) {
V = RP;
continue;
}
if (offsetAllowed)
if (auto *RP =
llvm::getArgumentAliasingToReturnedPointer(Call, false)) {
V = RP;
continue;
}
}

if (auto I = llvm::dyn_cast<llvm::Instruction>(V)) {
if (offsetAllowed)
if (auto I = llvm::dyn_cast<llvm::Instruction>(V)) {
#if LLVM_VERSION_MAJOR >= 12
auto V2 = llvm::getUnderlyingObject(I, 100);
auto V2 = llvm::getUnderlyingObject(I, 100);
#else
auto V2 = llvm::GetUnderlyingObject(
I, I->getParent()->getParent()->getParent()->getDataLayout(), 100);
auto V2 = llvm::GetUnderlyingObject(
I, I->getParent()->getParent()->getParent()->getDataLayout(), 100);
#endif
if (V2 != V) {
V = V2;
break;
if (V2 != V) {
V = V2;
break;
}
}
}
break;
}
return V;
Expand Down Expand Up @@ -1560,6 +1568,26 @@ static inline bool isNoCapture(const llvm::CallBase *call, size_t idx) {
return false;
}

static inline bool isNoAlias(const llvm::CallBase *call) {
if (call->returnDoesNotAlias())
return true;

if (auto F = getFunctionFromCall(call)) {
if (F->returnDoesNotAlias())
return true;
}
return false;
}

static inline bool isNoAlias(const llvm::Value *val) {
if (auto CB = llvm::dyn_cast<llvm::CallBase>(val))
return isNoAlias(CB);
if (auto arg = llvm::dyn_cast<llvm::Argument>(val)) {
arg->hasNoAliasAttr();
}
return false;
}

static inline bool isNoEscapingAllocation(const llvm::Function *F) {
if (F->hasFnAttribute("enzyme_no_escaping_allocation"))
return true;
Expand Down
1 change: 1 addition & 0 deletions enzyme/test/Enzyme/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_subdirectory(ForwardModeSplit)
add_subdirectory(ForwardModeVector)
add_subdirectory(BatchMode)
add_subdirectory(ProbProg)
add_subdirectory(JLSimplify)

# Run regression and unit tests
add_lit_testsuite(check-enzyme "Running enzyme regression tests"
Expand Down
Loading

0 comments on commit 759d272

Please sign in to comment.