Skip to content

Commit

Permalink
sls updates
Browse files Browse the repository at this point in the history
- add SINGLE_THREAD mode
- add interface to retrieve "best" model so far
  • Loading branch information
NikolajBjorner committed Apr 13, 2024
1 parent 43dd6a5 commit 2682c2e
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 66 deletions.
18 changes: 17 additions & 1 deletion src/ast/sls/bv_sls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@ namespace bv {
}
}


void sls::set_model() {
if (!m_set_model)
return;
if (m_repair_roots.size() >= m_min_repair_size)
return;
m_min_repair_size = m_repair_roots.size();
IF_VERBOSE(2, verbose_stream() << "(sls-update-model :num-unsat " << m_min_repair_size << ")\n");
m_set_model(*get_model());
}

void sls::init_repair_goal(app* t) {
m_eval.init_eval(t);
}
Expand Down Expand Up @@ -94,6 +105,9 @@ namespace bv {
if (m_to_repair.empty())
return;

// refresh the best model so far to a callback
set_model();

// add fresh units, if any
bool new_assertion = false;
while (m_get_unit) {
Expand Down Expand Up @@ -130,7 +144,7 @@ namespace bv {
return m_rand() % 2 == 0;
};
m_eval.init_eval(m_terms.assertions(), eval);
init_repair();
init_repair();
// m_engine_init = false;
}

Expand Down Expand Up @@ -295,10 +309,12 @@ namespace bv {
model_ref mdl = alloc(model, m);
auto& terms = m_eval.sort_assertions(m_terms.assertions());
for (expr* e : terms) {
#if 0
if (!m_eval.re_eval_is_correct(to_app(e))) {
verbose_stream() << "missed evaluation #" << e->get_id() << " " << mk_bounded_pp(e, m) << "\n";
m_eval.display_value(verbose_stream(), e) << "\n";
}
#endif
if (!is_uninterp_const(e))
continue;

Expand Down
8 changes: 8 additions & 0 deletions src/ast/sls/bv_sls.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,13 @@ namespace bv {
bool m_engine_model = false;
bool m_engine_init = false;
std::function<expr_ref()> m_get_unit;
std::function<void(model& mdl)> m_set_model;
unsigned m_min_repair_size = UINT_MAX;

std::pair<bool, app*> next_to_repair();

void init_repair_goal(app* e);
void set_model();
void try_repair_down(app* e);
void try_repair_up(app* e);
void set_repair_down(expr* e) { m_repair_down = e->get_id(); }
Expand Down Expand Up @@ -96,6 +99,11 @@ namespace bv {
*/
void init_unit(std::function<expr_ref()> get_unit) { m_get_unit = get_unit; }

/**
* Add callback to set model
*/
void set_model(std::function<void(model& mdl)> f) { m_set_model = f; }

/**
* Run (bounded) local search to find feasible assignments.
*/
Expand Down
2 changes: 1 addition & 1 deletion src/sat/smt/intblast_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1069,7 +1069,7 @@ namespace intblast {
if (e->get_family_id() != bv.get_family_id())
return false;
for (euf::enode* arg : euf::enode_args(n))
dep.add(n, arg->get_root());
dep.add(n, arg);
return true;
}

Expand Down
98 changes: 52 additions & 46 deletions src/sat/smt/sls_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,31 +22,37 @@ Module Name:

namespace sls {

#ifdef SINGLE_THREAD

solver::solver(euf::solver& ctx) :
th_euf_solver(ctx, symbol("sls"), ctx.get_manager().mk_family_id("sls"))
{}

#else
solver::solver(euf::solver& ctx):
th_euf_solver(ctx, symbol("sls"), ctx.get_manager().mk_family_id("sls")),
m_units(m) {}
th_euf_solver(ctx, symbol("sls"), ctx.get_manager().mk_family_id("sls"))
{}

solver::~solver() {
finalize();
}

void solver::finalize() {
if (!m_completed && m_bvsls) {
m_bvsls->cancel();
if (!m_completed && m_sls) {
m_sls->cancel();
m_thread.join();
m_bvsls->collect_statistics(m_st);
m_bvsls = nullptr;
m_sls->collect_statistics(m_st);
m_sls = nullptr;
m_shared = nullptr;
m_slsm = nullptr;
m_units = nullptr;
}
}

sat::check_result solver::check() {

return sat::check_result::CR_DONE;
}

void solver::simplify() {
}

bool solver::unit_propagate() {
force_push();
sample_local_search();
Expand All @@ -66,71 +72,70 @@ namespace sls {
return false;
}

void solver::push_core() {

}

void solver::pop_core(unsigned n) {
for (; m_trail_lim < s().init_trail_size(); ++m_trail_lim) {
auto lit = s().trail_literal(m_trail_lim);
auto e = ctx.literal2expr(lit);
if (is_unit(e)) {
// IF_VERBOSE(1, verbose_stream() << "add unit " << mk_pp(e, m) << "\n");
std::lock_guard<std::mutex> lock(m_mutex);
m_units.push_back(e);
ast_translation tr(m, *m_shared);
m_units->push_back(tr(e.get()));
m_has_units = true;
}
}
}

void solver::init_search() {
init_local_search();
}
}

void solver::init_local_search() {
if (m_bvsls) {
m_bvsls->cancel();
void solver::init_search() {
if (m_sls) {
m_sls->cancel();
m_thread.join();
m_result = l_undef;
m_completed = false;
m_has_units = false;
m_model = nullptr;
m_units.reset();
m_units = nullptr;
}
// set up state for local search solver here

m_m = alloc(ast_manager, m);
ast_translation tr(m, *m_m);
m_shared = alloc(ast_manager);
m_slsm = alloc(ast_manager);
m_units = alloc(expr_ref_vector, *m_shared);
ast_translation tr(m, *m_slsm);

params_ref p;
m_completed = false;
m_result = l_undef;
m_model = nullptr;
m_bvsls = alloc(bv::sls, *m_m, p);
m_sls = alloc(bv::sls, *m_slsm, s().params());

for (expr* a : ctx.get_assertions())
m_bvsls->assert_expr(tr(a));
m_sls->assert_expr(tr(a));

std::function<bool(expr*, unsigned)> eval = [&](expr* e, unsigned r) {
return false;
};

m_bvsls->init();
m_bvsls->init_eval(eval);
m_bvsls->updt_params(s().params());
m_bvsls->init_unit([&]() {
m_sls->init();
m_sls->init_eval(eval);
m_sls->updt_params(s().params());
m_sls->init_unit([&]() {
if (!m_has_units)
return expr_ref(*m_m);
expr_ref e(m);
return expr_ref(*m_slsm);
expr_ref e(*m_slsm);
{
std::lock_guard<std::mutex> lock(m_mutex);
if (m_units.empty())
return expr_ref(*m_m);
e = m_units.back();
m_units.pop_back();
if (m_units->empty())
return expr_ref(*m_slsm);
ast_translation tr(*m_shared, *m_slsm);
e = tr(m_units->back());
m_units->pop_back();
}
ast_translation tr(m, *m_m);
return expr_ref(tr(e.get()), *m_m);
return e;
});
m_sls->set_model([&](model& mdl) {
std::lock_guard<std::mutex> lock(m_mutex);
ast_translation tr(*m_shared, m);
m_model = mdl.translate(tr);
});

m_thread = std::thread([this]() { run_local_search(); });
Expand All @@ -141,20 +146,21 @@ namespace sls {
return;
m_thread.join();
m_completed = false;
m_bvsls->collect_statistics(m_st);
m_sls->collect_statistics(m_st);
if (m_result == l_true) {
IF_VERBOSE(2, verbose_stream() << "(sat.sls :model-completed)\n";);
auto mdl = m_bvsls->get_model();
ast_translation tr(*m_m, m);
auto mdl = m_sls->get_model();
ast_translation tr(*m_slsm, m);
m_model = mdl->translate(tr);
s().set_canceled();
}
m_bvsls = nullptr;
m_sls = nullptr;
}

void solver::run_local_search() {
m_result = (*m_bvsls)();
m_result = (*m_sls)();
m_completed = true;
}

#endif
}
68 changes: 50 additions & 18 deletions src/sat/smt/sls_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,45 @@ Module Name:
--*/
#pragma once

#include <thread>
#include <mutex>

#include "util/rlimit.h"
#include "ast/sls/bv_sls.h"
#include "sat/smt/sat_th.h"


#ifdef SINGLE_THREAD


namespace euf {
class solver;
}

namespace sls {

class solver : public euf::th_euf_solver {
public:
solver(euf::solver& ctx);

sat::literal internalize(expr* e, bool sign, bool root) override { UNREACHABLE(); return sat::null_literal; }
void internalize(expr* e) override { UNREACHABLE(); }
th_solver* clone(euf::solver& ctx) override { return alloc(solver, ctx); }

model_ref get_model() { return model_ref(nullptr); }
bool unit_propagate() override { return false; }
void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector& r, bool probing) override { UNREACHABLE(); }
sat::check_result check() override { return sat::check_result::CR_DONE;}
std::ostream& display(std::ostream& out) const override { return out; }
std::ostream& display_justification(std::ostream& out, sat::ext_justification_idx idx) const override { UNREACHABLE(); return out; }
std::ostream& display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const override { UNREACHABLE(); return out; }

};
}

#else

#include <thread>
#include <mutex>

namespace euf {
class solver;
}
Expand All @@ -34,38 +66,36 @@ namespace sls {
std::atomic<bool> m_completed, m_has_units;
std::thread m_thread;
std::mutex m_mutex;
scoped_ptr<ast_manager> m_m;
scoped_ptr<bv::sls> m_bvsls;
// m is accessed by the main thread
// m_slsm is accessed by the sls thread
// m_shared is only accessed at synchronization points
scoped_ptr<ast_manager> m_shared, m_slsm;
scoped_ptr<bv::sls> m_sls;
scoped_ptr<expr_ref_vector> m_units;
model_ref m_model;
unsigned m_trail_lim = 0;
expr_ref_vector m_units;
statistics m_st;

void run_local_search();
void init_local_search();
void sample_local_search();

bool is_unit(expr*);

public:
solver(euf::solver& ctx);
~solver();

void simplify() override;
void init_search() override;
model_ref get_model() { return m_model; }

void push_core() override;
void init_search() override;
void push_core() override {}
void pop_core(unsigned n) override;

sat::literal internalize(expr* e, bool sign, bool root) override { UNREACHABLE(); return sat::null_literal; }
void internalize(expr* e) override { UNREACHABLE(); }
th_solver* clone(euf::solver& ctx) override { return alloc(solver, ctx); }
void collect_statistics(statistics& st) const override { st.copy(m_st); }

model_ref get_model() { return m_model; }

void collect_statistics(statistics& st) const override { st.copy(m_st); }
void finalize() override;

bool unit_propagate() override;

sat::literal internalize(expr* e, bool sign, bool root) override { UNREACHABLE(); return sat::null_literal; }
void internalize(expr* e) override { UNREACHABLE(); }
void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector & r, bool probing) override { UNREACHABLE(); }
sat::check_result check() override;
std::ostream & display(std::ostream & out) const override { return out; }
Expand All @@ -75,3 +105,5 @@ namespace sls {
};

}

#endif

2 comments on commit 2682c2e

@nunoplopes
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hot! 🔥

thank you 🙏

@NikolajBjorner
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a lot of WIP and hidden behind some configuration parameters (examples, listed in previous commit).
It only really applies to QFBV at this point and it is unstable (untested) for any input outside of this fragment.
The general setup will evolve to become resilient and include formulas beyond QFBV.
Yet, it is already not too difficult to find examples where this feature helps closing satisfiable formulas quickly.

Please sign in to comment.