Skip to content

Commit 4f48ac8

Browse files
committed
Complete factoring out alloc_if_needed
1 parent 041bc78 commit 4f48ac8

File tree

2 files changed

+48
-32
lines changed

2 files changed

+48
-32
lines changed

arrayjit/lib/backends.ml

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,35 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
266266
}
267267
[@@deriving sexp_of]
268268

269-
let%track3_sexp _alloc_if_needed (stream : stream) ~key ~data:node ctx_arrays =
269+
let compile ?shared ?name bindings comp : code =
270+
let name, lowered = lower_assignments ?name bindings comp.Assignments.asgns in
271+
let code = compile ?shared ~name bindings lowered in
272+
let from_prior_context =
273+
Set.diff (Assignments.context_nodes ~unified_memory comp.asgns) comp.embedded_nodes
274+
in
275+
{ from_prior_context; name; lowered; code; expected_merge_node = lowered.Low_level.merge_node }
276+
277+
let compile_batch ?shared ?names ?occupancy bindings comps =
278+
let names, lowereds =
279+
lower_batch_assignments ?names ?occupancy bindings
280+
@@ Array.map comps ~f:(fun c -> c.Assignments.asgns)
281+
in
282+
let code_batch = compile_batch ?shared ~names bindings lowereds in
283+
let from_prior_context =
284+
from_prior_context_batch ~unified_memory
285+
@@ Array.mapi lowereds ~f:(fun i -> Option.map ~f:(fun _ -> comps.(i)))
286+
in
287+
{
288+
from_prior_context;
289+
names;
290+
lowereds;
291+
code_batch;
292+
expected_merge_nodes =
293+
Array.map lowereds ~f:(fun lowered ->
294+
Option.(join @@ map lowered ~f:(fun optim -> optim.Low_level.merge_node)));
295+
}
296+
297+
let%track3_sexp alloc_if_needed (stream : stream) ~key ~data:node ctx_arrays =
270298
if Tnode.is_in_context_force ~unified_memory key 345 && not (Map.mem ctx_arrays key) then (
271299
[%log2 Tn.debug_name key];
272300
[%log3 (key : Tnode.t)];
@@ -298,39 +326,14 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
298326
add_new ()))
299327
else ctx_arrays
300328

301-
let compile ?shared ?name bindings comp : code =
302-
let name, lowered = lower_assignments ?name bindings comp.Assignments.asgns in
303-
let code = compile ?shared ~name bindings lowered in
304-
let from_prior_context =
305-
Set.diff (Assignments.context_nodes ~unified_memory comp.asgns) comp.embedded_nodes
306-
in
307-
{ from_prior_context; name; lowered; code; expected_merge_node = lowered.Low_level.merge_node }
308-
309-
let compile_batch ?shared ?names ?occupancy bindings comps =
310-
let names, lowereds =
311-
lower_batch_assignments ?names ?occupancy bindings
312-
@@ Array.map comps ~f:(fun c -> c.Assignments.asgns)
313-
in
314-
let code_batch = compile_batch ?shared ~names bindings lowereds in
315-
let from_prior_context =
316-
from_prior_context_batch ~unified_memory
317-
@@ Array.mapi lowereds ~f:(fun i -> Option.map ~f:(fun _ -> comps.(i)))
318-
in
319-
{
320-
from_prior_context;
321-
names;
322-
lowereds;
323-
code_batch;
324-
expected_merge_nodes =
325-
Array.map lowereds ~f:(fun lowered ->
326-
Option.(join @@ map lowered ~f:(fun optim -> optim.Low_level.merge_node)));
327-
}
328-
329329
let link context (code : code) =
330330
verify_prior_context ~unified_memory ~ctx_arrays:context.ctx_arrays
331331
~from_prior_context:code.from_prior_context;
332332
let inputs, outputs = Low_level.input_and_output_nodes code.lowered in
333-
let ctx_arrays = failwith "NOT IMPLEMENTED YET" in
333+
let ctx_arrays =
334+
Hashtbl.fold code.lowered.traced_store ~init:context.ctx_arrays
335+
~f:(alloc_if_needed context.stream)
336+
in
334337
let bindings, schedule = link context code.code ctx_arrays in
335338
let context = make_child ~ctx_arrays context in
336339
let schedule =
@@ -344,7 +347,13 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
344347
let link_batch context code_batch =
345348
verify_prior_context ~unified_memory ~ctx_arrays:context.ctx_arrays
346349
~from_prior_context:code_batch.from_prior_context;
347-
let ctx_arrays = failwith "NOT IMPLEMENTED YET" in
350+
let ctx_arrays =
351+
Array.map code_batch.lowereds
352+
~f:
353+
(Option.map ~f:(fun l ->
354+
Hashtbl.fold l.Low_level.traced_store ~init:context.ctx_arrays
355+
~f:(alloc_if_needed context.stream)))
356+
in
348357
let bindings, schedules = link_batch context code_batch.code_batch ctx_arrays in
349358
Array.fold_mapi schedules ~init:context ~f:(fun i context -> function
350359
| None -> (context, None)

arrayjit/lib/cc_backend.ml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,14 @@ let%diagn_sexp compile_batch ~names ~opt_ctx_arrays bindings
138138

139139
let%diagn_sexp link_compiled ~merge_buffer ~runner_label ctx_arrays (code : procedure) =
140140
let name : string = code.name in
141-
List.iter code.params ~f:(function _, Param_ptr tn -> assert (Map.mem ctx_arrays tn) | _ -> ());
141+
List.iter code.params ~f:(function
142+
| _, Param_ptr tn ->
143+
if not (Map.mem ctx_arrays tn) then
144+
invalid_arg
145+
[%string
146+
"Cc_backend.link_compiled: node %{Tn.debug_name tn} missing from context: \
147+
%{Tn.debug_memory_mode tn.Tn.memory_mode}"]
148+
| _ -> ());
142149
let log_file_name = Utils.diagn_log_file [%string "debug-%{runner_label}-%{code.name}.log"] in
143150
let run_variadic =
144151
[%log_level

0 commit comments

Comments
 (0)