Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/passes/GUFA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,8 @@ struct GUFAOptimizer
bool optimized = false;

void visitExpression(Expression* curr) {
if (!curr->type.isRef()) {
// Ignore anything we cannot infer a type for.
// Ignore anything we cannot emit a cast for.
if (!curr->type.isCastable()) {
return;
}

Expand Down
1 change: 1 addition & 0 deletions src/tools/fuzzing.h
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,7 @@ class TranslateToFuzzReader {
// Getters for Types
Type getSingleConcreteType();
Type getReferenceType();
Type getCastableReferenceType();
Type getEqReferenceType();
Type getMVPType();
Type getTupleType();
Expand Down
46 changes: 36 additions & 10 deletions src/tools/fuzzing/fuzzing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2414,10 +2414,12 @@ Expression* TranslateToFuzzReader::_makeConcrete(Type type) {
options.add(FeatureSet::ReferenceTypes | FeatureSet::GC,
&Self::makeCompoundRef);
}
// Exact casts are only allowed with custom descriptors enabled.
if (type.isInexact() || wasm.features.hasCustomDescriptors()) {
options.add(FeatureSet::ReferenceTypes | FeatureSet::GC,
&Self::makeRefCast);
if (type.isCastable()) {
// Exact casts are only allowed with custom descriptors enabled.
if (type.isInexact() || wasm.features.hasCustomDescriptors()) {
options.add(FeatureSet::ReferenceTypes | FeatureSet::GC,
&Self::makeRefCast);
}
}
if (heapType.getDescribedType()) {
options.add(FeatureSet::ReferenceTypes | FeatureSet::GC,
Expand Down Expand Up @@ -5044,8 +5046,8 @@ Expression* TranslateToFuzzReader::makeRefTest(Type type) {
switch (upTo(3)) {
case 0:
// Totally random.
refType = getReferenceType();
castType = getReferenceType();
refType = getCastableReferenceType();
castType = getCastableReferenceType();
// They must share a bottom type in order to validate.
if (refType.getHeapType().getBottom() ==
castType.getHeapType().getBottom()) {
Expand All @@ -5056,12 +5058,12 @@ Expression* TranslateToFuzzReader::makeRefTest(Type type) {
[[fallthrough]];
case 1:
// Cast is a subtype of ref.
refType = getReferenceType();
refType = getCastableReferenceType();
castType = getSubType(refType);
break;
case 2:
// Ref is a subtype of cast.
castType = getReferenceType();
castType = getCastableReferenceType();
refType = getSubType(castType);
break;
default:
Expand All @@ -5085,7 +5087,7 @@ Expression* TranslateToFuzzReader::makeRefCast(Type type) {
switch (upTo(3)) {
case 0:
// Totally random.
refType = getReferenceType();
refType = getCastableReferenceType();
// They must share a bottom type in order to validate.
if (refType.getHeapType().getBottom() == type.getHeapType().getBottom()) {
break;
Expand Down Expand Up @@ -5190,7 +5192,11 @@ Expression* TranslateToFuzzReader::makeBrOn(Type type) {
// We are sending a reference type to the target. All other BrOn variants can
// do that.
assert(targetType.isRef());
auto op = pick(BrOnNonNull, BrOnCast, BrOnCastFail);
// BrOnNonNull can handle sending any reference. The casts are more limited.
auto op = BrOnNonNull;
if (targetType.isCastable()) {
op = pick(BrOnNonNull, BrOnCast, BrOnCastFail);
}
Type castType = Type::none;
Type refType;
switch (op) {
Expand Down Expand Up @@ -5635,6 +5641,26 @@ Type TranslateToFuzzReader::getReferenceType() {
Type(HeapType::string, NonNullable)));
}

Type TranslateToFuzzReader::getCastableReferenceType() {
int tries = fuzzParams->TRIES;
while (tries-- > 0) {
auto type = getReferenceType();
if (type.isCastable()) {
return type;
}
}
// We failed to find a type using fair sampling. Do something simple that must
// work.
Type type;
if (oneIn(4)) {
type = getSubType(Type(HeapType::func, Nullable));
} else {
type = getSubType(Type(HeapType::any, Nullable));
}
assert(type.isCastable());
return type;
}

Type TranslateToFuzzReader::getEqReferenceType() {
if (oneIn(2) && !interestingHeapTypes.empty()) {
// Try to find an interesting eq-compatible type.
Expand Down
15 changes: 5 additions & 10 deletions src/wasm-stack.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,16 +167,11 @@ class BinaryInstWriter : public OverriddenVisitor<BinaryInstWriter> {
// when they have a value that is more refined than the wasm type system
// allows atm (and they are not dropped, in which case the type would not
// matter). See https://github.com/WebAssembly/binaryen/pull/6390 for more on
// the difference. As a result of the difference, we will insert extra casts
// to ensure validation in the wasm spec. The wasm spec will hopefully improve
// to use the more refined type as well, which would remove the need for this
// hack.
//
// Each br_if present as a key here is mapped to the unrefined type for it.
// That is, the br_if has a type in Binaryen IR that is too refined, and the
// map contains the unrefined one (which we need to know the local types, as
// we'll stash the unrefined values and then cast them).
std::unordered_map<Break*, Type> brIfsNeedingHandling;
// the difference. As a result of the difference, we must fix things up for
// the spec. (The wasm spec might - hopefully - improve to use the more
// refined type as well, which would remove the need for this hack, and
// improve code size in general.)
std::unordered_set<Break*> brIfsNeedingHandling;
};

// Takes binaryen IR and converts it to something else (binary or stack IR)
Expand Down
3 changes: 3 additions & 0 deletions src/wasm-type.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ class HeapType {
return isBasic() && getBasic(Unshared) == type;
}

bool isCastable();

Signature getSignature() const;
Continuation getContinuation() const;

Expand Down Expand Up @@ -415,6 +417,7 @@ class Type {
return isRef() && getHeapType().isContinuation();
}
bool isDefaultable() const;
bool isCastable();

// TODO: Allow this only for reference types.
Nullability getNullability() const {
Expand Down
138 changes: 89 additions & 49 deletions src/wasm/wasm-stack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,56 +63,88 @@ void BinaryInstWriter::visitLoop(Loop* curr) {
}

void BinaryInstWriter::visitBreak(Break* curr) {
auto type = curr->type;

// See comment on |brIfsNeedingHandling| for the extra handling we need to
// emit here for certain br_ifs. If we need that handling, we either use a
// cast in simple cases, or scratch locals otherwise. We use the scratch
// locals to stash the stack before the br_if (which contains the refined
// types), then restore it later from those locals.
bool needScratchLocals = false;
// If we need locals, we must track how many we've used from each type as we
// go, as a type might appear multiple times in the tuple. We know we have
// enough of a range allocated for them, so we just increment as we go.
std::unordered_map<Type, Index> scratchTypeUses;
// Logic to stash and restore the stack, given a vector of types we are
// stashing/restoring. We will first stash the entire stack, including the i32
// condition, and after the br_if, restore the value (without the condition).
auto stashStack = [&](const std::vector<Type>& types) {
for (Index i = 0; i < types.size(); i++) {
auto t = types[types.size() - i - 1];
assert(scratchLocals.find(t) != scratchLocals.end());
auto localIndex = scratchLocals[t] + scratchTypeUses[t]++;
o << int8_t(BinaryConsts::LocalSet) << U32LEB(localIndex);
}
};
auto restoreStack = [&](const std::vector<Type>& types) {
// Use a copy of this data, as we will restore twice.
auto currScratchTypeUses = scratchTypeUses;
for (Index i = 0; i < types.size(); i++) {
auto t = types[i];
auto localIndex = scratchLocals[t] + --currScratchTypeUses[t];
o << int8_t(BinaryConsts::LocalGet) << U32LEB(localIndex);
}
};

// The types on the stack before the br_if. We need this if we use locals to
// stash the stack.
std::vector<Type> typesOnStack;

auto needHandling = brIfsNeedingHandling.count(curr);
if (needHandling) {
// Tuples always need scratch locals. Uncastable types do as well, we we
// can't fix them up below with a simple cast.
needScratchLocals = type.isTuple() || !type.isCastable();
if (needScratchLocals) {
// Stash all the values on the stack to those locals, then reload them for
// the br_if to consume. Later, we can reload the refined values after the
// br_if, for its parent to consume.

typesOnStack = std::vector<Type>(type.begin(), type.end());
typesOnStack.push_back(Type::i32);

stashStack(typesOnStack);
restoreStack(typesOnStack);
// The stack is now in the same state as before, but we have copies in
// locals for later.
}
}

o << int8_t(curr->condition ? BinaryConsts::BrIf : BinaryConsts::Br)
<< U32LEB(getBreakIndex(curr->name));

// See comment on |brIfsNeedingHandling| for the extra casts we need to emit
// here for certain br_ifs.
auto iter = brIfsNeedingHandling.find(curr);
if (iter != brIfsNeedingHandling.end()) {
auto unrefinedType = iter->second;
auto type = curr->type;
assert(type.size() == unrefinedType.size());
if (needHandling) {
if (!needScratchLocals) {
// We can just cast here, avoiding scratch locals. (Casting adds overhead,
// but this is very rare, and it avoids adding locals, which would keep
// growing the wasm with each roundtrip.)

assert(curr->type.hasRef());

auto emitCast = [&](Type to) {
// Shim a tiny bit of IR, just enough to get visitRefCast to see what we
// are casting, and to emit the proper thing.
RefCast cast;
cast.type = to;
cast.type = type;
cast.ref = cast.desc = nullptr;
visitRefCast(&cast);
};

if (!type.isTuple()) {
// Simple: Just emit a cast, and then the type matches Binaryen IR's.
emitCast(type);
} else {
// Tuples are trickier to handle, and we need to use scratch locals. Stash
// all the values on the stack to those locals, then reload them, casting
// as we go.
//
// We must track how many scratch locals we've used from each type as we
// go, as a type might appear multiple times in the tuple. We allocated
// enough for each, in a contiguous range, so we just increment as we go.
std::unordered_map<Type, Index> scratchTypeUses;
for (Index i = 0; i < unrefinedType.size(); i++) {
auto t = unrefinedType[unrefinedType.size() - i - 1];
assert(scratchLocals.find(t) != scratchLocals.end());
auto localIndex = scratchLocals[t] + scratchTypeUses[t]++;
o << int8_t(BinaryConsts::LocalSet) << U32LEB(localIndex);
}
for (Index i = 0; i < unrefinedType.size(); i++) {
auto t = unrefinedType[i];
auto localIndex = scratchLocals[t] + --scratchTypeUses[t];
o << int8_t(BinaryConsts::LocalGet) << U32LEB(localIndex);
if (t.isRef()) {
// Note that we cast all types here, when perhaps only some of the
// tuple's lanes need that. This is simpler.
emitCast(type[i]);
}
// We need locals. Earlier we stashed the stack, so we just need to
// restore the value from there (note we don't restore the condition),
// after dropping the br_if's unrefined values.
for (Index i = 0; i < type.size(); ++i) {
o << int8_t(BinaryConsts::Drop);
}
assert(typesOnStack.back() == Type::i32);
typesOnStack.pop_back();
restoreStack(typesOnStack);
}
}
}
Expand Down Expand Up @@ -3094,8 +3126,9 @@ InsertOrderedMap<Type, Index> BinaryInstWriter::countScratchLocals() {
: writer(writer), finder(finder) {}

void visitBreak(Break* curr) {
auto type = curr->type;
// See if this is one of the dangerous br_ifs we must handle.
if (!curr->type.hasRef()) {
if (!type.hasRef()) {
// Not even a reference.
return;
}
Expand All @@ -3106,7 +3139,7 @@ InsertOrderedMap<Type, Index> BinaryInstWriter::countScratchLocals() {
return;
}
if (auto* cast = parent->dynCast<RefCast>()) {
if (Type::isSubType(cast->type, curr->type)) {
if (Type::isSubType(cast->type, type)) {
// It is cast to the same type or a better one. In particular this
// handles the case of repeated roundtripping: After the first
// roundtrip we emit a cast that we'll identify here, and not emit
Expand All @@ -3117,23 +3150,30 @@ InsertOrderedMap<Type, Index> BinaryInstWriter::countScratchLocals() {
}
auto* breakTarget = findBreakTarget(curr->name);
auto unrefinedType = breakTarget->type;
if (unrefinedType == curr->type) {
if (unrefinedType == type) {
// It has the proper type anyhow.
return;
}

// Mark the br_if as needing handling, and add the type to the set of
// types we need scratch tuple locals for (if relevant).
writer.brIfsNeedingHandling[curr] = unrefinedType;

if (unrefinedType.isTuple()) {
// We must allocate enough scratch locals for this tuple. Note that we
// may need more than one per type in the tuple, if a type appears more
// than once, so we count their appearances.
writer.brIfsNeedingHandling.insert(curr);

// Simple cases can be handled by a cast. However, tuples and uncastable
// types require us to use locals too.
if (type.isTuple() || !type.isCastable()) {
// We must allocate enough scratch locals for this tuple, plus the i32
// of the condition, as we will stash it all so that we can restore the
// fully refined value after the br_if.
//
// Note that we may need more than one per type in the tuple, if a type
// appears more than once, so we count their appearances.
InsertOrderedMap<Type, Index> scratchTypeUses;
for (auto t : unrefinedType) {
for (auto t : type) {
scratchTypeUses[t]++;
}
// The condition.
scratchTypeUses[Type::i32]++;
for (auto& [type, uses] : scratchTypeUses) {
auto& count = finder.scratches[type];
count = std::max(count, uses);
Expand Down
7 changes: 7 additions & 0 deletions src/wasm/wasm-type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,8 @@ bool Type::isDefaultable() const {
return isConcrete() && !isNonNullable();
}

bool Type::isCastable() { return isRef() && getHeapType().isCastable(); }

unsigned Type::getByteSize() const {
// TODO: alignment?
auto getSingleByteSize = [](Type t) {
Expand Down Expand Up @@ -889,6 +891,11 @@ Shareability HeapType::getShared() const {
}
}

bool HeapType::isCastable() {
return !isContinuation() && !isMaybeShared(HeapType::cont) &&
!isMaybeShared(HeapType::nocont);
}

Signature HeapType::getSignature() const {
assert(isSignature());
return getHeapTypeInfo(*this)->signature;
Expand Down
Loading
Loading