@@ -16,12 +16,15 @@ open Backend_intf
1616module type No_device_buffer_and_copying = sig
1717 include Alloc_buffer with type stream := unit
1818
19+ val get_used_memory : unit -> int
20+ (* * Returns (an upper bound of) the memory used for arrays, in bytes. *)
21+
1922 val buffer_to_buffer : dst :buffer_ptr -> src :buffer_ptr -> size_in_bytes :int -> unit
2023 val host_to_buffer : Ndarray .t -> dst :buffer_ptr -> unit
2124 val buffer_to_host : Ndarray .t -> src :buffer_ptr -> unit
2225end
2326
24- module No_device_buffer_and_copying :
27+ module No_device_buffer_and_copying () :
2528 No_device_buffer_and_copying with type buffer_ptr = unit Ctypes. ptr = struct
2629 type buffer_ptr = unit Ctypes .ptr
2730
@@ -31,18 +34,28 @@ module No_device_buffer_and_copying :
3134 type nonrec buffer_ptr = buffer_ptr [@@ deriving sexp_of ]
3235 end )
3336
34- let alloc_buffer ?old_buffer ~size_in_bytes () =
35- match old_buffer with
36- | Some ({ size_in_bytes = old_size ; _ } as buffer ) when size_in_bytes < = old_size -> buffer
37- | _ ->
38- let ptr = Ctypes. (to_voidp @@ allocate_n int8_t ~count: size_in_bytes) in
39- { ptr; size_in_bytes }
37+ let used_memory = Atomic. make 0
38+ let get_used_memory () = Atomic. get used_memory
39+
40+ let alloc_impl ~size_in_bytes =
41+ let finalize _ptr = ignore (Atomic. fetch_and_add used_memory ~- size_in_bytes : int ) in
42+ let ptr = Ctypes. (to_voidp @@ allocate_n int8_t ~count: size_in_bytes) in
43+ let _ : int = Atomic. fetch_and_add used_memory size_in_bytes in
44+ Stdlib.Gc. finalise finalize ptr;
45+ ptr
4046
4147 let alloc_zero_init_array prec ~dims () =
4248 let size_in_bytes =
4349 (if Array. length dims = 0 then 0 else Array. reduce_exn dims ~f: ( * )) * Ops. prec_in_bytes prec
4450 in
45- Ctypes. (to_voidp @@ allocate_n int8_t ~count: size_in_bytes)
51+ alloc_impl ~size_in_bytes
52+
53+ let alloc_buffer ?old_buffer ~size_in_bytes () =
54+ match old_buffer with
55+ | Some ({ size_in_bytes = old_size ; _ } as buffer ) when size_in_bytes < = old_size -> buffer
56+ | _ -> { ptr = alloc_impl ~size_in_bytes ; size_in_bytes }
57+
58+ let free_buffer = None
4659
4760 let buffer_to_buffer ~dst :Ctypes_static. (CPointer dst ) ~src :Ctypes_static. (CPointer src )
4861 ~size_in_bytes =
@@ -127,34 +140,12 @@ module type Backend_impl_common = sig
127140 directly from the host. *)
128141end
129142
130- (* * An intermediate interface for stream-agnostic (typically CPU) backend implementations. *)
131- module type No_device_backend = sig
132- include Backend_common
133- include Backend_impl_common with type buffer_ptr := buffer_ptr
143+ (* * An interface to adding schedulers for stream-agnostic (typically CPU) backend implementations. *)
144+ module type For_add_scheduler = sig
145+ include Backend_any_common
134146
135147 val name : string
136148
137- val link :
138- merge_buffer :(buffer_ptr * Tnode .t ) option ref ->
139- runner_label :string ->
140- ctx_arrays ->
141- code ->
142- ctx_arrays routine
143- (* * Returns the routine for the code's procedure, in a new context derived from the given context. *)
144-
145- val link_batch :
146- merge_buffer :(buffer_ptr * Tnode .t ) option ref ->
147- runner_label :string ->
148- ctx_arrays ->
149- code_batch ->
150- ctx_arrays * ctx_arrays routine option array
151- (* * Returns the routines for the procedures included in the code batch. The returned context is
152- downstream of all the returned routines (in particular, the routines' contexts are not
153- independent). *)
154-
155- val get_used_memory : unit -> int
156- (* * Returns (an upper bound of) the memory used for arrays, in bytes. *)
157-
158149 include No_device_buffer_and_copying with type buffer_ptr := buffer_ptr
159150end
160151
@@ -227,8 +218,8 @@ module type Lowered_no_device_backend = sig
227218 runner_label :string ->
228219 ctx_arrays ->
229220 procedure ->
230- ctx_arrays * Indexing .lowered_bindings * Task .t * string
231- (* * [runner_label] will be [get_name stream] of the stream from which the [ctx_arrays] come from . *)
221+ ctx_arrays * Indexing .lowered_bindings * Task .t
222+ (* * [runner_label] will be [get_name stream] of the stream holding the resulting [ctx_arrays]. *)
232223
233224 include No_device_buffer_and_copying with type buffer_ptr := buffer_ptr
234225end
@@ -255,28 +246,48 @@ module type No_buffer_retrieval_or_syncing = sig
255246 [Invalid_argument] if [into_merge_buffer = No] and [dst_ptr = None]. *)
256247end
257248
258- (* * Lowered-level backend interface: implementation-facing API for device-based (typically GPU)
259- backends. *)
249+ (* * A compilation-agnostic backend API -- {!Lowered_backend} instantates it, but
250+ {!Lowered_no_device_backend} backends are also converted to its instantations. *)
251+ module type With_scheduler = sig
252+ include Backend_device_common
253+
254+ val schedule_task : stream -> Task .t -> unit
255+ end
256+
257+ (* * Lowered-level backend interface: implementation-facing API for device-based (GPU, or CPU after
258+ adding a scheduler) backends. *)
260259module type Lowered_backend = sig
261- include No_buffer_retrieval_or_syncing
260+ include Backend_device_common
261+
262+ include
263+ No_buffer_retrieval_or_syncing
264+ with type buffer_ptr := buffer_ptr
265+ and type dev := dev
266+ and type runner := runner
267+ and type event := event
262268
263269 type code [@@deriving sexp_of]
264270 type code_batch [@@deriving sexp_of]
265271
266- val compile : name :string -> Indexing .unit_bindings -> Low_level .optimized -> code
272+ val compile : ? shared : bool -> name :string -> Indexing .unit_bindings -> Low_level .optimized -> code
267273
268274 val compile_batch :
275+ ?shared : bool ->
269276 names :string option array ->
270277 Indexing .unit_bindings ->
271278 Low_level .optimized option array ->
272279 code_batch
273280
274- val link : context -> code -> context * Indexing .lowered_bindings * Task .t
281+ val link : context -> code -> ctx_arrays * Indexing .lowered_bindings * Task .t
282+ (* * The results correspond to the fields {!field-Backend_intf.ctx_arrays} of
283+ {!field-Backend_intf.context}, {!field-Backend_intf.bindings} and
284+ {!field-Backend_intf.schedule} of {!Backend_intf.routine}. *)
275285
276286 val link_batch :
277- context -> code_batch -> context * Indexing .lowered_bindings * Task .t option array
278-
279- val scheduled_merge_node : stream -> Tnode .t option
280- (* * [scheduled_merge_node stream] is the tensor node that would be in the [stream]'s merge buffer
281- right after [await stream]. *)
287+ context ->
288+ code_batch ->
289+ ctx_arrays * Indexing .lowered_bindings * (ctx_arrays * Task .t ) option array
290+ (* * Returns the schedule tasks and their [ctx_arrays] for the procedures included in the code
291+ batch. The returned [ctx_arrays] will be part of a context downstream of all the tasks and the
292+ tasks' contexts are not independent (typically, they are cumulative). *)
282293end
0 commit comments