11open Base
22
33module Lazy = Utils. Lazy
4- (* * The code for operating on n-dimensional arrays. *)
54
65module Nd = Ndarray
76module Tn = Tnode
2625type scope_id = Scope_id .t = { tn : Tn .t ; scope_id : int }
2726[@@ deriving sexp_of , equal , hash , compare ]
2827
29- (* * *** Low-level representation. *)
30-
3128let get_scope =
3229 let uid = ref 0 in
3330 fun tn ->
3431 Int. incr uid;
3532 { tn; scope_id = ! uid }
3633
37- (* * Cases: [t] -- code, [float_t] -- single number at some precision. *)
3834type t =
3935 | Noop
4036 | Comment of string
@@ -71,12 +67,6 @@ let rec unflat_lines = function
7167 | Noop :: tl -> unflat_lines tl
7268 | llc :: tl -> Seq (llc, unflat_lines tl)
7369
74- let comment_to_name =
75- let nonliteral = Str. regexp {| [^ a- zA- Z0 -9_ ]| } in
76- Str. global_replace nonliteral " _"
77-
78- (* * *** Optimization *** *)
79-
8070type virtualize_settings = {
8171 mutable enable_device_only : bool ;
8272 mutable max_visits : int ;
@@ -102,31 +92,18 @@ let virtualize_settings =
10292type visits =
10393 | Visits of int
10494 | Recurrent
105- (* * A [Recurrent] visit is when there is an access prior to any assignment in an update. *)
10695[@@ deriving sexp , equal , variants ]
10796
10897type traced_array = {
10998 tn : Tn .t ;
11099 mutable computations : (Indexing .axis_index array option * t ) list ;
111- (* * The computations (of the tensor node) are retrieved for optimization just as they are
112- populated, so that the inlined code corresponds precisely to the changes to the arrays
113- that would happen up till that point. Within the code blocks paired with an index tuple,
114- all assignments and accesses must happen via the index tuple; if this is not the case for
115- some assignment, the node cannot be virtual. Currently, we only allow for-loop symbols in
116- assignment indices of virtual nodes. *)
117100 assignments : int array Hash_set .t ;
118101 accesses : (int array , visits ) Hashtbl .t ;
119- (* * For dynamic indexes, we take a value of 0. This leads to an overestimate of visits, which
120- is safe. *)
121102 mutable zero_initialized : bool ;
122103 mutable zeroed_out : bool ;
123104 mutable read_before_write : bool ;
124- (* * The node is read before it is written (i.e. it is recurrent). *)
125105 mutable read_only : bool ;
126106 mutable is_scalar_constexpr : bool ;
127- (* * True only if the tensor node has all axes of dimension 1, is either zeroed-out or assigned
128- before accessed, is assigned at most once, and from an expression involving only constants
129- or tensor nodes that were at the time is_scalar_constexpr. *)
130107}
131108[@@ deriving sexp_of ]
132109
@@ -144,22 +121,6 @@ let get_node store tn =
144121 is_scalar_constexpr = false ;
145122 })
146123
147- let partition_tf_with_comment cs ~f =
148- let both = Array. map cs ~f: (fun c -> if f c then Either. First c else Either. Second c) in
149- let trues =
150- Array. filter_map both ~f: (function
151- | First x -> Some x
152- | Second (Comment _ as x ) -> Some x
153- | Second _ -> None )
154- in
155- let falses =
156- Array. filter_map both ~f: (function
157- | First (Comment _ as x ) -> Some x
158- | First _ -> None
159- | Second x -> Some x)
160- in
161- (trues, falses)
162-
163124let visit ~is_assigned old =
164125 if not is_assigned then Recurrent
165126 else
@@ -801,21 +762,8 @@ let%diagn2_sexp optimize_proc static_indices llc =
801762 { traced_store; llc; merge_node }
802763
803764let code_hum_margin = ref 100
804- let pp_comma ppf () = Stdlib.Format. fprintf ppf " ,@ "
805- let pp_symbol ppf sym = Stdlib.Format. fprintf ppf " %s" @@ Indexing. symbol_ident sym
806-
807- let pp_static_symbol ppf { Indexing. static_symbol; static_range } =
808- match static_range with
809- | None -> pp_symbol ppf static_symbol
810- | Some range -> Stdlib.Format. fprintf ppf " %a : [0..%d]" pp_symbol static_symbol (range - 1 )
811-
812- let pp_index ppf idx =
813- match idx with
814- | Indexing. Iterator sym -> pp_symbol ppf sym
815- | Fixed_idx i -> Stdlib.Format. fprintf ppf " %d" i
816765
817- let pp_indices ppf idcs =
818- Stdlib.Format. pp_print_list ~pp_sep: pp_comma pp_index ppf @@ Array. to_list idcs
766+ open Indexing.Pp_helpers
819767
820768let fprint_function_header ?name ?static_indices () ppf =
821769 let open Stdlib.Format in
@@ -924,7 +872,7 @@ let fprint_hum ?name ?static_indices () ppf llc =
924872 fprintf ppf " @[<2>%a.merge[@,%a]@]" pp_ident tn pp_indices idcs
925873 | Get (tn , idcs ) -> fprintf ppf " @[<2>%a[@,%a]@]" pp_ident tn pp_indices idcs
926874 | Constant c -> fprintf ppf " %.16g" c
927- | Embed_index idx -> pp_index ppf idx
875+ | Embed_index idx -> pp_axis_index ppf idx
928876 | Binop (Arg1, v1 , _v2 ) -> pp_float prec ppf v1
929877 | Binop (Arg2, _v1 , v2 ) -> pp_float prec ppf v2
930878 | Binop (op , v1 , v2 ) ->
0 commit comments