Skip to content

Commit 7019e18

Browse files
committed
In progress: make Tnode.is_in_context non-forcing and more precise, pass host ptrs by params if undecided at compile time
This is not correct/ideal still. Moreover it uncovers another potential bug about merge buffers.
1 parent 069322a commit 7019e18

File tree

7 files changed

+95
-36
lines changed

7 files changed

+95
-36
lines changed

arrayjit/lib/assignments.ml

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,17 @@ let get_name_exn asgns =
7777
let is_total ~initialize_neutral ~projections =
7878
initialize_neutral && Indexing.is_bijective projections
7979

80-
(** Returns materialized nodes in the sense of {!Tnode.is_in_context_force}. NOTE: it ideally should
81-
be called after compilation. *)
80+
(** Returns materialized nodes in the sense of {!Tnode.is_in_context}. NOTE: it should be called
81+
after compilation and ideally after linking with the relevant contexts; otherwise, it is an
82+
under-estimate. *)
8283
let context_nodes ~use_host_memory asgns =
8384
let open Utils.Set_O in
8485
let empty = Set.empty (module Tn) in
85-
let one tn = if Tnode.is_in_context_force ~use_host_memory tn 34 then Set.singleton (module Tn) tn else empty in
86+
let one tn =
87+
if Option.value ~default:false @@ Tnode.is_in_context ~use_host_memory tn then
88+
Set.singleton (module Tn) tn
89+
else empty
90+
in
8691
let of_node = function Node rhs -> one rhs | Merge_buffer _ -> empty in
8792
let rec loop = function
8893
| Noop -> empty

arrayjit/lib/backend_intf.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ type ('buffer_ptr, 'dev, 'event) device = {
8686
released : Utils.atomic_bool;
8787
cross_stream_candidates : 'buffer_ptr Hashtbl.M(Tnode).t;
8888
(** Freshly created arrays that might be shared across streams. The map can both grow and
89-
shrink. See the explanation on top of this file. *)
89+
shrink. *)
9090
owner_streams : int Hashtbl.M(Tnode).t;
9191
(** The streams owning the given nodes. This map can only grow. *)
9292
stream_working_on : (int * 'event) option Hashtbl.M(Tnode).t;

arrayjit/lib/backends.ml

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ let lower_batch_assignments ?names ?occupancy bindings asgns_l =
118118
let verify_prior_context ~use_host_memory ~ctx_arrays ~from_prior_context =
119119
Set.iter from_prior_context ~f:(fun tn ->
120120
if
121-
Tn.is_in_context_force ~use_host_memory tn 342
121+
(* Err on the safe side. *)
122+
Option.value ~default:false (Tn.is_in_context ~use_host_memory tn)
122123
&& not (Option.is_some @@ Map.find ctx_arrays tn)
123124
then raise @@ Utils.User_error ("The linked context lacks node " ^ Tnode.debug_name tn))
124125

@@ -295,9 +296,14 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
295296
}
296297

