Skip to content

Commit

Permalink
Merge pull request #961 from CakeML/flat_pattern-improvements
Browse files Browse the repository at this point in the history
Improvements to the flatLang pattern compiler
  • Loading branch information
myreen committed Jul 15, 2023
2 parents c530285 + 45a4953 commit 950fef4
Show file tree
Hide file tree
Showing 7 changed files with 216 additions and 68 deletions.
61 changes: 49 additions & 12 deletions compiler/backend/flatLangScript.sml
Expand Up @@ -24,7 +24,7 @@ val _ = set_grammar_ancestry ["ast", "backend_common"];

(* Copied from the semantics, but with AallocEmpty missing. GlobalVar ops have
* been added, also TagLenEq and El for pattern match compilation. *)
val _ = Datatype `
Datatype:
op =
(* Operations on integers *)
Opn opn
Expand Down Expand Up @@ -101,21 +101,23 @@ val _ = Datatype `
| LenEq num
| El num
(* No-op step for a single value *)
| Id`;
| Id
End

Type ctor_id = ``:num``
(* NONE represents the exception type *)
Type type_id = ``:num option``
Type type_group_id = ``:(num # (ctor_id # num) list) option``

val _ = Datatype `
Datatype:
pat =
| Pany
| Pvar varN
| Plit lit
| Pcon ((ctor_id # type_group_id) option) (pat list)
| Pas pat varN
| Pref pat`;
| Pref pat
End

Definition pat_bindings_def:
(pat_bindings Pany already_bound = already_bound) ∧
Expand All @@ -128,7 +130,7 @@ Definition pat_bindings_def:
(pats_bindings (p::ps) already_bound = pats_bindings ps (pat_bindings p already_bound))
End

val _ = Datatype`
Datatype:
exp =
Raise tra exp
| Handle tra exp ((pat # exp) list)
Expand All @@ -140,7 +142,8 @@ val _ = Datatype`
| If tra exp exp exp
| Mat tra exp ((pat # exp) list)
| Let tra (varN option) exp exp
| Letrec varN ((varN # varN # exp) list) exp`;
| Letrec varN ((varN # varN # exp) list) exp
End

val exp_size_def = definition"exp_size_def";

Expand Down Expand Up @@ -207,20 +210,54 @@ Proof
\\ decide_tac
QED

val _ = Datatype`
Datatype:
dec =
Dlet exp
(* The first number is the identity for the type. The sptree maps arities to
* how many constructors have that arity *)
| Dtype num (num spt)
(* The first number is the identity of the exception. The second number is the
* constructor's arity *)
| Dexn num num`;
| Dexn num num
End

Definition bool_id_def:
bool_id = 0n
End

val bool_id_def = Define `
bool_id = 0n`;
Definition Bool_def:
Bool t b = Con t (SOME (backend_common$bool_to_tag b, SOME bool_id)) []
End

Definition SmartIf_def:
SmartIf t e p q =
case e of
Con _ (SOME (tag, SOME id)) [] =>
if id = bool_id then
if tag = backend_common$true_tag then p
else if tag = backend_common$false_tag then q
else If t e p q
else If t e p q
| _ => If t e p q
End

val Bool_def = Define`
Bool t b = Con t (SOME (backend_common$bool_to_tag b, SOME bool_id)) []`;
val _ = patternMatchesLib.ENABLE_PMATCH_CASES();

Theorem SmartIf_PMATCH:
!t e p q.
SmartIf t e p q =
case e of
Con _ (SOME (tag, SOME id)) [] =>
if id = bool_id then
if tag = backend_common$true_tag then p
else if tag = backend_common$false_tag then q
else If t e p q
else If t e p q
| _ => If t e p q
Proof
rpt strip_tac
\\ CONV_TAC (RAND_CONV patternMatchesLib.PMATCH_ELIM_CONV)
\\ rw [SmartIf_def]
QED

val _ = export_theory ();
46 changes: 29 additions & 17 deletions compiler/backend/flat_patternScript.sml
Expand Up @@ -107,24 +107,36 @@ Definition decode_test_def:
End

Definition simp_guard_def:
simp_guard (Conj x y) = (if x = True then simp_guard y
else if y = True then simp_guard x
else if x = Not True \/ y = Not True then Not True
else Conj (simp_guard x) (simp_guard y)) /\
simp_guard (Disj x y) = (if x = True \/ y = True then True
else if x = Not True then simp_guard y
else if y = Not True then simp_guard x
else Disj (simp_guard x) (simp_guard y)) /\
simp_guard (Not (Not x)) = simp_guard x /\
simp_guard (Not x) = Not (simp_guard x) /\
simp_guard (Conj x y) =
(let v = simp_guard x in
let w = simp_guard y in
if v = Not True \/ w = Not True then
Not True
else if v = True then w
else if w = True then v
else Conj v w) /\
simp_guard (Disj x y) =
(let v = simp_guard x in
let w = simp_guard y in
if v = True \/ w = True then
True
else if v = Not True then w
else if w = Not True then v
else Disj v w) /\
simp_guard (Not x) =
(let v = simp_guard x in
case v of
Not True => True
| Not w => w
| _ => Not v) /\
simp_guard x = x
End

Definition decode_guard_def:
decode_guard t v (Not gd) = App t Equality [decode_guard t v gd; Bool t F] /\
decode_guard t v (Conj gd1 gd2) = If t (decode_guard t v gd1)
decode_guard t v (Conj gd1 gd2) = SmartIf t (decode_guard t v gd1)
(decode_guard t v gd2) (Bool t F) /\
decode_guard t v (Disj gd1 gd2) = If t (decode_guard t v gd1) (Bool t T)
decode_guard t v (Disj gd1 gd2) = SmartIf t (decode_guard t v gd1) (Bool t T)
(decode_guard t v gd2) /\
decode_guard t v True = Bool t T /\
decode_guard t v (PosTest pos test) = decode_test t test (decode_pos t v pos)
Expand All @@ -141,7 +153,7 @@ Definition decode_dtree_def:
let dec2 = decode_dtree t br_spt v df dt2 in
if guard = True then dec1
else if guard = Not True then dec2
else If t (decode_guard t v guard) dec1 dec2
else SmartIf t (decode_guard t v guard) dec1 dec2
End

Definition encode_pat_def:
Expand All @@ -165,14 +177,14 @@ Definition naive_pattern_match_def:
naive_pattern_match t ((flatLang$Pany, _) :: mats) = naive_pattern_match t mats
/\
naive_pattern_match t ((Pvar _, _) :: mats) = naive_pattern_match t mats /\
naive_pattern_match t ((Plit l, v) :: mats) = If t
naive_pattern_match t ((Plit l, v) :: mats) = SmartIf t
(App t Equality [v; Lit t l]) (naive_pattern_match t mats) (Bool t F) /\
naive_pattern_match t ((Pcon NONE ps, v) :: mats) =
naive_pattern_match t (MAPi (\i p. (p, App t (El i) [v])) ps ++ mats) /\
naive_pattern_match t ((Pas p i, v) :: mats) =
naive_pattern_match t ((p, v) :: mats) /\
naive_pattern_match t ((Pcon (SOME stmp) ps, v) :: mats) =
If t (App t (TagLenEq (FST stmp) (LENGTH ps)) [v])
SmartIf t (App t (TagLenEq (FST stmp) (LENGTH ps)) [v])
(naive_pattern_match t (MAPi (\i p. (p, App t (El i) [v])) ps ++ mats))
(Bool t F)
/\
Expand All @@ -188,7 +200,7 @@ End
Definition naive_pattern_matches_def:
naive_pattern_matches t v [] dflt_x = dflt_x /\
naive_pattern_matches t v ((p, x) :: ps) dflt_x =
If t (naive_pattern_match t [(p, v)]) x (naive_pattern_matches t v ps dflt_x)
SmartIf t (naive_pattern_match t [(p, v)]) x (naive_pattern_matches t v ps dflt_x)
End

Definition compile_pats_def:
Expand Down Expand Up @@ -266,7 +278,7 @@ Definition compile_exp_def:
let (i, sg1, y1) = compile_exp cfg x1 in
let (j, sg2, y2) = compile_exp cfg x2 in
let (k, sg3, y3) = compile_exp cfg x3 in
(MAX i (MAX j k), sg1 \/ sg2 \/ sg3, If t y1 y2 y3)) /\
(MAX i (MAX j k), sg1 \/ sg2 \/ sg3, SmartIf t y1 y2 y3)) /\
(compile_exp cfg exp = (0, F, exp)) /\
(compile_exps cfg [] = (0, F, [])) /\
(compile_exps cfg (x::xs) =
Expand Down

0 comments on commit 950fef4

Please sign in to comment.