Skip to content

Commit 1f2a22b

Browse files
committed
Factor out Indexing.Pp_helpers, more interface files
1 parent 067169f commit 1f2a22b

File tree

7 files changed

+175
-77
lines changed

7 files changed

+175
-77
lines changed

CHANGES.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
### Added
44

5+
- Interface files for `Backends` and `Low_level`.
56
- TODO: stream-to-stream synchronization functionality, with lazy per-tensor-node synchronization.
67
- Fixed #245: tracking of used memory.
78

@@ -16,7 +17,7 @@
1617
- Moved the multicore backend from a `device = stream` model to a single device model.
1718
- Got rid of `unsafe_cleanup`.
1819
- Got rid of `subordinal`.
19-
- Removed dependency on `dore`, broke up dependency on `ppx_jane`.
20+
- Removed dependency on `core`, broke up dependency on `ppx_jane`.
2021
- TODO: Built per-tensor-node stream-to-stream synchronization into device-to-device copying functions, removed obsolete blocking synchronizations.
2122

2223
### Fixed

arrayjit/lib/backends.ml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -805,15 +805,12 @@ end
805805
module Cuda_backend : Backend_types.Backend = Lowered_backend ((
806806
Cuda_backend : Backend_types.Lowered_backend))
807807

808-
(** Initializes the backend, and if it was already initialized, performs garbage collection. *)
809808
let reinitialize (module Backend : Backend_types.Backend) config =
810809
if not @@ Backend.is_initialized () then Backend.initialize config
811810
else (
812811
Stdlib.Gc.full_major ();
813812
Backend.initialize config)
814813

815-
(** Reinitializes and returns a backend corresponding to [backend_name], or if omitted, selected via
816-
the global [backend] setting. See {!reinitialize}. *)
817814
let fresh_backend ?backend_name ?(config = Only_devices_parallel) () =
818815
let backend =
819816
match

arrayjit/lib/backends.mli

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
(** {1 The collection of the execution backends} *)
2+
3+
open Base
4+
5+
val sync_suggested_num_streams : int ref
6+
7+
module Cc_backend : Backend_types.Backend
8+
module Sync_cc_backend : Backend_types.Backend
9+
module Gccjit_backend : Backend_types.Backend
10+
module Sync_gccjit_backend : Backend_types.Backend
11+
module Cuda_backend : Backend_types.Backend
12+
13+
val reinitialize : (module Backend_types.Backend) -> Backend_types.Types.config -> unit
14+
(** Initializes the backend, and if it was already initialized, performs garbage collection. *)
15+
16+
val fresh_backend :
17+
?backend_name:string ->
18+
?config:Backend_types.Types.config ->
19+
unit ->
20+
(module Backend_types.Backend)
21+
(** Reinitializes and returns a backend corresponding to [backend_name], or if omitted, selected via
22+
the global [backend] setting. See {!reinitialize}. *)

arrayjit/lib/c_syntax.ml

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,7 @@ struct
3838
let pp_zero_out ppf tn =
3939
Stdlib.Format.fprintf ppf "@[<2>memset(%s, 0, %d);@]@ " (get_ident tn) @@ Tn.size_in_bytes tn
4040

41-
(* let pp_semi ppf () = Stdlib.Format.fprintf ppf ";@ " *)
42-
let pp_comma ppf () = Stdlib.Format.fprintf ppf ",@ "
43-
44-
(* let pp_symbol ppf sym = Stdlib.Format.fprintf ppf "%s" @@ Indexing.symbol_ident sym *)
45-
let pp_index ppf sym = Stdlib.Format.fprintf ppf "%s" @@ Indexing.symbol_ident sym
46-
47-
let pp_index_axis ppf = function
48-
| Indexing.Iterator it -> pp_index ppf it
49-
| Fixed_idx i when i < 0 -> Stdlib.Format.fprintf ppf "(%d)" i
50-
| Fixed_idx i -> Stdlib.Format.fprintf ppf "%d" i
41+
open Indexing.Pp_helpers
5142

