Skip to content

Commit

Permalink
use tuned gcd to compute mult inverse
Browse files Browse the repository at this point in the history
  • Loading branch information
NikolajBjorner committed Mar 5, 2024
1 parent 4391c90 commit 7dc4ce8
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 11 deletions.
3 changes: 2 additions & 1 deletion src/ast/sls/bv_sls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ namespace bv {
void sls::reinit_eval() {
std::function<bool(expr*, unsigned)> eval = [&](expr* e, unsigned i) {
auto should_keep = [&]() {
return m_rand() % 100 >= 95;
return m_rand() % 100 >= 98;
};
if (m.is_bool(e)) {
if (m_eval.is_fixed0(e) || should_keep())
Expand Down Expand Up @@ -225,5 +225,6 @@ namespace bv {
void sls::updt_params(params_ref const& _p) {
sls_params p(_p);
m_config.m_max_restarts = p.max_restarts();
m_rand.set_seed(p.random_seed());
}
}
94 changes: 91 additions & 3 deletions src/ast/sls/bv_sls_eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ namespace bv {
m_tmp4.push_back(0);
m_zero.push_back(0);
m_one.push_back(0);
m_a.push_back(0);
m_b.push_back(0);
m_nexta.push_back(0);
m_nextb.push_back(0);
m_aux.push_back(0);
m_minus_one.push_back(~0);
m_one[0] = 1;
}
Expand Down Expand Up @@ -1011,17 +1016,98 @@ namespace bv {
}
return false;
}

unsigned parity_e = e.parity(e.bits);
unsigned parity_b = b.parity(b.bits);

#if 1

auto& x = m_tmp;
auto& y = m_tmp2;
auto& quot = m_tmp3;
auto& rem = m_tmp4;
auto& ta = m_a;
auto& tb = m_b;
auto& nexta = m_nexta;
auto& nextb = m_nextb;
auto& aux = m_aux;


// x*ta + y*tb = x
b.get(y);
if (parity_b > 0)
b.shift_right(y, parity_b);
y[a.nw] = 0;
a.nw = a.nw + 1;
a.bw = 8 * sizeof(digit_t) * a.nw;
// x = 2 ^ b.bw
a.set_zero(x);
a.set(x, b.bw, true);

a.set_one(ta);
a.set_zero(tb);
a.set_zero(nexta);
a.set_one(nextb);

rem.reserve(2 * a.nw);
SASSERT(a.le(y, x));
while (a.gt(y, m_zero)) {
SASSERT(a.le(y, x));
set_div(x, y, a.bw, quot, rem); // quot, rem := quot_rem(x, y)
SASSERT(a.le(rem, y));
a.set(x, y); // x := y
a.set(y, rem); // y := rem
a.set(aux, nexta); // aux := nexta
a.set_mul(rem, quot, nexta, false);
a.set_sub(nexta, ta, rem); // nexta := ta - quot*nexta
a.set(ta, aux); // ta := aux
a.set(aux, nextb); // aux := nextb
a.set_mul(rem, quot, nextb, false);
a.set_sub(nextb, tb, rem); // nextb := tb - quot*nextb
a.set(tb, aux); // tb := aux
}

a.bw = b.bw;
a.nw = b.nw;
// x*a + y*b = 1

#if Z3DEBUG
b.get(y);
if (parity_b > 0)
b.shift_right(y, parity_b);
a.set_mul(m_tmp, tb, y);
#if 0
for (unsigned i = a.nw; i-- > 0; )
verbose_stream() << tb[i];
verbose_stream() << "\n";
for (unsigned i = a.nw; i-- > 0; )
verbose_stream() << y[i];
verbose_stream() << "\n";
for (unsigned i = a.nw; i-- > 0; )
verbose_stream() << m_tmp[i];
verbose_stream() << "\n";
#endif
SASSERT(b.is_one(m_tmp));
#endif
e.get(m_tmp2);
if (parity_e > 0 && parity_b > 0)
b.shift_right(m_tmp2, std::min(parity_b, parity_e));
a.set_mul(m_tmp, tb, m_tmp2);
a.set_repair(random_bool(), m_tmp);

#else
rational ne, nb;
e.get_value(e.bits, ne);
b.get_value(b.bits, nb);
unsigned parity_e = e.parity(e.bits);
unsigned parity_b = b.parity(b.bits);
if (parity_b > 0)
ne /= rational::power_of_two(std::min(parity_b, parity_e));
auto inv_b = nb.pseudo_inverse(b.bw);
rational na = mod(inv_b * ne, rational::power_of_two(a.bw));
a.set_value(m_tmp, na);
a.set_repair(random_bool(), m_tmp);
#endif
return true;
}

Expand Down Expand Up @@ -1454,7 +1540,9 @@ namespace bv {
}
quot[nw - 1] = (1 << (bw % (8 * sizeof(digit_t)))) - 1;
}
else {
else {
for (unsigned i = 0; i < nw; ++i)
rem[i] = quot[i] = 0;
mpn.div(a.data(), nw, b.data(), bnw, quot.data(), rem.data());
}
}
Expand Down
1 change: 1 addition & 0 deletions src/ast/sls/bv_sls_eval.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ namespace bv {
bool_vector m_fixed; // expr-id -> is Boolean fixed

mutable svector<digit_t> m_tmp, m_tmp2, m_tmp3, m_tmp4, m_zero, m_one, m_minus_one;
svector<digit_t> m_a, m_b, m_nextb, m_nexta, m_aux;

using bvval = sls_valuation;

Expand Down
21 changes: 17 additions & 4 deletions src/ast/sls/sls_valuation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,15 @@ namespace bv {
return value;
}

void sls_valuation::shift_right(svector<digit_t>& out, unsigned shift) const {
SASSERT(shift < bw);
unsigned n = shift / (8 * sizeof(digit_t));
unsigned s = shift % (8 * sizeof(digit_t));
for (unsigned i = 0; i < bw; ++i)
set(out, i, i + shift < bw ? get(bits, i + shift) : false);
SASSERT(!has_overflow(out));
}

void sls_valuation::add_range(rational l, rational h) {
l = mod(l, rational::power_of_two(bw));
h = mod(h, rational::power_of_two(bw));
Expand Down Expand Up @@ -427,11 +436,15 @@ namespace bv {
return ovfl;
}

bool sls_valuation::set_mul(svector<digit_t>& out, svector<digit_t> const& a, svector<digit_t> const& b) const {
bool sls_valuation::set_mul(svector<digit_t>& out, svector<digit_t> const& a, svector<digit_t> const& b, bool check_overflow) const {
mpn_manager().mul(a.data(), nw, b.data(), nw, out.data());
bool ovfl = has_overflow(out);
for (unsigned i = nw; i < 2 * nw; ++i)
ovfl |= out[i] != 0;

bool ovfl = false;
if (check_overflow) {
ovfl = has_overflow(out);
for (unsigned i = nw; i < 2 * nw; ++i)
ovfl |= out[i] != 0;
}
clear_overflow_bits(out);
return ovfl;
}
Expand Down
17 changes: 14 additions & 3 deletions src/ast/sls/sls_valuation.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,19 @@ namespace bv {
clear_overflow_bits(bits);
}

void set_zero() {
void set_zero(svector<digit_t>& out) const {
for (unsigned i = 0; i < nw; ++i)
bits[i] = 0;
out[i] = 0;
}

void set_one(svector<digit_t>& out) const {
for (unsigned i = 1; i < nw; ++i)
out[i] = 0;
out[0] = 1;
}

void set_zero() {
set_zero(bits);
}

void sub1(svector<digit_t>& out) const {
Expand All @@ -149,7 +159,8 @@ namespace bv {

void set_sub(svector<digit_t>& out, svector<digit_t> const& a, svector<digit_t> const& b) const;
bool set_add(svector<digit_t>& out, svector<digit_t> const& a, svector<digit_t> const& b) const;
bool set_mul(svector<digit_t>& out, svector<digit_t> const& a, svector<digit_t> const& b) const;
bool set_mul(svector<digit_t>& out, svector<digit_t> const& a, svector<digit_t> const& b, bool check_overflow = true) const;
void shift_right(svector<digit_t>& out, unsigned shift) const;

void set_range(svector<digit_t>& dst, unsigned lo, unsigned hi, bool b) {
for (unsigned i = lo; i < hi; ++i)
Expand Down

0 comments on commit 7dc4ce8

Please sign in to comment.