Skip to content

Commit 7510e76

Browse files
committed
Fix: host_read_by_devices -> devices_not_lagging_host with the corresponding change of semantics
This fixes overriding of changes by transferring a stale version from host.
1 parent 38eebf6 commit 7510e76

File tree

3 files changed

+20
-18
lines changed

3 files changed

+20
-18
lines changed

arrayjit/lib/anatomy_of_a_backend.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -204,22 +204,22 @@ Unless disabled via setting `automatic_host_transfers` to false, `arrayjit` auto
204204

205205
- `prepare_read` for synchronization and `to_host` transfers right before a host array is read,
206206
- `prepare_write` for synchronization right before a host array is written to,
207-
- `host_read_by_devices` for tracking which devices have scheduled transferring the data already.
207+
- `devices_not_lagging_host` for tracking which devices have scheduled transferring the data already, or don't need transferring because they computed or scheduled computing the data themselves.
208208

209-
Since currently the tagging is per-device, for per-stream tensor nodes might need supplementary `from_host` (or `device_to_device`) calls in rare situations.
209+
Since currently the tagging is per-device, for per-stream, tensor nodes might need supplementary `from_host` (or `device_to_device`) calls in rare situations.
210210

211211
There are three code components to the automation.
212212

213213
- Within `Tnode`:
214-
- The helper function `do_read` unconditionally invokes synchronization code, and if `automatic_host_transfers` invokes data transfer code, as stored in the `prepare_read` field of a node; then clears the field.
214+
- The helper function `do_read` unconditionally invokes synchronization code, and if `automatic_host_transfers` it invokes data transfer code, as stored in the `prepare_read` field of a node; then clears the field.
215215
- The helper function `do_write` unconditionally invokes synchronization code as stored in the `prepare_write` field of a node, then clears the field.
216216
- `do_read` is invoked from `points_1d`, `points_2d`, `get_value`, `get_values` of `Tnode`; and also from `to_dag` and `print` of `Tensor`.
217217
- `do_write` is invoked from `set_value`, `set_values`.
218218
- `Tnode` exposes `prepare_read` and `prepare_write` for updating the fields: only the new data transfer is preserved, but the synchronization codes are combined.
219219
- Within `Backends.Add_buffer_retrieval_and_syncing`:
220220
- The `update_writer_event` helper adds the after-modification event to synchronization and sets data transfer to `to_host` from the stream, using `prepare_read`. This happens for `device_to_device` and `sync_routine` (after scheduling the routine) scheduling calls, and independently of `automatic_host_transfers`.
221-
- Moreover, `sync_routine`, before scheduling the routine and only if `automatic_host_transfers`, directly schedules `from_host` for input nodes that are not tagged with the device (via `host_read_by_devices`). Note that input nodes are the "read only" and "read before write" nodes that are not constants.
221+
- Moreover, `sync_routine`, before scheduling the routine and only if `automatic_host_transfers`, directly schedules `from_host` for input nodes that are not tagged with the device (via `devices_not_lagging_host`). Note that input nodes are the "read only" and "read before write" nodes that are not constants.
222222
- Within `Backends.Raise_backend.alloc_if_needed`:
223-
- If `automatic_host_transfers` and the node allocated for the context is a constant, `alloc_if_needed` directly schedules `from_host` for the node regardless of whether it is tagged with the device (via `host_read_by_devices`); it does add the device tag to the node (if missing).
223+
- If `automatic_host_transfers` and the node allocated for the context is a constant, `alloc_if_needed` directly schedules `from_host` for the node regardless of whether it is tagged with the device (via `devices_not_lagging_host`); it does add the device tag to the node (if missing).
224224

225225
**Note:** we do **not** invoke `Tnode.do_read` from within `Backends.Add_buffer_retrieval_and_syncing.from_host`, since to adequately handle such transfers one should deliberately use `device_to_device` functions. This can lead to confusing behavior, in particular observing (or not) a tensor node (on host) can change later computations by inserting (or not) an additional `to_host` before a `from_host`. This aspect of the design might change in the future.