5243
let pp_array_offset ppf (idcs, dims) =
5344
let open Stdlib.Format in
@@ -57,9 +48,9 @@ struct
5748
done;
5849
for i = 0 to Array.length idcs - 1 do
5950
let dim = dims.(i) in
60-
if i = 0 then fprintf ppf "%a" pp_index_axis idcs.(i)
61-
else if i = Array.length idcs - 1 then fprintf ppf " * %d + %a" dim pp_index_axis idcs.(i)
62-
else fprintf ppf " * %d +@ %a@;<0 -1>)@]" dim pp_index_axis idcs.(i)
51+
if i = 0 then fprintf ppf "%a" pp_axis_index idcs.(i)
52+
else if i = Array.length idcs - 1 then fprintf ppf " * %d + %a" dim pp_axis_index idcs.(i)
53+
else fprintf ppf " * %d +@ %a@;<0 -1>)@]" dim pp_axis_index idcs.(i)
6354
done
6455

6556
let array_offset_to_string (idcs, dims) =
@@ -113,15 +104,15 @@ struct
113104
fprintf ppf "@[<v 0>%a@]" (pp_print_list pp_ll)
114105
(List.filter [ c1; c2 ] ~f:(function Noop -> false | _ -> true))
115106
| For_loop { index = i; from_; to_; body; trace_it = _ } ->
116-
fprintf ppf "@[<2>for (int@ %a = %d;@ %a <= %d;@ ++%a) {@ " pp_index i from_ pp_index i
117-
to_ pp_index i;
107+
fprintf ppf "@[<2>for (int@ %a = %d;@ %a <= %d;@ ++%a) {@ " pp_symbol i from_ pp_symbol i
108+
to_ pp_symbol i;
118109
if Utils.debug_log_from_routines () then
119110
if B.logs_to_stdout then
120111
fprintf ppf {|printf(@[<h>"%s%%d: index %a = %%d\n",@] log_id, %a);@ |}
121-
!Utils.captured_log_prefix pp_index i pp_index i
112+
!Utils.captured_log_prefix pp_symbol i pp_symbol i
122113
else
123-
fprintf ppf {|fprintf(log_file,@ @[<h>"index %a = %%d\n",@] %a);@ |} pp_index i
124-
pp_index i;
114+
fprintf ppf {|fprintf(log_file,@ @[<h>"index %a = %%d\n",@] %a);@ |} pp_symbol i
115+
pp_symbol i;
125116
fprintf ppf "%a@;<1 -2>}@]@," pp_ll body
126117
| Zero_out tn ->
127118
let traced = Low_level.(get_node traced_store tn) in
@@ -235,7 +226,7 @@ struct
235226
fprintf ppf "%s%.16g%s" prefix c postfix
236227
| Embed_index idx ->
237228
let prefix, postfix = B.convert_precision ~from:Ops.double ~to_:prec in
238-
fprintf ppf "%s%a%s" prefix pp_index_axis idx postfix
229+
fprintf ppf "%s%a%s" prefix pp_axis_index idx postfix
239230
| Binop (Arg1, v1, _v2) -> loop ppf v1
240231
| Binop (Arg2, _v1, v2) -> loop ppf v2
241232
| Binop (op, v1, v2) ->

arrayjit/lib/indexing.ml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,21 @@ let derive_index ~product_syms ~(projection : axis_index array) =
175175
| it -> Second it)
176176
in
177177
fun ~product -> Array.map positions ~f:(function First p -> product.(p) | Second it -> it)
178+
179+
module Pp_helpers = struct
180+
let pp_comma ppf () = Stdlib.Format.fprintf ppf ",@ "
181+
let pp_symbol ppf sym = Stdlib.Format.fprintf ppf "%s" @@ symbol_ident sym
182+
183+
let pp_static_symbol ppf { static_symbol; static_range } =
184+
match static_range with
185+
| None -> pp_symbol ppf static_symbol
186+
| Some range -> Stdlib.Format.fprintf ppf "%a : [0..%d]" pp_symbol static_symbol (range - 1)
187+
188+
let pp_axis_index ppf idx =
189+
match idx with
190+
| Iterator sym -> pp_symbol ppf sym
191+
| Fixed_idx i -> Stdlib.Format.fprintf ppf "%d" i
192+
193+
let pp_indices ppf idcs =
194+
Stdlib.Format.pp_print_list ~pp_sep:pp_comma pp_axis_index ppf @@ Array.to_list idcs
195+
end

