Skip to content

Commit 0f6feaf

Browse files
committed
In progress: factor out alloc_if_needed
1 parent c42347d commit 0f6feaf

File tree

10 files changed

+58
-65
lines changed

10 files changed

+58
-65
lines changed

arrayjit/lib/assignments.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,10 @@ let is_total ~initialize_neutral ~projections =
7979

8080
(** Returns materialized nodes in the sense of {!Tnode.is_in_context_force}. NOTE: it ideally should
8181
be called after compilation. *)
82-
let context_nodes asgns =
82+
let context_nodes ~unified_memory asgns =
8383
let open Utils.Set_O in
8484
let empty = Set.empty (module Tn) in
85-
let one tn = if Tnode.is_in_context_force tn 34 then Set.singleton (module Tn) tn else empty in
85+
let one tn = if Tnode.is_in_context_force ~unified_memory tn 34 then Set.singleton (module Tn) tn else empty in
8686
let of_node = function Node rhs -> one rhs | Merge_buffer _ -> empty in
8787
let rec loop = function
8888
| Noop -> empty

arrayjit/lib/backend_impl.ml

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ end
133133
module type Backend_impl_common = sig
134134
include Buffer
135135

136-
val is_in_context : Low_level.traced_array -> bool
136+
val unified_memory : bool
137137
(** If true, the node is required to be in the contexts linked with code that uses it.
138138
139139
Should return false for nodes that are virtual, local, or which the backend prefers to access
@@ -305,3 +305,36 @@ struct
305305
let alloc_zero_init_array prec ~dims _stream = Backend.alloc_zero_init_array prec ~dims ()
306306
let free_buffer = Option.map Backend.free_buffer ~f:(fun memfree _stream ptr -> memfree () ptr)
307307
end
308+
309+
let%track3_sexp alloc_if_needed (type buffer_ptr) ~ ~unified_memory ctx stream ~key ~data:node ctx_arrays =
310+
if Tnode.is_in_context ~unified_memory node && not (Map.mem ctx_arrays key) then (
311+
[%log2 Tn.debug_name key, "read_only", (node.read_only : bool)];
312+
[%log3 (key : Tn.t)];
313+
let default () : buffer_ptr =
314+
set_ctx ctx;
315+
Cu.Deviceptr.mem_alloc ~size_in_bytes:(Tn.size_in_bytes key)
316+
in
317+
let add_new () = Map.add_exn ctx_arrays ~key ~data:(default ()) in
318+
let device = stream.device in
319+
if node.read_only then
320+
if Tn.known_non_cross_stream key then add_new ()
321+
else (
322+
if Hashtbl.mem device.cross_stream_candidates key then
323+
Tn.update_memory_sharing key Tn.Shared_cross_stream 40;
324+
let data = Hashtbl.find_or_add device.cross_stream_candidates key ~default in
325+
Map.add_exn ctx_arrays ~key ~data)
326+
else if Tn.known_shared_cross_stream key then (
327+
if Hashtbl.mem device.owner_streams key then
328+
if not (stream.stream_id = Hashtbl.find_exn device.owner_streams key) then
329+
raise
330+
@@ Utils.User_error
331+
("Cuda_backend.alloc_if_needed: node " ^ Tn.debug_name key
332+
^ " assumed to be cross-stream-shared but then written to on multiple devices")
333+
else Hashtbl.add_exn device.owner_streams ~key ~data:stream.stream_id;
334+
let data = Hashtbl.find_exn device.cross_stream_candidates key in
335+
Map.add_exn ctx_arrays ~key ~data)
336+
else (
337+
Tn.update_memory_sharing key Tn.Per_stream 41;
338+
Hashtbl.remove device.cross_stream_candidates key;
339+
add_new ()))
340+
else ctx_arrays

arrayjit/lib/backends.ml

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,18 +115,20 @@ let lower_batch_assignments ?names ?occupancy bindings asgns_l =
115115
Some (Assignments.lower ~unoptim_ll_source ~ll_source ~cd_source ~name bound asgns) )
116116
else (None, None))
117117

