Skip to content

Commit a91751b

Browse files
committed
Fix auto transfer for constants
Note: auto transfers currently don't handle multi-device, will need fixing.
1 parent a58eabb commit a91751b

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

arrayjit/lib/backends.ml

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -373,12 +373,22 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
373373
Option.(join @@ map lowered ~f:(fun optim -> optim.Low_level.merge_node)));
374374
}
375375

376-
let%track3_sexp alloc_if_needed (stream : stream) ~key ~data:node ctx_arrays =
376+
let%track3_sexp alloc_if_needed parent_context ~key ~data:node ctx_arrays =
377377
if Tnode.is_in_context_force ~use_host_memory key 43 && not (Map.mem ctx_arrays key) then (
378+
let stream = parent_context.stream in
378379
[%log Tn.debug_name key];
379380
[%log (key : Tnode.t)];
380381
let default () =
381-
alloc_zero_init_array (Lazy.force key.prec) ~dims:(Lazy.force key.dims) stream
382+
let dst_ptr =
383+
alloc_zero_init_array (Lazy.force key.prec) ~dims:(Lazy.force key.dims) stream
384+
in
385+
(if Utils.settings.automatic_host_transfers && Tn.known_constant key then
386+
match key.array with
387+
| (lazy (Some hosted)) ->
388+
Device.from_host ~dst_ptr ~dst:parent_context hosted;
389+
key.host_modified <- false
390+
| _ -> ());
391+
dst_ptr
382392
in
383393
let add_new () = Map.add_exn ctx_arrays ~key ~data:(default ()) in
384394
let device = stream.device in
@@ -423,8 +433,7 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
423433
~from_prior_context:code.from_prior_context;
424434
let (inputs, outputs), merge_buffer_input = Low_level.input_and_output_nodes code.lowered in
425435
let ctx_arrays =
426-
Hashtbl.fold code.lowered.traced_store ~init:context.ctx_arrays
427-
~f:(alloc_if_needed context.stream)
436+
Hashtbl.fold code.lowered.traced_store ~init:context.ctx_arrays ~f:(alloc_if_needed context)
428437
in
429438
let bindings, schedule = link context code.code ctx_arrays in
430439
let context = make_child ~ctx_arrays context in
@@ -443,7 +452,7 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
443452
~f:
444453
(Option.map ~f:(fun l ->
445454
Hashtbl.fold l.Low_level.traced_store ~init:context.ctx_arrays
446-
~f:(alloc_if_needed context.stream)))
455+
~f:(alloc_if_needed context)))
447456
in
448457
let bindings, schedules = link_batch context code_batch.code_batch ctx_arrays in
449458
Array.fold_mapi schedules ~init:context ~f:(fun i context -> function

0 commit comments

Comments
 (0)