diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index 93f2e9eb434c2..c998cf2962d32 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -112,10 +112,13 @@ const std::map MPIInactiveCommAllocators = { }; const std::set KnownInactiveFunctions = { + "abort", "__assert_fail", "__cxa_guard_acquire", "__cxa_guard_release", "__cxa_guard_abort", + "snprintf", + "sprintf", "printf", "putchar", "fprintf", @@ -777,7 +780,6 @@ bool ActivityAnalyzer::isConstantValue(TypeResults &TR, Value *Val) { // of the global auto res = TR.query(GI).Data0(); auto dt = res[{-1}]; - dt |= res[{0}]; if (dt.isIntegral()) { if (EnzymePrintActivity) llvm::errs() << " VALUE const as global int pointer " << *Val diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 105d48221c997..057444bbc37ad 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -2294,6 +2294,11 @@ class AdjointGenerator } } } + EmitWarning("CannotDeduceType", MTI.getDebugLoc(), gutils->oldFunc, + MTI.getParent(), &MTI, "failed to deduce type of copy ", + MTI); + vd = TypeTree(BaseType::Pointer).Only(0); + goto known; } EmitFailure("CannotDeduceType", MTI.getDebugLoc(), &MTI, "failed to deduce type of copy ", MTI); diff --git a/enzyme/Enzyme/DifferentialUseAnalysis.h b/enzyme/Enzyme/DifferentialUseAnalysis.h index 9cde3b66080a9..ea57f62b208b9 100644 --- a/enzyme/Enzyme/DifferentialUseAnalysis.h +++ b/enzyme/Enzyme/DifferentialUseAnalysis.h @@ -259,6 +259,17 @@ static inline bool is_value_needed_in_reverse( return seen[idx] = true; } } +#if LLVM_VERSION_MAJOR >= 11 + const Value *F = CI->getCalledOperand(); +#else + const Value *F = CI->getCalledValue(); +#endif + if (F == inst) { + if (!gutils->isConstantInstruction(const_cast(user)) || + !gutils->isConstantValue(const_cast((Value *)user))) { + return seen[idx] = true; + } + } } if (isa(user)) { diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 23aa2c7ad35da..04b70585ae9ee 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -570,7 +570,14 @@ class Enzyme : public ModulePass { } } } - if (!res->getType()->canLosslesslyBitCastTo(PTy)) { + if (res->getType()->canLosslesslyBitCastTo(PTy)) { + res = Builder.CreateBitCast(res, PTy); + } + if (res->getType() != PTy && res->getType()->isIntegerTy() && + PTy->isIntegerTy(1)) { + res = Builder.CreateTrunc(res, PTy); + } + if (res->getType() != PTy) { auto loc = CI->getDebugLoc(); if (auto arg = dyn_cast(res)) { loc = arg->getDebugLoc(); @@ -581,7 +588,6 @@ class Enzyme : public ModulePass { " - to arg ", truei, " ", *PTy); return false; } - res = Builder.CreateBitCast(res, PTy); } #if LLVM_VERSION_MAJOR >= 9 if (CI->isByValArgument(i)) { diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 8e40e684d8fce..91d506ca37e93 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -73,6 +73,11 @@ llvm::cl::opt EnzymePrint("enzyme-print", cl::init(false), cl::Hidden, cl::desc("Print before and after fns for autodiff")); +llvm::cl::opt + EnzymePrintUnnecessary("enzyme-print-unnecessary", cl::init(false), + cl::Hidden, + cl::desc("Print unnecessary values in function")); + cl::opt looseTypeAnalysis("enzyme-loose-types", cl::init(false), cl::Hidden, cl::desc("Allow looser use of types")); @@ -910,16 +915,27 @@ void calculateUnusedValuesInFunction( } return UseReq::Recur; }); -#if 0 - llvm::errs() << "unnecessaryValues of " << func.getName() << ": mode=" << to_string(mode) << "\n"; - for (auto a : unnecessaryValues) { - llvm::errs() << *a << "\n"; - } - llvm::errs() << "unnecessaryInstructions " << func.getName() << ":\n"; - for (auto a : unnecessaryInstructions) { - llvm::errs() << *a << "\n"; + + if (EnzymePrintUnnecessary) { + llvm::errs() << "unnecessaryValues of " << func.getName() + << ": mode=" << to_string(mode) << "\n"; + for (auto a : unnecessaryValues) { + bool ivn = is_value_needed_in_reverse( + TR, gutils, a, mode, PrimalSeen, oldUnreachable); + bool isn = is_value_needed_in_reverse( + TR, gutils, a, mode, PrimalSeen, oldUnreachable); + llvm::errs() << *a << " ivn=" << (int)ivn << " isn: " << (int)isn; + auto found = gutils->knownRecomputeHeuristic.find(a); + if (found != gutils->knownRecomputeHeuristic.end()) { + llvm::errs() << " krc=" << (int)found->second; + } + llvm::errs() << "\n"; + } + llvm::errs() << "unnecessaryInstructions " << func.getName() << ":\n"; + for (auto a : unnecessaryInstructions) { + llvm::errs() << *a << "\n"; + } } -#endif } void calculateUnusedStoresInFunction( diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 3f76c06580cd0..d1fb3091489f2 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -2639,8 +2639,47 @@ Constant *GradientUtils::GetOrCreateShadowFunction(EnzymeLogic &Logic, // indirect augmented calls), topLevel MUST be true otherwise subcalls will // not be able to lookup the augmenteddata/subdata (triggering an assertion // failure, among much worse) + bool isRealloc = false; + if (fn->empty()) { + if (hasMetadata(fn, "enzyme_callwrapper")) { + auto md = fn->getMetadata("enzyme_callwrapper"); + if (!isa(md)) { + llvm::errs() << *fn << "\n"; + llvm::errs() << *md << "\n"; + assert(0 && "callwrapper of incorrect type"); + report_fatal_error("callwrapper of incorrect type"); + } + auto md2 = cast(md); + assert(md2->getNumOperands() == 1); + auto gvemd = cast(md2->getOperand(0)); + fn = cast(gvemd->getValue()); + } else { + auto oldfn = fn; + fn = Function::Create(oldfn->getFunctionType(), Function::InternalLinkage, + "callwrap_" + oldfn->getName(), oldfn->getParent()); + BasicBlock *entry = BasicBlock::Create(fn->getContext(), "entry", fn); + IRBuilder<> B(entry); + SmallVector args; + for (auto &a : fn->args()) + args.push_back(&a); + auto res = B.CreateCall(oldfn, args); + if (fn->getReturnType()->isVoidTy()) + B.CreateRetVoid(); + else + B.CreateRet(res); + oldfn->setMetadata( + "enzyme_callwrapper", + MDTuple::get(oldfn->getContext(), {ConstantAsMetadata::get(fn)})); + if (oldfn->getName() == "realloc") + isRealloc = true; + } + } std::map uncacheable_args; FnTypeInfo type_args(fn); + if (isRealloc) { + llvm::errs() << "warning: assuming realloc only creates pointers\n"; + type_args.Return.insert({-1, -1}, BaseType::Pointer); + } // conservatively assume that we can only cache existing floating types // (i.e. that all args are uncacheable) diff --git a/enzyme/Enzyme/GradientUtils.h b/enzyme/Enzyme/GradientUtils.h index 76a9e926934d0..231bea75c4718 100644 --- a/enzyme/Enzyme/GradientUtils.h +++ b/enzyme/Enzyme/GradientUtils.h @@ -434,6 +434,8 @@ class GradientUtils : public CacheUtility { Value *getNewFromOriginal(const Value *originst) const { assert(originst); + if (isa(originst)) + return const_cast(originst); auto f = originalToNewFn.find(originst); if (f == originalToNewFn.end()) { llvm::errs() << *oldFunc << "\n"; @@ -691,6 +693,20 @@ class GradientUtils : public CacheUtility { placeholder->setName(""); IRBuilder<> bb(placeholder); + Function *Fn = orig->getCalledFunction(); + +#if LLVM_VERSION_MAJOR >= 11 + if (auto castinst = dyn_cast(orig->getCalledOperand())) +#else + if (auto castinst = dyn_cast(orig->getCalledValue())) +#endif + { + if (castinst->isCast()) + if (auto fn = dyn_cast(castinst->getOperand(0))) + Fn = fn; + } + assert(Fn); + SmallVector args; #if LLVM_VERSION_MAJOR >= 14 for (auto &arg : orig->args()) @@ -701,14 +717,12 @@ class GradientUtils : public CacheUtility { args.push_back(getNewFromOriginal(arg)); } - if (shadowHandlers.find(orig->getCalledFunction()->getName().str()) != - shadowHandlers.end()) { + if (shadowHandlers.find(Fn->getName().str()) != shadowHandlers.end()) { bb.SetInsertPoint(placeholder); Value *anti = placeholder; if (mode != DerivativeMode::ReverseModeGradient) { - anti = shadowHandlers[orig->getCalledFunction()->getName().str()]( - bb, orig, args); + anti = shadowHandlers[Fn->getName().str()](bb, orig, args); invertedPointers.erase(found); bb.SetInsertPoint(placeholder); @@ -726,8 +740,14 @@ class GradientUtils : public CacheUtility { return anti; } +#if LLVM_VERSION_MAJOR >= 11 Value *anti = - bb.CreateCall(orig->getCalledFunction(), args, orig->getName() + "'mi"); + bb.CreateCall(orig->getFunctionType(), orig->getCalledOperand(), args, + orig->getName() + "'mi"); +#else + Value *anti = + bb.CreateCall(orig->getCalledValue(), args, orig->getName() + "'mi"); +#endif cast(anti)->setAttributes(orig->getAttributes()); cast(anti)->setCallingConv(orig->getCallingConv()); cast(anti)->setTailCallKind(orig->getTailCallKind()); @@ -745,8 +765,7 @@ class GradientUtils : public CacheUtility { Attribute::NonNull); #endif unsigned derefBytes = 0; - if (orig->getCalledFunction()->getName() == "malloc" || - orig->getCalledFunction()->getName() == "_Znwm") { + if (Fn->getName() == "malloc" || Fn->getName() == "_Znwm") { if (auto ci = dyn_cast(args[0])) { derefBytes = ci->getLimitedValue(); CallInst *cal = cast(getNewFromOriginal(orig)); @@ -789,7 +808,7 @@ class GradientUtils : public CacheUtility { std::make_pair((const Value *)orig, InvertedPointerVH(this, anti))); if (tape == nullptr) { - if (orig->getCalledFunction()->getName() == "julia.gc_alloc_obj") { + if (Fn->getName() == "julia.gc_alloc_obj") { Type *tys[] = { PointerType::get(StructType::get(orig->getContext()), 10)}; FunctionType *FT = @@ -799,7 +818,7 @@ class GradientUtils : public CacheUtility { anti); } - if (orig->getCalledFunction()->getName() == "swift_allocObject") { + if (Fn->getName() == "swift_allocObject") { EmitFailure( "SwiftShadowAllocation", orig->getDebugLoc(), orig, "Haven't implemented shadow allocator for `swift_allocObject`", @@ -817,7 +836,7 @@ class GradientUtils : public CacheUtility { auto val_arg = ConstantInt::get(Type::getInt8Ty(orig->getContext()), 0); Value *size; // todo check if this memset is legal and if a write barrier is needed - if (orig->getCalledFunction()->getName() == "julia.gc_alloc_obj") { + if (Fn->getName() == "julia.gc_alloc_obj") { size = args[1]; } else { size = args[0]; @@ -1667,7 +1686,7 @@ class DiffeGradientUtils : public GradientUtils { llvm::errs() << "module: " << *oldFunc->getParent() << "\n"; llvm::errs() << "oldFunc: " << *oldFunc << "\n"; llvm::errs() << "newFunc: " << *newFunc << "\n"; - llvm::errs() << "val: " << *val << " old: " << old << "\n"; + llvm::errs() << "val: " << *val << " old: " << *old << "\n"; } assert(addingType); assert(addingType->isFPOrFPVectorTy()); diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index eb8b52e0429d7..3a4908b6e0f32 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -2489,8 +2489,9 @@ void TypeAnalyzer::visitMemTransferCommon(llvm::CallInst &MTI) { size_t sz = 1; for (auto val : fntypeinfo.knownIntegralValues(MTI.getArgOperand(2), *DT, intseen)) { - assert(val >= 0); - sz = max(sz, (size_t)val); + if (val >= 0) { + sz = max(sz, (size_t)val); + } } TypeTree res = getAnalysis(MTI.getArgOperand(0)).AtMost(sz).PurgeAnything(); diff --git a/enzyme/test/Integration/ReverseMode/metamalloc.c b/enzyme/test/Integration/ReverseMode/metamalloc.c new file mode 100644 index 0000000000000..7eb554bd4cc3d --- /dev/null +++ b/enzyme/test/Integration/ReverseMode/metamalloc.c @@ -0,0 +1,58 @@ +// RUN: %clang -std=c11 -Xclang -new-struct-path-tbaa -O0 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli - +// RUN: %clang -std=c11 -Xclang -new-struct-path-tbaa -fno-unroll-loops -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli - +// RUN: %clang -std=c11 -Xclang -new-struct-path-tbaa -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli - +// RUN: %clang -std=c11 -Xclang -new-struct-path-tbaa -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli - +// RUN: %clang -std=c11 -Xclang -new-struct-path-tbaa -O0 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme --enzyme-inline=1 -S | %lli - +// RUN: %clang -std=c11 -Xclang -new-struct-path-tbaa -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme --enzyme-inline=1 -S | %lli - +// RUN: %clang -std=c11 -Xclang -new-struct-path-tbaa -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme --enzyme-inline=1 -S | %lli - +// RUN: %clang -std=c11 -Xclang -new-struct-path-tbaa -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme --enzyme-inline=1 -S | %lli - + +#include +#include +#include + +#include "test_utils.h" + +double __enzyme_autodiff(void*, ...); + +struct { + int count; +void* (*allocfn)(long int); +} tup = {0, malloc}; +__attribute__((noinline)) +void* metamalloc(long int size) { + void* ret = tup.allocfn(size); + //if (ret != 0) + // tup.count++; + return ret; +} +__attribute__((noinline)) +void square(double* x) { + *x *= *x; +} +double alldiv(double x) { + double* mem = (double*)metamalloc(8); + *mem = x; + square(mem); + return mem[0]; +} + + +static void* (*sallocfn)(int) = malloc; +__attribute__((noinline)) +void* smetamalloc(int size) { + return sallocfn(size); +} +double salldiv(double x) { + double* mem = (double*)metamalloc(8); + *mem = x * x; + return mem[0]; +} + +int main(int argc, char** argv) { + double res = __enzyme_autodiff((void*)alldiv, 3.14); + APPROX_EQ(res, 6.28, 1e-6); + double sres = __enzyme_autodiff((void*)salldiv, 3.14); + APPROX_EQ(sres, 6.28, 1e-6); + return 0; +}