Skip to content

Commit

Permalink
import updates from poly branch
Browse files Browse the repository at this point in the history
Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
  • Loading branch information
NikolajBjorner committed Jan 11, 2024
1 parent 2ca1187 commit 955c80e
Show file tree
Hide file tree
Showing 8 changed files with 408 additions and 217 deletions.
1 change: 1 addition & 0 deletions src/ast/arith_decl_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ class arith_util : public arith_recognizers {

// return true if \c n is a term of the form (* -1 r)
bool is_zero(expr const* n) const { rational val; return is_numeral(n, val) && val.is_zero(); }
bool is_one(expr const* n) const{ rational val; return is_numeral(n, val) && val.is_one(); }
bool is_minus_one(expr* n) const { rational tmp; return is_numeral(n, tmp) && tmp.is_minus_one(); }
bool is_times_minus_one(expr* n, expr*& r) const {
if (is_mul(n) && to_app(n)->get_num_args() == 2 && is_minus_one(to_app(n)->get_arg(0))) {
Expand Down
162 changes: 111 additions & 51 deletions src/ast/euf/euf_bv_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,18 +69,15 @@ The formal properties of saturation have to be established.
- Saturation does not complete with respect to associativity.
Instead the claim is along the lines that the resulting E-graph can be used as a canonizer.
If given a set of equations E that are saturated, and terms t1, t2 that are
both simplified with respect to left-associativity of concatenation, and t1, t2 belong to the E-graph,
both simplified with respect to left-associativity of concatentation, and t1, t2 belong to the E-graph,
then t1 = t2 iff t1 ~ t2 in the E-graph.
TODO: Is saturation for (7) overkill for the purpose of canonization?
TODO: revisit re-entrancy during register_node. It can be called when creating internal extract terms.
Instead of allowing re-entrancy we can accumulate nodes that are registered during recursive calls
and have the main call perform recursive slicing.
--*/


#include "ast/ast_pp.h"
#include "ast/euf/euf_bv_plugin.h"
#include "ast/euf/euf_egraph.h"

Expand All @@ -91,19 +88,22 @@ namespace euf {
bv(g.get_manager())
{}

enode* bv_plugin::mk_value_concat(enode* a, enode* b) {
auto v1 = get_value(a);
auto v2 = get_value(b);
auto v3 = v1 + v2 * power(rational(2), width(a));
return mk_value(v3, width(a) + width(b));
enode* bv_plugin::mk_value_concat(enode* hi, enode* lo) {
auto v1 = get_value(hi);
auto v2 = get_value(lo);
auto v3 = v2 + v1 * rational::power_of_two(width(lo));
return mk_value(v3, width(lo) + width(hi));
}

enode* bv_plugin::mk_value(rational const& v, unsigned sz) {
auto e = bv.mk_numeral(v, sz);
return mk(e, 0, nullptr);
auto n = mk(e, 0, nullptr);
if (m_ensure_th_var)
m_ensure_th_var(n);
return n;
}

void bv_plugin::merge_eh(enode* x, enode* y) {
void bv_plugin::propagate_merge(enode* x, enode* y) {
if (!bv.is_bv(x->get_expr()))
return;

Expand All @@ -128,7 +128,36 @@ namespace euf {
propagate_extract(n);
}

// enforce concat(v1, v2) = v2*2^|v1| + v1
void bv_plugin::register_node(enode* n) {
m_queue.push_back(n);
m_trail.push_back(new (get_region()) push_back_vector(m_queue));
push_plugin_undo(bv.get_family_id());
}

void bv_plugin::merge_eh(enode* n1, enode* n2) {
m_queue.push_back(enode_pair(n1, n2));
m_trail.push_back(new (get_region()) push_back_vector(m_queue));
push_plugin_undo(bv.get_family_id());
}

void bv_plugin::propagate() {
if (m_qhead == m_queue.size())
return;
m_trail.push_back(new (get_region()) value_trail(m_qhead));
push_plugin_undo(bv.get_family_id());
for (; m_qhead < m_queue.size(); ++m_qhead) {
if (std::holds_alternative<enode*>(m_queue[m_qhead])) {
auto n = *std::get_if<enode*>(&m_queue[m_qhead]);
propagate_register_node(n);
}
else {
auto [a, b] = *std::get_if<enode_pair>(&m_queue[m_qhead]);
propagate_merge(a, b);
}
}
}

// enforce concat(v1, v2) = v1*2^|v2| + v2
void bv_plugin::propagate_values(enode* x) {
if (!is_value(x))
return;
Expand All @@ -142,9 +171,9 @@ namespace euf {
if (is_concat(sib, a, b)) {
if (!is_value(a) || !is_value(b)) {
auto val = get_value(x);
auto v1 = mod2k(val, width(a));
auto v2 = machine_div2k(val, width(a));
push_merge(mk_concat(mk_value(v1, width(a)), mk_value(v2, width(b))), x->get_interpreted());
auto val_a = machine_div2k(val, width(b));
auto val_b = mod2k(val, width(b));
push_merge(mk_concat(mk_value(val_a, width(a)), mk_value(val_b, width(b))), x->get_interpreted());
}
}
}
Expand Down Expand Up @@ -176,18 +205,18 @@ namespace euf {
if (is_extract(p1, lo_, hi_) && lo_ == lo && hi_ == hi && p1->get_arg(0)->get_root() == arg_r)
return;
// add the axiom instead of merge(p, mk_extract(arg, lo, hi)), which would require tracking justifications
push_merge(mk_concat(mk_extract(arg, lo, mid), mk_extract(arg, mid + 1, hi)), mk_extract(arg, lo, hi));
push_merge(mk_concat(mk_extract(arg, mid + 1, hi), mk_extract(arg, lo, mid)), mk_extract(arg, lo, hi));
};

auto propagate_left = [&](enode* b) {
TRACE("bv", tout << "propagate-left " << g.bpp(b) << "\n");
auto propagate_above = [&](enode* b) {
TRACE("bv", tout << "propagate-above " << g.bpp(b) << "\n");
for (enode* sib : enode_class(b))
if (is_extract(sib, lo2, hi2) && sib->get_arg(0)->get_root() == arg_r && hi1 + 1 == lo2)
ensure_concat(lo1, hi1, hi2);
};

auto propagate_right = [&](enode* a) {
TRACE("bv", tout << "propagate-right " << g.bpp(a) << "\n");
auto propagate_below = [&](enode* a) {
TRACE("bv", tout << "propagate-below " << g.bpp(a) << "\n");
for (enode* sib : enode_class(a))
if (is_extract(sib, lo2, hi2) && sib->get_arg(0)->get_root() == arg_r && hi2 + 1 == lo1)
ensure_concat(lo2, hi2, hi1);
Expand All @@ -196,46 +225,65 @@ namespace euf {
for (enode* p : enode_parents(n)) {
if (is_concat(p, a, b)) {
if (a->get_root() == n_r)
propagate_left(b);
propagate_below(b);
if (b->get_root() == n_r)
propagate_right(a);
propagate_above(a);
}
}
}

void bv_plugin::push_undo_split(enode* n) {
m_undo_split.push_back(n);
class bv_plugin::undo_split : public trail {
bv_plugin& p;
enode* n;
public:
undo_split(bv_plugin& p, enode* n): p(p), n(n) {}
void undo() override {
auto& i = p.info(n);
i.value = nullptr;
i.lo = nullptr;
i.hi = nullptr;
i.cut = null_cut;
}
};

void bv_plugin::push_undo_split(enode* n) {
m_trail.push_back(new (get_region()) undo_split(*this, n));
push_plugin_undo(bv.get_family_id());
}

void bv_plugin::undo() {
enode* n = m_undo_split.back();
m_undo_split.pop_back();
auto& i = info(n);
i.lo = nullptr;
i.hi = nullptr;
i.cut = null_cut;
m_trail.back()->undo();
m_trail.pop_back();
}


void bv_plugin::register_node(enode* n) {
void bv_plugin::propagate_register_node(enode* n) {
TRACE("bv", tout << "register " << g.bpp(n) << "\n");
auto& i = info(n);
i.value = n;
enode* a, * b;
unsigned lo, hi;
if (is_concat(n, a, b)) {
i.lo = a;
i.hi = b;
i.cut = width(a);
auto& i = info(n);
i.value = n;
i.hi = a;
i.lo = b;
i.cut = width(b);
push_undo_split(n);
}
unsigned lo, hi;
if (is_extract(n, lo, hi) && (lo != 0 || hi + 1 != width(n->get_arg(0)))) {
else if (is_concat(n) && n->num_args() != 2) {
SASSERT(n->num_args() != 0);
auto last = n->get_arg(n->num_args() - 1);
for (unsigned i = n->num_args() - 1; i-- > 0;)
last = mk_concat(n->get_arg(i), last);
push_merge(last, n);
}
else if (is_extract(n, lo, hi) && (lo != 0 || hi + 1 != width(n->get_arg(0)))) {
enode* arg = n->get_arg(0);
unsigned w = width(arg);
if (all_of(enode_parents(arg), [&](enode* p) { unsigned _lo, _hi; return !is_extract(p, _lo, _hi) || _lo != 0 || _hi + 1 != w; }))
push_merge(mk_extract(arg, 0, w - 1), arg);
ensure_slice(arg, lo, hi);
}
TRACE("bv", tout << "done register " << g.bpp(n) << "\n");
}

//
Expand All @@ -250,7 +298,8 @@ namespace euf {
SASSERT(ub - lb + 1 == width(r));
if (lb == lo && ub == hi)
return;
slice_info& i = info(r);
slice_info const& i = info(r);

if (!i.lo) {
if (lo > lb) {
split(r, lo - lb);
Expand Down Expand Up @@ -287,19 +336,28 @@ namespace euf {
hi += lo1;
n = n->get_arg(0);
}
if (n->interpreted()) {
auto v = get_value(n);
if (lo > 0)
v = div(v, rational::power_of_two(lo));
if (hi + 1 != width(n))
v = mod(v, rational::power_of_two(hi + 1));
return mk_value(v, hi - lo + 1);
}
return mk(bv.mk_extract(hi, lo, n->get_expr()), 1, &n);
}

enode* bv_plugin::mk_concat(enode* lo, enode* hi) {
enode* args[2] = { lo, hi };
return mk(bv.mk_concat(lo->get_expr(), hi->get_expr()), 2, args);
enode* bv_plugin::mk_concat(enode* hi, enode* lo) {
enode* args[2] = { hi, lo };
return mk(bv.mk_concat(hi->get_expr(), lo->get_expr()), 2, args);
}

void bv_plugin::merge(enode_vector& xs, enode_vector& ys, justification dep) {
while (!xs.empty()) {
SASSERT(!ys.empty());
auto x = xs.back();
auto y = ys.back();
TRACE("bv", tout << "merge " << g.bpp(x) << " " << g.bpp(y) << "\n");
if (unfold_sub(x, xs))
continue;
else if (unfold_sub(y, ys))
Expand Down Expand Up @@ -342,14 +400,13 @@ namespace euf {
SASSERT(0 < cut && cut < w);
enode* hi = mk_extract(n, cut, w - 1);
enode* lo = mk_extract(n, 0, cut - 1);
auto& i = info(n);
if (!i.value)
i.value = n;
auto& i = info(n);
i.value = n;
i.hi = hi;
i.lo = lo;
i.cut = cut;
push_undo_split(n);
push_merge(mk_concat(lo, hi), n);
push_merge(mk_concat(hi, lo), n);
}

void bv_plugin::sub_slices(enode* n, std::function<bool(enode*, unsigned)>& consumer) {
Expand Down Expand Up @@ -442,9 +499,12 @@ namespace euf {
continue;
offsets.push_back(offs);
if (n->get_root() == b->get_root() && offs == offset) {
if (n != b)
consumer(n, b);
while (j != UINT_MAX) {
auto [x, y, j2] = just[j];
consumer(x, y);
if (x != y)
consumer(x, y);
j = j2;
}
for (auto const& [n, offset, j] : m_jtodo) {
Expand Down Expand Up @@ -487,10 +547,10 @@ namespace euf {
}

std::ostream& bv_plugin::display(std::ostream& out) const {
out << "bv\n";
out << "bv\n";
for (auto const& i : m_info)
if (i.lo)
out << g.bpp(i.value) << " cut " << i.cut << " lo " << g.bpp(i.lo) << " hi " << g.bpp(i.hi) << "\n";
if (i.lo)
out << g.bpp(i.value) << " cut " << i.cut << " lo " << g.bpp(i.lo) << " hi " << g.bpp(i.hi) << "\n";
return out;
}
}
27 changes: 18 additions & 9 deletions src/ast/euf/euf_bv_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Module Name:

#pragma once

#include "util/trail.h"
#include "ast/bv_decl_plugin.h"
#include "ast/euf/euf_plugin.h"

Expand All @@ -40,26 +41,24 @@ namespace euf {

bv_util bv;
slice_info_vector m_info; // indexed by enode::get_id()



enode_vector m_xs, m_ys;

std::function<void(enode*)> m_ensure_th_var;

bool is_concat(enode* n) const { return bv.is_concat(n->get_expr()); }
bool is_concat(enode* n, enode*& a, enode*& b) { return is_concat(n) && (a = n->get_arg(0), b = n->get_arg(1), true); }
bool is_concat(enode* n, enode*& a, enode*& b) { return is_concat(n) && n->num_args() == 2 && (a = n->get_arg(0), b = n->get_arg(1), true); }
bool is_extract(enode* n, unsigned& lo, unsigned& hi) { expr* body; return bv.is_extract(n->get_expr(), lo, hi, body); }
bool is_extract(enode* n) const { return bv.is_extract(n->get_expr()); }
unsigned width(enode* n) const { return bv.get_bv_size(n->get_expr()); }

enode* mk_extract(enode* n, unsigned lo, unsigned hi);
enode* mk_concat(enode* lo, enode* hi);
enode* mk_value_concat(enode* lo, enode* hi);
enode* mk_concat(enode* hi, enode* lo);
enode* mk_value_concat(enode* hi, enode* lo);
enode* mk_value(rational const& v, unsigned sz);
unsigned width(enode* n) { return bv.get_bv_size(n->get_expr()); }
bool is_value(enode* n) { return n->get_root()->interpreted(); }
rational get_value(enode* n) { rational val; VERIFY(bv.is_numeral(n->get_interpreted()->get_expr(), val)); return val; }
slice_info& info(enode* n) { unsigned id = n->get_id(); m_info.reserve(id + 1); return m_info[id]; }
slice_info& root_info(enode* n) { unsigned id = n->get_root_id(); m_info.reserve(id + 1); return m_info[id]; }
bool has_sub(enode* n) { return !!info(n).lo; }
enode* sub_lo(enode* n) { return info(n).lo; }
enode* sub_hi(enode* n) { return info(n).hi; }
Expand All @@ -81,8 +80,16 @@ namespace euf {
svector<std::tuple<enode*, unsigned, unsigned>> m_jtodo;
void clear_offsets();

enode_vector m_undo_split;

ptr_vector<trail> m_trail;

class undo_split;
void push_undo_split(enode* n);

vector<std::variant<enode*, enode_pair>> m_queue;
unsigned m_qhead = 0;
void propagate_register_node(enode* n);
void propagate_merge(enode* a, enode* b);

public:
bv_plugin(egraph& g);
Expand All @@ -97,9 +104,11 @@ namespace euf {

void diseq_eh(enode* eq) override {}

void propagate() override {}
void propagate() override;

void undo() override;

void set_ensure_th_var(std::function<void(enode*)>& f) { m_ensure_th_var = f; }

std::ostream& display(std::ostream& out) const override;

Expand Down
Loading

0 comments on commit 955c80e

Please sign in to comment.