diff --git a/src/analysis/lattices/flat.h b/src/analysis/lattices/flat.h index 39c067dc0a4..aa26b101d99 100644 --- a/src/analysis/lattices/flat.h +++ b/src/analysis/lattices/flat.h @@ -17,13 +17,15 @@ #ifndef wasm_analysis_lattices_flat_h #define wasm_analysis_lattices_flat_h +#include +#include #include #if __cplusplus >= 202002L #include #endif -#include "../lattice.h" +#include "analysis/lattice.h" #include "support/utilities.h" namespace wasm::analysis { @@ -33,27 +35,48 @@ namespace wasm::analysis { template concept Flattenable = std::copyable && std::equality_comparable; -// Given a type T, Flat is the lattice where none of the values of T are -// comparable except with themselves, but they are all greater than a common -// bottom element not in T and less than a common top element also not in T. -template +// Given types Ts..., Flat is the lattice where none of the values of any +// T are comparable except with themselves, but they are all greater than a +// common bottom element and less than a common top element. +template #else -template +template #endif struct Flat { private: - struct Bot {}; - struct Top {}; + struct Bot : std::monostate {}; + struct Top : std::monostate {}; + + template + using TI = std::tuple_element_t>; public: - struct Element : std::variant { + struct Element : std::variant { bool isBottom() const noexcept { return std::get_if(this); } bool isTop() const noexcept { return std::get_if(this); } - const T* getVal() const noexcept { return std::get_if(this); } - T* getVal() noexcept { return std::get_if(this); } + template const U* getVal() const noexcept { + return std::get_if(this); + } + template U* getVal() noexcept { + return std::get_if(this); + } + template const TI* getVal() const noexcept { + return std::get_if(this); + } + template TI* getVal() noexcept { + return std::get_if(this); + } bool operator==(const Element& other) const noexcept { - return ((isBottom() && other.isBottom()) || (isTop() && other.isTop()) || - (getVal() && other.getVal() && *getVal() == *other.getVal())); + return this->index() == other.index() && + std::visit( + [](const auto& a, const auto& b) { + if constexpr (std::is_same_v) { + return a == b; + } + return false; + }, + *this, + other); } bool operator!=(const Element& other) const noexcept { return !(*this == other); @@ -62,18 +85,21 @@ struct Flat { Element getBottom() const noexcept { return Element{Bot{}}; } Element getTop() const noexcept { return Element{Top{}}; } - Element get(T&& val) const noexcept { return Element{std::move(val)}; } + template Element get(U&& val) const noexcept { + return Element{std::move(val)}; + } LatticeComparison compare(const Element& a, const Element& b) const noexcept { - if (a.index() < b.index()) { - return LESS; - } else if (a.index() > b.index()) { - return GREATER; - } else if (auto pA = a.getVal(); pA && *pA != *b.getVal()) { - return NO_RELATION; - } else { + if (a == b) { return EQUAL; } + if (a.isTop() || b.isBottom()) { + return GREATER; + } + if (a.isBottom() || b.isTop()) { + return LESS; + } + return NO_RELATION; } bool join(Element& joinee, const Element& joiner) const noexcept { diff --git a/test/gtest/lattices.cpp b/test/gtest/lattices.cpp index 5a80994b2bd..bed52819dba 100644 --- a/test/gtest/lattices.cpp +++ b/test/gtest/lattices.cpp @@ -348,6 +348,22 @@ TEST(FlatLattice, Join) { flat, flat.getBottom(), flat.get(0), flat.get(1), flat.getTop()); } +TEST(FlatLattice, MultipleTypes) { + analysis::Flat flat; + testDiamondJoin( + flat, flat.getBottom(), flat.get(0), flat.get("foo"), flat.getTop()); + + auto stringElem = flat.get("foo"); + + EXPECT_EQ(stringElem.getVal<0>(), nullptr); + ASSERT_NE(stringElem.getVal<1>(), nullptr); + EXPECT_EQ(*stringElem.getVal<1>(), std::string("foo")); + + EXPECT_EQ(stringElem.getVal(), nullptr); + ASSERT_NE(stringElem.getVal(), nullptr); + EXPECT_EQ(*stringElem.getVal(), std::string("foo")); +} + TEST(LiftLattice, GetBottom) { analysis::Lift lift{analysis::Bool{}}; EXPECT_TRUE(lift.getBottom().isBottom()); @@ -711,9 +727,9 @@ TEST(StackLattice, Compare) { auto& flat = stack.lattice; testDiamondCompare(stack, {}, - {flat.get(0)}, - {flat.get(0), flat.get(1)}, - {flat.get(0), flat.getTop()}); + {flat.get(0u)}, + {flat.get(0u), flat.get(1u)}, + {flat.get(0u), flat.getTop()}); } TEST(StackLattice, Join) { @@ -721,9 +737,9 @@ TEST(StackLattice, Join) { auto& flat = stack.lattice; testDiamondJoin(stack, {}, - {flat.get(0)}, - {flat.get(0), flat.get(1)}, - {flat.get(0), flat.getTop()}); + {flat.get(0u)}, + {flat.get(0u), flat.get(1u)}, + {flat.get(0u), flat.getTop()}); } using OddEvenInt = analysis::Flat; @@ -815,10 +831,10 @@ TEST(AbstractionLattice, Join) { #define JOIN(a, b, c) expectJoin(__FILE__, __LINE__, a, b, c) auto bot = abstraction.getBottom(); - auto one = OddEvenAbstraction::Element(OddEvenInt{}.get(1)); - auto two = OddEvenAbstraction::Element(OddEvenInt{}.get(2)); - auto three = OddEvenAbstraction::Element(OddEvenInt{}.get(3)); - auto four = OddEvenAbstraction::Element(OddEvenInt{}.get(4)); + auto one = OddEvenAbstraction::Element(OddEvenInt{}.get(1u)); + auto two = OddEvenAbstraction::Element(OddEvenInt{}.get(2u)); + auto three = OddEvenAbstraction::Element(OddEvenInt{}.get(3u)); + auto four = OddEvenAbstraction::Element(OddEvenInt{}.get(4u)); auto even = OddEvenAbstraction::Element(OddEvenBool{}.get(true)); auto odd = OddEvenAbstraction::Element(OddEvenBool{}.get(false)); auto top = OddEvenAbstraction::Element(OddEvenBool{}.getTop());