diff --git a/scripts/update_api.py b/scripts/update_api.py index 59811bf4c7b..00251765be9 100755 --- a/scripts/update_api.py +++ b/scripts/update_api.py @@ -1826,7 +1826,7 @@ def _to_pystr(s): push_eh_type = ctypes.CFUNCTYPE(None, ctypes.c_void_p) pop_eh_type = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_uint) -fixed_eh_type = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_uint, ctypes.c_void_p) +fixed_eh_type = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_uint, ctypes.c_void_p) fresh_eh_type = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.c_void_p) _lib.Z3_solver_propagate_init.restype = None diff --git a/src/api/api_solver.cpp b/src/api/api_solver.cpp index afcaa29db9a..726271131b4 100644 --- a/src/api/api_solver.cpp +++ b/src/api/api_solver.cpp @@ -899,7 +899,7 @@ extern "C" { init_solver(c, s); std::function _push = push_eh; std::function _pop = pop_eh; - std::function _fixed = (void(*)(void*,unsigned,expr*))fixed_eh; + std::function _fixed = (void(*)(void*,solver::propagate_callback*,unsigned,expr*))fixed_eh; std::function _fresh = fresh_eh; to_solver_ref(s)->user_propagate_init(user_context, _fixed, _push, _pop, _fresh); Z3_CATCH; @@ -913,11 +913,11 @@ extern "C" { Z3_CATCH_RETURN(0); } - void Z3_API Z3_solver_propagate_consequence(Z3_context c, Z3_solver s, unsigned sz, unsigned const* ids, Z3_ast conseq) { + void Z3_API Z3_solver_propagate_consequence(Z3_context c, Z3_solver_callback s, unsigned sz, unsigned const* ids, Z3_ast conseq) { Z3_TRY; LOG_Z3_solver_propagate_consequence(c, s, sz, ids, conseq); RESET_ERROR_CODE(); - to_solver_ref(s)->user_propagate_consequence(sz, ids, to_expr(conseq)); + reinterpret_cast(s)->propagate(sz, ids, to_expr(conseq)); Z3_CATCH; } diff --git a/src/api/python/z3/z3.py b/src/api/python/z3/z3.py index c9760b05976..5e595b63bc0 100644 --- a/src/api/python/z3/z3.py +++ b/src/api/python/z3/z3.py @@ -10513,13 +10513,16 @@ def user_prop_push(ctx): def user_prop_pop(ctx, num_scopes): _user_propagate_bases[ctx].pop(num_scopes) -def user_prop_fixed(ctx, id, value): +def user_prop_fixed(ctx, cb, id, value): prop = _user_propagate_bases[ctx] + prop.cb = cb prop.fixed(id, _to_expr_ref(ctypes.c_void_p(value), prop.ctx)) + prop.cb = None def user_prop_fresh(ctx): prop = _user_propagate_bases[ctx] - new_prop = prop.fresh() + new_prop = UsePropagateBase(None, prop.ctx) + _user_prop_bases[new_prop.id] = new_prop.fresh() return ctypes.c_void_p(new_prop.id) @@ -10530,18 +10533,20 @@ def user_prop_fresh(ctx): class UserPropagateBase: - def __init__(self, s): + def __init__(self, s, ctx = None): self.id = len(_user_propagate_bases) + 3 self.solver = s - self.ctx = s.ctx + self.ctx = s.ctx if s is not None else ctx + self.cb = None _user_propagate_bases[self.id] = self - Z3_solver_propagate_init(s.ctx.ref(), - s.solver, - ctypes.c_void_p(self.id), - _user_prop_push, - _user_prop_pop, - _user_prop_fixed, - _user_prop_fresh) + if s: + Z3_solver_propagate_init(s.ctx.ref(), + s.solver, + ctypes.c_void_p(self.id), + _user_prop_push, + _user_prop_pop, + _user_prop_fixed, + _user_prop_fresh) def push(self): raise Z3Exception("push has not been overwritten") @@ -10551,19 +10556,23 @@ def pop(self, num_scopes): def fixed(self, id, e): raise Z3Exception("fixed has not been overwritten") - - def fresh(self): + + def fresh(self, prop_base): raise Z3Exception("fresh has not been overwritten") def add(self, e): + assert self.solver return Z3_solver_propagate_register(self.ctx.ref(), self.solver.solver, e.ast) + # + # Propagation can only be invoked as during a fixed-callback. + # def propagate(self, ids, e): sz = len(ids) _ids = (ctypes.c_uint * sz)() for i in range(sz): _ids[i] = ids[i] - Z3_solver_propagate_consequence(self.ctx.ref(), self.solver.solver, sz, _ids, e.ast) + Z3_solver_propagate_consequence(self.ctx.ref(), self.cb, sz, _ids, e.ast) def conflict(self, ids): self.propagate(ids, BoolVal(False, self.ctx)) diff --git a/src/api/python/z3/z3types.py b/src/api/python/z3/z3types.py index d52a7914e2c..5c93b27a26e 100644 --- a/src/api/python/z3/z3types.py +++ b/src/api/python/z3/z3types.py @@ -82,6 +82,10 @@ class SolverObj(ctypes.c_void_p): def __init__(self, solver): self._as_parameter_ = solver def from_param(obj): return obj +class SolverCallbackObj(ctypes.c_void_p): + def __init__(self, solver): self._as_parameter_ = solver + def from_param(obj): return obj + class FixedpointObj(ctypes.c_void_p): def __init__(self, fixedpoint): self._as_parameter_ = fixedpoint def from_param(obj): return obj diff --git a/src/api/z3_api.h b/src/api/z3_api.h index 127bb90b1f3..921ce372d8d 100644 --- a/src/api/z3_api.h +++ b/src/api/z3_api.h @@ -25,6 +25,7 @@ DEFINE_TYPE(Z3_tactic); DEFINE_TYPE(Z3_probe); DEFINE_TYPE(Z3_stats); DEFINE_TYPE(Z3_solver); +DEFINE_TYPE(Z3_solver_callback); DEFINE_TYPE(Z3_ast_vector); DEFINE_TYPE(Z3_ast_map); DEFINE_TYPE(Z3_apply_result); @@ -1391,6 +1392,7 @@ typedef enum def_Type('CONSTRUCTOR', 'Z3_constructor', 'Constructor') def_Type('CONSTRUCTOR_LIST', 'Z3_constructor_list', 'ConstructorList') def_Type('SOLVER', 'Z3_solver', 'SolverObj') + def_Type('SOLVER_CALLBACK', 'Z3_solver_callback', 'SolverCallbackObj') def_Type('GOAL', 'Z3_goal', 'GoalObj') def_Type('TACTIC', 'Z3_tactic', 'TacticObj') def_Type('PARAMS', 'Z3_params', 'Params') @@ -1418,7 +1420,7 @@ typedef void Z3_error_handler(Z3_context c, Z3_error_code e); */ typedef void Z3_push_eh(void* ctx); typedef void Z3_pop_eh(void* ctx, unsigned num_scopes); -typedef void Z3_fixed_eh(void* ctx, unsigned id, Z3_ast value); +typedef void Z3_fixed_eh(void* ctx, Z3_solver_callback cb, unsigned id, Z3_ast value); typedef void* Z3_fresh_eh(void* ctx); /** @@ -6553,10 +6555,10 @@ extern "C" { The callback adds a propagation consequence based on the fixed values of the \c ids. - def_API('Z3_solver_propagate_consequence', VOID, (_in(CONTEXT), _in(SOLVER), _in(UINT), _in_array(2, UINT), _in(AST))) + def_API('Z3_solver_propagate_consequence', VOID, (_in(CONTEXT), _in(SOLVER_CALLBACK), _in(UINT), _in_array(2, UINT), _in(AST))) */ - void Z3_API Z3_solver_propagate_consequence(Z3_context c, Z3_solver, unsigned sz, unsigned const* ids, Z3_ast conseq); + void Z3_API Z3_solver_propagate_consequence(Z3_context c, Z3_solver_callback, unsigned sz, unsigned const* ids, Z3_ast conseq); /** \brief Check whether the assertions in a given solver are consistent or not. diff --git a/src/smt/smt_context.cpp b/src/smt/smt_context.cpp index dca77ed29e1..77bab98d902 100644 --- a/src/smt/smt_context.cpp +++ b/src/smt/smt_context.cpp @@ -2951,7 +2951,7 @@ namespace smt { void context::user_propagate_init( void* ctx, - std::function& fixed_eh, + std::function& fixed_eh, std::function& push_eh, std::function& pop_eh, std::function& fresh_eh) { diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h index ebef0e80711..4bacac97bf1 100644 --- a/src/smt/smt_context.h +++ b/src/smt/smt_context.h @@ -1689,7 +1689,7 @@ namespace smt { */ void user_propagate_init( void* ctx, - std::function& fixed_eh, + std::function& fixed_eh, std::function& push_eh, std::function& pop_eh, std::function& fresh_eh); @@ -1700,13 +1700,6 @@ namespace smt { return m_user_propagator->add_expr(e); } - void user_propagate_consequence(unsigned sz, unsigned const* ids, expr* conseq) { - if (!m_user_propagator) - throw default_exception("user propagator must be initialized"); - m_user_propagator->add_propagation(sz, ids, conseq); - } - - bool watches_fixed(enode* n) const; void assign_fixed(enode* n, expr* val, unsigned sz, literal const* explain); diff --git a/src/smt/smt_kernel.cpp b/src/smt/smt_kernel.cpp index ca8ea5a376b..567f0eef593 100644 --- a/src/smt/smt_kernel.cpp +++ b/src/smt/smt_kernel.cpp @@ -235,7 +235,7 @@ namespace smt { void user_propagate_init( void* ctx, - std::function& fixed_eh, + std::function& fixed_eh, std::function& push_eh, std::function& pop_eh, std::function& fresh_eh) { @@ -246,9 +246,6 @@ namespace smt { return m_kernel.user_propagate_register(e); } - void user_propagate_consequence(unsigned sz, unsigned const* ids, expr* conseq) { - m_kernel.user_propagate_consequence(sz, ids, conseq); - } }; kernel::kernel(ast_manager & m, smt_params & fp, params_ref const & p) { @@ -463,7 +460,7 @@ namespace smt { void kernel::user_propagate_init( void* ctx, - std::function& fixed_eh, + std::function& fixed_eh, std::function& push_eh, std::function& pop_eh, std::function& fresh_eh) { @@ -472,12 +469,6 @@ namespace smt { unsigned kernel::user_propagate_register(expr* e) { return m_imp->user_propagate_register(e); - } - - void kernel::user_propagate_consequence(unsigned sz, unsigned const* ids, expr* conseq) { - m_imp->user_propagate_consequence(sz, ids, conseq); - } - - + } }; diff --git a/src/smt/smt_kernel.h b/src/smt/smt_kernel.h index 126e4c20f33..3eac0bc188f 100644 --- a/src/smt/smt_kernel.h +++ b/src/smt/smt_kernel.h @@ -26,11 +26,12 @@ Revision History: --*/ #pragma once -#include "ast/ast.h" #include "util/params.h" -#include "model/model.h" #include "util/lbool.h" #include "util/statistics.h" +#include "ast/ast.h" +#include "model/model.h" +#include "solver/solver.h" #include "smt/smt_failure.h" struct smt_params; @@ -289,7 +290,7 @@ namespace smt { */ void user_propagate_init( void* ctx, - std::function& fixed_eh, + std::function& fixed_eh, std::function& push_eh, std::function& pop_eh, std::function& fresh_eh); @@ -298,13 +299,6 @@ namespace smt { \brief register an expression to be tracked fro user propagation. */ unsigned user_propagate_register(expr* e); - - - /** - \brief accept a user-propagation callback (issued during fixed_he). - */ - - void user_propagate_consequence(unsigned sz, unsigned const* ids, expr* conseq); /** diff --git a/src/smt/smt_solver.cpp b/src/smt/smt_solver.cpp index 382bc9d3e4e..eddc0e0d35b 100644 --- a/src/smt/smt_solver.cpp +++ b/src/smt/smt_solver.cpp @@ -210,7 +210,7 @@ namespace { void user_propagate_init( void* ctx, - std::function& fixed_eh, + std::function& fixed_eh, std::function& push_eh, std::function& pop_eh, std::function& fresh_eh) override { @@ -221,10 +221,6 @@ namespace { return m_context.user_propagate_register(e); } - void user_propagate_consequence(unsigned sz, unsigned const* ids, expr* conseq) override { - m_context.user_propagate_consequence(sz, ids, conseq); - } - struct scoped_minimize_core { smt_solver& s; expr_ref_vector m_assumptions; diff --git a/src/smt/user_propagator.cpp b/src/smt/user_propagator.cpp index 114d3fd7e75..5c401fffffd 100644 --- a/src/smt/user_propagator.cpp +++ b/src/smt/user_propagator.cpp @@ -49,7 +49,7 @@ unsigned user_propagator::add_expr(expr* e) { return v; } -void user_propagator::add_propagation(unsigned sz, unsigned const* ids, expr* conseq) { +void user_propagator::propagate(unsigned sz, unsigned const* ids, expr* conseq) { m_prop.push_back(prop_info(sz, ids, expr_ref(conseq, m))); } @@ -63,7 +63,7 @@ theory * user_propagator::mk_fresh(context * new_ctx) { void user_propagator::new_fixed_eh(theory_var v, expr* value, unsigned num_lits, literal const* jlits) { force_push(); m_id2justification.setx(v, literal_vector(num_lits, jlits), literal_vector()); - m_fixed_eh(m_user_context, v, value); + m_fixed_eh(m_user_context, this, v, value); } void user_propagator::push_scope_eh() { diff --git a/src/smt/user_propagator.h b/src/smt/user_propagator.h index 208e9447326..db845086566 100644 --- a/src/smt/user_propagator.h +++ b/src/smt/user_propagator.h @@ -25,11 +25,12 @@ Module Name: #pragma once #include "smt/smt_theory.h" +#include "solver/solver.h" namespace smt { - class user_propagator : public theory { + class user_propagator : public theory, public solver::propagate_callback { void* m_user_context; - std::function m_fixed_eh; + std::function m_fixed_eh; std::function m_push_eh; std::function m_pop_eh; std::function m_fresh_eh; @@ -60,7 +61,7 @@ namespace smt { */ void add( void* ctx, - std::function& fixed_eh, + std::function& fixed_eh, std::function& push_eh, std::function& pop_eh, std::function& fresh_eh) { @@ -73,7 +74,7 @@ namespace smt { unsigned add_expr(expr* e); - void add_propagation(unsigned sz, unsigned const* ids, expr* conseq); + void propagate(unsigned sz, unsigned const* ids, expr* conseq) override; void new_fixed_eh(theory_var v, expr* value, unsigned num_lits, literal const* jlits); diff --git a/src/solver/solver.h b/src/solver/solver.h index 14379e2f6e3..f5452029d5f 100644 --- a/src/solver/solver.h +++ b/src/solver/solver.h @@ -238,9 +238,14 @@ class solver : public check_sat_result { virtual expr_ref get_implied_upper_bound(expr* e) = 0; + class propagate_callback { + public: + virtual void propagate(unsigned sz, unsigned const* ids, expr* conseq) = 0; + }; + virtual void user_propagate_init( void* ctx, - std::function& fixed_eh, + std::function& fixed_eh, std::function& push_eh, std::function& pop_eh, std::function& fresh_eh) { @@ -249,8 +254,6 @@ class solver : public check_sat_result { virtual unsigned user_propagate_register(expr* e) { return 0; } - virtual void user_propagate_consequence(unsigned sz, unsigned const* ids, expr* conseq) {} - /** \brief Display the content of this solver.