diff --git a/src/binaryen-c.cpp b/src/binaryen-c.cpp index a8f39c5abdc..bc16fe08db3 100644 --- a/src/binaryen-c.cpp +++ b/src/binaryen-c.cpp @@ -4973,7 +4973,7 @@ static BinaryenFunctionRef addFunctionInternal(BinaryenModuleRef module, BinaryenExpressionRef body) { auto* ret = new Function; ret->setExplicitName(name); - ret->type = type; + ret->type = Type(type, NonNullable, Exact); for (BinaryenIndex i = 0; i < numVarTypes; i++) { ret->vars.push_back(Type(varTypes[i])); } @@ -5097,7 +5097,8 @@ void BinaryenAddFunctionImport(BinaryenModuleRef module, func->module = externalModuleName; func->base = externalBaseName; // TODO: Take a HeapType rather than params and results. - func->type = Signature(Type(params), Type(results)); + func->type = + Type(Signature(Type(params), Type(results)), NonNullable, Exact); ((Module*)module)->addFunction(std::move(func)); } else { // already exists so just set module and base @@ -5285,7 +5286,8 @@ BinaryenAddActiveElementSegment(BinaryenModuleRef module, Fatal() << "invalid function '" << funcNames[i] << "'."; } segment->data.push_back( - Builder(*(Module*)module).makeRefFunc(funcNames[i], func->type)); + Builder(*(Module*)module) + .makeRefFunc(funcNames[i], func->type.getHeapType())); } return ((Module*)module)->addElementSegment(std::move(segment)); } @@ -5302,7 +5304,8 @@ BinaryenAddPassiveElementSegment(BinaryenModuleRef module, Fatal() << "invalid function '" << funcNames[i] << "'."; } segment->data.push_back( - Builder(*(Module*)module).makeRefFunc(funcNames[i], func->type)); + Builder(*(Module*)module) + .makeRefFunc(funcNames[i], func->type.getHeapType())); } return ((Module*)module)->addElementSegment(std::move(segment)); } @@ -6017,10 +6020,10 @@ void BinaryenFunctionSetBody(BinaryenFunctionRef func, ((Function*)func)->body = (Expression*)body; } BinaryenHeapType BinaryenFunctionGetType(BinaryenFunctionRef func) { - return ((Function*)func)->type.getID(); + return ((Function*)func)->type.getHeapType().getID(); } void BinaryenFunctionSetType(BinaryenFunctionRef func, BinaryenHeapType type) { - ((Function*)func)->type = HeapType(type); + ((Function*)func)->type = Type(HeapType(type), NonNullable, Exact); } void BinaryenFunctionOptimize(BinaryenFunctionRef func, BinaryenModuleRef module) { diff --git a/src/ir/module-splitting.cpp b/src/ir/module-splitting.cpp index 43ad6ccde58..887d21b11cc 100644 --- a/src/ir/module-splitting.cpp +++ b/src/ir/module-splitting.cpp @@ -370,9 +370,10 @@ void ModuleSplitter::setupJSPI() { primary.removeExport(LOAD_SECONDARY_MODULE); } else { // Add an imported function to load the secondary module. - auto import = Builder::makeFunction(ModuleSplitting::LOAD_SECONDARY_MODULE, - Signature(Type::none, Type::none), - {}); + auto import = Builder::makeFunction( + ModuleSplitting::LOAD_SECONDARY_MODULE, + Type(Signature(Type::none, Type::none), NonNullable, Exact), + {}); import->module = ENV; import->base = ModuleSplitting::LOAD_SECONDARY_MODULE; primary.addFunction(std::move(import)); @@ -689,14 +690,15 @@ void ModuleSplitter::indirectCallsToSecondaryFunctions() { Builder builder(*getModule()); Index secIndex = parent.funcToSecondaryIndex.at(curr->target); auto* func = parent.secondaries.at(secIndex)->getFunction(curr->target); - auto tableSlot = parent.tableManager.getSlot(curr->target, func->type); + auto tableSlot = + parent.tableManager.getSlot(curr->target, func->type.getHeapType()); replaceCurrent(parent.maybeLoadSecondary( builder, builder.makeCallIndirect(tableSlot.tableName, tableSlot.makeExpr(parent.primary), curr->operands, - func->type, + func->type.getHeapType(), curr->isReturn))); } }; @@ -786,7 +788,8 @@ void ModuleSplitter::setupTablePatching() { primary, std::string("placeholder_") + placeholder->base.toString()); placeholder->hasExplicitName = true; placeholder->type = secondaryFunc->type; - elem = Builder(primary).makeRefFunc(placeholder->name, placeholder->type); + elem = Builder(primary).makeRefFunc(placeholder->name, + placeholder->type.getHeapType()); primary.addFunction(std::move(placeholder)); }); @@ -827,7 +830,8 @@ void ModuleSplitter::setupTablePatching() { // primarySeg->data[i] is a placeholder, so use the secondary // function. auto* func = replacement->second; - auto* ref = Builder(secondary).makeRefFunc(func->name, func->type); + auto* ref = Builder(secondary).makeRefFunc(func->name, + func->type.getHeapType()); secondaryElems.push_back(ref); ++replacement; } else if (auto* get = primarySeg->data[i]->dynCast()) { @@ -869,7 +873,7 @@ void ModuleSplitter::setupTablePatching() { } auto* func = curr->second; currData.push_back( - Builder(secondary).makeRefFunc(func->name, func->type)); + Builder(secondary).makeRefFunc(func->name, func->type.getHeapType())); } if (currData.size()) { finishSegment(); diff --git a/src/ir/module-utils.cpp b/src/ir/module-utils.cpp index 582a2ad82df..5abbbecc01d 100644 --- a/src/ir/module-utils.cpp +++ b/src/ir/module-utils.cpp @@ -657,14 +657,14 @@ std::vector getPublicHeapTypes(Module& wasm) { // We can ignore call.without.effects, which is implemented as an import but // functionally is a call within the module. if (!Intrinsics(wasm).isCallWithoutEffects(func)) { - notePublic(func->type); + notePublic(func->type.getHeapType()); } }); for (auto& ex : wasm.exports) { switch (ex->kind) { case ExternalKind::Function: { auto* func = wasm.getFunction(*ex->getInternalName()); - notePublic(func->type); + notePublic(func->type.getHeapType()); continue; } case ExternalKind::Table: { diff --git a/src/ir/possible-contents.cpp b/src/ir/possible-contents.cpp index aba87d79db2..b9f2213fc9a 100644 --- a/src/ir/possible-contents.cpp +++ b/src/ir/possible-contents.cpp @@ -651,12 +651,13 @@ struct InfoCollector // actually have a RefFunc. auto* func = getModule()->getFunction(curr->func); for (Index i = 0; i < func->getParams().size(); i++) { - info.links.push_back( - {SignatureParamLocation{func->type, i}, ParamLocation{func, i}}); + info.links.push_back({SignatureParamLocation{func->type.getHeapType(), i}, + ParamLocation{func, i}}); } for (Index i = 0; i < func->getResults().size(); i++) { info.links.push_back( - {ResultLocation{func, i}, SignatureResultLocation{func->type, i}}); + {ResultLocation{func, i}, + SignatureResultLocation{func->type.getHeapType(), i}}); } if (!options.closedWorld) { @@ -1759,9 +1760,9 @@ void TNHOracle::infer() { continue; } while (1) { - typeFunctions[type].push_back(func.get()); - if (auto super = type.getDeclaredSuperType()) { - type = *super; + typeFunctions[type.getHeapType()].push_back(func.get()); + if (auto super = type.getHeapType().getDeclaredSuperType()) { + type = type.with(*super); } else { break; } @@ -1859,8 +1860,9 @@ void TNHOracle::infer() { // as other opts will make this call direct later, after which a // lot of other optimizations become possible anyhow. auto target = possibleTargets[0]->name; - info.inferences[call->target] = PossibleContents::literal( - Literal::makeFunc(target, wasm.getFunction(target)->type)); + info.inferences[call->target] = + PossibleContents::literal(Literal::makeFunc( + target, wasm.getFunction(target)->type.getHeapType())); continue; } diff --git a/src/ir/table-utils.h b/src/ir/table-utils.h index 76cc9f47951..6b9f6b5a8d0 100644 --- a/src/ir/table-utils.h +++ b/src/ir/table-utils.h @@ -92,7 +92,8 @@ inline Index append(Table& table, Name name, Module& wasm) { auto* func = wasm.getFunctionOrNull(name); assert(func != nullptr && "Cannot append non-existing function to a table."); - segment->data.push_back(Builder(wasm).makeRefFunc(name, func->type)); + segment->data.push_back( + Builder(wasm).makeRefFunc(name, func->type.getHeapType())); table.initial++; return tableIndex; } diff --git a/src/parser/contexts.h b/src/parser/contexts.h index ed8bebc02d8..122aaf0e1d6 100644 --- a/src/parser/contexts.h +++ b/src/parser/contexts.h @@ -1418,7 +1418,7 @@ struct ParseModuleTypesCtx : TypeParserCtx, if (!type.type.isSignature()) { return in.err(pos, "expected signature type"); } - f->type = type.type; + f->type = f->type.with(type.type); // If we are provided with too many names (more than the function has), we // will error on that later when we check the signature matches the type. // For now, avoid asserting in setLocalName. @@ -1601,7 +1601,7 @@ struct ParseDefsCtx : TypeParserCtx, AnnotationParserCtx { elems.push_back(expr); } void appendFuncElem(std::vector& elems, Name func) { - auto type = wasm.getFunction(func)->type; + auto type = wasm.getFunction(func)->type.getHeapType(); elems.push_back(builder.makeRefFunc(func, type)); } diff --git a/src/passes/Directize.cpp b/src/passes/Directize.cpp index 5b393a66900..519049456cb 100644 --- a/src/passes/Directize.cpp +++ b/src/passes/Directize.cpp @@ -153,7 +153,7 @@ struct FunctionDirectizer : public WalkerPass> { return CallUtils::Trap{}; } auto* func = getModule()->getFunction(name); - if (!HeapType::isSubType(func->type, original->heapType)) { + if (!HeapType::isSubType(func->type.getHeapType(), original->heapType)) { return CallUtils::Trap{}; } return CallUtils::Known{name}; diff --git a/src/passes/FuncCastEmulation.cpp b/src/passes/FuncCastEmulation.cpp index c534e6e647e..17c12a5aebd 100644 --- a/src/passes/FuncCastEmulation.cpp +++ b/src/passes/FuncCastEmulation.cpp @@ -178,7 +178,7 @@ struct FuncCastEmulation : public Pass { } auto* thunk = iter->second; ref->func = thunk->name; - ref->finalize(thunk->type); + ref->finalize(thunk->type.getHeapType()); } } @@ -209,11 +209,11 @@ struct FuncCastEmulation : public Pass { for (Index i = 0; i < numParams; i++) { thunkParams.push_back(Type::i64); } - auto thunkFunc = - builder.makeFunction(thunk, - Signature(Type(thunkParams), Type::i64), - {}, // no vars - toABI(call, module)); + auto thunkFunc = builder.makeFunction( + thunk, + Type(Signature(Type(thunkParams), Type::i64), NonNullable, Exact), + {}, // no vars + toABI(call, module)); thunkFunc->hasExplicitName = true; return module->addFunction(std::move(thunkFunc)); } diff --git a/src/passes/GenerateDynCalls.cpp b/src/passes/GenerateDynCalls.cpp index c13656d4837..a49ff1408d9 100644 --- a/src/passes/GenerateDynCalls.cpp +++ b/src/passes/GenerateDynCalls.cpp @@ -61,7 +61,7 @@ struct GenerateDynCalls : public WalkerPass> { std::vector tableSegmentData; ElementUtils::iterElementSegmentFunctionNames( it->get(), [&](Name name, Index) { - generateDynCallThunk(wasm->getFunction(name)->type); + generateDynCallThunk(wasm->getFunction(name)->type.getHeapType()); }); } } @@ -70,7 +70,7 @@ struct GenerateDynCalls : public WalkerPass> { // Generate dynCalls for invokes if (func->imported() && func->module == ENV && func->base.startsWith("invoke_")) { - Signature sig = func->type.getSignature(); + Signature sig = func->type.getHeapType().getSignature(); // The first parameter is a pointer to the original function that's called // by the invoke, so skip it std::vector newParams(sig.params.begin() + 1, sig.params.end()); @@ -155,7 +155,11 @@ void GenerateDynCalls::generateDynCallThunk(HeapType funcType) { params.push_back(param); } auto f = builder.makeFunction( - name, std::move(namedParams), Signature(Type(params), sig.results), {}); + name, + std::move(namedParams), + Type(Signature(Type(params), sig.results), NonNullable, Exact), + {}, + nullptr); f->hasExplicitName = true; Expression* fptr = builder.makeLocalGet(0, table->addressType); std::vector args; diff --git a/src/passes/JSPI.cpp b/src/passes/JSPI.cpp index 2ff8aa6ca47..4f77074d2ab 100644 --- a/src/passes/JSPI.cpp +++ b/src/passes/JSPI.cpp @@ -97,10 +97,10 @@ struct JSPI : public Pass { if (wasmSplit) { // Make an import for the load secondary module function so a JSPI wrapper // version will be created. - auto import = - Builder::makeFunction(ModuleSplitting::LOAD_SECONDARY_MODULE, - Signature(Type::none, Type::none), - {}); + auto import = Builder::makeFunction( + ModuleSplitting::LOAD_SECONDARY_MODULE, + Type(Signature(Type::none, Type::none), NonNullable, Exact), + {}); import->module = ENV; import->base = ModuleSplitting::LOAD_SECONDARY_MODULE; module->addFunction(std::move(import)); @@ -152,7 +152,8 @@ struct JSPI : public Pass { continue; } auto* replacementRef = builder.makeRefFunc( - iter->second, module->getFunction(iter->second)->type); + iter->second, + module->getFunction(iter->second)->type.getHeapType()); segment->data[i] = replacementRef; } } @@ -213,12 +214,12 @@ struct JSPI : public Pass { block->list.push_back(builder.makeConst(0)); } block->finalize(); - auto wrapperFunc = - Builder::makeFunction(wrapperName, - std::move(namedWrapperParams), - Signature(Type(wrapperParams), resultsType), - {}, - block); + auto wrapperFunc = Builder::makeFunction( + wrapperName, + std::move(namedWrapperParams), + Type(Signature(Type(wrapperParams), resultsType), NonNullable, Exact), + {}, + block); return module->addFunction(std::move(wrapperFunc))->name; } @@ -276,7 +277,8 @@ struct JSPI : public Pass { block->finalize(); call->type = im->getResults(); stub->body = block; - wrapperIm->type = Signature(Type(params), call->type); + wrapperIm->type = + Type(Signature(Type(params), call->type), NonNullable, Exact); if (wasmSplit && im->name == ModuleSplitting::LOAD_SECONDARY_MODULE) { // In non-debug builds the name of the JSPI wrapper function for loading diff --git a/src/passes/LegalizeJSInterface.cpp b/src/passes/LegalizeJSInterface.cpp index 4509f280b23..b3525b3d5a8 100644 --- a/src/passes/LegalizeJSInterface.cpp +++ b/src/passes/LegalizeJSInterface.cpp @@ -148,7 +148,7 @@ struct LegalizeJSInterface : public Pass { } curr->func = iter->second->name; - curr->finalize(iter->second->type); + curr->finalize(iter->second->type.getHeapType()); } }; @@ -248,7 +248,8 @@ struct LegalizeJSInterface : public Pass { } Type resultsType = func->getResults() == Type::i64 ? Type::i32 : func->getResults(); - legal->type = Signature(Type(legalParams), resultsType); + legal->type = + Type(Signature(Type(legalParams), resultsType), NonNullable, Exact); if (func->getResults() == Type::i64) { auto index = Builder::addVar(legal, Name(), Type::i64); auto* block = builder.makeBlock(); @@ -307,7 +308,8 @@ struct LegalizeJSInterface : public Pass { call->type = im->getResults(); stub->body = call; } - legalIm->type = Signature(Type(params), call->type); + legalIm->type = + Type(Signature(Type(params), call->type), NonNullable, Exact); auto* stubPtr = stub.get(); if (!module->getFunctionOrNull(stub->name)) { @@ -331,7 +333,8 @@ struct LegalizeJSInterface : public Pass { return f; } // Failing that create a new function import. - auto import = Builder::makeFunction(name, Signature(params, results), {}); + auto import = Builder::makeFunction( + name, Type(Signature(params, results), NonNullable, Exact), {}); import->module = ENV; import->base = name; auto* ret = import.get(); @@ -374,7 +377,7 @@ struct LegalizeAndPruneJSInterface : public LegalizeJSInterface { // The params are allowed to be multivalue, but not the results. Otherwise // look for SIMD etc. - auto sig = func->type.getSignature(); + auto sig = func->getSig(); auto illegal = isIllegal(sig.results); illegal = illegal || std::any_of(sig.params.begin(), diff --git a/src/passes/MergeSimilarFunctions.cpp b/src/passes/MergeSimilarFunctions.cpp index 54dcd4034f0..1e908f49abb 100644 --- a/src/passes/MergeSimilarFunctions.cpp +++ b/src/passes/MergeSimilarFunctions.cpp @@ -119,7 +119,7 @@ struct ParamInfo { return (*literals)[0].type; } else if (auto callees = std::get_if>(&values)) { auto* callee = module->getFunction((*callees)[0]); - return Type(callee->type, NonNullable); + return Type(callee->type.getHeapType(), NonNullable); } else { WASM_UNREACHABLE("unexpected const value type"); } @@ -132,7 +132,7 @@ struct ParamInfo { return builder.makeConst((*literals)[index]); } else if (auto callees = std::get_if>(&values)) { auto fnName = (*callees)[index]; - auto heapType = module->getFunction(fnName)->type; + auto heapType = module->getFunction(fnName)->type.getHeapType(); return builder.makeRefFunc(fnName, heapType); } else { WASM_UNREACHABLE("unexpected const value type"); @@ -613,8 +613,8 @@ Function* EquivalentClass::createShared(Module* module, Expression* body = ExpressionManipulator::flexibleCopy(primaryFunction->body, *module, copier); auto vars = primaryFunction->vars; - std::unique_ptr f = - builder.makeFunction(fnName, sig, std::move(vars), body); + std::unique_ptr f = builder.makeFunction( + fnName, Type(sig, NonNullable, Exact), std::move(vars), body); return module->addFunction(std::move(f)); } diff --git a/src/passes/Monomorphize.cpp b/src/passes/Monomorphize.cpp index 2459acf1796..9f02d00402e 100644 --- a/src/passes/Monomorphize.cpp +++ b/src/passes/Monomorphize.cpp @@ -792,7 +792,8 @@ struct Monomorphize : public Pass { // If we were dropped then we are pulling the drop into the monomorphized // function, which means we return nothing. auto newResults = context.dropped ? Type::none : func->getResults(); - newFunc->type = Signature(Type(newParams), newResults); + newFunc->type = + Type(Signature(Type(newParams), newResults), NonNullable, Exact); // We must update local indexes: the new function has a potentially // different number of parameters, and parameters are at the very bottom of diff --git a/src/passes/Outlining.cpp b/src/passes/Outlining.cpp index 768cdc8c186..9bf2f190252 100644 --- a/src/passes/Outlining.cpp +++ b/src/passes/Outlining.cpp @@ -595,7 +595,7 @@ struct ReconstructStringifyWalker } // Add a local.get instruction for every parameter of the outlined function. - Signature sig = outlinedFunc->type.getSignature(); + Signature sig = outlinedFunc->getSig(); ODBG(std::cerr << outlinedFunc->name << " takes " << sig.params.size() << " parameters\n"); for (Index i = 0; i < sig.params.size(); i++) { @@ -725,8 +725,8 @@ struct Outlining : public Pass { exprIdx++) { sig += StackSignature(exprs[exprIdx]); } - module->addFunction( - Builder::makeFunction(func, Signature(sig.params, sig.results), {})); + module->addFunction(Builder::makeFunction( + func, Type(Signature(sig.params, sig.results), NonNullable, Exact), {})); return func; } diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp index d4adeaa4636..19e8d9b3ee6 100644 --- a/src/passes/Print.cpp +++ b/src/passes/Print.cpp @@ -3079,9 +3079,9 @@ void PrintSExpression::handleSignature(Function* curr, printMajor(o, "func "); curr->name.print(o); if ((currModule && currModule->features.hasGC()) || - requiresExplicitFuncType(curr->type)) { + requiresExplicitFuncType(curr->type.getHeapType())) { o << " (type "; - printHeapTypeName(curr->type) << ')'; + printHeapTypeName(curr->type.getHeapType()) << ')'; } bool inParam = false; Index i = 0; diff --git a/src/passes/RemoveUnusedModuleElements.cpp b/src/passes/RemoveUnusedModuleElements.cpp index 44c1848355b..cf933c639f4 100644 --- a/src/passes/RemoveUnusedModuleElements.cpp +++ b/src/passes/RemoveUnusedModuleElements.cpp @@ -378,7 +378,7 @@ struct Analyzer { for (auto* item : segment->data) { if (auto* refFunc = item->dynCast()) { auto* func = module->getFunction(refFunc->func); - if (HeapType::isSubType(func->type, type)) { + if (HeapType::isSubType(func->type.getHeapType(), type)) { use({ModuleElementKind::Function, refFunc->func}); segmentReferenced = true; } @@ -402,7 +402,7 @@ struct Analyzer { // case where the target function is referenced but not used. auto element = ModuleElement{ModuleElementKind::Function, func}; - auto type = module->getFunction(func)->type; + auto type = module->getFunction(func)->type.getHeapType(); if (calledSignatures.count(type)) { // We must not have a type in both calledSignatures and // uncalledRefFuncMap: once it is called, we do not track RefFuncs for it diff --git a/src/passes/SignaturePruning.cpp b/src/passes/SignaturePruning.cpp index 274b6edcf1e..2330720b622 100644 --- a/src/passes/SignaturePruning.cpp +++ b/src/passes/SignaturePruning.cpp @@ -125,7 +125,8 @@ struct SignaturePruning : public Pass { // For direct calls, add each call to the type of the function being // called. for (auto* call : info.calls) { - allInfo[module->getFunction(call->target)->type].calls.push_back(call); + allInfo[module->getFunction(call->target)->type.getHeapType()] + .calls.push_back(call); // Intrinsics limit our ability to optimize in some cases. We will avoid // modifying any type that is used by call.without.effects, to avoid @@ -150,16 +151,16 @@ struct SignaturePruning : public Pass { // A parameter used in this function is used in the heap type - just one // function is enough to prevent the parameter from being removed. - auto& allUsedParams = allInfo[func->type].usedParams; + auto& allUsedParams = allInfo[func->type.getHeapType()].usedParams; for (auto index : info.usedParams) { allUsedParams.insert(index); } if (!info.optimizable) { - allInfo[func->type].optimizable = false; + allInfo[func->type.getHeapType()].optimizable = false; } - sigFuncs[func->type].push_back(func); + sigFuncs[func->type.getHeapType()].push_back(func); } // Find the public types, which cannot be modified. @@ -292,7 +293,7 @@ struct SignaturePruning : public Pass { // that, which would add more complexity in that method, undo the change // here. for (auto* func : funcs) { - func->type = type; + func->type = func->type.with(type); } } @@ -311,7 +312,7 @@ struct SignaturePruning : public Pass { for (auto* call : callTargetsToLocalize) { HeapType type; if (auto* c = call->dynCast()) { - type = module->getFunction(c->target)->type; + type = module->getFunction(c->target)->type.getHeapType(); } else if (auto* c = call->dynCast()) { type = c->target->type.getHeapType(); } else { diff --git a/src/passes/SignatureRefining.cpp b/src/passes/SignatureRefining.cpp index 270035ef680..bb45939fa3b 100644 --- a/src/passes/SignatureRefining.cpp +++ b/src/passes/SignatureRefining.cpp @@ -114,7 +114,8 @@ struct SignatureRefining : public Pass { // For direct calls, add each call to the type of the function being // called. for (auto* call : info.calls) { - allInfo[module->getFunction(call->target)->type].calls.push_back(call); + allInfo[module->getFunction(call->target)->type.getHeapType()] + .calls.push_back(call); // For call.without.effects, we also add the effective function being // called as well. The final operand is the function reference being @@ -138,11 +139,11 @@ struct SignatureRefining : public Pass { // Add the function's return LUB to the one for the heap type of that // function. - allInfo[func->type].resultsLUB.combine(info.resultsLUB); + allInfo[func->type.getHeapType()].resultsLUB.combine(info.resultsLUB); // If one function cannot be modified, that entire type cannot be. if (!info.canModify) { - allInfo[func->type].canModify = false; + allInfo[func->type.getHeapType()].canModify = false; } } @@ -180,16 +181,16 @@ struct SignatureRefining : public Pass { std::unordered_set seen; for (auto& func : module->functions) { auto type = func->type; - if (!seen.insert(type).second) { + if (!seen.insert(type.getHeapType()).second) { continue; } - auto& info = allInfo[type]; + auto& info = allInfo[type.getHeapType()]; if (!info.canModify) { continue; } - auto sig = type.getSignature(); + auto sig = type.getHeapType().getSignature(); auto numParams = sig.params.size(); std::vector paramLUBs(numParams); @@ -246,7 +247,7 @@ struct SignatureRefining : public Pass { } // We found an improvement! - newSignatures[type] = Signature(newParams, newResults); + newSignatures[type.getHeapType()] = Signature(newParams, newResults); if (newResults != func->getResults()) { // Update the types of calls using the signature. @@ -287,7 +288,7 @@ struct SignatureRefining : public Pass { } void doWalkFunction(Function* func) { - auto iter = parent.newSignatures.find(func->type); + auto iter = parent.newSignatures.find(func->type.getHeapType()); if (iter != parent.newSignatures.end()) { std::vector newParamsTypes; for (auto param : iter->second.params) { @@ -343,8 +344,8 @@ struct SignatureRefining : public Pass { } auto name = Names::getValidFunctionName(*module, import->name); - auto newImport = - module->addFunction(Builder(*module).makeFunction(name, newType, {})); + auto newImport = module->addFunction(Builder(*module).makeFunction( + name, Type(newType, NonNullable, Exact), {})); // Copy the binaryen intrinsic module.base import names. newImport->module = import->module; diff --git a/src/passes/StringLifting.cpp b/src/passes/StringLifting.cpp index 464ce7eafdb..7b8f456217f 100644 --- a/src/passes/StringLifting.cpp +++ b/src/passes/StringLifting.cpp @@ -131,63 +131,65 @@ struct StringLifting : public Pass { if (!func->imported() || func->module != WasmStringsModule) { continue; } + // TODO: Check exactness here too. auto type = func->type; if (func->base == "fromCharCodeArray") { - if (type != Signature({array16, i32, i32}, refExtern)) { + if (type.getHeapType() != Signature({array16, i32, i32}, refExtern)) { Fatal() << "StringLifting: bad type for fromCharCodeArray: " << type; } fromCharCodeArrayImport = func->name; found = true; } else if (func->base == "fromCodePoint") { - if (type != Signature(i32, refExtern)) { + if (type.getHeapType() != Signature(i32, refExtern)) { Fatal() << "StringLifting: bad type for fromCodePoint: " << type; } fromCodePointImport = func->name; found = true; } else if (func->base == "concat") { - if (type != Signature({externref, externref}, refExtern)) { + if (type.getHeapType() != + Signature({externref, externref}, refExtern)) { Fatal() << "StringLifting: bad type for concat: " << type; } concatImport = func->name; found = true; } else if (func->base == "intoCharCodeArray") { - if (type != Signature({externref, array16, i32}, i32)) { + if (type.getHeapType() != Signature({externref, array16, i32}, i32)) { Fatal() << "StringLifting: bad type for intoCharCodeArray: " << type; } intoCharCodeArrayImport = func->name; found = true; } else if (func->base == "equals") { - if (type != Signature({externref, externref}, i32)) { + if (type.getHeapType() != Signature({externref, externref}, i32)) { Fatal() << "StringLifting: bad type for equals: " << type; } equalsImport = func->name; found = true; } else if (func->base == "test") { - if (type != Signature({externref}, i32)) { + if (type.getHeapType() != Signature({externref}, i32)) { Fatal() << "StringLifting: bad type for test: " << type; } testImport = func->name; found = true; } else if (func->base == "compare") { - if (type != Signature({externref, externref}, i32)) { + if (type.getHeapType() != Signature({externref, externref}, i32)) { Fatal() << "StringLifting: bad type for compare: " << type; } compareImport = func->name; found = true; } else if (func->base == "length") { - if (type != Signature({externref}, i32)) { + if (type.getHeapType() != Signature({externref}, i32)) { Fatal() << "StringLifting: bad type for length: " << type; } lengthImport = func->name; found = true; } else if (func->base == "charCodeAt") { - if (type != Signature({externref, i32}, i32)) { + if (type.getHeapType() != Signature({externref, i32}, i32)) { Fatal() << "StringLifting: bad type for charCodeAt: " << type; } charCodeAtImport = func->name; found = true; } else if (func->base == "substring") { - if (type != Signature({externref, i32, i32}, refExtern)) { + if (type.getHeapType() != Signature({externref, i32, i32}, refExtern)) { Fatal() << "StringLifting: bad type for substring: " << type; } substringImport = func->name; diff --git a/src/passes/StringLowering.cpp b/src/passes/StringLowering.cpp index b19ccbc9d6e..fc93fdc17b3 100644 --- a/src/passes/StringLowering.cpp +++ b/src/passes/StringLowering.cpp @@ -305,7 +305,7 @@ struct StringLowering : public StringGathering { // function, which must be modified either in TypeMapper - but as just // explained we cannot do that - or before it, which is what we do here). for (auto& func : module->functions) { - if (func->type.getRecGroup().size() != 1 || + if (func->type.getHeapType().getRecGroup().size() != 1 || !func->type.getFeatures().hasStrings()) { continue; } @@ -320,18 +320,18 @@ struct StringLowering : public StringGathering { } return t; }; - for (auto param : func->type.getSignature().params) { + for (auto param : func->type.getHeapType().getSignature().params) { params.push_back(fix(param)); } - for (auto result : func->type.getSignature().results) { + for (auto result : func->type.getHeapType().getSignature().results) { results.push_back(fix(result)); } // In addition to doing the update, mark it in the map of updates for // TypeMapper, so RefFuncs with this type get updated. auto old = func->type; - func->type = Signature(params, results); - updates[old] = func->type; + func->type = func->type.with(Signature(params, results)); + updates[old.getHeapType()] = func->type.getHeapType(); } // Strings turn into externref. @@ -358,7 +358,8 @@ struct StringLowering : public StringGathering { auto name = Names::getValidFunctionName(*module, trueName); auto sig = Signature(params, results); Builder builder(*module); - auto* func = module->addFunction(builder.makeFunction(name, sig, {})); + auto* func = module->addFunction( + builder.makeFunction(name, Type(sig, NonNullable, Exact), {})); func->module = WasmStringsModule; func->base = trueName; return name; diff --git a/src/passes/TraceCalls.cpp b/src/passes/TraceCalls.cpp index 44bc16e95e6..9fa4e374dfa 100644 --- a/src/passes/TraceCalls.cpp +++ b/src/passes/TraceCalls.cpp @@ -95,7 +95,7 @@ struct AddTraceWrappers : public WalkerPass> { trackerCallParams.push_back(builder.makeLocalGet(localVar, op->type)); } - auto resultType = target->type.getSignature().results; + auto resultType = target->type.getHeapType().getSignature().results; auto realCall = builder.makeCall(target->name, realCallParams, resultType); if (resultType.isConcrete()) { @@ -146,7 +146,7 @@ struct TraceCalls : public Pass { private: Type getTracerParamsType(ImportInfo& info, const Function& func) { - auto resultsType = func.type.getSignature().results; + auto resultsType = func.type.getHeapType().getSignature().results; if (resultsType.isTuple()) { Fatal() << "Failed to instrument function '" << func.name << "': Multi-value result type is not supported"; @@ -156,7 +156,7 @@ struct TraceCalls : public Pass { if (resultsType.isConcrete()) { tracerParamTypes.push_back(resultsType); } - for (auto& op : func.type.getSignature().params) { + for (auto& op : func.type.getHeapType().getSignature().params) { tracerParamTypes.push_back(op); } @@ -205,7 +205,11 @@ struct TraceCalls : public Pass { if (!info.getImportedFunction(ENV, tracerName)) { auto import = Builder::makeFunction( - tracerName, Signature(getTracerParamsType(info, f), Type::none), {}); + tracerName, + Type(Signature(getTracerParamsType(info, f), Type::none), + NonNullable, + Exact), + {}); import->module = ENV; import->base = tracerName; wasm->addFunction(std::move(import)); diff --git a/src/passes/TrapMode.cpp b/src/passes/TrapMode.cpp index f269f081fd4..5c0bb708ec7 100644 --- a/src/passes/TrapMode.cpp +++ b/src/passes/TrapMode.cpp @@ -133,7 +133,8 @@ Function* generateBinaryFunc(Module& wasm, Binary* curr) { result); } auto funcSig = Signature({type, type}, type); - auto func = Builder::makeFunction(getBinaryFuncName(curr), funcSig, {}); + auto func = Builder::makeFunction( + getBinaryFuncName(curr), Type(funcSig, NonNullable, Exact), {}); func->body = builder.makeIf(builder.makeUnary(eqZOp, builder.makeLocalGet(1, type)), builder.makeConst(zeroLit), @@ -194,7 +195,9 @@ Function* generateUnaryFunc(Module& wasm, Unary* curr) { } auto func = - Builder::makeFunction(getUnaryFuncName(curr), Signature(type, retType), {}); + Builder::makeFunction(getUnaryFuncName(curr), + Type(Signature(type, retType), NonNullable, Exact), + {}); func->body = builder.makeUnary(truncOp, builder.makeLocalGet(0, type)); // too small XXX this is different than asm.js, which does frem. here we // clamp, which is much simpler/faster, and similar to native builds @@ -251,7 +254,7 @@ void ensureF64ToI64JSImport(TrappingFunctionContainer& trappingFunctions) { import->name = F64_TO_INT; import->module = ASM2WASM; import->base = F64_TO_INT; - import->type = Signature(Type::f64, Type::i32); + import->type = import->type.with(Signature(Type::f64, Type::i32)); trappingFunctions.addImport(import); } diff --git a/src/passes/param-utils.cpp b/src/passes/param-utils.cpp index e63606195ad..0b7aee94344 100644 --- a/src/passes/param-utils.cpp +++ b/src/passes/param-utils.cpp @@ -351,7 +351,8 @@ void localizeCallsTo(const std::unordered_set& callTargets, : callTargets(callTargets) {} void visitCall(Call* curr) { - handleCall(curr, getModule()->getFunction(curr->target)->type); + handleCall(curr, + getModule()->getFunction(curr->target)->type.getHeapType()); } void visitCallRef(CallRef* curr) { diff --git a/src/shell-interface.h b/src/shell-interface.h index 2c38f6456d4..8c45efff7d9 100644 --- a/src/shell-interface.h +++ b/src/shell-interface.h @@ -148,7 +148,7 @@ struct ShellExternalInterface : ModuleRunner::ExternalInterface { } return Flow(); }), - import->type); + import->type.getHeapType()); } else if (import->module == ENV && import->base == EXIT) { return Literal(std::make_shared(import->name, nullptr, @@ -157,7 +157,7 @@ struct ShellExternalInterface : ModuleRunner::ExternalInterface { std::cout << "exit()\n"; throw ExitException(); }), - import->type); + import->type.getHeapType()); } else if (auto* inst = getImportInstance(import)) { return inst->getExportedFunction(import->base); } diff --git a/src/tools/execution-results.h b/src/tools/execution-results.h index 8f30cf245a8..722050afa73 100644 --- a/src/tools/execution-results.h +++ b/src/tools/execution-results.h @@ -211,7 +211,7 @@ struct LoggingExternalInterface : public ShellExternalInterface { }; // Use a null instance because this is a host function. return Literal(std::make_shared(import->name, nullptr, f), - import->type); + import->type.getHeapType()); } void throwJSException() { diff --git a/src/tools/fuzzing/fuzzing.cpp b/src/tools/fuzzing/fuzzing.cpp index 8f3775f3a10..60c3d4edcba 100644 --- a/src/tools/fuzzing/fuzzing.cpp +++ b/src/tools/fuzzing/fuzzing.cpp @@ -928,7 +928,7 @@ void TranslateToFuzzReader::addImportLoggingSupport() { // simpler than avoiding calls to logging in all the rest of the logic). func->body = builder.makeNop(); } - func->type = Signature(type, Type::none); + func->type = Type(Signature(type, Type::none), NonNullable, Exact); wasm.addFunction(std::move(func)); } } @@ -976,7 +976,8 @@ void TranslateToFuzzReader::addImportCallingSupport() { func->name = callExportImportName; func->module = "fuzzing-support"; func->base = "call-export"; - func->type = Signature({Type::i32, Type::i32}, Type::none); + func->type = + Type(Signature({Type::i32, Type::i32}, Type::none), NonNullable, Exact); wasm.addFunction(std::move(func)); } @@ -990,7 +991,7 @@ void TranslateToFuzzReader::addImportCallingSupport() { func->name = callExportCatchImportName; func->module = "fuzzing-support"; func->base = "call-export-catch"; - func->type = Signature(Type::i32, Type::i32); + func->type = Type(Signature(Type::i32, Type::i32), NonNullable, Exact); wasm.addFunction(std::move(func)); } @@ -1007,7 +1008,9 @@ void TranslateToFuzzReader::addImportCallingSupport() { // As call-export, there is a flags param that allows us to catch+rethrow // all exceptions. func->type = - Signature({Type(HeapType::func, Nullable), Type::i32}, Type::none); + Type(Signature({Type(HeapType::func, Nullable), Type::i32}, Type::none), + NonNullable, + Exact); wasm.addFunction(std::move(func)); } @@ -1020,7 +1023,9 @@ void TranslateToFuzzReader::addImportCallingSupport() { func->name = callRefCatchImportName; func->module = "fuzzing-support"; func->base = "call-ref-catch"; - func->type = Signature(Type(HeapType::func, Nullable), Type::i32); + func->type = Type(Signature(Type(HeapType::func, Nullable), Type::i32), + NonNullable, + Exact); wasm.addFunction(std::move(func)); } } @@ -1042,7 +1047,7 @@ void TranslateToFuzzReader::addImportThrowingSupport() { // As with logging, implement in a trivial way when we cannot add imports. func->body = builder.makeNop(); } - func->type = Signature(Type::i32, Type::none); + func->type = Type(Signature(Type::i32, Type::none), NonNullable, Exact); wasm.addFunction(std::move(func)); } @@ -1081,7 +1086,9 @@ void TranslateToFuzzReader::addImportTableSupport() { func->name = tableGetImportName; func->module = "fuzzing-support"; func->base = "table-get"; - func->type = Signature({Type::i32}, Type(HeapType::func, Nullable)); + func->type = Type(Signature({Type::i32}, Type(HeapType::func, Nullable)), + NonNullable, + Exact); wasm.addFunction(std::move(func)); } @@ -1093,7 +1100,9 @@ void TranslateToFuzzReader::addImportTableSupport() { func->module = "fuzzing-support"; func->base = "table-set"; func->type = - Signature({Type::i32, Type(HeapType::func, Nullable)}, Type::none); + Type(Signature({Type::i32, Type(HeapType::func, Nullable)}, Type::none), + NonNullable, + Exact); wasm.addFunction(std::move(func)); } } @@ -1113,7 +1122,8 @@ void TranslateToFuzzReader::addImportSleepSupport() { func->name = sleepImportName; func->module = "fuzzing-support"; func->base = "sleep"; - func->type = Signature({Type::i32, Type::i32}, Type::i32); + func->type = + Type(Signature({Type::i32, Type::i32}, Type::i32), NonNullable, Exact); wasm.addFunction(std::move(func)); } @@ -1159,7 +1169,10 @@ void TranslateToFuzzReader::addHashMemorySupport() { auto* body = builder.makeBlock(contents); hashMemoryName = Names::getValidFunctionName(wasm, "hashMemory"); auto* hasher = wasm.addFunction(builder.makeFunction( - hashMemoryName, Signature(Type::none, Type::i32), {Type::i32}, body)); + hashMemoryName, + Type(Signature(Type::none, Type::i32), NonNullable, Exact), + {Type::i32}, + body)); if (!preserveImportsAndExports && !wasm.getExportOrNull("hashMemory")) { wasm.addExport( @@ -1192,7 +1205,8 @@ void TranslateToFuzzReader::useImportedModule() { auto name = Names::getValidFunctionName(wasm, "primary_" + exp->name.toString()); // We can import it as its own type, or any (declared) supertype. - auto type = getSuperType(func->type); + // TODO: this will be inexact eventually + auto type = getSuperType(func->type).with(NonNullable).with(Exact); auto import = builder.makeFunction(name, type, {}); import->module = "primary"; import->base = exp->name; @@ -1557,7 +1571,7 @@ Function* TranslateToFuzzReader::addFunction() { auto resultType = getControlFlowType(); funcType = Signature(paramType, resultType); } - func->type = *funcType; + func->type = Type(*funcType, NonNullable, Exact); Index numVars = upToSquared(fuzzParams->MAX_VARS); for (Index i = 0; i < numVars; i++) { @@ -1602,7 +1616,7 @@ Function* TranslateToFuzzReader::addFunction() { } // add some to an elem segment TODO we could do this for imported funcs too while (oneIn(3) && !random.finished()) { - auto type = Type(func->type, NonNullable); + auto type = func->type; std::vector compatibleSegments; ModuleUtils::iterActiveElementSegments(wasm, [&](ElementSegment* segment) { if (Type::isSubType(type, segment->type)) { @@ -1610,7 +1624,8 @@ Function* TranslateToFuzzReader::addFunction() { } }); auto& randomElem = compatibleSegments[upTo(compatibleSegments.size())]; - randomElem->data.push_back(builder.makeRefFunc(func->name, func->type)); + randomElem->data.push_back( + builder.makeRefFunc(func->name, func->type.getHeapType())); } numAddedFunctions++; return func; @@ -2184,7 +2199,8 @@ void TranslateToFuzzReader::addInvocations(Function* func) { if (wasm.getFunctionOrNull(name) || wasm.getExportOrNull(name)) { return; } - auto invoker = builder.makeFunction(name, Signature(), {}); + auto invoker = + builder.makeFunction(name, Type(Signature(), NonNullable, Exact), {}); Block* body = builder.makeBlock(); invoker->body = body; FunctionCreationContext context(*this, invoker.get()); @@ -2858,7 +2874,7 @@ Expression* TranslateToFuzzReader::makeCallIndirect(Type type) { } // TODO: use a random table return builder.makeCallIndirect( - funcrefTableName, target, args, targetFn->type, isReturn); + funcrefTableName, target, args, targetFn->type.getHeapType(), isReturn); } Expression* TranslateToFuzzReader::makeCallRef(Type type) { @@ -2886,7 +2902,10 @@ Expression* TranslateToFuzzReader::makeCallRef(Type type) { } // TODO: half the time make a completely random item with that type. return builder.makeCallRef( - builder.makeRefFunc(target->name, target->type), args, type, isReturn); + builder.makeRefFunc(target->name, target->type.getHeapType()), + args, + type, + isReturn); } Expression* TranslateToFuzzReader::makeLocalGet(Type type) { @@ -3507,10 +3526,11 @@ Expression* TranslateToFuzzReader::makeRefFuncConst(Type type) { assert(heapType.getBasic(Unshared) == HeapType::func); // With high probability, use the last created function if possible. // Otherwise, continue on to select some other function. - if (funcContext && funcContext->func->type.getShared() == share && + if (funcContext && + funcContext->func->type.getHeapType().getShared() == share && !oneIn(4)) { auto* target = funcContext->func; - return builder.makeRefFunc(target->name, target->type); + return builder.makeRefFunc(target->name, target->type.getHeapType()); } } // Look for a proper function starting from a random location, and loop from @@ -3520,8 +3540,8 @@ Expression* TranslateToFuzzReader::makeRefFuncConst(Type type) { Index i = start; do { auto& func = wasm.functions[i]; - if (Type::isSubType(Type(func->type, NonNullable), type)) { - return builder.makeRefFunc(func->name, func->type); + if (Type::isSubType(func->type, type)) { + return builder.makeRefFunc(func->name, func->type.getHeapType()); } i = (i + 1) % wasm.functions.size(); } while (i != start); @@ -3555,8 +3575,11 @@ Expression* TranslateToFuzzReader::makeRefFuncConst(Type type) { auto* body = heapType.getSignature().results == Type::none ? (Expression*)builder.makeNop() : (Expression*)builder.makeUnreachable(); - auto* func = wasm.addFunction(builder.makeFunction( - Names::getValidFunctionName(wasm, "ref_func_target"), heapType, {}, body)); + auto* func = wasm.addFunction( + builder.makeFunction(Names::getValidFunctionName(wasm, "ref_func_target"), + Type(heapType, NonNullable, Exact), + {}, + body)); return builder.makeRefFunc(func->name, heapType); } diff --git a/src/tools/wasm-ctor-eval.cpp b/src/tools/wasm-ctor-eval.cpp index 52cb9ba7039..427ed2d5dcb 100644 --- a/src/tools/wasm-ctor-eval.cpp +++ b/src/tools/wasm-ctor-eval.cpp @@ -320,7 +320,7 @@ struct CtorEvalExternalInterface : EvallingModuleRunner::ExternalInterface { // Use a null instance because these are either host functions or imported // from unknown sources. return Literal(std::make_shared(import->name, nullptr, f), - import->type); + import->type.getHeapType()); } Tag* getImportedTag(Tag* tag) override { @@ -1301,7 +1301,8 @@ EvalCtorOutcome evalCtor(EvallingModuleRunner& instance, // signature. If there is a mismatch, shift the local indices to make room // for the unused parameters. std::vector localTypes; - auto originalParams = originalFuncType.getSignature().params; + auto originalParams = + originalFuncType.getHeapType().getSignature().params; if (originalParams != func->getParams()) { // Add locals for the body to use instead of using the params. for (auto type : func->getParams()) { diff --git a/src/tools/wasm-merge.cpp b/src/tools/wasm-merge.cpp index bfce9f0cdc2..958c45fce3c 100644 --- a/src/tools/wasm-merge.cpp +++ b/src/tools/wasm-merge.cpp @@ -480,7 +480,9 @@ void fuseImportsAndExports(const PassOptions& options) { [import->module][import->base]; if (internalName.is()) { auto* export_ = merged.getFunction(internalName); - if (!HeapType::isSubType(export_->type, import->type)) { + // TODO: use Type subtyping when exactness handling is complete. + if (!HeapType::isSubType(export_->type.getHeapType(), + import->type.getHeapType())) { reportTypeMismatch(valid, "function", import); std::cerr << "type " << export_->type << " is not a subtype of " << import->type << ".\n"; @@ -569,7 +571,7 @@ void updateTypes(Module& wasm) { } void visitRefFunc(RefFunc* curr) { - curr->finalize(getModule()->getFunction(curr->func)->type); + curr->finalize(getModule()->getFunction(curr->func)->type.getHeapType()); } void visitFunction(Function* curr) { diff --git a/src/tools/wasm-reduce/wasm-reduce.cpp b/src/tools/wasm-reduce/wasm-reduce.cpp index d69d698b601..79696d6a7ba 100644 --- a/src/tools/wasm-reduce/wasm-reduce.cpp +++ b/src/tools/wasm-reduce/wasm-reduce.cpp @@ -1033,7 +1033,11 @@ struct Reducer } // Try to replace the body with the child, fixing up the function // to accept it. - func->type = Signature(funcType.getSignature().params, child->type); + func->type = + Type(Signature(funcType.getHeapType().getSignature().params, + child->type), + NonNullable, + Exact); func->body = child; if (writeAndTestReduction()) { // great, we succeeded! diff --git a/src/wasm-builder.h b/src/wasm-builder.h index c919bc29486..5689df5e2be 100644 --- a/src/wasm-builder.h +++ b/src/wasm-builder.h @@ -44,7 +44,7 @@ class Builder { // make* functions create an expression instance. static std::unique_ptr makeFunction(Name name, - HeapType type, + Type type, std::vector&& vars, Expression* body = nullptr) { assert(type.isSignature()); @@ -57,8 +57,16 @@ class Builder { } static std::unique_ptr makeFunction(Name name, - std::vector&& params, HeapType type, + std::vector&& vars, + Expression* body = nullptr) { + return makeFunction( + name, Type(type, NonNullable, Exact), std::move(vars), body); + } + + static std::unique_ptr makeFunction(Name name, + std::vector&& params, + Type type, std::vector&& vars, Expression* body = nullptr) { assert(type.isSignature()); @@ -82,6 +90,18 @@ class Builder { return func; } + static std::unique_ptr makeFunction(Name name, + std::vector&& params, + HeapType type, + std::vector&& vars, + Expression* body = nullptr) { + return makeFunction(name, + std::move(params), + Type(type, NonNullable, Exact), + std::move(vars), + body); + } + static std::unique_ptr makeTable(Name name, Type type = Type(HeapType::func, Nullable), @@ -1389,7 +1409,7 @@ class Builder { Signature sig = func->getSig(); std::vector params(sig.params.begin(), sig.params.end()); params.push_back(type); - func->type = Signature(Type(params), sig.results); + func->type = func->type.with(Signature(Type(params), sig.results)); Index index = func->localNames.size(); func->localIndices[name] = index; func->localNames[index] = name; diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h index 918e7f83674..724d73dfb29 100644 --- a/src/wasm-interpreter.h +++ b/src/wasm-interpreter.h @@ -3202,7 +3202,7 @@ class ModuleRunnerBase : public ExpressionRunner { [this, func](const Literals& arguments) -> Flow { return callFunction(func->name, arguments); }), - func->type); + func->type.getHeapType()); } // get an exported global @@ -3504,14 +3504,14 @@ class ModuleRunnerBase : public ExpressionRunner { // The call.without.effects intrinsic is a call to an import that actually // calls the given function reference that is the final argument. target = arguments.back().getFunc(); - funcType = arguments.back().type.getHeapType(); + funcType = arguments.back().type; arguments.pop_back(); } if (curr->isReturn) { // Return calls are represented by their arguments followed by a reference // to the function to be called. - arguments.push_back(self()->makeFuncData(target, funcType)); + arguments.push_back(self()->makeFuncData(target, funcType.getHeapType())); return Flow(RETURN_CALL_FLOW, std::move(arguments)); } @@ -4470,7 +4470,8 @@ class ModuleRunnerBase : public ExpressionRunner { auto funcName = funcValue.getFunc(); auto* func = self()->getModule()->getFunction(funcName); return Literal(std::make_shared( - self()->makeFuncData(func->name, func->type), curr->type.getHeapType())); + self()->makeFuncData(func->name, func->type.getHeapType()), + curr->type.getHeapType())); } Flow visitContBind(ContBind* curr) { Literals arguments; @@ -4815,7 +4816,7 @@ class ModuleRunnerBase : public ExpressionRunner { // not the original function that was called, and the original has been // returned from already; we should call the last return_called // function). - auto target = self()->makeFuncData(name, function->type); + auto target = self()->makeFuncData(name, function->type.getHeapType()); self()->pushResumeEntry({target}, "function-target"); } diff --git a/src/wasm-ir-builder.h b/src/wasm-ir-builder.h index c1f01e955a0..b79061d8ab9 100644 --- a/src/wasm-ir-builder.h +++ b/src/wasm-ir-builder.h @@ -537,7 +537,7 @@ class IRBuilder : public UnifiedExpressionVisitor> { } Type getResultType() { if (auto* func = getFunction()) { - return func->type.getSignature().results; + return func->getResults(); } if (auto* block = getBlock()) { return block->type; diff --git a/src/wasm.h b/src/wasm.h index 8ae52168304..e1bec596219 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -2224,7 +2224,9 @@ class EffectAnalyzer; class Function : public Importable { public: - HeapType type = HeapType(Signature()); // parameters and return value + // A non-nullable reference to a function type. Exact for defined functions. + // TODO: Inexact for imported functions. + Type type = Type(Signature(), NonNullable, Exact); IRProfile profile = IRProfile::Normal; std::vector vars; // non-param locals @@ -2307,11 +2309,15 @@ class Function : public Importable { bool noPartialInline = false; // Methods - Signature getSig() { return type.getSignature(); } + Signature getSig() { return type.getHeapType().getSignature(); } Type getParams() { return getSig().params; } Type getResults() { return getSig().results; } - void setParams(Type params) { type = Signature(params, getResults()); } - void setResults(Type results) { type = Signature(getParams(), results); } + void setParams(Type params) { + type = type.with(Signature(params, getResults())); + } + void setResults(Type results) { + type = type.with(Signature(getParams(), results)); + } size_t getNumParams(); size_t getNumVars(); diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index 848f8e7473e..ebab1eb24b2 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp @@ -333,7 +333,7 @@ void WasmBinaryWriter::writeImports() { ModuleUtils::iterImportedFunctions(*wasm, [&](Function* func) { writeImportHeader(func); o << U32LEB(int32_t(ExternalKind::Function)); - o << U32LEB(getTypeIndex(func->type)); + o << U32LEB(getTypeIndex(func->type.getHeapType())); }); ModuleUtils::iterImportedGlobals(*wasm, [&](Global* global) { writeImportHeader(global); @@ -375,8 +375,9 @@ void WasmBinaryWriter::writeFunctionSignatures() { } auto start = startSection(BinaryConsts::Section::Function); o << U32LEB(importInfo->getNumDefinedFunctions()); - ModuleUtils::iterDefinedFunctions( - *wasm, [&](Function* func) { o << U32LEB(getTypeIndex(func->type)); }); + ModuleUtils::iterDefinedFunctions(*wasm, [&](Function* func) { + o << U32LEB(getTypeIndex(func->type.getHeapType())); + }); finishSection(start); } @@ -2893,7 +2894,8 @@ void WasmBinaryReader::readImports() { '.' + base.toString() + "'s type must be a signature. Given: " + type.toString()); } - auto curr = builder.makeFunction(name, type, {}); + auto curr = + builder.makeFunction(name, Type(type, NonNullable, Exact), {}); curr->hasExplicitName = isExplicit; curr->module = module; curr->base = base; @@ -3028,7 +3030,8 @@ void WasmBinaryReader::readFunctionSignatures() { functionTypes.push_back(type); // Check that the type is a signature. getSignatureByTypeIndex(index); - auto func = Builder(wasm).makeFunction(name, type, {}, nullptr); + auto func = Builder(wasm).makeFunction( + name, Type(type, NonNullable, Exact), {}, nullptr); func->hasExplicitName = isExplicit; wasm.addFunction(std::move(func)); } diff --git a/src/wasm/wasm-ir-builder.cpp b/src/wasm/wasm-ir-builder.cpp index 32042f1c902..512551b1eb7 100644 --- a/src/wasm/wasm-ir-builder.cpp +++ b/src/wasm/wasm-ir-builder.cpp @@ -1736,7 +1736,7 @@ Result<> IRBuilder::makeRefIsNull() { } Result<> IRBuilder::makeRefFunc(Name func) { - push(builder.makeRefFunc(func, wasm.getFunction(func)->type)); + push(builder.makeRefFunc(func, wasm.getFunction(func)->type.getHeapType())); return Ok{}; } diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp index 576cc476934..c8466b8c792 100644 --- a/src/wasm/wasm-validator.cpp +++ b/src/wasm/wasm-validator.cpp @@ -965,7 +965,7 @@ void FunctionValidator::visitCall(Call* curr) { if (!shouldBeTrue(!!target, curr, "call target must exist")) { return; } - validateCallParamsAndResult(curr, target->type); + validateCallParamsAndResult(curr, target->type.getHeapType()); if (Intrinsics(*getModule()).isCallWithoutEffects(curr)) { // call.without.effects has the specific form of the last argument being a @@ -2385,9 +2385,10 @@ void FunctionValidator::visitRefFunc(RefFunc* curr) { if (!shouldBeTrue(!!func, curr, "function argument of ref.func must exist")) { return; } - shouldBeTrue(func->type == curr->type.getHeapType(), - curr, - "function reference type must match referenced function type"); + shouldBeEqual(curr->type, + func->type, + curr, + "function reference type must match referenced function type"); shouldBeTrue( curr->type.isExact(), curr, "function reference should be exact"); } diff --git a/test/example/local-graph.cpp b/test/example/local-graph.cpp index 5f82982e3ae..10e59487651 100644 --- a/test/example/local-graph.cpp +++ b/test/example/local-graph.cpp @@ -12,7 +12,7 @@ int main() { { Function foo; - foo.type = Signature(Type::none, Type::none); + foo.type = Type(Signature(Type::none, Type::none), NonNullable, Exact); foo.vars = {Type::i32}; auto* get1 = builder.makeLocalGet(0, Type::i32); auto* get2 = builder.makeLocalGet(0, Type::i32); @@ -28,7 +28,7 @@ int main() { { Function foo; - foo.type = Signature(Type::none, Type::none); + foo.type = Type(Signature(Type::none, Type::none), NonNullable, Exact); foo.vars = {Type::i32}; auto* get1 = builder.makeLocalGet(0, Type::i32); auto* get2 = builder.makeLocalGet(0, Type::i32); @@ -44,7 +44,7 @@ int main() { { Function foo; - foo.type = Signature({Type::i32}, Type::none); + foo.type = Type(Signature({Type::i32}, Type::none), NonNullable, Exact); auto* get1 = builder.makeLocalGet(0, Type::i32); auto* get2 = builder.makeLocalGet(0, Type::i32); foo.body = builder.makeBlock({ @@ -58,7 +58,7 @@ int main() { { Function foo; - foo.type = Signature({Type::i32}, Type::none); + foo.type = Type(Signature({Type::i32}, Type::none), NonNullable, Exact); auto* get1 = builder.makeLocalGet(0, Type::i32); auto* get2 = builder.makeLocalGet(0, Type::i32); foo.body = builder.makeBlock({ @@ -73,7 +73,8 @@ int main() { { Function foo; - foo.type = Signature({Type::i32, Type::i32}, Type::none); + foo.type = + Type(Signature({Type::i32, Type::i32}, Type::none), NonNullable, Exact); auto* get1 = builder.makeLocalGet(0, Type::i32); auto* get2 = builder.makeLocalGet(1, Type::i32); foo.body = builder.makeBlock({ @@ -87,7 +88,7 @@ int main() { { Function foo; - foo.type = Signature(Type::none, Type::none); + foo.type = Type(Signature(Type::none, Type::none), NonNullable, Exact); foo.vars = {Type::i32, Type::i32}; auto* get1 = builder.makeLocalGet(0, Type::i32); auto* get2 = builder.makeLocalGet(1, Type::i32); @@ -103,7 +104,7 @@ int main() { { Function foo; - foo.type = Signature(Type::none, Type::none); + foo.type = Type(Signature(Type::none, Type::none), NonNullable, Exact); foo.vars = {Type::i32, Type::f64}; auto* get1 = builder.makeLocalGet(0, Type::i32); auto* get2 = builder.makeLocalGet(1, Type::f64); diff --git a/test/lit/merge/types.wat b/test/lit/merge/types.wat index 956d885e034..5ab6bad9786 100644 --- a/test/lit/merge/types.wat +++ b/test/lit/merge/types.wat @@ -2,8 +2,8 @@ ;; Test of exports / imports type matching -;; CHECK: Type mismatch when importing function f1 from module env ($bad1): type (type $func.0 (func)) is not a subtype of (type $func.0 (func (param (ref eq)))). -;; CHECK-NEXT: Type mismatch when importing function f3 from module env ($bad2): type (type $func.0 (sub (func (result anyref)))) is not a subtype of (type $func.0 (sub $func.1 (func (result eqref)))). +;; CHECK: Type mismatch when importing function f1 from module env ($bad1): +;; CHECK-NEXT: Type mismatch when importing function f3 from module env ($bad2): ;; CHECK-NEXT: Type mismatch when importing table t1 from module env ($bad1): minimal size 10 is smaller than expected minimal size 11. ;; CHECK-NEXT: Type mismatch when importing table t1 from module env ($bad2): maximal size 100 is larger than expected maximal size 99. ;; CHECK-NEXT: Type mismatch when importing table t2 from module env ($bad3): expecting a bounded table but the imported table is unbounded. diff --git a/test/lit/passes/string-lifting-validation.wast b/test/lit/passes/string-lifting-validation.wast index f91dfedff48..9f11b801c4e 100644 --- a/test/lit/passes/string-lifting-validation.wast +++ b/test/lit/passes/string-lifting-validation.wast @@ -18,5 +18,4 @@ ) ;; RUN: not wasm-opt %s --string-lifting -all 2>&1 | filecheck %s -;; CHECK: Fatal: StringLifting: bad type for fromCharCodeArray: (type $func.0 (func (param (ref null $array.0) i32 i64) (result (ref extern)))) - +;; CHECK: Fatal: StringLifting: bad type for fromCharCodeArray diff --git a/test/lit/passes/string-lifting-wrong-type.wast b/test/lit/passes/string-lifting-wrong-type.wast index bcc01d0f056..b8fa2c99355 100644 --- a/test/lit/passes/string-lifting-wrong-type.wast +++ b/test/lit/passes/string-lifting-wrong-type.wast @@ -21,4 +21,4 @@ ) ;; RUN: not wasm-opt %s --string-lifting -all 2>&1 | filecheck %s -;; CHECK: Fatal: StringLifting: bad type for fromCharCodeArray: (type $func.0 (sub (func (param (ref null $array.0) i32 i32) (result (ref extern))))) +;; CHECK: Fatal: StringLifting: bad type for fromCharCodeArray