Skip to content

Commit

Permalink
Better alias info
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed May 17, 2020
1 parent 3973cfb commit 4ea5488
Show file tree
Hide file tree
Showing 13 changed files with 898 additions and 135 deletions.
2 changes: 1 addition & 1 deletion enzyme/Enzyme/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ file(GLOB ENZYME_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
)

list(APPEND ENZYME_SRC SCEV/ScalarEvolutionExpander.cpp) # Attributor/Attributor.cpp)

message("found enzyme sources " ${ENZYME_SRC})

if (${LLVM_VERSION_MAJOR} LESS 8)
Expand Down
199 changes: 95 additions & 104 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ bool is_value_mustcache_from_origin(Value* obj, AAResults& AA, GradientUtils* gu
mustcache = true;
}
//llvm::errs() << " + argument (mustcache=" << mustcache << ") " << " object: " << *obj << " arg: " << *arg << "e\n";
//TODO this case (alloca goes out of scope/allocation is freed and we dont force it to continue needs to be forcibly cached)
} else {

// Pointer operands originating from call instructions that are not malloc/free are conservatively considered uncacheable.
Expand Down Expand Up @@ -158,34 +157,36 @@ bool is_load_uncacheable(LoadInst& li, AAResults& AA, GradientUtils* gutils, Tar
// Find the underlying object for the pointer operand of the load instruction.
auto obj = GetUnderlyingObject(li.getPointerOperand(), gutils->oldFunc->getParent()->getDataLayout(), 100);

//llvm::errs() << "underlying object for load " << li << " is " << *obj << "\n";

bool can_modref = is_value_mustcache_from_origin(obj, AA, gutils, TLI, uncacheable_args);

allFollowersOf(&li, [&](Instruction* inst2) {
// Don't consider modref from malloc/free as a need to cache
if (auto obj_op = dyn_cast<CallInst>(inst2)) {
Function* called = obj_op->getCalledFunction();
if (auto castinst = dyn_cast<ConstantExpr>(obj_op->getCalledValue())) {
if (castinst->isCast()) {
if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) {
if (isAllocationFunction(*fn, TLI) || isDeallocationFunction(*fn, TLI)) {
called = fn;
//llvm::errs() << "underlying object for load " << li << " is " << *obj << " fromorigin: " << can_modref << "\n";

if (!can_modref) {
allFollowersOf(&li, [&](Instruction* inst2) {
// Don't consider modref from malloc/free as a need to cache
if (auto obj_op = dyn_cast<CallInst>(inst2)) {
Function* called = obj_op->getCalledFunction();
if (auto castinst = dyn_cast<ConstantExpr>(obj_op->getCalledValue())) {
if (castinst->isCast()) {
if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) {
if (isAllocationFunction(*fn, TLI) || isDeallocationFunction(*fn, TLI)) {
called = fn;
}
}
}
}
if (called && isCertainMallocOrFree(called)) {
return;
}
}
if (called && isCertainMallocOrFree(called)) {

if (llvm::isModSet(AA.getModRefInfo(inst2, MemoryLocation::get(&li)))) {
can_modref = true;
return;
}
}

if (llvm::isModSet(AA.getModRefInfo(inst2, MemoryLocation::get(&li)))) {
can_modref = true;
//llvm::errs() << li << " needs to be cached due to: " << *inst2 << "\n";
return;
}
});
});
}

//llvm::errs() << "F - " << li << " can_modref" << can_modref << "\n";
return can_modref;
Expand All @@ -210,9 +211,13 @@ std::map<Instruction*, bool> compute_uncacheable_load_map(GradientUtils* gutils,
std::map<Argument*, bool> compute_uncacheable_args_for_one_callsite(CallInst* callsite_op, DominatorTree &DT,
TargetLibraryInfo &TLI, AAResults& AA, GradientUtils* gutils, const std::map<Argument*, bool> parent_uncacheable_args) {

if (!callsite_op->getCalledFunction()) return {};

std::vector<Value*> args;
std::vector<bool> args_safe;

//llvm::errs() << "CallInst: " << *callsite_op<< "CALL ARGUMENT INFO: \n";

// First, we need to propagate the uncacheable status from the parent function to the callee.
// because memory location x modified after parent returns => x modified after callee returns.
for (unsigned i = 0; i < callsite_op->getNumArgOperands(); i++) {
Expand All @@ -231,78 +236,45 @@ std::map<Argument*, bool> compute_uncacheable_args_for_one_callsite(CallInst* ca

// Second, we check for memory modifications that can occur in the continuation of the
// callee inside the parent function.
for (inst_iterator I = inst_begin(*gutils->oldFunc), E = inst_end(*gutils->oldFunc); I != E; ++I) {
Instruction* inst = &*I;
assert(inst->getParent()->getParent() == callsite_op->getParent()->getParent());

if (inst == callsite_op) continue;

// If the "inst" does not dominate "callsite_op" then we cannot prove that
// "inst" happens before "callsite_op". If "inst" modifies an argument of the call,
// then that call needs to consider the argument uncacheable.
// To correctly handle case where inst == callsite_op, we need to look at next instruction after callsite_op.
if (!gutils->OrigDT.dominates(inst, callsite_op)) {
//llvm::errs() << "Instruction " << *inst << " DOES NOT dominates " << *callsite_op << "\n";
// Consider Store Instructions.
if (auto op = dyn_cast<StoreInst>(inst)) {
for (unsigned i = 0; i < args.size(); i++) {
// If the modification flag is set, then this instruction may modify the $i$th argument of the call.
if (!llvm::isModSet(AA.getModRefInfo(op, MemoryLocation::getForArgument(callsite_op, i, TLI)))) {
//llvm::errs() << "Instruction " << *op << " is NoModRef with call argument " << *args[i] << "\n";
} else {
//llvm::errs() << "Instruction " << *op << " is maybe ModRef with call argument " << *args[i] << "\n";
args_safe[i] = false;
allFollowersOf(callsite_op, [&](Instruction* inst2) {
// Don't consider modref from malloc/free as a need to cache
if (auto obj_op = dyn_cast<CallInst>(inst2)) {
Function* called = obj_op->getCalledFunction();
if (auto castinst = dyn_cast<ConstantExpr>(obj_op->getCalledValue())) {
if (castinst->isCast()) {
if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) {
if (isAllocationFunction(*fn, TLI) || isDeallocationFunction(*fn, TLI)) {
called = fn;
}
}
}
}
if (called && isCertainMallocOrFree(called)) {
return;
}
}

// Consider Call Instructions.
if (auto op = dyn_cast<CallInst>(inst)) {
//llvm::errs() << "OP is call inst: " << *op << "\n";
// Ignore memory allocation functions.
Function* called = op->getCalledFunction();
if (auto castinst = dyn_cast<ConstantExpr>(op->getCalledValue())) {
if (castinst->isCast()) {
if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) {
if (isAllocationFunction(*fn, TLI) || isDeallocationFunction(*fn, TLI)) {
called = fn;
}
}
}
}
if (isCertainMallocOrFree(called)) {
//llvm::errs() << "OP is certain malloc or free: " << *op << "\n";
continue;
}

// For all the arguments, perform same check as for Stores, but ignore non-pointer arguments.
for (unsigned i = 0; i < args.size(); i++) {
if (!args[i]->getType()->isPointerTy()) continue; // Ignore non-pointer arguments.
if (!llvm::isModSet(AA.getModRefInfo(op, MemoryLocation::getForArgument(callsite_op, i, TLI)))) {
//llvm::errs() << "Instruction " << *op << " is NoModRef with call argument " << *args[i] << "\n";
} else {
//llvm::errs() << "Instruction " << *op << " is maybe ModRef with call argument " << *args[i] << "\n";
args_safe[i] = false;
}
}
for (unsigned i = 0; i < args.size(); i++) {
if (llvm::isModSet(AA.getModRefInfo(inst2, MemoryLocation::getForArgument(callsite_op, i, TLI)))) {
args_safe[i] = false;
//llvm::errs() << "Instruction " << *inst2 << " is maybe ModRef with call argument " << *args[i] << "\n";
}
} else {
//llvm::errs() << "Instruction " << *inst << " DOES dominates " << *callsite_op << "\n";
}
}
});

std::map<Argument*, bool> uncacheable_args;
//llvm::errs() << "CallInst: " << *callsite_op<< "CALL ARGUMENT INFO: \n";
if (callsite_op->getCalledFunction()) {

auto arg = callsite_op->getCalledFunction()->arg_begin();
for (unsigned i = 0; i < args.size(); i++) {
uncacheable_args[arg] = !args_safe[i];
//llvm::errs() << "callArg: " << *args[i] << " arg:" << *arg << " STATUS: " << args_safe[i] << "\n";
//llvm::errs() << "callArg: " << *args[i] << " arg:" << *arg << " uncacheable: " << uncacheable_args[arg] << "\n";
arg++;
if (arg ==callsite_op->getCalledFunction()->arg_end()) {
break;
}
}

}
return uncacheable_args;
}

Expand Down Expand Up @@ -661,38 +633,57 @@ bool legalCombinedForwardReverse(CallInst &ci, const std::map<ReturnInst*,StoreI
return false;
}

auto getMRI = [&](Instruction* inst, Instruction* inst2) {
if (auto call = dyn_cast<CallInst>(inst)) {
auto writesToMemoryReadBy = [&](Instruction* maybeReader, Instruction* maybeWriter) -> bool {
if (auto call = dyn_cast<CallInst>(maybeWriter)) {
if (call->getCalledFunction() && isCertainMallocOrFree(call->getCalledFunction())) {
return ModRefInfo::NoModRef;
return false;
}
}
if (auto call = dyn_cast<CallInst>(maybeReader)) {
if (call->getCalledFunction() && isCertainMallocOrFree(call->getCalledFunction())) {
return false;
}
}
if (auto call = dyn_cast<InvokeInst>(maybeWriter)) {
if (call->getCalledFunction() && isCertainMallocOrFree(call->getCalledFunction())) {
return false;
}
}
if (auto call = dyn_cast<InvokeInst>(maybeReader)) {
if (call->getCalledFunction() && isCertainMallocOrFree(call->getCalledFunction())) {
return false;
}
}
assert(maybeWriter->mayWriteToMemory());
assert(maybeReader->mayReadFromMemory());

if (auto li = dyn_cast<LoadInst>(inst2)) {
return gutils->AA.getModRefInfo(inst, MemoryLocation::get(li));
if (auto li = dyn_cast<LoadInst>(maybeReader)) {
return isModSet(gutils->AA.getModRefInfo(maybeWriter, MemoryLocation::get(li)));
}
if (auto si = dyn_cast<StoreInst>(inst2)) {
return gutils->AA.getModRefInfo(inst, MemoryLocation::get(si));
if (auto rmw = dyn_cast<AtomicRMWInst>(maybeReader)) {
return isModSet(gutils->AA.getModRefInfo(maybeWriter, MemoryLocation::get(rmw)));
}
if (auto rmw = dyn_cast<AtomicRMWInst>(inst2)) {
return gutils->AA.getModRefInfo(inst, MemoryLocation::get(rmw));
if (auto xch = dyn_cast<AtomicCmpXchgInst>(maybeReader)) {
return isModSet(gutils->AA.getModRefInfo(maybeWriter, MemoryLocation::get(xch)));
}
if (auto xch = dyn_cast<AtomicCmpXchgInst>(inst2)) {
return gutils->AA.getModRefInfo(inst, MemoryLocation::get(xch));

if (auto si = dyn_cast<StoreInst>(maybeWriter)) {
return isRefSet(gutils->AA.getModRefInfo(maybeReader, MemoryLocation::get(si)));
}
if (auto cb = dyn_cast<CallInst>(inst2)) {
if (cb->getCalledFunction() && isCertainMallocOrFree(cb->getCalledFunction())) {
return ModRefInfo::NoModRef;
}
return gutils->AA.getModRefInfo(inst, cb);
if (auto rmw = dyn_cast<AtomicRMWInst>(maybeWriter)) {
return isRefSet(gutils->AA.getModRefInfo(maybeReader, MemoryLocation::get(rmw)));
}
if (auto cb = dyn_cast<InvokeInst>(inst2)) {
if (cb->getCalledFunction() && isCertainMallocOrFree(cb->getCalledFunction())) {
return ModRefInfo::NoModRef;
}
return gutils->AA.getModRefInfo(inst, cb);
if (auto xch = dyn_cast<AtomicCmpXchgInst>(maybeWriter)) {
return isRefSet(gutils->AA.getModRefInfo(maybeReader, MemoryLocation::get(xch)));
}

if (auto cb = dyn_cast<CallInst>(maybeReader)) {
return isModOrRefSet(gutils->AA.getModRefInfo(maybeWriter, cb));
}
llvm::errs() << " inst2: " << *inst2 << "\n";
if (auto cb = dyn_cast<InvokeInst>(maybeReader)) {
return isModOrRefSet(gutils->AA.getModRefInfo(maybeWriter, cb));
}
llvm::errs() << " maybeReader: " << *maybeReader << " maybeWriter: " << *maybeWriter << "\n";
llvm_unreachable("unknown inst2");
};

Expand Down Expand Up @@ -806,12 +797,11 @@ bool legalCombinedForwardReverse(CallInst &ci, const std::map<ReturnInst*,StoreI
if (inst->mayWriteToMemory()) {
auto consider = [&](Instruction* user) {
if (!user->mayReadFromMemory()) return;
auto mri = getMRI(user, inst);
//llvm::errs() << " checking if need follower of " << *inst << " - " << *user << " : mri " << mri << "\n";
if (isRefSet(mri)) {
if (writesToMemoryReadBy(/*maybeReader*/user, /*maybeWriter*/inst)) {
//llvm::errs() << " memory deduced need follower of " << *inst << " - " << *user << "\n";
propagate(user);
if (!legal) return;
}
}
};
allFollowersOf(inst, consider);
if (!legal) return false;
Expand All @@ -826,13 +816,14 @@ bool legalCombinedForwardReverse(CallInst &ci, const std::map<ReturnInst*,StoreI
// llvm::errs() << " + " << *u << "\n";

// Check if any of the unmoved operations will make it illegal to move the instruction

for (auto inst : usetree) {
if (!inst->mayReadFromMemory()) continue;
allFollowersOf(inst, [&](Instruction* post) {
if (unnecessaryInstructions.count(post)) return;
if (!post->mayWriteToMemory()) return;
//llvm::errs() << " checking if illegal move of " << *inst << " due to " << *post << "\n";
auto mri = getMRI(inst, post);
if (isModSet(mri)) {
if (writesToMemoryReadBy(/*maybeReader*/inst, /*maybeWriter*/post)) {
if (called)
llvm::errs() << " failed to replace function " << (called->getName()) << " due to " << *post << " usetree: " << *inst << "\n";
else
Expand Down
Loading

0 comments on commit 4ea5488

Please sign in to comment.