Skip to content

Commit 069322a

Browse files
committed
Rename unified_memory -> use_host_memory and add more debugging
1 parent 4f48ac8 commit 069322a

File tree

11 files changed

+35
-45
lines changed

11 files changed

+35
-45
lines changed

arrayjit/lib/assignments.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,10 @@ let is_total ~initialize_neutral ~projections =
7979

8080
(** Returns materialized nodes in the sense of {!Tnode.is_in_context_force}. NOTE: it ideally should
8181
be called after compilation. *)
82-
let context_nodes ~unified_memory asgns =
82+
let context_nodes ~use_host_memory asgns =
8383
let open Utils.Set_O in
8484
let empty = Set.empty (module Tn) in
85-
let one tn = if Tnode.is_in_context_force ~unified_memory tn 34 then Set.singleton (module Tn) tn else empty in
85+
let one tn = if Tnode.is_in_context_force ~use_host_memory tn 34 then Set.singleton (module Tn) tn else empty in
8686
let of_node = function Node rhs -> one rhs | Merge_buffer _ -> empty in
8787
let rec loop = function
8888
| Noop -> empty

arrayjit/lib/backend_impl.ml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,10 @@ end
133133
module type Backend_impl_common = sig
134134
include Buffer
135135

136-
val unified_memory : bool
137-
(** If true, the node is required to be in the contexts linked with code that uses it.
136+
val use_host_memory : bool
137+
(** If true, the backend will read from and write to the host memory directly whenever possible.
138138
139-
Should return false for nodes that are virtual, local, or which the backend prefers to access
140-
directly from the host. *)
139+
[use_host_memory] can only be true on unified memory devices, like CPU and Apple Metal. *)
141140
end
142141

143142
(** An interface to adding schedulers for stream-agnostic (typically CPU) backend implementations. *)

arrayjit/lib/backends.ml

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -115,18 +115,18 @@ let lower_batch_assignments ?names ?occupancy bindings asgns_l =
115115
Some (Assignments.lower ~unoptim_ll_source ~ll_source ~cd_source ~name bound asgns) )
116116
else (None, None))
117117

118-
let verify_prior_context ~unified_memory ~ctx_arrays ~from_prior_context =
118+
let verify_prior_context ~use_host_memory ~ctx_arrays ~from_prior_context =
119119
Set.iter from_prior_context ~f:(fun tn ->
120120
if
121-
Tn.is_in_context_force ~unified_memory tn 342
121+
Tn.is_in_context_force ~use_host_memory tn 342
122122
&& not (Option.is_some @@ Map.find ctx_arrays tn)
123123
then raise @@ Utils.User_error ("The linked context lacks node " ^ Tnode.debug_name tn))
124124

125-
let from_prior_context_batch ~unified_memory comps =
125+
let from_prior_context_batch ~use_host_memory comps =
126126
Array.filter_map comps ~f:(fun comp ->
127127
Option.map comp ~f:(fun comp ->
128128
Set.diff
129-
(Assignments.context_nodes ~unified_memory comp.Assignments.asgns)
129+
(Assignments.context_nodes ~use_host_memory comp.Assignments.asgns)
130130
comp.embedded_nodes))
131131
|> Array.fold ~init:(Set.empty (module Tnode)) ~f:Set.union
132132

@@ -270,7 +270,7 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
270270
let name, lowered = lower_assignments ?name bindings comp.Assignments.asgns in
271271
let code = compile ?shared ~name bindings lowered in
272272
let from_prior_context =
273-
Set.diff (Assignments.context_nodes ~unified_memory comp.asgns) comp.embedded_nodes
273+
Set.diff (Assignments.context_nodes ~use_host_memory comp.asgns) comp.embedded_nodes
274274
in
275275
{ from_prior_context; name; lowered; code; expected_merge_node = lowered.Low_level.merge_node }
276276

@@ -281,7 +281,7 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
281281
in
282282
let code_batch = compile_batch ?shared ~names bindings lowereds in
283283
let from_prior_context =
284-
from_prior_context_batch ~unified_memory
284+
from_prior_context_batch ~use_host_memory
285285
@@ Array.mapi lowereds ~f:(fun i -> Option.map ~f:(fun _ -> comps.(i)))
286286
in
287287
{
@@ -295,7 +295,7 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
295295
}
296296

