diff --git a/src/passes/DeadArgumentElimination.cpp b/src/passes/DeadArgumentElimination.cpp index 93e6ba44b26..c1a6689c976 100644 --- a/src/passes/DeadArgumentElimination.cpp +++ b/src/passes/DeadArgumentElimination.cpp @@ -188,6 +188,11 @@ struct DAE : public Pass { bool optimize = false; + Index numFunctions; + + // Map of function names to indexes. This lets us use indexes below for speed. + std::unordered_map indexes; + void run(Module* module) override { DAEFunctionInfoMap infoMap; // Ensure all entries exist so the parallel threads don't modify the data @@ -198,6 +203,12 @@ struct DAE : public Pass { // The null name represents module-level code (not in a function). infoMap[Name()]; + numFunctions = module->functions.size(); + + for (Index i = 0; i < numFunctions; i++) { + indexes[module->functions[i]->name] = i; + } + // Iterate to convergence. while (1) { if (!iteration(module, infoMap)) { @@ -233,34 +244,36 @@ struct DAE : public Pass { Call* call; Function* func; }; - std::map> allCalls; - std::unordered_set tailCallees; - std::unordered_set hasUnseenCalls; + + std::vector> allCalls(numFunctions); + std::vector tailCallees(numFunctions); + std::vector hasUnseenCalls(numFunctions); + // Track the function in which relevant expressions exist. When we modify // those expressions we will need to mark the function's info as stale. std::unordered_map expressionFuncs; for (auto& [func, info] : infoMap) { for (auto& [name, calls] : info.calls) { - auto& allCallsToName = allCalls[name]; + auto& allCallsToName = allCalls[indexes[name]]; allCallsToName.insert(allCallsToName.end(), calls.begin(), calls.end()); for (auto* call : calls) { expressionFuncs[call] = func; } } for (auto& callee : info.tailCallees) { - tailCallees.insert(callee); + tailCallees[indexes[callee]] = true; } for (auto& [call, dropp] : info.droppedCalls) { allDroppedCalls[call] = dropp; } for (auto& name : info.hasUnseenCalls) { - hasUnseenCalls.insert(name); + hasUnseenCalls[indexes[name]] = true; } } // Exports are considered unseen calls. for (auto& curr : module->exports) { if (curr->kind == ExternalKind::Function) { - hasUnseenCalls.insert(*curr->getInternalName()); + hasUnseenCalls[indexes[*curr->getInternalName()]] = true; } } @@ -299,15 +312,24 @@ struct DAE : public Pass { // We now have a mapping of all call sites for each function, and can look // for optimization opportunities. - for (auto& [name, calls] : allCalls) { + for (Index index = 0; index < numFunctions; index++) { + auto* func = module->functions[index].get(); + if (func->imported()) { + continue; + } // We can only optimize if we see all the calls and can modify them. - if (hasUnseenCalls.count(name)) { + if (hasUnseenCalls[index]) { + continue; + } + auto& calls = allCalls[index]; + if (calls.empty()) { + // Nothing calls this, so it is not worth optimizing. continue; } - auto* func = module->getFunction(name); // Refine argument types before doing anything else. This does not // affect whether an argument is used or not, it just refines the type // where possible. + auto name = func->name; if (refineArgumentTypes(func, calls, module, infoMap[name])) { worthOptimizing.insert(func); markStale(func->name); @@ -315,7 +337,7 @@ struct DAE : public Pass { // Refine return types as well. if (refineReturnTypes(func, calls, module)) { refinedReturnTypes = true; - markStale(func->name); + markStale(name); markCallersStale(calls); } auto optimizedIndexes = @@ -336,21 +358,29 @@ struct DAE : public Pass { ReFinalize().run(getPassRunner(), module); } // We now know which parameters are unused, and can potentially remove them. - for (auto& [name, calls] : allCalls) { - if (hasUnseenCalls.count(name)) { + for (Index index = 0; index < numFunctions; index++) { + auto* func = module->functions[index].get(); + if (func->imported()) { + continue; + } + if (hasUnseenCalls[index]) { continue; } - auto* func = module->getFunction(name); auto numParams = func->getNumParams(); if (numParams == 0) { continue; } + auto& calls = allCalls[index]; + if (calls.empty()) { + continue; + } + auto name = func->name; auto [removedIndexes, outcome] = ParamUtils::removeParameters( {func}, infoMap[name].unusedParams, calls, {}, module, getPassRunner()); if (!removedIndexes.empty()) { // Success! worthOptimizing.insert(func); - markStale(func->name); + markStale(name); markCallersStale(calls); } if (outcome == ParamUtils::RemovalOutcome::Failure) { @@ -362,25 +392,28 @@ struct DAE : public Pass { // modified allCalls (we can't modify a call site twice in one iteration, // once to remove a param, once to drop the return value). if (worthOptimizing.empty()) { - for (auto& func : module->functions) { + for (Index index = 0; index < numFunctions; index++) { + auto& func = module->functions[index]; + if (func->imported()) { + continue; + } if (func->getResults() == Type::none) { continue; } - auto name = func->name; - if (hasUnseenCalls.count(name)) { + if (hasUnseenCalls[index]) { continue; } + auto name = func->name; if (infoMap[name].hasTailCalls) { continue; } - if (tailCallees.count(name)) { + if (tailCallees[index]) { continue; } - auto iter = allCalls.find(name); - if (iter == allCalls.end()) { + auto& calls = allCalls[index]; + if (calls.empty()) { continue; } - auto& calls = iter->second; bool allDropped = std::all_of(calls.begin(), calls.end(), [&](Call* call) { return allDroppedCalls.count(call); @@ -397,7 +430,7 @@ struct DAE : public Pass { // TODO Removing a drop may also open optimization opportunities in the // callers. worthOptimizing.insert(func.get()); - markStale(func->name); + markStale(name); markCallersStale(calls); } }