118-
let verify_prior_context ~is_in_context ~ctx_arrays ~from_prior_context traced_stores =
118+
let verify_prior_context ~unified_memory ~ctx_arrays ~from_prior_context traced_stores =
119119
Set.iter from_prior_context ~f:(fun tn ->
120120
let node = Array.find_map traced_stores ~f:(fun store -> Hashtbl.find store tn) in
121121
if
122122
Option.value_map node ~default:false ~f:(fun node ->
123-
is_in_context node && not (Option.is_some @@ Map.find ctx_arrays tn))
123+
Tn.is_in_context ~unified_memory node && not (Option.is_some @@ Map.find ctx_arrays tn))
124124
then raise @@ Utils.User_error ("The linked context lacks node " ^ Tnode.debug_name tn))
125125

126126
let from_prior_context_batch comps =
127127
Array.filter_map comps ~f:(fun comp ->
128128
Option.map comp ~f:(fun comp ->
129-
Set.diff (Assignments.context_nodes comp.Assignments.asgns) comp.embedded_nodes))
129+
Set.diff
130+
(Assignments.context_nodes ~unified_memory comp.Assignments.asgns)
131+
comp.embedded_nodes))
130132
|> Array.fold ~init:(Set.empty (module Tnode)) ~f:Set.union
131133

132134
(** Adds a scheduler and brings a lowered no-device backend on par with lowered device backends. *)
@@ -296,7 +298,7 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
296298
}
297299

298300
let link context (code : code) =
299-
verify_prior_context ~is_in_context ~ctx_arrays:context.ctx_arrays
301+
verify_prior_context ~unified_memory ~ctx_arrays:context.ctx_arrays
300302
~from_prior_context:code.from_prior_context [| code.lowered.traced_store |];
301303
let inputs, outputs = Low_level.input_and_output_nodes code.lowered in
302304
let ctx_arrays, bindings, schedule = link context code.code in
@@ -310,7 +312,7 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
310312
{ context; schedule; bindings; name = code.name; inputs; outputs }
311313

312314
let link_batch context code_batch =
313-
verify_prior_context ~is_in_context ~ctx_arrays:context.ctx_arrays
315+
verify_prior_context ~unified_memory ~ctx_arrays:context.ctx_arrays
314316
~from_prior_context:code_batch.from_prior_context
315317
@@ Array.filter_map code_batch.lowereds ~f:(Option.map ~f:(fun l -> l.Low_level.traced_store));
316318
let _ctx_arrays, bindings, schedules = link_batch context code_batch.code_batch in

