Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for WebAssembly GC recursion groups #740

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
235 changes: 235 additions & 0 deletions JSTests/wasm/gc/rec.js
@@ -0,0 +1,235 @@
//@ runWebAssemblySuite("--useWebAssemblyTypedFunctionReferences=true", "--useWebAssemblyGC=true")

import * as assert from "../assert.js";
import { compile, instantiate } from "./wast-wrapper.js";

function module(bytes, valid = true) {
let buffer = new ArrayBuffer(bytes.length);
let view = new Uint8Array(buffer);
for (let i = 0; i < bytes.length; ++i) {
view[i] = bytes.charCodeAt(i);
}
return new WebAssembly.Module(buffer);
}

function testRecDeclaration() {
instantiate(`
(module
(rec (type (func)) (type (struct)))
)
`);

instantiate(`
(module
(rec (type (func)) (type (struct)))
(func (type 0))
)
`);

assert.throws(
() => compile(`
(module
(rec (type (struct)) (type (func)))
(func (type 0))
)
`),
WebAssembly.CompileError,
"type signature was not a function signature"
);

instantiate(`
(module
(rec
(type (func (result (ref 1))))
(type (func (result (ref 0)))))
)
`);

instantiate(`
(module
(rec
(type (func))
(type (func (result (ref 0)))))
)
`);

instantiate(`
(module
(rec (type (func)) (type (struct)))
(rec (type (func)) (type (struct)))
(elem declare funcref (ref.func 0))
(func (type 0))
(func (result (ref 2)) (ref.func 0))
)
`);

{
let m1 = instantiate(`
(module
(rec (type (func)) (type (struct)))
(func (export "f") (type 0))
)
`);
instantiate(`
(module
(rec (type (func)) (type (struct)))
(func (import "m" "f") (type 0))
(start 0)
)
`, { m: { f: m1.exports.f } });
}

{
let m1 = instantiate(`
(module
(rec (type (func)) (type (struct)))
(func (export "f") (type 0))
)
`);
assert.throws(
() => instantiate(`
(module
(rec (type (struct)) (type (func)))
(func (import "m" "f") (type 1))
(start 0)
)
`, { m: { f: m1.exports.f } }),
WebAssembly.LinkError,
"imported function m:f signature doesn't match the provided WebAssembly function's signature"
);
}

{
let m1 = instantiate(`
(module
(rec (type (func)) (type (struct)))
(elem declare funcref (ref.func 0))
(func)
(func (export "f") (result (ref 0)) (ref.func 0))
)
`);
assert.throws(
() => instantiate(`
(module
(rec (type (struct)) (type (func)))
(func (import "m" "f") (type 1))
)
`, { m: { f: m1.exports.f } }),
WebAssembly.LinkError,
"imported function m:f signature doesn't match the provided WebAssembly function's signature"
);
}

assert.throws(
() => compile(`
(module
(rec (type (func)) (type (struct)))
(rec (type (struct)) (type (func)))
(global (ref 0) (ref.func 0))
(func (type 3))
)
`),
WebAssembly.CompileError,
"Global init_expr opcode of type Ref doesn't match global's type Ref"
);

instantiate(`
(module
(rec (type (func (result (ref 1))))
(type (func (result (ref 0)))))
(elem declare funcref (ref.func 0))
(elem declare funcref (ref.func 1))
(func (type 0) (ref.func 1))
(func (type 1) (ref.func 0))
)
`);

assert.throws(
() => compile(`
(module
(rec (type (func (result (ref 1))))
(type (func (result (ref 0)))))
(elem declare funcref (ref.func 0))
(elem declare funcref (ref.func 1))
(func (type 0) (ref.func 1))
(func (type 1) (ref.func 1))
)
`),
WebAssembly.CompileError,
"control flow returns with unexpected type. Ref is not a Ref, in function at index 1"
);

instantiate(`
(module
(rec (type (func (param i32))) (type (struct)))
(elem declare funcref (ref.func 0))
(func (type 0))
(func (call_ref (i32.const 42) (ref.func 0)))
(start 1)
)
`);

instantiate(`
(module
(rec (type (func (result i32))) (type (struct)))
(rec (type (struct)) (type (func (result i32))))
(func (type 0)
(block (type 3)
(i32.const 42)))
)
`);

instantiate(`
(module
(rec (type (func (result i32))) (type (struct)))
(rec (type (struct)) (type (func (result i32))))
(func (type 0)
(loop (type 3)
(i32.const 42)))
)
`);

instantiate(`
(module
(rec (type (func (result i32))) (type (struct)))
(rec (type (struct)) (type (func (result i32))))
(func (type 0)
(i32.const 1)
(if (type 3) (then (i32.const 42)) (else (i32.const 43))))
)
`);

instantiate(`
(module
(rec (type (func)) (type (struct)))
(table 5 funcref)
(elem (offset (i32.const 0)) funcref (ref.func 0))
(func (type 0))
(func (call_indirect (type 0) (i32.const 0)))
(start 1)
)
`);

// Ensure implicit rec groups are accounted for, and treated
// correctly with regards to equality.
instantiate(`
(module
(type $a (struct (field i32)))
(rec (type $b (struct (field i32))))
(type $c (struct (field i32)))

(func (result (ref null $a)) (ref.null $b))
(func (result (ref null $a)) (ref.null $c))
(func (result (ref null $b)) (ref.null $a))
(func (result (ref null $b)) (ref.null $c))
(func (result (ref null $c)) (ref.null $a))
(func (result (ref null $c)) (ref.null $b)))
`);

// This is the same test as above, but using a particular binary encoding.
// The encoding for this test specifically uses both shorthand and the full
// rec form to test the equivalence of the two.
new WebAssembly.Instance(module("\x00\x61\x73\x6d\x01\x00\x00\x00\x01\x9e\x80\x80\x80\x00\x06\x5f\x01\x7f\x00\x4f\x01\x5f\x01\x7f\x00\x5f\x01\x7f\x00\x60\x00\x01\x6c\x00\x60\x00\x01\x6c\x01\x60\x00\x01\x6c\x02\x03\x87\x80\x80\x80\x00\x06\x03\x03\x04\x04\x05\x05\x0a\xb7\x80\x80\x80\x00\x06\x84\x80\x80\x80\x00\x00\xd0\x01\x0b\x84\x80\x80\x80\x00\x00\xd0\x02\x0b\x84\x80\x80\x80\x00\x00\xd0\x00\x0b\x84\x80\x80\x80\x00\x00\xd0\x02\x0b\x84\x80\x80\x80\x00\x00\xd0\x00\x0b\x84\x80\x80\x80\x00\x00\xd0\x01\x0b"));
}