297297
let%track3_sexp alloc_if_needed (stream : stream) ~key ~data:node ctx_arrays =
298-
if Tnode.is_in_context_force ~unified_memory key 345 && not (Map.mem ctx_arrays key) then (
298+
if Tnode.is_in_context_force ~use_host_memory key 345 && not (Map.mem ctx_arrays key) then (
299299
[%log2 Tn.debug_name key];
300300
[%log3 (key : Tnode.t)];
301301
let default () =
@@ -307,7 +307,7 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
307307
if Tn.known_non_cross_stream key then add_new ()
308308
else (
309309
if Hashtbl.mem device.cross_stream_candidates key then
310-
Tn.update_memory_sharing key Tn.Shared_cross_stream 40;
310+
Tn.update_memory_sharing key Tn.Shared_cross_stream 39;
311311
let data = Hashtbl.find_or_add device.cross_stream_candidates key ~default in
312312
Map.add_exn ctx_arrays ~key ~data)
313313
else if Tn.known_shared_cross_stream key then (
@@ -326,8 +326,8 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
326326
add_new ()))
327327
else ctx_arrays
328328

329-
let link context (code : code) =
330-
verify_prior_context ~unified_memory ~ctx_arrays:context.ctx_arrays
329+
let%debug3_sexp link context (code : code) =
330+
verify_prior_context ~use_host_memory ~ctx_arrays:context.ctx_arrays
331331
~from_prior_context:code.from_prior_context;
332332
let inputs, outputs = Low_level.input_and_output_nodes code.lowered in
333333
let ctx_arrays =
@@ -344,8 +344,8 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
344344
in
345345
{ context; schedule; bindings; name = code.name; inputs; outputs }
346346

347-
let link_batch context code_batch =
348-
verify_prior_context ~unified_memory ~ctx_arrays:context.ctx_arrays
347+
let%debug3_sexp link_batch context code_batch =
348+
verify_prior_context ~use_host_memory ~ctx_arrays:context.ctx_arrays
349349
~from_prior_context:code_batch.from_prior_context;
350350
let ctx_arrays =
351351
Array.map code_batch.lowereds

arrayjit/lib/c_syntax.ml

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@ module C_syntax (B : sig
1818
shared and already known. *)
1919

2020
val hardcoded_context_ptr : (buffer_ptr -> Ops.prec -> string) option
21-
val unified_memory : bool
22-
val host_ptrs_for_readonly : bool
21+
val use_host_memory : bool
2322
val logs_to_stdout : bool
2423
val main_kernel_prefix : string
2524
val kernel_prep_line : string
@@ -33,7 +32,7 @@ struct
3332
let get_ident =
3433
Low_level.get_ident_within_code ~no_dots:true @@ Array.map B.procs ~f:(fun (l, _) -> l.llc)
3534

36-
let in_ctx tn = B.(Tn.is_in_context_force ~unified_memory tn 341)
35+
let in_ctx tn = B.(Tn.is_in_context_force ~use_host_memory tn 341)
3736

3837
let pp_zero_out ppf tn =
3938
Stdlib.Format.fprintf ppf "@[<2>memset(%s, 0, %d);@]@ " (get_ident tn) @@ Tn.size_in_bytes tn
@@ -73,18 +72,15 @@ struct
7372
if not @@ Hash_set.mem is_global tn then
7473
let ctx_ptr = B.hardcoded_context_ptr in
7574
let mem : (Tn.memory_mode * int) option = tn.memory_mode in
76-
match
77-
(in_ctx tn, ctx_ptr, ctx_arrays, B.host_ptrs_for_readonly, mem, node.read_only)
78-
with
79-
| true, Some get_ptr, Some ctx_arrays, _, _, _ ->
75+
match (in_ctx tn, ctx_ptr, ctx_arrays, B.use_host_memory, mem) with
76+
| true, Some get_ptr, Some ctx_arrays, _, _ ->
8077
let ident = get_ident tn in
8178
let ctx_array =
8279
Option.value_exn ~here:[%here] ~message:ident @@ Map.find ctx_arrays tn
8380
in
8481
fprintf ppf "#define %s (%s)@," ident @@ get_ptr ctx_array (Lazy.force tn.prec);
8582
Hash_set.add is_global tn
86-
| false, _, _, true, Some (Hosted _, _), true ->
87-
(* In-context nodes to read directly from host would be error prone. *)
83+
| false, _, _, true, Some (Hosted _, _) ->
8884
let nd = Option.value_exn ~here:[%here] @@ Lazy.force tn.array in
8985
fprintf ppf "#define %s (%s)@," (get_ident tn) (Ndarray.c_ptr_to_string nd);
9086
Hash_set.add is_global tn
@@ -288,7 +284,7 @@ struct
288284
in
289285
pp_ll ppf llc
290286
291-
let%diagn_sexp compile_proc ~name ppf idx_params ~is_global
287+
let%track3_sexp compile_proc ~name ppf idx_params ~is_global
292288
Low_level.{ traced_store; llc; merge_node } =
293289
let open Stdlib.Format in
294290
let params : (string * param_source) list =

arrayjit/lib/cc_backend.ml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ type procedure = {
3333
}
3434
[@@deriving sexp_of]
3535

36-
let unified_memory = true
36+
let use_host_memory = true
3737

3838
let get_global_run_id =
3939
let next_id = ref 0 in
@@ -79,8 +79,7 @@ struct
7979

8080
let procs = Input.procs
8181
let hardcoded_context_ptr = c_ptr_to_string
82-
let unified_memory = unified_memory
83-
let host_ptrs_for_readonly = true
82+
let use_host_memory = use_host_memory
8483
let logs_to_stdout = false
8584
let main_kernel_prefix = ""
8685
let kernel_prep_line = ""
@@ -136,7 +135,7 @@ let%diagn_sexp compile_batch ~names ~opt_ctx_arrays bindings
136135
Option.map names.(i) ~f:(fun name ->
137136
{ result; params = Option.value_exn ~here:[%here] params; bindings; name }))
138137

