Skip to content

Commit 58ecce7

Browse files
committed
Refactor backend types into a separate file
1 parent 7060a2f commit 58ecce7

File tree

11 files changed

+349
-408
lines changed

11 files changed

+349
-408
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ A possible route to learning OCANNL:
4646
2. Get some basic grasp of the aims and design of the project by reading or skimming files in [test/](test/) and [bin/](bin/).
4747
3. Read the syntax extensions documentation [lib/syntax_extensions.md](lib/syntax_extensions.md).
4848
4. Read the introductory part of the shape inference documentation [lib/shape_inference.md](lib/shape_inference.md).
49-
5. Improve your understanding by reading or skimming [lib/shape.mli](lib/shape.mli), [lib/tensor.mli](lib/tensor.mli), [lib/operation.ml](lib/operation.ml), [lib/train.ml](lib/train.ml), and (since 0.4.1) [lib/nn_blocks.ml](lib/nn_blocks.ml).
49+
5. Improve your understanding by reading or skimming: [lib/shape.mli](lib/shape.mli), [lib/tensor.mli](lib/tensor.mli), [lib/operation.ml](lib/operation.ml), [arrayjit/lib/backend_types.ml](arrayjit/lib/backend_types.ml), [lib/train.ml](lib/train.ml), and [lib/nn_blocks.ml](lib/nn_blocks.ml).
5050
6. Read [arrayjit/lib/writing_a_backend.md](arrayjit/lib/writing_a_backend.md).
5151
7. Read the implementation overview:
5252
1. Shape inference details [lib/shape_inference.md](lib/shape_inference.md).

arrayjit/lib/backend_types.ml

