Skip to content

Commit

Permalink
Extend functionality (#1875)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed May 11, 2024
1 parent 53a31b2 commit db5d616
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 38 deletions.
44 changes: 44 additions & 0 deletions enzyme/Enzyme/BlasDerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,16 @@ def gemv : CallBlasPattern<(Op $layout, $transa, $m, $n, $alpha, $A, $lda, $x, $
/* y */ (b<"scal"> (Rows $transa, $m, $n), $beta, adj<"y">)
]
>;

// x = Ax
// currently assumes for vector dimensions that transa = 'N' and gets dimensions wrong otherwise
def trmv : CallBlasPattern<(Op $layout, $transa, $diag, $n, $A, $lda, $x, $incx),
["x"], [cblas_layout, trans, diag, len, mld<["diag", "n", "n"]>, vinc<["n"]>],
[
/* A */ (inactive), //(b<"ger"> $layout, $m, $n, $alpha, (Rows $transa, (Concat adj<"y">, $x), (Concat $x, adj<"y">)), adj<"A">),
/* x */ (b<"trmv"> $layout, transpose<"transa">, $diag, $n, $A, (ld $A, Char<"N">, $lda, $n, $n), adj<"x">)
]
>;
//
def ger : CallBlasPattern<(Op $layout, $m, $n, $alpha, $x, $incx, $y, $incy, $A, $lda),
["A"],[cblas_layout, len, len, fp, vinc<["m"]>, vinc<["n"]>, mld<["m", "n"]>],
Expand Down Expand Up @@ -253,6 +263,28 @@ def gemm : CallBlasPattern<(Op $layout, $transa, $transb, $m, $n, $k, $alpha, $A
]
>;

def syrk : CallBlasPattern<(Op $layout, $uplo, $trans, $n, $k, $alpha, $A, $lda, $beta, $C, $ldc),
["C"],
[cblas_layout, uplo, trans, len, len, fp, mld<["trans", "n", "k"]>, fp, mld<["n", "n"]>],
[

/* alpha */ (inactive), /*(Seq<["AB", "product", "m", "n"]>
(b<"gemm"> $layout, $transa, $transb, $m, $n, $k, Constant<"1.0">, $A, (ld $A, $transa, $lda, $k, $m), $B, (ld $B, $transb, $ldb, $k, $n), Constant<"0.0">, use<"AB">, $m),// TODO: check if last arg should be $m or $n
(FrobInnerProd<""> $m, $n, adj<"C">, use<"AB">)),*/
/* A */ (inactive), /*(b<"gemm"> $layout, (Rows $transa,
(Concat $transa, transpose<"transb">, $m, $k),
(Concat $transb, $transa, $k, $m)),
$n, $alpha,
(Rows $transa,
(Concat adj<"C">, $B, (ld $B, $transb, $ldb, $n, $k)),
(Concat $B, (ld $B, $transb, $ldb, $n, $k), adj<"C">)),
Constant<"1.0">, adj<"A">),*/

/* beta */ (inactive), //(FrobInnerProd<""> $m, $n, adj<"C">, input<"C">),
/* C */ (inactive), //(b<"lascl"> $layout, Char<"G">, ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, $beta, $m, $n, adj<"C">, Alloca<1>)
]
>;

def spmv : CallBlasPattern<(Op $layout, $uplo, $n, $alpha, $ap, $x, $incx, $beta, $y, $incy),
["y"],
[cblas_layout, uplo, len, fp, ap<["n"]>, vinc<["n"]>, fp, vinc<["n"]>],
Expand All @@ -269,6 +301,18 @@ def spmv : CallBlasPattern<(Op $layout, $uplo, $n, $alpha, $ap, $x, $incx, $beta
]
>;

// B2 = inv(A^T) B
// dB = inv(A^T) dB2
// d(A^T) −= dB B2^T
def trtrs : CallBlasPattern<(Op $layout, $uplo, $trans, $diag, $n, $nrhs, $a, $lda, $b, $ldb, $info),
["b"],
[cblas_layout, uplo, trans, diag, len, len, mld<["n", "n"]>, vinc<["n"]>, len],
[
/* a */ (inactive),
/* b */ (inactive),
]
>;

def spr2 : CallBlasPattern<(Op $layout, $uplo, $n, $alpha, $x, $incx, $y, $incy, $ap),
["ap"],
[cblas_layout, uplo, len, fp, vinc<["n"]>, vinc<["n"]>, ap<["n"]>],
Expand Down
3 changes: 2 additions & 1 deletion enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,8 @@ bool attributeKnownFunctions(llvm::Function &F) {
F.getName() ==
"_ZNKSt8__detail20_Prime_rehash_policy14_M_need_rehashEmmm" ||
F.getName() == "fprintf" || F.getName() == "fwrite" ||
F.getName() == "strtol" || F.getName() == "getenv") {
F.getName() == "strtol" || F.getName() == "getenv" ||
F.getName() == "memchr") {
changed = true;
F.addAttribute(
AttributeList::FunctionIndex,
Expand Down
27 changes: 26 additions & 1 deletion enzyme/Enzyme/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1613,10 +1613,31 @@ static inline bool isNoEscapingAllocation(const llvm::Function *F) {
case Intrinsic::prefetch:
case Intrinsic::trap:
case Intrinsic::is_constant:
#if LLVM_VERSION_MAJOR >= 12
case Intrinsic::smax:
case Intrinsic::smin:
case Intrinsic::umax:
case Intrinsic::umin:
#endif
case Intrinsic::ctlz:
case Intrinsic::cttz:
case Intrinsic::sadd_with_overflow:
case Intrinsic::ssub_with_overflow:
#if LLVM_VERSION_MAJOR >= 12
case Intrinsic::abs:
#endif
case Intrinsic::sqrt:
case Intrinsic::exp:
case Intrinsic::cos:
case Intrinsic::sin:
case Intrinsic::copysign:
case Intrinsic::fabs:
return true;
default:
break;
}
// if (F->empty())
// llvm::errs() << " may escape:" << F->getName() << "\n";
return false;
}
static inline bool isNoEscapingAllocation(const llvm::CallBase *call) {
Expand All @@ -1625,7 +1646,11 @@ static inline bool isNoEscapingAllocation(const llvm::CallBase *call) {
if (AttrList.hasAttribute("enzyme_no_escaping_allocation"))
return true;
if (auto F = getFunctionFromCall(call)) {
return isNoEscapingAllocation(F);
auto res = isNoEscapingAllocation(F);
// if (!res && F->empty()) {
// llvm::errs() << " may escape:" << *call << "\n";
//}
return res;
}
return false;
}
Expand Down
25 changes: 14 additions & 11 deletions enzyme/tools/enzyme-tblgen/blas-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ void emit_helper(const TGPattern &pattern, raw_ostream &os) {
return;
}
}
PrintFatalError("Blas function without vector or matrix?");
PrintFatalError(pattern.getLoc(), "Blas function without vector or matrix?");
}

void emit_scalar_types(const TGPattern &pattern, raw_ostream &os) {
Expand Down Expand Up @@ -1045,9 +1045,12 @@ void emit_deriv_rule(const StringMap<TGPattern> &patternMap, Rule &rule,
// nothing to prepare
} else if (Def->isSubClassOf("DiffeRetIndex")) {
// nothing to prepare
} else if (Def->getName() == "inactive") {
} else if (Def->isSubClassOf("Inst")) {
PrintFatalError("Unhandled Inst Rule!");
// TODO:
std::string str;
llvm::raw_string_ostream os(str);
os << "Unhandled Inst Rule: " << *Def;
PrintFatalError(Def->getLoc(), os.str());
return;
} else if (Def->isSubClassOf("Seq")) {
// handle seq rules
Expand Down Expand Up @@ -1076,7 +1079,7 @@ void emit_deriv_rule(const StringMap<TGPattern> &patternMap, Rule &rule,
// nothing to prepare
assert(ruleDag->getNumArgs() == 6);
} else {
PrintFatalError("Unhandled deriv Rule!");
PrintFatalError(Def->getLoc(), "Unhandled deriv Rule!");
}
}

Expand Down Expand Up @@ -1148,7 +1151,7 @@ void rev_call_arg(DagInit *ruleDag, Rule &rule, size_t actArg, size_t pos,
}

errs() << Def->getName() << "\n";
PrintFatalError("Dag/Def that isn't a DiffeRet!!");
PrintFatalError(Def->getLoc(), "Dag/Def that isn't a DiffeRet!!");
} else if (DefInit *DefArg = dyn_cast<DefInit>(arg)) {
auto Def = DefArg->getDef();
if (Def->isSubClassOf("DiffeRetIndex")) {
Expand All @@ -1166,7 +1169,7 @@ void rev_call_arg(DagInit *ruleDag, Rule &rule, size_t actArg, size_t pos,
if (argPosition == (size_t)(-1)) {
errs() << "couldn't find name: " << name << " ap=" << argPosition
<< "\n";
PrintFatalError("arg not in inverted nameMap!");
PrintFatalError(Def->getLoc(), "arg not in inverted nameMap!");
}
auto ty = rule.argTypesFull.lookup(argPosition);
auto incName = rule.nameVec[argPosition + 1];
Expand All @@ -1188,7 +1191,7 @@ void rev_call_arg(DagInit *ruleDag, Rule &rule, size_t actArg, size_t pos,
if (argPosition == (size_t)(-1)) {
errs() << "couldn't find name: " << name << " ap=" << argPosition
<< "\n";
PrintFatalError("arg not in inverted nameMap!");
PrintFatalError(Def->getLoc(), "arg not in inverted nameMap!");
}
auto ty = rule.argTypesFull.lookup(argPosition);
auto incName = rule.nameVec[argPosition + 1];
Expand Down Expand Up @@ -1221,7 +1224,7 @@ void rev_call_arg(DagInit *ruleDag, Rule &rule, size_t actArg, size_t pos,
//} else if (val == "C") {
} else {
errs() << "unknown char: " << val << "\n";
PrintFatalError("unknown char");
PrintFatalError(Def->getLoc(), "unknown char");
}
} else if (Def->isSubClassOf("Alloca")) {
auto val = Def->getValueAsInt("value");
Expand All @@ -1242,18 +1245,18 @@ void rev_call_arg(DagInit *ruleDag, Rule &rule, size_t actArg, size_t pos,
<< "\"))}";
} else {
errs() << Def->getName() << "\n";
PrintFatalError("Def that isn't a DiffeRet!");
PrintFatalError(Def->getLoc(), "Def that isn't a DiffeRet!");
}
} else {
auto name = ruleDag->getArgNameStr(pos);
if (name == "") {
PrintFatalError("arg has no name!" + std::to_string(pos));
PrintFatalError(rule.getLoc(), "arg has no name!" + std::to_string(pos));
assert(name != "");
}
// get the position of the argument in the primary blas call
if (nameMap.count(name) != 1) {
errs() << "couldn't find name: " << name << "\n";
PrintFatalError("arg not in nameMap!");
PrintFatalError(rule.getLoc(), "arg not in nameMap!");
}
assert(nameMap.count(name) == 1);
auto argPosition = nameMap.lookup(name);
Expand Down
3 changes: 2 additions & 1 deletion enzyme/tools/enzyme-tblgen/blasDeclUpdater.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ void emit_attributeBLAS(const TGPattern &pattern, raw_ostream &os) {
"llvm::Attribute::MustProgress);\n"
<< "#endif\n"
<< " F->addFnAttr(llvm::Attribute::NoFree);\n"
<< " F->addFnAttr(llvm::Attribute::NoSync);\n";
<< " F->addFnAttr(llvm::Attribute::NoSync);\n"
<< " F->addFnAttr(\"enzyme_no_escaping_allocation\");\n";

auto argTypeMap = pattern.getArgTypeMap();
DenseSet<size_t> mutableArgs = pattern.getMutableArgs();
Expand Down
7 changes: 7 additions & 0 deletions enzyme/tools/enzyme-tblgen/blasTAUpdater.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ void emit_BLASTA(TGPattern &pattern, raw_ostream &os) {
// sorry. will fix later. effectively, skip arg 0 for for lv23,
// because we have the cblas layout in the .td declaration
size_t i = (lv23 ? j - 1 : j);
if (pattern.getArgNames().size() <= j) {
PrintFatalError(pattern.getLoc(),
Twine("Too few argnames for pattern '") + name +
"' found " +
std::to_string(pattern.getArgNames().size()) +
" expected " + std::to_string(argTypeMap.size()));
}
os << " // " << currentType << " " << pattern.getArgNames()[j] << "\n";
switch (currentType) {
case ArgType::len:
Expand Down
55 changes: 37 additions & 18 deletions enzyme/tools/enzyme-tblgen/datastructures.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,13 @@ bool isVecLikeArg(ArgType ty) {
return false;
}

bool isArgUsed(StringRef toFind, const DagInit *toSearch,
bool isArgUsed(Rule *rule, StringRef toFind, const DagInit *toSearch,
ArrayRef<std::string> nameVec,
const DenseMap<size_t, ArgType> &argTypesFull) {
for (size_t i = 0; i < toSearch->getNumArgs(); i++) {
if (DagInit *arg = dyn_cast<DagInit>(toSearch->getArg(i))) {
// os << " Recursing. Magic!\n";
if (isArgUsed(toFind, arg, nameVec, argTypesFull))
if (isArgUsed(rule, toFind, arg, nameVec, argTypesFull))
return true;
} else {
auto name = toSearch->getArgNameStr(i);
Expand Down Expand Up @@ -111,9 +111,10 @@ bool isArgUsed(StringRef toFind, const DagInit *toSearch,
}
}
if (argPosition == (size_t)(-1)) {
errs() << "couldn't find name: " << name << " ap=" << argPosition
<< "\n";
PrintFatalError("arg not in inverted nameMap!");
PrintFatalError(rule->getLoc(),
Twine("arg '") + name +
"' (pos=" + std::to_string(argPosition) +
") not in inverted nameMap isArgUsed(1)!");
}
auto ty = argTypesFull.lookup(argPosition);
if (ty == ArgType::vincData ||
Expand All @@ -136,9 +137,10 @@ bool isArgUsed(StringRef toFind, const DagInit *toSearch,
}
}
if (argPosition == (size_t)(-1)) {
errs() << "couldn't find name: " << name << " ap=" << argPosition
<< "\n";
PrintFatalError("arg not in inverted nameMap!");
PrintFatalError(rule->getLoc(),
Twine("arg '") + name +
"' (pos=" + std::to_string(argPosition) +
") not in inverted nameMap isArgUsed(2)!");
}
auto ty = argTypesFull.lookup(argPosition);
if (ty == ArgType::vincData || ty == ArgType::mldData) {
Expand All @@ -152,23 +154,31 @@ bool isArgUsed(StringRef toFind, const DagInit *toSearch,
return false;
}

Rule::Rule(ArrayRef<std::string> nameVec, DagInit *dag, size_t activeArgIdx,
const StringMap<size_t> &patternArgs,
Rule::Rule(TGPattern *pattern, ArrayRef<std::string> nameVec, DagInit *dag,
size_t activeArgIdx, const StringMap<size_t> &patternArgs,
const DenseMap<size_t, ArgType> &patternTypes,
const DenseSet<size_t> &patternMutables)
: rewriteRule(dag), activeArg(activeArgIdx),
: pattern(pattern), rewriteRule(dag), activeArg(activeArgIdx),
nameVec(nameVec.begin(), nameVec.end()) {
// For each arg found in the dag:
// 1) copy patternArgs to ruleArgs if arg shows up in this rule
for (auto argName : patternArgs.keys()) {
assert(patternArgs.count(argName) == 1);
size_t argPos = patternArgs.lookup(argName);
argTypesFull.insert(*patternTypes.find(argPos));
auto found = patternTypes.find(argPos);
if (found == patternTypes.end()) {
PrintFatalError(getLoc(), Twine("Could not successfully find argName '") +
argName + " (index " +
std::to_string(argPos) +
") in patternTypes");
}
argTypesFull.insert(*found);
}
for (auto argName : patternArgs.keys()) {
assert(patternArgs.count(argName) == 1);
size_t argPos = patternArgs.lookup(argName);
bool argUsedInRule = isArgUsed(argName, rewriteRule, nameVec, argTypesFull);
bool argUsedInRule =
isArgUsed(this, argName, rewriteRule, nameVec, argTypesFull);
if (argUsedInRule) {
argNameToPos.insert(std::pair<std::string, size_t>(argName, argPos));
// 2) look up and copy the corresponding argType
Expand All @@ -193,6 +203,10 @@ Rule::Rule(ArrayRef<std::string> nameVec, DagInit *dag, size_t activeArgIdx,
assert(argTypes.size() == argNameToPos.size());
}

TGPattern *Rule::getPattern() const { return pattern; }

ArrayRef<SMLoc> Rule::getLoc() const { return getPattern()->getLoc(); }

bool Rule::isBLASLevel2or3() const { return BLASLevel2or3; }

DagInit *Rule::getRuleDag() { return rewriteRule; }
Expand Down Expand Up @@ -346,7 +360,8 @@ void fillRelatedLenghts(
assert(argTypes.lookup(lengths[0]) == ArgType::len);
assert(argTypes.lookup(lengths[1]) == ArgType::len);
} else {
assert(argTypes.lookup(lengths[0]) == ArgType::trans);
assert(argTypes.lookup(lengths[0]) == ArgType::trans ||
argTypes.lookup(lengths[0]) == ArgType::diag);
assert(argTypes.lookup(lengths[1]) == ArgType::len);
assert(argTypes.lookup(lengths[2]) == ArgType::len);
}
Expand Down Expand Up @@ -375,7 +390,10 @@ void fillArgUserMap(ArrayRef<Rule> rules, ArrayRef<std::string> nameVec,
}
}

TGPattern::TGPattern(Record *r) : blasName(r->getNameInitAsString()) {
ArrayRef<SMLoc> TGPattern::getLoc() const { return record->getLoc(); }

TGPattern::TGPattern(Record *r)
: record(r), blasName(r->getNameInitAsString()) {
fillArgs(r, args, argNameToPos);
fillArgTypes(r, argTypes);
fillRelatedLenghts(r, argNameToPos, argTypes, relatedLengths);
Expand All @@ -395,8 +413,8 @@ TGPattern::TGPattern(Record *r) : blasName(r->getNameInitAsString()) {
for (auto &&derivOp : enumerate(*derivOps)) {
DagInit *derivRule = cast<DagInit>(derivOp.value());
size_t actIdx = posActArgs[derivOp.index()];
rules.push_back(
Rule(args, derivRule, actIdx, argNameToPos, argTypes, mutables));
rules.push_back(Rule(this, args, derivRule, actIdx, argNameToPos,
argTypes, mutables));
}
}

Expand All @@ -413,7 +431,8 @@ SmallVector<size_t, 3> TGPattern::getRelatedLengthArgs(size_t arg) const {
auto related = relatedLengths.lookup(arg);

if (related.size() == 3) {
assert(argTypes.lookup(related[0]) == ArgType::trans);
auto argTy = argTypes.lookup(related[0]);
assert(argTy == ArgType::trans || argTy == ArgType::diag);
}

return related;
Expand Down
Loading

0 comments on commit db5d616

Please sign in to comment.