@@ -9,16 +9,16 @@ let _get_local_debug_runtime = Utils._get_local_debug_runtime
99[%% global_debug_log_level 9 ]
1010[%% global_debug_log_level_from_env_var " OCANNL_LOG_LEVEL" ]
1111
12- let check_merge_buffer ~ scheduled_node ~code_node =
12+ let check_merge_buffer stream ~code_node =
1313 let name = function Some tn -> Tnode. debug_name tn | None -> " none" in
14- match (scheduled_node , code_node) with
14+ match (stream.scheduled_merge_node , code_node) with
1515 | _ , None -> ()
1616 | Some actual , Some expected when Tnode. equal actual expected -> ()
1717 | _ ->
1818 raise
1919 @@ Utils. User_error
20- (" Merge buffer mismatch, on stream: " ^ name scheduled_node ^ " , expected by code: "
21- ^ name code_node)
20+ (" Merge buffer mismatch, on stream: " ^ name stream.scheduled_merge_node
21+ ^ " , expected by code: " ^ name code_node)
2222
2323module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncing ) = struct
2424 let work_for context tn =
@@ -177,11 +177,17 @@ module Add_device
177177 let link context (code : code ) ctx_arrays =
178178 let runner_label = get_name context.stream in
179179 let merge_buffer = context.stream.merge_buffer in
180- match code with
181- | Postponed { lowered; bindings; name } ->
182- let proc = Backend. compile ~name ~opt_ctx_arrays: (Some ctx_arrays) bindings lowered in
183- link_compiled ~merge_buffer ~runner_label ctx_arrays proc
184- | Compiled { proc; _ } -> link_compiled ~merge_buffer ~runner_label ctx_arrays proc
180+ let bindings, to_schedule =
181+ match code with
182+ | Postponed { lowered; bindings; name } ->
183+ let proc = Backend. compile ~name ~opt_ctx_arrays: (Some ctx_arrays) bindings lowered in
184+ link_compiled ~merge_buffer ~runner_label ctx_arrays proc
185+ | Compiled { proc; _ } -> link_compiled ~merge_buffer ~runner_label ctx_arrays proc
186+ in
187+ let schedule =
188+ Task. enschedule ~schedule_task ~get_stream_name: get_name context.stream to_schedule
189+ in
190+ (bindings, schedule)
185191
186192 let link_batch context (code_batch : code_batch ) ctx_arrays =
187193 let runner_label = get_name context.stream in
@@ -196,8 +202,13 @@ module Add_device
196202 Array. fold_mapi procs ~init: None ~f: (fun i bindings -> function
197203 | Some proc ->
198204 let ctx_arrays = Option. value_exn ctx_arrays.(i) in
199- let bindings', schedule = link_compiled ~merge_buffer ~runner_label ctx_arrays proc in
205+ let bindings', to_schedule =
206+ link_compiled ~merge_buffer ~runner_label ctx_arrays proc
207+ in
200208 Option. iter bindings ~f: (fun bindings -> assert (phys_equal bindings bindings'));
209+ let schedule =
210+ Task. enschedule ~schedule_task ~get_stream_name: get_name context.stream to_schedule
211+ in
201212 (Some bindings', Some schedule)
202213 | None -> (bindings, None ))
203214 in
@@ -219,12 +230,13 @@ module Add_device
219230 let device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
220231 let dev = dst.stream in
221232 let size_in_bytes = Tnode. size_in_bytes tn in
233+ (* FIXME(#290): handle shared_merge_node. *)
222234 let work =
223- (* TODO: log the operation if [Utils.settings.with_log_level > 0 ]. *)
235+ (* TODO: log the operation if [Utils.settings.with_log_level > 1 ]. *)
224236 match (into_merge_buffer, dst_ptr) with
225237 | No , None -> invalid_arg " Multicore_scheduler.device_to_device: missing dst_ptr"
226238 | No , Some dst_ptr -> fun () -> buffer_to_buffer ~dst: dst_ptr ~src: src_ptr ~size_in_bytes
227- | Streaming , _ -> fun () -> dev.merge_buffer := Some ( src_ptr, tn)
239+ | Streaming , _ -> fun () -> dev.merge_buffer := Some { ptr = src_ptr; size_in_bytes }
228240 | Copy , _ ->
229241 fun () ->
230242 let size_in_bytes = Tnode. size_in_bytes tn in
@@ -235,13 +247,14 @@ module Add_device
235247 dev.allocated_buffer < -
236248 Some (alloc_buffer ?old_buffer:dev.allocated_buffer ~size_in_bytes dst.stream );
237249 let merge_ptr = (Option. value_exn dev.allocated_buffer).ptr in
238- dev.merge_buffer := Some (merge_ptr, tn) ;
250+ dev.merge_buffer := dev.allocated_buffer ;
239251 buffer_to_buffer ~dst: merge_ptr ~src: src_ptr ~size_in_bytes
240252 in
241253 let description =
242254 " device_to_device " ^ Tnode. debug_name tn ^ " dst " ^ get_name dev ^ " src "
243255 ^ get_name src.stream
244256 in
257+ (match into_merge_buffer with No -> () | _ -> dev.scheduled_merge_node < - Some tn);
245258 schedule_task dev (Task. Task { context_lifetime = (src, dst); description; work })
246259end
247260
@@ -344,9 +357,7 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
344357 let context = make_child ~ctx_arrays context in
345358 let schedule =
346359 Task. prepend schedule ~work: (fun () ->
347- check_merge_buffer
348- ~scheduled_node: (scheduled_merge_node context.stream)
349- ~code_node: code.expected_merge_node)
360+ check_merge_buffer context.stream ~code_node: code.expected_merge_node)
350361 in
351362 { context; schedule; bindings; name = code.name; inputs; outputs }
352363
@@ -372,9 +383,7 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
372383 in
373384 let schedule =
374385 Task. prepend schedule ~work: (fun () ->
375- check_merge_buffer
376- ~scheduled_node: (scheduled_merge_node context.stream)
377- ~code_node: expected_merge_node)
386+ check_merge_buffer context.stream ~code_node: expected_merge_node)
378387 in
379388 (context, Some { context; schedule; bindings; name; inputs; outputs }))
380389end
0 commit comments