Skip to content

Commit 93d161d

Browse files
committed
Improving Total_elems shape inference: safe wrt. forcing; cover missing cases; address remaining known FIXMEs
1 parent e2ea2d5 commit 93d161d

File tree

4 files changed

+326
-193
lines changed

4 files changed

+326
-193
lines changed

arrayjit/lib/utils.ml

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -315,12 +315,8 @@ let get_local_debug_runtime =
315315
@@ "ocannl_debug_backend setting should be text, html, markdown or flushing; found: " ^ s
316316
in
317317
let hyperlink = get_global_arg ~default:"./" ~arg_name:"hyperlink_prefix" in
318-
let print_entry_ids =
319-
get_global_flag ~default:false ~arg_name:"logs_print_entry_ids"
320-
in
321-
let verbose_entry_ids =
322-
get_global_flag ~default:false ~arg_name:"logs_verbose_entry_ids"
323-
in
318+
let print_entry_ids = get_global_flag ~default:false ~arg_name:"logs_print_entry_ids" in
319+
let verbose_entry_ids = get_global_flag ~default:false ~arg_name:"logs_verbose_entry_ids" in
324320
let log_main_domain_to_stdout =
325321
get_global_flag ~default:false ~arg_name:"log_main_domain_to_stdout"
326322
in
@@ -444,9 +440,7 @@ let restore_settings () =
444440
let () = restore_settings ()
445441
let with_runtime_debug () = settings.output_debug_files_in_build_directory && settings.log_level > 1
446442
let debug_log_from_routines () = settings.debug_log_from_routines && settings.log_level > 1
447-
448-
let never_capture_stdout () =
449-
get_global_flag ~default:false ~arg_name:"never_capture_stdout"
443+
let never_capture_stdout () = get_global_flag ~default:false ~arg_name:"never_capture_stdout"
450444

451445
let enable_runtime_debug () =
452446
settings.output_debug_files_in_build_directory <- true;
@@ -942,3 +936,55 @@ let weak_iter (arr : 'a weak_dynarray) ~f =
942936
for i = 0 to W.length !arr - 1 do
943937
Option.iter (W.get !arr i) ~f
944938
done
939+
940+
type 'a safe_lazy = {
941+
mutable value : [ `Callback of unit -> 'a | `Value of 'a ];
942+
unique_id : string;
943+
}
944+
945+
let safe_lazy unique_id f = { value = `Callback f; unique_id }
946+
947+
let safe_force gated =
948+
match gated.value with
949+
| `Value v -> v
950+
| `Callback f ->
951+
let v = f () in
952+
gated.value <- `Value v;
953+
v
954+
955+
let safe_map ~upd ~f gated =
956+
let unique_id = gated.unique_id ^ "_" ^ upd in
957+
match gated.value with
958+
| `Value v -> { value = `Value (f v); unique_id }
959+
| `Callback callback -> { value = `Callback (fun () -> f (callback ())); unique_id }
960+
961+
let equal_safe_lazy equal_elem g1 g2 =
962+
match (g1.value, g2.value) with
963+
| `Value v1, `Value v2 ->
964+
(* Both values are forced - assert uniqueness *)
965+
let id_equal = String.equal g1.unique_id g2.unique_id in
966+
if id_equal then assert (equal_elem v1 v2);
967+
id_equal
968+
| _ -> String.equal g1.unique_id g2.unique_id
969+
970+
let compare_safe_lazy compare_elem g1 g2 =
971+
match (g1.value, g2.value) with
972+
| `Value v1, `Value v2 ->
973+
(* Both values are forced - assert uniqueness *)
974+
let id_cmp = String.compare g1.unique_id g2.unique_id in
975+
if id_cmp = 0 then assert (compare_elem v1 v2 = 0);
976+
id_cmp
977+
| _ -> String.compare g1.unique_id g2.unique_id
978+
979+
let hash_fold_safe_lazy _hash_elem state gated = hash_fold_string state gated.unique_id
980+
981+
let sexp_of_safe_lazy sexp_of_elem gated =
982+
let status =
983+
match gated.value with `Callback _ -> Sexp.Atom "pending" | `Value v -> sexp_of_elem v
984+
in
985+
Sexp.List
986+
[
987+
Sexp.Atom "safe_lazy";
988+
Sexp.List [ Sexp.Atom "id"; Sexp.Atom gated.unique_id ];
989+
Sexp.List [ Sexp.Atom "value"; status ];
990+
]

0 commit comments

Comments
 (0)