Skip to content

Commit 041bc78

Browse files
committed
In progress step 2 of 3: factor out alloc_if_needed
Steps 1 and 2: in `compile` parameter `opt_ctx_arrays` and `link` parameter `ctx_arrays`, expect arrays of the resulting context. Step 3: compute the context arrays before calling `Backend.link` when raising a backend.
1 parent 0f6feaf commit 041bc78

File tree

7 files changed

+215
-370
lines changed

7 files changed

+215
-370
lines changed

arrayjit/lib/backend_impl.ml

Lines changed: 16 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -205,21 +205,26 @@ module type Lowered_no_device_backend = sig
205205
Indexing.unit_bindings ->
206206
Low_level.optimized ->
207207
procedure
208+
(** [opt_ctx_arrays], if any, already contain the arrays of the context that will result from
209+
linking the code. *)
208210

209211
val compile_batch :
210212
names:string option array ->
211-
opt_ctx_arrays:ctx_arrays option ->
213+
opt_ctx_arrays:ctx_arrays option array option ->
212214
Indexing.unit_bindings ->
213215
Low_level.optimized option array ->
214-
ctx_arrays option * procedure option array
216+
procedure option array
217+
(** [opt_ctx_arrays], if any, already contain the arrays of the contexts that will result from
218+
linking the code. *)
215219

216220
val link_compiled :
217221
merge_buffer:(buffer_ptr * Tnode.t) option ref ->
218222
runner_label:string ->
219223
ctx_arrays ->
220224
procedure ->
221-
ctx_arrays * Indexing.lowered_bindings * Task.t
222-
(** [runner_label] will be [get_name stream] of the stream holding the resulting [ctx_arrays]. *)
225+
Indexing.lowered_bindings * Task.t
226+
(** The [ctx_arrays] already contain the arrays of the resulting context. [runner_label] will be
227+
[get_name stream] of the stream holding the resulting [ctx_arrays]. *)
223228

224229
include No_device_buffer_and_copying with type buffer_ptr := buffer_ptr
225230
end
@@ -278,18 +283,18 @@ module type Lowered_backend = sig
278283
Low_level.optimized option array ->
279284
code_batch
280285

