@@ -16,7 +16,7 @@ module Types = struct
1616 }
1717 [@@ deriving sexp_of ]
1818
19- type config = Physical_devices_only | For_parallel_copying | Most_parallel_devices
19+ type config = Only_devices_parallel | For_parallel_copying | Most_parallel_streams
2020 [@@ deriving equal , sexp , variants ]
2121
2222 type merge_buffer_use = No | Streaming | Copy [@@ deriving equal , sexp ]
@@ -48,7 +48,7 @@ module type No_device_backend = sig
4848 {!initialize} before using the backend. *)
4949
5050 val init : label :string -> context
51- (* * [label] is usually the backend name concatenated with the device number. *)
51+ (* * [label] is usually the backend name concatenated with the device or stream number. *)
5252
5353 val finalize : context -> unit
5454 (* * Finalizes (just) the context. *)
@@ -59,8 +59,8 @@ module type No_device_backend = sig
5959
6060 val compile : ?shared : bool -> ?name : string -> Indexing .unit_bindings -> Assignments .comp -> code
6161 (* * If [~shared:true] (default [false]), the backend should prefer to do more compile work in a
62- device-agnostic way. If [~shared:false], the backend can opt to postpone compiling altogether
63- until [link] is called, to benefit from more optimizations. *)
62+ device-and-stream- agnostic way. If [~shared:false], the backend can opt to postpone compiling
63+ altogether until [link] is called, to benefit from more optimizations. *)
6464
6565 val compile_batch :
6666 ?shared : bool ->
@@ -88,8 +88,8 @@ module type No_device_backend = sig
8888
8989 val unsafe_cleanup : unit -> unit
9090 (* * Cleans up all work on a backend, releases resources. All previously retrieved values
91- (contexts, virtual and physical devices) become invalid. The backend needs to be initialized
92- again to be used again. *)
91+ (contexts, streams and devices) become invalid. The backend needs to be initialized again to
92+ be used again. *)
9393
9494 val to_buffer : Tnode .t -> dst :buffer_ptr -> src :context -> unit
9595 val host_to_buffer : Ndarray .t -> dst :buffer_ptr -> unit
@@ -108,8 +108,8 @@ module type Backend = sig
108108 downstream of all the returned routines. *)
109109
110110 type event
111- (* * An event tracks if a device finished computing past a particular point in its schedue. These
112- values are used internally for scheduling across devices of the backend, and can be used for
111+ (* * An event tracks if a stream finished computing past a particular point in its schedue. These
112+ values are used internally for scheduling across streams of the backend, and can be used for
113113 explicit scheduling. *)
114114
115115 val sync : event -> unit
@@ -120,13 +120,13 @@ module type Backend = sig
120120
121121 val work_for : context -> Tnode .t -> event option
122122 (* * If the tensor node is in the context, returns the event indicating if currently running or
123- scheduled computations modifying that node on the context's device have completed.
123+ scheduled computations modifying that node on the context's stream have completed.
124124
125125 NOTE: [work_for ctx tn], if work tracking was not yet registered for [tn], will register work
126- tracking for [tn] and return the [all_work] event for [ctx]'s device . *)
126+ tracking for [tn] and return the [all_work] event for [ctx]'s stream . *)
127127
128128 val will_wait_for : context -> event -> unit
129- (* * Schedules waiting for the given event on the context's device .
129+ (* * Schedules waiting for the given event on the context's stream .
130130
131131 NOTE: it should rarely be needed to call [will_wait_for] explicitly, because it is typically
132132 called internally when necessary. But there is one exception, see {!device_to_device} when
@@ -135,13 +135,13 @@ module type Backend = sig
135135 val from_host : context -> Tnode .t -> bool
136136 (* * If the tensor node is both hosted and in-context, schedules a copy from host to context and
137137 returns true, otherwise returns false. NOTE: it's the caller's responsibility to synchronize
138- the device (via [await ctx.device ] or [sync (work_for ctx tn)]) before the host's data is
138+ the stream (via [await ctx.stream ] or [sync (work_for ctx tn)]) before the host's data is
139139 overwritten. *)
140140
141141 val to_host : context -> Tnode .t -> bool
142142 (* * If the tensor node is both hosted and in-context, schedules a copy from context to host and
143143 returns true, otherwise returns false. NOTE: it's the caller's responsibility to synchronize
144- the device (via [await ctx.device ] or [sync (work_for ctx tn)]) before the host's data is
144+ the stream (via [await ctx.stream ] or [sync (work_for ctx tn)]) before the host's data is
145145 read. *)
146146
147147 val device_to_device :
@@ -151,50 +151,49 @@ module type Backend = sig
151151 or [into_merge_buffer] is different from [No]: raises an error.
152152 - If the node is absent from [dst] and [into_merge_buffer=No]: returns false.
153153 - Executes [will_wait_for dst (work_for src tn)].
154- - If [into_merge_buffer=No]: schedules a copy of the tensor node from the device of [src] to
155- the device of [dst].
154+ - If [into_merge_buffer=No]: schedules a copy of the tensor node from [src] to [dst].
156155 - If [into_merge_buffer] is different from [No]: sets on [dst] the merge buffer source to the
157156 given node. If [into_merge_buffer=Streaming], remembers the buffer pointer of the source
158157 node to use for streaming, without blocking. If [into_merge_buffer=Copy], schedules copying
159- from [src] to the merge buffer of [dst]'s device .
158+ from [src] to the merge buffer of [dst]'s stream .
160159 - If the [dst] context resulted from a compilation with [Streaming] or [Copy] specific merge
161160 buffer code, the [device_to_device] call should fail immediately if there's a mismatch with
162161 [into_merge_buffer].
163162
164163 NOTE: If [into_merge_buffer:Streaming], after scheduling the work on [dst] using the merge
165164 buffer but before scheduling work on [src] that modifies [tn], execute
166- [will_wait_for src (all_work (get_ctx_device dst))]. *)
165+ [will_wait_for src (all_work (get_ctx_stream dst))]. *)
167166
168- type physical_device
169167 type device
168+ type stream
170169
171- val init : device -> context
172- val alloc_buffer : ?old_buffer : buffer_ptr * int -> size_in_bytes :int -> device -> buffer_ptr
170+ val init : stream -> context
171+ val alloc_buffer : ?old_buffer : buffer_ptr * int -> size_in_bytes :int -> stream -> buffer_ptr
173172
174- val await : device -> unit
175- (* * Blocks till the device becomes idle, i.e. synchronizes the device . *)
173+ val await : stream -> unit
174+ (* * Blocks till the stream becomes idle, i.e. synchronizes the stream . *)
176175
177- val all_work : device -> event
178- (* * Returns the event indicating if any currently running or scheduled computations on the device
176+ val all_work : stream -> event
177+ (* * Returns the event indicating if any currently running or scheduled computations on the stream
179178 have completed. *)
180179
181- val is_idle : device -> bool
182- (* * Whether the device is currently waiting for work. *)
180+ val is_idle : stream -> bool
181+ (* * Whether the stream is currently waiting for work. *)
183182
184- val sexp_of_device : device -> Sexp .t
185- val get_device : ordinal :int -> physical_device
186- val num_physical_devices : unit -> int
183+ val sexp_of_stream : stream -> Sexp .t
184+ val get_device : ordinal :int -> device
185+ val num_devices : unit -> int
187186
188- val suggested_num_virtual_devices : physical_device -> int
189- (* * The optimal number of virtual devices for the given physical device to follow the
190- {!Types.config} strategy passed to {!No_device_backend.initialize}. *)
187+ val suggested_num_streams : device -> int
188+ (* * The optimal number of streams for the given device to follow the {!Types.config} strategy
189+ passed to {!No_device_backend.initialize}. *)
191190
192- val new_virtual_device : physical_device -> device
193- val get_ctx_device : context -> device
194- val get_physical_device : device -> physical_device
195- val to_ordinal : physical_device -> int
196- val to_subordinal : device -> int
197- val get_name : device -> string
191+ val new_stream : device -> stream
192+ val get_ctx_stream : context -> stream
193+ val get_stream_device : stream -> device
194+ val to_ordinal : device -> int
195+ val to_subordinal : stream -> int
196+ val get_name : stream -> string
198197end
199198
200199module type Simple_backend = sig
@@ -289,7 +288,7 @@ module type Lowered_backend = sig
289288
290289 val device_to_device :
291290 Tnode .t -> into_merge_buffer :merge_buffer_use -> dst :context -> src :context -> bool
292- (* * If the array is in both contexts, copies from [dst] to [src]. *)
291+ (* * If the tensor node is in both contexts, copies from [dst] to [src]. *)
293292
294293 type buffer_ptr [@@deriving sexp_of]
295294
@@ -298,23 +297,23 @@ module type Lowered_backend = sig
298297 val buffer_to_host : Ndarray .t -> src :buffer_ptr -> unit
299298 val get_buffer : Tnode .t -> context -> buffer_ptr option
300299
301- type physical_device
302300 type device
303-
304- val alloc_buffer : ?old_buffer : buffer_ptr * int -> size_in_bytes :int -> device -> buffer_ptr
305- val init : device -> context
306- val await : device -> unit
307- val is_idle : device -> bool
308- val all_work : device -> event
309- val sexp_of_device : device -> Sexplib.Sexp .t
310- val num_physical_devices : unit -> int
311- val suggested_num_virtual_devices : physical_device -> int
312- val get_device : ordinal :int -> physical_device
313- val get_physical_device : device -> physical_device
314- val new_virtual_device : physical_device -> device
315- val get_ctx_device : context -> device
316- val get_name : device -> string
317- val to_ordinal : physical_device -> int
318- val to_subordinal : device -> int
301+ type stream
302+
303+ val alloc_buffer : ?old_buffer : buffer_ptr * int -> size_in_bytes :int -> stream -> buffer_ptr
304+ val init : stream -> context
305+ val await : stream -> unit
306+ val is_idle : stream -> bool
307+ val all_work : stream -> event
308+ val sexp_of_stream : stream -> Sexplib.Sexp .t
309+ val num_devices : unit -> int
310+ val suggested_num_streams : device -> int
311+ val get_device : ordinal :int -> device
312+ val get_stream_device : stream -> device
313+ val new_stream : device -> stream
314+ val get_ctx_stream : context -> stream
315+ val get_name : stream -> string
316+ val to_ordinal : device -> int
317+ val to_subordinal : stream -> int
319318 val name : string
320319end
0 commit comments