Skip to content

Commit 1416586

Browse files
committed
More fine-grained refactoring of backend APIs
1 parent 3cb9936 commit 1416586

File tree

7 files changed

+107
-101
lines changed

7 files changed

+107
-101
lines changed

CHANGES.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
- Got rid of `unsafe_cleanup`.
1919
- Got rid of `subordinal`.
2020
- Removed dependency on `core`, broke up dependency on `ppx_jane`.
21-
- TODO: Built per-tensor-node stream-to-stream synchronization into device-to-device copying functions, removed obsolete blocking synchronizations.
21+
- TODO: Built per-tensor-node stream-to-stream synchronization into copying functions, removed obsolete blocking synchronizations.
2222

2323
### Fixed
2424

arrayjit/lib/backend_types.ml

Lines changed: 80 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ module Types = struct
2626
}
2727
[@@deriving sexp_of]
2828

29+
(** For now, we only configure a backend with regard to how many streams it should suggest using
30+
(where applicable). *)
2931
type config = Only_devices_parallel | For_parallel_copying | Most_parallel_streams
3032
[@@deriving equal, sexp, variants]
3133

@@ -39,12 +41,10 @@ module Types = struct
3941
[@@deriving sexp_of]
4042
end
4143

42-
module type Backend_common = sig
43-
type code [@@deriving sexp_of]
44-
type code_batch [@@deriving sexp_of]
44+
(** Parts shared by both assignments-level and lowered-level backend interfaces. *)
45+
module type Backend_any_common = sig
4546
type buffer_ptr [@@deriving sexp_of]
4647
type context [@@deriving sexp_of]
47-
type routine = context Types.routine [@@deriving sexp_of]
4848
type stream
4949

5050
type init_info
@@ -67,9 +67,15 @@ module type Backend_common = sig
6767
(** Finalizes (just) the context. *)
6868

6969
val alloc_buffer : ?old_buffer:buffer_ptr * int -> size_in_bytes:int -> stream -> buffer_ptr
70+
end
7071

71-
val get_used_memory : unit -> int
72-
(** Returns (an upper bound of) the memory used for arrays, in bytes. *)
72+
(** Parts shared by assignments-level backend interfaces. *)
73+
module type Backend_common = sig
74+
include Backend_any_common
75+
76+
type routine = context Types.routine [@@deriving sexp_of]
77+
type code [@@deriving sexp_of]
78+
type code_batch [@@deriving sexp_of]
7379

7480
val compile : ?shared:bool -> ?name:string -> Indexing.unit_bindings -> Assignments.comp -> code
7581
(** If [~shared:true] (default [false]), the backend should prefer to do more compile work in a
@@ -89,6 +95,7 @@ module type Backend_common = sig
8995
[occupancy] returns true are included. *)
9096
end
9197

98+
(** An intermediate interface for stream-agnostic (typically CPU) backend implementations. *)
9299
module type No_device_backend = sig
93100
include Backend_common with type init_info := string and type stream := unit
94101

@@ -104,23 +111,21 @@ module type No_device_backend = sig
104111
downstream of all the returned routines (in particular, the routines' contexts are not
105112
independent). *)
106113

114+
val get_used_memory : unit -> int
115+
(** Returns (an upper bound of) the memory used for arrays, in bytes. *)
116+
107117
val to_buffer : Tnode.t -> dst:buffer_ptr -> src:context -> unit
108118
val host_to_buffer : Ndarray.t -> dst:buffer_ptr -> unit
109119
val buffer_to_host : Ndarray.t -> src:buffer_ptr -> unit
110120
val get_buffer : Tnode.t -> context -> buffer_ptr option
111121
end
112122

113-
module type Backend = sig
123+
(** Parts shared by both assignments-level and lowered-level backend interfaces providing streams
124+
and devices. *)
125+
module type Backend_device_common = sig
114126
type stream [@@deriving sexp_of]
115127