Lines changed: 320 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,320 @@
1+
open Base
2+
module Lazy = Utils.Lazy
3+
module Debug_runtime = Utils.Debug_runtime
4+
5+
let _get_local_debug_runtime = Utils._get_local_debug_runtime
6+
7+
[%%global_debug_log_level 9]
8+
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]
9+
10+
module Types = struct
11+
type 'context routine = {
12+
context : 'context;
13+
schedule : Tnode.task;
14+
bindings : Indexing.lowered_bindings;
15+
name : string;
16+
}
17+
[@@deriving sexp_of]
18+
19+
type config = Physical_devices_only | For_parallel_copying | Most_parallel_devices
20+
[@@deriving equal, sexp, variants]
21+
22+
type merge_buffer_use = No | Streaming | Copy [@@deriving equal, sexp]
23+
24+
type param_source =
25+
| Log_file_name
26+
| Merge_buffer
27+
| Param_ptr of Tnode.t
28+
| Static_idx of Indexing.static_symbol
29+
[@@deriving sexp_of]
30+
end
31+
32+
module type No_device_backend = sig
33+
type code [@@deriving sexp_of]
34+
type code_batch [@@deriving sexp_of]
35+
type buffer_ptr [@@deriving sexp_of]
36+
type context [@@deriving sexp_of]
37+
type routine = context Types.routine [@@deriving sexp_of]
38+
39+
val name : string
40+
41+
val initialize : Types.config -> unit
42+
(** Initializes a backend before first use or (on some backends) after {!unsafe_cleanup}. Does
43+
nothing if the backend is already initialized. *)
44+
45+
val is_initialized : unit -> bool
46+
(** Returns false if there was no previous {!initialize} call, or, on some backends, the most
47+
recent call was followed by {!unsafe_cleanup}. If it returns false, one must call
48+
{!initialize} before using the backend. *)
49+
50+
val init : label:string -> context
51+
(** [label] is usually the backend name concatenated with the device number. *)
52+
53+
val finalize : context -> unit
54+
(** Finalizes (just) the context. *)
55+
56+
val alloc_buffer : ?old_buffer:buffer_ptr * int -> size_in_bytes:int -> unit -> buffer_ptr
57+
val expected_merge_node : code -> Tnode.t option
58+
val expected_merge_nodes : code_batch -> Tnode.t option array
59+
60+
val compile : ?shared:bool -> ?name:string -> Indexing.unit_bindings -> Assignments.comp -> code
61+
(** If [~shared:true] (default [false]), the backend should prefer to do more compile work in a
62+
device-agnostic way. If [~shared:false], the backend can opt to postpone compiling altogether
63+
until [link] is called, to benefit from more optimizations. *)
64+
65+
val compile_batch :
66+
?shared:bool ->
67+
?names:string array ->
68+
?occupancy:(name:string -> src_n:int -> bool) ->
69+
Indexing.unit_bindings ->
70+
Assignments.comp array ->
71+
code_batch
72+
(** Unlike the [~shared] parameter, [compile_batch] vs. [compile] is mostly about improving the
73+
compile time and debugging convenience by generating fewer files -- ideally does not affect
74+
execution, but there can be backend-specific differences. Only array entries for which
75+
[occupancy] returns true are included. *)
76+
77+
val link : merge_buffer:(buffer_ptr * Tnode.t) option ref -> context -> code -> routine
78+
(** Returns the routine for the code's procedure, in a new context derived from the given context. *)
79+
80+
val link_batch :
81+
merge_buffer:(buffer_ptr * Tnode.t) option ref ->
82+
context ->
83+
code_batch ->
84+
context * routine option array
85+
(** Returns the routines for the procedures included in the code batch. The returned context is
86+
downstream of all the returned routines (in particular, the routines' contexts are not
87+
independent). *)
88+
89+
val unsafe_cleanup : unit -> unit
90+
(** Cleans up all work on a backend, releases resources. All previously retrieved values
91+
(contexts, virtual and physical devices) become invalid. The backend needs to be initialized
92+
again to be used again. *)
93+
94+
val to_buffer : Tnode.t -> dst:buffer_ptr -> src:context -> unit
95+
val host_to_buffer : Ndarray.t -> dst:buffer_ptr -> unit
96+
val buffer_to_host : Ndarray.t -> src:buffer_ptr -> unit
97+
val get_buffer : Tnode.t -> context -> buffer_ptr option
98+
end
99+
100+
module type Backend = sig
101+
include No_device_backend
102+
103+
val link : context -> code -> routine
104+
(** Returns the routine for the code's procedure, in a new context derived from the given context. *)
105+
106+
val link_batch : context -> code_batch -> context * routine option array
107+
(** Returns the routines for the procedures included in the code batch. The returned context is
108+
downstream of all the returned routines. *)
109+
110+
type event
111+
(** An event tracks if a device finished computing past a particular point in its schedue. These
112+
values are used internally for scheduling across devices of the backend, and can be used for
113+
explicit scheduling. *)
114+
115+
val sync : event -> unit
116+
(** Blocks till the event completes, if it's not done already. *)
117+
118+
val is_done : event -> bool
119+
(** Whether the event completed. *)
120+
121+
val work_for : context -> Tnode.t -> event option
122+
(** If the tensor node is in the context, returns the event indicating if currently running or
123+
scheduled computations modifying that node on the context's device have completed.
124+
125+
NOTE: [work_for ctx tn], if work tracking was not yet registered for [tn], will register work
126+
tracking for [tn] and return the [all_work] event for [ctx]'s device. *)
127+
128+
val will_wait_for : context -> event -> unit
129+
(** Schedules waiting for the given event on the context's device.
130+
131+
NOTE: it should rarely be needed to call [will_wait_for] explicitly, because it is typically
132+
called internally when necessary. But there is one exception, see {!device_to_device} when
133+
[into_merge_buffer=Streaming]. *)
134+
135+
val from_host : context -> Tnode.t -> bool
136+
(** If the tensor node is both hosted and in-context, schedules a copy from host to context and
137+
returns true, otherwise returns false. NOTE: it's the caller's responsibility to synchronize
138+
the device (via [await ctx.device] or [sync (work_for ctx tn)]) before the host's data is
139+
overwritten. *)
140+
141+
val to_host : context -> Tnode.t -> bool
142+
(** If the tensor node is both hosted and in-context, schedules a copy from context to host and
143+
returns true, otherwise returns false. NOTE: it's the caller's responsibility to synchronize
144+
the device (via [await ctx.device] or [sync (work_for ctx tn)]) before the host's data is
145+
read. *)
146+
147+
val device_to_device :
148+
Tnode.t -> into_merge_buffer:Types.merge_buffer_use -> dst:context -> src:context -> bool
149+
(** [device_to_device tn ~into_merge_buffer ~dst ~src] proceeds as follows:
150+
- If the node is absent from the [src] context and either it is present in the [dst] context
151+
or [into_merge_buffer] is different from [No]: raises an error.
152+
- If the node is absent from [dst] and [into_merge_buffer=No]: returns false.
153+
- Executes [will_wait_for dst (work_for src tn)].
154+
- If [into_merge_buffer=No]: schedules a copy of the tensor node from the device of [src] to
155+
the device of [dst].
156+
- If [into_merge_buffer] is different from [No]: sets on [dst] the merge buffer source to the
157+
given node. If [into_merge_buffer=Streaming], remembers the buffer pointer of the source
158+
node to use for streaming, without blocking. If [into_merge_buffer=Copy], schedules copying
159+
from [src] to the merge buffer of [dst]'s device.
160+
- If the [dst] context resulted from a compilation with [Streaming] or [Copy] specific merge
161+
buffer code, the [device_to_device] call should fail immediately if there's a mismatch with
162+
[into_merge_buffer].
163+
164+
NOTE: If [into_merge_buffer:Streaming], after scheduling the work on [dst] using the merge
165+
buffer but before scheduling work on [src] that modifies [tn], execute
166+
[will_wait_for src (all_work (get_ctx_device dst))]. *)
167+
168+
type physical_device
169+
type device
170+
171+
val init : device -> context
172+
val alloc_buffer : ?old_buffer:buffer_ptr * int -> size_in_bytes:int -> device -> buffer_ptr
173+
174+
val await : device -> unit
175+
(** Blocks till the device becomes idle, i.e. synchronizes the device. *)
176+
177+
val all_work : device -> event
178+
(** Returns the event indicating if any currently running or scheduled computations on the device
179+
have completed. *)
180+
181+
val is_idle : device -> bool
182+
(** Whether the device is currently waiting for work. *)
183+
184+
val sexp_of_device : device -> Sexp.t
185+
val get_device : ordinal:int -> physical_device
186+
val num_physical_devices : unit -> int
187+
188+
val suggested_num_virtual_devices : physical_device -> int
189+
(** The optimal number of virtual devices for the given physical device to follow the
190+
{!Types.config} strategy passed to {!No_device_backend.initialize}. *)
191+
192+
val new_virtual_device : physical_device -> device
193+
val get_ctx_device : context -> device
194+
val get_physical_device : device -> physical_device
195+
val to_ordinal : physical_device -> int
196+
val to_subordinal : device -> int
197+
val get_name : device -> string
198+
end
199+
200+
module type Simple_backend = sig
201+
type context [@@deriving sexp_of]
202+
type procedure [@@deriving sexp_of]
203+
type ctx_array [@@deriving sexp_of]
204+
type buffer_ptr [@@deriving sexp_of]
205+
type ctx_arrays = ctx_array Map.M(Tnode).t [@@deriving sexp_of]
206+
207+
val buffer_ptr : ctx_array -> buffer_ptr
208+
val ctx_arrays : context -> ctx_arrays
209+
val alloc_buffer : ?old_buffer:buffer_ptr * int -> size_in_bytes:int -> unit -> buffer_ptr
210+
val expected_merge_node : procedure -> Tnode.t option
211+
212+
val is_in_context : Low_level.traced_array -> bool
213+
(** If true, the node is required to be in the contexts linked with code that uses it.
214+
215+
Should return false for nodes that are virtual, local, or which the backend prefers to access
216+
directly from the host. *)
217+
218+
val compile :
219+
name:string ->
220+
opt_ctx_arrays:ctx_arrays option ->
221+
Indexing.unit_bindings ->
222+
Low_level.optimized ->
223+
procedure
224+
225+
val compile_batch :
226+
names:string option array ->
227+
opt_ctx_arrays:ctx_arrays option ->
228+
Indexing.unit_bindings ->
229+
Low_level.optimized option array ->
230+
ctx_arrays option * procedure option array
231+
232+
val link_compiled :
233+
merge_buffer:(buffer_ptr * Tnode.t) option ref ->
234+
context ->
235+
procedure ->
236+
context * Indexing.lowered_bindings * Tnode.task * string
237+
238+
val name : string
239+
val initialize : unit -> unit
240+
val is_initialized : unit -> bool
241+
val init : label:string -> context
242+
val finalize : context -> unit
243+
val unsafe_cleanup : unit -> unit
244+
val to_buffer : Tnode.t -> dst:buffer_ptr -> src:context -> unit
245+
val host_to_buffer : Ndarray.t -> dst:buffer_ptr -> unit
246+
val buffer_to_host : Ndarray.t -> src:buffer_ptr -> unit
247+
end
248+
249+
module type Lowered_backend = sig
250+
type context [@@deriving sexp_of]
251+
type code [@@deriving sexp_of]
252+
type code_batch [@@deriving sexp_of]
253+
type ctx_array [@@deriving sexp_of]
254+
type event
255+
256+
val sync : event -> unit
257+
val is_done : event -> bool
258+
val work_for : context -> Tnode.t -> event option
259+
val will_wait_for : context -> event -> unit
260+
261+
open Types
262+
263+
val initialize : config -> unit
264+
val is_initialized : unit -> bool
265+
val finalize : context -> unit
266+
val sexp_of_context : context -> Sexplib.Sexp.t
267+
val compile : name:string -> Indexing.unit_bindings -> Low_level.optimized -> code
268+
269+
val compile_batch :
270+
names:string option array ->
271+
Indexing.unit_bindings ->
272+
Low_level.optimized option array ->
273+
code_batch
274+
275+
val is_in_context : Low_level.traced_array -> bool
276+
val ctx_arrays : context -> ctx_array Map.M(Tnode).t
277+
val link : context -> code -> context * Indexing.lowered_bindings * Tnode.task
278+
279+
val link_batch :
280+
context -> code_batch -> context * Indexing.lowered_bindings * Tnode.task option array
281+
282+
val unsafe_cleanup : unit -> unit
283+
284+
val from_host : context -> Tnode.t -> bool
285+
(** If the array is both hosted and in-context, copies from host to context. *)
286+
287+
val to_host : context -> Tnode.t -> bool
288+
(** If the array is both hosted and in-context, copies from context to host. *)
289+
290+
val device_to_device :
291+
Tnode.t -> into_merge_buffer:merge_buffer_use -> dst:context -> src:context -> bool
292+
(** If the array is in both contexts, copies from [dst] to [src]. *)
293+
294+
type buffer_ptr [@@deriving sexp_of]
295+
296+
val to_buffer : Tnode.t -> dst:buffer_ptr -> src:context -> unit
297+
val host_to_buffer : Ndarray.t -> dst:buffer_ptr -> unit
298+
val buffer_to_host : Ndarray.t -> src:buffer_ptr -> unit
299+
val get_buffer : Tnode.t -> context -> buffer_ptr option
300+
301+
type physical_device
302+
type device
303+
304+
val alloc_buffer : ?old_buffer:buffer_ptr * int -> size_in_bytes:int -> device -> buffer_ptr
305+
val init : device -> context
306+
val await : device -> unit
307+
val is_idle : device -> bool
308+
val all_work : device -> event
309+
val sexp_of_device : device -> Sexplib.Sexp.t
310+
val num_physical_devices : unit -> int
311+
val suggested_num_virtual_devices : physical_device -> int
312+
val get_device : ordinal:int -> physical_device
313+
val get_physical_device : device -> physical_device
314+
val new_virtual_device : physical_device -> device
315+
val get_ctx_device : context -> device
316+
val get_name : device -> string
317+
val to_ordinal : physical_device -> int
318+
val to_subordinal : device -> int
319+
val name : string
320+
end

arrayjit/lib/backend_utils.ml

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,6 @@ let _get_local_debug_runtime = Utils._get_local_debug_runtime
77
[%%global_debug_log_level 9]
88
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]
99

10-
module Types = struct
11-
type 'context routine = {
12-
context : 'context;
13-
schedule : Tnode.task;
14-
bindings : Indexing.lowered_bindings;
15-
name : string;
16-
}
17-
[@@deriving sexp_of]
18-
19-
type config = Physical_devices_only | For_parallel_copying | Most_parallel_devices
20-
[@@deriving equal, sexp, variants]
21-
22-
type merge_buffer_use = No | Streaming | Copy [@@deriving equal, sexp]
23-
24-
type param_source =
25-
| Log_file_name
26-
| Merge_buffer
27-
| Param_ptr of Tnode.t
28-
| Static_idx of Indexing.static_symbol
29-
[@@deriving sexp_of]
30-
end
31-
3210
module Tn = Tnode
3311

3412
module C_syntax (B : sig
@@ -50,7 +28,7 @@ module C_syntax (B : sig
5028
val convert_precision : from:Ops.prec -> to_:Ops.prec -> string * string
5129
end) =
5230
struct
53-
open Types
31+
open Backend_types.Types
5432

5533
let get_ident =
5634
Low_level.get_ident_within_code ~no_dots:true @@ Array.map B.for_lowereds ~f:(fun l -> l.llc)

0 commit comments

Comments
 (0)