Skip to content

Commit

Permalink
fix bounds parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
Brant-Skywalker committed Jul 4, 2024
1 parent ecf1aa1 commit 32dbbd5
Showing 1 changed file with 53 additions and 36 deletions.
89 changes: 53 additions & 36 deletions enzyme/Enzyme/Herbie.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,6 @@ double stringToDouble(const std::string &str) {

class FPConst : public FPNode {
std::string strValue;
double loggedValue = std::numeric_limits<double>::quiet_NaN();

public:
FPConst(std::string strValue) : FPNode("__const"), strValue(strValue) {}
Expand All @@ -251,23 +250,34 @@ class FPConst : public FPNode {
return strValue;
}

void updateBounds(double lower, double upper) override {
assert(lower == upper && "logged bounds for constant are not the same");
loggedValue = lower;
llvm::errs() << "Updated bounds for " << strValue << ": [" << lower << ", "
<< upper << "]\n";
}
void updateBounds(double lower, double upper) override { return; }

double getLowerBound() const override {
assert(!std::isnan(loggedValue));
return loggedValue;
}
if (strValue == "+inf.0") {
return std::numeric_limits<double>::infinity();
} else if (strValue == "-inf.0") {
return -std::numeric_limits<double>::infinity();
}

double getUpperBound() const override {
assert(!std::isnan(loggedValue));
return loggedValue;
double constantValue;
size_t div = strValue.find('/');

if (div != std::string::npos) {
std::string numerator = strValue.substr(0, div);
std::string denominator = strValue.substr(div + 1);
double num = stringToDouble(numerator);
double denom = stringToDouble(denominator);

constantValue = num / denom;
} else {
constantValue = stringToDouble(strValue);
}

return constantValue;
}

double getUpperBound() const override { return getLowerBound(); }

virtual Value *getValue(IRBuilder<> &builder) override {
if (strValue == "+inf.0") {
return ConstantFP::getInfinity(builder.getDoubleTy(), false);
Expand Down Expand Up @@ -514,7 +524,7 @@ struct ErrorLogData {
double maxRes;
double minError;
double maxError;
long executions;
unsigned executions;
SmallVector<double, 2> lower; // Known bounds of operands
SmallVector<double, 2> upper;
};
Expand Down Expand Up @@ -763,24 +773,6 @@ bool fpOptimize(Function &F) {
operation_seen.insert(I2);
component_seen.insert(cur);

ErrorLogData errorLogData;
if (!ErrorLogPath.empty()) {
auto blockIt = std::find_if(
I2->getFunction()->begin(), I2->getFunction()->end(),
[&](const auto &block) { return &block == I2->getParent(); });
assert(blockIt != I2->getFunction()->end() && "Block not found");
size_t blockIdx = std::distance(I2->getFunction()->begin(), blockIt);
auto instIt =
std::find_if(I2->getParent()->begin(), I2->getParent()->end(),
[&](const auto &curr) { return &curr == I2; });
assert(instIt != I2->getParent()->end() && "Instruction not found");
size_t instIdx = std::distance(I2->getParent()->begin(), instIt);
if (!extractErrorLogData(ErrorLogPath, functionName, blockIdx,
instIdx, errorLogData)) {
assert(0 && "Failed to extract error log data");
}
}

auto operands =
isa<CallInst>(I2) ? cast<CallInst>(I2)->args() : I2->operands();

Expand All @@ -794,11 +786,36 @@ bool fpOptimize(Function &F) {

// look up error log to get bounds of the operand of I2
if (!ErrorLogPath.empty()) {
ErrorLogData errorLogData;
auto blockIt = std::find_if(
I2->getFunction()->begin(), I2->getFunction()->end(),
[&](const auto &block) { return &block == I2->getParent(); });
assert(blockIt != I2->getFunction()->end() && "Block not found");
size_t blockIdx =
std::distance(I2->getFunction()->begin(), blockIt);
auto instIt =
std::find_if(I2->getParent()->begin(), I2->getParent()->end(),
[&](const auto &curr) { return &curr == I2; });
assert(instIt != I2->getParent()->end() &&
"Instruction not found");
size_t instIdx = std::distance(I2->getParent()->begin(), instIt);
bool logFound = extractErrorLogData(
ErrorLogPath, functionName, blockIdx, instIdx, errorLogData);

auto *node = valueToNodeMap[operand];
node->updateBounds(errorLogData.lower[i], errorLogData.upper[i]);
llvm::errs() << "Bounds of " << *operand
<< " are: " << errorLogData.lower[i] << " and "
<< errorLogData.upper[i] << "\n";
if (logFound) {
node->updateBounds(errorLogData.lower[i],
errorLogData.upper[i]);
llvm::errs() << "Bounds of " << *operand
<< " are: " << errorLogData.lower[i] << " and "
<< errorLogData.upper[i] << "\n";
} else { // Unknown bounds
node->updateBounds(-std::numeric_limits<double>::infinity(),
std::numeric_limits<double>::infinity());
llvm::errs() << "Bounds of " << *operand
<< " are not found in the log\n";
}

llvm::errs() << "Node bounds of " << *operand << " are: "
<< valueToNodeMap[operand]->getLowerBound()
<< " and "
Expand Down

0 comments on commit 32dbbd5

Please sign in to comment.