Skip to content

Commit

Permalink
Fix cublas return in reverse (#1882)
Browse files Browse the repository at this point in the history
* Fix cublas return in reverse

* Fixup reverse caching
  • Loading branch information
wsmoses authored May 15, 2024
1 parent f5e66b2 commit 6875a9e
Showing 1 changed file with 27 additions and 10 deletions.
37 changes: 27 additions & 10 deletions enzyme/tools/enzyme-tblgen/blas-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,20 +147,30 @@ void emit_handleBLAS(ArrayRef<TGPattern> blasPatterns, raw_ostream &os) {
<< " return false; \n"
<< " } \n"
<< " } else { \n"
<< " if (gutils->knownRecomputeHeuristic.find(&call) !=\n"
<< " gutils->knownRecomputeHeuristic.end()) {\n"
<< " if (!gutils->knownRecomputeHeuristic[&call]) {\n"
<< " auto found = gutils->knownRecomputeHeuristic.find(&call); \n"
<< " auto end = gutils->knownRecomputeHeuristic.end(); \n"
<< " bool shouldErase = true;\n"
<< " if (found != end) {\n"
<< " if (!found->second) { \n"
<< " auto newCall = gutils->getNewFromOriginal(&call);\n"
<< " llvm::IRBuilder<> BuilderZ(newCall);\n"
<< " gutils->cacheForReverse(BuilderZ, newCall,\n"
<< " getIndex(&call, CacheType::Self, BuilderZ));\n"
<< " shouldErase = false;\n"
<< " }\n"
<< " }\n"
<< " if (Mode == DerivativeMode::ReverseModeGradient) { \n"
<< " eraseIfUnused(call, /*erase*/ true, /*check*/ false); \n"
<< " } else { \n"
<< " eraseIfUnused(call); \n"
<< " } \n"
<< " if (shouldErase) {\n"
<< " if (Mode == DerivativeMode::ReverseModeGradient) { "
"\n"
<< " eraseIfUnused(call, /*erase*/ true, /*check*/ false); "
"\n"
<< " } else { "
"\n"
<< " eraseIfUnused(call); "
"\n"
<< " } "
"\n"
<< " }\n"
<< " }\n"
<< " return result; \n"
<< "} \n";
Expand Down Expand Up @@ -237,8 +247,15 @@ void emit_free_and_ending(const TGPattern &pattern, raw_ostream &os) {

os << " }\n"
<< " }\n"
<< " \n"
<< " if (gutils->knownRecomputeHeuristic.find(&call) !=\n"
<< " \n";

os << " if (cublas && Mode == DerivativeMode::ReverseModeGradient && "
"call.getType()->isIntegerTy()) { \n"
<< " gutils->replaceAWithB(gutils->getNewFromOriginal(&call), "
"Constant::getNullValue(call.getType()));\n"
<< " }\n";

os << " if (gutils->knownRecomputeHeuristic.find(&call) !=\n"
<< " gutils->knownRecomputeHeuristic.end()) {\n"
<< " if (!gutils->knownRecomputeHeuristic[&call]) {\n"
<< " auto cv = gutils->cacheForReverse(BuilderZ, newCall,\n"
Expand Down

0 comments on commit 6875a9e

Please sign in to comment.