281-
val link : context -> code -> ctx_arrays * Indexing.lowered_bindings * Task.t
282-
(** The results correspond to the fields {!field-Backend_intf.ctx_arrays} of
283-
{!field-Backend_intf.context}, {!field-Backend_intf.bindings} and
286+
val link : context -> code -> ctx_arrays -> Indexing.lowered_bindings * Task.t
287+
(** [context] is the prior context, while [ctx_arrays] are the arrays of the resulting context.
288+
The results correspond to the fields {!field-Backend_intf.bindings} and
284289
{!field-Backend_intf.schedule} of {!Backend_intf.routine}. *)
285290

286291
val link_batch :
287292
context ->
288293
code_batch ->
289-
ctx_arrays * Indexing.lowered_bindings * (ctx_arrays * Task.t) option array
290-
(** Returns the schedule tasks and their [ctx_arrays] for the procedures included in the code
291-
batch. The returned [ctx_arrays] will be part of a context downstream of all the tasks and the
292-
tasks' contexts are not independent (typically, they are cumulative). *)
294+
ctx_arrays option array ->
295+
Indexing.lowered_bindings * Task.t option array
296+
(** [context] is the prior context, while the [ctx_arrays] are the arrays of the resulting
297+
contexts. Returns the schedule tasks for the procedures included in the code batch. *)
293298
end
294299

295300
module Alloc_buffer_ignore_stream
@@ -305,36 +310,3 @@ struct
305310
let alloc_zero_init_array prec ~dims _stream = Backend.alloc_zero_init_array prec ~dims ()
306311
let free_buffer = Option.map Backend.free_buffer ~f:(fun memfree _stream ptr -> memfree () ptr)
307312
end
308-
309-
let%track3_sexp alloc_if_needed (type buffer_ptr) ~ ~unified_memory ctx stream ~key ~data:node ctx_arrays =
310-
if Tnode.is_in_context ~unified_memory node && not (Map.mem ctx_arrays key) then (
311-
[%log2 Tn.debug_name key, "read_only", (node.read_only : bool)];
312-
[%log3 (key : Tn.t)];
313-
let default () : buffer_ptr =
314-
set_ctx ctx;
315-
Cu.Deviceptr.mem_alloc ~size_in_bytes:(Tn.size_in_bytes key)
316-
in
317-
let add_new () = Map.add_exn ctx_arrays ~key ~data:(default ()) in
318-
let device = stream.device in
319-
if node.read_only then
320-
if Tn.known_non_cross_stream key then add_new ()
321-
else (
322-
if Hashtbl.mem device.cross_stream_candidates key then
323-
Tn.update_memory_sharing key Tn.Shared_cross_stream 40;
324-
let data = Hashtbl.find_or_add device.cross_stream_candidates key ~default in
325-
Map.add_exn ctx_arrays ~key ~data)
326-
else if Tn.known_shared_cross_stream key then (
327-
if Hashtbl.mem device.owner_streams key then
328-
if not (stream.stream_id = Hashtbl.find_exn device.owner_streams key) then
329-
raise
330-
@@ Utils.User_error
331-
("Cuda_backend.alloc_if_needed: node " ^ Tn.debug_name key
332-
^ " assumed to be cross-stream-shared but then written to on multiple devices")
333-
else Hashtbl.add_exn device.owner_streams ~key ~data:stream.stream_id;
334-
let data = Hashtbl.find_exn device.cross_stream_candidates key in
335-
Map.add_exn ctx_arrays ~key ~data)
336-
else (
337-
Tn.update_memory_sharing key Tn.Per_stream 41;
338-
Hashtbl.remove device.cross_stream_candidates key;
339-
add_new ()))
340-
else ctx_arrays

arrayjit/lib/backends.ml

Lines changed: 59 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -115,15 +115,14 @@ let lower_batch_assignments ?names ?occupancy bindings asgns_l =
115115
Some (Assignments.lower ~unoptim_ll_source ~ll_source ~cd_source ~name bound asgns) )
116116
else (None, None))
117117

118-
let verify_prior_context ~unified_memory ~ctx_arrays ~from_prior_context traced_stores =
118+
let verify_prior_context ~unified_memory ~ctx_arrays ~from_prior_context =
119119
Set.iter from_prior_context ~f:(fun tn ->
120-
let node = Array.find_map traced_stores ~f:(fun store -> Hashtbl.find store tn) in
121120
if
122-
Option.value_map node ~default:false ~f:(fun node ->
123-
Tn.is_in_context ~unified_memory node && not (Option.is_some @@ Map.find ctx_arrays tn))
121+
Tn.is_in_context_force ~unified_memory tn 342
122+
&& not (Option.is_some @@ Map.find ctx_arrays tn)
124123
then raise @@ Utils.User_error ("The linked context lacks node " ^ Tnode.debug_name tn))
125124