arrayjit/lib/c_syntax.ml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ module C_syntax (B : sig
1717

1818
val opt_ctx_arrays : buffer_ptr Map.M(Tnode).t option
1919
val hardcoded_context_ptr : (buffer_ptr -> Ops.prec -> string) option
20-
val is_in_context : Low_level.traced_array -> bool
20+
val unified_memory : bool
2121
val host_ptrs_for_readonly : bool
2222
val logs_to_stdout : bool
2323
val main_kernel_prefix : string
@@ -68,7 +68,7 @@ struct
6868
Hashtbl.iter l.Low_level.traced_store ~f:(fun (node : Low_level.traced_array) ->
6969
let tn = node.tn in
7070
if not @@ Hash_set.mem is_global tn then
71-
let in_ctx : bool = B.is_in_context node in
71+
let in_ctx : bool = B.unified_memory node in
7272
let ctx_ptr = B.hardcoded_context_ptr in
7373
let mem : (Tn.memory_mode * int) option = tn.memory_mode in
7474
match
@@ -296,14 +296,14 @@ struct
296296
(* A rough approximation to the type Gccjit_backend.mem_properties. *)
297297
let backend_info =
298298
Sexp.Atom
299-
(if B.is_in_context node then "From_context"
299+
(if B.unified_memory node then "From_context"
300300
else if Hash_set.mem is_global tn then "Constant_from_host"
301301
else if Tn.is_virtual_force tn 3331 then "Virtual"
302302
else "Local_only")
303303
in
304304
if not @@ Utils.sexp_mem ~elem:backend_info tn.backend_info then
305305
tn.backend_info <- Utils.sexp_append ~elem:backend_info tn.backend_info;
306-
if B.is_in_context node && not (Hash_set.mem is_global tn) then
306+
if B.unified_memory node && not (Hash_set.mem is_global tn) then
307307
(B.typ_of_prec (Lazy.force tn.Tn.prec) ^ " *" ^ get_ident tn, Param_ptr tn) :: params
308308
else params)
309309
in
@@ -369,7 +369,7 @@ struct
369369
params);
370370
fprintf ppf "/* Local declarations and initialization. */@ ";
371371
Hashtbl.iteri traced_store ~f:(fun ~key:tn ~data:node ->
372-
if not (Tn.is_virtual_force tn 333 || B.is_in_context node || Hash_set.mem is_global tn)
372+
if not (Tn.is_virtual_force tn 333 || B.unified_memory node || Hash_set.mem is_global tn)
373373
then
374374
fprintf ppf "%s %s[%d]%s;@ "
375375
(B.typ_of_prec @@ Lazy.force tn.prec)

arrayjit/lib/cc_backend.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ type procedure = {
3434
}
3535
[@@deriving sexp_of]
3636

37-
let is_in_context node = Tnode.is_in_context_force node.Low_level.tn 33
37+
let unified_memory = true
3838

3939
let get_global_run_id =
4040
let next_id = ref 0 in
@@ -82,7 +82,7 @@ struct
8282
let for_lowereds = Input.for_lowereds
8383
let opt_ctx_arrays = Input.opt_ctx_arrays
8484
let hardcoded_context_ptr = c_ptr_to_string
85-
let is_in_context = is_in_context
85+
let unified_memory = unified_memory
8686
let host_ptrs_for_readonly = true
8787
let logs_to_stdout = false
8888
let main_kernel_prefix = ""

arrayjit/lib/cuda_backend.cudajit.ml

Lines changed: 3 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ module Backend_buffer = struct
3030
end)
3131
end
3232

33+
let unified_memory = false
34+
3335
module Device_config = struct
3436
include Backend_buffer
3537

@@ -261,11 +263,6 @@ let%diagn2_sexp cuda_to_ptx ~name cu_src =
261263
Stdio.Out_channel.close oc);
262264
ptx
263265

264-
let is_in_context node =
265-
(* FIXME: shouldn't we use Tnode.is_in_context_force? *)
266-
Tnode.default_to_most_local node.Low_level.tn 33;
267-
match node.tn.memory_mode with Some ((Virtual | Local), _) -> false | _ -> true
268-
269266
module C_syntax_config (Input : sig
270267
val for_lowereds : Low_level.optimized array
271268
end) =
@@ -276,7 +273,7 @@ struct
276273

277274
let opt_ctx_arrays = None
278275
let hardcoded_context_ptr = None
279-
let is_in_context = is_in_context
276+
let unified_memory = unified_memory
280277
let host_ptrs_for_readonly = false
281278
(* GPUs cannot access host memory pointers directly. *)
282279

@@ -449,39 +446,6 @@ let link_proc ~prior_context ~name ~(params : (string * param_source) list) ~ctx
449446
work;
450447
}
451448

