Skip to content

Commit

Permalink
fix macos specialfunc (#1907)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed May 26, 2024
1 parent 2c136a0 commit 877b901
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 18 deletions.
9 changes: 8 additions & 1 deletion enzyme/Enzyme/InstructionDerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,19 @@ class InsertValue<list<int> indices_> : Inst<"InsertValue"> {

class MultiReturn : Operation</*primal*/1,/*shadow*/0> {
bit struct;
bit useRetType;
}
def RetMultiReturnRet : MultiReturn {
bit struct = 0;
bit useRetType = 1;
}
def StructRet : MultiReturn {
bit struct = 1;
bit useRetType = 0;
}
def ArrayRet : MultiReturn {
bit struct = 0;
bit useRetType = 0;
}

def CFAdd : SubRoutine<(Op (Op $re1, $im1):$z1, (Op $re2, $im2):$z2),
Expand Down Expand Up @@ -650,7 +657,7 @@ def : CallPattern<(Op $x),
>;

def ToStruct2 : SubRoutine<(Op (Op $re, $im):$z),
(StructRet $re, $im)
(RetMultiReturnRet $re, $im)
>;
def : CallPattern<(Op $x, $tbd),
["Faddeeva_erf"],
Expand Down
4 changes: 4 additions & 0 deletions enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3142,3 +3142,7 @@ llvm::Value *get1ULP(llvm::IRBuilder<> &builder, llvm::Value *res) {

return absres;
}

void dumpValue(llvm::Value *val) { llvm::errs() << *val << "\n"; }

void dumpType(llvm::Type *ty) { llvm::errs() << *ty << "\n"; }
22 changes: 22 additions & 0 deletions enzyme/test/Enzyme/ReverseModeVector/fad_erf.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -S | FileCheck %s; fi
; RUN: %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,%simplifycfg)" -enzyme-preopt=false -S | FileCheck %s

declare [3 x [2 x double]] @__enzyme_autodiff(...)

declare [2 x double] @Faddeeva_erf([2 x double], double)

define double @test([2 x double] %x) {
entry:
%f = call [2 x double] @Faddeeva_erf([2 x double] %x, double noundef 0.000000e+00)
%y = extractvalue [2 x double] %f, 1
ret double %y
}

define [3 x [2 x double]] @test_derivative([2 x double] %x) {
entry:
%call = call [3 x [2 x double]] (...) @__enzyme_autodiff(double ([2 x double])* @test, metadata !"enzyme_width", i64 3, [2 x double] %x)
ret [3 x [2 x double]] %call
}


; CHECK: define internal { [3 x [2 x double]] } @diffe3test([2 x double] %x, [3 x double] %differeturn)
39 changes: 22 additions & 17 deletions enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,7 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os,
os << "({\n";

bool useStruct = Def->getValueAsBit("struct");
bool useRetType = Def->getValueAsBit("useRetType");

SmallVector<bool, 1> vectorValued = prepareArgs(
curIndent + INDENT, os, argPattern, pattern, resultRoot, builder,
Expand All @@ -756,25 +757,29 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os,
if (anyVector)
os << "gutils->getShadowType(";

if (useStruct)
os << "StructType::get(gutils->newFunc->getContext(), "
"std::vector<llvm::Type*>({";
else
os << "ArrayType::get(";
for (size_t i = 0; i < (useStruct ? vectorValued.size() : 1); i++) {
if (i != 0)
os << ", ";
if (!vectorValued[i])
os << argPattern << "_" << i << "->getType()";
if (useRetType) {
os << (origName == "<ILLEGAL>" ? "call" : origName) << ".getType()";
} else {
if (useStruct)
os << "StructType::get(gutils->newFunc->getContext(), "
"std::vector<llvm::Type*>({";
else
os << "(gutils->getWidth() == 1) ? " << argPattern << "_" << i
<< "->getType() : getSubType(" << argPattern << "_" << i
<< "->getType(), -1)";
os << "ArrayType::get(";
for (size_t i = 0; i < (useStruct ? vectorValued.size() : 1); i++) {
if (i != 0)
os << ", ";
if (!vectorValued[i])
os << argPattern << "_" << i << "->getType()";
else
os << "(gutils->getWidth() == 1) ? " << argPattern << "_" << i
<< "->getType() : getSubType(" << argPattern << "_" << i
<< "->getType(), -1)";
}
if (useStruct)
os << "}))";
else
os << ", " << vectorValued.size() << ")";
}
if (useStruct)
os << "}))";
else
os << ", " << vectorValued.size() << ")";

if (anyVector)
os << ")";
Expand Down

0 comments on commit 877b901

Please sign in to comment.