297298
let%track3_sexp alloc_if_needed (stream : stream) ~key ~data:node ctx_arrays =
298-
if Tnode.is_in_context_force ~use_host_memory key 345 && not (Map.mem ctx_arrays key) then (
299-
[%log2 Tn.debug_name key];
300-
[%log3 (key : Tnode.t)];
299+
(* TODO: do we need this? *)
300+
(* Tn.default_to_most_local key 345; *)
301+
if
302+
Option.value ~default:true (Tnode.is_in_context ~use_host_memory key)
303+
&& not (Map.mem ctx_arrays key)
304+
then (
305+
[%log Tn.debug_name key];
306+
[%log (key : Tnode.t)];
301307
let default () =
302308
alloc_zero_init_array (Lazy.force key.prec) ~dims:(Lazy.force key.dims) stream
303309
in
@@ -311,13 +317,13 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
311317
let data = Hashtbl.find_or_add device.cross_stream_candidates key ~default in
312318
Map.add_exn ctx_arrays ~key ~data)
313319
else if Tn.known_shared_cross_stream key then (
314-
if Hashtbl.mem device.owner_streams key then
320+
if Hashtbl.mem device.owner_streams key then (
315321
if not (stream.stream_id = Hashtbl.find_exn device.owner_streams key) then
316322
raise
317323
@@ Utils.User_error
318324
("Cuda_backend.alloc_if_needed: node " ^ Tn.debug_name key
319-
^ " assumed to be cross-stream-shared but then written to on multiple devices")
320-
else Hashtbl.add_exn device.owner_streams ~key ~data:stream.stream_id;
325+
^ " assumed to be cross-stream-shared but then written to on multiple devices"))
326+
else Hashtbl.add_exn device.owner_streams ~key ~data:stream.stream_id;
321327
let data = Hashtbl.find_exn device.cross_stream_candidates key in
322328
Map.add_exn ctx_arrays ~key ~data)
323329
else (

arrayjit/lib/c_syntax.ml

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ struct
3232
let get_ident =
3333
Low_level.get_ident_within_code ~no_dots:true @@ Array.map B.procs ~f:(fun (l, _) -> l.llc)
3434

35-
let in_ctx tn = B.(Tn.is_in_context_force ~use_host_memory tn 341)
35+
let in_ctx tn = B.(Tn.is_in_context ~use_host_memory tn)
3636

3737
let pp_zero_out ppf tn =
3838
Stdlib.Format.fprintf ppf "@[<2>memset(%s, 0, %d);@]@ " (get_ident tn) @@ Tn.size_in_bytes tn
@@ -72,15 +72,16 @@ struct
7272
if not @@ Hash_set.mem is_global tn then
7373
let ctx_ptr = B.hardcoded_context_ptr in
7474
let mem : (Tn.memory_mode * int) option = tn.memory_mode in
75-
match (in_ctx tn, ctx_ptr, ctx_arrays, B.use_host_memory, mem) with
76-
| true, Some get_ptr, Some ctx_arrays, _, _ ->
75+
match (in_ctx tn, ctx_ptr, ctx_arrays, mem) with
76+
| Some true, Some get_ptr, Some ctx_arrays, _ ->
7777
let ident = get_ident tn in
7878
let ctx_array =
7979
Option.value_exn ~here:[%here] ~message:ident @@ Map.find ctx_arrays tn
8080
in
8181
fprintf ppf "#define %s (%s)@," ident @@ get_ptr ctx_array (Lazy.force tn.prec);
8282
Hash_set.add is_global tn
83-
| false, _, _, true, Some (Hosted _, _) ->
83+
| Some false, _, _, Some (Hosted _, _)
84+
when B.(Tn.known_shared_with_host ~use_host_memory tn) ->
8485
let nd = Option.value_exn ~here:[%here] @@ Lazy.force tn.array in
8586
fprintf ppf "#define %s (%s)@," (get_ident tn) (Ndarray.c_ptr_to_string nd);
8687
Hash_set.add is_global tn
@@ -294,14 +295,19 @@ struct
294295
(* A rough approximation to the type Gccjit_backend.mem_properties. *)
295296
let backend_info =
296297
Sexp.Atom
297-
(if in_ctx tn then "Ctx"
298-
else if Hash_set.mem is_global tn then "Host"
298+
(if Hash_set.mem is_global tn then "Host"
299299
else if Tn.is_virtual_force tn 3331 then "Virt"
300-
else "Local")
300+
else
301+
match in_ctx tn with
302+
| Some true -> "Ctx"
303+
| Some false -> "Local"
304+
| None -> "Unk")
301305
in
302306
if not @@ Utils.sexp_mem ~elem:backend_info tn.backend_info then
303307
tn.backend_info <- Utils.sexp_append ~elem:backend_info tn.backend_info;
304-
if in_ctx tn && not (Hash_set.mem is_global tn) then
308+
(* We often don't know ahead of linking with relevant contexts what the stream sharing
309+
mode of the node will become. Conservatively, use passing as argument. *)
310+
if Option.value ~default:true (in_ctx tn) && not (Hash_set.mem is_global tn) then
305311
(B.typ_of_prec (Lazy.force tn.Tn.prec) ^ " *" ^ get_ident tn, Param_ptr tn) :: params
306312
else params)
307313
in
@@ -367,7 +373,12 @@ struct
367373
params);
368374
fprintf ppf "/* Local declarations and initialization. */@ ";
369375
Hashtbl.iteri traced_store ~f:(fun ~key:tn ~data:node ->
370-
if not (Tn.is_virtual_force tn 333 || in_ctx tn || Hash_set.mem is_global tn) then
376+
if
377+
not
378+
(Tn.is_virtual_force tn 333
379+
|| Option.value ~default:true (in_ctx tn)
380+
|| Hash_set.mem is_global tn)
381+
then
371382
fprintf ppf "%s %s[%d]%s;@ "
372383
(B.typ_of_prec @@ Lazy.force tn.prec)
373384
(get_ident tn) (Tn.num_elems tn)

arrayjit/lib/cc_backend.ml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ let%diagn_sexp compile_batch ~names ~opt_ctx_arrays bindings
138138
let%track3_sexp link_compiled ~merge_buffer ~runner_label ctx_arrays (code : procedure) =
139139
let name : string = code.name in
140140
List.iter code.params ~f:(function
141-
| _, Param_ptr tn ->
141+
| _, Param_ptr tn when not @@ Tn.known_shared_with_host ~use_host_memory tn ->
142142
if not (Map.mem ctx_arrays tn) then
143143
invalid_arg
144144
[%string
@@ -167,7 +167,13 @@ let%track3_sexp link_compiled ~merge_buffer ~runner_label ctx_arrays (code : pro
167167
let get_ptr (ptr, _tn) = ptr in
168168
Param_2f (get_ptr, merge_buffer, link bs ps Ctypes.(ptr void @-> cs))
169169
| bs, Param_ptr tn :: ps ->
170-
let c_ptr = Map.find_exn ctx_arrays tn in
170+
let c_ptr =
171+
if Tn.known_shared_with_host ~use_host_memory tn then
172+
Ndarray.get_voidptr_not_managed
173+
@@ Option.value_exn ~here:[%here]
174+
@@ Lazy.force tn.array
175+
else Map.find_exn ctx_arrays tn
176+
in
171177
Param_2 (ref (Some c_ptr), link bs ps Ctypes.(ptr void @-> cs))
172178
in
173179
(* Reverse the input order because [Indexing.apply] will reverse it again. Important:

arrayjit/lib/gcc_backend.gccjit.ml

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -113,16 +113,17 @@ let prepare_node ~debug_log_zero_out ~get_ident ctx traced_store ~opt_ctx_arrays
113113
let num_typ = Type.(get ctx c_typ) in
114114
let ptr_typ = Type.pointer num_typ in
115115
let ident = get_ident tn in
116-
let in_ctx = Tn.is_in_context_force ~use_host_memory tn 343 in
116+
let hosted = Tn.is_hosted_force tn 344 in
117+
let in_ctx = Tn.is_in_context ~use_host_memory tn in
117118
let ptr =
118-
match (in_ctx, opt_ctx_arrays, Tn.is_hosted_force tn 344) with
119-
| true, Some ctx_arrays, _ ->
119+
match (in_ctx, opt_ctx_arrays, hosted) with
120+
| Some true, Some ctx_arrays, _ ->
120121
Lazy.from_val @@ get_c_ptr ctx num_typ @@ Map.find_exn ctx_arrays tn
121-
| true, None, _ ->
122+
| (Some true | None), None, _ ->
122123
let p = Param.create ctx ptr_typ ident in
123124
param_ptrs := (p, Param_ptr tn) :: !param_ptrs;
124125
Lazy.from_val (RValue.param p)
125-
| false, _, true -> (
126+
| (Some false | None), _, true -> (
126127
let addr arr =
127128
Lazy.from_val @@ get_c_ptr ctx num_typ @@ Ctypes.bigarray_start Ctypes_static.Genarray arr
128129
in
@@ -132,7 +133,7 @@ let prepare_node ~debug_log_zero_out ~get_ident ctx traced_store ~opt_ctx_arrays
132133
| Some (Single_nd arr) -> addr arr
133134
| Some (Double_nd arr) -> addr arr
134135
| None -> assert false)
135-
| false, _, false ->
136+
| (Some false | None), _, false ->
136137
let arr_typ = Type.array ctx num_typ size_in_elems in
137138
let v = ref None in
138139
let initialize _init_block func = v := Some (Function.local func arr_typ ident) in
@@ -645,7 +646,9 @@ let%diagn_sexp compile_batch ~(names : string option array) ~opt_ctx_arrays bind
645646

646647
let%track3_sexp link_compiled ~merge_buffer ~runner_label ctx_arrays (code : procedure) =
647648
let name : string = code.name in
648-
List.iter code.params ~f:(function Param_ptr tn -> assert (Map.mem ctx_arrays tn) | _ -> ());
649+
List.iter code.params ~f:(function
650+
| Param_ptr tn when not (Tn.known_shared_cross_stream tn) -> assert (Map.mem ctx_arrays tn)
651+
| _ -> ());
649652
let log_file_name = Utils.diagn_log_file [%string "debug-%{runner_label}-%{code.name}.log"] in
650653
let run_variadic =
651654
[%log_level
@@ -666,7 +669,13 @@ let%track3_sexp link_compiled ~merge_buffer ~runner_label ctx_arrays (code : pro
666669
| bs, Log_file_name :: ps ->
667670
Param_1 (ref (Some log_file_name), link bs ps Ctypes.(string @-> cs))
668671
| bs, Param_ptr tn :: ps ->
669-
let c_ptr = Map.find_exn ctx_arrays tn in
672+
let c_ptr =
673+
if Tn.known_shared_with_host ~use_host_memory tn then
674+
Ndarray.get_voidptr_not_managed
675+
@@ Option.value_exn ~here:[%here]
676+
@@ Lazy.force tn.array
677+
else Map.find_exn ctx_arrays tn
678+
in
670679
Param_2 (ref (Some c_ptr), link bs ps Ctypes.(ptr void @-> cs))
671680
| bs, Merge_buffer :: ps ->
672681
let get_ptr (ptr, _tn) = ptr in

arrayjit/lib/tnode.ml

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@ let _get_local_debug_runtime = Utils._get_local_debug_runtime
2424
type sharing =
2525
| Unset
2626
| Per_stream (** The tensor node has separate arrays for each stream. *)
27-
| Shared_cross_stream (** The tensor node has a single array per device. *)
27+
| Shared_cross_stream
28+
(** The tensor node has a single array per device that can appear in multiple contexts, except
29+
for backends with [use_host_memory = true] and nodes with memory mode
30+
[Hosted (Changed_on_devices Shared_cross_stream)], where it only has the on-host array and
31+
does not appear in any contexts. *)
2832
[@@deriving sexp, compare, equal]
2933

3034
type memory_type =
@@ -183,14 +187,32 @@ let is_materialized_force tn provenance =
183187
| Some ((On_device _ | Hosted _ | Materialized), _) -> true
184188
| Some ((Never_virtual | Device_only | Effectively_constant), _) -> assert false
185189

186-
let is_in_context_force ~use_host_memory tn provenance =
187-
default_to_most_local tn provenance;
190+
(* Unlike the [known_] functions which can only change from [false] to [true], [is_in_context
191+
~use_host_memory tn] is more precise. Generally, it can only change away from [None], but there
192+
is one exception. When [use_host_memory = true], it can change from [Some false] to [Some true]
193+
if the memory mode changes from [Hosted (Changed_on_devices Shared_cross_stream)] to [Hosted
194+
(Changed_on_devices Per_stream)]. *)
195+
let is_in_context ~use_host_memory tn =
188196
match tn.memory_mode with
197+
| Some (Hosted (Changed_on_devices Per_stream), _) -> Some true
198+
| Some ((Materialized | Hosted Nonconstant), _) when not use_host_memory -> Some true
189199
| Some (Hosted (Constant | Volatile | Changed_on_devices Shared_cross_stream), _)
190200
when use_host_memory ->
191-
false
192-
| Some ((Virtual | Local), _) -> false
193-
| _ -> true
201+
Some false
202+
| Some (Hosted Nonconstant, _) when use_host_memory -> None
203+
| Some (Hosted _, _) -> Some true
204+
| Some ((Virtual | Local), _) -> Some false
205+
| None | Some ((Materialized | Effectively_constant | Never_virtual | Device_only), _) -> None
206+
| Some (On_device _, _) -> Some true
207+
208+
(** The opposite of [is_in_context] for hosted tensor nodes. False if [use_host_memory = false] or
209+
for non-hosted tensor nodes. *)
210+
let known_shared_with_host ~use_host_memory tn =
211+
match tn.memory_mode with
212+
| Some (Hosted (Constant | Volatile | Changed_on_devices Shared_cross_stream), _)
213+
when use_host_memory ->
214+
true
215+
| _ -> false
194216

195217
let known_not_materialized tn =
196218
match tn.memory_mode with Some ((Virtual | Local), _) -> true | _ -> false

0 commit comments

Comments
 (0)