From 6f139bceaaa308cb92e2c20cff9429fecdddca2d Mon Sep 17 00:00:00 2001 From: Thomas Lively Date: Fri, 14 Nov 2025 15:05:29 -0800 Subject: [PATCH] [analysis] Let Flat lattice take multiple types Previously, one could approximate a Flat lattice whose elements could have multiple types by creating a Flat lattice of a variant type. However, this would produce elements that were variants of variants, wasting space on an extra discriminant. To make this use case more efficient and ergonomic, support taking multiple type parameters in Flat. The multiple type parameters all become part of the element variant type. To handle the case where types are repeated, also add element accessors templatized on the type index. --- src/analysis/lattices/flat.h | 68 +++++++++++++++++++++++++----------- test/gtest/lattices.cpp | 36 +++++++++++++------ 2 files changed, 73 insertions(+), 31 deletions(-) diff --git a/src/analysis/lattices/flat.h b/src/analysis/lattices/flat.h index 39c067dc0a4..aa26b101d99 100644 --- a/src/analysis/lattices/flat.h +++ b/src/analysis/lattices/flat.h @@ -17,13 +17,15 @@ #ifndef wasm_analysis_lattices_flat_h #define wasm_analysis_lattices_flat_h +#include +#include #include #if __cplusplus >= 202002L #include #endif -#include "../lattice.h" +#include "analysis/lattice.h" #include "support/utilities.h" namespace wasm::analysis { @@ -33,27 +35,48 @@ namespace wasm::analysis { template concept Flattenable = std::copyable && std::equality_comparable; -// Given a type T, Flat is the lattice where none of the values of T are -// comparable except with themselves, but they are all greater than a common -// bottom element not in T and less than a common top element also not in T. -template +// Given types Ts..., Flat is the lattice where none of the values of any +// T are comparable except with themselves, but they are all greater than a +// common bottom element and less than a common top element. +template #else -template +template #endif struct Flat { private: - struct Bot {}; - struct Top {}; + struct Bot : std::monostate {}; + struct Top : std::monostate {}; + + template + using TI = std::tuple_element_t>; public: - struct Element : std::variant { + struct Element : std::variant { bool isBottom() const noexcept { return std::get_if(this); } bool isTop() const noexcept { return std::get_if(this); } - const T* getVal() const noexcept { return std::get_if(this); } - T* getVal() noexcept { return std::get_if(this); } + template const U* getVal() const noexcept { + return std::get_if(this); + } + template U* getVal() noexcept { + return std::get_if(this); + } + template const TI* getVal() const noexcept { + return std::get_if(this); + } + template TI* getVal() noexcept { + return std::get_if(this); + } bool operator==(const Element& other) const noexcept { - return ((isBottom() && other.isBottom()) || (isTop() && other.isTop()) || - (getVal() && other.getVal() && *getVal() == *other.getVal())); + return this->index() == other.index() && + std::visit( + [](const auto& a, const auto& b) { + if constexpr (std::is_same_v) { + return a == b; + } + return false; + }, + *this, + other); } bool operator!=(const Element& other) const noexcept { return !(*this == other); @@ -62,18 +85,21 @@ struct Flat { Element getBottom() const noexcept { return Element{Bot{}}; } Element getTop() const noexcept { return Element{Top{}}; } - Element get(T&& val) const noexcept { return Element{std::move(val)}; } + template Element get(U&& val) const noexcept { + return Element{std::move(val)}; + } LatticeComparison compare(const Element& a, const Element& b) const noexcept { - if (a.index() < b.index()) { - return LESS; - } else if (a.index() > b.index()) { - return GREATER; - } else if (auto pA = a.getVal(); pA && *pA != *b.getVal()) { - return NO_RELATION; - } else { + if (a == b) { return EQUAL; } + if (a.isTop() || b.isBottom()) { + return GREATER; + } + if (a.isBottom() || b.isTop()) { + return LESS; + } + return NO_RELATION; } bool join(Element& joinee, const Element& joiner) const noexcept { diff --git a/test/gtest/lattices.cpp b/test/gtest/lattices.cpp index 5a80994b2bd..bed52819dba 100644 --- a/test/gtest/lattices.cpp +++ b/test/gtest/lattices.cpp @@ -348,6 +348,22 @@ TEST(FlatLattice, Join) { flat, flat.getBottom(), flat.get(0), flat.get(1), flat.getTop()); } +TEST(FlatLattice, MultipleTypes) { + analysis::Flat flat; + testDiamondJoin( + flat, flat.getBottom(), flat.get(0), flat.get("foo"), flat.getTop()); + + auto stringElem = flat.get("foo"); + + EXPECT_EQ(stringElem.getVal<0>(), nullptr); + ASSERT_NE(stringElem.getVal<1>(), nullptr); + EXPECT_EQ(*stringElem.getVal<1>(), std::string("foo")); + + EXPECT_EQ(stringElem.getVal(), nullptr); + ASSERT_NE(stringElem.getVal(), nullptr); + EXPECT_EQ(*stringElem.getVal(), std::string("foo")); +} + TEST(LiftLattice, GetBottom) { analysis::Lift lift{analysis::Bool{}}; EXPECT_TRUE(lift.getBottom().isBottom()); @@ -711,9 +727,9 @@ TEST(StackLattice, Compare) { auto& flat = stack.lattice; testDiamondCompare(stack, {}, - {flat.get(0)}, - {flat.get(0), flat.get(1)}, - {flat.get(0), flat.getTop()}); + {flat.get(0u)}, + {flat.get(0u), flat.get(1u)}, + {flat.get(0u), flat.getTop()}); } TEST(StackLattice, Join) { @@ -721,9 +737,9 @@ TEST(StackLattice, Join) { auto& flat = stack.lattice; testDiamondJoin(stack, {}, - {flat.get(0)}, - {flat.get(0), flat.get(1)}, - {flat.get(0), flat.getTop()}); + {flat.get(0u)}, + {flat.get(0u), flat.get(1u)}, + {flat.get(0u), flat.getTop()}); } using OddEvenInt = analysis::Flat; @@ -815,10 +831,10 @@ TEST(AbstractionLattice, Join) { #define JOIN(a, b, c) expectJoin(__FILE__, __LINE__, a, b, c) auto bot = abstraction.getBottom(); - auto one = OddEvenAbstraction::Element(OddEvenInt{}.get(1)); - auto two = OddEvenAbstraction::Element(OddEvenInt{}.get(2)); - auto three = OddEvenAbstraction::Element(OddEvenInt{}.get(3)); - auto four = OddEvenAbstraction::Element(OddEvenInt{}.get(4)); + auto one = OddEvenAbstraction::Element(OddEvenInt{}.get(1u)); + auto two = OddEvenAbstraction::Element(OddEvenInt{}.get(2u)); + auto three = OddEvenAbstraction::Element(OddEvenInt{}.get(3u)); + auto four = OddEvenAbstraction::Element(OddEvenInt{}.get(4u)); auto even = OddEvenAbstraction::Element(OddEvenBool{}.get(true)); auto odd = OddEvenAbstraction::Element(OddEvenBool{}.get(false)); auto top = OddEvenAbstraction::Element(OddEvenBool{}.getTop());