Skip to content
79 changes: 56 additions & 23 deletions src/passes/DeadArgumentElimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,11 @@ struct DAE : public Pass {

bool optimize = false;

Index numFunctions;

// Map of function names to indexes. This lets us use indexes below for speed.
std::unordered_map<Name, Index> indexes;

void run(Module* module) override {
DAEFunctionInfoMap infoMap;
// Ensure all entries exist so the parallel threads don't modify the data
Expand All @@ -198,6 +203,12 @@ struct DAE : public Pass {
// The null name represents module-level code (not in a function).
infoMap[Name()];

numFunctions = module->functions.size();

for (Index i = 0; i < numFunctions; i++) {
indexes[module->functions[i]->name] = i;
}

// Iterate to convergence.
while (1) {
if (!iteration(module, infoMap)) {
Expand Down Expand Up @@ -233,34 +244,36 @@ struct DAE : public Pass {
Call* call;
Function* func;
};
std::map<Name, std::vector<Call*>> allCalls;
std::unordered_set<Name> tailCallees;
std::unordered_set<Name> hasUnseenCalls;

std::vector<std::vector<Call*>> allCalls(numFunctions);
std::vector<bool> tailCallees(numFunctions);
std::vector<bool> hasUnseenCalls(numFunctions);

// Track the function in which relevant expressions exist. When we modify
// those expressions we will need to mark the function's info as stale.
std::unordered_map<Expression*, Name> expressionFuncs;
for (auto& [func, info] : infoMap) {
for (auto& [name, calls] : info.calls) {
auto& allCallsToName = allCalls[name];
auto& allCallsToName = allCalls[indexes[name]];
allCallsToName.insert(allCallsToName.end(), calls.begin(), calls.end());
for (auto* call : calls) {
expressionFuncs[call] = func;
}
}
for (auto& callee : info.tailCallees) {
tailCallees.insert(callee);
tailCallees[indexes[callee]] = true;
}
for (auto& [call, dropp] : info.droppedCalls) {
allDroppedCalls[call] = dropp;
}
for (auto& name : info.hasUnseenCalls) {
hasUnseenCalls.insert(name);
hasUnseenCalls[indexes[name]] = true;
}
}
// Exports are considered unseen calls.
for (auto& curr : module->exports) {
if (curr->kind == ExternalKind::Function) {
hasUnseenCalls.insert(*curr->getInternalName());
hasUnseenCalls[indexes[*curr->getInternalName()]] = true;
}
}

Expand Down Expand Up @@ -299,23 +312,32 @@ struct DAE : public Pass {

// We now have a mapping of all call sites for each function, and can look
// for optimization opportunities.
for (auto& [name, calls] : allCalls) {
for (Index index = 0; index < numFunctions; index++) {
auto* func = module->functions[index].get();
if (func->imported()) {
continue;
}
// We can only optimize if we see all the calls and can modify them.
if (hasUnseenCalls.count(name)) {
if (hasUnseenCalls[index]) {
continue;
}
auto& calls = allCalls[index];
if (calls.empty()) {
// Nothing calls this, so it is not worth optimizing.
continue;
}
auto* func = module->getFunction(name);
// Refine argument types before doing anything else. This does not
// affect whether an argument is used or not, it just refines the type
// where possible.
auto name = func->name;
if (refineArgumentTypes(func, calls, module, infoMap[name])) {
worthOptimizing.insert(func);
markStale(func->name);
}
// Refine return types as well.
if (refineReturnTypes(func, calls, module)) {
refinedReturnTypes = true;
markStale(func->name);
markStale(name);
markCallersStale(calls);
}
auto optimizedIndexes =
Expand All @@ -336,21 +358,29 @@ struct DAE : public Pass {
ReFinalize().run(getPassRunner(), module);
}
// We now know which parameters are unused, and can potentially remove them.
for (auto& [name, calls] : allCalls) {
if (hasUnseenCalls.count(name)) {
for (Index index = 0; index < numFunctions; index++) {
auto* func = module->functions[index].get();
if (func->imported()) {
continue;
}
if (hasUnseenCalls[index]) {
continue;
}
auto* func = module->getFunction(name);
auto numParams = func->getNumParams();
if (numParams == 0) {
continue;
}
auto& calls = allCalls[index];
if (calls.empty()) {
continue;
}
auto name = func->name;
auto [removedIndexes, outcome] = ParamUtils::removeParameters(
{func}, infoMap[name].unusedParams, calls, {}, module, getPassRunner());
if (!removedIndexes.empty()) {
// Success!
worthOptimizing.insert(func);
markStale(func->name);
markStale(name);
markCallersStale(calls);
}
if (outcome == ParamUtils::RemovalOutcome::Failure) {
Expand All @@ -362,25 +392,28 @@ struct DAE : public Pass {
// modified allCalls (we can't modify a call site twice in one iteration,
// once to remove a param, once to drop the return value).
if (worthOptimizing.empty()) {
for (auto& func : module->functions) {
for (Index index = 0; index < numFunctions; index++) {
auto& func = module->functions[index];
if (func->imported()) {
continue;
}
if (func->getResults() == Type::none) {
continue;
}
auto name = func->name;
if (hasUnseenCalls.count(name)) {
if (hasUnseenCalls[index]) {
continue;
}
auto name = func->name;
if (infoMap[name].hasTailCalls) {
continue;
}
if (tailCallees.count(name)) {
if (tailCallees[index]) {
continue;
}
auto iter = allCalls.find(name);
if (iter == allCalls.end()) {
auto& calls = allCalls[index];
if (calls.empty()) {
continue;
}
auto& calls = iter->second;
bool allDropped =
std::all_of(calls.begin(), calls.end(), [&](Call* call) {
return allDroppedCalls.count(call);
Expand All @@ -397,7 +430,7 @@ struct DAE : public Pass {
// TODO Removing a drop may also open optimization opportunities in the
// callers.
worthOptimizing.insert(func.get());
markStale(func->name);
markStale(name);
markCallersStale(calls);
}
}
Expand Down
Loading