@@ -22,7 +22,7 @@ let check_merge_buffer stream ~code_node =
2222 ^ " , expected by code: " ^ name code_node)
2323
2424module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncing ) = struct
25- let wait_for_all ctx streams tn =
25+ let [ @ landmark] wait_for_all ctx streams tn =
2626 let s = ctx.stream in
2727 Hashtbl. update_and_return streams tn
2828 ~f:
@@ -31,15 +31,15 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
3131 |> List. iter ~f: (fun (work_stream , e ) ->
3232 if not (equal_stream work_stream s) then Backend. will_wait_for ctx e)
3333
34- let wait_for_ready ~dst ~src tn =
34+ let [ @ landmark] wait_for_ready ~dst ~src tn =
3535 let s = src.stream in
3636 let d = dst.stream in
3737 (* TODO: maybe it's worthwhile to clean up s.updating_for every now and then. *)
3838 Hashtbl. find s.updating_for tn
3939 |> Option. iter ~f: (fun upd_e ->
4040 if not (equal_stream s d || Backend. is_done upd_e) then Backend. will_wait_for dst upd_e)
4141
42- let update_writer_event ?e ?from s tn =
42+ let [ @ landmark] update_writer_event ?e ?from s tn =
4343 let e = Option. value_or_thunk e ~default: (fun () -> Backend. all_work s) in
4444 let f l = (s, e) :: Option. value ~default: [] l in
4545 (match (from, tn) with
@@ -52,13 +52,14 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
5252 | Node tn ->
5353 if Tn. potentially_cross_stream tn then
5454 Hashtbl. update s.device.shared_writer_streams tn ~f: (fun l ->
55- (s, e) :: Option. value ~default: [] l);
55+ (s, e) :: Option. value ~default: [] l)
56+ else Hashtbl. remove s.device.shared_writer_streams tn;
5657 Hashtbl. update s.updating_for tn ~f: (fun _ -> e)
5758 | Merge_buffer tn ->
5859 (* Note: the previous event does not need to be done! *)
5960 s.updating_for_merge_buffer < - Some (tn, Some e)
6061
61- let % diagn2_l_sexp from_host (ctx : Backend.context ) tn =
62+ let % track2_l_sexp[ @ landmark] from_host (ctx : Backend.context ) tn =
6263 match (tn, Map. find ctx.ctx_arrays tn) with
6364 | { Tn. array = (lazy (Some hosted )); _ } , Some dst ->
6465 wait_for_all ctx ctx.stream.reader_streams tn;
@@ -68,7 +69,7 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
6869 true
6970 | _ -> false
7071
71- let % diagn2_l_sexp to_host (ctx : Backend.context ) (tn : Tn.t ) =
72+ let % track2_l_sexp[ @ landmark] to_host (ctx : Backend.context ) (tn : Tn.t ) =
7273 match (tn, Map. find ctx.ctx_arrays tn) with
7374 | { Tn. array = (lazy (Some hosted )); _ } , Some src ->
7475 if Tn. potentially_cross_stream tn then
@@ -82,8 +83,8 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
8283 true
8384 | _ -> false
8485
85- let % diagn2_l_sexp device_to_device (tn : Tn.t ) ~into_merge_buffer ~(dst : Backend.context )
86- ~(src : Backend.context ) =
86+ let % diagn2_l_sexp[ @ landmark] device_to_device (tn : Tn.t ) ~into_merge_buffer
87+ ~(dst : Backend.context ) ~( src : Backend.context ) =
8788 let ordinal_of ctx = ctx.stream.device.ordinal in
8889 let name_of ctx = Backend. (get_name ctx.stream) in
8990 let same_device = ordinal_of dst = ordinal_of src in
@@ -115,30 +116,40 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
115116 Backend. (
116117 device_to_device tn ~into_merge_buffer ~dst_ptr: None ~dst ~src_ptr: s_arr ~src );
117118 dst.stream.updating_for_merge_buffer < - Some (tn, None );
118- Task. run task;
119+ let [@ landmark] merge_task () = Task. run task in
120+ merge_task () ;
119121 update_writer_event ~from: (`Src src.stream) dst.stream @@ Merge_buffer tn;
120122 [% log " streaming into merge buffer" , Tn. debug_name tn, " from" , name_of src];
121123 true )
122124
123- let % track3_l_sexp sync_routine r =
125+ let % track2_l_sexp sync_routine r =
124126 let s = r.context.stream in
125- let pre () =
126- Hashtbl. filter_mapi_inplace s.device.shared_writer_streams ~f: (fun ~key ~data ->
127- if Tn. potentially_cross_stream key then
128- if Set. mem r.inputs key then (
129- let data = List. filter data ~f: (fun (_ , e ) -> Backend. is_done e) in
130- List. iter data ~f: (fun (work_stream , e ) ->
131- if not (equal_stream work_stream s) then Backend. will_wait_for r.context e);
132- Some data)
133- else Some data
134- else None )
127+ let [@ landmark] pre () =
128+ Set. iter r.inputs ~f: (fun tn ->
129+ if Tn. potentially_cross_stream tn then
130+ Option. iter (Hashtbl. find s.device.shared_writer_streams tn) ~f: (fun data ->
131+ let data = List. filter data ~f: (fun (_ , e ) -> not (Backend. is_done e)) in
132+ Hashtbl. set s.device.shared_writer_streams ~key: tn ~data ;
133+ List. iter data ~f: (fun (work_stream , e ) ->
134+ if not (equal_stream work_stream s) then Backend. will_wait_for r.context e))
135+ else Hashtbl. remove s.device.shared_writer_streams tn)
135136 (* Since merge buffers are always per-stream, no need to check r.merge_buffer_input. *)
136137 in
137- let post () =
138+ let [ @ landmark] post () =
138139 let e = Backend. all_work s in
139140 Set. iter r.outputs ~f: (fun tn -> update_writer_event ~e s @@ Node tn)
140141 in
141142 { r with schedule = Task. (prepend ~work: pre @@ append ~work: post r.schedule) }
143+
144+ let [@ landmark] sync_device device =
145+ Utils. weak_iter device.streams ~f: Backend. await;
146+ Hashtbl. clear device.host_writing_streams;
147+ Hashtbl. clear device.host_reading_streams;
148+ Hashtbl. clear device.shared_writer_streams;
149+ Utils. weak_iter device.streams ~f: (fun s ->
150+ Hashtbl. clear s.reader_streams;
151+ s.updating_for_merge_buffer < - None ;
152+ Hashtbl. clear s.updating_for)
142153end
143154
144155let lower_assignments ?name bindings asgns =
@@ -268,20 +279,20 @@ module Add_device
268279 in
269280 (Option. value_exn ~here: [% here] bindings, schedules)
270281
271- let from_host ~dst_ptr ~dst hosted =
282+ let [ @ landmark] from_host ~dst_ptr ~dst hosted =
272283 let work () = host_to_buffer hosted ~dst: dst_ptr in
273284 (* TODO: pass description to from_host. *)
274285 schedule_task dst.stream
275286 (Task. Task
276287 { context_lifetime = dst; description = " from_host on " ^ get_name dst.stream; work })
277288
278- let to_host ~src_ptr ~src hosted =
289+ let [ @ landmark] to_host ~src_ptr ~src hosted =
279290 let work () = buffer_to_host hosted ~src: src_ptr in
280291 (* TODO: pass description to to_host. *)
281292 schedule_task src.stream
282293 (Task. Task { context_lifetime = src; description = " to_host on " ^ get_name src.stream; work })
283294
284- let device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
295+ let [ @ landmark] device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
285296 let s = dst.stream in
286297 let size_in_bytes = Tnode. size_in_bytes tn in
287298 let work =
@@ -468,7 +479,7 @@ let reinitialize (module Backend : Backend) config =
468479 Stdlib.Gc. full_major () ;
469480 Backend. initialize config)
470481
471- let % track3_sexp finalize (type buffer_ptr dev runner event)
482+ let [ @ landmark] finalize (type buffer_ptr dev runner event)
472483 (module Backend : Backend
473484 with type buffer_ptr = buffer_ptr
474485 and type dev = dev
0 commit comments