arrayjit/lib/backends.ml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
8181
Hashtbl.update s.device.shared_writer_streams tn ~f:(fun l ->
8282
(s, e) :: Option.value ~default:[] l)
8383
else Hashtbl.remove s.device.shared_writer_streams tn;
84+
Hash_set.add tn.devices_not_lagging_host ctx.stream.device.device_id;
8485
Hashtbl.update s.updating_for tn ~f:(fun _ -> e)
8586
| Merge_buffer tn ->
8687
(* Note: the previous event does not need to be done! *)
@@ -94,7 +95,6 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
9495
(* Stdio.printf "copying: %s from_host\n" (Tn.debug_name tn); *)
9596
Backend.from_host ~dst_ptr:dst ~dst:ctx hosted;
9697
update_writer_event ~from:`Host ctx @@ Node tn;
97-
Hash_set.add tn.host_read_by_devices ctx.stream.device.device_id;
9898
true
9999
| _ -> false
100100

@@ -109,7 +109,6 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
109109
(* Stdio.printf "copying: %s from_host\n" (Tn.debug_name tn); *)
110110
Backend.from_host ~dst_ptr:dst ~dst:ctx hosted;
111111
update_writer_event ~from:`Host ctx @@ Node tn;
112-
Hash_set.add tn.host_read_by_devices ctx.stream.device.device_id;
113112
{ ctx with ctx_arrays = Map.add_exn ctx.ctx_arrays ~key:tn ~data:dst }
114113
| _, Some _ ->
115114
raise
@@ -211,7 +210,7 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
211210
assert (Domain.is_main_domain ());
212211
if Utils.settings.automatic_host_transfers then
213212
Set.iter hosted_inputs ~f:(fun tn ->
214-
if not (Hash_set.mem tn.host_read_by_devices s.device.device_id) then
213+
if not (Hash_set.mem tn.devices_not_lagging_host s.device.device_id) then
215214
assert (from_host r.context tn));
216215
Set.iter r.inputs ~f:(fun tn ->
217216
if Tn.potentially_cross_stream tn then
@@ -480,7 +479,7 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
480479
match key.array with
481480
| (lazy (Some hosted)) ->
482481
Device.from_host ~dst_ptr ~dst:parent_context hosted;
483-
Hash_set.add key.host_read_by_devices stream.device.device_id
482+
Hash_set.add key.devices_not_lagging_host stream.device.device_id
484483
| _ -> ());
485484
dst_ptr
486485
in

arrayjit/lib/tnode.ml

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,10 @@ type t = {
8787
mutable code_name : string option;
8888
mutable prepare_read : prepare option;
8989
mutable prepare_write : prepare option;
90-
mutable host_read_by_devices : Hash_set.M(Int).t;
91-
(** The unique ids of devices that read the most recent modification of the host array. *)
90+
mutable devices_not_lagging_host : Hash_set.M(Int).t;
91+
(** The unique ids of devices that either read the most recent modification of the host
92+
buffer, or computed the most recent modification of the node themselves, whether or not it
93+
has been transferred to the host yet. *)
9294
}
9395
[@@deriving sexp_of]
9496

@@ -458,7 +460,7 @@ let has a = match a.array with (lazy (Some _)) -> true | _ -> false
458460

459461
let dims_to_string ?(with_axis_numbers = false) arr =
460462
let dims_s =
461-
if Lazy.is_val arr.dims then
463+
if Lazy.is_val arr.dims then
462464
let padding = Option.map ~f:fst (Lazy.force arr.padding) in
463465
Nd.int_dims_to_string ~with_axis_numbers ?padding @@ Lazy.force arr.dims
464466
else "<not-in-yet>"
@@ -576,7 +578,7 @@ let create ?default_prec ~id ~label ~dims ~padding () =
576578
code_name = None;
577579
prepare_read = None;
578580
prepare_write = None;
579-
host_read_by_devices = Hash_set.create (module Int);
581+
devices_not_lagging_host = Hash_set.create (module Int);
580582
}
581583
in
582584
(* Note: if tensor nodes get non-trivial finalizers, remember to either add an is_finalized flag
@@ -604,7 +606,7 @@ let create_from_padded ~id ~label ~ndarray ~padding () =
604606
code_name = None;
605607
prepare_read = None;
606608
prepare_write = None;
607-
host_read_by_devices = Hash_set.create (module Int);
609+
devices_not_lagging_host = Hash_set.create (module Int);
608610
}
609611
in
610612
Registry.add registry tn;
@@ -678,7 +680,7 @@ let create_with_reshape ~id ~label ~base_ndarray ~dims ~padding ~from_padded ()
678680
code_name = None;
679681
prepare_read = None;
680682
prepare_write = None;
681-
host_read_by_devices = Hash_set.create (module Int);
683+
devices_not_lagging_host = Hash_set.create (module Int);
682684
}
683685
in
684686
Registry.add registry tn;
@@ -703,7 +705,7 @@ let find =
703705
code_name = None;
704706
prepare_read = None;
705707
prepare_write = None;
706-
host_read_by_devices = Hash_set.create (module Int);
708+
devices_not_lagging_host = Hash_set.create (module Int);
707709
}
708710
in
709711
fun ~id -> Registry.find_opt registry { mock with id }
@@ -721,7 +723,7 @@ let do_read tn =
721723
let do_write tn =
722724
Option.iter ~f:(fun p -> p.sync ()) tn.prepare_write;
723725
tn.prepare_write <- None;
724-
Hash_set.clear tn.host_read_by_devices
726+
Hash_set.clear tn.devices_not_lagging_host
725727

726728
let points_1d ?from_axis ~xdim tn =
727729
do_read tn;
@@ -732,7 +734,8 @@ let points_1d ?from_axis ~xdim tn =
732734
let points_2d ?from_axis ~xdim ~ydim tn =
733735
do_read tn;
734736
let padding = Option.map ~f:fst (Lazy.force tn.padding) in
735-
Option.value_map ~default:[||] ~f:(fun arr -> Nd.retrieve_2d_points ?from_axis ?padding ~xdim ~ydim arr)
737+
Option.value_map ~default:[||] ~f:(fun arr ->
738+
Nd.retrieve_2d_points ?from_axis ?padding ~xdim ~ydim arr)
736739
@@ Lazy.force tn.array
737740

738741
let set_value tn =

0 commit comments

Comments
 (0)