@@ -266,7 +266,35 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
266266 }
267267 [@@ deriving sexp_of ]
268268
269- let % track3_sexp _alloc_if_needed (stream : stream ) ~key ~data: node ctx_arrays =
269+ let compile ?shared ?name bindings comp : code =
270+ let name, lowered = lower_assignments ?name bindings comp.Assignments. asgns in
271+ let code = compile ?shared ~name bindings lowered in
272+ let from_prior_context =
273+ Set. diff (Assignments. context_nodes ~unified_memory comp.asgns) comp.embedded_nodes
274+ in
275+ { from_prior_context; name; lowered; code; expected_merge_node = lowered.Low_level. merge_node }
276+
277+ let compile_batch ?shared ?names ?occupancy bindings comps =
278+ let names, lowereds =
279+ lower_batch_assignments ?names ?occupancy bindings
280+ @@ Array. map comps ~f: (fun c -> c.Assignments. asgns)
281+ in
282+ let code_batch = compile_batch ?shared ~names bindings lowereds in
283+ let from_prior_context =
284+ from_prior_context_batch ~unified_memory
285+ @@ Array. mapi lowereds ~f: (fun i -> Option. map ~f: (fun _ -> comps.(i)))
286+ in
287+ {
288+ from_prior_context;
289+ names;
290+ lowereds;
291+ code_batch;
292+ expected_merge_nodes =
293+ Array. map lowereds ~f: (fun lowered ->
294+ Option. (join @@ map lowered ~f: (fun optim -> optim.Low_level. merge_node)));
295+ }
296+
297+ let % track3_sexp alloc_if_needed (stream : stream ) ~key ~data: node ctx_arrays =
270298 if Tnode. is_in_context_force ~unified_memory key 345 && not (Map. mem ctx_arrays key) then (
271299 [% log2 Tn. debug_name key];
272300 [% log3 (key : Tnode.t )];
@@ -298,39 +326,14 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
298326 add_new () ))
299327 else ctx_arrays
300328
301- let compile ?shared ?name bindings comp : code =
302- let name, lowered = lower_assignments ?name bindings comp.Assignments. asgns in
303- let code = compile ?shared ~name bindings lowered in
304- let from_prior_context =
305- Set. diff (Assignments. context_nodes ~unified_memory comp.asgns) comp.embedded_nodes
306- in
307- { from_prior_context; name; lowered; code; expected_merge_node = lowered.Low_level. merge_node }
308-
309- let compile_batch ?shared ?names ?occupancy bindings comps =
310- let names, lowereds =
311- lower_batch_assignments ?names ?occupancy bindings
312- @@ Array. map comps ~f: (fun c -> c.Assignments. asgns)
313- in
314- let code_batch = compile_batch ?shared ~names bindings lowereds in
315- let from_prior_context =
316- from_prior_context_batch ~unified_memory
317- @@ Array. mapi lowereds ~f: (fun i -> Option. map ~f: (fun _ -> comps.(i)))
318- in
319- {
320- from_prior_context;
321- names;
322- lowereds;
323- code_batch;
324- expected_merge_nodes =
325- Array. map lowereds ~f: (fun lowered ->
326- Option. (join @@ map lowered ~f: (fun optim -> optim.Low_level. merge_node)));
327- }
328-
329329 let link context (code : code ) =
330330 verify_prior_context ~unified_memory ~ctx_arrays: context.ctx_arrays
331331 ~from_prior_context: code.from_prior_context;
332332 let inputs, outputs = Low_level. input_and_output_nodes code.lowered in
333- let ctx_arrays = failwith " NOT IMPLEMENTED YET" in
333+ let ctx_arrays =
334+ Hashtbl. fold code.lowered.traced_store ~init: context.ctx_arrays
335+ ~f: (alloc_if_needed context.stream)
336+ in
334337 let bindings, schedule = link context code.code ctx_arrays in
335338 let context = make_child ~ctx_arrays context in
336339 let schedule =
@@ -344,7 +347,13 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
344347 let link_batch context code_batch =
345348 verify_prior_context ~unified_memory ~ctx_arrays: context.ctx_arrays
346349 ~from_prior_context: code_batch.from_prior_context;
347- let ctx_arrays = failwith " NOT IMPLEMENTED YET" in
350+ let ctx_arrays =
351+ Array. map code_batch.lowereds
352+ ~f:
353+ (Option. map ~f: (fun l ->
354+ Hashtbl. fold l.Low_level. traced_store ~init: context.ctx_arrays
355+ ~f: (alloc_if_needed context.stream)))
356+ in
348357 let bindings, schedules = link_batch context code_batch.code_batch ctx_arrays in
349358 Array. fold_mapi schedules ~init: context ~f: (fun i context -> function
350359 | None -> (context, None )
0 commit comments