126-
let from_prior_context_batch comps =
125+
let from_prior_context_batch ~unified_memory comps =
127126
Array.filter_map comps ~f:(fun comp ->
128127
Option.map comp ~f:(fun comp ->
129128
Set.diff
@@ -156,7 +155,7 @@ module Add_device
156155
}
157156
| Compiled of {
158157
lowereds : Low_level.optimized option array;
159-
procs : ctx_arrays option * Backend.procedure option array;
158+
procs : Backend.procedure option array;
160159
}
161160
[@@deriving sexp_of]
162161

@@ -174,38 +173,34 @@ module Add_device
174173

175174
include Add_scheduler (Backend)
176175

177-
let link context (code : code) =
176+
let link context (code : code) ctx_arrays =
178177
let runner_label = get_name context.stream in
179-
let ctx_arrays = context.ctx_arrays in
180178
let merge_buffer = context.stream.merge_buffer in
181179
match code with
182180
| Postponed { lowered; bindings; name } ->
183181
let proc = Backend.compile ~name ~opt_ctx_arrays:(Some ctx_arrays) bindings lowered in
184182
link_compiled ~merge_buffer ~runner_label ctx_arrays proc
185183
| Compiled { proc; _ } -> link_compiled ~merge_buffer ~runner_label ctx_arrays proc
186184

187-
let link_batch context (code_batch : code_batch) =
185+
let link_batch context (code_batch : code_batch) ctx_arrays =
188186
let runner_label = get_name context.stream in
189-
let ctx_arrays = context.ctx_arrays in
190187
let merge_buffer = context.stream.merge_buffer in
191-
(* FIXME: why are we getting and ignoring opt_ctx_arrays here? *)
192-
let _opt_ctx_arrays, procs =
188+
let procs =
193189
match code_batch with
194190
| Postponed { lowereds; bindings; names } ->
195191
Backend.compile_batch ~names ~opt_ctx_arrays:(Some ctx_arrays) bindings lowereds
196192
| Compiled { procs; _ } -> procs
197193
in
198-
let (ctx_arrays, bindings), schedules =
199-
Array.fold_map procs ~init:(ctx_arrays, None) ~f:(fun (ctx_arrays, bindings) -> function
194+
let bindings, schedules =
195+
Array.fold_mapi procs ~init:None ~f:(fun i bindings -> function
200196
| Some proc ->
201-
let ctx_arrays, bindings', schedule =
202-
link_compiled ~merge_buffer ~runner_label ctx_arrays proc
203-
in
197+
let ctx_arrays = Option.value_exn ctx_arrays.(i) in
198+
let bindings', schedule = link_compiled ~merge_buffer ~runner_label ctx_arrays proc in
204199
Option.iter bindings ~f:(fun bindings -> assert (phys_equal bindings bindings'));
205-
((ctx_arrays, Some bindings'), Some (ctx_arrays, schedule))
206-
| None -> ((ctx_arrays, bindings), None))
200+
(Some bindings', Some schedule)
201+
| None -> (bindings, None))
207202
in
208-
(ctx_arrays, Option.value_exn ~here:[%here] bindings, schedules)
203+
(Option.value_exn ~here:[%here] bindings, schedules)
209204

210205
let from_host ~dst_ptr ~dst hosted =
211206
let work () = host_to_buffer hosted ~dst:dst_ptr in
@@ -271,10 +266,44 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
271266
}
272267
[@@deriving sexp_of]
273268

269+
let%track3_sexp _alloc_if_needed (stream : stream) ~key ~data:node ctx_arrays =
270+
if Tnode.is_in_context_force ~unified_memory key 345 && not (Map.mem ctx_arrays key) then (
271+
[%log2 Tn.debug_name key];
272+
[%log3 (key : Tnode.t)];
273+
let default () =
274+
alloc_zero_init_array (Lazy.force key.prec) ~dims:(Lazy.force key.dims) stream
275+
in
276+
let add_new () = Map.add_exn ctx_arrays ~key ~data:(default ()) in
277+
let device = stream.device in
278+
if node.Low_level.read_only then
279+
if Tn.known_non_cross_stream key then add_new ()
280+
else (
281+
if Hashtbl.mem device.cross_stream_candidates key then
282+
Tn.update_memory_sharing key Tn.Shared_cross_stream 40;
283+
let data = Hashtbl.find_or_add device.cross_stream_candidates key ~default in
284+
Map.add_exn ctx_arrays ~key ~data)
285+
else if Tn.known_shared_cross_stream key then (
286+
if Hashtbl.mem device.owner_streams key then
287+
if not (stream.stream_id = Hashtbl.find_exn device.owner_streams key) then
288+
raise
289+
@@ Utils.User_error
290+
("Cuda_backend.alloc_if_needed: node " ^ Tn.debug_name key
291+
^ " assumed to be cross-stream-shared but then written to on multiple devices")
292+
else Hashtbl.add_exn device.owner_streams ~key ~data:stream.stream_id;
293+
let data = Hashtbl.find_exn device.cross_stream_candidates key in
294+
Map.add_exn ctx_arrays ~key ~data)
295+
else (
296+
Tn.update_memory_sharing key Tn.Per_stream 41;
297+
Hashtbl.remove device.cross_stream_candidates key;
298+
add_new ()))
299+
else ctx_arrays
300+
274301
let compile ?shared ?name bindings comp : code =
275302
let name, lowered = lower_assignments ?name bindings comp.Assignments.asgns in
276303
let code = compile ?shared ~name bindings lowered in
277-
let from_prior_context = Set.diff (Assignments.context_nodes comp.asgns) comp.embedded_nodes in
304+
let from_prior_context =
305+
Set.diff (Assignments.context_nodes ~unified_memory comp.asgns) comp.embedded_nodes
306+
in
278307
{ from_prior_context; name; lowered; code; expected_merge_node = lowered.Low_level.merge_node }
279308

280309
let compile_batch ?shared ?names ?occupancy bindings comps =
@@ -284,7 +313,7 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
284313
in
285314
let code_batch = compile_batch ?shared ~names bindings lowereds in
286315
let from_prior_context =
287-
from_prior_context_batch
316+
from_prior_context_batch ~unified_memory
288317
@@ Array.mapi lowereds ~f:(fun i -> Option.map ~f:(fun _ -> comps.(i)))
289318
in
290319
{
@@ -299,9 +328,10 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
299328

300329
let link context (code : code) =
301330
verify_prior_context ~unified_memory ~ctx_arrays:context.ctx_arrays
302-
~from_prior_context:code.from_prior_context [| code.lowered.traced_store |];
331+
~from_prior_context:code.from_prior_context;
303332
let inputs, outputs = Low_level.input_and_output_nodes code.lowered in
304-
let ctx_arrays, bindings, schedule = link context code.code in
333+
let ctx_arrays = failwith "NOT IMPLEMENTED YET" in
334+
let bindings, schedule = link context code.code ctx_arrays in
305335
let context = make_child ~ctx_arrays context in
306336
let schedule =
307337
Task.prepend schedule ~work:(fun () ->
@@ -313,12 +343,13 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
313343

314344
let link_batch context code_batch =
315345
verify_prior_context ~unified_memory ~ctx_arrays:context.ctx_arrays
316-
~from_prior_context:code_batch.from_prior_context
317-
@@ Array.filter_map code_batch.lowereds ~f:(Option.map ~f:(fun l -> l.Low_level.traced_store));
318-
let _ctx_arrays, bindings, schedules = link_batch context code_batch.code_batch in
346+
~from_prior_context:code_batch.from_prior_context;
347+
let ctx_arrays = failwith "NOT IMPLEMENTED YET" in
348+
let bindings, schedules = link_batch context code_batch.code_batch ctx_arrays in
319349
Array.fold_mapi schedules ~init:context ~f:(fun i context -> function
320350
| None -> (context, None)
321-
| Some (ctx_arrays, schedule) ->
351+
| Some schedule ->
352+
let ctx_arrays = Option.value_exn ctx_arrays.(i) in
322353
let context = make_child ~ctx_arrays context in
323354
let expected_merge_node = code_batch.expected_merge_nodes.(i) in
324355
let inputs, outputs =

arrayjit/lib/c_syntax.ml

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@ let _get_local_debug_runtime = Utils._get_local_debug_runtime
1111
module Tn = Tnode
1212

1313
module C_syntax (B : sig
14-
val for_lowereds : Low_level.optimized array
15-
1614
type buffer_ptr
1715

18-
val opt_ctx_arrays : buffer_ptr Map.M(Tnode).t option
16+
val procs : (Low_level.optimized * buffer_ptr ctx_arrays option) array
17+
(** The low-level prcedure to compile, and the arrays of the context it will be linked to if not
18+
shared and already known. *)
19+
1920
val hardcoded_context_ptr : (buffer_ptr -> Ops.prec -> string) option
2021
val unified_memory : bool
2122
val host_ptrs_for_readonly : bool
@@ -30,7 +31,9 @@ module C_syntax (B : sig
3031
end) =
3132
struct
3233
let get_ident =
33-
Low_level.get_ident_within_code ~no_dots:true @@ Array.map B.for_lowereds ~f:(fun l -> l.llc)
34+
Low_level.get_ident_within_code ~no_dots:true @@ Array.map B.procs ~f:(fun (l, _) -> l.llc)
35+
36+
let in_ctx tn = B.(Tn.is_in_context_force ~unified_memory tn 341)
3437

3538
let pp_zero_out ppf tn =
3639
Stdlib.Format.fprintf ppf "@[<2>memset(%s, 0, %d);@]@ " (get_ident tn) @@ Tn.size_in_bytes tn
@@ -64,15 +67,14 @@ struct
6467
let is_global = Hash_set.create (module Tn) in
6568
fprintf ppf {|@[<v 0>%a@,/* Global declarations. */@,|} (pp_print_list pp_print_string)
6669
B.include_lines;
67-
Array.iter B.for_lowereds ~f:(fun l ->
70+
Array.iter B.procs ~f:(fun (l, ctx_arrays) ->
6871
Hashtbl.iter l.Low_level.traced_store ~f:(fun (node : Low_level.traced_array) ->
6972
let tn = node.tn in
7073
if not @@ Hash_set.mem is_global tn then
71-
let in_ctx : bool = B.unified_memory node in
7274
let ctx_ptr = B.hardcoded_context_ptr in
7375
let mem : (Tn.memory_mode * int) option = tn.memory_mode in
7476
match
75-
(in_ctx, ctx_ptr, B.opt_ctx_arrays, B.host_ptrs_for_readonly, mem, node.read_only)
77+
(in_ctx tn, ctx_ptr, ctx_arrays, B.host_ptrs_for_readonly, mem, node.read_only)
7678
with
7779
| true, Some get_ptr, Some ctx_arrays, _, _, _ ->
7880
let ident = get_ident tn in
@@ -292,18 +294,18 @@ struct
292294
let params : (string * param_source) list =
293295
(* Preserve the order in the hashtable, so it's the same as e.g. in compile_globals. *)
294296
List.rev
295-
@@ Hashtbl.fold traced_store ~init:[] ~f:(fun ~key:tn ~data:node params ->
297+
@@ Hashtbl.fold traced_store ~init:[] ~f:(fun ~key:tn ~data:_ params ->
296298
(* A rough approximation to the type Gccjit_backend.mem_properties. *)
297299
let backend_info =
298300
Sexp.Atom
299-
(if B.unified_memory node then "From_context"
300-
else if Hash_set.mem is_global tn then "Constant_from_host"
301-
else if Tn.is_virtual_force tn 3331 then "Virtual"
302-
else "Local_only")
301+
(if in_ctx tn then "Ctx"
302+
else if Hash_set.mem is_global tn then "Host"
303+
else if Tn.is_virtual_force tn 3331 then "Virt"
304+
else "Local")
303305
in
304306
if not @@ Utils.sexp_mem ~elem:backend_info tn.backend_info then
305307
tn.backend_info <- Utils.sexp_append ~elem:backend_info tn.backend_info;
306-
if B.unified_memory node && not (Hash_set.mem is_global tn) then
308+
if in_ctx tn && not (Hash_set.mem is_global tn) then
307309
(B.typ_of_prec (Lazy.force tn.Tn.prec) ^ " *" ^ get_ident tn, Param_ptr tn) :: params
308310
else params)
309311
in
@@ -369,8 +371,7 @@ struct
369371
params);
370372
fprintf ppf "/* Local declarations and initialization. */@ ";
371373
Hashtbl.iteri traced_store ~f:(fun ~key:tn ~data:node ->
372-
if not (Tn.is_virtual_force tn 333 || B.unified_memory node || Hash_set.mem is_global tn)
373-
then
374+
if not (Tn.is_virtual_force tn 333 || in_ctx tn || Hash_set.mem is_global tn) then
374375
fprintf ppf "%s %s[%d]%s;@ "
375376
(B.typ_of_prec @@ Lazy.force tn.prec)
376377
(get_ident tn) (Tn.num_elems tn)

0 commit comments

Comments
 (0)