Skip to content

Commit

Permalink
move theory_var_list into id_var_list and utilities from smt-enode in…
Browse files Browse the repository at this point in the history
…to it, prepare for theory variables in egraph

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
  • Loading branch information
NikolajBjorner committed Sep 1, 2020
1 parent fa9cf0f commit d4e92d4
Show file tree
Hide file tree
Showing 11 changed files with 332 additions and 193 deletions.
57 changes: 54 additions & 3 deletions src/ast/euf/euf_egraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,28 @@ Module Name:

namespace euf {

/**
\brief Trail for add_th_var
*/
class add_th_var_trail : public trail<egraph> {
enode * m_enode;
theory_id m_th_id;
public:
add_th_var_trail(enode * n, theory_id th_id):
m_enode(n),
m_th_id(th_id) {
}

void undo(egraph & ctx) override {
theory_var v = m_enode->get_th_var(m_th_id);
SASSERT(v != null_var);
m_enode->del_th_var(m_th_id);
enode * root = m_enode->get_root();
if (root != m_enode && root->get_th_var(m_th_id) == v)
root->del_th_var(m_th_id);
}
};

void egraph::undo_eq(enode* r1, enode* n1, unsigned r2_num_parents) {
enode* r2 = r1->get_root();
r2->dec_class_size(r1->class_size());
Expand Down Expand Up @@ -89,6 +111,7 @@ namespace euf {
s.m_inconsistent = m_inconsistent;
s.m_num_eqs = m_eqs.size();
s.m_num_nodes = m_nodes.size();
s.m_trail_sz = m_trail.size();
m_scopes.push_back(s);
m_region.push_scope();
}
Expand Down Expand Up @@ -135,6 +158,14 @@ namespace euf {
n->m_parents.finalize();
}

void egraph::add_th_var(enode* n, theory_var v, theory_id id) {
force_push();
SASSERT(null_var == n->get_th_var(id));
SASSERT(n->class_size() == 1);
n->add_th_var(v, id, m_region);
m_trail.push_back(new (m_region) add_th_var_trail(n, id));
}

void egraph::pop(unsigned num_scopes) {
if (num_scopes <= m_num_scopes) {
m_num_scopes -= num_scopes;
Expand All @@ -154,6 +185,7 @@ namespace euf {
m_expr2enode[n->get_owner_id()] = nullptr;
n->~enode();
}
undo_trail_stack<egraph>(*this, m_trail, s.m_trail_sz);
m_inconsistent = s.m_inconsistent;
m_eqs.shrink(s.m_num_eqs);
m_nodes.shrink(s.m_num_nodes);
Expand Down Expand Up @@ -194,12 +226,30 @@ namespace euf {
std::swap(r1->m_next, r2->m_next);
r2->inc_class_size(r1->class_size());
r2->m_parents.append(r1->m_parents);
merge_th_eq(r1, r2);
m_worklist.push_back(r2);
}

void egraph::merge_th_eq(enode* n, enode* root) {
SASSERT(n != root);
for (auto iv : enode_th_vars(n)) {
theory_id id = iv.get_id();
theory_var v = root->get_th_var(id);
if (v == null_var) {
root->add_th_var(iv.get_var(), id, m_region);
m_trail.push_back(new (m_region) add_th_var_trail(root, id));
}
else {
SASSERT(v != iv.get_var());
m_new_th_eqs.push_back(th_eq(id, v, iv.get_var(), n, root));
}
}
}

void egraph::propagate() {
m_new_eqs.reset();
m_new_lits.reset();
m_new_th_eqs.reset();
SASSERT(m_num_scopes == 0 || m_worklist.empty());
unsigned head = 0, tail = m_worklist.size();
while (head < tail && m.limit().inc() && !inconsistent()) {
Expand Down Expand Up @@ -315,7 +365,7 @@ namespace euf {
}

template <typename T>
void egraph::explain_eq(ptr_vector<T>& justifications, enode* a, enode* b, bool comm) {
void egraph::explain_eq(ptr_vector<T>& justifications, enode* a, enode* b) {
SASSERT(m_todo.empty());
SASSERT(a->get_root() == b->get_root());
enode* lca = find_lca(a, b);
Expand Down Expand Up @@ -394,6 +444,7 @@ namespace euf {
for (unsigned i = 0; i < src.m_nodes.size(); ++i) {
enode* n1 = src.m_nodes[i];
expr* e1 = src.m_exprs[i];
SASSERT(!n1->has_th_vars());
args.reset();
for (unsigned j = 0; j < n1->num_args(); ++j)
args.push_back(old_expr2new_enode[n1->get_arg(j)->get_owner_id()]);
Expand All @@ -418,9 +469,9 @@ namespace euf {

template void euf::egraph::explain(ptr_vector<int>& justifications);
template void euf::egraph::explain_todo(ptr_vector<int>& justifications);
template void euf::egraph::explain_eq(ptr_vector<int>& justifications, enode* a, enode* b, bool comm);
template void euf::egraph::explain_eq(ptr_vector<int>& justifications, enode* a, enode* b);

template void euf::egraph::explain(ptr_vector<unsigned>& justifications);
template void euf::egraph::explain_todo(ptr_vector<unsigned>& justifications);
template void euf::egraph::explain_eq(ptr_vector<unsigned>& justifications, enode* a, enode* b, bool comm);
template void euf::egraph::explain_eq(ptr_vector<unsigned>& justifications, enode* a, enode* b);

23 changes: 21 additions & 2 deletions src/ast/euf/euf_egraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Module Name:

#pragma once
#include "util/statistics.h"
#include "util/trail.h"
#include "ast/euf/euf_enode.h"
#include "ast/euf/euf_etable.h"

Expand All @@ -37,12 +38,24 @@ namespace euf {
add_eq_record(enode* r1, enode* n1, unsigned r2_num_parents):
r1(r1), n1(n1), r2_num_parents(r2_num_parents) {}
};

struct th_eq {
theory_id m_id;
theory_var m_v1;
theory_var m_v2;
enode* m_child;
enode* m_root;
th_eq(theory_id id, theory_var v1, theory_var v2, enode* c, enode* r) :
m_id(id), m_v1(v1), m_v2(v2), m_child(c), m_root(r) {}
};

class egraph {
typedef ptr_vector<trail<egraph> > trail_stack;
struct scope {
bool m_inconsistent;
unsigned m_num_eqs;
unsigned m_num_nodes;
unsigned m_trail_sz;
};
struct stats {
unsigned m_num_merge;
Expand All @@ -53,6 +66,7 @@ namespace euf {
void reset() { memset(this, 0, sizeof(*this)); }
};
ast_manager& m;
trail_stack m_trail;
region m_region;
enode_vector m_worklist;
etable m_table;
Expand All @@ -68,10 +82,11 @@ namespace euf {
justification m_justification;
enode_vector m_new_eqs;
enode_vector m_new_lits;
svector<th_eq> m_new_th_eqs;
enode_vector m_todo;
stats m_stats;
std::function<void(expr*,expr*,expr*)> m_used_eq;
std::function<void(app*,app*)> m_used_cc;
std::function<void(app*,app*)> m_used_cc;

void push_eq(enode* r1, enode* n1, unsigned r2_num_parents) {
m_eqs.push_back(add_eq_record(r1, n1, r2_num_parents));
Expand All @@ -82,6 +97,7 @@ namespace euf {
void force_push();
void set_conflict(enode* n1, enode* n2, justification j);
void merge(enode* n1, enode* n2, justification j);
void merge_th_eq(enode* n, enode* root);
void merge_justification(enode* n1, enode* n2, justification j);
void unmerge_justification(enode* n1);
void dedup_equalities();
Expand Down Expand Up @@ -132,14 +148,17 @@ namespace euf {
bool inconsistent() const { return m_inconsistent; }
enode_vector const& new_eqs() const { return m_new_eqs; }
enode_vector const& new_lits() const { return m_new_lits; }
svector<th_eq> const& new_th_eqs() const { return m_new_th_eqs; }

void add_th_var(enode* n, theory_var v, theory_id id);

void set_used_eq(std::function<void(expr*,expr*,expr*)>& used_eq) { m_used_eq = used_eq; }
void set_used_cc(std::function<void(app*,app*)>& used_cc) { m_used_cc = used_cc; }

template <typename T>
void explain(ptr_vector<T>& justifications);
template <typename T>
void explain_eq(ptr_vector<T>& justifications, enode* a, enode* b, bool comm);
void explain_eq(ptr_vector<T>& justifications, enode* a, enode* b);
enode_vector const& nodes() const { return m_nodes; }
void invariant();
void copy_from(egraph const& src, std::function<void*(void*)>& copy_justification);
Expand Down
37 changes: 37 additions & 0 deletions src/ast/euf/euf_enode.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Module Name:
--*/

#include "util/vector.h"
#include "util/id_var_list.h"
#include "ast/ast.h"
#include "ast/euf/euf_justification.h"

Expand All @@ -28,6 +29,11 @@ namespace euf {
typedef ptr_vector<enode> enode_vector;
typedef std::pair<enode*,enode*> enode_pair;
typedef svector<enode_pair> enode_pair_vector;
typedef id_var_list<> th_var_list;
typedef int theory_var;
typedef int theory_id;
const theory_var null_var = -1;
const theory_id null_id = -1;

class enode {
expr* m_owner;
Expand All @@ -42,13 +48,15 @@ namespace euf {
enode* m_next;
enode* m_root;
enode* m_target { nullptr };
th_var_list m_th_vars;
justification m_justification;
unsigned m_num_args;
enode* m_args[0];

friend class enode_args;
friend class enode_parents;
friend class enode_class;
friend class enode_th_vars;
friend class etable;
friend class egraph;

Expand All @@ -73,6 +81,12 @@ namespace euf {
}

void set_update_children() { m_update_children = true; }

friend class add_th_var_trail;
void add_th_var(theory_var v, theory_id id, region & r) { m_th_vars.add_var(v, id, r); }
void replace_th_var(theory_var v, theory_id id) { m_th_vars.replace(v, id); }
void del_th_var(theory_id id) { m_th_vars.del_var(id); }

public:
~enode() {
SASSERT(m_root == this);
Expand Down Expand Up @@ -127,6 +141,9 @@ namespace euf {
expr* get_owner() const { return m_owner; }
unsigned get_owner_id() const { return m_owner->get_id(); }
unsigned get_root_id() const { return m_root->m_owner->get_id(); }
theory_var get_th_var(theory_id id) const { return m_th_vars.find(id); }
bool has_th_vars() const { return !m_th_vars.empty(); }

void inc_class_size(unsigned n) { m_class_size += n; }
void dec_class_size(unsigned n) { m_class_size -= n; }

Expand Down Expand Up @@ -177,4 +194,24 @@ namespace euf {
iterator begin() const { return iterator(&n, nullptr); }
iterator end() const { return iterator(&n, &n); }
};

class enode_th_vars {
enode& n;
public:
class iterator {
th_var_list* m_th_vars;
public:
iterator(th_var_list* n) : m_th_vars(n) {}
th_var_list operator*() { return *m_th_vars; }
iterator& operator++() { m_th_vars = m_th_vars->get_next(); return *this; }
iterator operator++(int) { iterator tmp = *this; ++* this; return tmp; }
bool operator==(iterator const& other) const { return m_th_vars == other.m_th_vars; }
bool operator!=(iterator const& other) const { return !(*this == other); }
};
enode_th_vars(enode& _n) :n(_n) {}
enode_th_vars(enode* _n) :n(*_n) {}
iterator begin() const { return iterator(n.m_th_vars.empty() ? nullptr : &n.m_th_vars); }
iterator end() const { return iterator(nullptr); }
};

}
4 changes: 2 additions & 2 deletions src/sat/smt/euf_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,13 @@ namespace euf {
SASSERT(n);
SASSERT(m_egraph.is_equality(n));
SASSERT(!l.sign());
m_egraph.explain_eq<unsigned>(m_explain, n->get_arg(0), n->get_arg(1), n->commutative());
m_egraph.explain_eq<unsigned>(m_explain, n->get_arg(0), n->get_arg(1));
break;
case constraint::kind_t::lit:
n = m_var2node[l.var()];
SASSERT(n);
SASSERT(m.is_bool(n->get_owner()));
m_egraph.explain_eq<unsigned>(m_explain, n, (l.sign() ? mk_false() : mk_true()), false);
m_egraph.explain_eq<unsigned>(m_explain, n, (l.sign() ? mk_false() : mk_true()));
break;
default:
IF_VERBOSE(0, verbose_stream() << (unsigned)j.kind() << "\n");
Expand Down
Loading

0 comments on commit d4e92d4

Please sign in to comment.