Skip to content

Commit aee62bc

Browse files
committed
Untested: synchronization for routines
1 parent 164e9ef commit aee62bc

File tree

4 files changed

+72
-25
lines changed

4 files changed

+72
-25
lines changed

arrayjit/lib/anatomy_of_a_backend.md

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -141,19 +141,38 @@ When using the default stream, CUDA would predictably write to the standard outp
141141

142142
## Synchronization and data transfers
143143

144-
OCANNL expects backends to implement FIFO queue scheduling, and an event mechanism for synchronizing between streams (and ideally devices), matching the CUDA specification. On top of events, OCANNL implements per-tensor-node synchronization, using the fields `reader_streams` and `writer_streams` of the device record, and `updating_for` of the stream record.
144+
OCANNL expects backends to implement FIFO queue scheduling, and an event mechanism for synchronizing between streams (and ideally devices), matching the CUDA specification. On top of events, OCANNL implements per-tensor-node synchronization. 1/3rd of the `device` fields have to do with synchronization:
145+
146+
```ocaml
147+
mutable scheduled_shared_merge_node : (Tnode.t * 'event option) option;
148+
(** The tensor node that was most recently scheduled to be in the cross-stream merge buffer,
149+
and its readiness event. *)
150+
shared_writer_streams :
151+
(('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t;
152+
(** The streams that most recently have been scheduled to update (write to) a
153+
cross-stream-shared node, and the associated update completion event. The completed events
154+
are removed opportunistically. *)
155+
host_reading_streams :
156+
(('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t;
157+
(** The streams that most recently have been reading from a node's on-host array. The
158+
completed events are removed opportunistically. *)
159+
host_writing_streams :
160+
(('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t;
161+
(** The streams that most recently have been writing to a node's on-host array. The completed
162+
events are removed opportunistically. *)
163+
```
164+
165+
and 1/3rd of the stream fields also:
145166

146167
```ocaml
147-
...
148-
writer_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t;
149-
(** The streams that most recently have been scheduled to update (write to) the node, and the
150-
associated update completion event. The completed events are removed opportunistically. *)
151-
reader_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t;
152-
(** The streams that most recently have been reading from the node, and the associated use
153-
completion events. The completed events are removed opportunistically. *)
154-
...
155168
updating_for : 'event Hashtbl.M(Tnode).t;
156169
(* The completion event for updating (writing to) a node via this stream, if any. *)
170+
mutable updating_for_merge_buffer : (Tnode.t * 'event) option;
171+
(** Like {!field-updating_for}, but for the merge buffer. *)
172+
reader_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t;
173+
(** The streams, other than this stream, that most recently have been reading from a node in
174+
this stream's context, and the associated use completion events. The completed events are
175+
removed opportunistically. *)
157176
```
158177

159178
Besides routines, calling `from_host`, `to_host`, `device_to_device` from a backend puts the corresponding tasks on the device's queue. Both invoking a routine and calling these copying functions will perform the necessary event creations and synchronizations to ensure that when scheduling writing into an array precedes scheduling reading from it, the actual writing also precedes the actual reading.

