diff --git a/enzyme/Enzyme/ActiveVariable.cpp b/enzyme/Enzyme/ActiveVariable.cpp index 1bafa198c96f..e3eab8bca3ab 100644 --- a/enzyme/Enzyme/ActiveVariable.cpp +++ b/enzyme/Enzyme/ActiveVariable.cpp @@ -55,6 +55,7 @@ bool isIntASecretFloat(Value* val) { if (auto inst = dyn_cast(val)) { bool floatingUse = false; bool pointerUse = false; + bool intUse = false; SmallPtrSet seen; std::function trackPointer = [&](Value* v) { @@ -126,12 +127,50 @@ bool isIntASecretFloat(Value* val) { if (auto si = dyn_cast(use)) { assert(inst == si->getValueOperand()); + if (MDNode* md = si->getMetadata(LLVMContext::MD_tbaa)) { + llvm::errs() << "TFKDEBUG MDNode " << *md << " Inst is " << *inst << " " << *si <<"\n"; + if (md->getNumOperands() == 3) { + const MDOperand& accessType = md->getOperand(1); + Metadata* metadata = accessType.get(); + if (auto mda = dyn_cast(metadata)) { + if (mda->getNumOperands() > 0) { + const MDOperand& underlyingType = mda->getOperand(0); + Metadata* metadata2 = underlyingType.get(); + if (auto typeName = dyn_cast(metadata2)) { + auto typeNameStringRef = typeName->getString(); + if (typeNameStringRef == "long") { + intUse = true; + } + } + } + } + } + } trackPointer(si->getPointerOperand()); } } if (auto li = dyn_cast(inst)) { - trackPointer(li->getOperand(0)); + if (MDNode* md = li->getMetadata(LLVMContext::MD_tbaa)) { + llvm::errs() << "TFKDEBUG MDNode " << *md << " Inst is " << *inst << " " << *li <<"\n"; + if (md->getNumOperands() == 3) { + const MDOperand& accessType = md->getOperand(1); + Metadata* metadata = accessType.get(); + if (auto mda = dyn_cast(metadata)) { + if (mda->getNumOperands() > 0) { + const MDOperand& underlyingType = mda->getOperand(0); + Metadata* metadata2 = underlyingType.get(); + if (auto typeName = dyn_cast(metadata2)) { + auto typeNameStringRef = typeName->getString(); + if (typeNameStringRef == "long") { + intUse = true; + } + } + } + } + } + } + trackPointer(li->getOperand(0)); } if (auto ci = dyn_cast(inst)) { @@ -147,11 +186,11 @@ bool isIntASecretFloat(Value* val) { if (isa(inst)) { pointerUse = true; } - - if (pointerUse && !floatingUse) return false; - if (!pointerUse && floatingUse) return true; + if (intUse && !pointerUse && !floatingUse) return false; + if (!intUse && pointerUse && !floatingUse) return false; + if (!intUse && !pointerUse && floatingUse) return true; llvm::errs() << *inst->getParent()->getParent() << "\n"; - llvm::errs() << " val:" << *val << " pointer:" << pointerUse << " floating:" << floatingUse << "\n"; + llvm::errs() << " val:" << *val << " pointer:" << pointerUse << " floating:" << floatingUse << " intuse: " << intUse << "\n"; assert(0 && "ambiguous unsure if constant or not"); } @@ -296,6 +335,17 @@ bool isconstantM(Instruction* inst, SmallPtrSetImpl &constants, SmallPtr } } } + + if (auto op = dyn_cast(inst)) { + switch(op->getOpcode()) { + //case BinaryOperator::Add: + case BinaryOperator::Mul: + constants.insert(inst); + return true; + default: + break; + } + } if (auto op = dyn_cast(inst)) { switch(op->getIntrinsicID()) { diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 1bd358db06aa..19fdeb5fb37c 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -155,7 +155,8 @@ void HandleAutoDiff(CallInst *CI, TargetLibraryInfo &TLI, AAResults &AA) {//, Lo bool differentialReturn = cast(fn)->getReturnType()->isFPOrFPVectorTy(); - auto newFunc = CreatePrimalAndGradient(cast(fn), constants, TLI, AA, /*should return*/false, differentialReturn, /*topLevel*/true, /*addedType*/nullptr);//, LI, DT); + std::set volatile_args; + auto newFunc = CreatePrimalAndGradient(cast(fn), constants, TLI, AA, /*should return*/false, differentialReturn, /*topLevel*/true, /*addedType*/nullptr, volatile_args);//, LI, DT); if (differentialReturn) args.push_back(ConstantFP::get(cast(fn)->getReturnType(), 1.0)); diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 62af08b003b2..6f2be38f2a1b 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -32,6 +32,7 @@ #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" @@ -47,14 +48,283 @@ llvm::cl::opt enzyme_print("enzyme_print", cl::init(false), cl::Hidden, cl::desc("Print before and after fns for autodiff")); cl::opt cachereads( - "enzyme_cachereads", cl::init(false), cl::Hidden, + "enzyme_cachereads", cl::init(true), cl::Hidden, cl::desc("Force caching of all reads")); +std::map compute_volatile_load_map(GradientUtils* gutils, AAResults& AA, + std::set volatile_args) { + std::map can_modref_map; + // NOTE(TFK): Want to construct a test case where this causes an issue. + for(BasicBlock* BB: gutils->originalBlocks) { + for (auto I = BB->begin(), E = BB->end(); I != E; I++) { + Instruction* inst = &*I; + llvm::errs() << "considering instruction " << *inst << "\n"; + if (auto op = dyn_cast(inst)) { + if (gutils->isConstantValue(inst) || gutils->isConstantInstruction(inst)) { + continue; + } + + bool can_modref = false; + auto obj = GetUnderlyingObject(op->getPointerOperand(), BB->getModule()->getDataLayout(), 100); + if (auto arg = dyn_cast(obj)) { + if (volatile_args.find(arg->getArgNo()) != volatile_args.end()) { + can_modref = true; + } + } + for (BasicBlock* BB2 : gutils->originalBlocks) { + llvm::errs() << "looking at basic block " << *BB2 << "\n"; + if (can_modref) break; + llvm::errs() << "A1\n"; + if (BB == BB2) { + llvm::errs() << "A2\n"; + for (auto I2 = BB2->begin(), E2 = BB2->end(); I2 != E2; I2++) { + Instruction* inst2 = &*I2; + llvm::errs() << "looking at basic block instruction " << *inst2 << "\n"; + if (inst == inst2) continue; + llvm::errs() << "A3\n"; + if (!gutils->DT.dominates(inst2, inst)) { + llvm::errs() << "A4\n"; + //if (llvm::isModSet(AA.getModRefInfo(inst2, MemoryLocation::get(op)))) { + if (AA.canInstructionRangeModRef (*inst2, *inst2, MemoryLocation::get(op), ModRefInfo::Mod)) { + llvm::errs() << "A5\n"; + can_modref = true; + } + //if (llvm::isModSet(AA.getModRefInfo(inst2, MemoryLocation::get(op)))) { + // can_modref = true; + // //llvm::errs() << *inst << " needs to be cached due to: " << *inst2 << "\n"; + // break; + //} + llvm::errs() << "A6\n"; + } + llvm::errs() << "A7\n"; + } + llvm::errs() << "A8\n"; + } else { + llvm::errs() << "A9\n"; + if (!gutils->DT.dominates(BB2, BB)) { + llvm::errs() << "A10\n"; + if (AA.canBasicBlockModify(*BB2, MemoryLocation::get(op))) { + llvm::errs() << "A11\n"; + can_modref = true; + break; + } + llvm::errs() << "A12\n"; + } + llvm::errs() << "A13\n"; + } + llvm::errs() << "A14\n"; + } + // NOTE(TFK): I need a testcase where this logic below fails to test correctness of logic above. + //for (unsigned k = 0; k < gutils->originalBlocks.size(); k++) { + // if (AA.canBasicBlockModify(*(gutils->originalBlocks[k]), MemoryLocation::get(op))) { + // can_modref = true; + // break; + // } + //} + can_modref_map[inst] = can_modref; + //can_modref_map[inst] = true; + } + } + } + return can_modref_map; +} + + +std::set compute_volatile_args_for_one_callsite(Instruction* callsite_inst, DominatorTree &DT, + TargetLibraryInfo &TLI, AAResults& AA, GradientUtils* gutils, std::set parent_volatile_args) { + CallInst* callsite_op = dyn_cast(callsite_inst); + assert(callsite_op != nullptr); + + std::set volatile_args; + std::vector args; + std::vector args_safe; + llvm::errs() << "computing volatile args for callsite " << *callsite_inst << "\n"; + // First, we need to propagate the volatile status from the parent function to the callee. + // because memory location x modified after parent returns => x modified after callee returns. + for (unsigned i = 0; i < callsite_op->getNumArgOperands(); i++) { + args.push_back(callsite_op->getArgOperand(i)); + bool init_safe = true; + + // If the UnderlyingObject is from one of this function's arguments, then we need to propagate the volatility. + Value* obj = GetUnderlyingObject(callsite_op->getArgOperand(i), + callsite_inst->getParent()->getModule()->getDataLayout(), + 100); + // If underlying object is an Argument, check parent volatility status. + if (auto arg = dyn_cast(obj)) { + if (parent_volatile_args.find(arg->getArgNo()) != parent_volatile_args.end()) { + init_safe = false; + } + } + // TODO(TFK): Also need to check whether underlying object is traced to load / non-allocating-call instruction. + args_safe.push_back(init_safe); + } + + // Second, we check for memory modifications that can occur in the continuation of the + // callee inside the parent function. + for(BasicBlock* BB: gutils->originalBlocks) { + for (auto I = BB->begin(), E = BB->end(); I != E; I++) { + Instruction* inst = &*I; + if (inst == callsite_inst) continue; + + // If the "inst" does not dominate "callsite_inst" then we cannot prove that + // "inst" happens before "callsite_inst". If "inst" modifies an argument of the call, + // then that call needs to consider the argument volatile. + if (!gutils->DT.dominates(inst, callsite_inst)) { + // Consider Store Instructions. + if (auto op = dyn_cast(inst)) { + for (unsigned i = 0; i < args.size(); i++) { + // If the modification flag is set, then this instruction may modify the $i$th argument of the call. + if (!llvm::isModSet(AA.getModRefInfo(op, MemoryLocation::getForArgument(callsite_op, i, TLI)))) { + //llvm::errs() << "Instruction " << *op << " is NoModRef with call argument " << *args[i] << "\n"; + } else { + //llvm::errs() << "Instruction " << *op << " is maybe ModRef with call argument " << *args[i] << "\n"; + args_safe[i] = false; + } + } + } + + // Consider Call Instructions. + if (auto op = dyn_cast(inst)) { + // Ignore memory allocation functions. + Function* called = op->getCalledFunction(); + if (auto castinst = dyn_cast(op->getCalledValue())) { + if (castinst->isCast()) { + if (auto fn = dyn_cast(castinst->getOperand(0))) { + if (isAllocationFunction(*fn, TLI) || isDeallocationFunction(*fn, TLI)) { + called = fn; + } + } + } + } + if (isCertainMallocOrFree(called)) { + continue; + } + + // For all the arguments, perform same check as for Stores, but ignore non-pointer arguments. + for (unsigned i = 0; i < args.size(); i++) { + if (!args[i]->getType()->isPointerTy()) continue; // Ignore non-pointer arguments. + if (!llvm::isModSet(AA.getModRefInfo(op, MemoryLocation::getForArgument(callsite_op, i, TLI)))) { + //llvm::errs() << "Instruction " << *op << " is NoModRef with call argument " << *args[i] << "\n"; + } else { + //llvm::errs() << "Instruction " << *op << " is maybe ModRef with call argument " << *args[i] << "\n"; + args_safe[i] = false; + } + } + } + } + } + } + + //llvm::errs() << "CallInst: " << *callsite_op<< "CALL ARGUMENT INFO: \n"; + for (unsigned i = 0; i < args.size(); i++) { + if (!args_safe[i]) { + volatile_args.insert(i); + } + //llvm::errs() << "Arg: " << *args[i] << " STATUS: " << args_safe[i] << "\n"; + } + return volatile_args; +} + +// Given a function and the arguments passed to it by its caller that are volatile (_volatile_args) compute +// the set of volatile arguments for each callsite inside the function. A pointer argument is volatile at +// a callsite if the memory pointed to might be modified after that callsite. +std::map > compute_volatile_args_for_callsites( + Function* F, DominatorTree &DT, TargetLibraryInfo &TLI, AAResults& AA, GradientUtils* gutils, + std::set const volatile_args) { + std::map > volatile_args_map; + for(BasicBlock* BB: gutils->originalBlocks) { + for (auto I = BB->begin(), E = BB->end(); I != E; I++) { + Instruction* inst = &*I; + if (auto op = dyn_cast(inst)) { + + // We do not need volatile args for intrinsic functions. So skip such callsites. + if(isa(inst)) { + continue; + } + + // We do not need volatile args for memory allocation functions. So skip such callsites. + Function* called = op->getCalledFunction(); + if (auto castinst = dyn_cast(op->getCalledValue())) { + if (castinst->isCast()) { + if (auto fn = dyn_cast(castinst->getOperand(0))) { + if (isAllocationFunction(*fn, TLI) || isDeallocationFunction(*fn, TLI)) { + called = fn; + } + } + } + } + if (isCertainMallocOrFree(called)) { + continue; + } + + // For all other calls, we compute the volatile args for this callsite. + volatile_args_map[op] = compute_volatile_args_for_one_callsite(inst, + DT, TLI, AA, gutils, volatile_args); + } + } + } + return volatile_args_map; +} + +// Determine if a load is needed in the reverse pass. We only use this logic in the top level function right now. +bool is_load_needed_in_reverse(GradientUtils* gutils, AAResults& AA, Instruction* inst) { + + std::vector uses_list; + std::set uses_set; + uses_list.push_back(inst); + uses_set.insert(inst); + + while (true) { + bool new_user_added = false; + for (unsigned i = 0; i < uses_list.size(); i++) { + for (auto use = uses_list[i]->user_begin(), end = uses_list[i]->user_end(); use != end; ++use) { + Value* v = (*use); + //llvm::errs() << "Use list: " << *v << "\n"; + if (uses_set.find(v) == uses_set.end()) { + uses_set.insert(v); + uses_list.push_back(v); + new_user_added = true; + } + } + } + if (!new_user_added) break; + } + //llvm::errs() << "Analysis for load " << *inst << " which has nuses: " << inst->getNumUses() << "\n"; + for (unsigned i = 0; i < uses_list.size(); i++) { + //llvm::errs() << "Considering use " << *uses_list[i] << "\n"; + if (uses_list[i] == dyn_cast(inst)) continue; + + if (isa(uses_list[i]) || isa(uses_list[i]) || isa(uses_list[i]) || isa(uses_list[i]) || isa(uses_list[i]) || + isa(uses_list[i]) || isa(uses_list[i])){ + continue; + } + + if (auto op = dyn_cast(uses_list[i])) { + if (op->getOpcode() == Instruction::FAdd || op->getOpcode() == Instruction::FSub) { + continue; + } else { + //llvm::errs() << "Need value of " << *inst << "\n" << "\t Due to " << *op << "\n"; + return true; + } + } + + //if (auto op = dyn_cast(uses_list[i])) { + // //llvm::errs() << "Need value of " << *inst << "\n" << "\t Due to " << *op << "\n"; + // return true; + //} + + //llvm::errs() << "Need value of " << *inst << "\n" << "\t Due to " << *uses_list[i] << "\n"; + return true; + } + return false; +} + + //! return structtype if recursive function -std::pair CreateAugmentedPrimal(Function* todiff, AAResults &AA, const std::set& constant_args, TargetLibraryInfo &TLI, bool differentialReturn, bool returnUsed) { - static std::map, bool/*differentialReturn*/, bool/*returnUsed*/>, std::pair> cachedfunctions; - static std::map, bool/*differentialReturn*/, bool/*returnUsed*/>, bool> cachedfinished; - auto tup = std::make_tuple(todiff, std::set(constant_args.begin(), constant_args.end()), differentialReturn, returnUsed); +std::pair CreateAugmentedPrimal(Function* todiff, AAResults &_AA, const std::set& constant_args, TargetLibraryInfo &TLI, bool differentialReturn, bool returnUsed, const std::set _volatile_args) { + static std::map, std::set, bool/*differentialReturn*/, bool/*returnUsed*/>, std::pair> cachedfunctions; + static std::map, std::set, bool/*differentialReturn*/, bool/*returnUsed*/>, bool> cachedfinished; + auto tup = std::make_tuple(todiff, std::set(constant_args.begin(), constant_args.end()), std::set(_volatile_args.begin(), _volatile_args.end()), differentialReturn, returnUsed); if (cachedfunctions.find(tup) != cachedfunctions.end()) { return cachedfunctions[tup]; } @@ -104,15 +374,26 @@ std::pair CreateAugmentedPrimal(Function* todiff, AAResul //assert(st->getNumElements() > 0); return cachedfunctions[tup] = std::pair(foundcalled, nullptr); //dyn_cast(st->getElementType(0))); } + + + + + if (todiff->empty()) { llvm::errs() << *todiff << "\n"; } assert(!todiff->empty()); - + AAResults AA(TLI); GradientUtils *gutils = GradientUtils::CreateFromClone(todiff, AA, TLI, constant_args, /*returnValue*/returnUsed ? ReturnType::TapeAndReturns : ReturnType::Tape, /*differentialReturn*/differentialReturn); cachedfunctions[tup] = std::pair(gutils->newFunc, nullptr); cachedfinished[tup] = false; + std::map > volatile_args_map;/* = + compute_volatile_args_for_callsites(gutils->oldFunc, gutils->DT, TLI, AA, gutils, _volatile_args);*/ + + std::map can_modref_map = compute_volatile_load_map(gutils, AA, _volatile_args); + gutils->can_modref_map = &can_modref_map; + gutils->forceContexts(); gutils->forceAugmentedReturns(); @@ -170,6 +451,7 @@ std::pair CreateAugmentedPrimal(Function* todiff, AAResul if (gutils->originalInstructions.find(inst) == gutils->originalInstructions.end()) continue; if(auto op = dyn_cast_or_null(inst)) { + llvm::errs() << "TFKDEBUG OP " << *op << " WITH INTRINSIC ID " << op->getIntrinsicID() << "\n"; switch(op->getIntrinsicID()) { case Intrinsic::memcpy: { if (gutils->isConstantInstruction(inst)) continue; @@ -330,6 +612,7 @@ std::pair CreateAugmentedPrimal(Function* todiff, AAResul auto argType = op->getArgOperand(i)->getType(); if (argType->isPointerTy() || argType->isIntegerTy()) { + llvm::errs() << "TFKDEBUG " << *op << "\n"; argsInverted.push_back(DIFFE_TYPE::DUP_ARG); args.push_back(gutils->invertPointerM(op->getArgOperand(i), BuilderZ)); @@ -364,7 +647,7 @@ std::pair CreateAugmentedPrimal(Function* todiff, AAResul } } - auto newcalled = CreateAugmentedPrimal(dyn_cast(called), AA, subconstant_args, TLI, /*differentialReturn*/subdifferentialreturn, /*return is used*/subretused).first; + auto newcalled = CreateAugmentedPrimal(dyn_cast(called), AA, subconstant_args, TLI, /*differentialReturn*/subdifferentialreturn, /*return is used*/subretused, volatile_args_map[op]).first; auto augmentcall = BuilderZ.CreateCall(newcalled, args); assert(augmentcall->getType()->isStructTy()); augmentcall->setCallingConv(op->getCallingConv()); @@ -395,20 +678,21 @@ std::pair CreateAugmentedPrimal(Function* todiff, AAResul gutils->addMalloc(BuilderZ, rv); } - if ((op->getType()->isPointerTy() || op->getType()->isIntegerTy()) && subdifferentialreturn) { + if ((op->getType()->isPointerTy() || op->getType()->isIntegerTy()) && gutils->invertedPointers.count(op) != 0) { auto placeholder = cast(gutils->invertedPointers[op]); if (I != E && placeholder == &*I) I++; gutils->invertedPointers.erase(op); - assert(cast(augmentcall->getType())->getNumElements() == 3); - auto antiptr = cast(BuilderZ.CreateExtractValue(augmentcall, {2}, "antiptr_" + op->getName() )); - gutils->invertedPointers[rv] = antiptr; - placeholder->replaceAllUsesWith(antiptr); - - if (shouldCache) { - gutils->addMalloc(BuilderZ, antiptr); + if (subdifferentialreturn) { + assert(cast(augmentcall->getType())->getNumElements() == 3); + auto antiptr = cast(BuilderZ.CreateExtractValue(augmentcall, {2}, "antiptr_" + op->getName() )); + gutils->invertedPointers[rv] = antiptr; + placeholder->replaceAllUsesWith(antiptr); + + if (shouldCache) { + gutils->addMalloc(BuilderZ, antiptr); + } } - gutils->erase(placeholder); } else { if (cast(augmentcall->getType())->getNumElements() != 2) { @@ -422,12 +706,19 @@ std::pair CreateAugmentedPrimal(Function* todiff, AAResul } gutils->replaceAWithB(op,rv); + } else { + if ((op->getType()->isPointerTy() || op->getType()->isIntegerTy()) && gutils->invertedPointers.count(op) != 0) { + auto placeholder = cast(gutils->invertedPointers[op]); + if (I != E && placeholder == &*I) I++; + gutils->invertedPointers.erase(op); + gutils->erase(placeholder); + } } gutils->erase(op); } else if(LoadInst* li = dyn_cast(inst)) { if (gutils->isConstantInstruction(inst) || gutils->isConstantValue(inst)) continue; - if (cachereads) { + if (/*true || */(cachereads && can_modref_map[inst])) { llvm::errs() << "Forcibly caching reads " << *li << "\n"; IRBuilder<> BuilderZ(li); gutils->addMalloc(BuilderZ, li); @@ -901,7 +1192,7 @@ std::pair,SmallVector> getDefaultFunctionTypeForGr return std::pair,SmallVector>(args, outs); } -void handleGradientCallInst(BasicBlock::reverse_iterator &I, const BasicBlock::reverse_iterator &E, IRBuilder <>& Builder2, CallInst* op, DiffeGradientUtils* const gutils, TargetLibraryInfo &TLI, AAResults &AA, const bool topLevel, const std::map &replacedReturns) { +void handleGradientCallInst(BasicBlock::reverse_iterator &I, const BasicBlock::reverse_iterator &E, IRBuilder <>& Builder2, CallInst* op, DiffeGradientUtils* const gutils, TargetLibraryInfo &TLI, AAResults &AA, const bool topLevel, const std::map &replacedReturns, std::set volatile_args) { Function *called = op->getCalledFunction(); if (auto castinst = dyn_cast(op->getCalledValue())) { @@ -1242,7 +1533,7 @@ void handleGradientCallInst(BasicBlock::reverse_iterator &I, const BasicBlock::r if (modifyPrimal && called) { bool subretused = op->getNumUses() != 0; bool subdifferentialreturn = (!gutils->isConstantValue(op)) && subretused; - auto fnandtapetype = CreateAugmentedPrimal(cast(called), AA, subconstant_args, TLI, /*differentialReturns*/subdifferentialreturn, /*return is used*/subretused); + auto fnandtapetype = CreateAugmentedPrimal(cast(called), AA, subconstant_args, TLI, /*differentialReturns*/subdifferentialreturn, /*return is used*/subretused, volatile_args); if (topLevel) { Function* newcalled = fnandtapetype.first; augmentcall = BuilderZ.CreateCall(newcalled, pre_args); @@ -1314,7 +1605,7 @@ void handleGradientCallInst(BasicBlock::reverse_iterator &I, const BasicBlock::r bool subdiffereturn = (!gutils->isConstantValue(op)) && !( op->getType()->isPointerTy() || op->getType()->isIntegerTy() || op->getType()->isEmptyTy() ); llvm::errs() << "subdifferet:" << subdiffereturn << " " << *op << "\n"; if (called) { - newcalled = CreatePrimalAndGradient(cast(called), subconstant_args, TLI, AA, /*returnValue*/retUsed, /*subdiffereturn*/subdiffereturn, /*topLevel*/replaceFunction, tape ? tape->getType() : nullptr);//, LI, DT); + newcalled = CreatePrimalAndGradient(cast(called), subconstant_args, TLI, AA, /*returnValue*/retUsed, /*subdiffereturn*/subdiffereturn, /*topLevel*/replaceFunction, tape ? tape->getType() : nullptr, volatile_args);//, LI, DT); } else { newcalled = gutils->invertPointerM(op->getCalledValue(), Builder2); auto ft = cast(cast(op->getCalledValue()->getType())->getElementType()); @@ -1424,7 +1715,7 @@ void handleGradientCallInst(BasicBlock::reverse_iterator &I, const BasicBlock::r } } -Function* CreatePrimalAndGradient(Function* todiff, const std::set& constant_args, TargetLibraryInfo &TLI, AAResults &AA, bool returnValue, bool differentialReturn, bool topLevel, llvm::Type* additionalArg) { +Function* CreatePrimalAndGradient(Function* todiff, const std::set& constant_args, TargetLibraryInfo &TLI, AAResults &_AA, bool returnValue, bool differentialReturn, bool topLevel, llvm::Type* additionalArg, std::set _volatile_args) { if (differentialReturn) { if(!todiff->getReturnType()->isFPOrFPVectorTy()) { llvm::errs() << *todiff << "\n"; @@ -1436,13 +1727,17 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co llvm::errs() << "addl arg: " << *additionalArg << "\n"; } if (additionalArg) assert(additionalArg->isStructTy()); - - static std::map, bool/*retval*/, bool/*differentialReturn*/, bool/*topLevel*/, llvm::Type*>, Function*> cachedfunctions; - auto tup = std::make_tuple(todiff, std::set(constant_args.begin(), constant_args.end()), returnValue, differentialReturn, topLevel, additionalArg); + static std::map, std::set, bool/*retval*/, bool/*differentialReturn*/, bool/*topLevel*/, llvm::Type*>, Function*> cachedfunctions; + auto tup = std::make_tuple(todiff, std::set(constant_args.begin(), constant_args.end()), std::set(_volatile_args.begin(), _volatile_args.end()), returnValue, differentialReturn, topLevel, additionalArg); if (cachedfunctions.find(tup) != cachedfunctions.end()) { return cachedfunctions[tup]; } + + + + bool hasTape = false; + if (constant_args.size() == 0 && !topLevel && !returnValue && hasMetadata(todiff, "enzyme_gradient")) { auto md = todiff->getMetadata("enzyme_gradient"); @@ -1458,7 +1753,6 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co auto res = getDefaultFunctionTypeForGradient(todiff->getFunctionType(), /*has return value*/!todiff->getReturnType()->isVoidTy(), differentialReturn); - bool hasTape = false; if (foundcalled->arg_size() == res.first.size() + 1 /*tape*/) { auto lastarg = foundcalled->arg_end(); @@ -1527,9 +1821,26 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co auto& Context = M->getContext(); + AAResults AA(TLI); DiffeGradientUtils *gutils = DiffeGradientUtils::CreateFromClone(todiff, AA, TLI, constant_args, returnValue ? ReturnType::ArgsWithReturn : ReturnType::Args, differentialReturn, additionalArg); cachedfunctions[tup] = gutils->newFunc; + std::map > volatile_args_map = + compute_volatile_args_for_callsites(gutils->oldFunc, gutils->DT, TLI, AA, gutils, _volatile_args); + + std::map can_modref_map = compute_volatile_load_map(gutils, AA, _volatile_args); + // NOTE(TFK): Sanity check this decision. + // Is it always possibly to recompute the result of loads at top level? + if (topLevel) { + for (auto iter = can_modref_map.begin(); iter != can_modref_map.end(); iter++) { + if (iter->second) { + bool is_needed = is_load_needed_in_reverse(gutils, AA, iter->first); + can_modref_map[iter->first] = is_needed; + } + } + } + gutils->can_modref_map = &can_modref_map; + gutils->forceContexts(true); gutils->forceAugmentedReturns(); @@ -1602,7 +1913,6 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co } } - for(BasicBlock* BB: gutils->originalBlocks) { auto BB2 = gutils->reverseBlocks[BB]; assert(BB2); @@ -1648,6 +1958,8 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co assert(0 && "unknown terminator inst"); } + + for (BasicBlock::reverse_iterator I = BB->rbegin(), E = BB->rend(); I != E;) { Instruction* inst = &*I; assert(inst); @@ -1696,8 +2008,12 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co break; } default: + //continue; assert(op); - llvm::errs() << *gutils->newFunc << "\n"; + llvm::errs() << "OLDFUNC:\n" << *gutils->oldFunc << "\n"; + llvm::errs() << "NEWFUNC:\n" << *gutils->newFunc << "\n"; + + llvm::errs() << "cannot handle unknown binary operator: " << *op << "\n"; report_fatal_error("unknown binary operator"); } @@ -1932,7 +2248,7 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co if (dif0) addToDiffe(op->getOperand(0), dif0); if (dif1) addToDiffe(op->getOperand(1), dif1); } else if(auto op = dyn_cast_or_null(inst)) { - handleGradientCallInst(I, E, Builder2, op, gutils, TLI, AA, topLevel, replacedReturns); + handleGradientCallInst(I, E, Builder2, op, gutils, TLI, AA, topLevel, replacedReturns, volatile_args_map[op]); } else if(auto op = dyn_cast_or_null(inst)) { if (gutils->isConstantValue(inst)) continue; if (op->getType()->isPointerTy()) continue; @@ -1949,24 +2265,23 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co if (dif1) addToDiffe(op->getOperand(1), dif1); if (dif2) addToDiffe(op->getOperand(2), dif2); } else if(auto op = dyn_cast(inst)) { - if (gutils->isConstantValue(inst)) continue; - - + if (gutils->isConstantValue(inst) || gutils->isConstantInstruction(inst)) continue; auto op_operand = op->getPointerOperand(); auto op_type = op->getType(); if (cachereads) { - llvm::errs() << "Forcibly loading cached reads " << *op << "\n"; - IRBuilder<> BuilderZ(op->getNextNode()); - inst = cast(gutils->addMalloc(BuilderZ, inst)); - if (inst != op) { - // Set to nullptr since op should never be used after invalidated through addMalloc. - op = nullptr; - gutils->nonconstant_values.insert(inst); - gutils->nonconstant.insert(inst); - gutils->originalInstructions.insert(inst); - assert(inst->getType() == op_type); + if (can_modref_map[inst]) { + IRBuilder<> BuilderZ(op->getNextNode()); + inst = cast(gutils->addMalloc(BuilderZ, inst)); + if (inst != op) { + // Set to nullptr since op should never be used after invalidated through addMalloc. + op = nullptr; + gutils->nonconstant_values.insert(inst); + gutils->nonconstant.insert(inst); + gutils->originalInstructions.insert(inst); + assert(inst->getType() == op_type); + } } } diff --git a/enzyme/Enzyme/EnzymeLogic.h b/enzyme/Enzyme/EnzymeLogic.h index ac65e7734432..ec54b19e4b77 100644 --- a/enzyme/Enzyme/EnzymeLogic.h +++ b/enzyme/Enzyme/EnzymeLogic.h @@ -36,6 +36,6 @@ extern llvm::cl::opt enzyme_print; //! return structtype if recursive function std::pair CreateAugmentedPrimal(llvm::Function* todiff, llvm::AAResults &AA, const std::set& constant_args, llvm::TargetLibraryInfo &TLI, bool differentialReturn); -llvm::Function* CreatePrimalAndGradient(llvm::Function* todiff, const std::set& constant_args, llvm::TargetLibraryInfo &TLI, llvm::AAResults &AA, bool returnValue, bool differentialReturn, bool topLevel, llvm::Type* additionalArg); +llvm::Function* CreatePrimalAndGradient(llvm::Function* todiff, const std::set& constant_args, llvm::TargetLibraryInfo &TLI, llvm::AAResults &AA, bool returnValue, bool differentialReturn, bool topLevel, llvm::Type* additionalArg, std::set volatile_args); #endif diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 38d9a177a4d4..c8dd75d0b4ee 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -439,7 +439,7 @@ Function* preprocessForClone(Function *F, AAResults &AA, TargetLibraryInfo &TLI) FunctionAnalysisManager AM; AM.registerPass([] { return AAManager(); }); AM.registerPass([] { return ScalarEvolutionAnalysis(); }); - AM.registerPass([] { return AssumptionAnalysis(); }); + //AM.registerPass([] { return AssumptionAnalysis(); }); AM.registerPass([] { return TargetLibraryAnalysis(); }); AM.registerPass([] { return TargetIRAnalysis(); }); AM.registerPass([] { return LoopAnalysis(); }); @@ -458,13 +458,22 @@ Function* preprocessForClone(Function *F, AAResults &AA, TargetLibraryInfo &TLI) MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(AM); }); //Alias analysis is necessary to ensure can query whether we can move a forward pass function - BasicAA ba; - auto baa = new BasicAAResult(ba.run(*NewF, AM)); + //BasicAA ba; + //auto baa = new BasicAAResult(ba.run(*NewF, AM)); + AssumptionCache* AC = new AssumptionCache(*NewF); + TargetLibraryInfo* TLI = new TargetLibraryInfo(AM.getResult(*NewF)); + auto baa = new BasicAAResult(NewF->getParent()->getDataLayout(), + *NewF, + *TLI, + *AC, + &AM.getResult(*NewF), + AM.getCachedResult(*NewF), + AM.getCachedResult(*NewF)); AA.addAAResult(*baa); - ScopedNoAliasAA sa; - auto saa = new ScopedNoAliasAAResult(sa.run(*NewF, AM)); - AA.addAAResult(*saa); + //ScopedNoAliasAA sa; + //auto saa = new ScopedNoAliasAAResult(sa.run(*NewF, AM)); + //AA.addAAResult(*saa); } diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index a99159c92574..09e1f0ee7e08 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -164,6 +164,10 @@ static bool isParentOrSameContext(LoopContext & possibleChild, LoopContext & pos // Case 2: The correct exiting block terminator unconditionally branches a different block, change to a conditional branch depending on if we are the first iteration } else if (succ.size() == 1) { + lc.latchMerge->getTerminator()->eraseFromParent(); + mergeBuilder.SetInsertPoint(lc.latchMerge); + + assert(mergeBuilder.GetInsertBlock()->size() == 0 || !isa(mergeBuilder.GetInsertBlock()->back())); // If first iteration, branch to the exiting block, otherwise the backlatch mergeBuilder.CreateCondBr(firstiter, succ[0], reverseBlocks[backlatch]); @@ -187,6 +191,8 @@ static bool isParentOrSameContext(LoopContext & possibleChild, LoopContext & pos lc.latchMerge->getTerminator()->eraseFromParent(); mergeBuilder.SetInsertPoint(lc.latchMerge); + + assert(mergeBuilder.GetInsertBlock()->size() == 0 || !isa(mergeBuilder.GetInsertBlock()->back())); mergeBuilder.CreateCondBr(firstiter, splitBlock, reverseBlocks[backlatch]); } @@ -310,13 +316,13 @@ Value* GradientUtils::invertPointerM(Value* val, IRBuilder<>& BuilderM) { } if(isConstantValue(val)) { - llvm::errs() << *oldFunc << "\n"; - llvm::errs() << *newFunc << "\n"; + llvm::errs() << "TFKDEBUG OLD:\n" << *oldFunc << "\n"; + llvm::errs() << "TFKDEBUG NEW:\n" << *newFunc << "\n"; dumpSet(this->originalInstructions); if (auto arg = dyn_cast(val)) { - llvm::errs() << *arg->getParent()->getParent() << "\n"; + llvm::errs() << "TFKDEBUG arg getparent getparent: \n" << *arg->getParent()->getParent() << "\n"; } - llvm::errs() << *val << "\n"; + llvm::errs() << "TFKDEBUG val: "<< *val << "\n"; } assert(!isConstantValue(val)); @@ -345,7 +351,8 @@ Value* GradientUtils::invertPointerM(Value* val, IRBuilder<>& BuilderM) { return invertedPointers[val] = cs; } else if (auto fn = dyn_cast(val)) { //! Todo allow tape propagation - auto newf = CreatePrimalAndGradient(fn, /*constant_args*/{}, TLI, AA, /*returnValue*/false, /*differentialReturn*/fn->getReturnType()->isFPOrFPVectorTy(), /*topLevel*/false, /*additionalArg*/nullptr); + std::set volatile_args; + auto newf = CreatePrimalAndGradient(fn, /*constant_args*/{}, TLI, AA, /*returnValue*/false, /*differentialReturn*/fn->getReturnType()->isFPOrFPVectorTy(), /*topLevel*/false, /*additionalArg*/nullptr, volatile_args); return BuilderM.CreatePointerCast(newf, fn->getType()); } else if (auto arg = dyn_cast(val)) { auto result = BuilderM.CreateCast(arg->getOpcode(), invertPointerM(arg->getOperand(0), BuilderM), arg->getDestTy(), arg->getName()+"'ipc"); @@ -377,7 +384,22 @@ Value* GradientUtils::invertPointerM(Value* val, IRBuilder<>& BuilderM) { } else if (auto arg = dyn_cast(val)) { assert(arg->getType()->isIntOrIntVectorTy()); IRBuilder <> bb(arg); - auto li = bb.CreateBinOp(arg->getOpcode(), invertPointerM(arg->getOperand(0), bb), invertPointerM(arg->getOperand(1), bb), arg->getName()); + Value* val0 = nullptr; + Value* val1 = nullptr; + + if (isa(arg->getOperand(0))) { + val0 = arg->getOperand(0); + val1 = invertPointerM(arg->getOperand(1), bb); + } else if (isa(arg->getOperand(1))) { + val0 = invertPointerM(arg->getOperand(0), bb); + val1 = arg->getOperand(1); + } else { + val0 = invertPointerM(arg->getOperand(0), bb); + val1 = invertPointerM(arg->getOperand(1), bb); + } + + + auto li = bb.CreateBinOp(arg->getOpcode(), val0, val1, arg->getName()); invertedPointers[arg] = li; return lookupM(invertedPointers[arg], BuilderM); } else if (auto arg = dyn_cast(val)) { @@ -818,10 +840,12 @@ Value* GradientUtils::lookupM(Value* val, IRBuilder<>& BuilderM) { } } - if (!shouldRecompute(inst, available)) { - auto op = unwrapM(inst, BuilderM, available, /*lookupIfAble*/true); - assert(op); - return op; + if (!(*(this->can_modref_map))[inst]) { + if (!shouldRecompute(inst, available)) { + auto op = unwrapM(inst, BuilderM, available, /*lookupIfAble*/true); + assert(op); + return op; + } } /* if (!inLoop) { @@ -858,6 +882,7 @@ void GradientUtils::branchToCorrespondingTarget(BasicBlock* ctx, IRBuilder <>& B if (targetToPreds.size() == 1) { if (replacePHIs == nullptr) { + assert(BuilderM.GetInsertBlock()->size() == 0 || !isa(BuilderM.GetInsertBlock()->back())); BuilderM.CreateBr( targetToPreds.begin()->first ); } else { for (auto pair : *replacePHIs) { @@ -962,6 +987,7 @@ void GradientUtils::branchToCorrespondingTarget(BasicBlock* ctx, IRBuilder <>& B Value* phi = lookupValueFromCache(BuilderM, ctx, cache); if (replacePHIs == nullptr) { + assert(BuilderM.GetInsertBlock()->size() == 0 || !isa(BuilderM.GetInsertBlock()->back())); BuilderM.CreateCondBr(phi, *done[std::make_pair(block, branch->getSuccessor(0))].begin(), *done[std::make_pair(block, branch->getSuccessor(1))].begin()); } else { for (auto pair : *replacePHIs) { @@ -1076,6 +1102,7 @@ void GradientUtils::branchToCorrespondingTarget(BasicBlock* ctx, IRBuilder <>& B if (replacePHIs == nullptr) { if (targetToPreds.size() == 2) { + assert(BuilderM.GetInsertBlock()->size() == 0 || !isa(BuilderM.GetInsertBlock()->back())); BuilderM.CreateCondBr(which, /*true*/targets[1], /*false*/targets[0]); } else { auto swit = BuilderM.CreateSwitch(which, targets.back(), targets.size()-1); diff --git a/enzyme/Enzyme/GradientUtils.h b/enzyme/Enzyme/GradientUtils.h index 7b71c48195ee..7efa4df80f24 100644 --- a/enzyme/Enzyme/GradientUtils.h +++ b/enzyme/Enzyme/GradientUtils.h @@ -89,6 +89,9 @@ class GradientUtils { ValueToValueMapTy scopeFrees; ValueToValueMapTy originalToNewFn; + std::map* can_modref_map; + + Value* getNewFromOriginal(Value* originst) { assert(originst); auto f = originalToNewFn.find(originst); @@ -507,7 +510,7 @@ class GradientUtils { } assert(lastScopeAlloc.find(malloc) == lastScopeAlloc.end()); cast(malloc)->replaceAllUsesWith(ret); - auto n = malloc->getName(); + std::string n = malloc->getName().str(); erase(cast(malloc)); ret->setName(n); } @@ -709,7 +712,7 @@ class GradientUtils { IRBuilder<> BuilderZ(getNextNonDebugInstruction(op)); BuilderZ.setFastMathFlags(getFast()); - this->invertedPointers[op] = BuilderZ.CreatePHI(op->getType(), 1); + this->invertedPointers[op] = BuilderZ.CreatePHI(op->getType(), 1, op->getName() + "_fa"); if ( called && (called->getName() == "malloc" || called->getName() == "_Znwm")) { this->invertedPointers[op]->setName(op->getName()+"'mi"); @@ -894,98 +897,176 @@ class GradientUtils { return nullptr; report_fatal_error("unable to unwrap"); } + + //! returns true indices + std::vector>>> getSubLimits(BasicBlock* ctx) { + std::vector contexts; + for (BasicBlock* blk = ctx; blk != nullptr; ) { + LoopContext idx; + if (!getContext(blk, idx)) { + break; + } + llvm::errs() << " adding to contexts: " << idx.header->getName() << " starting ctx=" << ctx->getName() << "\n"; + contexts.emplace_back(idx); + blk = idx.preheader; + } + + std::vector allocationPreheaders(contexts.size(), nullptr); + std::vector limits(contexts.size(), nullptr); + for(int i=contexts.size()-1; i >= 0; i--) { + if ((unsigned)i == contexts.size() - 1) { + allocationPreheaders[i] = contexts[i].preheader; + } else if (contexts[i].dynamic) { + allocationPreheaders[i] = contexts[i].preheader; + } else { + allocationPreheaders[i] = allocationPreheaders[i+1]; + } + + if (contexts[i].dynamic) { + limits[i] = ConstantInt::get(Type::getInt64Ty(ctx->getContext()), 1); + } else { + //while (limits[i] == nullptr) { + ValueToValueMapTy emptyMap; + IRBuilder <> allocationBuilder(&allocationPreheaders[i]->back()); + Value* limitMinus1 = unwrapM(contexts[i].limit, allocationBuilder, emptyMap, /*lookupIfAble*/false); + if (limitMinus1 == nullptr) { + assert(allocationPreheaders[i]); + llvm::errs() << *oldFunc << "\n"; + llvm::errs() << *newFunc << "\n"; + llvm::errs() << "needed value " << *contexts[i].limit << " at " << allocationPreheaders[i]->getName() << "\n"; + } + assert(limitMinus1 != nullptr); + limits[i] = allocationBuilder.CreateNUWAdd(limitMinus1, ConstantInt::get(limitMinus1->getType(), 1)); + //TODO allow triangular arrays per above + /* + if (limits[i] == nullptr) { + int firstDifferent = j+1; + while (allocationPreheaders[firstDifferent] == allocationPreheaders[i]) { + firstDifferent++; + assert(firstDifferent < contexts.size()); + } + allocationPreheaders[i] = allocationPreheaders[firstDifferent]; + } + }*/ + } + } + + std::vector>>> sublimits; + + Value* size = nullptr; + std::vector> lims; + for(unsigned i=0; i < contexts.size(); i++) { + IRBuilder <> allocationBuilder(&allocationPreheaders[i]->back()); + lims.push_back(std::make_pair(contexts[i], limits[i])); + if (size == nullptr) { + size = limits[i]; + } else { + size = allocationBuilder.CreateNUWMul(size, limits[i]); + } + + llvm::errs() << "considering ctx " << ctx->getName() << " alph=" << allocationPreheaders[i]->getName() << " ctxheader=" << contexts[i].header->getName() << "\n"; + if (contexts[i].dynamic) { + llvm::errs() << "starting outermost ph at " << allocationPreheaders[i]->getName() << "|ctx=" << ctx->getName() <<"\n"; + sublimits.push_back(std::make_pair(size, lims)); + size = nullptr; + lims.clear(); + } + } + + if (size != nullptr) { + llvm::errs() << "starting final outermost ph at " << allocationPreheaders[contexts.size()-1]->getName()<<"|ctx=" << ctx->getName() << "\n"; + sublimits.push_back(std::make_pair(size, lims)); + lims.clear(); + } + return sublimits; + } //! Caching mechanism: creates a cache of type T in a scope given by ctx (where if ctx is in a loop there will be a corresponding number of slots) AllocaInst* createCacheForScope(BasicBlock* ctx, Type* T, StringRef name, CallInst** freeLocation, Instruction** lastScopeAllocLocation) { assert(ctx); assert(T); - LoopContext lc; - bool inLoop = getContext(ctx, lc); + + auto sublimits = getSubLimits(ctx); + + /* goes from inner loop to outer loop*/ + std::vector types = {T}; + for(const auto sublimit: sublimits) { + types.push_back(PointerType::getUnqual(types.back())); + } assert(inversionAllocs && "must be able to allocate inverted caches"); IRBuilder<> entryBuilder(inversionAllocs); entryBuilder.setFastMathFlags(getFast()); - - if (!inLoop) { - return entryBuilder.CreateAlloca(T, nullptr, name+"_cache"); - } else { - - BasicBlock* outermostPreheader = nullptr; - - for(LoopContext idx = lc; ; getContext(idx.parent->getHeader(), idx) ) { - if (idx.parent == nullptr) { - outermostPreheader = idx.preheader; - } - if (idx.parent == nullptr) break; + AllocaInst* alloc = entryBuilder.CreateAlloca(types.back(), nullptr, name+"_cache"); + llvm::errs() << "alloc: "<< *alloc << "\n"; + + Type *BPTy = Type::getInt8PtrTy(ctx->getContext()); + auto realloc = newFunc->getParent()->getOrInsertFunction("realloc", BPTy, BPTy, Type::getInt64Ty(ctx->getContext())); + + Value* storeInto = alloc; + ValueToValueMapTy antimap; + + for(int i=sublimits.size()-1; i>=0; i--) { + const auto& containedloops = sublimits[i].second; + for(auto riter = containedloops.rbegin(), rend = containedloops.rend(); riter != rend; riter++) { + const auto& idx = riter->first; + antimap[idx.var] = idx.antivar; } - assert(outermostPreheader); - IRBuilder <> allocationBuilder(&outermostPreheader->back()); + Value* size = sublimits[i].first; + Type* myType = types[i]; + + IRBuilder <> allocationBuilder(&containedloops.back().first.preheader->back()); + if (!sublimits[i].second.back().first.dynamic) { + auto firstallocation = CallInst::CreateMalloc( + &allocationBuilder.GetInsertBlock()->back(), + size->getType(), + myType, + ConstantInt::get(size->getType(), allocationBuilder.GetInsertBlock()->getParent()->getParent()->getDataLayout().getTypeAllocSizeInBits(myType)/8), size, nullptr, name+"_malloccache"); + CallInst* malloccall = dyn_cast(firstallocation); + if (malloccall == nullptr) { + malloccall = cast(cast(firstallocation)->getOperand(0)); + } + malloccall->addAttribute(AttributeList::ReturnIndex, Attribute::NoAlias); + malloccall->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull); + + allocationBuilder.CreateStore(firstallocation, storeInto); + + if (lastScopeAllocLocation) { + *lastScopeAllocLocation = cast(firstallocation); + } - Value* size = nullptr; - static std::map sizecache; - if (sizecache.find(lc.header) != sizecache.end()) { - size = sizecache[lc.header]; + //allocationBuilder.GetInsertBlock()->getInstList().push_back(cast(allocation)); + //cast(firstallocation)->moveBefore(allocationBuilder.GetInsertBlock()->getTerminator()); + //mallocs.push_back(firstallocation); } else { - for(LoopContext idx = lc; ; getContext(idx.parent->getHeader(), idx) ) { - //TODO handle allocations for dynamic loops - if (idx.dynamic && idx.parent != nullptr) { - assert(idx.var); - assert(idx.var->getParent()); - assert(idx.var->getParent()->getParent()); - llvm::errs() << *idx.var->getParent()->getParent() << "\n" - << "idx.var=" <<*idx.var << "\n" - << "idx.limit=" <<*idx.limit << "\n"; - llvm::errs() << "cannot handle non-outermost dynamic loop\n"; - assert(0 && "cannot handle non-outermost dynamic loop"); - } - Value* ns = nullptr; - Type* intT = idx.dynamic ? cast(idx.limit->getType())->getElementType() : idx.limit->getType(); - if (idx.dynamic) { - ns = ConstantInt::get(intT, 1); - } else { - Value* limitm1 = nullptr; - ValueToValueMapTy emptyMap; - limitm1 = unwrapM(idx.limit, allocationBuilder, emptyMap, /*lookupIfAble*/false); - if (limitm1 == nullptr) { - assert(outermostPreheader); - assert(outermostPreheader->getParent()); - llvm::errs() << *outermostPreheader->getParent() << "\n"; - llvm::errs() << "needed value " << *idx.limit << " at " << allocationBuilder.GetInsertBlock()->getName() << "\n"; - } - assert(limitm1); - ns = allocationBuilder.CreateNUWAdd(limitm1, ConstantInt::get(intT, 1)); - } - if (size == nullptr) size = ns; - else size = allocationBuilder.CreateNUWMul(size, ns); - if (idx.parent == nullptr) break; - } - sizecache[lc.header] = size; - } + llvm::errs() << "storeInto: " << *storeInto << "\n"; + llvm::errs() << "myType: " << *myType << "\n"; + allocationBuilder.CreateStore(ConstantPointerNull::get(PointerType::getUnqual(myType)), storeInto); + + IRBuilder <> build(containedloops.back().first.header->getFirstNonPHI()); + Value* allocation = build.CreateLoad(storeInto); + Value* foo = build.CreateNUWAdd(containedloops.back().first.var, ConstantInt::get(Type::getInt64Ty(ctx->getContext()), 1)); + Value* realloc_size = build.CreateNUWMul(foo, sublimits[i].first); + Value* idxs[2] = { + build.CreatePointerCast(allocation, BPTy), + build.CreateNUWMul( + ConstantInt::get(size->getType(), newFunc->getParent()->getDataLayout().getTypeAllocSizeInBits(myType)/8), realloc_size + ) + }; - auto firstallocation = CallInst::CreateMalloc( - &allocationBuilder.GetInsertBlock()->back(), - size->getType(), - T, - ConstantInt::get(size->getType(), allocationBuilder.GetInsertBlock()->getParent()->getParent()->getDataLayout().getTypeAllocSizeInBits(T)/8), size, nullptr, name+"_malloccache"); - CallInst* malloccall = dyn_cast(firstallocation); - if (malloccall == nullptr) { - malloccall = cast(cast(firstallocation)->getOperand(0)); + Value* realloccall = nullptr; + allocation = build.CreatePointerCast(realloccall = build.CreateCall(realloc, idxs, name+"_realloccache"), allocation->getType()); + if (lastScopeAllocLocation) { + *lastScopeAllocLocation = cast(allocation); + } + build.CreateStore(allocation, storeInto); } - malloccall->addAttribute(AttributeList::ReturnIndex, Attribute::NoAlias); - malloccall->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull); - //allocationBuilder.GetInsertBlock()->getInstList().push_back(cast(allocation)); - cast(firstallocation)->moveBefore(allocationBuilder.GetInsertBlock()->getTerminator()); - AllocaInst* holderAlloc = entryBuilder.CreateAlloca(firstallocation->getType(), nullptr, name+"_mdyncache"); - if (lastScopeAllocLocation) - *lastScopeAllocLocation = firstallocation; - allocationBuilder.CreateStore(firstallocation, holderAlloc); - if (freeLocation) { assert(reverseBlocks.size()); - IRBuilder<> tbuild(reverseBlocks[outermostPreheader]); + IRBuilder<> tbuild(reverseBlocks[containedloops.back().first.preheader]); tbuild.setFastMathFlags(getFast()); // ensure we are before the terminator if it exists @@ -993,165 +1074,116 @@ class GradientUtils { tbuild.SetInsertPoint(tbuild.GetInsertBlock()->getFirstNonPHI()); } - auto ci = cast(CallInst::CreateFree(tbuild.CreatePointerCast(tbuild.CreateLoad(holderAlloc), Type::getInt8PtrTy(outermostPreheader->getContext())), tbuild.GetInsertBlock())); + auto ci = cast(CallInst::CreateFree(tbuild.CreatePointerCast(tbuild.CreateLoad(unwrapM(storeInto, tbuild, antimap, /*lookup*/false)), Type::getInt8PtrTy(ctx->getContext())), tbuild.GetInsertBlock())); ci->addAttribute(AttributeList::FirstArgIndex, Attribute::NonNull); if (ci->getParent()==nullptr) { tbuild.Insert(ci); } *freeLocation = ci; } - - IRBuilder <> v(ctx->getFirstNonPHI()); - v.setFastMathFlags(getFast()); - - SmallVector indices; - SmallVector limits; - PHINode* dynamicPHI = nullptr; - - for(LoopContext idx = lc; ; getContext(idx.parent->getHeader(), idx) ) { - indices.push_back(idx.var); - - if (idx.dynamic) { - dynamicPHI = idx.var; - assert(dynamicPHI); - llvm::errs() << "saw idx.dynamic:" << *dynamicPHI << "\n"; - assert(idx.parent == nullptr); - break; - } - - if (idx.parent == nullptr) break; - ValueToValueMapTy emptyMap; - auto limitm1 = unwrapM(idx.limit, v, emptyMap, /*lookupIfAble*/false); - assert(limitm1); - Type* intT = idx.dynamic ? cast(idx.limit->getType())->getElementType() : idx.limit->getType(); - auto lim = v.CreateNUWAdd(limitm1, ConstantInt::get(intT, 1)); - if (limits.size() != 0) { - lim = v.CreateNUWMul(lim, limits.back()); - } - limits.push_back(lim); - } - - /* - Value* idx = nullptr; - for(unsigned i=0; igetContext()); - auto realloc = newFunc->getParent()->getOrInsertFunction("realloc", BPTy, BPTy, size->getType()); - Value* allocation = v.CreateLoad(holderAlloc); - Value* foo = v.CreateNUWAdd(dynamicPHI, ConstantInt::get(dynamicPHI->getType(), 1)); - Value* realloc_size = v.CreateNUWMul(size, foo); - Value* idxs[2] = { - v.CreatePointerCast(allocation, BPTy), - v.CreateNUWMul( - ConstantInt::get(size->getType(), newFunc->getParent()->getDataLayout().getTypeAllocSizeInBits(T)/8), realloc_size - ) - }; - - Value* realloccall = nullptr; - allocation = v.CreatePointerCast(realloccall = v.CreateCall(realloc, idxs, name+"_realloccache"), allocation->getType()); - if (lastScopeAllocLocation) { - *lastScopeAllocLocation = cast(allocation); + + if (i != 0) { + IRBuilder <>v(&sublimits[i-1].second.back().first.preheader->back()); + //TODO + if (!sublimits[i].second.back().first.dynamic) { + storeInto = v.CreateGEP(v.CreateLoad(storeInto), sublimits[i].second.back().first.var); + } else { + storeInto = v.CreateGEP(v.CreateLoad(storeInto), sublimits[i].second.back().first.var); } - v.CreateStore(allocation, holderAlloc); } - return holderAlloc; } + return alloc; } - void storeInstructionInCache(BasicBlock* ctx, IRBuilder <>& BuilderM, Value* val, AllocaInst* cache) { + Value* getCachePointer(IRBuilder <>& BuilderM, BasicBlock* ctx, Value* cache) { assert(ctx); - assert(val); assert(cache); - LoopContext lc; - bool inLoop = getContext(ctx, lc); + + auto sublimits = getSubLimits(ctx); + + ValueToValueMapTy available; + + Value* next = cache; + for(int i=sublimits.size()-1; i>=0; i--) { + next = BuilderM.CreateLoad(next); - if (!inLoop) { - BuilderM.CreateStore(val, cache); - } else { - IRBuilder <> v(BuilderM); - v.setFastMathFlags(getFast()); - - //Note for dynamic loops where the allocation is stored somewhere inside the loop, - // we must ensure that we load the allocation after the store ensuring memory exists - // This does not need to occur (and will find no such store) for nondynamic loops - // as memory is statically allocated in the preheader - for (auto I = BuilderM.GetInsertBlock()->rbegin(), E = BuilderM.GetInsertBlock()->rend(); I != E; I++) { - if (&*I == &*BuilderM.GetInsertPoint()) break; - if (auto si = dyn_cast(&*I)) { - if (si->getPointerOperand() == cache) { - v.SetInsertPoint(getNextNonDebugInstruction(si)); - } - } - } + const auto& containedloops = sublimits[i].second; SmallVector indices; SmallVector limits; - PHINode* dynamicPHI = nullptr; - - for(LoopContext idx = lc; ; getContext(idx.parent->getHeader(), idx) ) { - indices.push_back(idx.var); - - if (idx.dynamic) { - dynamicPHI = idx.var; - assert(dynamicPHI); - llvm::errs() << "saw idx.dynamic:" << *dynamicPHI << "\n"; - assert(idx.parent == nullptr); - break; - } - - if (idx.parent == nullptr) break; - ValueToValueMapTy emptyMap; - auto limitm1 = unwrapM(idx.limit, v, emptyMap, /*lookupIfAble*/false); - assert(limitm1); - auto lim = v.CreateNUWAdd(limitm1, ConstantInt::get(idx.limit->getType(), 1)); - if (limits.size() != 0) { - lim = v.CreateNUWMul(lim, limits.back()); + for(auto riter = containedloops.rbegin(), rend = containedloops.rend(); riter != rend; riter++) { + // Only include dynamic index on last iteration (== skip dynamic index on non-last iterations) + //if (i != 0 && riter+1 == rend) break; + + const auto &idx = riter->first; + if (!isOriginalBlock(*BuilderM.GetInsertBlock())) { + indices.push_back(idx.antivar); + available[idx.var] = idx.antivar; + } else { + indices.push_back(idx.var); + available[idx.var] = idx.var; } - limits.push_back(lim); - } + llvm::errs() << "W sl idx=" << i << " " << *idx.var << " header=" << idx.header->getName() << "\n"; - Value* idx = nullptr; - for(unsigned i=0; isecond, BuilderM, available, /*lookupIfAble*/true); + assert(lim); + if (limits.size() == 0) { + limits.push_back(lim); } else { - auto mul = v.CreateNUWMul(indices[i], limits[i-1]); - idx = v.CreateNUWAdd(idx, mul); + limits.push_back(BuilderM.CreateNUWMul(lim, limits.back())); } } - Value* allocation = nullptr; - if (dynamicPHI == nullptr) { - BasicBlock* outermostPreheader = nullptr; + if (indices.size() > 0) { + llvm::errs() << "sl idx=" << i << " " << *indices[0] << "\n"; + Value* idx = indices[0]; + for(unsigned ind=1; ind& BuilderM, BasicBlock* ctx, Value* cache) { + auto result = BuilderM.CreateLoad(getCachePointer(BuilderM, ctx, cache)); + result->setMetadata(LLVMContext::MD_invariant_load, MDNode::get(ctx->getContext(), {})); + return result; + } - for(LoopContext idx = lc; ; getContext(idx.parent->getHeader(), idx) ) { - if (idx.parent == nullptr) { - outermostPreheader = idx.preheader; - } - if (idx.parent == nullptr) break; - } - assert(outermostPreheader); + void storeInstructionInCache(BasicBlock* ctx, IRBuilder <>& BuilderM, Value* val, AllocaInst* cache) { + IRBuilder <> v(BuilderM); + v.setFastMathFlags(getFast()); - IRBuilder<> outerBuilder(&outermostPreheader->back()); - allocation = outerBuilder.CreateLoad(cache); - } else { - allocation = v.CreateLoad(cache); + //Note for dynamic loops where the allocation is stored somewhere inside the loop, + // we must ensure that we load the allocation after the store ensuring memory exists + // This does not need to occur (and will find no such store) for nondynamic loops + // as memory is statically allocated in the preheader + for (auto I = BuilderM.GetInsertBlock()->rbegin(), E = BuilderM.GetInsertBlock()->rend(); I != E; I++) { + if (&*I == &*BuilderM.GetInsertPoint()) break; + if (auto si = dyn_cast(&*I)) { + if (si->getPointerOperand() == cache) { + v.SetInsertPoint(getNextNonDebugInstruction(si)); + } } - - Value* idxs[] = {idx}; - auto gep = v.CreateGEP(allocation, idxs); - v.CreateStore(val, gep); } + v.CreateStore(val, getCachePointer(v, ctx, cache)); } + void storeInstructionInCache(BasicBlock* ctx, Instruction* inst, AllocaInst* cache) { assert(ctx); assert(inst); @@ -1186,60 +1218,6 @@ class GradientUtils { storeInstructionInCache(inst->getParent(), inst, cache); } - LoadInst* lookupValueFromCache(IRBuilder<>& BuilderM, BasicBlock* ctx, Value* cache) { - assert(ctx); - assert(cache); - LoopContext lc; - bool inLoop = getContext(ctx, lc); - - if (!inLoop) { - auto result = BuilderM.CreateLoad(cache); - result->setMetadata(LLVMContext::MD_invariant_load, MDNode::get(ctx->getContext(), {})); - return result; - } else { - - ValueToValueMapTy available; - for(LoopContext idx = lc; ; getContext(idx.parent->getHeader(), idx)) { - if (!isOriginalBlock(*BuilderM.GetInsertBlock())) { - available[idx.var] = idx.antivar; - } else { - available[idx.var] = idx.var; - } - if (idx.parent == nullptr) break; - } - - SmallVector indices; - SmallVector limits; - for(LoopContext idx = lc; ; getContext(idx.parent->getHeader(), idx) ) { - indices.push_back(unwrapM(idx.var, BuilderM, available, /*lookupIfAble*/false)); - if (idx.parent == nullptr) break; - - auto limitm1 = unwrapM(idx.limit, BuilderM, available, /*lookupIfAble*/true); - assert(limitm1); - auto lim = BuilderM.CreateNUWAdd(limitm1, ConstantInt::get(idx.limit->getType(), 1)); - if (limits.size() != 0) { - lim = BuilderM.CreateNUWMul(lim, limits.back()); - } - limits.push_back(lim); - } - - Value* idx = nullptr; - for(unsigned i=0; isetMetadata(LLVMContext::MD_invariant_load, MDNode::get(result->getContext(), {})); - return result; - } - } - Instruction* fixLCSSA(Instruction* inst, const IRBuilder <>& BuilderM) { LoopContext lc; bool inLoop = getContext(inst->getParent(), lc); diff --git a/enzyme/functional_tests_c/Makefile b/enzyme/functional_tests_c/Makefile index 8d1c98051e0d..1e806d499aed 100644 --- a/enzyme/functional_tests_c/Makefile +++ b/enzyme/functional_tests_c/Makefile @@ -18,8 +18,12 @@ OBJ := $(wildcard *.c) all: $(patsubst %.c,build/%-enzyme0,$(OBJ)) $(patsubst %.c,build/%-enzyme1,$(OBJ)) $(patsubst %.c,build/%-enzyme2,$(OBJ)) $(patsubst %.c,build/%-enzyme3,$(OBJ)) -POST_ENZYME_FLAGS := -mem2reg -sroa -adce -simplifycfg -enzyme_cachereads=true +POST_ENZYME_FLAGS := -mem2reg -sroa -adce -simplifycfg -enzyme_cachereads=true -enzyme_print=true +#CC := clang +#CSTD := -std=c11 +CC := clang +CSTD := -std=c11 #all: $(patsubst %.c,build/%-enzyme1,$(OBJ)) $(patsubst %.c,build/%-enzyme2,$(OBJ)) $(patsubst %.c,build/%-enzyme3,$(OBJ)) #clean: # rm -f main main-* main.ll @@ -31,30 +35,30 @@ POST_ENZYME_FLAGS := -mem2reg -sroa -adce -simplifycfg -enzyme_cachereads=true #EXTRA_FLAGS = -indvars -loop-simplify -loop-rotate -# NOTE(TFK): Optimization level 0 is broken right now. +# /efs/home/tfk/valgrind-3.12.0/vg-in-place build/%-enzyme0: %.c - @./setup.sh $(CLANG_BIN_PATH)/clang -std=c11 -O1 $(patsubst %.c,%,$<).c -S -emit-llvm -o $@.ll + @./setup.sh $(CLANG_BIN_PATH)/$(CC) $(CSTD) -O1 $(patsubst %.c,%,$<).c -S -emit-llvm -o $@.ll @./setup.sh $(CLANG_BIN_PATH)/opt $@.ll $(EXTRA_FLAGS) -load=$(ENZYME_PLUGIN) -enzyme $(POST_ENZYME_FLAGS) -o $@.bc - @./setup.sh $(CLANG_BIN_PATH)/clang -std=c11 $@.bc -S -emit-llvm -o $@-final.ll - @./setup.sh $(CLANG_BIN_PATH)/clang -std=c11 $@.bc -o $@ + @./setup.sh $(CLANG_BIN_PATH)/$(CC) $(CSTD) $@.bc -S -emit-llvm -o $@-final.ll + @./setup.sh $(CLANG_BIN_PATH)/$(CC) $(CSTD) $@.bc -o $@ build/%-enzyme1: %.c - @./setup.sh $(CLANG_BIN_PATH)/clang -std=c11 -O1 $(patsubst %.c,%,$<).c -S -emit-llvm -o $@.ll + @./setup.sh $(CLANG_BIN_PATH)/$(CC) $(CSTD) -O1 $(patsubst %.c,%,$<).c -S -emit-llvm -o $@.ll @./setup.sh $(CLANG_BIN_PATH)/opt $@.ll $(EXTRA_FLAGS) -load=$(ENZYME_PLUGIN) -enzyme $(POST_ENZYME_FLAGS) -o $@.bc - @./setup.sh $(CLANG_BIN_PATH)/clang -std=c11 $@.bc -S -emit-llvm -o $@-final.ll - @./setup.sh $(CLANG_BIN_PATH)/clang -std=c11 $@.bc -o $@ + @./setup.sh $(CLANG_BIN_PATH)/$(CC) $(CSTD) $@.bc -S -emit-llvm -o $@-final.ll + @./setup.sh $(CLANG_BIN_PATH)/$(CC) $(CSTD) $@.bc -o $@ build/%-enzyme2: %.c - @./setup.sh $(CLANG_BIN_PATH)/clang -std=c11 -O2 $(patsubst %.c,%,$<).c -S -emit-llvm -o $@.ll + @./setup.sh $(CLANG_BIN_PATH)/$(CC) $(CSTD) -O2 $(patsubst %.c,%,$<).c -S -emit-llvm -o $@.ll @./setup.sh $(CLANG_BIN_PATH)/opt $@.ll $(EXTRA_FLAGS) -load=$(ENZYME_PLUGIN) -enzyme $(POST_ENZYME_FLAGS) -o $@.bc - @./setup.sh $(CLANG_BIN_PATH)/clang -std=c11 $@.bc -S -emit-llvm -o $@-final.ll - @./setup.sh $(CLANG_BIN_PATH)/clang -std=c11 $@.bc -o $@ + @./setup.sh $(CLANG_BIN_PATH)/$(CC) $(CSTD) $@.bc -S -emit-llvm -o $@-final.ll + @./setup.sh $(CLANG_BIN_PATH)/$(CC) $(CSTD) $@.bc -o $@ build/%-enzyme3: %.c - @./setup.sh $(CLANG_BIN_PATH)/clang -std=c11 -O3 $(patsubst %.c,%,$<).c -S -emit-llvm -o $@.ll + @./setup.sh $(CLANG_BIN_PATH)/$(CC) $(CSTD) -O3 $(patsubst %.c,%,$<).c -S -emit-llvm -o $@.ll @./setup.sh $(CLANG_BIN_PATH)/opt $@.ll $(EXTRA_FLAGS) -load=$(ENZYME_PLUGIN) -enzyme $(POST_ENZYME_FLAGS) -o $@.bc - @./setup.sh $(CLANG_BIN_PATH)/clang -std=c11 $@.bc -S -emit-llvm -o $@-final.ll - @./setup.sh $(CLANG_BIN_PATH)/clang -std=c11 $@.bc -o $@ + @./setup.sh $(CLANG_BIN_PATH)/$(CC) $(CSTD) $@.bc -S -emit-llvm -o $@-final.ll + @./setup.sh $(CLANG_BIN_PATH)/$(CC) $(CSTD) $@.bc -o $@ %-enzyme-test0: build/%-enzyme0 diff --git a/enzyme/functional_tests_c/FAIL_insertsort_sum.c b/enzyme/functional_tests_c/insertsort_sum.c similarity index 92% rename from enzyme/functional_tests_c/FAIL_insertsort_sum.c rename to enzyme/functional_tests_c/insertsort_sum.c index 875bf620077c..15556e37af8d 100644 --- a/enzyme/functional_tests_c/FAIL_insertsort_sum.c +++ b/enzyme/functional_tests_c/insertsort_sum.c @@ -17,7 +17,8 @@ float* unsorted_array_init(int N) { } // sums the first half of a sorted array. -void insertsort_sum (float* array, int N, float* ret) { +//__attribute__((noinline)) +void insertsort_sum (float*__restrict array, int N, float*__restrict ret) { float sum = 0; //qsort(array, N, sizeof(float), cmp); @@ -31,17 +32,14 @@ void insertsort_sum (float* array, int N, float* ret) { } } - for (int i = 0; i < N/2; i++) { - printf("Val: %f\n", array[i]); + //printf("Val: %f\n", array[i]); sum += array[i]; } + *ret = sum; } - - - int main(int argc, char** argv) { diff --git a/enzyme/functional_tests_c/loops.c b/enzyme/functional_tests_c/loops.c index 0c1ffcea6f6d..c5d6e4409d6c 100644 --- a/enzyme/functional_tests_c/loops.c +++ b/enzyme/functional_tests_c/loops.c @@ -4,7 +4,6 @@ #define __builtin_autodiff __enzyme_autodiff - double __enzyme_autodiff(void*, ...); //float man_max(float* a, float* b) { diff --git a/enzyme/functional_tests_c/readwriteread.c b/enzyme/functional_tests_c/readwriteread.c new file mode 100644 index 000000000000..355c632190a2 --- /dev/null +++ b/enzyme/functional_tests_c/readwriteread.c @@ -0,0 +1,46 @@ +#include +#include +#include +#include +#define __builtin_autodiff __enzyme_autodiff +double __enzyme_autodiff(void*, ...); + +double f_read(double* x) { + double product = (*x) * (*x); + return product; +} + +void g_write(double* x, double product) { + *x = (*x) * product; +} + +double h_read(double* x) { + return *x; +} + +double readwriteread_helper(double* x) { + double product = f_read(x); + g_write(x, product); + double ret = h_read(x); + return ret; +} + +void readwriteread(double*__restrict x, double*__restrict ret) { + *ret = readwriteread_helper(x); +} + +int main(int argc, char** argv) { + double ret = 0; + double dret = 1.0; + double* x = (double*) malloc(sizeof(double)); + double* dx = (double*) malloc(sizeof(double)); + *x = 2.0; + *dx = 0.0; + + __builtin_autodiff(readwriteread, x, dx, &ret, &dret); + + + printf("dx is %f ret is %f\n", *dx, ret); + assert(*dx == 3*2.0*2.0); + return 0; +} diff --git a/enzyme/functional_tests_c/setup.sh b/enzyme/functional_tests_c/setup.sh index 2d35d8a84745..c5c86df2fbb9 100755 --- a/enzyme/functional_tests_c/setup.sh +++ b/enzyme/functional_tests_c/setup.sh @@ -1,7 +1,7 @@ #!/bin/bash # NOTE(TFK): Uncomment for local testing. -export CLANG_BIN_PATH=./../../llvm/build/bin +export CLANG_BIN_PATH=./../../llvm/build/bin/ export ENZYME_PLUGIN=./../build/Enzyme/LLVMEnzyme-7.so mkdir -p build diff --git a/enzyme/functional_tests_c/testfiles/FAIL_insertsort_sum-enzyme0.test b/enzyme/functional_tests_c/testfiles/FAIL_insertsort_sum-enzyme0.test deleted file mode 100644 index 7e1a286d0c28..000000000000 --- a/enzyme/functional_tests_c/testfiles/FAIL_insertsort_sum-enzyme0.test +++ /dev/null @@ -1,6 +0,0 @@ -; RUN: cd %desired_wd -; RUN: make clean-FAIL_insertsort_sum-enzyme0 ENZYME_PLUGIN=%loadEnzyme -; RUN: make build/FAIL_insertsort_sum-enzyme0 ENZYME_PLUGIN=%loadEnzyme CLANG_BIN_PATH=%clangBinPath -; RUN: build/FAIL_insertsort_sum-enzyme0 -; RUN: make clean-FAIL_insertsort_sum-enzyme0 ENZYME_PLUGIN=%loadEnzyme -; XFAIL: * diff --git a/enzyme/functional_tests_c/testfiles/FAIL_insertsort_sum-enzyme1.test b/enzyme/functional_tests_c/testfiles/FAIL_insertsort_sum-enzyme1.test deleted file mode 100644 index 43264f17d626..000000000000 --- a/enzyme/functional_tests_c/testfiles/FAIL_insertsort_sum-enzyme1.test +++ /dev/null @@ -1,6 +0,0 @@ -; RUN: cd %desired_wd -; RUN: make clean-FAIL_insertsort_sum-enzyme1 ENZYME_PLUGIN=%loadEnzyme -; RUN: make build/FAIL_insertsort_sum-enzyme1 ENZYME_PLUGIN=%loadEnzyme CLANG_BIN_PATH=%clangBinPath -; RUN: build/FAIL_insertsort_sum-enzyme1 -; RUN: make clean-FAIL_insertsort_sum-enzyme1 ENZYME_PLUGIN=%loadEnzyme -; XFAIL: * diff --git a/enzyme/functional_tests_c/testfiles/FAIL_insertsort_sum-enzyme2.test b/enzyme/functional_tests_c/testfiles/FAIL_insertsort_sum-enzyme2.test deleted file mode 100644 index 7713b9d48281..000000000000 --- a/enzyme/functional_tests_c/testfiles/FAIL_insertsort_sum-enzyme2.test +++ /dev/null @@ -1,6 +0,0 @@ -; RUN: cd %desired_wd -; RUN: make clean-FAIL_insertsort_sum-enzyme2 ENZYME_PLUGIN=%loadEnzyme -; RUN: make build/FAIL_insertsort_sum-enzyme2 ENZYME_PLUGIN=%loadEnzyme CLANG_BIN_PATH=%clangBinPath -; RUN: build/FAIL_insertsort_sum-enzyme2 -; RUN: make clean-FAIL_insertsort_sum-enzyme2 ENZYME_PLUGIN=%loadEnzyme -; XFAIL: * diff --git a/enzyme/functional_tests_c/testfiles/FAIL_insertsort_sum-enzyme3.test b/enzyme/functional_tests_c/testfiles/FAIL_insertsort_sum-enzyme3.test deleted file mode 100644 index d8057b28fa26..000000000000 --- a/enzyme/functional_tests_c/testfiles/FAIL_insertsort_sum-enzyme3.test +++ /dev/null @@ -1,6 +0,0 @@ -; RUN: cd %desired_wd -; RUN: make clean-FAIL_insertsort_sum-enzyme3 ENZYME_PLUGIN=%loadEnzyme -; RUN: make build/FAIL_insertsort_sum-enzyme3 ENZYME_PLUGIN=%loadEnzyme CLANG_BIN_PATH=%clangBinPath -; RUN: build/FAIL_insertsort_sum-enzyme3 -; RUN: make clean-FAIL_insertsort_sum-enzyme3 ENZYME_PLUGIN=%loadEnzyme -; XFAIL: * diff --git a/enzyme/functional_tests_c/testfiles/insertsort_sum-enzyme0.test b/enzyme/functional_tests_c/testfiles/insertsort_sum-enzyme0.test new file mode 100644 index 000000000000..3a7577863354 --- /dev/null +++ b/enzyme/functional_tests_c/testfiles/insertsort_sum-enzyme0.test @@ -0,0 +1,6 @@ +; RUN: cd %desired_wd +; RUN: make clean-insertsort_sum-enzyme0 ENZYME_PLUGIN=%loadEnzyme +; RUN: make build/insertsort_sum-enzyme0 ENZYME_PLUGIN=%loadEnzyme CLANG_BIN_PATH=%clangBinPath +; RUN: build/insertsort_sum-enzyme0 +; RUN: make clean-insertsort_sum-enzyme0 ENZYME_PLUGIN=%loadEnzyme + diff --git a/enzyme/functional_tests_c/testfiles/insertsort_sum-enzyme1.test b/enzyme/functional_tests_c/testfiles/insertsort_sum-enzyme1.test new file mode 100644 index 000000000000..648763c9d6cc --- /dev/null +++ b/enzyme/functional_tests_c/testfiles/insertsort_sum-enzyme1.test @@ -0,0 +1,6 @@ +; RUN: cd %desired_wd +; RUN: make clean-insertsort_sum-enzyme1 ENZYME_PLUGIN=%loadEnzyme +; RUN: make build/insertsort_sum-enzyme1 ENZYME_PLUGIN=%loadEnzyme CLANG_BIN_PATH=%clangBinPath +; RUN: build/insertsort_sum-enzyme1 +; RUN: make clean-insertsort_sum-enzyme1 ENZYME_PLUGIN=%loadEnzyme + diff --git a/enzyme/functional_tests_c/testfiles/insertsort_sum-enzyme2.test b/enzyme/functional_tests_c/testfiles/insertsort_sum-enzyme2.test new file mode 100644 index 000000000000..880946e59032 --- /dev/null +++ b/enzyme/functional_tests_c/testfiles/insertsort_sum-enzyme2.test @@ -0,0 +1,6 @@ +; RUN: cd %desired_wd +; RUN: make clean-insertsort_sum-enzyme2 ENZYME_PLUGIN=%loadEnzyme +; RUN: make build/insertsort_sum-enzyme2 ENZYME_PLUGIN=%loadEnzyme CLANG_BIN_PATH=%clangBinPath +; RUN: build/insertsort_sum-enzyme2 +; RUN: make clean-insertsort_sum-enzyme2 ENZYME_PLUGIN=%loadEnzyme + diff --git a/enzyme/functional_tests_c/testfiles/insertsort_sum-enzyme3.test b/enzyme/functional_tests_c/testfiles/insertsort_sum-enzyme3.test new file mode 100644 index 000000000000..96f4c23d025d --- /dev/null +++ b/enzyme/functional_tests_c/testfiles/insertsort_sum-enzyme3.test @@ -0,0 +1,6 @@ +; RUN: cd %desired_wd +; RUN: make clean-insertsort_sum-enzyme3 ENZYME_PLUGIN=%loadEnzyme +; RUN: make build/insertsort_sum-enzyme3 ENZYME_PLUGIN=%loadEnzyme CLANG_BIN_PATH=%clangBinPath +; RUN: build/insertsort_sum-enzyme3 +; RUN: make clean-insertsort_sum-enzyme3 ENZYME_PLUGIN=%loadEnzyme + diff --git a/enzyme/functional_tests_c/testfiles/readwriteread-enzyme0.test b/enzyme/functional_tests_c/testfiles/readwriteread-enzyme0.test new file mode 100644 index 000000000000..14a037d8426b --- /dev/null +++ b/enzyme/functional_tests_c/testfiles/readwriteread-enzyme0.test @@ -0,0 +1,6 @@ +; RUN: cd %desired_wd +; RUN: make clean-readwriteread-enzyme0 ENZYME_PLUGIN=%loadEnzyme +; RUN: make build/readwriteread-enzyme0 ENZYME_PLUGIN=%loadEnzyme CLANG_BIN_PATH=%clangBinPath +; RUN: build/readwriteread-enzyme0 +; RUN: make clean-readwriteread-enzyme0 ENZYME_PLUGIN=%loadEnzyme + diff --git a/enzyme/functional_tests_c/testfiles/readwriteread-enzyme1.test b/enzyme/functional_tests_c/testfiles/readwriteread-enzyme1.test new file mode 100644 index 000000000000..9dc3174b8435 --- /dev/null +++ b/enzyme/functional_tests_c/testfiles/readwriteread-enzyme1.test @@ -0,0 +1,6 @@ +; RUN: cd %desired_wd +; RUN: make clean-readwriteread-enzyme1 ENZYME_PLUGIN=%loadEnzyme +; RUN: make build/readwriteread-enzyme1 ENZYME_PLUGIN=%loadEnzyme CLANG_BIN_PATH=%clangBinPath +; RUN: build/readwriteread-enzyme1 +; RUN: make clean-readwriteread-enzyme1 ENZYME_PLUGIN=%loadEnzyme + diff --git a/enzyme/functional_tests_c/testfiles/readwriteread-enzyme2.test b/enzyme/functional_tests_c/testfiles/readwriteread-enzyme2.test new file mode 100644 index 000000000000..e03f5242726c --- /dev/null +++ b/enzyme/functional_tests_c/testfiles/readwriteread-enzyme2.test @@ -0,0 +1,6 @@ +; RUN: cd %desired_wd +; RUN: make clean-readwriteread-enzyme2 ENZYME_PLUGIN=%loadEnzyme +; RUN: make build/readwriteread-enzyme2 ENZYME_PLUGIN=%loadEnzyme CLANG_BIN_PATH=%clangBinPath +; RUN: build/readwriteread-enzyme2 +; RUN: make clean-readwriteread-enzyme2 ENZYME_PLUGIN=%loadEnzyme + diff --git a/enzyme/functional_tests_c/testfiles/readwriteread-enzyme3.test b/enzyme/functional_tests_c/testfiles/readwriteread-enzyme3.test new file mode 100644 index 000000000000..40efc5f2c7e7 --- /dev/null +++ b/enzyme/functional_tests_c/testfiles/readwriteread-enzyme3.test @@ -0,0 +1,6 @@ +; RUN: cd %desired_wd +; RUN: make clean-readwriteread-enzyme3 ENZYME_PLUGIN=%loadEnzyme +; RUN: make build/readwriteread-enzyme3 ENZYME_PLUGIN=%loadEnzyme CLANG_BIN_PATH=%clangBinPath +; RUN: build/readwriteread-enzyme3 +; RUN: make clean-readwriteread-enzyme3 ENZYME_PLUGIN=%loadEnzyme + diff --git a/enzyme/test/Enzyme/badcall.ll b/enzyme/test/Enzyme/badcall.ll index 9672654917b2..15518f2ebd1d 100644 --- a/enzyme/test/Enzyme/badcall.ll +++ b/enzyme/test/Enzyme/badcall.ll @@ -42,11 +42,12 @@ attributes #1 = { noinline nounwind uwtable } ; CHECK: define internal {{(dso_local )?}}{} @diffef(double* nocapture %x, double* %"x'") ; CHECK-NEXT: entry: -; CHECK-NEXT: %0 = call { { {} } } @augmented_subf(double* %x, double* %"x'") -; CHECK-NEXT: store double 2.000000e+00, double* %x, align 8 -; CHECK-NEXT: store double 0.000000e+00, double* %"x'" -; CHECK-NEXT: %1 = call {} @diffesubf(double* nonnull %x, double* %"x'", { {} } undef) -; CHECK-NEXT: ret {} undef +; CHECK-NEXT: %0 = call { { {}, double } } @augmented_subf(double* %x, double* %"x'") +; CHECK-NEXT: %1 = extractvalue { { {}, double } } %0, 0 +; CHECK-NEXT: store double 2.000000e+00, double* %x, align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"x'" +; CHECK-NEXT: %2 = call {} @diffesubf(double* nonnull %x, double* %"x'", { {}, double } %1) +; CHECK-NEXT: ret {} undef ; CHECK-NEXT: } ; CHECK: define internal {{(dso_local )?}}{ {} } @augmented_metasubf(double* nocapture %x, double* %"x'") @@ -56,16 +57,21 @@ attributes #1 = { noinline nounwind uwtable } ; CHECK-NEXT: ret { {} } undef ; CHECK-NEXT: } -; CHECK: define internal {{(dso_local )?}}{ { {} } } @augmented_subf(double* nocapture %x, double* %"x'") +; CHECK: define internal {{(dso_local )?}}{ { {}, double } } @augmented_subf(double* nocapture %x, double* %"x'") ; CHECK-NEXT: entry: -; CHECK-NEXT: %0 = load double, double* %x, align 8 -; CHECK-NEXT: %mul = fmul fast double %0, 2.000000e+00 -; CHECK-NEXT: store double %mul, double* %x, align 8 -; CHECK-NEXT: %1 = call { {} } @augmented_metasubf(double* %x, double* %"x'") -; CHECK-NEXT: ret { { {} } } undef +; CHECK-NEXT: %0 = alloca { { {}, double } } +; CHECK-NEXT: %1 = getelementptr { { {}, double } }, { { {}, double } }* %0, i32 0, i32 0 +; CHECK-NEXT: %2 = load double, double* %x, align 8 +; CHECK-NEXT: %3 = getelementptr { {}, double }, { {}, double }* %1, i32 0, i32 1 +; CHECK-NEXT: store double %2, double* %3 +; CHECK-NEXT: %mul = fmul fast double %2, 2.000000e+00 +; CHECK-NEXT: store double %mul, double* %x, align 8 +; CHECK-NEXT: %4 = call { {} } @augmented_metasubf(double* %x, double* %"x'") +; CHECK-NEXT: %5 = load { { {}, double } }, { { {}, double } }* %0 +; CHECK-NEXT: ret { { {}, double } } %5 ; CHECK-NEXT: } -; CHECK: define internal {{(dso_local )?}}{} @diffesubf(double* nocapture %x, double* %"x'", { {} } %tapeArg) +; CHECK: define internal {{(dso_local )?}}{} @diffesubf(double* nocapture %x, double* %"x'", { {}, double } %tapeArg) ; CHECK-NEXT: entry: ; CHECK-NEXT: %0 = call {} @diffemetasubf(double* %x, double* %"x'", {} undef) ; CHECK-NEXT: %1 = load double, double* %"x'" diff --git a/enzyme/test/Enzyme/badcall2.ll b/enzyme/test/Enzyme/badcall2.ll index 10a46708f25f..0f47f7f1435e 100644 --- a/enzyme/test/Enzyme/badcall2.ll +++ b/enzyme/test/Enzyme/badcall2.ll @@ -50,10 +50,11 @@ declare dso_local double @__enzyme_autodiff(i8*, double*, double*) local_unnamed ; CHECK: define internal {{(dso_local )?}}{} @diffef(double* nocapture %x, double* %"x'") ; CHECK-NEXT: entry: -; CHECK-NEXT: %0 = call { { {}, {} } } @augmented_subf(double* %x, double* %"x'") +; CHECK-NEXT: %0 = call { { {}, {}, double } } @augmented_subf(double* %x, double* %"x'") +; CHECK-NEXT: %1 = extractvalue { { {}, {}, double } } %0, 0 ; CHECK-NEXT: store double 2.000000e+00, double* %x, align 8 ; CHECK-NEXT: store double 0.000000e+00, double* %"x'", align 8 -; CHECK-NEXT: %1 = call {} @diffesubf(double* nonnull %x, double* %"x'", { {}, {} } undef) +; CHECK-NEXT: %2 = call {} @diffesubf(double* nonnull %x, double* %"x'", { {}, {}, double } %1) ; CHECK-NEXT: ret {} undef ; CHECK-NEXT: } @@ -71,17 +72,23 @@ declare dso_local double @__enzyme_autodiff(i8*, double*, double*) local_unnamed ; CHECK-NEXT: ret { {} } undef ; CHECK-NEXT: } -; CHECK: define internal {{(dso_local )?}}{ { {}, {} } } @augmented_subf(double* nocapture %x, double* %"x'") +; CHECK: define internal {{(dso_local )?}}{ { {}, {}, double } } @augmented_subf(double* nocapture %x, double* %"x'") ; CHECK-NEXT: entry: -; CHECK-NEXT: %0 = load double, double* %x, align 8 -; CHECK-NEXT: %mul = fmul fast double %0, 2.000000e+00 +; CHECK-NEXT: %0 = alloca { { {}, {}, double } } +; CHECK-NEXT: %1 = getelementptr { { {}, {}, double } }, { { {}, {}, double } }* %0, i32 0, i32 0 +; CHECK-NEXT: %2 = load double, double* %x, align 8 +; CHECK-NEXT: %3 = getelementptr { {}, {}, double }, { {}, {}, double }* %1, i32 0, i32 2 +; CHECK-NEXT: store double %2, double* %3 +; CHECK-NEXT: %mul = fmul fast double %2, 2.000000e+00 ; CHECK-NEXT: store double %mul, double* %x, align 8 -; CHECK-NEXT: %1 = call { {} } @augmented_metasubf(double* %x, double* %"x'") -; CHECK-NEXT: %2 = call { {} } @augmented_othermetasubf(double* %x, double* %"x'") -; CHECK-NEXT: ret { { {}, {} } } undef +; CHECK-NEXT: %4 = call { {} } @augmented_metasubf(double* %x, double* %"x'") +; CHECK-NEXT: %5 = call { {} } @augmented_othermetasubf(double* %x, double* %"x'") +; CHECK-NEXT: %6 = load { { {}, {}, double } }, { { {}, {}, double } }* %0 +; CHECK-NEXT: ret { { {}, {}, double } } %6 ; CHECK-NEXT: } -; CHECK: define internal {{(dso_local )?}}{} @diffesubf(double* nocapture %x, double* %"x'", { {}, {} } %tapeArg) + +; CHECK: define internal {{(dso_local )?}}{} @diffesubf(double* nocapture %x, double* %"x'", { {}, {}, double } %tapeArg) ; CHECK-NEXT: entry: ; CHECK-NEXT: %0 = call {} @diffeothermetasubf(double* %x, double* %"x'", {} undef) ; CHECK-NEXT: %1 = call {} @diffemetasubf(double* %x, double* %"x'", {} undef) diff --git a/enzyme/test/Enzyme/badcall3.ll b/enzyme/test/Enzyme/badcall3.ll index 86fb9083359b..0d0b936da2fe 100644 --- a/enzyme/test/Enzyme/badcall3.ll +++ b/enzyme/test/Enzyme/badcall3.ll @@ -50,10 +50,11 @@ declare dso_local double @__enzyme_autodiff(i8*, double*, double*) local_unnamed ; CHECK: define internal {{(dso_local )?}}{} @diffef(double* nocapture %x, double* %"x'") ; CHECK-NEXT: entry: -; CHECK-NEXT: %0 = call { { {}, {} } } @augmented_subf(double* %x, double* %"x'") +; CHECK-NEXT: %0 = call { { {}, {}, double } } @augmented_subf(double* %x, double* %"x'") +; CHECK-NEXT: %1 = extractvalue { { {}, {}, double } } %0, 0 ; CHECK-NEXT: store double 2.000000e+00, double* %x, align 8 ; CHECK-NEXT: store double 0.000000e+00, double* %"x'", align 8 -; CHECK-NEXT: %1 = call {} @diffesubf(double* nonnull %x, double* %"x'", { {}, {} } undef) +; CHECK-NEXT: %2 = call {} @diffesubf(double* nonnull %x, double* %"x'", { {}, {}, double } %1) ; CHECK-NEXT: ret {} undef ; CHECK-NEXT: } @@ -71,17 +72,23 @@ declare dso_local double @__enzyme_autodiff(i8*, double*, double*) local_unnamed ; CHECK-NEXT: ret { {} } undef ; CHECK-NEXT: } -; CHECK: define internal {{(dso_local )?}}{ { {}, {} } } @augmented_subf(double* nocapture %x, double* %"x'") +; CHECK: define internal {{(dso_local )?}}{ { {}, {}, double } } @augmented_subf(double* nocapture %x, double* %"x'") ; CHECK-NEXT: entry: -; CHECK-NEXT: %0 = load double, double* %x, align 8 -; CHECK-NEXT: %mul = fmul fast double %0, 2.000000e+00 -; CHECK-NEXT: store double %mul, double* %x, align 8 -; CHECK-NEXT: %1 = call { {} } @augmented_metasubf(double* %x, double* %"x'") -; CHECK-NEXT: %2 = call { {} } @augmented_othermetasubf(double* %x, double* %"x'") -; CHECK-NEXT: ret { { {}, {} } } undef +; CHECK-NEXT: %0 = alloca { { {}, {}, double } } +; CHECK-NEXT: %1 = getelementptr { { {}, {}, double } }, { { {}, {}, double } }* %0, i32 0, i32 0 +; CHECK-NEXT: %2 = load double, double* %x, align 8 +; CHECK-NEXT: %3 = getelementptr { {}, {}, double }, { {}, {}, double }* %1, i32 0, i32 2 +; CHECK-NEXT: store double %2, double* %3 +; CHECK-NEXT: %mul = fmul fast double %2, 2.000000e+00 +; CHECK-NEXT: store double %mul, double* %x, align 8 +; CHECK-NEXT: %4 = call { {} } @augmented_metasubf(double* %x, double* %"x'") +; CHECK-NEXT: %5 = call { {} } @augmented_othermetasubf(double* %x, double* %"x'") +; CHECK-NEXT: %6 = load { { {}, {}, double } }, { { {}, {}, double } }* %0 +; CHECK-NEXT: ret { { {}, {}, double } } %6 ; CHECK-NEXT: } -; CHECK: define internal {{(dso_local )?}}{} @diffesubf(double* nocapture %x, double* %"x'", { {}, {} } %tapeArg) + +; CHECK: define internal {{(dso_local )?}}{} @diffesubf(double* nocapture %x, double* %"x'", { {}, {}, double } %tapeArg) ; CHECK-NEXT: entry: ; CHECK-NEXT: %0 = call {} @diffeothermetasubf(double* %x, double* %"x'", {} undef) ; CHECK-NEXT: %1 = call {} @diffemetasubf(double* %x, double* %"x'", {} undef) diff --git a/enzyme/test/Enzyme/badcall4.ll b/enzyme/test/Enzyme/badcall4.ll index b7183c501717..b099fac3c2e9 100644 --- a/enzyme/test/Enzyme/badcall4.ll +++ b/enzyme/test/Enzyme/badcall4.ll @@ -51,11 +51,11 @@ declare dso_local double @__enzyme_autodiff(i8*, double*, double*) local_unnamed ; CHECK: define internal {{(dso_local )?}}{} @diffef(double* nocapture %x, double* %"x'") ; CHECK-NEXT: entry: -; CHECK-NEXT: %0 = call { { {}, i1, {}, i1 } } @augmented_subf(double* %x, double* %"x'") -; CHECK-NEXT: %1 = extractvalue { { {}, i1, {}, i1 } } %0, 0 +; CHECK-NEXT: %0 = call { { {}, i1, {}, i1, double } } @augmented_subf(double* %x, double* %"x'") +; CHECK-NEXT: %1 = extractvalue { { {}, i1, {}, i1, double } } %0, 0 ; CHECK-NEXT: store double 2.000000e+00, double* %x, align 8 ; CHECK-NEXT: store double 0.000000e+00, double* %"x'", align 8 -; CHECK-NEXT: %2 = call {} @diffesubf(double* nonnull %x, double* %"x'", { {}, i1, {}, i1 } %1) +; CHECK-NEXT: %2 = call {} @diffesubf(double* nonnull %x, double* %"x'", { {}, i1, {}, i1, double } %1) ; CHECK-NEXT: ret {} undef ; CHECK-NEXT: } @@ -63,7 +63,7 @@ declare dso_local double @__enzyme_autodiff(i8*, double*, double*) local_unnamed ; CHECK: define internal {{(dso_local )?}}{ {}, i1 } @augmented_metasubf(double* nocapture %x, double* %"x'") -; CHECK: define internal {{(dso_local )?}}{ { {}, i1, {}, i1 } } @augmented_subf(double* nocapture %x, double* %"x'") +; CHECK: define internal {{(dso_local )?}}{ { {}, i1, {}, i1, double } } @augmented_subf(double* nocapture %x, double* %"x'") ; CHECK-NEXT: entry: ; CHECK-NEXT: %0 = load double, double* %x, align 8 ; CHECK-NEXT: %mul = fmul fast double %0, 2.000000e+00 @@ -72,12 +72,13 @@ declare dso_local double @__enzyme_autodiff(i8*, double*, double*) local_unnamed ; CHECK-NEXT: %2 = extractvalue { {}, i1 } %1, 1 ; CHECK-NEXT: %3 = call { {}, i1 } @augmented_othermetasubf(double* %x, double* %"x'") ; CHECK-NEXT: %4 = extractvalue { {}, i1 } %3, 1 -; CHECK-NEXT: %[[iv1:.+]] = insertvalue { { {}, i1, {}, i1 } } undef, i1 %4, 0, 1 -; CHECK-NEXT: %[[iv2:.+]] = insertvalue { { {}, i1, {}, i1 } } %[[iv1]], i1 %2, 0, 3 -; CHECK-NEXT: ret { { {}, i1, {}, i1 } } %[[iv2]] +; CHECK-NEXT: %.fca.0.1.insert = insertvalue { { {}, i1, {}, i1, double } } undef, i1 %4, 0, 1 +; CHECK-NEXT: %.fca.0.3.insert = insertvalue { { {}, i1, {}, i1, double } } %.fca.0.1.insert, i1 %2, 0, 3 +; CHECK-NEXT: %.fca.0.4.insert = insertvalue { { {}, i1, {}, i1, double } } %.fca.0.3.insert, double %0, 0, 4 +; CHECK-NEXT: ret { { {}, i1, {}, i1, double } } %.fca.0.4.insert ; CHECK-NEXT: } -; CHECK: define internal {{(dso_local )?}}{} @diffesubf(double* nocapture %x, double* %"x'", { {}, i1, {}, i1 } %tapeArg) +; CHECK: define internal {{(dso_local )?}}{} @diffesubf(double* nocapture %x, double* %"x'", { {}, i1, {}, i1, double } %tapeArg) ; CHECK-NEXT: entry: ; CHECK-NEXT: %0 = call {} @diffeothermetasubf(double* %x, double* %"x'", {} undef) ; CHECK-NEXT: %1 = call {} @diffemetasubf(double* %x, double* %"x'", {} undef) diff --git a/enzyme/test/Enzyme/badcallused.ll b/enzyme/test/Enzyme/badcallused.ll index 51f8b2b915ed..e39062a1751e 100644 --- a/enzyme/test/Enzyme/badcallused.ll +++ b/enzyme/test/Enzyme/badcallused.ll @@ -43,12 +43,13 @@ attributes #1 = { noinline nounwind uwtable } ; CHECK: define internal {{(dso_local )?}}{} @diffef(double* nocapture %x, double* %"x'") ; CHECK-NEXT: entry: -; CHECK-NEXT: %0 = call { { {} }, i1, i1 } @augmented_subf(double* %x, double* %"x'") -; CHECK-NEXT: %1 = extractvalue { { {} }, i1, i1 } %0, 1 -; CHECK-NEXT: %sel = select i1 %1, double 2.000000e+00, double 3.000000e+00 +; CHECK-NEXT: %0 = call { { {}, double }, i1, i1 } @augmented_subf(double* %x, double* %"x'") +; CHECK-NEXT: %1 = extractvalue { { {}, double }, i1, i1 } %0, 0 +; CHECK-NEXT: %2 = extractvalue { { {}, double }, i1, i1 } %0, 1 +; CHECK-NEXT: %sel = select i1 %2, double 2.000000e+00, double 3.000000e+00 ; CHECK-NEXT: store double %sel, double* %x, align 8 -; CHECK-NEXT: store double 0.000000e+00, double* %"x'" -; CHECK-NEXT: %[[dsubf:.+]] = call {} @diffesubf(double* nonnull %x, double* %"x'", { {} } undef) +; CHECK-NEXT: store double 0.000000e+00, double* %"x'", align 8 +; CHECK-NEXT: %3 = call {} @diffesubf(double* nonnull %x, double* %"x'", { {}, double } %1) ; CHECK-NEXT: ret {} undef ; CHECK-NEXT: } @@ -65,24 +66,29 @@ attributes #1 = { noinline nounwind uwtable } ; CHECK-NEXT: ret { {}, i1, i1 } %3 ; CHECK-NEXT: } -; CHECK: define internal {{(dso_local )?}}{ { {} }, i1, i1 } @augmented_subf(double* nocapture %x, double* %"x'") +; CHECK: define internal {{(dso_local )?}}{ { {}, double }, i1, i1 } @augmented_subf(double* nocapture %x, double* %"x'") ; CHECK-NEXT: entry: -; CHECK-NEXT: %0 = alloca { { {} }, i1, i1 } -; CHECK-NEXT: %1 = load double, double* %x, align 8 -; CHECK-NEXT: %mul = fmul fast double %1, 2.000000e+00 +; CHECK-NEXT: %0 = alloca { { {}, double }, i1, i1 } +; CHECK-NEXT: %1 = getelementptr { { {}, double }, i1, i1 } +; CHECK-NEXT: %2 = load double, double* %x, align 8 +; CHECK-NEXT: %3 = getelementptr { {}, double }, { {}, double }* %1, i32 0, i32 1 +; CHECK-NEXT: store double %2, double* %3 +; CHECK-NEXT: %mul = fmul fast double %2, 2.000000e+00 ; CHECK-NEXT: store double %mul, double* %x, align 8 -; CHECK-NEXT: %2 = call { {}, i1, i1 } @augmented_metasubf(double* %x, double* %"x'") -; CHECK-NEXT: %3 = extractvalue { {}, i1, i1 } %2, 1 -; CHECK-NEXT: %antiptr_call = extractvalue { {}, i1, i1 } %2, 2 -; CHECK-NEXT: %4 = getelementptr { { {} }, i1, i1 }, { { {} }, i1, i1 }* %0, i32 0, i32 1 -; CHECK-NEXT: store i1 %3, i1* %4 -; CHECK-NEXT: %5 = getelementptr { { {} }, i1, i1 }, { { {} }, i1, i1 }* %0, i32 0, i32 2 -; CHECK-NEXT: store i1 %antiptr_call, i1* %5 -; CHECK-NEXT: %[[toret:.+]] = load { { {} }, i1, i1 }, { { {} }, i1, i1 }* %0 -; CHECK-NEXT: ret { { {} }, i1, i1 } %[[toret]] +; CHECK-NEXT: %4 = call { {}, i1, i1 } @augmented_metasubf(double* %x, double* %"x'") +; CHECK-NEXT: %5 = extractvalue { {}, i1, i1 } %4, 1 +; CHECK-NEXT: %antiptr_call = extractvalue { {}, i1, i1 } %4, 2 + + +; CHECK-NEXT: %6 = getelementptr { { {}, double }, i1, i1 }, { { {}, double }, i1, i1 }* %0, i32 0, i32 1 +; CHECK-NEXT: store i1 %5, i1* %6 +; CHECK-NEXT: %7 = getelementptr { { {}, double }, i1, i1 }, { { {}, double }, i1, i1 }* %0, i32 0, i32 2 +; CHECK-NEXT: store i1 %antiptr_call, i1* %7 +; CHECK-NEXT: %[[toret:.+]] = load { { {}, double }, i1, i1 }, { { {}, double }, i1, i1 }* %0 +; CHECK-NEXT: ret { { {}, double }, i1, i1 } %[[toret]] ; CHECK-NEXT: } -; CHECK: define internal {{(dso_local )?}}{} @diffesubf(double* nocapture %x, double* %"x'", { {} } %tapeArg) +; CHECK: define internal {{(dso_local )?}}{} @diffesubf(double* nocapture %x, double* %"x'", { {}, double } %tapeArg) ; CHECK-NEXT: entry: ; CHECK-NEXT: %0 = call {} @diffemetasubf(double* %x, double* %"x'", {} undef) ; CHECK-NEXT: %1 = load double, double* %"x'" diff --git a/enzyme/test/Enzyme/badcallused2.ll b/enzyme/test/Enzyme/badcallused2.ll index 92069b003948..0513dde7ad9f 100644 --- a/enzyme/test/Enzyme/badcallused2.ll +++ b/enzyme/test/Enzyme/badcallused2.ll @@ -53,12 +53,13 @@ attributes #1 = { noinline nounwind uwtable } ; CHECK: define internal {{(dso_local )?}}{} @diffef(double* nocapture %x, double* %"x'") ; CHECK-NEXT: entry: -; CHECK-NEXT: %0 = call { { {}, {} }, i1, i1 } @augmented_subf(double* %x, double* %"x'") -; CHECK-NEXT: %1 = extractvalue { { {}, {} }, i1, i1 } %0, 1 -; CHECK-NEXT: %sel = select i1 %1, double 2.000000e+00, double 3.000000e+00 +; CHECK-NEXT: %0 = call { { {}, {}, double }, i1, i1 } @augmented_subf(double* %x, double* %"x'") +; CHECK-NEXT: %1 = extractvalue { { {}, {}, double }, i1, i1 } %0, 0 +; CHECK-NEXT: %2 = extractvalue { { {}, {}, double }, i1, i1 } %0, 1 +; CHECK-NEXT: %sel = select i1 %2, double 2.000000e+00, double 3.000000e+00 ; CHECK-NEXT: store double %sel, double* %x, align 8 ; CHECK-NEXT: store double 0.000000e+00, double* %"x'" -; CHECK-NEXT: %[[dsubf:.+]] = call {} @diffesubf(double* nonnull %x, double* %"x'", { {}, {} } undef) +; CHECK-NEXT: %[[dsubf:.+]] = call {} @diffesubf(double* nonnull %x, double* %"x'", { {}, {}, double } %1) ; CHECK-NEXT: ret {} undef ; CHECK-NEXT: } @@ -82,25 +83,28 @@ attributes #1 = { noinline nounwind uwtable } ; CHECK-NEXT: ret { {} } undef ; CHECK-NEXT: } -; CHECK: define internal {{(dso_local )?}}{ { {}, {} }, i1, i1 } @augmented_subf(double* nocapture %x, double* %"x'") +; CHECK: define internal {{(dso_local )?}}{ { {}, {}, double }, i1, i1 } @augmented_subf(double* nocapture %x, double* %"x'") ; CHECK-NEXT: entry: -; CHECK-NEXT: %0 = alloca { { {}, {} }, i1, i1 } -; CHECK-NEXT: %1 = load double, double* %x, align 8 -; CHECK-NEXT: %mul = fmul fast double %1, 2.000000e+00 +; CHECK-NEXT: %0 = alloca { { {}, {}, double }, i1, i1 } +; CHECK-NEXT: %1 = getelementptr { { {}, {}, double }, i1, i1 }, { { {}, {}, double }, i1, i1 }* %0, i32 0, i32 0 +; CHECK-NEXT: %2 = load double, double* %x, align 8 +; CHECK-NEXT: %3 = getelementptr { {}, {}, double }, { {}, {}, double }* %1, i32 0, i32 2 +; CHECK-NEXT: store double %2, double* %3 +; CHECK-NEXT: %mul = fmul fast double %2, 2.000000e+00 ; CHECK-NEXT: store double %mul, double* %x, align 8 -; CHECK-NEXT: %2 = call { {} } @augmented_omegasubf(double* %x, double* %"x'") -; CHECK-NEXT: %3 = call { {}, i1, i1 } @augmented_metasubf(double* %x, double* %"x'") -; CHECK-NEXT: %4 = extractvalue { {}, i1, i1 } %3, 1 -; CHECK-NEXT: %antiptr_call2 = extractvalue { {}, i1, i1 } %3, 2 -; CHECK-NEXT: %5 = getelementptr { { {}, {} }, i1, i1 }, { { {}, {} }, i1, i1 }* %0, i32 0, i32 1 -; CHECK-NEXT: store i1 %4, i1* %5 -; CHECK-NEXT: %6 = getelementptr { { {}, {} }, i1, i1 }, { { {}, {} }, i1, i1 }* %0, i32 0, i32 2 -; CHECK-NEXT: store i1 %antiptr_call2, i1* %6 -; CHECK-NEXT: %[[toret:.+]] = load { { {}, {} }, i1, i1 }, { { {}, {} }, i1, i1 }* %0 -; CHECK-NEXT: ret { { {}, {} }, i1, i1 } %[[toret]] +; CHECK-NEXT: %4 = call { {} } @augmented_omegasubf(double* %x, double* %"x'") +; CHECK-NEXT: %5 = call { {}, i1, i1 } @augmented_metasubf(double* %x, double* %"x'") +; CHECK-NEXT: %6 = extractvalue { {}, i1, i1 } %5, 1 +; CHECK-NEXT: %antiptr_call2 = extractvalue { {}, i1, i1 } %5, 2 +; CHECK-NEXT: %7 = getelementptr { { {}, {}, double }, i1, i1 }, { { {}, {}, double }, i1, i1 }* %0, i32 0, i32 1 +; CHECK-NEXT: store i1 %6, i1* %7 +; CHECK-NEXT: %8 = getelementptr { { {}, {}, double }, i1, i1 }, { { {}, {}, double }, i1, i1 }* %0, i32 0, i32 2 +; CHECK-NEXT: store i1 %antiptr_call2, i1* %8 +; CHECK-NEXT: %[[toret:.+]] = load { { {}, {}, double }, i1, i1 }, { { {}, {}, double }, i1, i1 }* %0 +; CHECK-NEXT: ret { { {}, {}, double }, i1, i1 } %[[toret]] ; CHECK-NEXT: } -; CHECK: define internal {{(dso_local )?}}{} @diffesubf(double* nocapture %x, double* %"x'", { {}, {} } %tapeArg) +; CHECK: define internal {{(dso_local )?}}{} @diffesubf(double* nocapture %x, double* %"x'", { {}, {}, double } %tapeArg) ; CHECK-NEXT: entry: ; CHECK-NEXT: %0 = call {} @diffemetasubf(double* %x, double* %"x'", {} undef) ; CHECK-NEXT: %1 = call {} @diffeomegasubf(double* %x, double* %"x'", {} undef) diff --git a/enzyme/test/Enzyme/cppllist.ll b/enzyme/test/Enzyme/cppllist.ll index 0c8c645b68d5..a775af61a868 100644 --- a/enzyme/test/Enzyme/cppllist.ll +++ b/enzyme/test/Enzyme/cppllist.ll @@ -1,4 +1,4 @@ -; RUN: opt < %s %loadEnzyme -enzyme -enzyme_preopt=false -inline -mem2reg -adce -instcombine -instsimplify -early-cse-memssa -simplifycfg -correlated-propagation -adce -jump-threading -instsimplify -early-cse -simplifycfg -S | FileCheck %s +; RUN: opt < %s %loadEnzyme -enzyme -enzyme_preopt=false -inline -mem2reg -adce -instcombine -instsimplify -early-cse-memssa -simplifycfg -correlated-propagation -adce -loop-simplify -jump-threading -instsimplify -early-cse -simplifycfg -S | FileCheck %s ; #include ; #include @@ -233,17 +233,13 @@ attributes #8 = { builtin nounwind } ; CHECK: define internal {{(dso_local )?}}{} @diffe_Z8sum_listPK4node(%class.node* noalias readonly %node, %class.node* %"node'", double %[[differet:.+]]) ; CHECK-NEXT: entry: ; CHECK-NEXT: %[[cmp:.+]] = icmp eq %class.node* %node, null -; CHECK-NEXT: br i1 %[[cmp]], label %invertentry, label %for.body.preheader - -; CHECK: for.body.preheader: -; CHECK-NEXT: %malloccall = tail call noalias nonnull i8* @malloc(i64 8) -; CHECK-NEXT: br label %for.body +; CHECK-NEXT: br i1 %[[cmp]], label %invertentry, label %for.body ; CHECK: for.body: -; CHECK-NEXT: %[[rawcache:.+]] = phi i8* [ %malloccall, %for.body.preheader ], [ %_realloccache, %for.body ] -; CHECK-NEXT: %[[preidx:.+]] = phi i64 [ 0, %for.body.preheader ], [ %[[postidx:.+]], %for.body ] -; CHECK-NEXT: %[[cur:.+]] = phi %class.node* [ %"node'", %for.body.preheader ], [ %"'ipl", %for.body ] -; CHECK-NEXT: %val.08 = phi %class.node* [ %node, %for.body.preheader ], [ %[[nextload:.+]], %for.body ] +; CHECK-NEXT: %[[rawcache:.+]] = phi i8* [ %_realloccache, %for.body ], [ null, %entry ] +; CHECK-NEXT: %[[preidx:.+]] = phi i64 [ %[[postidx:.+]], %for.body ], [ 0, %entry ] +; CHECK-NEXT: %[[cur:.+]] = phi %class.node* [ %"'ipl", %for.body ], [ %"node'", %entry ] +; CHECK-NEXT: %val.08 = phi %class.node* [ %[[loadst:.+]], %for.body ], [ %node, %entry ] ; CHECK-NEXT: %[[idx8:.+]] = shl i64 %[[preidx]], 3 ; CHECK-NEXT: %[[nextrealloc:.+]] = add i64 %[[idx8]], 8 ; CHECK-NEXT: %_realloccache = call i8* @realloc(i8* %[[rawcache]], i64 %[[nextrealloc]]) @@ -254,7 +250,7 @@ attributes #8 = { builtin nounwind } ; CHECK-NEXT: %next = getelementptr inbounds %class.node, %class.node* %val.08, i64 0, i32 1 ; CHECK-NEXT: %"next'ipg" = getelementptr %class.node, %class.node* %[[cur]], i64 0, i32 1 ; CHECK-NEXT: %"'ipl" = load %class.node*, %class.node** %"next'ipg", align 8 -; CHECK-NEXT: %[[nextload]] = load %class.node*, %class.node** %next, align 8, !tbaa !8 +; CHECK-NEXT: %[[nextload:.+]] = load %class.node*, %class.node** %next, align 8, !tbaa !8 ; CHECK-NEXT: %[[lcmp:.+]] = icmp eq %class.node* %[[nextload]], null ; CHECK-NEXT: br i1 %[[lcmp]], label %[[antiloop:.+]], label %for.body diff --git a/enzyme/test/Enzyme/llist.ll b/enzyme/test/Enzyme/llist.ll index 82ff07892eff..4300c046dd94 100644 --- a/enzyme/test/Enzyme/llist.ll +++ b/enzyme/test/Enzyme/llist.ll @@ -1,4 +1,4 @@ -; RUN: opt < %s %loadEnzyme -enzyme -enzyme_preopt=false -inline -mem2reg -adce -instcombine -instsimplify -early-cse-memssa -simplifycfg -correlated-propagation -adce -S -jump-threading -instsimplify -early-cse -simplifycfg | FileCheck %s +; RUN: opt < %s %loadEnzyme -enzyme -enzyme_preopt=false -inline -mem2reg -adce -instcombine -instsimplify -early-cse-memssa -simplifycfg -correlated-propagation -adce -S -loop-simplify -jump-threading -instsimplify -early-cse -simplifycfg | FileCheck %s %struct.n = type { double, %struct.n* } @@ -150,15 +150,11 @@ attributes #4 = { nounwind } ; CHECK-NEXT: %cmp6 = icmp eq %struct.n* %node, null ; CHECK-NEXT: br i1 %cmp6, label %invertentry, label %for.body -; CHECK: for.body.preheader: -; CHECK-NEXT: %malloccall = tail call noalias nonnull i8* @malloc(i64 8) -; CHECK-NEXT: br label %for.body - ; CHECK: for.body: -; CHECK-NEXT: %[[rawcache:.+]] = phi i8* [ %malloccall, %for.body.preheader ], [ %_realloccache, %for.body ] -; CHECK-NEXT: %[[preidx:.+]] = phi i64 [ 0, %for.body.preheader ], [ %[[postidx:.+]], %for.body ] -; CHECK-NEXT: %[[cur:.+]] = phi %struct.n* [ %"node'", %for.body.preheader ], [ %"'ipl", %for.body ] -; CHECK-NEXT: %val.08 = phi %struct.n* [ %node, %for.body.preheader ], [ %[[loadst:.+]], %for.body ] +; CHECK-NEXT: %[[rawcache:.+]] = phi i8* [ %_realloccache, %for.body ], [ null, %entry ] +; CHECK-NEXT: %[[preidx:.+]] = phi i64 [ %[[postidx:.+]], %for.body ], [ 0, %entry ] +; CHECK-NEXT: %[[cur:.+]] = phi %struct.n* [ %"'ipl", %for.body ], [ %"node'", %entry ] +; CHECK-NEXT: %val.08 = phi %struct.n* [ %[[loadst:.+]], %for.body ], [ %node, %entry ] ; CHECK-NEXT: %[[idx8:.+]] = shl i64 %[[preidx]], 3 ; CHECK-NEXT: %[[addalloc:.+]] = add i64 %[[idx8]], 8 ; CHECK-NEXT: %_realloccache = call i8* @realloc(i8* %[[rawcache]], i64 %[[addalloc]]) diff --git a/enzyme/test/Enzyme/nllist.ll b/enzyme/test/Enzyme/nllist.ll index bcec9ec7f1e7..97b519efa5cb 100644 --- a/enzyme/test/Enzyme/nllist.ll +++ b/enzyme/test/Enzyme/nllist.ll @@ -305,17 +305,13 @@ attributes #4 = { nounwind } ; CHECK: define internal {{(dso_local )?}}{} @diffesum_list(%struct.n* noalias readonly %node, %struct.n* %"node'", i64 %times, double %differeturn) ; CHECK-NEXT: entry: ; CHECK-NEXT: %[[firstcmp:.+]] = icmp eq %struct.n* %node, null -; CHECK-NEXT: br i1 %[[firstcmp]], label %invertentry, label %for.cond1.preheader.preheader - -; CHECK: for.cond1.preheader.preheader: ; preds = %entry -; CHECK-NEXT: %malloccall = tail call noalias nonnull i8* @malloc(i64 8) -; CHECK-NEXT: br label %for.cond1.preheader +; CHECK-NEXT: br i1 %[[firstcmp]], label %invertentry, label %for.cond1.preheader ; CHECK: for.cond1.preheader: -; CHECK-NEXT: %[[phirealloc:.+]] = phi i8* [ %malloccall, %for.cond1.preheader.preheader ], [ %[[postrealloc:.+]], %for.cond.cleanup4 ] -; CHECK-NEXT: %[[preidx:.+]] = phi i64 [ 0, %for.cond1.preheader.preheader ], [ %[[postidx:.+]], %for.cond.cleanup4 ] -; CHECK-NEXT: %[[valstruct:.+]] = phi %struct.n* [ %"node'", %for.cond1.preheader.preheader ], [ %[[dstructload:.+]], %for.cond.cleanup4 ] -; CHECK-NEXT: %val.020 = phi %struct.n* [ %node, %for.cond1.preheader.preheader ], [ %[[nextstruct:.+]], %for.cond.cleanup4 ] +; CHECK-NEXT: %[[phirealloc:.+]] = phi i8* [ %[[postrealloc:.+]], %for.cond.cleanup4 ], [ null, %entry ] +; CHECK-NEXT: %[[preidx:.+]] = phi i64 [ %[[postidx:.+]], %for.cond.cleanup4 ], [ 0, %entry ] +; CHECK-NEXT: %[[valstruct:.+]] = phi %struct.n* [ %[[dstructload:.+]], %for.cond.cleanup4 ], [ %"node'", %entry ] +; CHECK-NEXT: %val.020 = phi %struct.n* [ %[[nextstruct:.+]], %for.cond.cleanup4 ], [ %node, %entry ] ; CHECK-NEXT: %[[postidx]] = add nuw i64 %[[preidx]], 1 ; CHECK-NEXT: %[[added:.+]] = shl nuw i64 %[[postidx]], 3 ; CHECK-NEXT: %[[postrealloc]] = call i8* @realloc(i8* %[[phirealloc]], i64 %[[added]])