Skip to content

Commit c1ac52a

Browse files
committed
further simplifications & memory reductions to refinement check
1 parent 939646d commit c1ac52a

File tree

6 files changed

+113
-140
lines changed

6 files changed

+113
-140
lines changed

ir/state.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ class State {
220220
void finishInitializer();
221221

222222
auto& getFn() const { return f; }
223+
auto& getMemory() const { return memory; }
223224
auto& getMemory() { return memory; }
224225
auto& getAxioms() const { return axioms; }
225226
auto& getPre() const { return precondition; }

ir/type.cpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ void VoidType::fixup(const Model &m) {
289289
}
290290

291291
pair<expr, expr>
292-
VoidType::refines(State &src_s, State &tgt_s, const StateValue &src,
292+
VoidType::refines(const State &src_s, const State &tgt_s, const StateValue &src,
293293
const StateValue &tgt) const {
294294
return { true, true };
295295
}
@@ -299,7 +299,7 @@ expr VoidType::mkInput(State &s, const char *name,
299299
UNREACHABLE();
300300
}
301301

302-
void VoidType::printVal(ostream &os, State &s, const expr &e) const {
302+
void VoidType::printVal(ostream &os, const State &s, const expr &e) const {
303303
UNREACHABLE();
304304
}
305305

@@ -347,7 +347,7 @@ expr IntType::enforceIntType(unsigned bits) const {
347347
}
348348

349349
pair<expr, expr>
350-
IntType::refines(State &src_s, State &tgt_s, const StateValue &src,
350+
IntType::refines(const State &src_s, const State &tgt_s, const StateValue &src,
351351
const StateValue &tgt) const {
352352
return { src.non_poison.implies(tgt.non_poison),
353353
(src.non_poison && tgt.non_poison).implies(src.value == tgt.value) };
@@ -358,7 +358,7 @@ expr IntType::mkInput(State &s, const char *name,
358358
return expr::mkVar(name, bits());
359359
}
360360

361-
void IntType::printVal(ostream &os, State &s, const expr &e) const {
361+
void IntType::printVal(ostream &os, const State &s, const expr &e) const {
362362
e.printHexadecimal(os);
363363
os << " (";
364364
e.printUnsigned(os);
@@ -533,8 +533,8 @@ expr FloatType::enforceFloatType() const {
533533
}
534534

535535
pair<expr, expr>
536-
FloatType::refines(State &src_s, State &tgt_s, const StateValue &src,
537-
const StateValue &tgt) const {
536+
FloatType::refines(const State &src_s, const State &tgt_s,
537+
const StateValue &src, const StateValue &tgt) const {
538538
expr non_poison = src.non_poison && tgt.non_poison;
539539
return { src.non_poison.implies(tgt.non_poison),
540540
(src.non_poison && tgt.non_poison).implies(src.value == tgt.value) };
@@ -552,7 +552,7 @@ expr FloatType::mkInput(State &s, const char *name,
552552
UNREACHABLE();
553553
}
554554

555-
void FloatType::printVal(ostream &os, State &s, const expr &e) const {
555+
void FloatType::printVal(ostream &os, const State &s, const expr &e) const {
556556
if (e.isNaN().isTrue()) {
557557
os << "NaN";
558558
return;
@@ -652,7 +652,7 @@ StateValue PtrType::fromInt(StateValue v) const {
652652
}
653653

654654
pair<expr, expr>
655-
PtrType::refines(State &src_s, State &tgt_s, const StateValue &src,
655+
PtrType::refines(const State &src_s, const State &tgt_s, const StateValue &src,
656656
const StateValue &tgt) const {
657657
auto sm = src_s.returnMemory(), tm = tgt_s.returnMemory();
658658
Pointer p(sm, src.value);
@@ -672,7 +672,7 @@ PtrType::mkUndefInput(State &s, const ParamAttrs &attrs) const {
672672
return s.getMemory().mkUndefInput(attrs);
673673
}
674674

675-
void PtrType::printVal(ostream &os, State &s, const expr &e) const {
675+
void PtrType::printVal(ostream &os, const State &s, const expr &e) const {
676676
os << Pointer(s.getMemory(), e);
677677
}
678678

@@ -926,8 +926,8 @@ StateValue AggregateType::fromInt(StateValue v) const {
926926
}
927927

928928
pair<expr, expr>
929-
AggregateType::refines(State &src_s, State &tgt_s, const StateValue &src,
930-
const StateValue &tgt) const {
929+
AggregateType::refines(const State &src_s, const State &tgt_s,
930+
const StateValue &src, const StateValue &tgt) const {
931931
set<expr> poison, value;
932932
for (unsigned i = 0; i < elements; ++i) {
933933
auto [p, v] = children[i]->refines(src_s, tgt_s, extract(src, i),
@@ -954,7 +954,7 @@ unsigned AggregateType::numPointerElements() const {
954954
return count;
955955
}
956956

957-
void AggregateType::printVal(ostream &os, State &s, const expr &e) const {
957+
void AggregateType::printVal(ostream &os, const State &s, const expr &e) const {
958958
UNREACHABLE();
959959
}
960960

@@ -1360,8 +1360,8 @@ StateValue SymbolicType::fromInt(StateValue val) const {
13601360
}
13611361

13621362
pair<expr, expr>
1363-
SymbolicType::refines(State &src_s, State &tgt_s, const StateValue &src,
1364-
const StateValue &tgt) const {
1363+
SymbolicType::refines(const State &src_s, const State &tgt_s,
1364+
const StateValue &src, const StateValue &tgt) const {
13651365
DISPATCH(refines(src_s, tgt_s, src, tgt), UNREACHABLE());
13661366
}
13671367

@@ -1375,7 +1375,7 @@ SymbolicType::mkUndefInput(State &st, const ParamAttrs &attrs) const {
13751375
DISPATCH(mkUndefInput(st, attrs), UNREACHABLE());
13761376
}
13771377

1378-
void SymbolicType::printVal(ostream &os, State &st, const expr &e) const {
1378+
void SymbolicType::printVal(ostream &os, const State &st, const expr &e) const {
13791379
DISPATCH(printVal(os, st, e), UNREACHABLE());
13801380
}
13811381

ir/type.h

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,15 @@ class Type {
107107

108108
// returns pair of refinement constraints for <poison, !poison && value>
109109
virtual std::pair<smt::expr, smt::expr>
110-
refines(State &src_s, State &tgt_s, const StateValue &src,
110+
refines(const State &src_s, const State &tgt_s, const StateValue &src,
111111
const StateValue &tgt) const = 0;
112112

113113
virtual smt::expr
114114
mkInput(State &s, const char *name, const ParamAttrs &attrs) const = 0;
115115
virtual std::pair<smt::expr, smt::expr>
116116
mkUndefInput(State &s, const ParamAttrs &attrs) const;
117117

118-
virtual void printVal(std::ostream &os, State &s,
118+
virtual void printVal(std::ostream &os, const State &s,
119119
const smt::expr &e) const = 0;
120120

121121
virtual void print(std::ostream &os) const = 0;
@@ -136,11 +136,12 @@ class VoidType final : public Type {
136136
smt::expr getTypeConstraints() const override;
137137
void fixup(const smt::Model &m) override;
138138
std::pair<smt::expr, smt::expr>
139-
refines(State &src_s, State &tgt_s, const StateValue &src,
139+
refines(const State &src_s, const State &tgt_s, const StateValue &src,
140140
const StateValue &tgt) const override;
141141
smt::expr
142142
mkInput(State &s, const char *name, const ParamAttrs &attrs) const override;
143-
void printVal(std::ostream &os, State &s, const smt::expr &e) const override;
143+
void printVal(std::ostream &os, const State &s,
144+
const smt::expr &e) const override;
144145
void print(std::ostream &os) const override;
145146
};
146147

@@ -163,11 +164,12 @@ class IntType final : public Type {
163164
bool isIntType() const override;
164165
smt::expr enforceIntType(unsigned bits = 0) const override;
165166
std::pair<smt::expr, smt::expr>
166-
refines(State &src_s, State &tgt_s, const StateValue &src,
167+
refines(const State &src_s, const State &tgt_s, const StateValue &src,
167168
const StateValue &tgt) const override;
168169
smt::expr
169170
mkInput(State &s, const char *name, const ParamAttrs &attrs) const override;
170-
void printVal(std::ostream &os, State &s, const smt::expr &e) const override;
171+
void printVal(std::ostream &os, const State &s,
172+
const smt::expr &e) const override;
171173
void print(std::ostream &os) const override;
172174
};
173175

@@ -208,11 +210,12 @@ class FloatType final : public Type {
208210
smt::expr fromInt(smt::expr v) const override;
209211
IR::StateValue fromInt(IR::StateValue v) const override;
210212
std::pair<smt::expr, smt::expr>
211-
refines(State &src_s, State &tgt_s, const StateValue &src,
213+
refines(const State &src_s, const State &tgt_s, const StateValue &src,
212214
const StateValue &tgt) const override;
213215
smt::expr
214216
mkInput(State &s, const char *name, const ParamAttrs &attrs) const override;
215-
void printVal(std::ostream &os, State &s, const smt::expr &e) const override;
217+
void printVal(std::ostream &os, const State &s,
218+
const smt::expr &e) const override;
216219
void print(std::ostream &os) const override;
217220
};
218221

@@ -240,13 +243,14 @@ class PtrType final : public Type {
240243
smt::expr fromInt(smt::expr v) const override;
241244
IR::StateValue fromInt(IR::StateValue v) const override;
242245
std::pair<smt::expr, smt::expr>
243-
refines(State &src_s, State &tgt_s, const StateValue &src,
246+
refines(const State &src_s, const State &tgt_s, const StateValue &src,
244247
const StateValue &tgt) const override;
245248
smt::expr
246249
mkInput(State &s, const char *name, const ParamAttrs &attrs) const override;
247250
std::pair<smt::expr, smt::expr>
248251
mkUndefInput(State &s, const ParamAttrs &attrs) const override;
249-
void printVal(std::ostream &os, State &s, const smt::expr &e) const override;
252+
void printVal(std::ostream &os, const State &s,
253+
const smt::expr &e) const override;
250254
void print(std::ostream &os) const override;
251255
};
252256

@@ -296,12 +300,13 @@ class AggregateType : public Type {
296300
smt::expr fromInt(smt::expr v) const override;
297301
IR::StateValue fromInt(IR::StateValue v) const override;
298302
std::pair<smt::expr, smt::expr>
299-
refines(State &src_s, State &tgt_s, const StateValue &src,
303+
refines(const State &src_s, const State &tgt_s, const StateValue &src,
300304
const StateValue &tgt) const override;
301305
smt::expr
302306
mkInput(State &s, const char *name, const ParamAttrs &attrs) const override;
303307
unsigned numPointerElements() const;
304-
void printVal(std::ostream &os, State &s, const smt::expr &e) const override;
308+
void printVal(std::ostream &os, const State &s,
309+
const smt::expr &e) const override;
305310
const AggregateType* getAsAggregateType() const override;
306311
};
307312

@@ -402,14 +407,15 @@ class SymbolicType final : public Type {
402407
smt::expr fromInt(smt::expr v) const override;
403408
IR::StateValue fromInt(IR::StateValue v) const override;
404409
std::pair<smt::expr, smt::expr>
405-
refines(State &src_s, State &tgt_s, const StateValue &src,
410+
refines(const State &src_s, const State &tgt_s, const StateValue &src,
406411
const StateValue &tgt) const override;
407412
smt::expr
408413
mkInput(State &s, const char *name, const ParamAttrs &attrs)
409414
const override;
410415
std::pair<smt::expr, smt::expr>
411416
mkUndefInput(State &s, const ParamAttrs &attrs) const override;
412-
void printVal(std::ostream &os, State &s, const smt::expr &e) const override;
417+
void printVal(std::ostream &os, const State &s,
418+
const smt::expr &e) const override;
413419
void print(std::ostream &os) const override;
414420
};
415421

smt/solver.cpp

Lines changed: 26 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <iomanip>
1010
#include <iostream>
1111
#include <optional>
12+
#include <string_view>
1213
#include <utility>
1314
#include <vector>
1415
#include <z3.h>
@@ -208,15 +209,6 @@ ostream& operator<<(ostream &os, const Model &m) {
208209
}
209210

210211

211-
SolverPush::SolverPush(Solver &s) : s(s) {
212-
Z3_solver_push(ctx(), s.s);
213-
}
214-
215-
SolverPush::~SolverPush() {
216-
Z3_solver_pop(ctx(), s.s, 1);
217-
}
218-
219-
220212
static bool print_queries = false;
221213
void solver_print_queries(bool yes) {
222214
print_queries = yes;
@@ -238,7 +230,9 @@ Solver::~Solver() {
238230
}
239231

240232
void Solver::add(const expr &e) {
241-
if (e.isValid()) {
233+
if (e.isFalse()) {
234+
is_unsat = true;
235+
} else if (e.isValid()) {
242236
auto ast = e();
243237
Z3_solver_assert(ctx(), s, ast);
244238
tactic->add(ast);
@@ -288,16 +282,21 @@ expr Solver::assertions() const {
288282
}
289283

290284
Result Solver::check() const {
291-
if (config::skip_smt) {
292-
++num_skips;
293-
return Result::SKIP;
294-
}
295-
296285
if (!valid) {
297286
++num_invalid;
298287
return Result::INVALID;
299288
}
300289

290+
if (is_unsat) {
291+
++num_trivial;
292+
return Result::UNSAT;
293+
}
294+
295+
if (config::skip_smt) {
296+
++num_skips;
297+
return Result::SKIP;
298+
}
299+
301300
++num_queries;
302301
if (print_queries)
303302
cout << "\nSMT query:\n" << Z3_solver_to_string(ctx(), s) << endl;
@@ -312,48 +311,35 @@ Result Solver::check() const {
312311
++num_sats;
313312
return Z3_solver_get_model(ctx(), s);
314313
case Z3_L_UNDEF: {
315-
string reason = Z3_solver_get_reason_unknown(ctx(), s);
314+
string_view reason = Z3_solver_get_reason_unknown(ctx(), s);
316315
if (reason == "timeout") {
317316
++num_timeout;
318317
return Result::TIMEOUT;
319318
}
320319
++num_errors;
321-
return { Result::ERROR, move(reason) };
320+
return { Result::ERROR, string(reason) };
322321
}
323322
default:
324323
UNREACHABLE();
325324
}
326325
}
327326

328-
bool Solver::check(expr &&q, std::function<void(const Result &r)> &&error) {
329-
if (!q.isValid()) {
330-
++num_invalid;
331-
error(Result::INVALID);
332-
return false;
333-
}
334-
335-
if (q.isFalse()) {
336-
++num_trivial;
337-
return true;
338-
}
339-
340-
// TODO: benchmark: reset() or new solver every time?
341-
Solver s;
342-
s.add(q);
343-
auto res = s.check();
344-
if (!res.isUnsat()) {
345-
error(res);
346-
return false;
347-
}
348-
return true;
349-
}
350-
351327
Result check_expr(const expr &e) {
352328
Solver s;
353329
s.add(e);
354330
return s.check();
355331
}
356332

333+
334+
SolverPush::SolverPush(Solver &s) : s(s) {
335+
Z3_solver_push(ctx(), s.s);
336+
}
337+
338+
SolverPush::~SolverPush() {
339+
Z3_solver_pop(ctx(), s.s, 1);
340+
}
341+
342+
357343
void solver_print_stats(ostream &os) {
358344
float total = num_queries / 100.0;
359345
float trivial_pc = num_queries == 0 ? 0 :

0 commit comments

Comments
 (0)