Skip to content

Commit 112d458

Browse files
committed
Specification of device-to-device synchronization via API and docu-comments
1 parent 8ad8054 commit 112d458

File tree

3 files changed

+106
-31
lines changed

3 files changed

+106
-31
lines changed

CHANGES.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
### Added
44

5-
- TODO: (Virtual) device-to-device synchronization functionality.
6-
- TODO: lazy per-tensor-node synchronization functionality.
5+
- TODO: (Virtual) device-to-device synchronization functionality, with lazy per-tensor-node synchronization.
76

87
### Changed
98

arrayjit/lib/backends.ml

Lines changed: 102 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -98,39 +98,64 @@ module type Backend = sig
9898
(** Returns the routines for the procedures included in the code batch. The returned context is
9999
downstream of all the returned routines. *)
100100

101-
val from_host : context -> Tnode.t -> bool
102-
(** If the array is both hosted and in-context, schedules a copy from host to context and returns
103-
true, otherwise returns false. NOTE: when run for a device, it's the caller's responsibility
104-
to synchronize the device before the host's data is overwritten. *)
101+
type event
102+
(** An event tracks if a device finished computing past a particular point in its schedue. These
103+
values are used internally for scheduling across devices of the backend, and can be used for
104+
explicit scheduling. *)
105105

106-
val to_host : context -> Tnode.t -> bool
107-
(** If the array is both hosted and in-context, schedules a copy from context to host and returns
108-
true, otherwise returns false. NOTE: when run for a device, it's the caller's responsibility
109-
to synchronize the device before the host's data is read. *)
106+
val await_ev : event -> unit
107+
(** Blocks till the event completes, if it's not done already. *)
110108

111-
val device_to_device :
112-
Tnode.t -> into_merge_buffer:merge_buffer_use -> dst:context -> src:context -> bool
113-
(** If the node is absent from the [src] context and either it is present in the [dst] context or
114-
[~into_merge_buffer] is different from [No]: raises an error.
109+
val is_done : event -> bool
110+
(** Whether the event completed. *)
111+
112+
val work_for : context -> Tnode.t -> event option
113+
(** If the tensor node is in the context, returns the event indicating if currently running or
114+
scheduled computations modifying that node on the context's device have completed.
115+
116+
NOTE: [work_for ctx tn], if work tracking was not registered for [tn], will register work
117+
tracking for [tn] and return the event tracking all currently scheduled computations on
118+
[ctx]'s device. *)
115119

116-
If [~into_merge_buffer:No]: If the node is present in the [dst] context, schedules a copy of
117-
the tensor node from the device of [src] to the device of [dst] and returns true, otherwise
118-
returns false.
120+
val will_wait_for : context -> event -> unit
121+
(** Schedules waiting for the given event on the context's device.
119122
120-
If [~into_merge_buffer] is different from [No]: schedules the following task and returns true.
123+
NOTE: it should rarely be needed to call [will_wait_for] explicitly, because it is typically
124+
called internally when necessary. But there is one exception, see {!device_to_device} when
125+
[into_merge_buffer=Streaming]. *)
121126

122-
The merge-buffer task sets on [dst] the merge buffer source to the given node. If
123-
[~into_merge_buffer:Streaming], remembers the buffer pointer of the source node to use for
124-
streaming, without blocking. If [~into_merge_buffer:Copy], copies from [src] to the merge
125-
buffer of [dst]'s device.
127+
val from_host : context -> Tnode.t -> bool
128+
(** If the tensor node is both hosted and in-context, schedules a copy from host to context and
129+
returns true, otherwise returns false. NOTE: it's the caller's responsibility to synchronize
130+
the device (via [await ctx.device] or [await_ev (work_for ctx tn)]) before the host's data is
131+
overwritten. *)
126132

127-
If the [dst] context resulted from a compilation with [Streaming] or [Copy] specific merge
128-
buffer code, the [device_to_device] call should fail immediately if there's a mismatch with
129-
[~into_merge_buffer].
133+
val to_host : context -> Tnode.t -> bool
134+
(** If the tensor node is both hosted and in-context, schedules a copy from context to host and
135+
returns true, otherwise returns false. NOTE: it's the caller's responsibility to synchronize
136+
the device (via [await ctx.device] or [await_ev (work_for ctx tn)]) before the host's data is
137+
read. *)
130138

131-
NOTE: it's the caller's responsibility to synchronize the [src] device, if needed, {i before}
132-
calling [device_to_device], and if [~into_merge_buffer:Streaming], the [dst] device
133-
{i afterward}, before any computations on the [src] device overwrite the node. *)
139+
val device_to_device :
140+
Tnode.t -> into_merge_buffer:merge_buffer_use -> dst:context -> src:context -> bool
141+
(** [device_to_device tn ~into_merge_buffer ~dst ~src] proceeds as follows:
142+
- If the node is absent from the [src] context and either it is present in the [dst] context
143+
or [into_merge_buffer] is different from [No]: raises an error.
144+
- If the node is absent from [dst] and [into_merge_buffer=No]: returns false.
145+
- Executes [will_wait_for dst (work_for src tn)].
146+
- If [into_merge_buffer=No]: schedules a copy of the tensor node from the device of [src] to
147+
the device of [dst].
148+
- If [into_merge_buffer] is different from [No]: sets on [dst] the merge buffer source to the
149+
given node. If [into_merge_buffer=Streaming], remembers the buffer pointer of the source
150+
node to use for streaming, without blocking. If [into_merge_buffer=Copy], schedules copying
151+
from [src] to the merge buffer of [dst]'s device.
152+
- If the [dst] context resulted from a compilation with [Streaming] or [Copy] specific merge
153+
buffer code, the [device_to_device] call should fail immediately if there's a mismatch with
154+
[into_merge_buffer].
155+
156+
NOTE: If [into_merge_buffer:Streaming], after scheduling the work on [dst] using the merge
157+
buffer but before scheduling work on [src] that modifies [tn], execute
158+
[will_wait_for src (all_work (get_ctx_device dst))]. *)
134159

135160
type physical_device
136161
type device
@@ -141,6 +166,10 @@ module type Backend = sig
141166
val await : device -> unit
142167
(** Blocks till the device becomes idle, i.e. synchronizes the device. *)
143168

169+
val all_work : device -> event
170+
(** Returns the event indicating if any currently running or scheduled computations on the device
171+
have completed. *)
172+
144173
val is_idle : device -> bool
145174
(** Whether the device is currently waiting for work. *)
146175

@@ -173,6 +202,25 @@ module Multicore_backend (Backend : No_device_backend) : Backend = struct
173202
let sexp_of_task_queue q =
174203
Sexp.(List [ Atom "task_queue_of_size"; Atom (Int.to_string @@ Queue.size q) ])
175204

205+
type event = Not_implemented_yet (** TODO: NOT IMPLEMENTED YET *)
206+
207+
(** TODO: Blocks till the event completes, if it's not done already. *)
208+
let await_ev Not_implemented_yet = ()
209+
210+
(** TODO: Whether the event completed. *)
211+
let is_done Not_implemented_yet = true
212+
213+
(** TODO: If the tensor node is in the context, returns the event indicating if currently running
214+
or scheduled computations modifying that node on the context's device have completed.
215+
216+
NOTE: [work_for ctx tn], if work tracking was not registered for [tn], will register work
217+
tracking for [tn] and return the event tracking all currently scheduled computations on
218+
[ctx]'s device. *)
219+
let work_for _ctx _tn = Some Not_implemented_yet
220+
221+
(** TODO: Schedules waiting for the given event on the context's device. *)
222+
let will_wait_for _ctx Not_implemented_yet = ()
223+
176224
type device_state = {
177225
mutable keep_spinning : bool;
178226
mutable device_error : exn option;
@@ -222,6 +270,10 @@ module Multicore_backend (Backend : No_device_backend) : Backend = struct
222270
Option.iter d.device_error ~f:(fun e ->
223271
Exn.reraise e @@ name ^ " device " ^ Int.to_string device.ordinal))
224272

273+
(** TODO: Returns the event indicating if any currently running or scheduled computations on the
274+
device have completed. *)
275+
let all_work _device = Not_implemented_yet
276+
225277
let%track3_l_sexp schedule_task device task =
226278
assert (Domain.is_main_domain ());
227279
[%log_result "schedule_task", Tnode.describe task, "device", (device.ordinal : int)];
@@ -456,7 +508,7 @@ module Multicore_backend (Backend : No_device_backend) : Backend = struct
456508

457509
let num_physical_devices () = Domain.recommended_domain_count () - 1
458510
let suggested_num_virtual_devices _device = 1
459-
let devices = Array.create ~len:(num_physical_devices ()) None
511+
let devices : physical_device option array = Array.create ~len:(num_physical_devices ()) None
460512

461513
let%track2_sexp unsafe_cleanup () =
462514
assert (Domain.is_main_domain ());
@@ -497,8 +549,14 @@ let sync_suggested_num_virtual_devices = ref 1
497549

498550
(** A minimalisitc wrapper creating backends where all calls run synchronously on the main thread.
499551
There is only one physical device, but an arbitrary number of virtual devices. *)
500-
module Sync_backend (Backend : No_device_backend) (* : Backend *) = struct
552+
module Sync_backend (Backend : No_device_backend) : Backend = struct
501553
type buffer_ptr = Backend.buffer_ptr [@@deriving sexp_of]
554+
type event = unit
555+
556+
let await_ev () = ()
557+
let is_done () = true
558+
let work_for _context _tn = Some ()
559+
let will_wait_for _context () = ()
502560

503561
type device = {
504562
subordinal : int;
@@ -516,10 +574,11 @@ module Sync_backend (Backend : No_device_backend) (* : Backend *) = struct
516574

517575
let expected_merge_node (code : code) = Backend.expected_merge_node code
518576
let expected_merge_nodes (codes : code_batch) = Backend.expected_merge_nodes codes
577+
let all_work _device = ()
519578
let is_idle _device = true
520579
let name = "sync " ^ Backend.name
521580
let await _device = ()
522-
let global_run_no = ref 0
581+
(* let global_run_no = ref 0 *)
523582

524583
type context = { device : device; ctx : Backend.context; expected_merge_node : Tnode.t option }
525584
[@@deriving sexp_of]
@@ -934,6 +993,20 @@ module Cuda_backend : Backend = struct
934993
name;
935994
})) )
936995

996+
type event = Cudajit.Event.t
997+
998+
let work_for _ctx _tn = Some (Cudajit.Event.create ())
999+
(* TODO: NOT IMPLEMENTED YET *)
1000+
1001+
let is_done event = Cudajit.Event.query event
1002+
let will_wait_for _context _event = ()
1003+
(* Cudajit.Event.wait (get_ctx_device context.ctx).Cuda_backend.stream event *)
1004+
(* TODO: NOT IMPLEMENTED YET *)
1005+
1006+
let await_ev event = Cudajit.Event.synchronize event
1007+
let all_work _device = Cudajit.Event.create ()
1008+
(* TODO: NOT IMPLEMENTED YET *)
1009+
9371010
let init device = { ctx = init device; expected_merge_node = None }
9381011
let get_ctx_device context = get_ctx_device context.ctx
9391012
let finalize context = finalize context.ctx

arrayjit/lib/writing_a_backend.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515

1616
<!-- /TOC -->
1717

18+
NOTE: these are outdated.
19+
TODO: update regarding events and device-to-device synchronization.
20+
1821
## Design around compiling and running code, backend interfaces
1922

2023
Currently, OCANNL integrates new backends via code in [Backends](backends.ml), so it's the "sink" of backend module dependencies; [Backend_utils](backend_utils.ml) is the "source". `Backend_utils.Types` introduces the context-specific `routine` type, for code executable on a backend. The interface `Backends.No_device_backend` has `compile` functions that take `Assignments.t` as input, to allow full flexibility in backend implementations. There is a helper `Backends.lower_assignments` that wraps `Assignments.lower` and `Low_level.optimize_proc`, since currently all backends use the optimized C-like representation `Low_level.t`. The user-facing interface `Backends.Backend` builds on top of `No_device_backend` providing multi-device functionality. The functor `Multicore_backend` converts a `No_device_backend` targetting the CPU into a `Backend` whose devices are parallel threads (and ultimately the CPU cores).

0 commit comments

Comments
 (0)