Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 15 additions & 20 deletions src/passes/DeadArgumentElimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,26 +240,21 @@ struct DAE : public Pass {
scanner.walkModuleCode(module);
// Scan all the functions.
scanner.run(getPassRunner(), module);
// Combine all the info.
struct CallContext {
Call* call;
Function* func;
};

// Combine all the info from the scan.
std::vector<std::vector<Call*>> allCalls(numFunctions);
std::vector<bool> tailCallees(numFunctions);
std::vector<bool> hasUnseenCalls(numFunctions);

// Track the function in which relevant expressions exist. When we modify
// those expressions we will need to mark the function's info as stale.
std::unordered_map<Expression*, Name> expressionFuncs;
// For each function, the set of callers.
std::vector<std::unordered_set<Name>> callers(numFunctions);

for (auto& [func, info] : infoMap) {
for (auto& [name, calls] : info.calls) {
auto& allCallsToName = allCalls[indexes[name]];
auto targetIndex = indexes[name];
auto& allCallsToName = allCalls[targetIndex];
allCallsToName.insert(allCallsToName.end(), calls.begin(), calls.end());
for (auto* call : calls) {
expressionFuncs[call] = func;
}
callers[targetIndex].insert(func);
}
for (auto& callee : info.tailCallees) {
tailCallees[indexes[callee]] = true;
Expand Down Expand Up @@ -305,9 +300,9 @@ struct DAE : public Pass {
assert(func.is());
infoMap[func].markStale();
};
auto markCallersStale = [&](const std::vector<Call*>& calls) {
for (auto* call : calls) {
markStale(expressionFuncs[call]);
auto markCallersStale = [&](Index index) {
for (auto caller : callers[index]) {
markStale(caller);
}
};

Expand Down Expand Up @@ -339,7 +334,7 @@ struct DAE : public Pass {
if (refineReturnTypes(func, calls, module)) {
refinedReturnTypes = true;
markStale(name);
markCallersStale(calls);
markCallersStale(index);
}
auto optimizedIndexes =
ParamUtils::applyConstantValues({func}, calls, {}, module);
Expand Down Expand Up @@ -382,7 +377,7 @@ struct DAE : public Pass {
// Success!
worthOptimizing.insert(func);
markStale(name);
markCallersStale(calls);
markCallersStale(index);
}
if (outcome == ParamUtils::RemovalOutcome::Failure) {
callTargetsToLocalize.insert(name);
Expand Down Expand Up @@ -424,15 +419,15 @@ struct DAE : public Pass {
}
if (removeReturnValue(func.get(), calls, module)) {
// We should optimize the callers.
for (auto* call : calls) {
worthOptimizing.insert(module->getFunction(expressionFuncs[call]));
for (auto caller : callers[index]) {
worthOptimizing.insert(module->getFunction(caller));
}
}
// TODO Removing a drop may also open optimization opportunities in the
// callers.
worthOptimizing.insert(func.get());
markStale(name);
markCallersStale(calls);
markCallersStale(index);
}
}
if (!callTargetsToLocalize.empty()) {
Expand Down
Loading