@@ -26,6 +26,8 @@ module Types = struct
2626 }
2727 [@@ deriving sexp_of ]
2828
29+ (* * For now, we only configure a backend with regard to how many streams it should suggest using
30+ (where applicable). *)
2931 type config = Only_devices_parallel | For_parallel_copying | Most_parallel_streams
3032 [@@ deriving equal , sexp , variants ]
3133
@@ -39,12 +41,10 @@ module Types = struct
3941 [@@ deriving sexp_of ]
4042end
4143
42- module type Backend_common = sig
43- type code [@@deriving sexp_of]
44- type code_batch [@@deriving sexp_of]
44+ (* * Parts shared by both assignments-level and lowered-level backend interfaces. *)
45+ module type Backend_any_common = sig
4546 type buffer_ptr [@@deriving sexp_of]
4647 type context [@@deriving sexp_of]
47- type routine = context Types .routine [@@ deriving sexp_of ]
4848 type stream
4949
5050 type init_info
@@ -67,9 +67,15 @@ module type Backend_common = sig
6767 (* * Finalizes (just) the context. *)
6868
6969 val alloc_buffer : ?old_buffer : buffer_ptr * int -> size_in_bytes :int -> stream -> buffer_ptr
70+ end
7071
71- val get_used_memory : unit -> int
72- (* * Returns (an upper bound of) the memory used for arrays, in bytes. *)
72+ (* * Parts shared by assignments-level backend interfaces. *)
73+ module type Backend_common = sig
74+ include Backend_any_common
75+
76+ type routine = context Types .routine [@@ deriving sexp_of ]
77+ type code [@@deriving sexp_of]
78+ type code_batch [@@deriving sexp_of]
7379
7480 val compile : ?shared : bool -> ?name : string -> Indexing .unit_bindings -> Assignments .comp -> code
7581 (* * If [~shared:true] (default [false]), the backend should prefer to do more compile work in a
@@ -89,6 +95,7 @@ module type Backend_common = sig
8995 [occupancy] returns true are included. *)
9096end
9197
98+ (* * An intermediate interface for stream-agnostic (typically CPU) backend implementations. *)
9299module type No_device_backend = sig
93100 include Backend_common with type init_info := string and type stream := unit
94101
@@ -104,23 +111,21 @@ module type No_device_backend = sig
104111 downstream of all the returned routines (in particular, the routines' contexts are not
105112 independent). *)
106113
114+ val get_used_memory : unit -> int
115+ (* * Returns (an upper bound of) the memory used for arrays, in bytes. *)
116+
107117 val to_buffer : Tnode .t -> dst :buffer_ptr -> src :context -> unit
108118 val host_to_buffer : Ndarray .t -> dst :buffer_ptr -> unit
109119 val buffer_to_host : Ndarray .t -> src :buffer_ptr -> unit
110120 val get_buffer : Tnode .t -> context -> buffer_ptr option
111121end
112122
113- module type Backend = sig
123+ (* * Parts shared by both assignments-level and lowered-level backend interfaces providing streams
124+ and devices. *)
125+ module type Backend_device_common = sig
114126 type stream [@@deriving sexp_of]
115127
116- include Backend_common with type init_info := stream and type stream := stream
117-
118- val link : context -> code -> routine
119- (* * Returns the routine for the code's procedure, in a new context derived from the given context. *)
120-
121- val link_batch : context -> code_batch -> context * routine option array
122- (* * Returns the routines for the procedures included in the code batch. The returned context is
123- downstream of all the returned routines. *)
128+ include Backend_any_common with type init_info := stream and type stream := stream
124129
125130 type event
126131 (* * An event tracks if a stream finished computing past a particular point in its schedue. These
@@ -147,6 +152,51 @@ module type Backend = sig
147152 called internally when necessary. But there is one exception, see {!device_to_device} when
148153 [into_merge_buffer=Streaming]. *)
149154
155+ type device
156+
157+ val get_used_memory : device -> int
158+ (* * Returns (an upper bound of) the memory used for arrays, in bytes. *)
159+
160+ val await : stream -> unit
161+ (* * Blocks till the stream becomes idle, i.e. synchronizes the stream. *)
162+
163+ val all_work : stream -> event
164+ (* * Returns the event indicating if any currently running or scheduled computations on the stream
165+ have completed. *)
166+
167+ val is_idle : stream -> bool
168+ (* * Whether the stream is currently waiting for work. *)
169+
170+ val get_device : ordinal :int -> device
171+ val num_devices : unit -> int
172+
173+ val suggested_num_streams : device -> int
174+ (* * The optimal number of streams for the given device to follow the {!Types.config} strategy
175+ passed to {!No_device_backend.initialize}. *)
176+
177+ val new_stream : device -> stream
178+ val get_ctx_stream : context -> stream
179+ val get_stream_device : stream -> device
180+ val to_ordinal : device -> int
181+ val get_name : stream -> string
182+ end
183+
184+ module type Backend = sig
185+ include Backend_device_common
186+
187+ include
188+ Backend_common
189+ with type context := context
190+ and type init_info := stream
191+ and type stream := stream
192+
193+ val link : context -> code -> routine
194+ (* * Returns the routine for the code's procedure, in a new context derived from the given context. *)
195+
196+ val link_batch : context -> code_batch -> context * routine option array
197+ (* * Returns the routines for the procedures included in the code batch. The returned context is
198+ downstream of all the returned routines. *)
199+
150200 val from_host : context -> Tnode .t -> bool
151201 (* * If the tensor node is both hosted and in-context, schedules a copy from host to context and
152202 returns true, otherwise returns false. NOTE: it's the caller's responsibility to synchronize
@@ -175,47 +225,16 @@ module type Backend = sig
175225 NOTE: If [into_merge_buffer=Streaming], after scheduling the work on [dst] using the merge
176226 buffer but before scheduling work on [src] that modifies [tn], execute
177227 [will_wait_for src (all_work (get_ctx_stream dst))]. *)
178-
179- type device
180-
181- val get_used_memory : device -> int
182- (* * Returns (an upper bound of) the memory used for arrays, in bytes. *)
183-
184- val await : stream -> unit
185- (* * Blocks till the stream becomes idle, i.e. synchronizes the stream. *)
186-
187- val all_work : stream -> event
188- (* * Returns the event indicating if any currently running or scheduled computations on the stream
189- have completed. *)
190-
191- val is_idle : stream -> bool
192- (* * Whether the stream is currently waiting for work. *)
193-
194- val get_device : ordinal :int -> device
195- val num_devices : unit -> int
196-
197- val suggested_num_streams : device -> int
198- (* * The optimal number of streams for the given device to follow the {!Types.config} strategy
199- passed to {!No_device_backend.initialize}. *)
200-
201- val new_stream : device -> stream
202- val get_ctx_stream : context -> stream
203- val get_stream_device : stream -> device
204- val to_ordinal : device -> int
205- val get_name : stream -> string
206228end
207229
230+ (* * Parts shared by lowered-level backends excluding what's already in {!Backend_any_common}. *)
208231module type Lowered_backend_common = sig
209232 type context [@@deriving sexp_of]
210233 type ctx_array [@@deriving sexp_of]
211234 type ctx_arrays [@@deriving sexp_of]
212- type buffer_ptr [@@deriving sexp_of]
213- type config
214- type init_info
215- type stream
235+ type buffer_ptr
216236
217237 val buffer_ptr : ctx_array -> buffer_ptr
218- val alloc_buffer : ?old_buffer : buffer_ptr * int -> size_in_bytes :int -> stream -> buffer_ptr
219238 val ctx_arrays : context -> ctx_arrays
220239 val get_array : ctx_arrays -> Tnode .t -> ctx_array option
221240
@@ -224,20 +243,18 @@ module type Lowered_backend_common = sig
224243
225244 Should return false for nodes that are virtual, local, or which the backend prefers to access
226245 directly from the host. *)
227-
228- val initialize : config -> unit
229- val is_initialized : unit -> bool
230- val init : init_info -> context
231- val finalize : context -> unit
232- val name : string
233246end
234247
248+ (* * Lowered-level stream agnostic backend interface: implementation-facing API for CPU backends. *)
235249module type Lowered_no_device_backend = sig
250+ include Lowered_backend_common
251+
236252 include
237- Lowered_backend_common
238- with type stream := unit
239- and type config := unit
253+ Backend_any_common
254+ with type context := context
255+ and type stream := unit
240256 and type init_info := string
257+ and type buffer_ptr := buffer_ptr
241258
242259 type procedure [@@deriving sexp_of]
243260
@@ -266,27 +283,15 @@ module type Lowered_no_device_backend = sig
266283 val buffer_to_host : Ndarray .t -> src :buffer_ptr -> unit
267284end
268285
286+ (* * Lowered-level backend interface: implementation-facing API for device-based (typically GPU)
287+ backends. *)
269288module type Lowered_backend = sig
270- type stream [@@deriving sexp_of]
271-
272- include
273- Lowered_backend_common
274- with type config := Types. config
275- and type stream := stream
276- and type init_info := stream
277-
289+ include Lowered_backend_common
290+ include Backend_device_common with type context := context and type buffer_ptr := buffer_ptr
291+
278292 type code [@@deriving sexp_of]
279293 type code_batch [@@deriving sexp_of]
280- type event
281294
282- val sync : event -> unit
283- val is_done : event -> bool
284- val work_for : context -> Tnode .t -> event option
285- val will_wait_for : context -> event -> unit
286-
287- open Types
288-
289- val sexp_of_context : context -> Sexplib.Sexp .t
290295 val compile : name :string -> Indexing .unit_bindings -> Low_level .optimized -> code
291296
292297 val compile_batch :
@@ -301,34 +306,13 @@ module type Lowered_backend = sig
301306 context -> code_batch -> context * Indexing .lowered_bindings * Task .t option array
302307
303308 val from_host : context -> Tnode .t -> bool
304- (* * If the array is both hosted and in-context, copies from host to context. *)
305309
306310 val to_host : context -> Tnode .t -> bool
307- (* * If the array is both hosted and in-context, copies from context to host. *)
308311
309312 val device_to_device :
310- Tnode .t -> into_merge_buffer :merge_buffer_use -> dst :context -> src :context -> bool
311- (* * See {!Backend.device_to_device}. *)
312-
313- type device
314-
315- val get_used_memory : device -> int
316- (* * Returns (an upper bound of) the memory used for arrays, in bytes. *)
317-
318- val await : stream -> unit
319- val is_idle : stream -> bool
320- val all_work : stream -> event
313+ Tnode .t -> into_merge_buffer :Types .merge_buffer_use -> dst :context -> src :context -> bool
321314
322315 val scheduled_merge_node : stream -> Tnode .t option
323316 (* * [scheduled_merge_node stream] is the tensor node that would be in the [stream]'s merge buffer
324317 right after [await stream]. *)
325-
326- val num_devices : unit -> int
327- val suggested_num_streams : device -> int
328- val get_device : ordinal :int -> device
329- val get_stream_device : stream -> device
330- val new_stream : device -> stream
331- val get_ctx_stream : context -> stream
332- val get_name : stream -> string
333- val to_ordinal : device -> int
334318end
0 commit comments