arrayjit/lib/backend_intf.ml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,9 @@ type ('buffer_ptr, 'dev, 'runner, 'event) device = {
8585
mutable shared_merge_buffer : 'buffer_ptr buffer option;
8686
(** Depending on backend implementations, either the currently used cross-stream merge buffer,
8787
or the one most recently scheduled. *)
88-
mutable scheduled_shared_merge_node : Tnode.t option;
89-
(** The tensor node that was most recently scheduled to be in the cross-stream merge buffer. *)
88+
mutable scheduled_shared_merge_node : (Tnode.t * 'event option) option;
89+
(** The tensor node that was most recently scheduled to be in the cross-stream merge buffer,
90+
and its readiness event. *)
9091
mutable latest_stream_id : int;
9192
released : Utils.atomic_bool;
9293
cross_stream_candidates : 'buffer_ptr Hashtbl.M(Tnode).t;
@@ -123,7 +124,7 @@ and ('buffer_ptr, 'dev, 'runner, 'event) stream = {
123124
stream_id : int; (** An ID unique within the device. *)
124125
mutable allocated_buffer : 'buffer_ptr buffer option;
125126
updating_for : 'event Hashtbl.M(Tnode).t;
126-
(* The completion event for updating (writing to) a node via this stream, if any. *)
127+
(* The completion event for updating (writing to) a node via this stream, if any. *)
127128
mutable updating_for_merge_buffer : (Tnode.t * 'event) option;
128129
(** Like {!field-updating_for}, but for the merge buffer. *)
129130
reader_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t;

arrayjit/lib/backends.ml

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,13 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
3333
let wait_for_ready ~dst ~src tn =
3434
let s = src.stream in
3535
let d = dst.stream in
36+
(* TODO: maybe it's worthwhile to clean up s.updating_for every now and then. *)
3637
Hashtbl.find s.updating_for tn
3738
|> Option.iter ~f:(fun upd_e ->
3839
if not (equal_stream s d || Backend.is_done upd_e) then Backend.will_wait_for dst upd_e)
3940

40-
let update_writer_event ?from s tn =
41-
let e = Backend.all_work s in
41+
let update_writer_event ?e ?from s tn =
42+
let e = Option.value_or_thunk e ~default:(fun () -> Backend.all_work s) in
4243
let f l = (s, e) :: Option.value ~default:[] l in
4344
(match (from, tn) with
4445
| None, _ -> ()
@@ -102,15 +103,7 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
102103
device_to_device tn ~into_merge_buffer ~dst_ptr:(Some d_arr) ~dst ~src_ptr:s_arr
103104
~src);
104105
update_writer_event ~from:(`Src src.stream) dst.stream @@ Node tn;
105-
[%log
106-
"copying",
107-
Tn.debug_name tn,
108-
"from",
109-
name_of src,
110-
"at",
111-
(s_arr : Backend.buffer_ptr),
112-
"to",
113-
(d_arr : Backend.buffer_ptr)];
106+
[%log "copying", Tn.debug_name tn, "from", name_of src, "to", name_of dst];
114107
true)
115108
| Streaming when same_device ->
116109
Backend.(
@@ -123,6 +116,26 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
123116
update_writer_event ~from:(`Src src.stream) dst.stream @@ Merge_buffer tn;
124117
[%log "copying into merge buffer", Tn.debug_name tn, "from", name_of src];
125118
true)
119+
120+
let%track3_l_sexp sync_routine r =
121+
let s = r.context.stream in
122+
let pre () =
123+
Hashtbl.iteri s.device.shared_writer_streams ~f:(fun ~key ~data ->
124+
if Set.mem r.inputs key then
125+
List.iter data ~f:(fun (work_stream, e) ->
126+
if not (equal_stream work_stream s) then Backend.will_wait_for r.context e));
127+
if r.merge_buffer_input then
128+
Option.iter s.device.scheduled_shared_merge_node ~f:(fun (shared_tn, e) ->
129+
match (s.scheduled_merge_node, e) with
130+
| Some merge_tn, Some e ->
131+
if Tn.equal shared_tn merge_tn then Backend.will_wait_for r.context e
132+
| _ -> ())
133+
in
134+
let post () =
135+
let e = Backend.all_work s in
136+
Set.iter r.outputs ~f:(fun tn -> update_writer_event ~e s @@ Node tn)
137+
in
138+
{ r with schedule = Task.(prepend ~work:pre @@ append ~work:post r.schedule) }
126139
end
127140

128141
let lower_assignments ?name bindings asgns =
@@ -397,7 +410,8 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
397410
Task.prepend schedule ~work:(fun () ->
398411
check_merge_buffer context.stream ~code_node:code.expected_merge_node)
399412
in
400-
{ context; schedule; bindings; name = code.name; inputs; merge_buffer_input; outputs }
413+
sync_routine
414+
{ context; schedule; bindings; name = code.name; inputs; merge_buffer_input; outputs }
401415

402416
let%debug3_sexp link_batch context code_batch =
403417
verify_prior_context ~use_host_memory ~ctx_arrays:context.ctx_arrays
@@ -423,7 +437,10 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
423437
Task.prepend schedule ~work:(fun () ->
424438
check_merge_buffer context.stream ~code_node:expected_merge_node)
425439
in
426-
(context, Some { context; schedule; bindings; name; inputs; merge_buffer_input; outputs }))
440+
let r =
441+
sync_routine { context; schedule; bindings; name; inputs; merge_buffer_input; outputs }
442+
in
443+
(context, Some r))
427444
end
428445

429446
module Cuda_backend : Backend = Raise_backend ((Cuda_backend : Lowered_backend))

arrayjit/lib/task.ml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,16 @@ let prepend ~work (Task task) =
2727
task.work ());
2828
}
2929

30+
let append ~work (Task task) =
31+
Task
32+
{
33+
task with
34+
work =
35+
(fun () ->
36+
task.work ();
37+
work ());
38+
}
39+
3040
let%track3_l_sexp enschedule ~schedule_task ~get_stream_name stream
3141
(Task { description; _ } as task) =
3242
[%log_result "enschedule", description, "on", get_stream_name stream];

0 commit comments

Comments
 (0)