@@ -98,39 +98,64 @@ module type Backend = sig
9898 (* * Returns the routines for the procedures included in the code batch. The returned context is
9999 downstream of all the returned routines. *)
100100
101- val from_host : context -> Tnode .t -> bool
102- (* * If the array is both hosted and in-context, schedules a copy from host to context and returns
103- true, otherwise returns false. NOTE: when run for a device, it's the caller's responsibility
104- to synchronize the device before the host's data is overwritten . *)
101+ type event
102+ (* * An event tracks if a device finished computing past a particular point in its schedue. These
103+ values are used internally for scheduling across devices of the backend, and can be used for
104+ explicit scheduling . *)
105105
106- val to_host : context -> Tnode .t -> bool
107- (* * If the array is both hosted and in-context, schedules a copy from context to host and returns
108- true, otherwise returns false. NOTE: when run for a device, it's the caller's responsibility
109- to synchronize the device before the host's data is read. *)
106+ val await_ev : event -> unit
107+ (* * Blocks till the event completes, if it's not done already. *)
110108
111- val device_to_device :
112- Tnode .t -> into_merge_buffer :merge_buffer_use -> dst :context -> src :context -> bool
113- (* * If the node is absent from the [src] context and either it is present in the [dst] context or
114- [~into_merge_buffer] is different from [No]: raises an error.
109+ val is_done : event -> bool
110+ (* * Whether the event completed. *)
111+
112+ val work_for : context -> Tnode .t -> event option
113+ (* * If the tensor node is in the context, returns the event indicating if currently running or
114+ scheduled computations modifying that node on the context's device have completed.
115+
116+ NOTE: [work_for ctx tn], if work tracking was not registered for [tn], will register work
117+ tracking for [tn] and return the event tracking all currently scheduled computations on
118+ [ctx]'s device. *)
115119
116- If [~into_merge_buffer:No]: If the node is present in the [dst] context, schedules a copy of
117- the tensor node from the device of [src] to the device of [dst] and returns true, otherwise
118- returns false.
120+ val will_wait_for : context -> event -> unit
121+ (* * Schedules waiting for the given event on the context's device.
119122
120- If [~into_merge_buffer] is different from [No]: schedules the following task and returns true.
123+ NOTE: it should rarely be needed to call [will_wait_for] explicitly, because it is typically
124+ called internally when necessary. But there is one exception, see {!device_to_device} when
125+ [into_merge_buffer=Streaming]. *)
121126
122- The merge-buffer task sets on [dst] the merge buffer source to the given node. If
123- [~into_merge_buffer:Streaming], remembers the buffer pointer of the source node to use for
124- streaming, without blocking. If [~into_merge_buffer:Copy], copies from [src] to the merge
125- buffer of [dst]'s device.
127+ val from_host : context -> Tnode .t -> bool
128+ (* * If the tensor node is both hosted and in-context, schedules a copy from host to context and
129+ returns true, otherwise returns false. NOTE: it's the caller's responsibility to synchronize
130+ the device (via [await ctx.device] or [await_ev (work_for ctx tn)]) before the host's data is
131+ overwritten. *)
126132
127- If the [dst] context resulted from a compilation with [Streaming] or [Copy] specific merge
128- buffer code, the [device_to_device] call should fail immediately if there's a mismatch with
129- [~into_merge_buffer].
133+ val to_host : context -> Tnode .t -> bool
134+ (* * If the tensor node is both hosted and in-context, schedules a copy from context to host and
135+ returns true, otherwise returns false. NOTE: it's the caller's responsibility to synchronize
136+ the device (via [await ctx.device] or [await_ev (work_for ctx tn)]) before the host's data is
137+ read. *)
130138
131- NOTE: it's the caller's responsibility to synchronize the [src] device, if needed, {i before}
132- calling [device_to_device], and if [~into_merge_buffer:Streaming], the [dst] device
133- {i afterward}, before any computations on the [src] device overwrite the node. *)
139+ val device_to_device :
140+ Tnode .t -> into_merge_buffer :merge_buffer_use -> dst :context -> src :context -> bool
141+ (* * [device_to_device tn ~into_merge_buffer ~dst ~src] proceeds as follows:
142+ - If the node is absent from the [src] context and either it is present in the [dst] context
143+ or [into_merge_buffer] is different from [No]: raises an error.
144+ - If the node is absent from [dst] and [into_merge_buffer=No]: returns false.
145+ - Executes [will_wait_for dst (work_for src tn)].
146+ - If [into_merge_buffer=No]: schedules a copy of the tensor node from the device of [src] to
147+ the device of [dst].
148+ - If [into_merge_buffer] is different from [No]: sets on [dst] the merge buffer source to the
149+ given node. If [into_merge_buffer=Streaming], remembers the buffer pointer of the source
150+ node to use for streaming, without blocking. If [into_merge_buffer=Copy], schedules copying
151+ from [src] to the merge buffer of [dst]'s device.
152+ - If the [dst] context resulted from a compilation with [Streaming] or [Copy] specific merge
153+ buffer code, the [device_to_device] call should fail immediately if there's a mismatch with
154+ [into_merge_buffer].
155+
156+ NOTE: If [into_merge_buffer:Streaming], after scheduling the work on [dst] using the merge
157+ buffer but before scheduling work on [src] that modifies [tn], execute
158+ [will_wait_for src (all_work (get_ctx_device dst))]. *)
134159
135160 type physical_device
136161 type device
@@ -141,6 +166,10 @@ module type Backend = sig
141166 val await : device -> unit
142167 (* * Blocks till the device becomes idle, i.e. synchronizes the device. *)
143168
169+ val all_work : device -> event
170+ (* * Returns the event indicating if any currently running or scheduled computations on the device
171+ have completed. *)
172+
144173 val is_idle : device -> bool
145174 (* * Whether the device is currently waiting for work. *)
146175
@@ -173,6 +202,25 @@ module Multicore_backend (Backend : No_device_backend) : Backend = struct
173202 let sexp_of_task_queue q =
174203 Sexp. (List [ Atom " task_queue_of_size" ; Atom (Int. to_string @@ Queue. size q) ])
175204
205+ type event = Not_implemented_yet (* * TODO: NOT IMPLEMENTED YET *)
206+
207+ (* * TODO: Blocks till the event completes, if it's not done already. *)
208+ let await_ev Not_implemented_yet = ()
209+
210+ (* * TODO: Whether the event completed. *)
211+ let is_done Not_implemented_yet = true
212+
213+ (* * TODO: If the tensor node is in the context, returns the event indicating if currently running
214+ or scheduled computations modifying that node on the context's device have completed.
215+
216+ NOTE: [work_for ctx tn], if work tracking was not registered for [tn], will register work
217+ tracking for [tn] and return the event tracking all currently scheduled computations on
218+ [ctx]'s device. *)
219+ let work_for _ctx _tn = Some Not_implemented_yet
220+
221+ (* * TODO: Schedules waiting for the given event on the context's device. *)
222+ let will_wait_for _ctx Not_implemented_yet = ()
223+
176224 type device_state = {
177225 mutable keep_spinning : bool ;
178226 mutable device_error : exn option ;
@@ -222,6 +270,10 @@ module Multicore_backend (Backend : No_device_backend) : Backend = struct
222270 Option. iter d.device_error ~f: (fun e ->
223271 Exn. reraise e @@ name ^ " device " ^ Int. to_string device.ordinal))
224272
273+ (* * TODO: Returns the event indicating if any currently running or scheduled computations on the
274+ device have completed. *)
275+ let all_work _device = Not_implemented_yet
276+
225277 let % track3_l_sexp schedule_task device task =
226278 assert (Domain. is_main_domain () );
227279 [% log_result " schedule_task" , Tnode. describe task, " device" , (device.ordinal : int )];
@@ -456,7 +508,7 @@ module Multicore_backend (Backend : No_device_backend) : Backend = struct
456508
457509 let num_physical_devices () = Domain. recommended_domain_count () - 1
458510 let suggested_num_virtual_devices _device = 1
459- let devices = Array. create ~len: (num_physical_devices () ) None
511+ let devices : physical_device option array = Array. create ~len: (num_physical_devices () ) None
460512
461513 let % track2_sexp unsafe_cleanup () =
462514 assert (Domain. is_main_domain () );
@@ -497,8 +549,14 @@ let sync_suggested_num_virtual_devices = ref 1
497549
498550(* * A minimalisitc wrapper creating backends where all calls run synchronously on the main thread.
499551 There is only one physical device, but an arbitrary number of virtual devices. *)
500- module Sync_backend (Backend : No_device_backend ) (* : Backend *) = struct
552+ module Sync_backend (Backend : No_device_backend ) : Backend = struct
501553 type buffer_ptr = Backend .buffer_ptr [@@ deriving sexp_of ]
554+ type event = unit
555+
556+ let await_ev () = ()
557+ let is_done () = true
558+ let work_for _context _tn = Some ()
559+ let will_wait_for _context () = ()
502560
503561 type device = {
504562 subordinal : int ;
@@ -516,10 +574,11 @@ module Sync_backend (Backend : No_device_backend) (* : Backend *) = struct
516574
517575 let expected_merge_node (code : code ) = Backend. expected_merge_node code
518576 let expected_merge_nodes (codes : code_batch ) = Backend. expected_merge_nodes codes
577+ let all_work _device = ()
519578 let is_idle _device = true
520579 let name = " sync " ^ Backend. name
521580 let await _device = ()
522- let global_run_no = ref 0
581+ (* let global_run_no = ref 0 *)
523582
524583 type context = { device : device ; ctx : Backend .context ; expected_merge_node : Tnode .t option }
525584 [@@ deriving sexp_of ]
@@ -934,6 +993,20 @@ module Cuda_backend : Backend = struct
934993 name;
935994 })) )
936995
996+ type event = Cudajit.Event .t
997+
998+ let work_for _ctx _tn = Some (Cudajit.Event. create () )
999+ (* TODO: NOT IMPLEMENTED YET *)
1000+
1001+ let is_done event = Cudajit.Event. query event
1002+ let will_wait_for _context _event = ()
1003+ (* Cudajit.Event.wait (get_ctx_device context.ctx).Cuda_backend.stream event *)
1004+ (* TODO: NOT IMPLEMENTED YET *)
1005+
1006+ let await_ev event = Cudajit.Event. synchronize event
1007+ let all_work _device = Cudajit.Event. create ()
1008+ (* TODO: NOT IMPLEMENTED YET *)
1009+
9371010 let init device = { ctx = init device; expected_merge_node = None }
9381011 let get_ctx_device context = get_ctx_device context.ctx
9391012 let finalize context = finalize context.ctx
0 commit comments