From 32d8f3ee5920a0a4fc20846efd9c33778529d931 Mon Sep 17 00:00:00 2001 From: Thomas Lively Date: Fri, 14 Nov 2025 11:49:52 -0800 Subject: [PATCH 1/3] [analysis] Add a ConeType lattice ConeType is an important part of PossibleContents. Implement it as a standalone lattice that can be combined into a larger lattice if we rewrite PossibleContents in terms of the lattice framework. --- src/analysis/lattices/conetype.h | 178 +++++++++++ test/gtest/lattices.cpp | 499 +++++++++++++++++++++++++++++++ 2 files changed, 677 insertions(+) create mode 100644 src/analysis/lattices/conetype.h diff --git a/src/analysis/lattices/conetype.h b/src/analysis/lattices/conetype.h new file mode 100644 index 00000000000..913c2a01567 --- /dev/null +++ b/src/analysis/lattices/conetype.h @@ -0,0 +1,178 @@ +/* + * Copyright 2025 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "analysis/lattice.h" +#include "wasm-type.h" + +#ifndef wasm_analysis_lattices_conetype_h +#define wasm_analysis_lattices_conetype_h + +namespace wasm::analysis { + +struct ConeType { + struct Element { + Type type; + Index depth; + bool operator==(const Element& other) const { + return type == other.type && depth == other.depth; + } + bool operator!=(const Element& other) const { return !(*this == other); } + bool isBottom() const { return type == Type::unreachable; } + bool isTop() const { return type == Type::none; } + }; + + std::unordered_map typeDepths; + + ConeType(std::unordered_map&& typeDepths) + : typeDepths(std::move(typeDepths)) {} + + Element get(Type type) const noexcept { + assert(!type.isTuple()); + if (!type.isRef() || type.isExact()) { + return Element{type, 0}; + } + if (auto it = typeDepths.find(type.getHeapType()); it != typeDepths.end()) { + return Element{type, it->second}; + } + return Element{type, 0}; + } + + Element getBottom() const noexcept { return Element{Type::unreachable, 0}; } + + Element getTop() const noexcept { return Element{Type::none, 0}; } + + bool join(Element& joinee, const Element& joiner) const noexcept { + auto lub = Type::getLeastUpperBound(joinee.type, joiner.type); + bool changed = lub != joinee.type; + if (!lub.isRef()) { + joinee.type = lub; + joinee.depth = 0; + return changed; + } + Index joineeToLub = 0, joinerToLub = 0; + if (!joinee.isBottom() && !joinee.type.getHeapType().isBottom()) { + joineeToLub = depthToSuper(joinee, lub); + } + if (!joiner.isBottom() && !joiner.type.getHeapType().isBottom()) { + joinerToLub = depthToSuper(joiner, lub); + } + Index newDepth = + std::max(joinee.depth + joineeToLub, joiner.depth + joinerToLub); + changed = changed || newDepth != joinee.depth; + joinee.type = lub; + joinee.depth = newDepth; + return changed; + } + + bool meet(Element& meetee, const Element& meeter) const noexcept { + // Type::none does not behave like the top type in + // Type::getGreatestLowerBound, so handle it separately first. Also handle + // unreachables so we don't have to worry about them later. + if (meetee.isBottom() || meeter.isTop()) { + return false; + } + if (meetee.isTop() || meeter.isBottom()) { + meetee = meeter; + return true; + } + if (meetee.type == meeter.type) { + auto newDepth = std::min(meetee.depth, meeter.depth); + bool changed = newDepth != meetee.depth; + meetee.depth = newDepth; + return changed; + } + Index newDepth; + auto glb = Type::getGreatestLowerBound(meetee.type, meeter.type); + if (glb == Type::unreachable || glb.getHeapType().isBottom()) { + newDepth = 0; + } else if (HeapType::isSubType(meetee.type.getHeapType(), + meeter.type.getHeapType())) { + auto diff = depthToSuper(meetee, meeter.type); + if (meeter.depth < diff) { + glb = glb.with(glb.getHeapType().getBottom()); + newDepth = 0; + } else { + newDepth = std::min(meeter.depth - diff, meetee.depth); + } + } else if (HeapType::isSubType(meeter.type.getHeapType(), + meetee.type.getHeapType())) { + auto diff = depthToSuper(meeter, meetee.type); + if (meetee.depth < diff) { + glb = glb.with(glb.getHeapType().getBottom()); + newDepth = 0; + } else { + newDepth = std::min(meetee.depth - diff, meeter.depth); + } + } else { + WASM_UNREACHABLE("unexpected case"); + } + bool changed = glb != meetee.type || newDepth != meetee.depth; + meetee.type = glb; + meetee.depth = newDepth; + return changed; + } + + analysis::LatticeComparison compare(const Element& a, + const Element& b) const noexcept { + if (a == b) { + return analysis::EQUAL; + } + if (a.isBottom() || b.isTop()) { + return analysis::LESS; + } + if (a.isTop() || b.isBottom()) { + return analysis::GREATER; + } + if (a.type == b.type) { + return a.depth < b.depth ? analysis::LESS : analysis::GREATER; + } + if (Type::isSubType(a.type, b.type)) { + if (a.type.getHeapType().isBottom()) { + return analysis::LESS; + } + Index diff = depthToSuper(a, b.type); + return a.depth + diff <= b.depth ? analysis::LESS : analysis::NO_RELATION; + } + if (Type::isSubType(b.type, a.type)) { + if (b.type.getHeapType().isBottom()) { + return analysis::GREATER; + } + Index diff = depthToSuper(b, a.type); + return b.depth + diff <= a.depth ? analysis::GREATER + : analysis::NO_RELATION; + } + return analysis::NO_RELATION; + } + +private: + Index depthToSuper(const Element& e, Type super) const noexcept { + Index depth = 0; + for (HeapType type = e.type.getHeapType(); type != super.getHeapType(); + type = *type.getSuperType()) { + ++depth; + } + return depth; + } +}; + +#if __cplusplus >= 202002L +static_assert(Lattice); +static_assert(FullLattice); +#endif + +} // namespace wasm::analysis + +#endif // wasm_analysis_lattices_conetype_h \ No newline at end of file diff --git a/test/gtest/lattices.cpp b/test/gtest/lattices.cpp index 5a80994b2bd..d5b55b50dd0 100644 --- a/test/gtest/lattices.cpp +++ b/test/gtest/lattices.cpp @@ -18,6 +18,7 @@ #include "analysis/lattices/abstraction.h" #include "analysis/lattices/array.h" #include "analysis/lattices/bool.h" +#include "analysis/lattices/conetype.h" #include "analysis/lattices/flat.h" #include "analysis/lattices/int.h" #include "analysis/lattices/inverted.h" @@ -27,6 +28,8 @@ #include "analysis/lattices/tuple.h" #include "analysis/lattices/valtype.h" #include "analysis/lattices/vector.h" +#include "ir/subtypes.h" +#include "wasm-type.h" #include "gtest/gtest.h" using namespace wasm; @@ -852,3 +855,499 @@ TEST(AbstractionLattice, Join) { #undef JOIN } + +class ConeTypeLatticeTest : public ::testing::Test { +protected: + HeapType super; + HeapType sub1; + HeapType sub2; + HeapType other; + analysis::ConeType lattice; + + using Element = analysis::ConeType::Element; + Element bot; + Element top; + Element i32; + Element i64; + Element eqNull3; + Element eqNonNull3; + Element structNull2; + Element structNonNull2; + Element i31Null; + Element i31NonNull; + Element noneNull; + Element noneNonNull; + Element superNull1; + Element superNonNull1; + Element superNullExact; + Element superNonNullExact; + Element sub1Null0; + Element sub1NonNull0; + Element sub1NullExact; + Element sub1NonNullExact; + Element sub2Null0; + Element sub2NonNull0; + Element sub2NullExact; + Element sub2NonNullExact; + Element otherNull0; + Element otherNonNull0; + Element otherNullExact; + Element otherNonNullExact; + + // Element created by combining other elements. + Element eqNull2; + Element eqNonNull2; + Element structNull1; + Element structNonNull1; + + void checkJoin(const Element& a, + const Element& b, + const Element& join, + const char* file, + int line) { + testing::ScopedTrace trace(file, line, "check join"); + Element copy = a; + EXPECT_EQ(lattice.join(copy, b), a != join); + EXPECT_EQ(copy, join); + copy = b; + EXPECT_EQ(lattice.join(copy, a), b != join); + EXPECT_EQ(copy, join); + } +#define CHECK_JOIN(a, b, join) checkJoin(a, b, join, __FILE__, __LINE__) + + void checkMeet(const Element& a, + const Element& b, + const Element& meet, + const char* file, + int line) { + testing::ScopedTrace trace(file, line, "check meet"); + Element copy = a; + EXPECT_EQ(lattice.meet(copy, b), a != meet); + EXPECT_EQ(copy, meet); + copy = b; + EXPECT_EQ(lattice.meet(copy, a), b != meet); + EXPECT_EQ(copy, meet); + } +#define CHECK_MEET(a, b, join) checkMeet(a, b, join, __FILE__, __LINE__) + + void + checkLess(const Element& a, const Element& b, const char* file, int line) { + testing::ScopedTrace trace(file, line, "check less"); + EXPECT_EQ(lattice.compare(a, b), analysis::LESS); + EXPECT_EQ(lattice.compare(b, a), analysis::GREATER); + } +#define CHECK_LESS(a, b) checkLess(a, b, __FILE__, __LINE__) + + void + checkGreater(const Element& a, const Element& b, const char* file, int line) { + testing::ScopedTrace trace(file, line, "check greater"); + EXPECT_EQ(lattice.compare(a, b), analysis::GREATER); + EXPECT_EQ(lattice.compare(b, a), analysis::LESS); + } +#define CHECK_GREATER(a, b) checkGreater(a, b, __FILE__, __LINE__) + + void checkUnrelated(const Element& a, + const Element& b, + const char* file, + int line) { + testing::ScopedTrace trace(file, line, "check unrelated"); + EXPECT_EQ(lattice.compare(a, b), analysis::NO_RELATION); + EXPECT_EQ(lattice.compare(b, a), analysis::NO_RELATION); + } +#define CHECK_UNRELATED(a, b) checkUnrelated(a, b, __FILE__, __LINE__) + + void + checkEqual(const Element& a, const Element& b, const char* file, int line) { + testing::ScopedTrace trace(file, line, "check equal"); + EXPECT_EQ(lattice.compare(a, b), analysis::EQUAL); + EXPECT_EQ(lattice.compare(b, a), analysis::EQUAL); + } +#define CHECK_EQUAL(a, b) checkEqual(a, b, __FILE__, __LINE__) + + void checkPair(const Element& a, + const Element& b, + const Element& join, + const Element& meet, + const char* file, + int line) { + testing::ScopedTrace trace(file, line, "check pair"); + CHECK_JOIN(a, b, join); + CHECK_MEET(a, b, meet); + switch (lattice.compare(a, b)) { + case analysis::NO_RELATION: + CHECK_UNRELATED(a, b); + CHECK_LESS(a, join); + CHECK_LESS(b, join); + CHECK_GREATER(a, meet); + CHECK_GREATER(b, meet); + CHECK_LESS(meet, join); + break; + case analysis::EQUAL: + CHECK_EQUAL(a, b); + CHECK_EQUAL(a, join); + CHECK_EQUAL(b, meet); + break; + case analysis::LESS: + CHECK_LESS(a, b); + CHECK_EQUAL(b, join); + CHECK_EQUAL(a, meet); + break; + case analysis::GREATER: + CHECK_GREATER(a, b); + CHECK_EQUAL(a, join); + CHECK_EQUAL(b, meet); + break; + } + } +#define CHECK_PAIR(a, b, join, meet) \ + checkPair(a, b, join, meet, __FILE__, __LINE__) + +public: + ConeTypeLatticeTest() : lattice(initTypes(this)) { + bot = lattice.getBottom(); + top = lattice.getTop(); + i32 = lattice.get(Type::i32); + i64 = lattice.get(Type::i64); + eqNull3 = lattice.get(Type(HeapType::eq, Nullable)); + eqNonNull3 = lattice.get(Type(HeapType::eq, NonNullable)); + structNull2 = lattice.get(Type(HeapType::struct_, Nullable)); + structNonNull2 = lattice.get(Type(HeapType::struct_, NonNullable)); + i31Null = lattice.get(Type(HeapType::i31, Nullable)); + i31NonNull = lattice.get(Type(HeapType::i31, NonNullable)); + noneNull = lattice.get(Type(HeapType::none, Nullable)); + noneNonNull = lattice.get(Type(HeapType::none, NonNullable)); + superNull1 = lattice.get(Type(super, Nullable)); + superNonNull1 = lattice.get(Type(super, NonNullable)); + superNullExact = lattice.get(Type(super, Nullable, Exact)); + superNonNullExact = lattice.get(Type(super, NonNullable, Exact)); + sub1Null0 = lattice.get(Type(sub1, Nullable)); + sub1NonNull0 = lattice.get(Type(sub1, NonNullable)); + sub1NullExact = lattice.get(Type(sub1, Nullable, Exact)); + sub1NonNullExact = lattice.get(Type(sub1, NonNullable, Exact)); + sub2Null0 = lattice.get(Type(sub2, Nullable)); + sub2NonNull0 = lattice.get(Type(sub2, NonNullable)); + sub2NullExact = lattice.get(Type(sub2, Nullable, Exact)); + sub2NonNullExact = lattice.get(Type(sub2, NonNullable, Exact)); + otherNull0 = lattice.get(Type(other, Nullable)); + otherNonNull0 = lattice.get(Type(other, NonNullable)); + otherNullExact = lattice.get(Type(other, Nullable, Exact)); + otherNonNullExact = lattice.get(Type(other, NonNullable, Exact)); + + eqNull2 = Element{eqNull3.type, 2}; + eqNonNull2 = Element{eqNonNull3.type, 2}; + structNull1 = Element{structNull2.type, 1}; + structNonNull1 = Element{structNonNull2.type, 1}; + } + +private: + static std::unordered_map + initTypes(ConeTypeLatticeTest* self); +}; + +std::unordered_map +ConeTypeLatticeTest::initTypes(ConeTypeLatticeTest* self) { + // 0 3 + // |\ + // 1 2 + TypeBuilder builder(4); + builder.createRecGroup(0, 4); + builder[0] = Struct{}; + builder[1] = Struct{}; + builder[2] = Struct{}; + builder[3] = Struct{}; + builder[0].setOpen(); + builder[1].subTypeOf(builder[0]); + builder[2].subTypeOf(builder[0]); + auto types = *builder.build(); + + self->super = types[0]; + self->sub1 = types[1]; + self->sub2 = types[2]; + self->other = types[3]; + + SubTypes subtypes(types); + return subtypes.getMaxDepths(); +} + +TEST_F(ConeTypeLatticeTest, GetBottom) { + EXPECT_TRUE(lattice.getBottom().isBottom()); + EXPECT_EQ(lattice.getBottom().type, Type(Type::unreachable)); + EXPECT_EQ(lattice.getBottom().depth, 0); +} + +TEST_F(ConeTypeLatticeTest, GetTop) { + EXPECT_TRUE(lattice.getTop().isTop()); + EXPECT_EQ(lattice.getTop().type, Type(Type::none)); + EXPECT_EQ(lattice.getTop().depth, 0); +} + +TEST_F(ConeTypeLatticeTest, Relations) { + CHECK_PAIR(bot, bot, bot, bot); + CHECK_PAIR(bot, top, top, bot); + CHECK_PAIR(bot, i32, i32, bot); + CHECK_PAIR(bot, i64, i64, bot); + CHECK_PAIR(bot, eqNull3, eqNull3, bot); + CHECK_PAIR(bot, eqNonNull3, eqNonNull3, bot); + CHECK_PAIR(bot, structNull2, structNull2, bot); + CHECK_PAIR(bot, structNonNull2, structNonNull2, bot); + CHECK_PAIR(bot, i31Null, i31Null, bot); + CHECK_PAIR(bot, i31NonNull, i31NonNull, bot); + CHECK_PAIR(bot, noneNull, noneNull, bot); + CHECK_PAIR(bot, noneNonNull, noneNonNull, bot); + CHECK_PAIR(bot, superNull1, superNull1, bot); + CHECK_PAIR(bot, superNonNull1, superNonNull1, bot); + CHECK_PAIR(bot, superNullExact, superNullExact, bot); + CHECK_PAIR(bot, superNonNullExact, superNonNullExact, bot); + + CHECK_PAIR(top, top, top, top); + CHECK_PAIR(top, i32, top, i32); + CHECK_PAIR(top, i64, top, i64); + CHECK_PAIR(top, eqNull3, top, eqNull3); + CHECK_PAIR(top, eqNonNull3, top, eqNonNull3); + CHECK_PAIR(top, structNull2, top, structNull2); + CHECK_PAIR(top, structNonNull2, top, structNonNull2); + CHECK_PAIR(top, i31Null, top, i31Null); + CHECK_PAIR(top, i31NonNull, top, i31NonNull); + CHECK_PAIR(top, noneNull, top, noneNull); + CHECK_PAIR(top, noneNonNull, top, noneNonNull); + CHECK_PAIR(top, superNull1, top, superNull1); + CHECK_PAIR(top, superNonNull1, top, superNonNull1); + CHECK_PAIR(top, superNullExact, top, superNullExact); + CHECK_PAIR(top, superNonNullExact, top, superNonNullExact); + + CHECK_PAIR(i32, i32, i32, i32); + CHECK_PAIR(i32, i64, top, bot); + CHECK_PAIR(i32, eqNull3, top, bot); + CHECK_PAIR(i32, eqNonNull3, top, bot); + CHECK_PAIR(i32, structNull2, top, bot); + CHECK_PAIR(i32, structNonNull2, top, bot); + CHECK_PAIR(i32, i31Null, top, bot); + CHECK_PAIR(i32, i31NonNull, top, bot); + CHECK_PAIR(i32, noneNull, top, bot); + CHECK_PAIR(i32, noneNonNull, top, bot); + CHECK_PAIR(i32, superNull1, top, bot); + CHECK_PAIR(i32, superNonNull1, top, bot); + CHECK_PAIR(i32, superNullExact, top, bot); + CHECK_PAIR(i32, superNonNullExact, top, bot); + + CHECK_PAIR(eqNull3, eqNull3, eqNull3, eqNull3); + CHECK_PAIR(eqNull3, eqNonNull3, eqNull3, eqNonNull3); + CHECK_PAIR(eqNull3, structNull2, eqNull3, structNull2); + CHECK_PAIR(eqNull3, structNonNull2, eqNull3, structNonNull2); + CHECK_PAIR(eqNull3, i31Null, eqNull3, i31Null); + CHECK_PAIR(eqNull3, i31NonNull, eqNull3, i31NonNull); + CHECK_PAIR(eqNull3, noneNull, eqNull3, noneNull); + CHECK_PAIR(eqNull3, noneNonNull, eqNull3, noneNonNull); + CHECK_PAIR(eqNull3, superNull1, eqNull3, superNull1); + CHECK_PAIR(eqNull3, superNonNull1, eqNull3, superNonNull1); + CHECK_PAIR(eqNull3, superNullExact, eqNull3, superNullExact); + CHECK_PAIR(eqNull3, superNonNullExact, eqNull3, superNonNullExact); + + CHECK_PAIR(eqNonNull3, eqNonNull3, eqNonNull3, eqNonNull3); + CHECK_PAIR(eqNonNull3, structNull2, eqNull3, structNonNull2); + CHECK_PAIR(eqNonNull3, structNonNull2, eqNonNull3, structNonNull2); + CHECK_PAIR(eqNonNull3, i31Null, eqNull3, i31NonNull); + CHECK_PAIR(eqNonNull3, i31NonNull, eqNonNull3, i31NonNull); + CHECK_PAIR(eqNonNull3, noneNull, eqNull3, noneNonNull); + CHECK_PAIR(eqNonNull3, noneNonNull, eqNonNull3, noneNonNull); + CHECK_PAIR(eqNonNull3, superNull1, eqNull3, superNonNull1); + CHECK_PAIR(eqNonNull3, superNonNull1, eqNonNull3, superNonNull1); + CHECK_PAIR(eqNonNull3, superNullExact, eqNull3, superNonNullExact); + CHECK_PAIR(eqNonNull3, superNonNullExact, eqNonNull3, superNonNullExact); + + CHECK_PAIR(structNull2, structNull2, structNull2, structNull2); + CHECK_PAIR(structNull2, structNonNull2, structNull2, structNonNull2); + CHECK_PAIR(structNull2, i31Null, eqNull3, noneNull); + CHECK_PAIR(structNull2, i31NonNull, eqNull3, noneNonNull); + CHECK_PAIR(structNull2, noneNull, structNull2, noneNull); + CHECK_PAIR(structNull2, noneNonNull, structNull2, noneNonNull); + CHECK_PAIR(structNull2, superNull1, structNull2, superNull1); + CHECK_PAIR(structNull2, superNonNull1, structNull2, superNonNull1); + CHECK_PAIR(structNull2, superNullExact, structNull2, superNullExact); + CHECK_PAIR(structNull2, superNonNullExact, structNull2, superNonNullExact); + + CHECK_PAIR(structNonNull2, structNonNull2, structNonNull2, structNonNull2); + CHECK_PAIR(structNonNull2, i31Null, eqNull3, noneNonNull); + CHECK_PAIR(structNonNull2, i31NonNull, eqNonNull3, noneNonNull); + CHECK_PAIR(structNonNull2, noneNull, structNull2, noneNonNull); + CHECK_PAIR(structNonNull2, noneNonNull, structNonNull2, noneNonNull); + CHECK_PAIR(structNonNull2, superNull1, structNull2, superNonNull1); + CHECK_PAIR(structNonNull2, superNonNull1, structNonNull2, superNonNull1); + CHECK_PAIR(structNonNull2, superNullExact, structNull2, superNonNullExact); + CHECK_PAIR( + structNonNull2, superNonNullExact, structNonNull2, superNonNullExact); + + CHECK_PAIR(i31Null, i31Null, i31Null, i31Null); + CHECK_PAIR(i31Null, i31NonNull, i31Null, i31NonNull); + CHECK_PAIR(i31Null, noneNull, i31Null, noneNull); + CHECK_PAIR(i31Null, noneNonNull, i31Null, noneNonNull); + CHECK_PAIR(i31Null, superNull1, eqNull3, noneNull); + CHECK_PAIR(i31Null, superNonNull1, eqNull3, noneNonNull); + CHECK_PAIR(i31Null, superNullExact, eqNull2, noneNull); + CHECK_PAIR(i31Null, superNonNullExact, eqNull2, noneNonNull); + + CHECK_PAIR(i31NonNull, i31NonNull, i31NonNull, i31NonNull); + CHECK_PAIR(i31NonNull, noneNull, i31Null, noneNonNull); + CHECK_PAIR(i31NonNull, noneNonNull, i31NonNull, noneNonNull); + CHECK_PAIR(i31NonNull, superNull1, eqNull3, noneNonNull); + CHECK_PAIR(i31NonNull, superNonNull1, eqNonNull3, noneNonNull); + CHECK_PAIR(i31NonNull, superNullExact, eqNull2, noneNonNull); + CHECK_PAIR(i31NonNull, superNonNullExact, eqNonNull2, noneNonNull); + + CHECK_PAIR(noneNull, noneNull, noneNull, noneNull); + CHECK_PAIR(noneNull, noneNonNull, noneNull, noneNonNull); + CHECK_PAIR(noneNull, superNull1, superNull1, noneNull); + CHECK_PAIR(noneNull, superNonNull1, superNull1, noneNonNull); + CHECK_PAIR(noneNull, superNullExact, superNullExact, noneNull); + CHECK_PAIR(noneNull, superNonNullExact, superNullExact, noneNonNull); + + CHECK_PAIR(noneNonNull, noneNonNull, noneNonNull, noneNonNull); + CHECK_PAIR(noneNonNull, superNull1, superNull1, noneNonNull); + CHECK_PAIR(noneNonNull, superNonNull1, superNonNull1, noneNonNull); + CHECK_PAIR(noneNonNull, superNullExact, superNullExact, noneNonNull); + CHECK_PAIR(noneNonNull, superNonNullExact, superNonNullExact, noneNonNull); + + CHECK_PAIR(superNull1, superNull1, superNull1, superNull1); + CHECK_PAIR(superNull1, superNonNull1, superNull1, superNonNull1); + CHECK_PAIR(superNull1, superNullExact, superNull1, superNullExact); + CHECK_PAIR(superNull1, superNonNullExact, superNull1, superNonNullExact); + CHECK_PAIR(superNull1, sub1Null0, superNull1, sub1Null0); + CHECK_PAIR(superNull1, sub1NonNull0, superNull1, sub1NonNull0); + CHECK_PAIR(superNull1, sub1NullExact, superNull1, sub1NullExact); + CHECK_PAIR(superNull1, sub1NonNullExact, superNull1, sub1NonNullExact); + CHECK_PAIR(superNull1, otherNull0, structNull2, noneNull); + CHECK_PAIR(superNull1, otherNonNull0, structNull2, noneNonNull); + CHECK_PAIR(superNull1, otherNullExact, structNull2, noneNull); + CHECK_PAIR(superNull1, otherNonNullExact, structNull2, noneNonNull); + + CHECK_PAIR(superNonNull1, superNonNull1, superNonNull1, superNonNull1); + CHECK_PAIR(superNonNull1, superNullExact, superNull1, superNonNullExact); + CHECK_PAIR( + superNonNull1, superNonNullExact, superNonNull1, superNonNullExact); + CHECK_PAIR(superNonNull1, sub1Null0, superNull1, sub1NonNull0); + CHECK_PAIR(superNonNull1, sub1NonNull0, superNonNull1, sub1NonNull0); + CHECK_PAIR(superNonNull1, sub1NullExact, superNull1, sub1NonNullExact); + CHECK_PAIR(superNonNull1, sub1NonNullExact, superNonNull1, sub1NonNullExact); + CHECK_PAIR(superNonNull1, otherNull0, structNull2, noneNonNull); + CHECK_PAIR(superNonNull1, otherNonNull0, structNonNull2, noneNonNull); + CHECK_PAIR(superNonNull1, otherNullExact, structNull2, noneNonNull); + CHECK_PAIR(superNonNull1, otherNonNullExact, structNonNull2, noneNonNull); + + CHECK_PAIR(superNullExact, superNullExact, superNullExact, superNullExact); + CHECK_PAIR( + superNullExact, superNonNullExact, superNullExact, superNonNullExact); + CHECK_PAIR(superNullExact, sub1Null0, superNull1, noneNull); + CHECK_PAIR(superNullExact, sub1NonNull0, superNull1, noneNonNull); + CHECK_PAIR(superNullExact, sub1NullExact, superNull1, noneNull); + CHECK_PAIR(superNullExact, sub1NonNullExact, superNull1, noneNonNull); + CHECK_PAIR(superNullExact, otherNull0, structNull1, noneNull); + CHECK_PAIR(superNullExact, otherNonNull0, structNull1, noneNonNull); + CHECK_PAIR(superNullExact, otherNullExact, structNull1, noneNull); + CHECK_PAIR(superNullExact, otherNonNullExact, structNull1, noneNonNull); + + CHECK_PAIR( + superNonNullExact, superNonNullExact, superNonNullExact, superNonNullExact); + CHECK_PAIR(superNonNullExact, sub1Null0, superNull1, noneNonNull); + CHECK_PAIR(superNonNullExact, sub1NonNull0, superNonNull1, noneNonNull); + CHECK_PAIR(superNonNullExact, sub1NullExact, superNull1, noneNonNull); + CHECK_PAIR(superNonNullExact, sub1NonNullExact, superNonNull1, noneNonNull); + CHECK_PAIR(superNonNullExact, otherNull0, structNull1, noneNonNull); + CHECK_PAIR(superNonNullExact, otherNonNull0, structNonNull1, noneNonNull); + CHECK_PAIR(superNonNullExact, otherNullExact, structNull1, noneNonNull); + CHECK_PAIR(superNonNullExact, otherNonNullExact, structNonNull1, noneNonNull); + + CHECK_PAIR(sub1Null0, sub1Null0, sub1Null0, sub1Null0); + CHECK_PAIR(sub1Null0, sub1NonNull0, sub1Null0, sub1NonNull0); + CHECK_PAIR(sub1Null0, sub1NullExact, sub1Null0, sub1NullExact); + CHECK_PAIR(sub1Null0, sub1NonNullExact, sub1Null0, sub1NonNullExact); + CHECK_PAIR(sub1Null0, sub2Null0, superNull1, noneNull); + CHECK_PAIR(sub1Null0, sub2NonNull0, superNull1, noneNonNull); + CHECK_PAIR(sub1Null0, sub2NullExact, superNull1, noneNull); + CHECK_PAIR(sub1Null0, sub2NonNullExact, superNull1, noneNonNull); + CHECK_PAIR(sub1Null0, otherNull0, structNull2, noneNull); + CHECK_PAIR(sub1Null0, otherNonNull0, structNull2, noneNonNull); + CHECK_PAIR(sub1Null0, otherNullExact, structNull2, noneNull); + CHECK_PAIR(sub1Null0, otherNonNullExact, structNull2, noneNonNull); + + CHECK_PAIR(sub1NonNull0, sub1NonNull0, sub1NonNull0, sub1NonNull0); + CHECK_PAIR(sub1NonNull0, sub1NullExact, sub1Null0, sub1NonNullExact); + CHECK_PAIR(sub1NonNull0, sub1NonNullExact, sub1NonNull0, sub1NonNullExact); + CHECK_PAIR(sub1NonNull0, sub2Null0, superNull1, noneNonNull); + CHECK_PAIR(sub1NonNull0, sub2NonNull0, superNonNull1, noneNonNull); + CHECK_PAIR(sub1NonNull0, sub2NullExact, superNull1, noneNonNull); + CHECK_PAIR(sub1NonNull0, sub2NonNullExact, superNonNull1, noneNonNull); + CHECK_PAIR(sub1NonNull0, otherNull0, structNull2, noneNonNull); + CHECK_PAIR(sub1NonNull0, otherNonNull0, structNonNull2, noneNonNull); + CHECK_PAIR(sub1NonNull0, otherNullExact, structNull2, noneNonNull); + CHECK_PAIR(sub1NonNull0, otherNonNullExact, structNonNull2, noneNonNull); + + CHECK_PAIR(sub1NullExact, sub1NullExact, sub1NullExact, sub1NullExact); + CHECK_PAIR(sub1NullExact, sub1NonNullExact, sub1NullExact, sub1NonNullExact); + CHECK_PAIR(sub1NullExact, sub2Null0, superNull1, noneNull); + CHECK_PAIR(sub1NullExact, sub2NonNull0, superNull1, noneNonNull); + CHECK_PAIR(sub1NullExact, sub2NullExact, superNull1, noneNull); + CHECK_PAIR(sub1NullExact, sub2NonNullExact, superNull1, noneNonNull); + CHECK_PAIR(sub1NullExact, otherNull0, structNull2, noneNull); + CHECK_PAIR(sub1NullExact, otherNonNull0, structNull2, noneNonNull); + CHECK_PAIR(sub1NullExact, otherNullExact, structNull2, noneNull); + CHECK_PAIR(sub1NullExact, otherNonNullExact, structNull2, noneNonNull); + + CHECK_PAIR( + sub1NonNullExact, sub1NonNullExact, sub1NonNullExact, sub1NonNullExact); + CHECK_PAIR(sub1NonNullExact, sub2Null0, superNull1, noneNonNull); + CHECK_PAIR(sub1NonNullExact, sub2NonNull0, superNonNull1, noneNonNull); + CHECK_PAIR(sub1NonNullExact, sub2NullExact, superNull1, noneNonNull); + CHECK_PAIR(sub1NonNullExact, sub2NonNullExact, superNonNull1, noneNonNull); + CHECK_PAIR(sub1NonNullExact, otherNull0, structNull2, noneNonNull); + CHECK_PAIR(sub1NonNullExact, otherNonNull0, structNonNull2, noneNonNull); + CHECK_PAIR(sub1NonNullExact, otherNullExact, structNull2, noneNonNull); + CHECK_PAIR(sub1NonNullExact, otherNonNullExact, structNonNull2, noneNonNull); +} + +TEST_F(ConeTypeLatticeTest, Depths) { + TypeBuilder builder(3); + builder[0].setOpen() = Struct{}; + builder[1].setOpen().subTypeOf(builder[0]) = Struct{}; + builder[2].setOpen().subTypeOf(builder[1]) = Struct{}; + auto built = builder.build(); + + HeapType a = (*built)[0]; + HeapType b = (*built)[1]; + HeapType c = (*built)[2]; + + Element none{Type(HeapType::none, Nullable), 0}; + + Element a0{Type(a, Nullable), 0}; + Element a1{Type(a, Nullable), 1}; + Element a2{Type(a, Nullable), 2}; + Element a3{Type(a, Nullable), 3}; + + Element b0{Type(b, Nullable), 0}; + Element b1{Type(b, Nullable), 1}; + Element b2{Type(b, Nullable), 2}; + + Element c0{Type(c, Nullable), 0}; + Element c1{Type(c, Nullable), 1}; + + CHECK_PAIR(a0, a0, a0, a0); + CHECK_PAIR(a0, a1, a1, a0); + CHECK_PAIR(a0, b0, a1, none); + CHECK_PAIR(a0, b1, a2, none); + CHECK_PAIR(a0, c0, a2, none); + CHECK_PAIR(a0, c1, a3, none); + + CHECK_PAIR(a1, a1, a1, a1); + CHECK_PAIR(a1, b0, a1, b0); + CHECK_PAIR(a1, b1, a2, b0); + CHECK_PAIR(a1, c0, a2, none); + CHECK_PAIR(a1, c1, a3, none); + + CHECK_PAIR(b0, b0, b0, b0); + CHECK_PAIR(b0, b1, b1, b0); + CHECK_PAIR(b0, c0, b1, none); + CHECK_PAIR(b0, c1, b2, none); + + CHECK_PAIR(b1, b1, b1, b1); + CHECK_PAIR(b1, c0, b1, c0); + CHECK_PAIR(b1, c1, b2, c0); +} From 776472c2d05ea879c12e33c0a06a84eea9378e2e Mon Sep 17 00:00:00 2001 From: Thomas Lively Date: Fri, 14 Nov 2025 13:58:14 -0800 Subject: [PATCH 2/3] use isNull and make get more robust --- src/analysis/lattices/conetype.h | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/analysis/lattices/conetype.h b/src/analysis/lattices/conetype.h index 913c2a01567..5883d8e343d 100644 --- a/src/analysis/lattices/conetype.h +++ b/src/analysis/lattices/conetype.h @@ -41,13 +41,13 @@ struct ConeType { Element get(Type type) const noexcept { assert(!type.isTuple()); - if (!type.isRef() || type.isExact()) { + if (!type.isRef() || type.isExact() || type.isNull() || + type.getHeapType().isMaybeShared(HeapType::i31)) { return Element{type, 0}; } - if (auto it = typeDepths.find(type.getHeapType()); it != typeDepths.end()) { - return Element{type, it->second}; - } - return Element{type, 0}; + auto it = typeDepths.find(type.getHeapType()); + assert(it != typeDepths.end()); + return Element{type, it->second}; } Element getBottom() const noexcept { return Element{Type::unreachable, 0}; } @@ -63,10 +63,10 @@ struct ConeType { return changed; } Index joineeToLub = 0, joinerToLub = 0; - if (!joinee.isBottom() && !joinee.type.getHeapType().isBottom()) { + if (!joinee.isBottom() && !joinee.type.isNull()) { joineeToLub = depthToSuper(joinee, lub); } - if (!joiner.isBottom() && !joiner.type.getHeapType().isBottom()) { + if (!joiner.isBottom() && !joiner.type.isNull()) { joinerToLub = depthToSuper(joiner, lub); } Index newDepth = @@ -96,7 +96,7 @@ struct ConeType { } Index newDepth; auto glb = Type::getGreatestLowerBound(meetee.type, meeter.type); - if (glb == Type::unreachable || glb.getHeapType().isBottom()) { + if (glb == Type::unreachable || glb.isNull()) { newDepth = 0; } else if (HeapType::isSubType(meetee.type.getHeapType(), meeter.type.getHeapType())) { @@ -140,14 +140,14 @@ struct ConeType { return a.depth < b.depth ? analysis::LESS : analysis::GREATER; } if (Type::isSubType(a.type, b.type)) { - if (a.type.getHeapType().isBottom()) { + if (a.type.isNull()) { return analysis::LESS; } Index diff = depthToSuper(a, b.type); return a.depth + diff <= b.depth ? analysis::LESS : analysis::NO_RELATION; } if (Type::isSubType(b.type, a.type)) { - if (b.type.getHeapType().isBottom()) { + if (b.type.isNull()) { return analysis::GREATER; } Index diff = depthToSuper(b, a.type); From ff10ec548b2d3bb44e9bc4585ee67e605aa29af5 Mon Sep 17 00:00:00 2001 From: Thomas Lively Date: Fri, 14 Nov 2025 14:13:07 -0800 Subject: [PATCH 3/3] add comments --- src/analysis/lattices/conetype.h | 8 +++++++- test/gtest/lattices.cpp | 8 ++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/analysis/lattices/conetype.h b/src/analysis/lattices/conetype.h index 5883d8e343d..5e5484fc630 100644 --- a/src/analysis/lattices/conetype.h +++ b/src/analysis/lattices/conetype.h @@ -22,6 +22,11 @@ namespace wasm::analysis { +// The value type lattice augmented with subtyping depths on reference types. An +// element {(ref $foo), 1}, for example, represents the set of values that are +// exactly $foo or exactly one of $foo's immediate subtypes, but not any deeper +// type. Non-reference types and bottom references types always have a depth of +// 0. struct ConeType { struct Element { Type type; @@ -34,7 +39,8 @@ struct ConeType { bool isTop() const { return type == Type::none; } }; - std::unordered_map typeDepths; + // Used only for initializing depths for new elements. + const std::unordered_map typeDepths; ConeType(std::unordered_map&& typeDepths) : typeDepths(std::move(typeDepths)) {} diff --git a/test/gtest/lattices.cpp b/test/gtest/lattices.cpp index d5b55b50dd0..043caa683f1 100644 --- a/test/gtest/lattices.cpp +++ b/test/gtest/lattices.cpp @@ -869,6 +869,8 @@ class ConeTypeLatticeTest : public ::testing::Test { Element top; Element i32; Element i64; + + // The number at the end is the depth of the element. Element eqNull3; Element eqNonNull3; Element structNull2; @@ -975,6 +977,8 @@ class ConeTypeLatticeTest : public ::testing::Test { CHECK_MEET(a, b, meet); switch (lattice.compare(a, b)) { case analysis::NO_RELATION: + // This first check looks redundant, but it's also checking the opposite + // direction. CHECK_UNRELATED(a, b); CHECK_LESS(a, join); CHECK_LESS(b, join); @@ -1046,8 +1050,8 @@ class ConeTypeLatticeTest : public ::testing::Test { std::unordered_map ConeTypeLatticeTest::initTypes(ConeTypeLatticeTest* self) { - // 0 3 - // |\ + // 0 3 + // /| // 1 2 TypeBuilder builder(4); builder.createRecGroup(0, 4);