Skip to content

Commit 7d333cd

Browse files
committed
Fix auto transfer from/to host in presence of multiple devices
1 parent a91751b commit 7d333cd

File tree

4 files changed

+21
-9
lines changed

4 files changed

+21
-9
lines changed

arrayjit/lib/backend_impl.ml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ module Device_types (Device_config : Device_config) = struct
9393
type nonrec context = (buffer_ptr, stream) context [@@deriving sexp_of]
9494
end
9595

96+
let next_global_device_id : Utils.atomic_int = Atomic.make 0
97+
9698
module Device
9799
(Device_types : Device_types)
98100
(Alloc_buffer :
@@ -104,9 +106,11 @@ struct
104106
include Alloc_buffer
105107

106108
let make_device dev ~ordinal =
109+
let device_id = Atomic.fetch_and_add next_global_device_id 1 in
107110
{
108111
dev;
109112
ordinal;
113+
device_id;
110114
cross_stream_candidates = Hashtbl.create (module Tnode);
111115
owner_stream = Hashtbl.create (module Tnode);
112116
shared_writer_streams = Hashtbl.create (module Tnode);

arrayjit/lib/backend_intf.ml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ end
8080
type ('buffer_ptr, 'dev, 'runner, 'event) device_ref = {
8181
dev : 'dev;
8282
ordinal : int;
83+
device_id : int;
8384
cross_stream_candidates : 'buffer_ptr Hashtbl.M(Tnode).t;
8485
owner_stream : ('buffer_ptr, 'dev, 'runner, 'event) stream_ref Hashtbl.M(Tnode).t;
8586
shared_writer_streams :
@@ -111,6 +112,11 @@ type ('buffer_ptr, 'dev, 'runner, 'event) device =
111112
('buffer_ptr, 'dev, 'runner, 'event) device_ref = {
112113
dev : 'dev;
113114
ordinal : int;
115+
(** The number of the represented backend's device, in the range from 0 to the number of the
116+
backend's devices - 1. *)
117+
device_id : int;
118+
(** A unique identifier among all device instances of all backends. Note that multiple
119+
[device_id] (distinct device instances) might refer to the same physical device. *)
114120
cross_stream_candidates : 'buffer_ptr Hashtbl.M(Tnode).t;
115121
(** Freshly created arrays that might be shared across streams. The map can both grow and
116122
shrink. *)
@@ -248,8 +254,8 @@ module type Backend_device_common = sig
248254
val sync : event -> unit
249255
(** Blocks till the event completes, if it's not done already.
250256
251-
It is rarely needed to call [sync] explicitly, because it should always be
252-
called internally when necessary, in particular before extracting values from host. *)
257+
It is rarely needed to call [sync] explicitly, because it should always be called internally
258+
when necessary, in particular before extracting values from host. *)
253259

254260
val is_done : event -> bool
255261
(** Whether the event completed. *)

arrayjit/lib/backends.ml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
9494
(* Stdio.printf "copying: %s from_host\n" (Tn.debug_name tn); *)
9595
Backend.from_host ~dst_ptr:dst ~dst:ctx hosted;
9696
update_writer_event ~from:`Host ctx @@ Node tn;
97-
tn.host_modified <- false;
97+
Hash_set.add tn.host_read_by_devices ctx.stream.device.device_id;
9898
true
9999
| _ -> false
100100

@@ -146,7 +146,8 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
146146
assert (Domain.is_main_domain ());
147147
if Utils.settings.automatic_host_transfers then
148148
Set.iter hosted_inputs ~f:(fun tn ->
149-
if tn.host_modified then assert (from_host r.context tn));
149+
if not (Hash_set.mem tn.host_read_by_devices s.device.device_id) then
150+
assert (from_host r.context tn));
150151
Set.iter r.inputs ~f:(fun tn ->
151152
if Tn.potentially_cross_stream tn then
152153
Option.iter (Hashtbl.find s.device.shared_writer_streams tn) ~f:(fun data ->
@@ -386,7 +387,7 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
386387
match key.array with
387388
| (lazy (Some hosted)) ->
388389
Device.from_host ~dst_ptr ~dst:parent_context hosted;
389-
key.host_modified <- false
390+
Hash_set.add key.host_read_by_devices stream.device.device_id
390391
| _ -> ());
391392
dst_ptr
392393
in

arrayjit/lib/tnode.ml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ type t = {
8484
mutable code_name : string option;
8585
mutable prepare_read : prepare option;
8686
mutable prepare_write : prepare option;
87-
mutable host_modified : bool;
87+
mutable host_read_by_devices : Hash_set.M(Int).t;
88+
(** The unique ids of devices that read the most recent modification of the host array. *)
8889
}
8990
[@@deriving sexp_of]
9091

@@ -554,7 +555,7 @@ let create ?default_prec ~id ~label ~dims init_op =
554555
code_name = None;
555556
prepare_read = None;
556557
prepare_write = None;
557-
host_modified = true;
558+
host_read_by_devices = Hash_set.create (module Int);
558559
}
559560
in
560561
(* Note: if tensor nodes get non-trivial finalizers, remember to either add an is_finalized flag
@@ -578,7 +579,7 @@ let find =
578579
code_name = None;
579580
prepare_read = None;
580581
prepare_write = None;
581-
host_modified = false;
582+
host_read_by_devices = Hash_set.create (module Int);
582583
}
583584
in
584585
fun ~id -> Registry.find_opt registry { mock with id }
@@ -596,7 +597,7 @@ let do_read tn =
596597
let do_write tn =
597598
Option.iter ~f:(fun p -> p.sync ()) tn.prepare_write;
598599
tn.prepare_write <- None;
599-
tn.host_modified <- true
600+
Hash_set.clear tn.host_read_by_devices
600601

601602
let points_1d ?from_axis ~xdim tn =
602603
do_read tn;

0 commit comments

Comments
 (0)