Skip to content

Commit

Permalink
wasm-metadce all the things (WebAssembly#6142)
Browse files Browse the repository at this point in the history
Remove hardcoded paths for globals/functions/etc. in favor of general code
paths that support all the module elements uniformly. As a result of that, we
now support all parts of wasm, such as tables and element segments, that
we didn't before.

This refactoring is NFC aside from adding functionality. Note that this reduces
the size of wasm-metadce by 10% while increasing its functionality - the
benefits of writing generic code.

To support this, add some trivial generic helpers to get or iterate over module
elements using their kind in a dynamic manner. Using them might make
wasm-metadce slightly slower, but I can't measure any difference.
  • Loading branch information
kripken committed Nov 30, 2023
1 parent a191d66 commit 42cddbf
Show file tree
Hide file tree
Showing 8 changed files with 222 additions and 141 deletions.
26 changes: 26 additions & 0 deletions src/ir/module-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,32 @@ template<typename T> inline void iterImportable(Module& wasm, T visitor) {
}
}

// Iterates over all module items. The visitor provided should have signature
// void(ModuleItemKind, Named*).
template<typename T> inline void iterModuleItems(Module& wasm, T visitor) {
for (auto& curr : wasm.functions) {
visitor(ModuleItemKind::Function, curr.get());
}
for (auto& curr : wasm.tables) {
visitor(ModuleItemKind::Table, curr.get());
}
for (auto& curr : wasm.memories) {
visitor(ModuleItemKind::Memory, curr.get());
}
for (auto& curr : wasm.globals) {
visitor(ModuleItemKind::Global, curr.get());
}
for (auto& curr : wasm.tags) {
visitor(ModuleItemKind::Tag, curr.get());
}
for (auto& curr : wasm.dataSegments) {
visitor(ModuleItemKind::DataSegment, curr.get());
}
for (auto& curr : wasm.elementSegments) {
visitor(ModuleItemKind::ElementSegment, curr.get());
}
}

