diff --git a/src/passes/DeadArgumentElimination.cpp b/src/passes/DeadArgumentElimination.cpp index 605b69eeb86..3de03a4ffe4 100644 --- a/src/passes/DeadArgumentElimination.cpp +++ b/src/passes/DeadArgumentElimination.cpp @@ -240,26 +240,21 @@ struct DAE : public Pass { scanner.walkModuleCode(module); // Scan all the functions. scanner.run(getPassRunner(), module); - // Combine all the info. - struct CallContext { - Call* call; - Function* func; - }; + // Combine all the info from the scan. std::vector> allCalls(numFunctions); std::vector tailCallees(numFunctions); std::vector hasUnseenCalls(numFunctions); - // Track the function in which relevant expressions exist. When we modify - // those expressions we will need to mark the function's info as stale. - std::unordered_map expressionFuncs; + // For each function, the set of callers. + std::vector> callers(numFunctions); + for (auto& [func, info] : infoMap) { for (auto& [name, calls] : info.calls) { - auto& allCallsToName = allCalls[indexes[name]]; + auto targetIndex = indexes[name]; + auto& allCallsToName = allCalls[targetIndex]; allCallsToName.insert(allCallsToName.end(), calls.begin(), calls.end()); - for (auto* call : calls) { - expressionFuncs[call] = func; - } + callers[targetIndex].insert(func); } for (auto& callee : info.tailCallees) { tailCallees[indexes[callee]] = true; @@ -305,9 +300,9 @@ struct DAE : public Pass { assert(func.is()); infoMap[func].markStale(); }; - auto markCallersStale = [&](const std::vector& calls) { - for (auto* call : calls) { - markStale(expressionFuncs[call]); + auto markCallersStale = [&](Index index) { + for (auto caller : callers[index]) { + markStale(caller); } }; @@ -339,7 +334,7 @@ struct DAE : public Pass { if (refineReturnTypes(func, calls, module)) { refinedReturnTypes = true; markStale(name); - markCallersStale(calls); + markCallersStale(index); } auto optimizedIndexes = ParamUtils::applyConstantValues({func}, calls, {}, module); @@ -382,7 +377,7 @@ struct DAE : public Pass { // Success! worthOptimizing.insert(func); markStale(name); - markCallersStale(calls); + markCallersStale(index); } if (outcome == ParamUtils::RemovalOutcome::Failure) { callTargetsToLocalize.insert(name); @@ -424,15 +419,15 @@ struct DAE : public Pass { } if (removeReturnValue(func.get(), calls, module)) { // We should optimize the callers. - for (auto* call : calls) { - worthOptimizing.insert(module->getFunction(expressionFuncs[call])); + for (auto caller : callers[index]) { + worthOptimizing.insert(module->getFunction(caller)); } } // TODO Removing a drop may also open optimization opportunities in the // callers. worthOptimizing.insert(func.get()); markStale(name); - markCallersStale(calls); + markCallersStale(index); } } if (!callTargetsToLocalize.empty()) {