Skip to content

Commit

Permalink
HolSmt: track Z3-defined variables across proof replay
Browse files Browse the repository at this point in the history
It turns out that there are multiple reasons for why keeping track of
Z3-defined variables is desirable (or perhaps even needed) during
proof replay.

One of the issues is that when removing definitions, term unification
may end up creating new definitions of this form:

var1 = var2

... where `var1` was not in the previous set of variables that we were
keeping track of for removal. This would cause such definitions to be
ignored when calculating the new set of definitions to remove.

Another reason is that we want to avoid ending up with circular
definitions such as:

var1 = var2
var2 = var1

... where `var1` and `var2` are both Z3-defined variables. To prevent
this, we can orient such definitions so that `var2 = var1` is always
translated into `var1 = var2` (where `var1` <= `var2`, for some
definition of `<=`), i.e. we can always create them in a canonical
orientation.

Keeping track of Z3-defined variables also allows us to orient
definitions created due to term unification (e.g. as part of
`rewrite` proof rules) such that they end up as `var = x` instead
of `x = var`, where `x` is a user-defined variable and `var` is a
Z3-defined variable.

Therefore, this commit adds code to keep track of which variables have
been defined by Z3. As a side effect, it fixes the first issue
mentioned above.

A subsequent commit will use this functionality to orient the
definitions appropriately during proof replay, which will fix the
remaining two issues.
  • Loading branch information
someplaceguy authored and mn200 committed Feb 21, 2024
1 parent 313365a commit 4b76320
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 56 deletions.
4 changes: 2 additions & 2 deletions src/HolSmt/SmtLib_Parser.sml
Expand Up @@ -399,7 +399,7 @@ local
else
raise ERR ("<" ^ name ^ ">") "wrong number of arguments"
in
Library.extend_dict ((name, parsefn), tmdict)
(tm, Library.extend_dict ((name, parsefn), tmdict))
end

(* returns an extended 'tmdict', and the definition (as a formula) *)
Expand Down Expand Up @@ -474,7 +474,7 @@ local
| "declare-fun" =>
let
val (logic, tydict, tmdict, asserted) = dest_state "declare-fun"
val tmdict = parse_declare_fun get_token (tydict, tmdict)
val (_, tmdict) = parse_declare_fun get_token (tydict, tmdict)
in
parse_commands get_token (SOME (logic, tydict, tmdict, asserted))
end
Expand Down
46 changes: 26 additions & 20 deletions src/HolSmt/Unittest.sml
Expand Up @@ -37,14 +37,14 @@ end
(*****************************************************************************)

(* Test: `Z3_ProofReplay.remove_definitions` works without any definitions *)
fun remove_defs_no_defs () = []
fun remove_defs_no_defs () = ([], [])

(* Test: `Z3_ProofReplay.remove_definitions` works with a duplicate definition *)
fun remove_defs_dup () = [
fun remove_defs_dup () = ([
``(z1:num) = x + 1``,
``(z2:num) = z1 + 2``,
``(z2:num) = (x + 1) + 2``
]
], [``(z1:num)``, ``(z2:num)``])

(* Test: `Z3_ProofReplay.remove_definitions` works with the following set of
(tricky) definitions, which can cause an exponential term blow-up in a
Expand All @@ -67,28 +67,31 @@ let
``(x:num) = a128``,
``(x:num) = b128``
]
val varl = [``(a1:num)``, ``(b1:num)``, ``(a128:num)``, ``(b128:num)``,
``(x:num)``]

(* `gen_def` creates a definition of the form ``si = s(i-1) + s(i-1)`` *)
fun gen_def (i, s) =
let
val si = Term.mk_var (s ^ Int.toString i, numSyntax.num)
val si_1 = Term.mk_var (s ^ Int.toString (i - 1), numSyntax.num)
in
boolSyntax.mk_eq (si, numSyntax.mk_plus (si_1, si_1))
(si, boolSyntax.mk_eq (si, numSyntax.mk_plus (si_1, si_1)))
end

