Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tx2 #32

Merged
merged 2 commits into from
Nov 12, 2019
Merged

Tx2 #32

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 76 additions & 5 deletions enzyme/Enzyme/ActiveVariable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,24 @@ cl::opt<bool> nonmarkedglobals_inactive(
"enzyme_nonmarkedglobals_inactive", cl::init(false), cl::Hidden,
cl::desc("Consider all nonmarked globals to be inactive"));

bool isKnownIntegerTBAA(Instruction* inst) {
if (MDNode* md = inst->getMetadata(LLVMContext::MD_tbaa)) {
if (md->getNumOperands() != 3) return false;
Metadata* metadata = md->getOperand(1).get();
if (auto mda = dyn_cast<MDNode>(metadata)) {
if (mda->getNumOperands() == 0) return false;
Metadata* metadata2 = mda->getOperand(0).get();
if (auto typeName = dyn_cast<MDString>(metadata2)) {
auto typeNameStringRef = typeName->getString();
if (typeNameStringRef == "long") {
return true;
}
}
}
}
return false;
}

bool isIntASecretFloat(Value* val) {
assert(val->getType()->isIntegerTy());

Expand All @@ -55,10 +73,10 @@ bool isIntASecretFloat(Value* val) {
//if (cint->isOne()) return cint;
}


if (auto inst = dyn_cast<Instruction>(val)) {
bool floatingUse = false;
bool pointerUse = false;
bool intUse = false;
SmallPtrSet<Value*, 4> seen;

std::function<void(Value*)> trackPointer = [&](Value* v) {
Expand Down Expand Up @@ -129,12 +147,13 @@ bool isIntASecretFloat(Value* val) {

if (auto si = dyn_cast<StoreInst>(use)) {
assert(inst == si->getValueOperand());

if (isKnownIntegerTBAA(si)) intUse = true;
trackPointer(si->getPointerOperand());
}
}

if (auto li = dyn_cast<LoadInst>(inst)) {
if (isKnownIntegerTBAA(li)) intUse = true;
trackPointer(li->getOperand(0));
}

Expand All @@ -152,10 +171,11 @@ bool isIntASecretFloat(Value* val) {
pointerUse = true;
}

if (pointerUse && !floatingUse) return false;
if (!pointerUse && floatingUse) return true;
if (intUse && !pointerUse && !floatingUse) return false;
if (!intUse && pointerUse && !floatingUse) return false;
if (!intUse && !pointerUse && floatingUse) return true;
llvm::errs() << *inst->getParent()->getParent() << "\n";
llvm::errs() << " val:" << *val << " pointer:" << pointerUse << " floating:" << floatingUse << "\n";
llvm::errs() << " val:" << *val << " pointer:" << pointerUse << " floating:" << floatingUse << " int:" << intUse << "\n";
assert(0 && "ambiguous unsure if constant or not");
}

Expand Down Expand Up @@ -267,6 +287,24 @@ Type* isIntPointerASecretFloat(Value* val) {
assert(0 && "unsure if constant or not");
}

cl::opt<bool> ipoconst(
"enzyme_ipoconst", cl::init(false), cl::Hidden,
cl::desc("Interprocedural constant detection"));

/*
bool isFunctionArgumentConstant(CallInst* CI, Value* arg, SmallPtrSetImpl<Value*> &constants, SmallPtrSetImpl<Value*> &nonconstant, const SmallPtrSetImpl<Value*> &retvals, const SmallPtrSetImpl<Instruction*> &originalInstructions) {
Function* F = CI->getCalledFunction();
if (F == nullptr) return false;
if (F->isEmpty()) return false;

for(unsigned i=0; i<CI->getArgOperands(); i++) {

}

assert(ipoconst);
}
*/

// TODO separate if the instruction is constant (i.e. could change things)
// from if the value is constant (the value is something that could be differentiated)
bool isconstantM(Instruction* inst, SmallPtrSetImpl<Value*> &constants, SmallPtrSetImpl<Value*> &nonconstant, const SmallPtrSetImpl<Value*> &retvals, const SmallPtrSetImpl<Instruction*> &originalInstructions, uint8_t directions) {
Expand Down Expand Up @@ -308,6 +346,7 @@ bool isconstantM(Instruction* inst, SmallPtrSetImpl<Value*> &constants, SmallPtr
switch(op->getIntrinsicID()) {
case Intrinsic::assume:
case Intrinsic::stacksave:
case Intrinsic::prefetch:
case Intrinsic::stackrestore:
case Intrinsic::lifetime_start:
case Intrinsic::lifetime_end:
Expand Down Expand Up @@ -335,6 +374,36 @@ bool isconstantM(Instruction* inst, SmallPtrSetImpl<Value*> &constants, SmallPtr
constants.insert(inst);
return true;
}

if (isa<LoadInst>(inst) || isa<StoreInst>(inst)) {
if (isKnownIntegerTBAA(inst)) {
if (printconst)
llvm::errs() << " constant instruction from TBAA " << *inst << "\n";
constants.insert(inst);
return true;
}
}

/* TODO consider constant stores
if (auto si = dyn_cast<StoreInst>(inst)) {
SmallPtrSet<Value*, 20> constants2;
constants2.insert(constants.begin(), constants.end());
SmallPtrSet<Value*, 20> nonconstant2;
nonconstant2.insert(nonconstant.begin(), nonconstant.end());
constants2.insert(inst);
if (isconstantValueM(si->getValueOperand(), constants2, nonconstant2, retvals, originalInstructions, directions)) {
constants.insert(inst);
constants.insert(constants2.begin(), constants2.end());
constants.insert(constants_tmp.begin(), constants_tmp.end());

// not here since if had full updown might not have been nonconstant
//nonconstant.insert(nonconstant2.begin(), nonconstant2.end());
if (printconst)
llvm::errs() << "constant(" << (int)directions << ") store:" << *inst << "\n";
return true;
}
}
*/

if (printconst)
llvm::errs() << "checking if is constant[" << (int)directions << "] " << *inst << "\n";
Expand All @@ -354,6 +423,7 @@ bool isconstantM(Instruction* inst, SmallPtrSetImpl<Value*> &constants, SmallPtr

for (const auto &a:inst->users()) {
if(auto store = dyn_cast<StoreInst>(a)) {

if (inst == store->getPointerOperand() && !isconstantValueM(store->getValueOperand(), constants2, nonconstant2, retvals, originalInstructions, directions)) {
if (directions == 3)
nonconstant.insert(inst);
Expand Down Expand Up @@ -420,6 +490,7 @@ bool isconstantM(Instruction* inst, SmallPtrSetImpl<Value*> &constants, SmallPtr
continue;
if (fnp->getIntrinsicID() == Intrinsic::memmove && call->getArgOperand(0) != inst && call->getArgOperand(1) != inst)
continue;

}
}

Expand Down
11 changes: 7 additions & 4 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,7 @@ std::pair<Function*,StructType*> CreateAugmentedPrimal(Function* todiff, AAResul
break;
}
case Intrinsic::stacksave:
case Intrinsic::prefetch:
case Intrinsic::stackrestore:
case Intrinsic::dbg_declare:
case Intrinsic::dbg_value:
Expand Down Expand Up @@ -783,7 +784,7 @@ std::pair<Function*,StructType*> CreateAugmentedPrimal(Function* todiff, AAResul
} else if(auto op = dyn_cast<StoreInst>(inst)) {
if (gutils->isConstantInstruction(inst)) continue;

if ( op->getValueOperand()->getType()->isPointerTy() || (op->getValueOperand()->getType()->isIntegerTy() && !isIntASecretFloat(op->getValueOperand()) ) ) {
if ( op->getValueOperand()->getType()->isPointerTy() || (op->getValueOperand()->getType()->isIntegerTy() && !gutils->isConstantValue(op->getValueOperand()) && !isIntASecretFloat(op->getValueOperand()) ) ) {
IRBuilder <> storeBuilder(op);
//llvm::errs() << "a op value: " << *op->getValueOperand() << "\n";
Value* valueop = gutils->invertPointerM(op->getValueOperand(), storeBuilder);
Expand Down Expand Up @@ -2232,6 +2233,7 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set<unsigned>& co
}
case Intrinsic::assume:
case Intrinsic::stacksave:
case Intrinsic::prefetch:
case Intrinsic::stackrestore:
case Intrinsic::dbg_declare:
case Intrinsic::dbg_value:
Expand Down Expand Up @@ -2472,9 +2474,10 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set<unsigned>& co
ts->setSyncScopeID(op->getSyncScopeID());
} else if (topLevel) {
IRBuilder <> storeBuilder(op);
//llvm::errs() << "op value: " << *op->getValueOperand() << "\n";
llvm::errs() << "store op pointer: " << *op << "\n";
llvm::errs() << "op value: " << *op->getValueOperand() << "\n";
Value* valueop = gutils->invertPointerM(op->getValueOperand(), storeBuilder);
//llvm::errs() << "op pointer: " << *op->getPointerOperand() << "\n";
llvm::errs() << "op pointer: " << *op->getPointerOperand() << "\n";
Value* pointerop = gutils->invertPointerM(op->getPointerOperand(), storeBuilder);
storeBuilder.CreateStore(valueop, pointerop);
//llvm::errs() << "ignoring store bc pointer of " << *op << "\n";
Expand Down Expand Up @@ -2555,7 +2558,7 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set<unsigned>& co
setDiffe(inst, Constant::getNullValue(inst->getType()));
} else if(auto op = dyn_cast<CastInst>(inst)) {
if (gutils->isConstantValue(inst)) continue;
if (op->getType()->isPointerTy()) continue;
if (op->getType()->isPointerTy() || op->getOpcode() == CastInst::CastOps::PtrToInt) continue;

if (!gutils->isConstantValue(op->getOperand(0))) {
if (op->getOpcode()==CastInst::CastOps::FPTrunc || op->getOpcode()==CastInst::CastOps::FPExt) {
Expand Down
3 changes: 2 additions & 1 deletion enzyme/Enzyme/FunctionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -467,13 +467,14 @@ Function* preprocessForClone(Function *F, AAResults &AA, TargetLibraryInfo &TLI)
//auto baa = new BasicAAResult(ba.run(*NewF, AM));
AssumptionCache* AC = new AssumptionCache(*NewF);
TargetLibraryInfo* TLI = new TargetLibraryInfo(AM.getResult<TargetLibraryAnalysis>(*NewF));
DominatorTree* DTL = new DominatorTree(*NewF);
auto baa = new BasicAAResult(NewF->getParent()->getDataLayout(),
#if LLVM_VERSION_MAJOR > 6
*NewF,
#endif
*TLI,
*AC,
&AM.getResult<DominatorTreeAnalysis>(*NewF),
DTL/*&AM.getResult<DominatorTreeAnalysis>(*NewF)*/,
AM.getCachedResult<LoopAnalysis>(*NewF)
#if LLVM_VERSION_MAJOR > 6
,AM.getCachedResult<PhiValuesAnalysis>(*NewF)
Expand Down
36 changes: 28 additions & 8 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ static bool isParentOrSameContext(LoopContext & possibleChild, LoopContext & pos
return lc.latchMerge;
}

//llvm::errs() << " BB:" << BB->getName() << " branchingBlock:" << branchingBlock->getName() << "\n";
llvm::errs() << " BB:" << BB->getName() << " branchingBlock:" << branchingBlock->getName() << "\n";
return reverseBlocks[BB];
}

Expand Down Expand Up @@ -133,8 +133,8 @@ static bool isParentOrSameContext(LoopContext & possibleChild, LoopContext & pos
for(auto pred : predecessors(exit)) {
auto fd = std::find(latches.begin(), latches.end(), pred);
if ( fd != latches.end()) {
auto latch = *fd;
targetToPreds[reverseBlocks[latch]].push_back(std::make_pair(pred, exit));
targetToPreds[reverseBlocks[pred]].push_back(std::make_pair(pred, exit));
//targetToPreds[getReverseOrLatchMerge(pred, exit)].push_back(std::make_pair(pred, exit));
}
}
}
Expand Down Expand Up @@ -396,9 +396,23 @@ Value* GradientUtils::invertPointerM(Value* val, IRBuilder<>& BuilderM) {
invertedPointers[arg] = li;
return lookupM(invertedPointers[arg], BuilderM);
} else if (auto arg = dyn_cast<BinaryOperator>(val)) {
assert(arg->getType()->isIntOrIntVectorTy());
assert(arg->getType()->isIntOrIntVectorTy());
IRBuilder <> bb(arg);
auto li = bb.CreateBinOp(arg->getOpcode(), invertPointerM(arg->getOperand(0), bb), invertPointerM(arg->getOperand(1), bb), arg->getName());
Value* val0 = nullptr;
Value* val1 = nullptr;

if (isa<ConstantInt>(arg->getOperand(0))) {
val0 = arg->getOperand(0);
val1 = invertPointerM(arg->getOperand(1), bb);
} else if (isa<ConstantInt>(arg->getOperand(1))) {
val0 = invertPointerM(arg->getOperand(0), bb);
val1 = arg->getOperand(1);
} else {
val0 = invertPointerM(arg->getOperand(0), bb);
val1 = invertPointerM(arg->getOperand(1), bb);
}

auto li = bb.CreateBinOp(arg->getOpcode(), val0, val1, arg->getName());
invertedPointers[arg] = li;
return lookupM(invertedPointers[arg], BuilderM);
} else if (auto arg = dyn_cast<GetElementPtrInst>(val)) {
Expand Down Expand Up @@ -539,6 +553,7 @@ void removeRedundantIVs(const Loop* L, BasicBlock* Header, BasicBlock* Preheader
for (BasicBlock::iterator II = Header->begin(); isa<PHINode>(II); ++II) {
PHINode *PN = cast<PHINode>(II);
if (PN == CanonicalIV) continue;
if (PN->getType()->isPointerTy()) continue;
if (!SE.isSCEVable(PN->getType())) continue;
const SCEV *S = SE.getSCEV(PN);
if (SE.getCouldNotCompute() == S) continue;
Expand Down Expand Up @@ -937,7 +952,6 @@ void GradientUtils::branchToCorrespondingTarget(BasicBlock* ctx, IRBuilder <>& B

IntegerType* T = (targetToPreds.size() == 2) ? Type::getInt1Ty(BuilderM.getContext()) : Type::getInt8Ty(BuilderM.getContext());
CallInst* freeLocation;
AllocaInst* cache = createCacheForScope(ctx, T, "", /*shouldFree*/&freeLocation, /*lastAlloca*/nullptr);

Instruction* equivalentTerminator = nullptr;

Expand Down Expand Up @@ -986,6 +1000,9 @@ void GradientUtils::branchToCorrespondingTarget(BasicBlock* ctx, IRBuilder <>& B
BasicBlock* block = equivalentTerminator->getParent();
assert(branch->getCondition());

assert(branch->getCondition()->getType() == T);

AllocaInst* cache = createCacheForScope(ctx, T, "", /*shouldFree*/&freeLocation, /*lastAlloca*/nullptr);
IRBuilder<> pbuilder(equivalentTerminator);
pbuilder.setFastMathFlags(getFast());
storeInstructionInCache(ctx, pbuilder, branch->getCondition(), cache);
Expand Down Expand Up @@ -1019,12 +1036,14 @@ void GradientUtils::branchToCorrespondingTarget(BasicBlock* ctx, IRBuilder <>& B
}
}
} else if (auto si = dyn_cast<SwitchInst>(equivalentTerminator)) {
assert(branch->getCondition());
BasicBlock* block = equivalentTerminator->getParent();

IRBuilder<> pbuilder(equivalentTerminator);
pbuilder.setFastMathFlags(getFast());
storeInstructionInCache(ctx, pbuilder, branch->getCondition(), cache);

AllocaInst* cache = createCacheForScope(ctx, si->getCondition()->getType(), "", /*shouldFree*/&freeLocation, /*lastAlloca*/nullptr);
Value* condition = si->getCondition();
storeInstructionInCache(ctx, pbuilder, condition, cache);

Value* phi = lookupValueFromCache(BuilderM, ctx, cache);

Expand Down Expand Up @@ -1068,6 +1087,7 @@ void GradientUtils::branchToCorrespondingTarget(BasicBlock* ctx, IRBuilder <>& B

nofast:;

AllocaInst* cache = createCacheForScope(ctx, T, "", /*shouldFree*/&freeLocation, /*lastAlloca*/nullptr);
std::vector<BasicBlock*> targets;
{
size_t idx = 0;
Expand Down
23 changes: 20 additions & 3 deletions enzyme/Enzyme/GradientUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1165,8 +1165,9 @@ class GradientUtils {
}
}
}

v.CreateStore(val, getCachePointer(v, ctx, cache));
Value* loc = getCachePointer(v, ctx, cache);
assert(cast<PointerType>(loc->getType())->getElementType() == val->getType());
v.CreateStore(val, loc);
}

void storeInstructionInCache(BasicBlock* ctx, Instruction* inst, AllocaInst* cache) {
Expand Down Expand Up @@ -1381,11 +1382,26 @@ class DiffeGradientUtils : public GradientUtils {
sv.push_back(i);
Value* ptr = BuilderM.CreateGEP(getDifferential(val), sv);
Value* old = BuilderM.CreateLoad(ptr);

Value* res = nullptr;

if (old->getType()->isIntOrIntVectorTy()) {
res = BuilderM.CreateFAdd(BuilderM.CreateBitCast(old, IntToFloatTy(old->getType())), BuilderM.CreateBitCast(dif, IntToFloatTy(dif->getType())));
res = BuilderM.CreateBitCast(res, old->getType());
} else if(old->getType()->isFPOrFPVectorTy()) {
res = BuilderM.CreateFAdd(old, dif);
} else {
assert(old);
assert(dif);
llvm::errs() << *newFunc << "\n" << "cannot handle type " << *old << "\n" << *dif;
report_fatal_error("cannot handle type");
}


Value* res = BuilderM.CreateFAdd(old, dif);
SelectInst* addedSelect = nullptr;

//! optimize fadd of select to select of fadd
// TODO: Handle Selects of ints
if (SelectInst* select = dyn_cast<SelectInst>(dif)) {
if (ConstantFP* ci = dyn_cast<ConstantFP>(select->getTrueValue())) {
if (ci->isZero()) {
Expand Down Expand Up @@ -1426,6 +1442,7 @@ class DiffeGradientUtils : public GradientUtils {

Value* res;
Value* old = BuilderM.CreateLoad(ptr);

if (old->getType()->isIntOrIntVectorTy()) {
res = BuilderM.CreateFAdd(BuilderM.CreateBitCast(old, IntToFloatTy(old->getType())), BuilderM.CreateBitCast(dif, IntToFloatTy(dif->getType())));
res = BuilderM.CreateBitCast(res, old->getType());
Expand Down