arrayjit/lib/low_level.ml

Lines changed: 2 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
open Base
22

33
module Lazy = Utils.Lazy
4-
(** The code for operating on n-dimensional arrays. *)
54

65
module Nd = Ndarray
76
module Tn = Tnode
@@ -26,15 +25,12 @@ end
2625
type scope_id = Scope_id.t = { tn : Tn.t; scope_id : int }
2726
[@@deriving sexp_of, equal, hash, compare]
2827

29-
(** *** Low-level representation. *)
30-
3128
let get_scope =
3229
let uid = ref 0 in
3330
fun tn ->
3431
Int.incr uid;
3532
{ tn; scope_id = !uid }
3633

37-
(** Cases: [t] -- code, [float_t] -- single number at some precision. *)
3834
type t =
3935
| Noop
4036
| Comment of string
@@ -71,12 +67,6 @@ let rec unflat_lines = function
7167
| Noop :: tl -> unflat_lines tl
7268
| llc :: tl -> Seq (llc, unflat_lines tl)
7369

74-
let comment_to_name =
75-
let nonliteral = Str.regexp {|[^a-zA-Z0-9_]|} in
76-
Str.global_replace nonliteral "_"
77-
78-
(** *** Optimization *** *)
79-
8070
type virtualize_settings = {
8171
mutable enable_device_only : bool;
8272
mutable max_visits : int;
@@ -102,31 +92,18 @@ let virtualize_settings =
10292
type visits =
10393
| Visits of int
10494
| Recurrent
105-
(** A [Recurrent] visit is when there is an access prior to any assignment in an update. *)
10695
[@@deriving sexp, equal, variants]
10796

10897
type traced_array = {
10998
tn : Tn.t;
11099
mutable computations : (Indexing.axis_index array option * t) list;
111-
(** The computations (of the tensor node) are retrieved for optimization just as they are
112-
populated, so that the inlined code corresponds precisely to the changes to the arrays
113-
that would happen up till that point. Within the code blocks paired with an index tuple,
114-
all assignments and accesses must happen via the index tuple; if this is not the case for
115-
some assignment, the node cannot be virtual. Currently, we only allow for-loop symbols in
116-
assignment indices of virtual nodes. *)
117100
assignments : int array Hash_set.t;
118101
accesses : (int array, visits) Hashtbl.t;
119-
(** For dynamic indexes, we take a value of 0. This leads to an overestimate of visits, which
120-
is safe. *)
121102
mutable zero_initialized : bool;
122103
mutable zeroed_out : bool;
123104
mutable read_before_write : bool;
124-
(** The node is read before it is written (i.e. it is recurrent). *)
125105
mutable read_only : bool;
126106
mutable is_scalar_constexpr : bool;
127-
(** True only if the tensor node has all axes of dimension 1, is either zeroed-out or assigned
128-
before accessed, is assigned at most once, and from an expression involving only constants
129-
or tensor nodes that were at the time is_scalar_constexpr. *)
130107
}
131108
[@@deriving sexp_of]
132109

@@ -144,22 +121,6 @@ let get_node store tn =
144121
is_scalar_constexpr = false;
145122
})
146123