139-
let%diagn_sexp link_compiled ~merge_buffer ~runner_label ctx_arrays (code : procedure) =
138+
let%track3_sexp link_compiled ~merge_buffer ~runner_label ctx_arrays (code : procedure) =
140139
let name : string = code.name in
141140
List.iter code.params ~f:(function
142141
| _, Param_ptr tn ->

arrayjit/lib/cuda_backend.cudajit.ml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ module Backend_buffer = struct
3030
end)
3131
end
3232

33-
let unified_memory = false
33+
let use_host_memory = false
3434

3535
module Device_config = struct
3636
include Backend_buffer
@@ -271,9 +271,7 @@ struct
271271

272272
let procs = Input.procs
273273
let hardcoded_context_ptr = None
274-
let unified_memory = unified_memory
275-
let host_ptrs_for_readonly = false
276-
(* GPUs cannot access host memory pointers directly. *)
274+
let use_host_memory = use_host_memory
277275

278276
let logs_to_stdout = true
279277
let main_kernel_prefix = "extern \"C\" __global__"

arrayjit/lib/cuda_backend.missing.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ let compile ?shared:_ ~name:_ bindings _optimized = bindings
1919
let compile_batch ?shared:_ ~names:_ (bindings : Indexing.unit_bindings) optimized : code_batch =
2020
Array.map optimized ~f:(fun _ -> bindings)
2121

22-
let unified_memory = false
22+
let use_host_memory = false
2323
let ctx_arrays Unimplemented_ctx = Map.empty (module Tnode)
2424

2525
let link (Unimplemented_ctx : context) (code : code) =

arrayjit/lib/gcc_backend.gccjit.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ type procedure = {
6666
}
6767
[@@deriving sexp_of]
6868

69-
let unified_memory = true
69+
let use_host_memory = true
7070

7171
let gcc_typ_of_prec =
7272
let open Gccjit in
@@ -113,7 +113,7 @@ let prepare_node ~debug_log_zero_out ~get_ident ctx traced_store ~opt_ctx_arrays
113113
let num_typ = Type.(get ctx c_typ) in
114114
let ptr_typ = Type.pointer num_typ in
115115
let ident = get_ident tn in
116-
let in_ctx = Tn.is_in_context_force ~unified_memory tn 343 in
116+
let in_ctx = Tn.is_in_context_force ~use_host_memory tn 343 in
117117
let ptr =
118118
match (in_ctx, opt_ctx_arrays, Tn.is_hosted_force tn 344) with
119119
| true, Some ctx_arrays, _ ->

arrayjit/lib/gcc_backend.missing.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ include Backend_impl.No_device_buffer_and_copying ()
55
let expected_merge_node Unimplemented_proc =
66
failwith "gcc backend missing: install the optional dependency gccjit"
77

8-
let unified_memory = true
8+
let use_host_memory = true
99

1010
let to_buffer _tn ~dst:_ ~src:_ =
1111
failwith "gcc backend missing: install the optional dependency gccjit"

arrayjit/lib/tnode.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,11 +183,11 @@ let is_materialized_force tn provenance =
183183
| Some ((On_device _ | Hosted _ | Materialized), _) -> true
184184
| Some ((Never_virtual | Device_only | Effectively_constant), _) -> assert false
185185

186-
let is_in_context_force ~unified_memory tn provenance =
186+
let is_in_context_force ~use_host_memory tn provenance =
187187
default_to_most_local tn provenance;
188188
match tn.memory_mode with
189189
| Some (Hosted (Constant | Volatile | Changed_on_devices Shared_cross_stream), _)
190-
when unified_memory ->
190+
when use_host_memory ->
191191
false
192192
| Some ((Virtual | Local), _) -> false
193193
| _ -> true

0 commit comments

Comments
 (0)