452-
let%track3_sexp alloc_if_needed ctx stream ~key ~data:node ctx_arrays =
453-
if is_in_context node && not (Map.mem ctx_arrays key) then (
454-
[%log2 Tn.debug_name key, "read_only", (node.read_only : bool)];
455-
[%log3 (key : Tn.t)];
456-
let default () : buffer_ptr =
457-
set_ctx ctx;
458-
Cu.Deviceptr.mem_alloc ~size_in_bytes:(Tn.size_in_bytes key)
459-
in
460-
let add_new () = Map.add_exn ctx_arrays ~key ~data:(default ()) in
461-
let device = stream.device in
462-
if node.read_only then
463-
if Tn.known_non_cross_stream key then add_new ()
464-
else (
465-
if Hashtbl.mem device.cross_stream_candidates key then
466-
Tn.update_memory_sharing key Tn.Shared_cross_stream 40;
467-
let data = Hashtbl.find_or_add device.cross_stream_candidates key ~default in
468-
Map.add_exn ctx_arrays ~key ~data)
469-
else if Tn.known_shared_cross_stream key then (
470-
if Hashtbl.mem device.owner_streams key then
471-
if not (stream.stream_id = Hashtbl.find_exn device.owner_streams key) then
472-
raise
473-
@@ Utils.User_error
474-
("Cuda_backend.alloc_if_needed: node " ^ Tn.debug_name key
475-
^ " assumed to be cross-stream-shared but then written to on multiple devices")
476-
else Hashtbl.add_exn device.owner_streams ~key ~data:stream.stream_id;
477-
let data = Hashtbl.find_exn device.cross_stream_candidates key in
478-
Map.add_exn ctx_arrays ~key ~data)
479-
else (
480-
Tn.update_memory_sharing key Tn.Per_stream 41;
481-
Hashtbl.remove device.cross_stream_candidates key;
482-
add_new ()))
483-
else ctx_arrays
484-
485449
let run_options () =
486450
if Utils.with_runtime_debug () then
487451
Cu.Module.[ GENERATE_DEBUG_INFO true; GENERATE_LINE_INFO true ]

arrayjit/lib/cuda_backend.missing.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ let compile ?shared:_ ~name:_ bindings _optimized = bindings
1919
let compile_batch ?shared:_ ~names:_ (bindings : Indexing.unit_bindings) optimized : code_batch =
2020
Array.map optimized ~f:(fun _ -> bindings)
2121

22-
let is_in_context _traced_array = false
22+
let unified_memory = false
2323
let ctx_arrays Unimplemented_ctx = Map.empty (module Tnode)
2424

2525
let link (Unimplemented_ctx : context) (code : code) =

arrayjit/lib/gcc_backend.gccjit.ml

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,7 @@ type procedure = {
7474
}
7575
[@@deriving sexp_of]
7676

77-
let is_in_context node =
78-
(* FIXME: shouldn't we use Tnode.is_in_context_force? *)
79-
Tnode.default_to_most_local node.Low_level.tn 33;
80-
match node.tn.memory_mode with
81-
| Some (Hosted (Constant | Volatile), _) -> false
82-
| Some ((Virtual | Local), _) -> false
83-
| _ -> true
77+
let unified_memory = true
8478

8579
type gccjit_param = Gccjit.param
8680

arrayjit/lib/gcc_backend.missing.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ include Backend_impl.No_device_buffer_and_copying ()
55
let expected_merge_node Unimplemented_proc =
66
failwith "gcc backend missing: install the optional dependency gccjit"
77

8-
let is_in_context _node = failwith "gcc backend missing: install the optional dependency gccjit"
8+
let unified_memory = true
99

1010
let to_buffer _tn ~dst:_ ~src:_ =
1111
failwith "gcc backend missing: install the optional dependency gccjit"

arrayjit/lib/tnode.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,10 +183,10 @@ let is_materialized_force tn provenance =
183183
| Some ((On_device _ | Hosted _ | Materialized), _) -> true
184184
| Some ((Never_virtual | Device_only | Effectively_constant), _) -> assert false
185185

186-
let is_in_context_force tn provenance =
186+
let is_in_context_force ~unified_memory tn provenance =
187187
default_to_most_local tn provenance;
188188
match tn.memory_mode with
189-
| Some (Hosted (Constant | Volatile), _) -> false
189+
| Some (Hosted (Constant | Volatile), _) when unified_memory -> false
190190
| Some ((Virtual | Local), _) -> false
191191
| _ -> true
192192

0 commit comments

Comments
 (0)