Skip to content

Commit

Permalink
user solver (#4709)
Browse files Browse the repository at this point in the history
* user solver

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* na

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* na

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* na

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
  • Loading branch information
NikolajBjorner committed Sep 24, 2020
1 parent 7c2bdfe commit 43db7df
Show file tree
Hide file tree
Showing 19 changed files with 420 additions and 41 deletions.
2 changes: 1 addition & 1 deletion scripts/mk_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def init_project_def():
add_lib('core_tactics', ['tactic', 'macros', 'normal_forms', 'rewriter', 'pattern'], 'tactic/core')
add_lib('arith_tactics', ['core_tactics', 'sat'], 'tactic/arith')

add_lib('sat_smt', ['sat', 'euf', 'tactic', 'smt_params', 'bit_blaster'], 'sat/smt')
add_lib('sat_smt', ['sat', 'euf', 'tactic', 'solver', 'smt_params', 'bit_blaster'], 'sat/smt')
add_lib('sat_tactic', ['tactic', 'sat', 'solver', 'sat_smt'], 'sat/tactic')
add_lib('nlsat_tactic', ['nlsat', 'sat_tactic', 'arith_tactics'], 'nlsat/tactic')
add_lib('subpaving_tactic', ['core_tactics', 'subpaving'], 'math/subpaving/tactic')
Expand Down
2 changes: 1 addition & 1 deletion src/api/api_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -974,7 +974,7 @@ extern "C" {
Z3_TRY;
LOG_Z3_solver_propagate_consequence(c, s, num_fixed, fixed_ids, num_eqs, eq_lhs, eq_rhs, conseq);
RESET_ERROR_CODE();
reinterpret_cast<solver::propagate_callback*>(s)->propagate(num_fixed, fixed_ids, num_eqs, eq_lhs, eq_rhs, to_expr(conseq));
reinterpret_cast<solver::propagate_callback*>(s)->propagate_cb(num_fixed, fixed_ids, num_eqs, eq_lhs, eq_rhs, to_expr(conseq));
Z3_CATCH;
}

Expand Down
6 changes: 3 additions & 3 deletions src/math/lp/lar_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,6 @@ class lar_solver : public column_namer {
void add_row_from_term_no_constraint(const lar_term * term, unsigned term_ext_index);
void add_basic_var_to_core_fields();
bool compare_values(impq const& lhs, lconstraint_kind k, const mpq & rhs);
// columns
bool column_is_int(column_index const& j) const { return column_is_int((unsigned)j); }
const impq& get_value(column_index const& j) const { return get_column_value(j); }


void update_column_type_and_bound_check_on_equal(unsigned j, lconstraint_kind kind, const mpq & right_side, constraint_index constr_index, unsigned&);
Expand Down Expand Up @@ -626,6 +623,9 @@ class lar_solver : public column_namer {
inline bool column_value_is_int(unsigned j) const { return m_mpq_lar_core_solver.m_r_x[j].is_int(); }
inline static_matrix<mpq, impq> & A_r() { return m_mpq_lar_core_solver.m_r_A; }
inline const static_matrix<mpq, impq> & A_r() const { return m_mpq_lar_core_solver.m_r_A; }
// columns
bool column_is_int(column_index const& j) const { return column_is_int((unsigned)j); }
const impq& get_value(column_index const& j) const { return get_column_value(j); }
const impq& get_column_value(unsigned j) const { return m_mpq_lar_core_solver.m_r_x[j]; }
inline
var_index external_to_local(unsigned j) const {
Expand Down
16 changes: 10 additions & 6 deletions src/sat/sat_extension.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,15 @@ namespace sat {
protected:
bool m_drating { false };
int m_id { 0 };
solver* m_solver { nullptr };
public:
extension(int id): m_id(id) {}
virtual ~extension() {}
virtual int get_id() const { return m_id; }
virtual void set_solver(solver* s) = 0;
int get_id() const { return m_id; }
void set_solver(solver* s) { m_solver = s; }
solver& s() { return *m_solver; }
solver const& s() const { return *m_solver; }

virtual void set_lookahead(lookahead* s) {};
class scoped_drating {
extension& ext;
Expand All @@ -70,13 +74,13 @@ namespace sat {
~scoped_drating() { ext.m_drating = false; }
};
virtual void init_search() {}
virtual bool propagate(literal l, ext_constraint_idx idx) = 0;
virtual bool unit_propagate() = 0;
virtual bool is_external(bool_var v) = 0;
virtual bool propagate(sat::literal l, sat::ext_constraint_idx idx) { UNREACHABLE(); return false; }
virtual bool unit_propagate() = 0;
virtual bool is_external(bool_var v) { return false; }
virtual double get_reward(literal l, ext_constraint_idx idx, literal_occs_fun& occs) const { return 0; }
virtual void get_antecedents(literal l, ext_justification_idx idx, literal_vector & r, bool probing) = 0;
virtual bool is_extended_binary(ext_justification_idx idx, literal_vector & r) { return false; }
virtual void asserted(literal l) = 0;
virtual void asserted(literal l) {};
virtual check_result check() = 0;
virtual lbool resolve_conflict() { return l_undef; } // stores result in sat::solver::m_lemma
virtual void push() = 0;
Expand Down
33 changes: 33 additions & 0 deletions src/sat/sat_solver/inc_sat_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,39 @@ class inc_sat_solver : public solver {
m_preprocess->reset();
}

euf::solver* ensure_euf() {
auto* ext = dynamic_cast<euf::solver*>(m_solver.get_extension());
return ext;
}

void user_propagate_init(
void* ctx,
solver::push_eh_t& push_eh,
solver::pop_eh_t& pop_eh,
solver::fresh_eh_t& fresh_eh) override {
ensure_euf()->user_propagate_init(ctx, push_eh, pop_eh, fresh_eh);
}

void user_propagate_register_fixed(solver::fixed_eh_t& fixed_eh) override {
ensure_euf()->user_propagate_register_fixed(fixed_eh);
}

void user_propagate_register_final(solver::final_eh_t& final_eh) override {
ensure_euf()->user_propagate_register_final(final_eh);
}

void user_propagate_register_eq(solver::eq_eh_t& eq_eh) override {
ensure_euf()->user_propagate_register_eq(eq_eh);
}

void user_propagate_register_diseq(solver::eq_eh_t& diseq_eh) override {
ensure_euf()->user_propagate_register_diseq(diseq_eh);
}

unsigned user_propagate_register(expr* e) override {
return ensure_euf()->user_propagate_register(e);
}

private:

lbool internalize_goal(goal_ref& g) {
Expand Down
1 change: 1 addition & 0 deletions src/sat/smt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ z3_add_component(sat_smt
euf_solver.cpp
sat_dual_solver.cpp
sat_th.cpp
user_solver.cpp
xor_solver.cpp
COMPONENT_DEPENDENCIES
sat
Expand Down
3 changes: 0 additions & 3 deletions src/sat/smt/array_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,12 @@ namespace array {

array_util a;
stats m_stats;
sat::solver* m_solver{ nullptr };
scoped_ptr_vector<var_data> m_var_data;
ast2ast_trailmap<sort, app> m_sort2epsilon;
ast2ast_trailmap<sort, func_decl> m_sort2diag;
obj_map<sort, func_decl_ref_vector*> m_sort2diff;
array_union_find m_find;

sat::solver& s() { return *m_solver; }
theory_var find(theory_var v) { return m_find.find(v); }

// internalize
Expand Down Expand Up @@ -187,7 +185,6 @@ namespace array {
public:
solver(euf::solver& ctx, theory_id id);
~solver() override {}
void set_solver(sat::solver* s) override { m_solver = s; }
bool is_external(bool_var v) override { return false; }
bool propagate(literal l, sat::ext_constraint_idx idx) override { UNREACHABLE(); return false; }
void get_antecedents(literal l, sat::ext_justification_idx idx, literal_vector& r, bool probing) override {}
Expand Down
6 changes: 3 additions & 3 deletions src/sat/smt/ba_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1365,7 +1365,7 @@ namespace sat {
ba_solver::ba_solver(ast_manager& m, sat::sat_internalizer& si, euf::theory_id id)
: euf::th_solver(m, id),
si(si), m_pb(m),
m_solver(nullptr), m_lookahead(nullptr),
m_lookahead(nullptr),
m_constraint_id(0), m_ba(*this), m_sort(m_ba) {
TRACE("ba", tout << this << "\n";);
m_num_propagations_since_pop = 0;
Expand Down Expand Up @@ -3579,14 +3579,14 @@ namespace sat {
switch (cnstr.tag()) {
case ba::tag_t::card_t: {
card& c = cnstr.to_card();
ineq.reset(offset*c.k());
ineq.reset(static_cast<uint64_t>(offset)*c.k());
for (literal l : c) ineq.push(l, offset);
if (c.lit() != null_literal) ineq.push(~c.lit(), offset*c.k());
break;
}
case ba::tag_t::pb_t: {
pb& p = cnstr.to_pb();
ineq.reset(offset * p.k());
ineq.reset(static_cast<uint64_t>(offset) * p.k());
for (wliteral wl : p) ineq.push(wl.second, offset * wl.first);
if (p.lit() != null_literal) ineq.push(~p.lit(), offset * p.k());
break;
Expand Down
5 changes: 0 additions & 5 deletions src/sat/smt/ba_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ namespace sat {
sat_internalizer& si;
pb_util m_pb;

solver* m_solver{ nullptr };
lookahead* m_lookahead{ nullptr };
stats m_stats;
small_object_allocator m_allocator;
Expand Down Expand Up @@ -140,9 +139,6 @@ namespace sat {
void inc_parity(bool_var v);
void reset_parity(bool_var v);

solver& s() const { return *m_solver; }


// simplification routines

vector<svector<constraint*>> m_cnstr_use_list;
Expand Down Expand Up @@ -400,7 +396,6 @@ namespace sat {
ba_solver(euf::solver& ctx, euf::theory_id id);
ba_solver(ast_manager& m, sat::sat_internalizer& si, euf::theory_id id);
~ba_solver() override;
void set_solver(solver* s) override { m_solver = s; }
void set_lookahead(lookahead* l) override { m_lookahead = l; }
void add_at_least(bool_var v, literal_vector const& lits, unsigned k);
void add_pb_ge(bool_var v, svector<wliteral> const& wlits, unsigned k);
Expand Down
9 changes: 7 additions & 2 deletions src/sat/smt/bv_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,14 @@ namespace bv {
void solver::fixed_var_eh(theory_var v1) {
numeral val1, val2;
VERIFY(get_fixed_value(v1, val1));
euf::enode* n1 = var2enode(v1);
unsigned sz = m_bits[v1].size();
value_sort_pair key(val1, sz);
theory_var v2;
if (ctx.watches_fixed(n1)) {
expr_ref value(bv.mk_numeral(val1, sz), m);
ctx.assign_fixed(n1, value, m_bits[v1]);
}
bool is_current =
m_fixed_var_table.find(key, v2) &&
v2 < static_cast<int>(get_num_vars()) &&
Expand All @@ -74,12 +79,12 @@ namespace bv {

if (!is_current)
m_fixed_var_table.insert(key, v1);
else if (var2enode(v1)->get_root() != var2enode(v2)->get_root()) {
else if (n1->get_root() != var2enode(v2)->get_root()) {
SASSERT(get_bv_size(v1) == get_bv_size(v2));
TRACE("bv", tout << "detected equality: v" << v1 << " = v" << v2 << "\n" << pp(v1) << pp(v2););
m_stats.m_num_bit2eq++;
add_fixed_eq(v1, v2);
ctx.propagate(var2enode(v1), var2enode(v2), mk_bit2eq_justification(v1, v2));
ctx.propagate(n1, var2enode(v2), mk_bit2eq_justification(v1, v2));
}
}

Expand Down
6 changes: 0 additions & 6 deletions src/sat/smt/bv_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,6 @@ namespace bv {
unsigned_vector m_prop_queue_lim;
unsigned m_prop_queue_head { 0 };


sat::solver* m_solver;
sat::solver& s() { return *m_solver; }
sat::solver const& s() const { return *m_solver; }

// internalize
void insert_bv2a(bool_var bv, atom * a) { m_bool_var2atom.setx(bv, a, 0); }
void erase_bv2a(bool_var bv) { m_bool_var2atom[bv] = 0; }
Expand Down Expand Up @@ -327,7 +322,6 @@ namespace bv {
public:
solver(euf::solver& ctx, theory_id id);
~solver() override {}
void set_solver(sat::solver* s) override { m_solver = s; }
void set_lookahead(sat::lookahead* s) override { }
void init_search() override {}
double get_reward(literal l, sat::ext_constraint_idx idx, sat::literal_occs_fun& occs) const override;
Expand Down
24 changes: 23 additions & 1 deletion src/sat/smt/euf_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ namespace euf {
m_trail(*this),
m_rewriter(m),
m_unhandled_functions(m),
m_solver(nullptr),
m_lookahead(nullptr),
m_to_m(&m),
m_to_si(&si),
Expand Down Expand Up @@ -670,4 +669,27 @@ namespace euf {
return true;
}

void solver::user_propagate_init(
void* ctx,
::solver::push_eh_t& push_eh,
::solver::pop_eh_t& pop_eh,
::solver::fresh_eh_t& fresh_eh) {
m_user_propagator = alloc(user::solver, *this);
m_user_propagator->add(ctx, push_eh, pop_eh, fresh_eh);
for (unsigned i = m_scopes.size(); i-- > 0; )
m_user_propagator->push();
m_solvers.push_back(m_user_propagator);
m_id2solver.setx(m_user_propagator->get_id(), m_user_propagator, nullptr);
}

bool solver::watches_fixed(enode* n) const {
return m_user_propagator && m_user_propagator->has_fixed() && n->get_th_var(m_user_propagator->get_id()) != null_theory_var;
}

void solver::assign_fixed(enode* n, expr* val, unsigned sz, literal const* explain) {
theory_var v = n->get_th_var(m_user_propagator->get_id());
m_user_propagator->new_fixed_eh(v, val, sz, explain);
}


}
52 changes: 45 additions & 7 deletions src/sat/smt/euf_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Module Name:
#include "sat/smt/sat_th.h"
#include "sat/smt/sat_dual_solver.h"
#include "sat/smt/euf_ackerman.h"
#include "sat/smt/user_solver.h"
#include "smt/params/smt_params.h"

namespace euf {
Expand Down Expand Up @@ -85,13 +86,12 @@ namespace euf {
stats m_stats;
th_rewriter m_rewriter;
func_decl_ref_vector m_unhandled_functions;

sat::solver* m_solver{ nullptr };
sat::lookahead* m_lookahead{ nullptr };
ast_manager* m_to_m;
sat::lookahead* m_lookahead{ nullptr };
ast_manager* m_to_m;
sat::sat_internalizer* m_to_si;
scoped_ptr<euf::ackerman> m_ackerman;
scoped_ptr<sat::dual_solver> m_dual_solver;
user::solver* m_user_propagator{ nullptr };

ptr_vector<expr> m_var2expr;
ptr_vector<size_t> m_explain;
Expand Down Expand Up @@ -174,6 +174,12 @@ namespace euf {
constraint& eq_constraint() { return mk_constraint(m_eq, constraint::kind_t::eq); }
constraint& lit_constraint() { return mk_constraint(m_lit, constraint::kind_t::lit); }

// user propagator
void check_for_user_propagator() {
if (!m_user_propagator)
throw default_exception("user propagator must be initialized");
}

public:
solver(ast_manager& m, sat::sat_internalizer& si, params_ref const& p = params_ref());

Expand All @@ -197,8 +203,7 @@ namespace euf {
};

// accessors
sat::solver& s() { return *m_solver; }
sat::solver const& s() const { return *m_solver; }

sat::sat_internalizer& get_si() { return si; }
ast_manager& get_manager() { return m; }
enode* get_enode(expr* e) { return m_egraph.find(e); }
Expand All @@ -212,7 +217,6 @@ namespace euf {
euf_trail_stack& get_trail_stack() { return m_trail; }

void updt_params(params_ref const& p);
void set_solver(sat::solver* s) override { m_solver = s; }
void set_lookahead(sat::lookahead* s) override { m_lookahead = s; }
void init_search() override;
double get_reward(literal l, ext_constraint_idx idx, sat::literal_occs_fun& occs) const override;
Expand Down Expand Up @@ -285,6 +289,40 @@ namespace euf {

// diagnostics
func_decl_ref_vector const& unhandled_functions() { return m_unhandled_functions; }

// user propagator
void user_propagate_init(
void* ctx,
::solver::push_eh_t& push_eh,
::solver::pop_eh_t& pop_eh,
::solver::fresh_eh_t& fresh_eh);
bool watches_fixed(enode* n) const;
void assign_fixed(enode* n, expr* val, unsigned sz, literal const* explain);
void assign_fixed(enode* n, expr* val, literal_vector const& explain) { assign_fixed(n, val, explain.size(), explain.c_ptr()); }
void assign_fixed(enode* n, expr* val, literal explain) { assign_fixed(n, val, 1, &explain); }

void user_propagate_register_final(::solver::final_eh_t& final_eh) {
check_for_user_propagator();
m_user_propagator->register_final(final_eh);
}
void user_propagate_register_fixed(::solver::fixed_eh_t& fixed_eh) {
check_for_user_propagator();
m_user_propagator->register_fixed(fixed_eh);
}
void user_propagate_register_eq(::solver::eq_eh_t& eq_eh) {
check_for_user_propagator();
m_user_propagator->register_eq(eq_eh);
}
void user_propagate_register_diseq(::solver::eq_eh_t& diseq_eh) {
check_for_user_propagator();
m_user_propagator->register_diseq(diseq_eh);
}
unsigned user_propagate_register(expr* e) {
check_for_user_propagator();
return m_user_propagator->add_expr(e);
}


};
};

Expand Down
Loading

0 comments on commit 43db7df

Please sign in to comment.