Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 47 additions & 21 deletions src/analysis/lattices/flat.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
#ifndef wasm_analysis_lattices_flat_h
#define wasm_analysis_lattices_flat_h

#include <tuple>
#include <type_traits>
#include <variant>

#if __cplusplus >= 202002L
#include <concepts>
#endif

#include "../lattice.h"
#include "analysis/lattice.h"
#include "support/utilities.h"

namespace wasm::analysis {
Expand All @@ -33,27 +35,48 @@ namespace wasm::analysis {
template<typename T>
concept Flattenable = std::copyable<T> && std::equality_comparable<T>;

// Given a type T, Flat<T> is the lattice where none of the values of T are
// comparable except with themselves, but they are all greater than a common
// bottom element not in T and less than a common top element also not in T.
template<Flattenable T>
// Given types Ts..., Flat<T...> is the lattice where none of the values of any
// T are comparable except with themselves, but they are all greater than a
// common bottom element and less than a common top element.
template<Flattenable T, Flattenable... Ts>
#else
template<typename T>
template<typename T, typename... Ts>
#endif
struct Flat {
private:
struct Bot {};
struct Top {};
struct Bot : std::monostate {};
struct Top : std::monostate {};

template<std::size_t I>
using TI = std::tuple_element_t<I, std::tuple<T, Ts...>>;

public:
struct Element : std::variant<Bot, T, Top> {
struct Element : std::variant<T, Ts..., Bot, Top> {
bool isBottom() const noexcept { return std::get_if<Bot>(this); }
bool isTop() const noexcept { return std::get_if<Top>(this); }
const T* getVal() const noexcept { return std::get_if<T>(this); }
T* getVal() noexcept { return std::get_if<T>(this); }
template<typename U = T> const U* getVal() const noexcept {
return std::get_if<U>(this);
}
template<typename U = T> U* getVal() noexcept {
return std::get_if<U>(this);
}
template<std::size_t I> const TI<I>* getVal() const noexcept {
return std::get_if<I>(this);
}
template<std::size_t I> TI<I>* getVal() noexcept {
return std::get_if<I>(this);
}
bool operator==(const Element& other) const noexcept {
return ((isBottom() && other.isBottom()) || (isTop() && other.isTop()) ||
(getVal() && other.getVal() && *getVal() == *other.getVal()));
return this->index() == other.index() &&
std::visit(
[](const auto& a, const auto& b) {
if constexpr (std::is_same_v<decltype(a), decltype(b)>) {
return a == b;
}
return false;
},
*this,
other);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can guess at what this line does, but where in the docs is it described? I can't seem to find that (the variant of visit with two objects after the lambda)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::visit is fun because it can take an arbitrary number of std::variant arguments as long as the visitor functions can take that same number of parameters. See the Variants&&... values variadic parameter given in https://en.cppreference.com/w/cpp/utility/variant/visit2.html.

}
bool operator!=(const Element& other) const noexcept {
return !(*this == other);
Expand All @@ -62,18 +85,21 @@ struct Flat {

Element getBottom() const noexcept { return Element{Bot{}}; }
Element getTop() const noexcept { return Element{Top{}}; }
Element get(T&& val) const noexcept { return Element{std::move(val)}; }
template<typename U> Element get(U&& val) const noexcept {
return Element{std::move(val)};
}

LatticeComparison compare(const Element& a, const Element& b) const noexcept {
if (a.index() < b.index()) {
return LESS;
} else if (a.index() > b.index()) {
return GREATER;
} else if (auto pA = a.getVal(); pA && *pA != *b.getVal()) {
return NO_RELATION;
} else {
if (a == b) {
return EQUAL;
}
if (a.isTop() || b.isBottom()) {
return GREATER;
}
if (a.isBottom() || b.isTop()) {
return LESS;
}
return NO_RELATION;
}

bool join(Element& joinee, const Element& joiner) const noexcept {
Expand Down
36 changes: 26 additions & 10 deletions test/gtest/lattices.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,22 @@ TEST(FlatLattice, Join) {
flat, flat.getBottom(), flat.get(0), flat.get(1), flat.getTop());
}

TEST(FlatLattice, MultipleTypes) {
analysis::Flat<int, std::string> flat;
testDiamondJoin(
flat, flat.getBottom(), flat.get(0), flat.get("foo"), flat.getTop());

auto stringElem = flat.get("foo");

EXPECT_EQ(stringElem.getVal<0>(), nullptr);
ASSERT_NE(stringElem.getVal<1>(), nullptr);
EXPECT_EQ(*stringElem.getVal<1>(), std::string("foo"));

EXPECT_EQ(stringElem.getVal<int>(), nullptr);
ASSERT_NE(stringElem.getVal<std::string>(), nullptr);
EXPECT_EQ(*stringElem.getVal<std::string>(), std::string("foo"));
}

TEST(LiftLattice, GetBottom) {
analysis::Lift lift{analysis::Bool{}};
EXPECT_TRUE(lift.getBottom().isBottom());
Expand Down Expand Up @@ -711,19 +727,19 @@ TEST(StackLattice, Compare) {
auto& flat = stack.lattice;
testDiamondCompare(stack,
{},
{flat.get(0)},
{flat.get(0), flat.get(1)},
{flat.get(0), flat.getTop()});
{flat.get(0u)},
{flat.get(0u), flat.get(1u)},
{flat.get(0u), flat.getTop()});
}

TEST(StackLattice, Join) {
analysis::Stack stack{analysis::Flat<uint32_t>{}};
auto& flat = stack.lattice;
testDiamondJoin(stack,
{},
{flat.get(0)},
{flat.get(0), flat.get(1)},
{flat.get(0), flat.getTop()});
{flat.get(0u)},
{flat.get(0u), flat.get(1u)},
{flat.get(0u), flat.getTop()});
}

using OddEvenInt = analysis::Flat<uint32_t>;
Expand Down Expand Up @@ -815,10 +831,10 @@ TEST(AbstractionLattice, Join) {
#define JOIN(a, b, c) expectJoin(__FILE__, __LINE__, a, b, c)

auto bot = abstraction.getBottom();
auto one = OddEvenAbstraction::Element(OddEvenInt{}.get(1));
auto two = OddEvenAbstraction::Element(OddEvenInt{}.get(2));
auto three = OddEvenAbstraction::Element(OddEvenInt{}.get(3));
auto four = OddEvenAbstraction::Element(OddEvenInt{}.get(4));
auto one = OddEvenAbstraction::Element(OddEvenInt{}.get(1u));
auto two = OddEvenAbstraction::Element(OddEvenInt{}.get(2u));
auto three = OddEvenAbstraction::Element(OddEvenInt{}.get(3u));
auto four = OddEvenAbstraction::Element(OddEvenInt{}.get(4u));
auto even = OddEvenAbstraction::Element(OddEvenBool{}.get(true));
auto odd = OddEvenAbstraction::Element(OddEvenBool{}.get(false));
auto top = OddEvenAbstraction::Element(OddEvenBool{}.getTop());
Expand Down
Loading