// Helper class for performing an operation on all the functions in the module,
// in parallel, with an Info object for each one that can contain results of
// some computation that the operation performs.
Expand Down
223 changes: 82 additions & 141 deletions src/tools/wasm-metadce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,12 @@ struct MetaDCEGraph {
std::unordered_map<Name, DCENode> nodes;
std::unordered_set<Name> roots;

// export exported name => DCE name
std::unordered_map<Name, Name> exportToDCENode;
std::unordered_map<Name, Name> functionToDCENode; // function name => DCE name
std::unordered_map<Name, Name> globalToDCENode; // global name => DCE name
std::unordered_map<Name, Name> tagToDCENode; // tag name => DCE name

std::unordered_map<Name, Name> DCENodeToExport; // reverse maps
std::unordered_map<Name, Name> DCENodeToFunction;
std::unordered_map<Name, Name> DCENodeToGlobal;
std::unordered_map<Name, Name> DCENodeToTag;
using KindName = std::pair<ModuleItemKind, Name>;

// Kind and exported name => DCE name
std::unordered_map<KindName, Name> itemToDCENode;

// imports are not mapped 1:1 to DCE nodes in the wasm, since env.X might
// be imported twice, for example. So we don't map a DCE node to an Import,
Expand All @@ -85,18 +81,8 @@ struct MetaDCEGraph {
return std::string(module.str) + " (*) " + std::string(base.str);
}

ImportId getFunctionImportId(Name name) {
auto* imp = wasm.getFunction(name);
return getImportId(imp->module, imp->base);
}

ImportId getGlobalImportId(Name name) {
auto* imp = wasm.getGlobal(name);
return getImportId(imp->module, imp->base);
}

ImportId getTagImportId(Name name) {
auto* imp = wasm.getTag(name);
ImportId getImportId(ModuleItemKind kind, Name name) {
auto* imp = wasm.getImport(kind, name);
return getImportId(imp->module, imp->base);
}

Expand All @@ -107,79 +93,44 @@ struct MetaDCEGraph {

MetaDCEGraph(Module& wasm) : wasm(wasm) {}

std::unordered_map<ModuleItemKind, std::string> kindPrefixes = {
{ModuleItemKind::Function, "func"},
{ModuleItemKind::Table, "table"},
{ModuleItemKind::Memory, "memory"},
{ModuleItemKind::Global, "global"},
{ModuleItemKind::Tag, "tag"},
{ModuleItemKind::DataSegment, "dseg"},
{ModuleItemKind::ElementSegment, "eseg"}};

// populate the graph with info from the wasm, integrating with
// potentially-existing nodes for imports and exports that the graph may
// already contain.
void scanWebAssembly() {
// Add an entry for everything we might need ahead of time, so parallel work
// does not alter parent state, just adds to things pointed by it,
// independently (each thread will add for one function, etc.)
ModuleUtils::iterDefinedFunctions(wasm, [&](Function* func) {
auto dceName = getName("func", func->name.toString());
DCENodeToFunction[dceName] = func->name;
functionToDCENode[func->name] = dceName;
nodes[dceName] = DCENode(dceName);
});
ModuleUtils::iterDefinedGlobals(wasm, [&](Global* global) {
auto dceName = getName("global", global->name.toString());
DCENodeToGlobal[dceName] = global->name;
globalToDCENode[global->name] = dceName;
nodes[dceName] = DCENode(dceName);
});
ModuleUtils::iterDefinedTags(wasm, [&](Tag* tag) {
auto dceName = getName("tag", tag->name.toString());
DCENodeToTag[dceName] = tag->name;
tagToDCENode[tag->name] = dceName;
nodes[dceName] = DCENode(dceName);
});
// only process function, global, and tag imports - the table and memory are
// always there
ModuleUtils::iterImportedFunctions(wasm, [&](Function* import) {
auto id = getImportId(import->module, import->base);
if (importIdToDCENode.find(id) == importIdToDCENode.end()) {
auto dceName = getName("importId", import->name.toString());
importIdToDCENode[id] = dceName;
}
});
ModuleUtils::iterImportedGlobals(wasm, [&](Global* import) {
auto id = getImportId(import->module, import->base);
if (importIdToDCENode.find(id) == importIdToDCENode.end()) {
auto dceName = getName("importId", import->name.toString());
importIdToDCENode[id] = dceName;
}
});
ModuleUtils::iterImportedTags(wasm, [&](Tag* import) {
auto id = getImportId(import->module, import->base);
if (importIdToDCENode.find(id) == importIdToDCENode.end()) {
auto dceName = getName("importId", import->name.toString());
importIdToDCENode[id] = dceName;
ModuleUtils::iterModuleItems(wasm, [&](ModuleItemKind kind, Named* item) {
if (auto* import = wasm.getImportOrNull(kind, item->name)) {
auto id = getImportId(import->module, import->base);
if (importIdToDCENode.find(id) == importIdToDCENode.end()) {
auto dceName = getName("importId", import->name.toString());
importIdToDCENode[id] = dceName;
}
return;
}
auto dceName = getName(kindPrefixes[kind], item->name.toString());
itemToDCENode[{kind, item->name}] = dceName;
nodes[dceName] = DCENode(dceName);
});
for (auto& exp : wasm.exports) {
if (exportToDCENode.find(exp->name) == exportToDCENode.end()) {
auto dceName = getName("export", exp->name.toString());
DCENodeToExport[dceName] = exp->name;
exportToDCENode[exp->name] = dceName;
nodes[dceName] = DCENode(dceName);
}
// we can also link the export to the thing being exported
auto& node = nodes[exportToDCENode[exp->name]];
if (exp->kind == ExternalKind::Function) {
node.reaches.push_back(getFunctionDCEName(exp->value));
} else if (exp->kind == ExternalKind::Global) {
if (!wasm.getGlobal(exp->value)->imported()) {
node.reaches.push_back(globalToDCENode[exp->value]);
} else {
node.reaches.push_back(
importIdToDCENode[getGlobalImportId(exp->value)]);
}
} else if (exp->kind == ExternalKind::Tag) {
if (!wasm.getTag(exp->value)->imported()) {
node.reaches.push_back(tagToDCENode[exp->value]);
} else {
node.reaches.push_back(importIdToDCENode[getTagImportId(exp->value)]);
}
}
node.reaches.push_back(getDCEName(ModuleItemKind(exp->kind), exp->value));
}
// Add initializer dependencies
// if we provide a parent DCE name, that is who can reach what we see
Expand All @@ -193,7 +144,7 @@ struct MetaDCEGraph {
void visitRefFunc(RefFunc* curr) {
assert(!parentDceName.isNull());
parent->nodes[parentDceName].reaches.push_back(
parent->getFunctionDCEName(curr->func));
parent->getDCEName(ModuleItemKind::Function, curr->func));
}

private:
Expand All @@ -204,10 +155,11 @@ struct MetaDCEGraph {
Name dceName;
if (!getModule()->getGlobal(name)->imported()) {
// its a defined global
dceName = parent->globalToDCENode[name];
dceName = parent->itemToDCENode[{ModuleItemKind::Global, name}];
} else {
// it's an import.
dceName = parent->importIdToDCENode[parent->getGlobalImportId(name)];
dceName = parent->importIdToDCENode[parent->getImportId(
ModuleItemKind::Global, name)];
}
if (parentDceName.isNull()) {
parent->roots.insert(dceName);
Expand All @@ -217,26 +169,33 @@ struct MetaDCEGraph {
}
};
ModuleUtils::iterDefinedGlobals(wasm, [&](Global* global) {
InitScanner scanner(this, globalToDCENode[global->name]);
InitScanner scanner(
this, itemToDCENode[{ModuleItemKind::Global, global->name}]);
scanner.setModule(&wasm);
scanner.walk(global->init);
});
// we can't remove segments, so root what they need
// We can't remove active segments, so root them and what they use.
// TODO: treat them as in a cycle with their parent memory/table
InitScanner rooter(this, Name());
rooter.setModule(&wasm);
ModuleUtils::iterActiveElementSegments(wasm, [&](ElementSegment* segment) {
// TODO: currently, all functions in the table are roots, but we
// should add an option to refine that
ElementUtils::iterElementSegmentFunctionNames(
segment,
[&](Name name, Index) { roots.insert(getFunctionDCEName(name)); });
segment, [&](Name name, Index) {
roots.insert(getDCEName(ModuleItemKind::Function, name));
});
rooter.walk(segment->offset);
roots.insert(getDCEName(ModuleItemKind::ElementSegment, segment->name));
});
ModuleUtils::iterActiveDataSegments(wasm, [&](DataSegment* segment) {
rooter.walk(segment->offset);
roots.insert(getDCEName(ModuleItemKind::DataSegment, segment->name));
});
ModuleUtils::iterActiveDataSegments(
wasm, [&](DataSegment* segment) { rooter.walk(segment->offset); });

// A parallel scanner for function bodies
struct Scanner : public WalkerPass<PostWalker<Scanner>> {
struct Scanner : public WalkerPass<
PostWalker<Scanner, UnifiedExpressionVisitor<Scanner>>> {
bool isFunctionParallel() override { return true; }

Scanner(MetaDCEGraph* parent) : parent(parent) {}
Expand All @@ -245,62 +204,58 @@ struct MetaDCEGraph {
return std::make_unique<Scanner>(parent);
}

void visitCall(Call* curr) { handleFunction(curr->target); }
void visitRefFunc(RefFunc* curr) { handleFunction(curr->func); }
void visitGlobalGet(GlobalGet* curr) { handleGlobal(curr->name); }
void visitGlobalSet(GlobalSet* curr) { handleGlobal(curr->name); }
void visitThrow(Throw* curr) { handleTag(curr->tag); }
void visitTry(Try* curr) {
for (auto tag : curr->catchTags) {
handleTag(tag);
}
void visitExpression(Expression* curr) {
#define DELEGATE_ID curr->_id

#define DELEGATE_START(id) [[maybe_unused]] auto* cast = curr->cast<id>();

#define DELEGATE_GET_FIELD(id, field) cast->field

#define DELEGATE_FIELD_TYPE(id, field)
#define DELEGATE_FIELD_HEAPTYPE(id, field)
#define DELEGATE_FIELD_CHILD(id, field)
#define DELEGATE_FIELD_OPTIONAL_CHILD(id, field)
#define DELEGATE_FIELD_INT(id, field)
#define DELEGATE_FIELD_INT_ARRAY(id, field)
#define DELEGATE_FIELD_LITERAL(id, field)
#define DELEGATE_FIELD_NAME(id, field)
#define DELEGATE_FIELD_NAME_VECTOR(id, field)
#define DELEGATE_FIELD_SCOPE_NAME_DEF(id, field)
#define DELEGATE_FIELD_SCOPE_NAME_USE(id, field)
#define DELEGATE_FIELD_SCOPE_NAME_USE_VECTOR(id, field)
#define DELEGATE_FIELD_ADDRESS(id, field)

#define DELEGATE_FIELD_NAME_KIND(id, field, kind) \
if (cast->field.is()) { \
handle(kind, cast->field); \
}

#include "wasm-delegations-fields.def"
}

private:
MetaDCEGraph* parent;

void handleFunction(Name name) {
parent->nodes[parent->functionToDCENode[getFunction()->name]]
.reaches.push_back(parent->getFunctionDCEName(name));
}

void handleGlobal(Name name) {
if (!getFunction()) {
return; // non-function stuff (initializers) are handled separately
}
Name dceName;
if (!getModule()->getGlobal(name)->imported()) {
// its a global
dceName = parent->globalToDCENode[name];
} else {
// it's an import.
dceName = parent->importIdToDCENode[parent->getGlobalImportId(name)];
}
parent->nodes[parent->functionToDCENode[getFunction()->name]]
.reaches.push_back(dceName);
void handle(ModuleItemKind kind, Name name) {
getCurrentFunctionDCENode().reaches.push_back(
parent->getDCEName(kind, name));
}

void handleTag(Name name) {
Name dceName;
if (!getModule()->getTag(name)->imported()) {
dceName = parent->tagToDCENode[name];
} else {
dceName = parent->importIdToDCENode[parent->getTagImportId(name)];
}
parent->nodes[parent->functionToDCENode[getFunction()->name]]
.reaches.push_back(dceName);
DCENode& getCurrentFunctionDCENode() {
return parent->nodes[parent->itemToDCENode[{ModuleItemKind::Function,
getFunction()->name}]];
}
};

PassRunner runner(&wasm);
Scanner(this).run(&runner, &wasm);
}

Name getFunctionDCEName(Name name) {
if (!wasm.getFunction(name)->imported()) {
return functionToDCENode[name];
Name getDCEName(ModuleItemKind kind, Name name) {
if (wasm.getImportOrNull(kind, name)) {
return importIdToDCENode[getImportId(kind, name)];
} else {
return importIdToDCENode[getFunctionImportId(name)];
return itemToDCENode[{kind, name}];
}
}

Expand Down Expand Up @@ -394,19 +349,6 @@ struct MetaDCEGraph {
if (importMap.find(name) != importMap.end()) {
std::cout << " is import " << importMap[name] << '\n';
}
if (DCENodeToExport.find(name) != DCENodeToExport.end()) {
std::cout << " is export " << DCENodeToExport[name] << ", "
<< wasm.getExport(DCENodeToExport[name])->value << '\n';
}
if (DCENodeToFunction.find(name) != DCENodeToFunction.end()) {
std::cout << " is function " << DCENodeToFunction[name] << '\n';
}
if (DCENodeToGlobal.find(name) != DCENodeToGlobal.end()) {
std::cout << " is global " << DCENodeToGlobal[name] << '\n';
}
if (DCENodeToTag.find(name) != DCENodeToTag.end()) {
std::cout << " is tag " << DCENodeToTag[name] << '\n';
}
for (auto target : node.reaches) {
std::cout << " reaches: " << target << '\n';
}
Expand Down Expand Up @@ -606,7 +548,6 @@ int main(int argc, const char* argv[]) {
"for the form";
}
graph.exportToDCENode[exp->getIString()] = node.name;
graph.DCENodeToExport[node.name] = exp->getIString();
}
if (ref->has(IMPORT)) {
json::Ref imp = ref[IMPORT];
Expand Down
7 changes: 7 additions & 0 deletions src/wasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -2310,6 +2310,13 @@ class Module {
Global* getGlobalOrNull(Name name);
Tag* getTagOrNull(Name name);

// get* methods that are generic over the kind, that is, items are identified
// by their kind and their name. Otherwise, they are similar to the above
// get* methods. These return items that can be imports.
// TODO: Add methods for things that cannot be imports (segments).
Importable* getImport(ModuleItemKind kind, Name name);
Importable* getImportOrNull(ModuleItemKind kind, Name name);

Export* addExport(Export* curr);
Function* addFunction(Function* curr);
Global* addGlobal(Global* curr);
Expand Down
Loading

0 comments on commit 42cddbf

Please sign in to comment.