@@ -30,32 +30,40 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
3030 |> List. iter ~f: (fun (work_stream , e ) ->
3131 if not (equal_stream work_stream s) then Backend. will_wait_for ctx e)
3232
33- let update_writer_event ?(from_host = false ) s tn =
34- let e = Backend. all_work s in
35- if from_host then
36- Hashtbl. update s.device.host_writing_streams tn ~f: (fun l ->
37- (s, e) :: Option. value ~default: [] l);
38- (* To be on the safe side, record events for potentially cross-stream nodes. *)
39- if Tn. potentially_cross_stream tn then
40- Hashtbl. update s.device.shared_writer_streams tn ~f: (fun l ->
41- (s, e) :: Option. value ~default: [] l);
42- Hashtbl. update s.updating_for tn ~f: (fun _ -> e)
43-
44- let add_reader s tn from =
33+ let wait_for_ready ~dst ~src tn =
34+ let s = src.stream in
35+ let d = dst.stream in
36+ Hashtbl. find s.updating_for tn
37+ |> Option. iter ~f: (fun upd_e ->
38+ if not (equal_stream s d || Backend. is_done upd_e) then Backend. will_wait_for dst upd_e)
39+
40+ let update_writer_event ?from s tn =
4541 let e = Backend. all_work s in
4642 let f l = (s, e) :: Option. value ~default: [] l in
47- match from with
48- | `Host -> Hashtbl. update s.device.host_reading_streams tn ~f
49- | `Src src -> Hashtbl. update src.reader_streams tn ~f
43+ (match (from, tn) with
44+ | None , _ -> ()
45+ | Some `Host , Assignments. (Node tn | Merge_buffer tn ) ->
46+ Hashtbl. update s.device.host_reading_streams tn ~f
47+ | Some (`Src src ), (Node tn | Merge_buffer tn ) -> Hashtbl. update src.reader_streams tn ~f );
48+ (* To be on the safe side, record events for potentially cross-stream nodes. *)
49+ match tn with
50+ | Node tn ->
51+ if Tn. potentially_cross_stream tn then
52+ Hashtbl. update s.device.shared_writer_streams tn ~f: (fun l ->
53+ (s, e) :: Option. value ~default: [] l);
54+ Hashtbl. update s.updating_for tn ~f: (fun _ -> e)
55+ | Merge_buffer tn ->
56+ Option. iter s.updating_for_merge_buffer ~f: (fun (_old_tn , old_e ) ->
57+ assert (Backend. is_done old_e));
58+ s.updating_for_merge_buffer < - Some (tn, e)
5059
5160 let % diagn2_l_sexp from_host (ctx : Backend.context ) tn =
5261 match (tn, Map. find ctx.ctx_arrays tn) with
5362 | { Tn. array = (lazy (Some hosted )); _ } , Some dst ->
5463 wait_for_all ctx ctx.stream.reader_streams tn;
5564 [% log " copying" , Tn. debug_name tn, " to" , (dst : Backend.buffer_ptr ), " from host" ];
5665 Backend. from_host ~dst_ptr: dst ~dst: ctx hosted;
57- update_writer_event ~from_host: true ctx.stream tn;
58- add_reader ctx.stream tn @@ `Host ;
66+ update_writer_event ~from: `Host ctx.stream @@ Node tn;
5967 true
6068 | _ -> false
6169
@@ -66,6 +74,10 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
6674 wait_for_all ctx ctx.stream.device.shared_writer_streams tn;
6775 [% log " copying" , Tn. debug_name tn, " at" , (src : Backend.buffer_ptr ), " to host" ];
6876 Backend. to_host ~src_ptr: src ~src: ctx hosted;
77+ let s = ctx.stream in
78+ let e = Backend. all_work s in
79+ Hashtbl. update s.device.host_writing_streams tn ~f: (fun l ->
80+ (s, e) :: Option. value ~default: [] l);
6981 true
7082 | _ -> false
7183
@@ -80,6 +92,7 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
8092 match Map. find src.ctx_arrays tn with
8193 | None -> false
8294 | Some s_arr -> (
95+ wait_for_ready ~dst ~src tn;
8396 match into_merge_buffer with
8497 | No -> (
8598 match Map. find dst.ctx_arrays tn with
@@ -88,8 +101,9 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
88101 Backend. (
89102 device_to_device tn ~into_merge_buffer ~dst_ptr: (Some d_arr) ~dst ~src_ptr: s_arr
90103 ~src );
104+ update_writer_event ~from: (`Src src.stream) dst.stream @@ Node tn;
91105 [% log
92- " copied " ,
106+ " copying " ,
93107 Tn. debug_name tn,
94108 " from" ,
95109 name_of src,
@@ -106,7 +120,8 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
106120 | Copy | Streaming ->
107121 Backend. (
108122 device_to_device tn ~into_merge_buffer ~dst_ptr: None ~dst ~src_ptr: s_arr ~src );
109- [% log " copied into merge buffer" , Tn. debug_name tn, " from" , name_of src];
123+ update_writer_event ~from: (`Src src.stream) dst.stream @@ Merge_buffer tn;
124+ [% log " copying into merge buffer" , Tn. debug_name tn, " from" , name_of src];
110125 true )
111126end
112127
@@ -371,7 +386,7 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
371386 let % debug3_sexp link context (code : code ) =
372387 verify_prior_context ~use_host_memory ~ctx_arrays: context.ctx_arrays
373388 ~from_prior_context: code.from_prior_context;
374- let inputs, outputs = Low_level. input_and_output_nodes code.lowered in
389+ let ( inputs, outputs), merge_buffer_input = Low_level. input_and_output_nodes code.lowered in
375390 let ctx_arrays =
376391 Hashtbl. fold code.lowered.traced_store ~init: context.ctx_arrays
377392 ~f: (alloc_if_needed context.stream)
@@ -382,7 +397,7 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
382397 Task. prepend schedule ~work: (fun () ->
383398 check_merge_buffer context.stream ~code_node: code.expected_merge_node)
384399 in
385- { context; schedule; bindings; name = code.name; inputs; outputs }
400+ { context; schedule; bindings; name = code.name; inputs; merge_buffer_input; outputs }
386401
387402 let % debug3_sexp link_batch context code_batch =
388403 verify_prior_context ~use_host_memory ~ctx_arrays: context.ctx_arrays
@@ -401,14 +416,14 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
401416 let ctx_arrays = Option. value_exn ctx_arrays.(i) in
402417 let context = make_child ~ctx_arrays context in
403418 let expected_merge_node = code_batch.expected_merge_nodes.(i) in
404- let inputs, outputs =
419+ let ( inputs, outputs), merge_buffer_input =
405420 Low_level. input_and_output_nodes @@ Option. value_exn code_batch.lowereds.(i)
406421 in
407422 let schedule =
408423 Task. prepend schedule ~work: (fun () ->
409424 check_merge_buffer context.stream ~code_node: expected_merge_node)
410425 in
411- (context, Some { context; schedule; bindings; name; inputs; outputs }))
426+ (context, Some { context; schedule; bindings; name; inputs; merge_buffer_input; outputs }))
412427end
413428
414429module Cuda_backend : Backend = Raise_backend ((Cuda_backend : Lowered_backend ))
0 commit comments