147-
let partition_tf_with_comment cs ~f =
148-
let both = Array.map cs ~f:(fun c -> if f c then Either.First c else Either.Second c) in
149-
let trues =
150-
Array.filter_map both ~f:(function
151-
| First x -> Some x
152-
| Second (Comment _ as x) -> Some x
153-
| Second _ -> None)
154-
in
155-
let falses =
156-
Array.filter_map both ~f:(function
157-
| First (Comment _ as x) -> Some x
158-
| First _ -> None
159-
| Second x -> Some x)
160-
in
161-
(trues, falses)
162-
163124
let visit ~is_assigned old =
164125
if not is_assigned then Recurrent
165126
else
@@ -801,21 +762,8 @@ let%diagn2_sexp optimize_proc static_indices llc =
801762
{ traced_store; llc; merge_node }
802763

803764
let code_hum_margin = ref 100
804-
let pp_comma ppf () = Stdlib.Format.fprintf ppf ",@ "
805-
let pp_symbol ppf sym = Stdlib.Format.fprintf ppf "%s" @@ Indexing.symbol_ident sym
806-
807-
let pp_static_symbol ppf { Indexing.static_symbol; static_range } =
808-
match static_range with
809-
| None -> pp_symbol ppf static_symbol
810-
| Some range -> Stdlib.Format.fprintf ppf "%a : [0..%d]" pp_symbol static_symbol (range - 1)
811-
812-
let pp_index ppf idx =
813-
match idx with
814-
| Indexing.Iterator sym -> pp_symbol ppf sym
815-
| Fixed_idx i -> Stdlib.Format.fprintf ppf "%d" i
816765

817-
let pp_indices ppf idcs =
818-
Stdlib.Format.pp_print_list ~pp_sep:pp_comma pp_index ppf @@ Array.to_list idcs
766+
open Indexing.Pp_helpers
819767

820768
let fprint_function_header ?name ?static_indices () ppf =
821769
let open Stdlib.Format in
@@ -924,7 +872,7 @@ let fprint_hum ?name ?static_indices () ppf llc =
924872
fprintf ppf "@[<2>%a.merge[@,%a]@]" pp_ident tn pp_indices idcs
925873
| Get (tn, idcs) -> fprintf ppf "@[<2>%a[@,%a]@]" pp_ident tn pp_indices idcs
926874
| Constant c -> fprintf ppf "%.16g" c
927-
| Embed_index idx -> pp_index ppf idx
875+
| Embed_index idx -> pp_axis_index ppf idx
928876
| Binop (Arg1, v1, _v2) -> pp_float prec ppf v1
929877
| Binop (Arg2, _v1, v2) -> pp_float prec ppf v2
930878
| Binop (op, v1, v2) ->