testRecDeclaration();
1 change: 1 addition & 0 deletions JSTests/wasm/wasm.json
Expand Up @@ -18,6 +18,7 @@
"i31ref": { "type": "varint7", "value": -22, "b3type": "B3::Void" },
"func": { "type": "varint7", "value": -32, "b3type": "B3::Void" },
"struct": { "type": "varint7", "value": -33, "b3type": "B3::Void" },
"rec": { "type": "varint7", "value": -49, "b3type": "B3::Void" },
"void": { "type": "varint7", "value": -64, "b3type": "B3::Void" }
},
"value_type": ["i32", "i64", "f32", "f64", "externref", "funcref"],
Expand Down
8 changes: 5 additions & 3 deletions Source/JavaScriptCore/wasm/WasmAirIRGenerator.cpp
Expand Up @@ -948,7 +948,7 @@ void AirIRGenerator::restoreWasmContextInstance(BasicBlock* block, TypedTmp inst
emitPatchpoint(block, patchpoint, Tmp(), instance);
}

AirIRGenerator::AirIRGenerator(const ModuleInformation& info, B3::Procedure& procedure, InternalFunction* compilation, Vector<UnlinkedWasmToWasmCall>& unlinkedWasmToWasmCalls, MemoryMode mode, unsigned functionIndex, std::optional<bool> hasExceptionHandlers, TierUpCount* tierUp, const TypeDefinition& signature, unsigned& osrEntryScratchBufferSize)
AirIRGenerator::AirIRGenerator(const ModuleInformation& info, B3::Procedure& procedure, InternalFunction* compilation, Vector<UnlinkedWasmToWasmCall>& unlinkedWasmToWasmCalls, MemoryMode mode, unsigned functionIndex, std::optional<bool> hasExceptionHandlers, TierUpCount* tierUp, const TypeDefinition& originalSignature, unsigned& osrEntryScratchBufferSize)
: m_info(info)
, m_mode(mode)
, m_functionIndex(functionIndex)
Expand Down Expand Up @@ -1056,6 +1056,7 @@ AirIRGenerator::AirIRGenerator(const ModuleInformation& info, B3::Procedure& pro
m_mainEntrypointStart = m_code.addBlock();
m_currentBlock = m_mainEntrypointStart;

const TypeDefinition& signature = originalSignature.expand();
ASSERT(!m_locals.size());
m_locals.grow(signature.as<FunctionSignature>()->argumentCount());
for (unsigned i = 0; i < signature.as<FunctionSignature>()->argumentCount(); ++i) {
Expand Down Expand Up @@ -3780,9 +3781,10 @@ auto AirIRGenerator::addCall(uint32_t functionIndex, const TypeDefinition& signa
return { };
}

auto AirIRGenerator::addCallIndirect(unsigned tableIndex, const TypeDefinition& signature, Vector<ExpressionType>& args, ResultList& results) -> PartialResult
auto AirIRGenerator::addCallIndirect(unsigned tableIndex, const TypeDefinition& originalSignature, Vector<ExpressionType>& args, ResultList& results) -> PartialResult
{
ExpressionType calleeIndex = args.takeLast();
const TypeDefinition& signature = originalSignature.expand();
ASSERT(signature.as<FunctionSignature>()->argumentCount() == args.size());
ASSERT(m_info.tableCount() > tableIndex);
ASSERT(m_info.tables[tableIndex].type() == TableElementType::Funcref);
Expand Down Expand Up @@ -3849,7 +3851,7 @@ auto AirIRGenerator::addCallIndirect(unsigned tableIndex, const TypeDefinition&
});

ExpressionType expectedSignatureIndex = g64();
append(Move, Arg::bigImm(TypeInformation::get(signature)), expectedSignatureIndex);
append(Move, Arg::bigImm(TypeInformation::get(originalSignature)), expectedSignatureIndex);
emitCheck([&] {
return Inst(Branch64, nullptr, Arg::relCond(MacroAssembler::NotEqual), calleeSignatureIndex, expectedSignatureIndex);
}, [=, this] (CCallHelpers& jit, const B3::StackmapGenerationParams&) {
Expand Down
5 changes: 3 additions & 2 deletions Source/JavaScriptCore/wasm/WasmB3IRGenerator.cpp
Expand Up @@ -3066,9 +3066,10 @@ auto B3IRGenerator::addCall(uint32_t functionIndex, const TypeDefinition& signat
return { };
}

auto B3IRGenerator::addCallIndirect(unsigned tableIndex, const TypeDefinition& signature, Vector<ExpressionType>& args, ResultList& results) -> PartialResult
auto B3IRGenerator::addCallIndirect(unsigned tableIndex, const TypeDefinition& originalSignature, Vector<ExpressionType>& args, ResultList& results) -> PartialResult
{
Value* calleeIndex = get(args.takeLast());
const TypeDefinition& signature = originalSignature.expand();
ASSERT(signature.as<FunctionSignature>()->argumentCount() == args.size());

m_makesCalls = true;
Expand Down Expand Up @@ -3127,7 +3128,7 @@ auto B3IRGenerator::addCallIndirect(unsigned tableIndex, const TypeDefinition& s

// Check the signature matches the value we expect.
{
Value* expectedSignatureIndex = m_currentBlock->appendNew<Const64Value>(m_proc, origin(), TypeInformation::get(signature));
Value* expectedSignatureIndex = m_currentBlock->appendNew<Const64Value>(m_proc, origin(), TypeInformation::get(originalSignature));
CheckValue* check = m_currentBlock->appendNew<CheckValue>(m_proc, Check, origin(),
m_currentBlock->appendNew<Value>(m_proc, NotEqual, origin(), calleeSignatureIndex, expectedSignatureIndex));

Expand Down
4 changes: 2 additions & 2 deletions Source/JavaScriptCore/wasm/WasmBBQPlan.cpp
Expand Up @@ -137,7 +137,7 @@ void BBQPlan::work(CompilationEffort effort)

size_t functionIndexSpace = m_functionIndex + m_moduleInformation->importFunctionCount();
TypeIndex typeIndex = m_moduleInformation->internalFunctionTypeIndices[m_functionIndex];
const TypeDefinition& signature = TypeInformation::get(typeIndex);
const TypeDefinition& signature = TypeInformation::get(typeIndex).expand();
function->entrypoint.compilation = makeUnique<Compilation>(
FINALIZE_WASM_CODE_FOR_MODE(CompilationMode::BBQMode, linkBuffer, JITCompilationPtrTag, "WebAssembly BBQ function[%i] %s name %s", m_functionIndex, signature.toString().ascii().data(), makeString(IndexOrName(functionIndexSpace, m_moduleInformation->nameSection->get(functionIndexSpace))).ascii().data()),
WTFMove(context.wasmEntrypointByproducts));
Expand Down Expand Up @@ -206,7 +206,7 @@ void BBQPlan::compileFunction(uint32_t functionIndex)
if (m_exportedFunctionIndices.contains(functionIndex) || m_moduleInformation->referencedFunctions().contains(functionIndex)) {
Locker locker { m_lock };
TypeIndex typeIndex = m_moduleInformation->internalFunctionTypeIndices[functionIndex];
const TypeDefinition& signature = TypeInformation::get(typeIndex);
const TypeDefinition& signature = TypeInformation::get(typeIndex).expand();

m_compilationContexts[functionIndex].embedderEntrypointJIT = makeUnique<CCallHelpers>();
auto embedderToWasmInternalFunction = createJSToWasmWrapper(*m_compilationContexts[functionIndex].embedderEntrypointJIT, signature, &m_unlinkedWasmToWasmCalls[functionIndex], m_moduleInformation.get(), m_mode, functionIndex);
Expand Down
4 changes: 2 additions & 2 deletions Source/JavaScriptCore/wasm/WasmCallee.h
Expand Up @@ -265,7 +265,7 @@ class LLIntCallee final : public Callee {

LLIntTierUpCounter& tierUpCounter() { return m_tierUpCounter; }

const FunctionSignature& signature(unsigned index) const
const TypeDefinition& signature(unsigned index) const
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is wrong. I believe all wasm callees are still FunctionSignatures.

Copy link
Contributor Author

@takikawa takikawa May 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAICT this is required due to the semantics of recursion groups. The reason is that with the addition of recursion groups, in general all type signatures have to be compared by projection, even if you know it will be a function signature underneath.

The reason is you may have function types like the following:

(rec (type (func $f1)) (type (struct ...)))

(rec (type (func $f2)))

Even though $f1 and $f2 look the same, and are indeed the same if you unfold the projections, they are not equal types because they are in different recursion groups.

(it seems a bit silly in this very simple case, but you can construct more interesting examples where $f1 and $f2 would unfold to the same type that, say, reference another member of the recursion group. And where these other members are not equal types.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kmiller68 was trying to explain this to me, so I am wondering if I can ask few clarifying questions:

(rec $r1
    (type $tree (struct (field i32) (field (ref $forest))))
    (type $forest (struct (field (ref $tree)) (field (ref $forest)))))
 
(type $s1 (struct (field (ref $tree))))
  1. Types are stored as instances of TypeInformation, and expanded to produce a FunctionSignature (including for struct types)?

  2. Projection in this case means taking a TypeDefinition, and selecting an element of it without any substitution? So:

$r1 is given a type index
$s1 is represented by TypeDefinition(StructType(Projection($r1, 0)))
$tree is TypeDefinition(Projection($r1, 0))

  1. Expanding replaces a projection with a new projection that has had one step of unrolling applied, with the recursion group as the target of the unrolling

  2. Unrolling is applying one step of substitution, replacing references to the recursion group with its expansion. So:

expand($tree) <=>
unroll(Projection($r1, 0) with respect to $r1) <=>
TypeDefinition(StructType(i32), TypeDefinition(Projection($r1, 0)))

This new expansion is given its own type index, and comparisons are made by type index?

We need to compare types by type index because the type definitions are iso-recursive? So a TypeDefinition and its .expand() are not the same type?

If they were equirecursive, then we would have to perform a dfs expanding until we could prove that the types did/did not match?

Sorry to bug you with all these questions, I will continue playing with this patch to see how many of these questions I can answer by myself.

{
return *m_signatures[index];
}
Expand Down Expand Up @@ -304,7 +304,7 @@ class LLIntCallee final : public Callee {
std::unique_ptr<WasmInstructionStream> m_instructions;
const void* m_instructionsRawPointer { nullptr };
FixedVector<WasmInstructionStream::Offset> m_jumpTargets;
FixedVector<const FunctionSignature*> m_signatures;
FixedVector<const TypeDefinition*> m_signatures;
OutOfLineJumpTargets m_outOfLineJumpTargets;
LLIntTierUpCounter m_tierUpCounter;
FixedVector<JumpTable> m_jumpTables;
Expand Down
4 changes: 4 additions & 0 deletions Source/JavaScriptCore/wasm/WasmFormat.h
Expand Up @@ -74,6 +74,10 @@ inline bool isValueType(Type type)
case TypeKind::Ref:
case TypeKind::RefNull:
return Options::useWebAssemblyTypedFunctionReferences();
// Rec type kinds are used internally to represent `rec.<i>` references
// within recursion groups. They are invalid in other contexts.
case TypeKind::Rec:
return Options::useWebAssemblyGC();
default:
break;
}
Expand Down
Expand Up @@ -51,7 +51,7 @@ WasmInstructionStream::Offset FunctionCodeBlockGenerator::outOfLineJumpOffset(Wa
return m_outOfLineJumpTargets.get(bytecodeOffset);
}

unsigned FunctionCodeBlockGenerator::addSignature(const FunctionSignature& signature)
unsigned FunctionCodeBlockGenerator::addSignature(const TypeDefinition& signature)
{
unsigned index = m_signatures.size();
m_signatures.append(&signature);
Expand Down
6 changes: 3 additions & 3 deletions Source/JavaScriptCore/wasm/WasmFunctionCodeBlockGenerator.h
Expand Up @@ -48,7 +48,7 @@ class BytecodeGeneratorBase;
namespace Wasm {

class LLIntCallee;
class FunctionSignature;
class TypeDefinition;
struct GeneratorTraits;

struct JumpTableEntry {
Expand Down Expand Up @@ -116,7 +116,7 @@ class FunctionCodeBlockGenerator {

HashMap<WasmInstructionStream::Offset, LLIntTierUpCounter::OSREntryData>& tierUpCounter() { return m_tierUpCounter; }

unsigned addSignature(const FunctionSignature&);
unsigned addSignature(const TypeDefinition&);

JumpTable& addJumpTable(size_t numberOfEntries);
unsigned numberOfJumpTables() const;
Expand All @@ -140,7 +140,7 @@ class FunctionCodeBlockGenerator {
std::unique_ptr<WasmInstructionStream> m_instructions;
const void* m_instructionsRawPointer { nullptr };
Vector<WasmInstructionStream::Offset> m_jumpTargets;
Vector<const FunctionSignature*> m_signatures;
Vector<const TypeDefinition*> m_signatures;
OutOfLineJumpTargets m_outOfLineJumpTargets;
HashMap<WasmInstructionStream::Offset, LLIntTierUpCounter::OSREntryData> m_tierUpCounter;
Vector<JumpTable> m_jumpTables;
Expand Down