From 120e7bb43bc3ffcca4efb2cafbeed0605d620bff Mon Sep 17 00:00:00 2001 From: someplaceguy Date: Wed, 21 Feb 2024 14:01:25 +0000 Subject: [PATCH] HolSmt: orient definitions in a canonical form In Z3 proof replay, sometimes we may end up with definitions of the form: var1 = var2 This kind of definition may theoretically arise in three different places: 1. As an explicit Z3-provided definition in `intro-def` proof rules 2. As an implicit Z3-provided definition which we instantiate as part of `rewrite` proof rules (after term unification) 3. When unifying terms to add new definitions in the process of removing the definitions from the final theorem. This can be problematic if `var1` is not a Z3-defined variable (which means `var2` is, and thus the definition should be reversed) or if both `var1` and `var2` are Z3-defined variables. If both vars are Z3-defined variables, we may end up with the following two hypotheses in the final theorem, which `remove_definitions` can't handle due to the circularity: var1 = var2 var2 = var1 To avoid ending up with such problematic definitions, this commit introduces code to change `var1 = var2` into `var2 = var1`, so that the left-hand side is always a Z3-defined variable and, if the right-hand side is also a Z3-defined variable, the former is not "greater" than the latter (for some definition of "greater"). --- src/HolSmt/Library.sml | 46 +++++++++++++++++++++++++++++++---- src/HolSmt/Unittest.sml | 5 +++- src/HolSmt/Z3_ProofReplay.sml | 39 +++++++++++++++++++++++------ 3 files changed, 77 insertions(+), 13 deletions(-) diff --git a/src/HolSmt/Library.sml b/src/HolSmt/Library.sml index d388bef73a..06210af02d 100644 --- a/src/HolSmt/Library.sml +++ b/src/HolSmt/Library.sml @@ -176,6 +176,32 @@ struct Redblackmap.dict = List.foldl extend_dict (Redblackmap.mkDict String.compare) xs + (***************************************************************************) + (* auxiliary functions *) + (***************************************************************************) + + (* `is_def_oriented` must return false when: + 1. `lhs` is not a variable in `var_set` but `rhs` is, or + 2. `lhs` and `rhs` are both variables in `var_set` but `rhs` is smaller + than `lhs` (for some definition of "smaller"). + Otherwise, it must return true. *) + fun is_def_oriented var_set (lhs, rhs) = + (not (HOLset.member (var_set, rhs))) orelse + (HOLset.member (var_set, lhs) andalso + Term.var_compare (rhs, lhs) <> LESS) + + (* Orient a definition of the form `lhs = rhs` into `rhs = lhs`, if necessary. + This is used to ensure that the `lhs` is a variable in `var_set`. Also, if + both the `lhs` and the `rhs` are variables in `var_set`, then the `rhs` + must not be "smaller" than the `lhs`. This is done to avoid ending up with + circular definitions after they are all aggregated into the final theorem, + i.e. we want to avoid ending up with both `var1 = var2` and `var2 = var1`. *) + fun orient_def var_set (lhs, rhs) = + if is_def_oriented var_set (lhs, rhs) then + (lhs, rhs) + else + (rhs, lhs) + (***************************************************************************) (* Derived rules *) (***************************************************************************) @@ -268,14 +294,24 @@ struct returned. The instantiations become hypotheses of the returned theorem. e.g.: - gen_instantiation ``x+1+z2`` ``z1+2`` returns the theorem: + gen_instantiation (``x+1+z2``, ``z1+2``, {``z1``, ``z2``, ``z3``}) + returns the theorem: + + { z1 = x+1, z2 = 2 } |- x+1+z2 = z1+2 - { z1 = x+1, z2 = 2 } |- x+1+z2 = z1+2 *) - fun gen_instantiation (lhs, rhs) = + In cases where we end up with an hypothesis of the form `var1 = var2`, + we might orient the hypothesis to become `var2 = var1` to make sure that + the left-hand side is a variable in `var_set`. If both `var1` and `var2` + are in `var_set`, then we make sure that the left-hand side contains the + "smaller" variable. This avoids creating circular definitions across + multiple calls of this function (i.e. one call instantiating `z1 = z2` + and another instantiating `z2 = z1`). *) + fun gen_instantiation (lhs, rhs, var_set) = let val substs = Unify.simp_unify_terms [] lhs rhs - val asl = List.map (fn {redex, residue} => boolSyntax.mk_eq(redex, residue)) - substs + fun orient {redex, residue} = orient_def var_set (redex, residue) + val oriented_substs = List.map orient substs + val asl = List.map boolSyntax.mk_eq oriented_substs val thms = List.map Thm.ASSUME asl val concl = boolSyntax.mk_eq (lhs, rhs) in diff --git a/src/HolSmt/Unittest.sml b/src/HolSmt/Unittest.sml index 3da2a63faf..3c46e5c636 100644 --- a/src/HolSmt/Unittest.sml +++ b/src/HolSmt/Unittest.sml @@ -183,8 +183,11 @@ let val thm = Drule.ADD_ASSUM ``(j:num) = i + 3`` thm (* Add definitions (which should be removed) *) val (asl, varl) = get_definitions_fn () - val defs = List.foldl (Lib.flip HOLset.add) Term.empty_tmset asl val vars = List.foldl (Lib.flip HOLset.add) Term.empty_tmset varl + (* Let's orient definitions in the same way we do during proof replay *) + val orient = boolSyntax.mk_eq o (Library.orient_def vars) o boolSyntax.dest_eq + val asl = List.map orient asl + val defs = List.foldl (Lib.flip HOLset.add) Term.empty_tmset asl val thm_with_defs = List.foldl (Lib.uncurry Drule.ADD_ASSUM) thm asl (* Remove definitions *) val final_thm = Z3_ProofReplay.remove_definitions (defs, vars, thm_with_defs) diff --git a/src/HolSmt/Z3_ProofReplay.sml b/src/HolSmt/Z3_ProofReplay.sml index 42d0a015a9..48f0a9c5e9 100644 --- a/src/HolSmt/Z3_ProofReplay.sml +++ b/src/HolSmt/Z3_ProofReplay.sml @@ -636,15 +636,39 @@ local used in these definitions are local names introduced by Z3 for the purposes of completing the proof and should not otherwise be relevant in either the remaining hypotheses or the conclusion of the final theorem, - we can remove all such definitions at the end of the proof. *) + we can remove all such definitions at the end of the proof. + + We must take an additional precaution: if `term` is a Z3-defined variable + and it is "smaller" than `name`, then we must actually return the theorem: + + term = name |- t + + This is done to avoid ending up with circular definitions in the final + theorem. *) fun z3_intro_def (state, t) = let val thm = List.hd (Net.match t Z3_ProformaThms.intro_def_thms) - val inst_thm = Drule.INST_TY_TERM (Term.match_term (Thm.concl thm) t) thm - val asm = List.hd (Thm.hyp inst_thm) + val substs = Term.match_term (Thm.concl thm) t + val term_substs = Lib.fst substs + (* Check if the hypothesis should be changed from `name = term` to + `term = name`. Note that `name` and `term` are actually called `n` and + `t` in `intro_def_thms`, except for the 4th schematic form which doesn't + have `t` (nor does it need to be oriented). *) + fun is_varname s tm = Lib.fst (Term.dest_var tm) = s + val name = Option.valOf (Lib.subst_assoc (is_varname "n") term_substs) + val term_opt = Lib.subst_assoc (is_varname "t") term_substs + val is_oriented = + case term_opt of + NONE => true (* `term_opt` will be NONE in the 4th schematic form *) + | SOME term => Library.is_def_oriented (#var_set state) (name, term) + (* Orient the hypothesis if necessary *) + val thm = if is_oriented then thm else + Conv.HYP_CONV_RULE (fn _ => true) Conv.SYM_CONV thm + val inst_thm = Drule.INST_TY_TERM substs thm + val asl = Thm.hyp inst_thm in - (state_define state [asm], inst_thm) + (state_define state asl, inst_thm) end (* [l1, ..., ln] |- F @@ -885,8 +909,9 @@ local removed from the set of hypotheses of the final theorem. *) let + val (lhs, rhs) = boolSyntax.dest_eq t val thm = profile "rewrite(12)(unification)" Library.gen_instantiation - (boolSyntax.dest_eq t) + (lhs, rhs, #var_set state) val asl = Thm.hyp thm in (state_define (state_cache_thm state thm) asl, thm) @@ -1305,8 +1330,8 @@ local (* For each definition corresponding to this variable, create a theorem that can eliminate the definition from the set of hypotheses of `thm` *) - val hyp_thms = List.map (fn def => Library.gen_instantiation (inst, def)) - defs_to_remove + val hyp_thms = List.map (fn def => Library.gen_instantiation (inst, def, + var_set)) defs_to_remove (* Remove all the definitions corresponding to this variable *) fun remove_hyp (hyp_thm, thm) = Drule.PROVE_HYP hyp_thm thm