arrayjit/lib/low_level.mli

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
(** {1 A for-loop-based array language and backend-agnostic optimization} *)
2+
3+
open Base
4+
5+
module Scope_id : sig
6+
type t = { tn : Tnode.t; scope_id : int } [@@deriving sexp_of, equal, hash, compare]
7+
type comparator_witness
8+
9+
val comparator : (t, comparator_witness) Base.Comparator.t
10+
end
11+
12+
type scope_id = Scope_id.t = { tn : Tnode.t; scope_id : int }
13+
[@@deriving sexp_of, equal, hash, compare]
14+
15+
(** {2 Low-level representation} *)
16+
17+
(** Cases: [t] -- code, [float_t] -- single number at some precision. *)
18+
type t =
19+
| Noop
20+
| Comment of string
21+
| Staged_compilation of (unit -> unit)
22+
| Seq of t * t
23+
| For_loop of { index : Indexing.symbol; from_ : int; to_ : int; body : t; trace_it : bool }
24+
| Zero_out of Tnode.t
25+
| Set of { tn : Tnode.t; idcs : Indexing.axis_index array; llv : float_t; mutable debug : string }
26+
| Set_local of scope_id * float_t
27+
[@@deriving sexp_of, equal]
28+
29+
and float_t =
30+
| Local_scope of { id : scope_id; body : t; orig_indices : Indexing.axis_index array }
31+
| Get_local of scope_id
32+
| Get_global of Ops.global_identifier * Indexing.axis_index array option
33+
| Get of Tnode.t * Indexing.axis_index array
34+
| Binop of Ops.binop * float_t * float_t
35+
| Unop of Ops.unop * float_t
36+
| Constant of float
37+
| Embed_index of Indexing.axis_index
38+
[@@deriving sexp_of, equal, compare]
39+
40+
val binop : op:Ops.binop -> rhs1:float_t -> rhs2:float_t -> float_t
41+
val unop : op:Ops.unop -> rhs:float_t -> float_t
42+
val flat_lines : t list -> t list
43+
val unflat_lines : t list -> t
44+
val loop_over_dims : int array -> body:(Indexing.axis_index array -> t) -> t
45+
46+
(** {2 Optimization} *)
47+
48+
type virtualize_settings = {
49+
mutable enable_device_only : bool;
50+
mutable max_visits : int;
51+
mutable max_tracing_dim : int;
52+
mutable inline_scalar_constexprs : bool;
53+
}
54+
55+
val virtualize_settings : virtualize_settings
56+
57+
type visits =
58+
| Visits of int
59+
| Recurrent
60+
(** A [Recurrent] visit is when there is an access prior to any assignment in an update. *)
61+
[@@deriving sexp, equal, variants]
62+
63+
type traced_array = {
64+
tn : Tnode.t;
65+
mutable computations : (Indexing.axis_index array option * t) list;
66+
(** The computations (of the tensor node) are retrieved for optimization just as they are
67+
populated, so that the inlined code corresponds precisely to the changes to the arrays
68+
that would happen up till that point. Within the code blocks paired with an index tuple,
69+
all assignments and accesses must happen via the index tuple; if this is not the case for
70+
some assignment, the node cannot be virtual. Currently, we only allow for-loop symbols in
71+
assignment indices of virtual nodes. *)
72+
assignments : int array Base.Hash_set.t;
73+
accesses : (int array, visits) Base.Hashtbl.t;
74+
mutable zero_initialized : bool;
75+
mutable zeroed_out : bool;
76+
mutable read_before_write : bool;
77+
(** The node is read before it is written (i.e. it is recurrent). *)
78+
mutable read_only : bool;
79+
mutable is_scalar_constexpr : bool;
80+
(** True only if the tensor node has all axes of dimension 1, is either zeroed-out or assigned
81+
before accessed, is assigned at most once, and from an expression involving only constants
82+
or tensor nodes that were at the time is_scalar_constexpr. *)
83+
}
84+
[@@deriving sexp_of]
85+
86+
val get_node : (Tnode.t, traced_array) Base.Hashtbl.t -> Tnode.t -> traced_array
87+
val optimize_integer_pow : bool ref
88+
89+
type traced_store = (Tnode.t, traced_array) Base.Hashtbl.t [@@deriving sexp_of]
90+
91+
type optimized = { traced_store : traced_store; llc : t; merge_node : Tnode.t option }
92+
[@@deriving sexp_of]
93+
94+
val optimize_proc :
95+
unoptim_ll_source:Stdlib.Format.formatter option ->
96+
ll_source:Stdlib.Format.formatter option ->
97+
name:string ->
98+
Indexing.static_symbol list ->
99+
t ->
100+
optimized
101+
102+
(** {2 Printing} *)
103+
104+
val code_hum_margin : int ref
105+
106+
val fprint_function_header :
107+
?name:string ->
108+
?static_indices:Indexing.static_symbol list ->
109+
unit ->
110+
Stdlib.Format.formatter ->
111+
unit
112+
113+
val get_ident_within_code : ?no_dots:bool -> t array -> Tnode.t -> string
114+
115+
val fprint_hum :
116+
?name:string ->
117+
?static_indices:Indexing.static_symbol list ->
118+
unit ->
119+
Stdlib.Format.formatter ->
120+
t ->
121+
unit

0 commit comments

Comments
 (0)