@@ -373,12 +373,22 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
373373 Option. (join @@ map lowered ~f: (fun optim -> optim.Low_level. merge_node)));
374374 }
375375
376- let % track3_sexp alloc_if_needed (stream : stream ) ~key ~data: node ctx_arrays =
376+ let % track3_sexp alloc_if_needed parent_context ~key ~data: node ctx_arrays =
377377 if Tnode. is_in_context_force ~use_host_memory key 43 && not (Map. mem ctx_arrays key) then (
378+ let stream = parent_context.stream in
378379 [% log Tn. debug_name key];
379380 [% log (key : Tnode.t )];
380381 let default () =
381- alloc_zero_init_array (Lazy. force key.prec) ~dims: (Lazy. force key.dims) stream
382+ let dst_ptr =
383+ alloc_zero_init_array (Lazy. force key.prec) ~dims: (Lazy. force key.dims) stream
384+ in
385+ (if Utils. settings.automatic_host_transfers && Tn. known_constant key then
386+ match key.array with
387+ | (lazy (Some hosted )) ->
388+ Device. from_host ~dst_ptr ~dst: parent_context hosted;
389+ key.host_modified < - false
390+ | _ -> () );
391+ dst_ptr
382392 in
383393 let add_new () = Map. add_exn ctx_arrays ~key ~data: (default () ) in
384394 let device = stream.device in
@@ -423,8 +433,7 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
423433 ~from_prior_context: code.from_prior_context;
424434 let (inputs, outputs), merge_buffer_input = Low_level. input_and_output_nodes code.lowered in
425435 let ctx_arrays =
426- Hashtbl. fold code.lowered.traced_store ~init: context.ctx_arrays
427- ~f: (alloc_if_needed context.stream)
436+ Hashtbl. fold code.lowered.traced_store ~init: context.ctx_arrays ~f: (alloc_if_needed context)
428437 in
429438 let bindings, schedule = link context code.code ctx_arrays in
430439 let context = make_child ~ctx_arrays context in
@@ -443,7 +452,7 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
443452 ~f:
444453 (Option. map ~f: (fun l ->
445454 Hashtbl. fold l.Low_level. traced_store ~init: context.ctx_arrays
446- ~f: (alloc_if_needed context.stream )))
455+ ~f: (alloc_if_needed context)))
447456 in
448457 let bindings, schedules = link_batch context code_batch.code_batch ctx_arrays in
449458 Array. fold_mapi schedules ~init: context ~f: (fun i context -> function
0 commit comments