(* Add ``ai = a(i-1) + a(i-1)`` and the same for ``bi``, for all 1 < i <= n *)
fun add_defs (n, l) =
fun add_defs (n, asl, varl) =
if n = 1 then
l
(asl, varl)
else
let
val an_def = gen_def (n, "a")
val bn_def = gen_def (n, "b")
val (an, an_def) = gen_def (n, "a")
val (bn, bn_def) = gen_def (n, "b")
in
add_defs (n - 1, an_def :: bn_def :: l)
add_defs (n - 1, an_def :: bn_def :: asl, an :: bn :: varl)
end
in
add_defs (128, asl)
add_defs (128, asl, varl)
end

(* Test: `Z3_ProofReplay.remove_definitions` works with the following set of
Expand Down Expand Up @@ -122,6 +125,8 @@ let
``(z3:num) = z2 + (y + 1) + z1``,
``(z3:num) = (y + 1) + z2 + (x + 1)``
]
val varl = List.map (fn t => Lib.fst (boolSyntax.dest_eq t)) asl

fun add3 (a, b, c) = numSyntax.mk_plus (numSyntax.mk_plus (a, b), c)

(* `gen_def1` creates a definition of the form:
Expand All @@ -135,7 +140,7 @@ let
val middle_addend = add3 (zi_2, zi_2, z1)
val sum = add3 (zi_1, middle_addend, z1)
in
boolSyntax.mk_eq (zi, sum)
(zi, boolSyntax.mk_eq (zi, sum))
end

(* `gen_def2` creates a definition of the form:
Expand All @@ -149,22 +154,22 @@ let
val first_addend = add3 (zi_2, zi_2, xp1)
val sum = add3 (first_addend, zi_1, xp1)
in
boolSyntax.mk_eq (zi, sum)
(zi, boolSyntax.mk_eq (zi, sum))
end

(* Add the definitions `gen_def1 i` and `gen_def2 i`, for all 3 < i <= n *)
fun add_defs (n, l) =
fun add_defs (n, asl, varl) =
if n = 3 then
l
(asl, varl)
else
let
val def1 = gen_def1 n
val def2 = gen_def2 n
val (v1, def1) = gen_def1 n
val (v2, def2) = gen_def2 n
in
add_defs (n - 1, def1 :: def2 :: l)
add_defs (n - 1, def1 :: def2 :: asl, v1 :: v2 :: varl)
end
in
add_defs (128, asl)
add_defs (128, asl, varl)
end

(* Wrapper for `Z3_ProofReplay.remove_definitions` unit tests *)
Expand All @@ -177,11 +182,12 @@ let
val thm = Drule.ADD_ASSUM ``(i:num) = 7`` thm
val thm = Drule.ADD_ASSUM ``(j:num) = i + 3`` thm
(* Add definitions (which should be removed) *)
val asl = get_definitions_fn ()
val (asl, varl) = get_definitions_fn ()
val defs = List.foldl (Lib.flip HOLset.add) Term.empty_tmset asl
val vars = List.foldl (Lib.flip HOLset.add) Term.empty_tmset varl
val thm_with_defs = List.foldl (Lib.uncurry Drule.ADD_ASSUM) thm asl
(* Remove definitions *)
val final_thm = Z3_ProofReplay.remove_definitions (defs, thm_with_defs)
val final_thm = Z3_ProofReplay.remove_definitions (defs, vars, thm_with_defs)
in
(* Make sure the resulting theorem is equal to the one before definitions
were added *)
Expand Down
7 changes: 5 additions & 2 deletions src/HolSmt/Z3_Proof.sml
Expand Up @@ -60,8 +60,11 @@ struct
(* The Z3 proof is a directed acyclic graph of inference steps. A
unique integer ID is assigned to each inference step. Note that
Z3 assigns no ID to the proof's root node, which derives the
final theorem "... |- F". We will use ID 0 for the root node. *)
final theorem "... |- F". We will use ID 0 for the root node.
type proof = (int, proofterm) Redblackmap.dict
Additionally, Z3 also defines variables in proofs, which we keep
keep track of in a set, so we can properly replay the proof. *)

type proof = (int, proofterm) Redblackmap.dict * Term.term HOLset.set

end
12 changes: 7 additions & 5 deletions src/HolSmt/Z3_ProofParser.sml
Expand Up @@ -321,15 +321,15 @@ local
(***************************************************************************)

(* returns an extended proof; 't' must encode a proofterm *)
fun extend_proof proof (id, t) =
fun extend_proof (steps, vars) (id, t) =
let
val _ = if !Library.trace > 0 andalso
Option.isSome (Redblackmap.peek (proof, id)) then
Option.isSome (Redblackmap.peek (steps, id)) then
WARNING "extend_proof"
("proofterm ID " ^ Int.toString id ^ " defined more than once")
else ()
in
Redblackmap.insert (proof, id, proofterm_of_term t)
(Redblackmap.insert (steps, id, proofterm_of_term t), vars)
end

(* distinguishes between a term definition and a proofterm
Expand Down Expand Up @@ -378,7 +378,8 @@ local
parse_proof_inner get_token (tydict, tmdict, proof) (rpars + 1)
else if head = "declare-fun" then
let
val tmdict = SmtLib_Parser.parse_declare_fun get_token (tydict, tmdict)
val (tm, tmdict) = SmtLib_Parser.parse_declare_fun get_token (tydict, tmdict)
val proof = Lib.apsnd (fn set => HOLset.add (set, tm)) proof
in
parse_proof_inner get_token (tydict, tmdict, proof) rpars
end
Expand Down Expand Up @@ -464,8 +465,9 @@ in
Feedback.HOL_MESG "HolSmtLib: parsing Z3 proof"
else ()
val get_token = Library.get_token (Library.get_buffered_char instream)
val empty_proof = (Redblackmap.mkDict Int.compare, Term.empty_tmset)
val proof = parse_proof get_token
(tydict, tmdict, Redblackmap.mkDict Int.compare)
(tydict, tmdict, empty_proof)
val _ = if !Library.trace > 0 then
WARNING "parse_stream" ("ignoring token '" ^ get_token () ^
"' (and perhaps others) after proof")
Expand Down
68 changes: 41 additions & 27 deletions src/HolSmt/Z3_ProofReplay.sml
Expand Up @@ -62,28 +62,33 @@ local
definition_hyps : Term.term HOLset.set,
(* stores certain theorems (proved by 'rewrite' or 'th_lemma') for
later retrieval, to avoid re-reproving them *)
thm_cache : Thm.thm Net.net
thm_cache : Thm.thm Net.net,
(* contains all of the variables that Z3 has defined *)
var_set : Term.term HOLset.set
}

fun state_assert (s : state) (t : Term.term) : state =
{
asserted_hyps = HOLset.add (#asserted_hyps s, t),
definition_hyps = #definition_hyps s,
thm_cache = #thm_cache s
thm_cache = #thm_cache s,
var_set = #var_set s
}

fun state_define (s : state) (terms : Term.term list) : state =
{
asserted_hyps = #asserted_hyps s,
definition_hyps = HOLset.addList (#definition_hyps s, terms),
thm_cache = #thm_cache s
thm_cache = #thm_cache s,
var_set = #var_set s
}

fun state_cache_thm (s : state) (thm : Thm.thm) : state =
{
asserted_hyps = #asserted_hyps s,
definition_hyps = #definition_hyps s,
thm_cache = Net.insert (Thm.concl thm, thm) (#thm_cache s)
thm_cache = Net.insert (Thm.concl thm, thm) (#thm_cache s),
var_set = #var_set s
}

fun state_inst_cached_thm (s : state) (t : Term.term) : Thm.thm =
Expand Down Expand Up @@ -1165,20 +1170,20 @@ local
list_prems state_proof "unit_resolution" z3_unit_resolution x
continuation []
| thm_of_proofterm ((state, proof), ID id) continuation =
(case Redblackmap.peek (proof, id) of
(case Redblackmap.peek (Lib.fst proof, id) of
SOME (THEOREM thm) =>
continuation ((state, proof), thm)
| SOME pt =>
thm_of_proofterm ((state, proof), pt) (continuation o
(* update the proof, replacing the original proofterm with
the theorem just derived *)
(fn ((state, proof), thm) =>
(fn ((state, (steps, vars)), thm) =>
(
if !Library.trace > 2 then
Feedback.HOL_MESG
("HolSmtLib: updating proof at ID " ^ Int.toString id)
else ();
((state, Redblackmap.insert (proof, id, THEOREM thm)), thm)
((state, (Redblackmap.insert (steps, id, THEOREM thm), vars)), thm)
)))
| NONE =>
raise ERR "thm_of_proofterm"
Expand All @@ -1190,11 +1195,11 @@ local
returning the resulting theorem, i.e.:
A u defs |- t
------------- remove_definitions defs
------------- remove_definitions (defs, var_set)
A |- t
Each definition in `defs` must be of the form ``var = term`` and `var` must
not be free in `t` nor in `A`.
Each definition in `defs` must be of the form ``var = term``, where `var`
must not be free in `t` nor in `A` and must be in `var_set`.
There is a major complication: some definitions reference variables in
other definitions and they may even be duplicated (with and without
Expand All @@ -1211,7 +1216,7 @@ local
definitions), which might occur in a naive attempt at removing these
definitions. Therefore, a more careful implementation is warranted.
In general, the variable references can form an acyclic graph. For
In general, the variable references can form a directed acyclic graph. For
efficiency purposes (explained later), we first find a variable that is not
referenced in any definition of the other variables.
Expand Down Expand Up @@ -1259,30 +1264,38 @@ local
at no point we needed to fully expand a definition (unless it's already
expanded). *)

fun remove_definitions (defs: Term.term HOLset.set, thm: Thm.thm): Thm.thm =
fun remove_definitions (defs, var_set, thm): Thm.thm =
if HOLset.isEmpty defs then
thm
else
let
(* For convenience, `pvar_defs` will contain a list of (var, def)
pairs *)
val pvar_defs = List.map boolSyntax.dest_eq (HOLset.listItems defs)
(* `var_set` will contain the set of all variables being defined *)
fun add_var ((var, def), set) = HOLset.add (set, var)
val var_set = List.foldl add_var Term.empty_varset pvar_defs
(* For convenience, `dest_defs` will contain a list of `(lhs, rhs)`
pairs, where `lhs` is the var being defined and `rhs` its
definition. *)
val dest_defs = List.map boolSyntax.dest_eq (HOLset.listItems defs)
val (lhs_l, rhs_l) = ListPair.unzip dest_defs
(* `ref_set` will contain the set of all variables being referenced *)
val all_defs = List.map Lib.snd pvar_defs
val ref_set = Term.FVL all_defs Term.empty_varset
(* `unref_set` will contain the set of all variables not being
referenced *)
val unref_set = HOLset.difference (var_set, ref_set)
val ref_set = Term.FVL rhs_l Term.empty_tmset
(* `def_set` will contain the set of all variables being defined.
It should always be a subset of `var_set`. *)
val def_set = List.foldl (Lib.flip HOLset.add) Term.empty_tmset lhs_l

(* `unref_set` will contain the set of all the variables being defined
but not being referenced *)
val unref_set = HOLset.difference (def_set, ref_set)

val () =
if HOLset.isEmpty unref_set then
raise ERR "remove_definitions" "no unreferenced variables"
else
()

(* Pick an arbitrary variable from `unref_set` *)
val var = Option.valOf (HOLset.find (fn _ => true) unref_set)

(* Get all the variable's definitions *)
fun filter_def (v, d) = if Term.term_eq v var then SOME d else NONE
val defs_to_remove = List.mapPartial filter_def pvar_defs
val defs_to_remove = List.mapPartial filter_def dest_defs

(* Pick an arbitrary definition for instantiation *)
val inst = List.hd defs_to_remove
Expand All @@ -1309,7 +1322,7 @@ local
val new_defs = HOLset.foldl add_def Term.empty_tmset (Thm.hypset thm)
in
(* Recurse to remove the remaining variables' definitions *)
remove_definitions (new_defs, thm)
remove_definitions (new_defs, var_set, thm)
end
in

Expand All @@ -1328,7 +1341,8 @@ in
val state = {
asserted_hyps = Term.empty_tmset,
definition_hyps = Term.empty_tmset,
thm_cache = Net.empty
thm_cache = Net.empty,
var_set = Lib.snd proof
}

(* ID 0 denotes the proof's root node *)
Expand All @@ -1339,7 +1353,7 @@ in

(* remove the definitions introduced by Z3 from the set of hypotheses *)
val final_thm = profile "check_proof(remove_definitions)" remove_definitions
(#definition_hyps state, thm)
(#definition_hyps state, #var_set state, thm)

(* check that the final theorem contains no hyps other than those
that have been asserted *)
Expand Down

0 comments on commit 4b76320

Please sign in to comment.