Skip to content

Commit 164e9ef

Browse files
committed
Proper syncing for from_host, to_host and device_to_device
1 parent b116838 commit 164e9ef

File tree

5 files changed

+63
-41
lines changed

5 files changed

+63
-41
lines changed

arrayjit/lib/backend_impl.ml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ struct
121121
stream_id;
122122
allocated_buffer = None;
123123
updating_for = Hashtbl.create (module Tnode);
124+
updating_for_merge_buffer = None;
124125
reader_streams = Hashtbl.create (module Tnode);
125126
}
126127

arrayjit/lib/backend_intf.ml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ type 'context routine = {
5757
inputs : Set.M(Tnode).t;
5858
(** The materialized read-only and read-before-write (within the routine) non-constant nodes.
5959
They are inputs in a broad sense, as they could be recurrent nodes or parameters. *)
60+
merge_buffer_input : bool; (** Similar to {!field-inputs}, for the merge buffer. *)
6061
outputs : Set.M(Tnode).t; (** All the materialized nodes written-to by the routine. *)
6162
}
6263
[@@deriving sexp_of]
@@ -123,6 +124,8 @@ and ('buffer_ptr, 'dev, 'runner, 'event) stream = {
123124
mutable allocated_buffer : 'buffer_ptr buffer option;
124125
updating_for : 'event Hashtbl.M(Tnode).t;
125126
(* The completion event for updating (writing to) a node via this stream, if any. *)
127+
mutable updating_for_merge_buffer : (Tnode.t * 'event) option;
128+
(** Like {!field-updating_for}, but for the merge buffer. *)
126129
reader_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t;
127130
(** The streams, other than this stream, that most recently have been reading from a node in
128131
this stream's context, and the associated use completion events. The completed events are
@@ -274,7 +277,7 @@ module type With_buffer_retrieval_and_syncing = sig
274277
- If [into_merge_buffer=Streaming], remembers the buffer pointer of the source node to use for
275278
streaming.
276279
- If [into_merge_buffer=Copy], schedules copying from [src] to the merge buffer of [dst]'s
277-
stream, and registers [dst.stream] with a reader event for the node.
280+
stream, and updates the writer event for the merge buffer.
278281
279282
NOTE: If [into_merge_buffer=Streaming], after scheduling the work on [dst] using the merge
280283
buffer but before scheduling work on [src] that modifies [tn], execute

arrayjit/lib/backends.ml

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -30,32 +30,40 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
3030
|> List.iter ~f:(fun (work_stream, e) ->
3131
if not (equal_stream work_stream s) then Backend.will_wait_for ctx e)
3232

33-
let update_writer_event ?(from_host = false) s tn =
34-
let e = Backend.all_work s in
35-
if from_host then
36-
Hashtbl.update s.device.host_writing_streams tn ~f:(fun l ->
37-
(s, e) :: Option.value ~default:[] l);
38-
(* To be on the safe side, record events for potentially cross-stream nodes. *)
39-
if Tn.potentially_cross_stream tn then
40-
Hashtbl.update s.device.shared_writer_streams tn ~f:(fun l ->
41-
(s, e) :: Option.value ~default:[] l);
42-
Hashtbl.update s.updating_for tn ~f:(fun _ -> e)
43-
44-
let add_reader s tn from =
33+
let wait_for_ready ~dst ~src tn =
34+
let s = src.stream in
35+
let d = dst.stream in
36+
Hashtbl.find s.updating_for tn
37+
|> Option.iter ~f:(fun upd_e ->
38+
if not (equal_stream s d || Backend.is_done upd_e) then Backend.will_wait_for dst upd_e)
39+
40+
let update_writer_event ?from s tn =
4541
let e = Backend.all_work s in
4642
let f l = (s, e) :: Option.value ~default:[] l in
47-
match from with
48-
| `Host -> Hashtbl.update s.device.host_reading_streams tn ~f
49-
| `Src src -> Hashtbl.update src.reader_streams tn ~f
43+
(match (from, tn) with
44+
| None, _ -> ()
45+
| Some `Host, Assignments.(Node tn | Merge_buffer tn) ->
46+
Hashtbl.update s.device.host_reading_streams tn ~f
47+
| Some (`Src src), (Node tn | Merge_buffer tn) -> Hashtbl.update src.reader_streams tn ~f);
48+
(* To be on the safe side, record events for potentially cross-stream nodes. *)
49+
match tn with
50+
| Node tn ->
51+
if Tn.potentially_cross_stream tn then
52+
Hashtbl.update s.device.shared_writer_streams tn ~f:(fun l ->
53+
(s, e) :: Option.value ~default:[] l);
54+
Hashtbl.update s.updating_for tn ~f:(fun _ -> e)
55+
| Merge_buffer tn ->
56+
Option.iter s.updating_for_merge_buffer ~f:(fun (_old_tn, old_e) ->
57+
assert (Backend.is_done old_e));
58+
s.updating_for_merge_buffer <- Some (tn, e)
5059

5160
let%diagn2_l_sexp from_host (ctx : Backend.context) tn =
5261
match (tn, Map.find ctx.ctx_arrays tn) with
5362
| { Tn.array = (lazy (Some hosted)); _ }, Some dst ->
5463
wait_for_all ctx ctx.stream.reader_streams tn;
5564
[%log "copying", Tn.debug_name tn, "to", (dst : Backend.buffer_ptr), "from host"];
5665
Backend.from_host ~dst_ptr:dst ~dst:ctx hosted;
57-
update_writer_event ~from_host:true ctx.stream tn;
58-
add_reader ctx.stream tn @@ `Host;
66+
update_writer_event ~from:`Host ctx.stream @@ Node tn;
5967
true
6068
| _ -> false
6169

@@ -66,6 +74,10 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
6674
wait_for_all ctx ctx.stream.device.shared_writer_streams tn;
6775
[%log "copying", Tn.debug_name tn, "at", (src : Backend.buffer_ptr), "to host"];
6876
Backend.to_host ~src_ptr:src ~src:ctx hosted;
77+
let s = ctx.stream in
78+
let e = Backend.all_work s in
79+
Hashtbl.update s.device.host_writing_streams tn ~f:(fun l ->
80+
(s, e) :: Option.value ~default:[] l);
6981
true
7082
| _ -> false
7183

@@ -80,6 +92,7 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
8092
match Map.find src.ctx_arrays tn with
8193
| None -> false
8294
| Some s_arr -> (
95+
wait_for_ready ~dst ~src tn;
8396
match into_merge_buffer with
8497
| No -> (
8598
match Map.find dst.ctx_arrays tn with
@@ -88,8 +101,9 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
88101
Backend.(
89102
device_to_device tn ~into_merge_buffer ~dst_ptr:(Some d_arr) ~dst ~src_ptr:s_arr
90103
~src);
104+
update_writer_event ~from:(`Src src.stream) dst.stream @@ Node tn;
91105
[%log
92-
"copied",
106+
"copying",
93107
Tn.debug_name tn,
94108
"from",
95109
name_of src,
@@ -106,7 +120,8 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
106120
| Copy | Streaming ->
107121
Backend.(
108122
device_to_device tn ~into_merge_buffer ~dst_ptr:None ~dst ~src_ptr:s_arr ~src);
109-
[%log "copied into merge buffer", Tn.debug_name tn, "from", name_of src];
123+
update_writer_event ~from:(`Src src.stream) dst.stream @@ Merge_buffer tn;
124+
[%log "copying into merge buffer", Tn.debug_name tn, "from", name_of src];
110125
true)
111126
end
112127

@@ -371,7 +386,7 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
371386
let%debug3_sexp link context (code : code) =
372387
verify_prior_context ~use_host_memory ~ctx_arrays:context.ctx_arrays
373388
~from_prior_context:code.from_prior_context;
374-
let inputs, outputs = Low_level.input_and_output_nodes code.lowered in
389+
let (inputs, outputs), merge_buffer_input = Low_level.input_and_output_nodes code.lowered in
375390
let ctx_arrays =
376391
Hashtbl.fold code.lowered.traced_store ~init:context.ctx_arrays
377392
~f:(alloc_if_needed context.stream)
@@ -382,7 +397,7 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
382397
Task.prepend schedule ~work:(fun () ->
383398
check_merge_buffer context.stream ~code_node:code.expected_merge_node)
384399
in
385-
{ context; schedule; bindings; name = code.name; inputs; outputs }
400+
{ context; schedule; bindings; name = code.name; inputs; merge_buffer_input; outputs }
386401

387402
let%debug3_sexp link_batch context code_batch =
388403
verify_prior_context ~use_host_memory ~ctx_arrays:context.ctx_arrays
@@ -401,14 +416,14 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
401416
let ctx_arrays = Option.value_exn ctx_arrays.(i) in
402417
let context = make_child ~ctx_arrays context in
403418
let expected_merge_node = code_batch.expected_merge_nodes.(i) in
404-
let inputs, outputs =
419+
let (inputs, outputs), merge_buffer_input =
405420
Low_level.input_and_output_nodes @@ Option.value_exn code_batch.lowereds.(i)
406421
in
407422
let schedule =
408423
Task.prepend schedule ~work:(fun () ->
409424
check_merge_buffer context.stream ~code_node:expected_merge_node)
410425
in
411-
(context, Some { context; schedule; bindings; name; inputs; outputs }))
426+
(context, Some { context; schedule; bindings; name; inputs; merge_buffer_input; outputs }))
412427
end
413428

414429
module Cuda_backend : Backend = Raise_backend ((Cuda_backend : Lowered_backend))

arrayjit/lib/low_level.ml

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -738,22 +738,25 @@ type optimized = { traced_store : traced_store; llc : t; merge_node : Tn.t optio
738738
[@@deriving sexp_of]
739739

740740
let input_and_output_nodes optimized =
741-
Hashtbl.fold optimized.traced_store
742-
~init:(Set.empty (module Tn), Set.empty (module Tn))
743-
~f:(fun ~key ~data (inputs, outputs) ->
744-
let materialized = Tn.is_materialized_force key 50 in
745-
let inputs =
746-
if
747-
materialized && (not (Tn.known_constant key)) && (data.read_only || data.read_before_write)
748-
then Set.add inputs key
749-
else inputs
750-
in
751-
let outputs =
752-
if materialized && (data.zeroed_out || not (Hash_set.is_empty data.assignments)) then
753-
Set.add outputs key
754-
else outputs
755-
in
756-
(inputs, outputs))
741+
( Hashtbl.fold optimized.traced_store
742+
~init:(Set.empty (module Tn), Set.empty (module Tn))
743+
~f:(fun ~key ~data (inputs, outputs) ->
744+
let materialized = Tn.is_materialized_force key 50 in
745+
let inputs =
746+
if
747+
materialized
748+
&& (not (Tn.known_constant key))
749+
&& (data.read_only || data.read_before_write)
750+
then Set.add inputs key
751+
else inputs
752+
in
753+
let outputs =
754+
if materialized && (data.zeroed_out || not (Hash_set.is_empty data.assignments)) then
755+
Set.add outputs key
756+
else outputs
757+
in
758+
(inputs, outputs)),
759+
Option.is_some optimized.merge_node )
757760

758761
let%diagn2_sexp optimize_proc static_indices llc =
759762
let traced_store = Hashtbl.create (module Tnode) in

arrayjit/lib/low_level.mli

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ val optimize :
9999
t ->
100100
optimized
101101

102-
val input_and_output_nodes : optimized -> Set.M(Tnode).t * Set.M(Tnode).t
102+
val input_and_output_nodes : optimized -> (Set.M(Tnode).t * Set.M(Tnode).t) * bool
103103
(** Inputs are the materialized read-only and read-before-write (within the code) non-constant
104104
nodes. They are inputs in a broad sense, as they could be recurrent nodes or parameters.
105105

0 commit comments

Comments
 (0)