Skip to content

Commit

Permalink
Logging for error estimate (#1859)
Browse files Browse the repository at this point in the history
* Add logging func call

* Add test

* Add original value

* func & bb name (requires -fno-discard-value-names)

* improve

* indices

* improve

* use std::distance instead

* fix private method call

* improve

* fix format

* add counter test

* Add test eq mechanism

---------

Co-authored-by: William S. Moses <gh@wsmoses.com>
  • Loading branch information
Brant-Skywalker and wsmoses committed May 11, 2024
1 parent db5d616 commit 75363f7
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 1 deletion.
14 changes: 13 additions & 1 deletion enzyme/test/Integration/ForwardError/binops.c
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,24 @@ double fabs(double);

extern double __enzyme_error_estimate(void *, ...);

int errorLogCount = 0;

void enzymeLogError(double res, double err, const char *opcodeName,
const char *calleeName, const char *moduleName,
const char *functionName, const char *blockName) {
++errorLogCount;
printf("Res = %e, Error = %e, Op = %s, Callee = %s, Module = %s, Function = "
"%s, BasicBlock = %s\n",
res, err, opcodeName, calleeName, moduleName, functionName, blockName);
}

// An example from https://dl.acm.org/doi/10.1145/3371128
double fun(double x) {
double v1 = cos(x);
double v2 = 1 - v1;
double v3 = x * x;
double v4 = v2 / v3;
double v5 = sin(v4);
double v5 = sin(v4); // Inactive -- logger is not invoked.

printf("v1 = %.18e, v2 = %.18e, v3 = %.18e, v4 = %.18e, v5 = %.18e\n", v1, v2,
v3, v4, v5);
Expand All @@ -31,4 +42,5 @@ int main() {
printf("res = %.18e, abs error = %.18e, rel error = %.18e\n", res, error,
fabs(error / res));
APPROX_EQ(error, 2.2222222222e-2, 1e-4);
TEST_EQ(errorLogCount, 4);
}
9 changes: 9 additions & 0 deletions enzyme/test/Integration/test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,12 @@ static inline bool approx_fp_equality_double(double f1, double f2, double thresh
abort(); \
} \
};

#define TEST_EQ(LHS, RHS) \
{ \
if ((LHS) != (RHS)) {\
fprintf(stderr, "Assertion Failed: [%s = %d] != [%s = %d] at %s:%d (%s)\n", #LHS, (int)(LHS), #RHS, (int)(RHS), \
__FILE__, __LINE__, __PRETTY_FUNCTION__); \
abort(); \
} \
};
85 changes: 85 additions & 0 deletions enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2169,6 +2169,91 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os,
<< origName << ")), res);\n";

os << " assert(res);\n";

// Insert logging function call (optional)
os << " Function *logFunc = " << origName
<< ".getModule()->getFunction(\"enzymeLogError\");\n";
os << " if (logFunc) {\n"
<< " std::string moduleName = " << origName
<< ".getModule()->getModuleIdentifier() ;\n"
<< " std::string functionName = " << origName
<< ".getFunction()->getName().str();\n"
<< " std::string blockName = " << origName
<< ".getParent()->getName().str();\n"
<< " int funcIdx = -1, blockIdx = -1, instIdx = -1;\n"
<< " auto funcIt = std::find_if(" << origName
<< ".getModule()->begin(), " << origName
<< ".getModule()->end(),\n"
" [&](const auto& func) { return &func == "
<< origName
<< ".getFunction(); });\n"
" if (funcIt != "
<< origName
<< ".getModule()->end()) {\n"
" funcIdx = "
"std::distance("
<< origName << ".getModule()->begin(), funcIt);\n"
<< " }\n"
<< " auto blockIt = std::find_if(" << origName
<< ".getFunction()->begin(), " << origName
<< ".getFunction()->end(),\n"
" [&](const auto& block) { return &block == "
<< origName
<< ".getParent(); });\n"
" if (blockIt != "
<< origName
<< ".getFunction()->end()) {\n"
" blockIdx = std::distance("
<< origName << ".getFunction()->begin(), blockIt);\n"
<< " }\n"
<< " auto instIt = std::find_if(" << origName
<< ".getParent()->begin(), " << origName
<< ".getParent()->end(),\n"
" [&](const auto& curr) { return &curr == &"
<< origName
<< "; });\n"
" if (instIt != "
<< origName
<< ".getParent()->end()) {\n"
" instIdx = std::distance("
<< origName << ".getParent()->begin(), instIt);\n"
<< " }\n"
<< " Value *origValue = "
"Builder2.CreateFPExt(gutils->getNewFromOriginal(&"
<< origName << "), Type::getDoubleTy(" << origName
<< ".getContext()));\n"
<< " Value *errValue = Builder2.CreateFPExt(res, "
"Type::getDoubleTy("
<< origName << ".getContext()));\n"
<< " std::string opcodeName = " << origName
<< ".getOpcodeName();\n"
<< " std::string calleeName = \"<N/A>\";\n"
<< " if (auto CI = dyn_cast<CallInst>(&" << origName
<< ")) {\n"
<< " if (Function *fn = CI->getCalledFunction()) {\n"
<< " calleeName = fn->getName();\n"
<< " } else {\n"
<< " calleeName = \"<Unknown>\";\n"
<< " }\n"
<< " }\n"
<< " Value *moduleNameValue = "
"Builder2.CreateGlobalStringPtr(moduleName);\n"
<< " Value *functionNameValue = "
"Builder2.CreateGlobalStringPtr(functionName + \" (\" +"
"std::to_string(funcIdx) + \")\");\n"
<< " Value *blockNameValue = "
"Builder2.CreateGlobalStringPtr(blockName + \" (\" +"
"std::to_string(blockIdx) + \")\");\n"
<< " Value *opcodeNameValue = "
"Builder2.CreateGlobalStringPtr(opcodeName + \" (\" "
"+std::to_string(instIdx) + \")\");\n"
<< " Value *calleeNameValue = "
"Builder2.CreateGlobalStringPtr(calleeName);\n"
<< " Builder2.CreateCall(logFunc, {origValue, "
"errValue, opcodeNameValue, calleeNameValue, moduleNameValue, "
"functionNameValue, blockNameValue});\n"
<< " }\n";

os << " setDiffe(&" << origName << ", res, Builder2);\n";
os << " break;\n";
os << " }\n";
Expand Down

0 comments on commit 75363f7

Please sign in to comment.