116-
include Backend_common with type init_info := stream and type stream := stream
117-
118-
val link : context -> code -> routine
119-
(** Returns the routine for the code's procedure, in a new context derived from the given context. *)
120-
121-
val link_batch : context -> code_batch -> context * routine option array
122-
(** Returns the routines for the procedures included in the code batch. The returned context is
123-
downstream of all the returned routines. *)
128+
include Backend_any_common with type init_info := stream and type stream := stream
124129

125130
type event
126131
(** An event tracks if a stream finished computing past a particular point in its schedue. These
@@ -147,6 +152,51 @@ module type Backend = sig
147152
called internally when necessary. But there is one exception, see {!device_to_device} when
148153
[into_merge_buffer=Streaming]. *)
149154

155+
type device
156+
157+
val get_used_memory : device -> int
158+
(** Returns (an upper bound of) the memory used for arrays, in bytes. *)
159+
160+
val await : stream -> unit
161+
(** Blocks till the stream becomes idle, i.e. synchronizes the stream. *)
162+
163+
val all_work : stream -> event
164+
(** Returns the event indicating if any currently running or scheduled computations on the stream
165+
have completed. *)
166+
167+
val is_idle : stream -> bool
168+
(** Whether the stream is currently waiting for work. *)
169+
170+
val get_device : ordinal:int -> device
171+
val num_devices : unit -> int
172+
173+
val suggested_num_streams : device -> int
174+
(** The optimal number of streams for the given device to follow the {!Types.config} strategy
175+
passed to {!No_device_backend.initialize}. *)
176+
177+
val new_stream : device -> stream
178+
val get_ctx_stream : context -> stream
179+
val get_stream_device : stream -> device
180+
val to_ordinal : device -> int
181+
val get_name : stream -> string
182+
end
183+
184+
module type Backend = sig
185+
include Backend_device_common
186+
187+
include
188+
Backend_common
189+
with type context := context
190+
and type init_info := stream
191+
and type stream := stream
192+
193+
val link : context -> code -> routine
194+
(** Returns the routine for the code's procedure, in a new context derived from the given context. *)
195+
196+
val link_batch : context -> code_batch -> context * routine option array
197+
(** Returns the routines for the procedures included in the code batch. The returned context is
198+
downstream of all the returned routines. *)
199+
150200
val from_host : context -> Tnode.t -> bool
151201
(** If the tensor node is both hosted and in-context, schedules a copy from host to context and
152202
returns true, otherwise returns false. NOTE: it's the caller's responsibility to synchronize
@@ -175,47 +225,16 @@ module type Backend = sig
175225
NOTE: If [into_merge_buffer=Streaming], after scheduling the work on [dst] using the merge
176226
buffer but before scheduling work on [src] that modifies [tn], execute
177227
[will_wait_for src (all_work (get_ctx_stream dst))]. *)
178-
179-
type device
180-
181-
val get_used_memory : device -> int
182-
(** Returns (an upper bound of) the memory used for arrays, in bytes. *)
183-
184-
val await : stream -> unit
185-
(** Blocks till the stream becomes idle, i.e. synchronizes the stream. *)
186-
187-
val all_work : stream -> event
188-
(** Returns the event indicating if any currently running or scheduled computations on the stream
189-
have completed. *)
190-
191-
val is_idle : stream -> bool
192-
(** Whether the stream is currently waiting for work. *)
193-
194-
val get_device : ordinal:int -> device
195-
val num_devices : unit -> int
196-
197-
val suggested_num_streams : device -> int
198-
(** The optimal number of streams for the given device to follow the {!Types.config} strategy
199-
passed to {!No_device_backend.initialize}. *)
200-
201-
val new_stream : device -> stream
202-
val get_ctx_stream : context -> stream
203-
val get_stream_device : stream -> device
204-
val to_ordinal : device -> int
205-
val get_name : stream -> string
206228
end
207229

230+
(** Parts shared by lowered-level backends excluding what's already in {!Backend_any_common}. *)
208231
module type Lowered_backend_common = sig
209232
type context [@@deriving sexp_of]
210233
type ctx_array [@@deriving sexp_of]
211234
type ctx_arrays [@@deriving sexp_of]
212-
type buffer_ptr [@@deriving sexp_of]
213-
type config
214-
type init_info
215-
type stream
235+
type buffer_ptr
216236

217237
val buffer_ptr : ctx_array -> buffer_ptr
218-
val alloc_buffer : ?old_buffer:buffer_ptr * int -> size_in_bytes:int -> stream -> buffer_ptr
219238
val ctx_arrays : context -> ctx_arrays
220239
val get_array : ctx_arrays -> Tnode.t -> ctx_array option
221240

@@ -224,20 +243,18 @@ module type Lowered_backend_common = sig
224243
225244
Should return false for nodes that are virtual, local, or which the backend prefers to access
226245
directly from the host. *)
227-
228-
val initialize : config -> unit
229-
val is_initialized : unit -> bool
230-
val init : init_info -> context
231-
val finalize : context -> unit
232-
val name : string
233246
end
234247

248+
(** Lowered-level stream agnostic backend interface: implementation-facing API for CPU backends. *)
235249
module type Lowered_no_device_backend = sig
250+
include Lowered_backend_common
251+
236252
include
237-
Lowered_backend_common
238-
with type stream := unit
239-
and type config := unit
253+
Backend_any_common
254+
with type context := context
255+
and type stream := unit
240256
and type init_info := string
257+
and type buffer_ptr := buffer_ptr
241258

242259
type procedure [@@deriving sexp_of]
243260

@@ -266,27 +283,15 @@ module type Lowered_no_device_backend = sig
266283
val buffer_to_host : Ndarray.t -> src:buffer_ptr -> unit
267284
end
268285

286+
(** Lowered-level backend interface: implementation-facing API for device-based (typically GPU)
287+
backends. *)
269288
module type Lowered_backend = sig
270-
type stream [@@deriving sexp_of]
271-
272-
include
273-
Lowered_backend_common
274-
with type config := Types.config
275-
and type stream := stream
276-
and type init_info := stream
277-
289+
include Lowered_backend_common
290+
include Backend_device_common with type context := context and type buffer_ptr := buffer_ptr
291+
278292
type code [@@deriving sexp_of]
279293
type code_batch [@@deriving sexp_of]
280-
type event
281294

282-
val sync : event -> unit
283-
val is_done : event -> bool
284-
val work_for : context -> Tnode.t -> event option
285-
val will_wait_for : context -> event -> unit
286-
287-
open Types
288-
289-
val sexp_of_context : context -> Sexplib.Sexp.t
290295
val compile : name:string -> Indexing.unit_bindings -> Low_level.optimized -> code
291296

292297
val compile_batch :
@@ -301,34 +306,13 @@ module type Lowered_backend = sig
301306
context -> code_batch -> context * Indexing.lowered_bindings * Task.t option array
302307

303308
val from_host : context -> Tnode.t -> bool
304-
(** If the array is both hosted and in-context, copies from host to context. *)
305309

306310
val to_host : context -> Tnode.t -> bool
307-
(** If the array is both hosted and in-context, copies from context to host. *)
308311

309312
val device_to_device :
310-
Tnode.t -> into_merge_buffer:merge_buffer_use -> dst:context -> src:context -> bool
311-
(** See {!Backend.device_to_device}. *)
312-
313-
type device
314-
315-
val get_used_memory : device -> int
316-
(** Returns (an upper bound of) the memory used for arrays, in bytes. *)
317-
318-
val await : stream -> unit
319-
val is_idle : stream -> bool
320-
val all_work : stream -> event
313+
Tnode.t -> into_merge_buffer:Types.merge_buffer_use -> dst:context -> src:context -> bool
321314

322315
val scheduled_merge_node : stream -> Tnode.t option
323316
(** [scheduled_merge_node stream] is the tensor node that would be in the [stream]'s merge buffer
324317
right after [await stream]. *)
325-
326-
val num_devices : unit -> int
327-
val suggested_num_streams : device -> int
328-
val get_device : ordinal:int -> device
329-
val get_stream_device : stream -> device
330-
val new_stream : device -> stream
331-
val get_ctx_stream : context -> stream
332-
val get_name : stream -> string
333-
val to_ordinal : device -> int
334318
end

arrayjit/lib/backends.ml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ let check_merge_buffer ~scheduled_node ~code_node =
1818
("Merge buffer mismatch, on stream: " ^ name scheduled_node ^ ", expected by code: "
1919
^ name code_node)
2020

21+
(* module *)
22+
2123
module Multicore_backend (Backend : Backend_types.No_device_backend) : Backend_types.Backend =
2224
struct
2325
module Domain = Domain [@warning "-3"]
@@ -581,7 +583,7 @@ module Lowered_no_device_backend (Backend : Backend_types.Lowered_no_device_back
581583

582584
let initialize config =
583585
global_config := config;
584-
initialize ()
586+
initialize config
585587

586588
type nonrec routine = context routine [@@deriving sexp_of]
587589

arrayjit/lib/cc_backend.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ let buffer_to_host dst ~src = Ndarray.map2 { f2 = Ndarray.A.blit } src dst
4343

4444
let is_initialized, initialize =
4545
let initialized = ref false in
46-
((fun () -> !initialized), fun () -> initialized := true)
46+
((fun () -> !initialized), fun _config -> initialized := true)
4747

4848
let finalize _ctx = ()
4949

arrayjit/lib/cuda_backend.missing.ml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ type stream = Unimplemented_stream [@@deriving sexp_of]
6060
type device = Unimplemented_device [@@deriving sexp_of]
6161

6262
let init Unimplemented_stream = Unimplemented_ctx
63+
let buffer_ptr _ctx_array = Unimplemented_buffer_ptr
6364
let alloc_buffer ?old_buffer:_ ~size_in_bytes:_ Unimplemented_stream = Unimplemented_buffer_ptr
6465
let await _stream = ()
6566
let is_idle _stream = true

arrayjit/lib/gcc_backend.gccjit.ml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ type mem_properties =
2525
let root_ctx = ref None
2626

2727
module Tn = Tnode
28-
include Backend_types.No_device_types
2928

3029
type buffer_ptr = ctx_array [@@deriving sexp_of]
3130
(** Alternative approach:
@@ -57,7 +56,7 @@ let buffer_to_host dst ~src = Ndarray.map2 { f2 = Ndarray.A.blit } src dst
5756

5857
let is_initialized () = Option.is_some !root_ctx
5958

60-
let initialize () =
59+
let initialize _config =
6160
if Option.is_none !root_ctx then (
6261
let open Gccjit in
6362
let ctx = Context.create () in

lib/attic.mld

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,4 +337,24 @@ let input_or_recurrent_nodes asgns =
337337
in
338338
loop asgns
339339

340+
]}
341+
342+
Upcoming in backend_types.ml:
343+
{[
344+
345+
val from_host : dst_ptr:buffer_ptr -> dst:context -> Tnode.t -> unit
346+
(** Like {!Backend.from_host}, but without synchronization and buffer retrieval. *)
347+
348+
val to_host : src_ptr:buffer_ptr -> src:context -> Tnode.t -> unit
349+
(** Like {!Backend.to_host}, but without synchronization and buffer retrieval. *)
350+
351+
val device_to_device :
352+
Tnode.t ->
353+
into_merge_buffer:merge_buffer_use ->
354+
dst_ptr:buffer_ptr ->
355+
dst:context ->
356+
src_ptr:buffer_ptr ->
357+
src:context ->
358+
unit
359+
(** Like {!Backend.device_to_device}, but without synchronization and buffer retrieval. *)
340360
]}

0 commit comments

Comments
 (0)