diff --git a/src/ir/effects.h b/src/ir/effects.h index 56cea0ad4fc..f7c0489dd4b 100644 --- a/src/ir/effects.h +++ b/src/ir/effects.h @@ -758,7 +758,8 @@ class EffectAnalyzer { parent.implicitTrap = true; const EffectAnalyzer* callTargetEffects = nullptr; - if (auto it = parent.module.indirectCallEffects.find(curr->heapType); + if (auto it = + parent.module.indirectCallEffects.find({curr->heapType, Inexact}); it != parent.module.indirectCallEffects.end()) { callTargetEffects = it->second.get(); } @@ -771,7 +772,8 @@ class EffectAnalyzer { const EffectAnalyzer* callTargetEffects = nullptr; if (auto it = parent.module.indirectCallEffects.find( - curr->target->type.getHeapType()); + {curr->target->type.getHeapType(), + curr->target->type.getExactness()}); it != parent.module.indirectCallEffects.end()) { callTargetEffects = it->second.get(); } diff --git a/src/ir/linear-execution.h b/src/ir/linear-execution.h index 9e69405ff7c..f67851d7e76 100644 --- a/src/ir/linear-execution.h +++ b/src/ir/linear-execution.h @@ -185,8 +185,10 @@ struct LinearExecutionWalker : public PostWalker { return true; } - auto* effects = find_or_null(self->getModule()->indirectCallEffects, - callRef->target->type.getHeapType()); + auto* effects = + find_or_null(self->getModule()->indirectCallEffects, + std::pair(callRef->target->type.getHeapType(), + callRef->target->type.getExactness())); if (!effects) { return false; } @@ -201,8 +203,9 @@ struct LinearExecutionWalker : public PostWalker { bool refutesThrowEffect = false; if (self->getModule()) { - if (auto* effects = find_or_null( - self->getModule()->indirectCallEffects, callIndirect->heapType); + if (auto* effects = + find_or_null(self->getModule()->indirectCallEffects, + std::pair(callIndirect->heapType, Inexact)); effects) { refutesThrowEffect = !(*effects)->throws_; } diff --git a/src/ir/type-updating.cpp b/src/ir/type-updating.cpp index be9f125e451..8dd30fbf956 100644 --- a/src/ir/type-updating.cpp +++ b/src/ir/type-updating.cpp @@ -329,16 +329,18 @@ void GlobalTypeRewriter::mapTypes(const TypeMap& oldToNewTypes) { // Update indirect call effects per type. // When A is rewritten to B, B inherits the effects of A and A loses its // effects. - std::unordered_map> + std::unordered_map, + std::shared_ptr> newTypeEffects; - for (auto& [oldType, oldEffects] : wasm.indirectCallEffects) { + for (auto& [oldTypeExact, oldEffects] : wasm.indirectCallEffects) { if (!oldEffects) { continue; } + auto [oldType, exactness] = oldTypeExact; auto newType = updater.getNew(oldType); std::shared_ptr& targetEffects = - newTypeEffects[newType]; + newTypeEffects[{newType, exactness}]; if (!targetEffects) { targetEffects = oldEffects; } else { diff --git a/src/passes/GlobalEffects.cpp b/src/passes/GlobalEffects.cpp index fd05d5afd6c..b3c5f8465a8 100644 --- a/src/passes/GlobalEffects.cpp +++ b/src/passes/GlobalEffects.cpp @@ -41,7 +41,7 @@ struct FuncInfo { std::unordered_set calledFunctions; // Types that are targets of indirect calls. - std::unordered_set indirectCalledTypes; + std::unordered_set> indirectCalledTypes; }; // Only funcs that are referenced may be the target of an indirect call. A @@ -161,19 +161,20 @@ std::map analyzeFuncs(Module& module, funcInfo.calledFunctions.insert(call->target); } else if (effects.calls && options.worldMode == WorldMode::Closed) { - HeapType type; + std::pair typeExact; if (auto* callRef = curr->dynCast()) { // call_ref on unreachable does not have a call effect, // so this must be a HeapType. - type = callRef->target->type.getHeapType(); + typeExact = {callRef->target->type.getHeapType(), + callRef->target->type.getExactness()}; } else if (auto* callIndirect = curr->dynCast()) { - type = callIndirect->heapType; + typeExact = {callIndirect->heapType, Inexact}; } else { funcInfo.effects = std::nullopt; return; } - funcInfo.indirectCalledTypes.insert(type); + funcInfo.indirectCalledTypes.insert(typeExact); } else if (effects.calls) { assert(options.worldMode == WorldMode::Open); funcInfo.effects = std::nullopt; @@ -195,19 +196,28 @@ std::map analyzeFuncs(Module& module, return std::move(analysis.map); } -using CallGraphNode = std::variant; +using CallGraphNode = std::variant>; // Call graph for indirect and direct calls. // // key (caller) -> value (callee) -// Function -> Function : direct call -// Function -> HeapType : indirect call to the given HeapType -// HeapType -> Function : The function `callee` has the type `caller`. The -// HeapType may essentially 'call' any of its -// potential implementations. -// HeapType -> HeapType : `callee` is a subtype of `caller`. A call_ref -// could target any subtype of the ref, so we need to -// aggregate effects of subtypes of the target type. +// Function -> Function : +// direct call +// Function -> HeapType : +// indirect call to the given HeapType (exact or inexact). +// HeapType -> Function : +// The function `callee` has the type `caller`. The HeapType may essentially +// 'call' any of its potential implementations. The HeapType is always Exact +// for these edges. +// HeapType -> HeapType : +// `callee` is a subtype of `caller`. An indirect call with an Inexact type +// could target any subtype of the ref, so we aggregate effects of subtypes of +// the target type. If B is a subtype of A, then we have edges: +// A (inexact) -> B (inexact) +// A (inexact) -> A (exact) +// B (inexact) -> B (exact) +// As a result, calls to (inexact A) include B's effects, and calls to +// (exact A) only include A's effects. // // If we're running in an open world, we only include Function -> Function // edges, and don't compute effects for indirect calls, conservatively assuming @@ -233,7 +243,7 @@ CallGraph buildCallGraph(const Module& module, return callGraph; } - std::unordered_set allFunctionTypes; + std::unordered_set> allFunctionTypes; for (const auto& [caller, callerInfo] : funcInfos) { auto& callees = callGraph[caller]; @@ -243,18 +253,19 @@ CallGraph buildCallGraph(const Module& module, } // Function -> Type - allFunctionTypes.insert(caller->type.getHeapType()); - for (HeapType calleeType : callerInfo.indirectCalledTypes) { - callees.insert(calleeType); + allFunctionTypes.insert(std::pair(caller->type.getHeapType(), Exact)); + for (auto calleeTypeExact : callerInfo.indirectCalledTypes) { + callees.insert(calleeTypeExact); // Add the key to ensure the lookup doesn't fail for indirect calls to // uninhabited types. - callGraph[calleeType]; + callGraph[calleeTypeExact]; + allFunctionTypes.insert(calleeTypeExact); } // Type -> Function if (referencedFuncs.contains(caller)) { - callGraph[caller->type.getHeapType()].insert(caller); + callGraph[std::pair(caller->type.getHeapType(), Exact)].insert(caller); } } @@ -262,18 +273,34 @@ CallGraph buildCallGraph(const Module& module, // Do a DFS up the type hierarchy for all function implementations. // We are essentially walking up each supertype chain and adding edges from // super -> subtype, but doing it via DFS to avoid repeated work. - Graph superTypeGraph(allFunctionTypes.begin(), - allFunctionTypes.end(), - [&callGraph](const auto& push, HeapType t) { - // Not needed except that during lookup we expect the - // key to exist. - callGraph[t]; - - if (auto super = t.getDeclaredSuperType()) { - callGraph[*super].insert(t); - push(*super); - } - }); + Graph superTypeGraph( + allFunctionTypes.begin(), + allFunctionTypes.end(), + [&callGraph](const auto& push, + std::pair typeAndExactness) { + // Not needed except that during lookup we expect the + // key to exist. + callGraph[typeAndExactness]; + + auto [type, exactness] = typeAndExactness; + + // The supertype of an exact type is its inexact type. + // The supertype of an inexact type is its normal inexact supertype. + switch (exactness) { + case Exact: { + callGraph[std::pair(type, Inexact)].insert(typeAndExactness); + push(std::pair(type, Inexact)); + break; + } + case Inexact: { + if (auto super = type.getDeclaredSuperType()) { + callGraph[std::pair(*super, Inexact)].insert(typeAndExactness); + push(std::pair(*super, Inexact)); + } + break; + } + } + }); (void)superTypeGraph.traverseDepthFirst(); return callGraph; @@ -310,8 +337,8 @@ void propagateEffects( const Module& module, const PassOptions& passOptions, std::map& funcInfos, - std::unordered_map>& - typeEffects, + std::unordered_map, + std::shared_ptr>& typeEffects, const CallGraph& callGraph) { // We only care about Functions that are roots, not types. // A type would be a root if a function exists with that type, but no-one @@ -410,7 +437,7 @@ void propagateEffects( // Assign each function's effects to its CC effects. for (auto node : cc) { - std::visit(overloaded{[&](HeapType type) { + std::visit(overloaded{[&](std::pair type) { if (ccEffects != UnknownEffects) { typeEffects[type] = ccEffects; } diff --git a/src/wasm.h b/src/wasm.h index 9f81336144d..5321eef18e2 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -2743,8 +2743,8 @@ class Module { // This data is only meaningful for indirect calls. If no indirect call // exists to a function, the data can be out of date (no effort is made to // clean up the data if e.g. all indirect calls to a function are removed). - // TODO: Account for exactness here. - std::unordered_map> + std::unordered_map, + std::shared_ptr> indirectCallEffects; MixedArena allocator; diff --git a/test/lit/passes/global-effects-closed-world.wast b/test/lit/passes/global-effects-closed-world.wast index 774a72d054f..041e391bd84 100644 --- a/test/lit/passes/global-effects-closed-world.wast +++ b/test/lit/passes/global-effects-closed-world.wast @@ -167,14 +167,12 @@ ) ;; CHECK: (func $calls-ref-with-exact-supertype (type $3) (param $func (ref (exact $super))) - ;; CHECK-NEXT: (call_ref $super - ;; CHECK-NEXT: (local.get $func) - ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (nop) ;; CHECK-NEXT: ) (func $calls-ref-with-exact-supertype (param $func (ref (exact $super))) ;; Same as above but this time our reference is the exact supertype - ;; so we know not to aggregate effects from the subtype. - ;; TODO: this case doesn't optimize today. Add exact ref support in the pass. + ;; so we know not to aggregate effects from the subtype. This can only + ;; call $nop-with-supertype which has no effects. (call_ref $super (local.get $func)) ) )