Skip to content

Commit

Permalink
fixes in new solver
Browse files Browse the repository at this point in the history
Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
  • Loading branch information
NikolajBjorner committed Dec 25, 2020
1 parent 21c626e commit 372e5ca
Show file tree
Hide file tree
Showing 12 changed files with 79 additions and 43 deletions.
3 changes: 3 additions & 0 deletions src/ast/euf/euf_egraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,9 @@ namespace euf {
out << " " << p->get_expr_id();
out << "] ";
}
if (n->value() != l_undef) {
out << "[v" << n->bool_var() << " := " << (n->value() == l_true ? "T":"F") << "] ";
}
if (n->has_th_vars()) {
out << "[t";
for (auto v : enode_th_vars(n))
Expand Down
2 changes: 1 addition & 1 deletion src/math/lp/lp_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ namespace lp_api {
return inf_rational(m_value + offset); // v <= value - 1 or v >= value + 1
}
else {
return inf_rational(m_value, m_bound_kind != lower_t); // v <= value - epsilon or v >= value + epsilon
return inf_rational(m_value, m_bound_kind != lower_t); // v <= value - epsilon or v >= value + epsilon
}
}

Expand Down
1 change: 0 additions & 1 deletion src/math/lp/nla_core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1462,7 +1462,6 @@ void core::check_weighted(unsigned sz, std::pair<unsigned, std::function<void(vo
bound += checks[i].first;
uint_set seen;
while (bound > 0 && !done() && m_lemma_vec->empty()) {
SASSERT(bound > 0);
unsigned n = random() % bound;
for (unsigned i = 0; i < sz; ++i) {
if (seen.contains(i))
Expand Down
2 changes: 1 addition & 1 deletion src/sat/smt/arith_diagnostics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ namespace arith {
scoped_anum an(m_nla->am());
m_nla->am().display(out << " = ", nl_value(v, an));
}
else if (m_model_is_initialized && is_registered_var(v))
else if (can_get_value(v))
out << " = " << get_value(v);
if (is_int(v))
out << ", int";
Expand Down
6 changes: 6 additions & 0 deletions src/sat/smt/arith_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ namespace arith {
}

bool solver::unit_propagate() {
m_model_is_initialized = false;
if (!m_new_eq && m_new_bounds.empty() && m_asserted_qhead == m_asserted.size())
return false;

Expand Down Expand Up @@ -572,6 +573,10 @@ namespace arith {
value = ~value;
if (!found_bad && value == get_phase(n->bool_var()))
continue;
TRACE("arith",
tout << eval << " " << value << " " << ctx.bpp(n) << "\n";
tout << mdl << "\n";
s().display(tout););
IF_VERBOSE(0,
verbose_stream() << eval << " " << value << " " << ctx.bpp(n) << "\n";
verbose_stream() << n->bool_var() << " " << n->value() << " " << get_phase(n->bool_var()) << " " << ctx.bpp(n) << "\n";
Expand Down Expand Up @@ -642,6 +647,7 @@ namespace arith {
for (unsigned i = m_bounds_trail.size(); i-- > old_size; ) {
unsigned v = m_bounds_trail[i];
api_bound* b = m_bounds[v].back();
m_bool_var2bound.erase(b->get_lit().var());
// del_use_lists(b);
dealloc(b);
m_bounds[v].pop_back();
Expand Down
3 changes: 3 additions & 0 deletions src/sat/smt/arith_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,9 @@ namespace arith {
vector<constraint_bound> m_upper_terms;
vector<constraint_bound> m_history;

bool can_get_value(theory_var v) const {
return is_registered_var(v) && m_model_is_initialized;
}

// solving
void report_equality_of_fixed_vars(unsigned vi1, unsigned vi2);
Expand Down
84 changes: 48 additions & 36 deletions src/sat/smt/euf_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,49 @@ Module Name:

namespace euf {

class solver::user_sort {
ast_manager& m;
model_ref& mdl;
expr_ref_vector& values;
user_sort_factory factory;
scoped_ptr_vector<expr_ref_vector> sort_values;
obj_map<sort, expr_ref_vector*> sort2values;
public:
user_sort(solver& s, expr_ref_vector& values, model_ref& mdl) :
m(s.m), mdl(mdl), values(values), factory(m) {}

~user_sort() {
for (auto kv : sort2values)
mdl->register_usort(kv.m_key, kv.m_value->size(), kv.m_value->c_ptr());
}

void add(unsigned id, sort* srt) {
expr_ref value(factory.get_fresh_value(srt), m);
values.set(id, value);
expr_ref_vector* vals = nullptr;
if (!sort2values.find(srt, vals)) {
vals = alloc(expr_ref_vector, m);
sort2values.insert(srt, vals);
sort_values.push_back(vals);
}
vals->push_back(value);
}

void register_value(expr* val) {
factory.register_value(val);
}
};

void solver::update_model(model_ref& mdl) {
for (auto* mb : m_solvers)
mb->init_model();
deps_t deps;
m_values.reset();
m_values2root.reset();
collect_dependencies(deps);
user_sort us(*this, m_values, mdl);
collect_dependencies(us, deps);
deps.topological_sort();
dependencies2values(deps, mdl);
dependencies2values(us, deps, mdl);
values2model(deps, mdl);
for (auto* mb : m_solvers)
mb->finalize_model(*mdl);
Expand All @@ -48,13 +82,17 @@ namespace euf {
return mb && mb->include_func_interp(f);
}

void solver::collect_dependencies(deps_t& deps) {
void solver::collect_dependencies(user_sort& us, deps_t& deps) {
for (enode* n : m_egraph.nodes()) {
auto* mb = sort2solver(m.get_sort(n->get_expr()));
expr* e = n->get_expr();
sort* srt = m.get_sort(e);
auto* mb = sort2solver(srt);
if (mb)
mb->add_dep(n, deps);
else
deps.insert(n, nullptr);
if (n->is_root() && m.is_uninterp(srt) && m.is_value(e))
us.register_value(e);
}

TRACE("euf",
Expand All @@ -67,37 +105,7 @@ namespace euf {
);
}

class solver::user_sort {
ast_manager& m;
model_ref& mdl;
expr_ref_vector& values;
user_sort_factory factory;
scoped_ptr_vector<expr_ref_vector> sort_values;
obj_map<sort, expr_ref_vector*> sort2values;
public:
user_sort(solver& s, expr_ref_vector& values, model_ref& mdl):
m(s.m), mdl(mdl), values(values), factory(m) {}

~user_sort() {
for (auto kv : sort2values)
mdl->register_usort(kv.m_key, kv.m_value->size(), kv.m_value->c_ptr());
}

void add(unsigned id, sort* srt) {
expr_ref value(factory.get_fresh_value(srt), m);
values.set(id, value);
expr_ref_vector* vals = nullptr;
if (!sort2values.find(srt, vals)) {
vals = alloc(expr_ref_vector, m);
sort2values.insert(srt, vals);
sort_values.push_back(vals);
}
vals->push_back(value);
}
};

void solver::dependencies2values(deps_t& deps, model_ref& mdl) {
user_sort user_sort(*this, m_values, mdl);
void solver::dependencies2values(user_sort& us, deps_t& deps, model_ref& mdl) {
for (enode* n : deps.top_sorted()) {
unsigned id = n->get_root_id();
if (m_values.get(id, nullptr))
Expand Down Expand Up @@ -134,9 +142,13 @@ namespace euf {
}
continue;
}
if (m.is_value(n->get_root()->get_expr())) {
m_values.set(id, n->get_root()->get_expr());
continue;
}
sort* srt = m.get_sort(e);
if (m.is_uninterp(srt))
user_sort.add(id, srt);
us.add(id, srt);
else if (auto* mbS = sort2solver(srt))
mbS->add_value(n, *mdl, m_values);
else if (auto* mbE = expr2solver(e))
Expand Down
2 changes: 2 additions & 0 deletions src/sat/smt/euf_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ namespace euf {
void solver::unhandled_function(func_decl* f) {
if (m_unhandled_functions.contains(f))
return;
if (m.is_model_value(f))
return;
m_unhandled_functions.push_back(f);
m_trail.push(push_back_vector<solver, func_decl_ref_vector>(m_unhandled_functions));
IF_VERBOSE(0, verbose_stream() << mk_pp(f, m) << " not handled\n");
Expand Down
4 changes: 2 additions & 2 deletions src/sat/smt/euf_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ namespace euf {
obj_map<expr, enode*> m_values2root;
bool include_func_interp(func_decl* f);
void register_macros(model& mdl);
void dependencies2values(deps_t& deps, model_ref& mdl);
void collect_dependencies(deps_t& deps);
void dependencies2values(user_sort& us, deps_t& deps, model_ref& mdl);
void collect_dependencies(user_sort& us, deps_t& deps);
void values2model(deps_t const& deps, model_ref& mdl);
void validate_model(model& mdl);

Expand Down
4 changes: 2 additions & 2 deletions src/sat/smt/q_mbi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,13 @@ namespace q {

lbool mbqi::check_forall(quantifier* q) {
quantifier* q_flat = m_qs.flatten(q);
init_solver();
::solver::scoped_push _sp(*m_solver);
auto* qb = specialize(q_flat);
if (!qb)
return l_undef;
if (m.is_false(qb->mbody))
return l_true;
init_solver();
::solver::scoped_push _sp(*m_solver);
m_solver->assert_expr(qb->mbody);
lbool r = m_solver->check_sat(0, nullptr);
if (r == l_undef)
Expand Down
10 changes: 10 additions & 0 deletions src/sat/smt/sat_dual_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,22 @@ namespace sat {
m_roots.push_scope();
m_tracked_vars.push_scope();
m_units.push_scope();
m_vars.push_scope();
}

void dual_solver::pop(unsigned num_scopes) {
m_solver.user_pop(num_scopes);
unsigned old_sz = m_tracked_vars.old_size(num_scopes);
for (unsigned i = m_tracked_vars.size(); i-- > old_sz; )
m_is_tracked[m_tracked_vars[i]] = false;
old_sz = m_vars.old_size(num_scopes);
for (unsigned i = m_vars.size(); i-- > old_sz; ) {
unsigned v = m_vars[i];
unsigned w = m_ext2var[v];
m_ext2var[v] = null_bool_var;
m_var2ext[w] = null_bool_var;
}
m_vars.pop_scope(num_scopes);
m_units.pop_scope(num_scopes);
m_roots.pop_scope(num_scopes);
m_tracked_vars.pop_scope(num_scopes);
Expand All @@ -51,6 +60,7 @@ namespace sat {
w = m_solver.mk_var();
m_ext2var.setx(v, w, null_bool_var);
m_var2ext.setx(w, v, null_bool_var);
m_vars.push_back(v);
}
return w;
}
Expand Down
1 change: 1 addition & 0 deletions src/sat/smt/sat_dual_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ namespace sat {
bool_var_vector m_is_tracked;
unsigned_vector m_ext2var;
unsigned_vector m_var2ext;
lim_svector<unsigned> m_vars;
void add_literal(literal lit);

bool_var ext2var(bool_var v);
Expand Down

0 comments on commit 372e5ca

Please sign in to comment.