@@ -115,15 +115,14 @@ let lower_batch_assignments ?names ?occupancy bindings asgns_l =
115115 Some (Assignments. lower ~unoptim_ll_source ~ll_source ~cd_source ~name bound asgns) )
116116 else (None , None ))
117117
118- let verify_prior_context ~unified_memory ~ctx_arrays ~from_prior_context traced_stores =
118+ let verify_prior_context ~unified_memory ~ctx_arrays ~from_prior_context =
119119 Set. iter from_prior_context ~f: (fun tn ->
120- let node = Array. find_map traced_stores ~f: (fun store -> Hashtbl. find store tn) in
121120 if
122- Option. value_map node ~default: false ~f: ( fun node ->
123- Tn. is_in_context ~unified_memory node && not (Option. is_some @@ Map. find ctx_arrays tn) )
121+ Tn. is_in_context_force ~unified_memory tn 342
122+ && not (Option. is_some @@ Map. find ctx_arrays tn)
124123 then raise @@ Utils. User_error (" The linked context lacks node " ^ Tnode. debug_name tn))
125124
126- let from_prior_context_batch comps =
125+ let from_prior_context_batch ~ unified_memory comps =
127126 Array. filter_map comps ~f: (fun comp ->
128127 Option. map comp ~f: (fun comp ->
129128 Set. diff
@@ -156,7 +155,7 @@ module Add_device
156155 }
157156 | Compiled of {
158157 lowereds : Low_level .optimized option array ;
159- procs : ctx_arrays option * Backend .procedure option array ;
158+ procs : Backend .procedure option array ;
160159 }
161160 [@@ deriving sexp_of ]
162161
@@ -174,38 +173,34 @@ module Add_device
174173
175174 include Add_scheduler (Backend )
176175
177- let link context (code : code ) =
176+ let link context (code : code ) ctx_arrays =
178177 let runner_label = get_name context.stream in
179- let ctx_arrays = context.ctx_arrays in
180178 let merge_buffer = context.stream.merge_buffer in
181179 match code with
182180 | Postponed { lowered; bindings; name } ->
183181 let proc = Backend. compile ~name ~opt_ctx_arrays: (Some ctx_arrays) bindings lowered in
184182 link_compiled ~merge_buffer ~runner_label ctx_arrays proc
185183 | Compiled { proc; _ } -> link_compiled ~merge_buffer ~runner_label ctx_arrays proc
186184
187- let link_batch context (code_batch : code_batch ) =
185+ let link_batch context (code_batch : code_batch ) ctx_arrays =
188186 let runner_label = get_name context.stream in
189- let ctx_arrays = context.ctx_arrays in
190187 let merge_buffer = context.stream.merge_buffer in
191- (* FIXME: why are we getting and ignoring opt_ctx_arrays here? *)
192- let _opt_ctx_arrays, procs =
188+ let procs =
193189 match code_batch with
194190 | Postponed { lowereds; bindings; names } ->
195191 Backend. compile_batch ~names ~opt_ctx_arrays: (Some ctx_arrays) bindings lowereds
196192 | Compiled { procs; _ } -> procs
197193 in
198- let (ctx_arrays, bindings) , schedules =
199- Array. fold_map procs ~init: (ctx_arrays, None ) ~f: (fun ( ctx_arrays , bindings ) -> function
194+ let bindings, schedules =
195+ Array. fold_mapi procs ~init: None ~f: (fun i bindings -> function
200196 | Some proc ->
201- let ctx_arrays, bindings', schedule =
202- link_compiled ~merge_buffer ~runner_label ctx_arrays proc
203- in
197+ let ctx_arrays = Option. value_exn ctx_arrays.(i) in
198+ let bindings', schedule = link_compiled ~merge_buffer ~runner_label ctx_arrays proc in
204199 Option. iter bindings ~f: (fun bindings -> assert (phys_equal bindings bindings'));
205- ((ctx_arrays, Some bindings') , Some (ctx_arrays, schedule) )
206- | None -> ((ctx_arrays, bindings) , None ))
200+ (Some bindings', Some schedule)
201+ | None -> (bindings, None ))
207202 in
208- (ctx_arrays, Option. value_exn ~here: [% here] bindings, schedules)
203+ (Option. value_exn ~here: [% here] bindings, schedules)
209204
210205 let from_host ~dst_ptr ~dst hosted =
211206 let work () = host_to_buffer hosted ~dst: dst_ptr in
@@ -271,10 +266,44 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
271266 }
272267 [@@ deriving sexp_of ]
273268
269+ let % track3_sexp _alloc_if_needed (stream : stream ) ~key ~data: node ctx_arrays =
270+ if Tnode. is_in_context_force ~unified_memory key 345 && not (Map. mem ctx_arrays key) then (
271+ [% log2 Tn. debug_name key];
272+ [% log3 (key : Tnode.t )];
273+ let default () =
274+ alloc_zero_init_array (Lazy. force key.prec) ~dims: (Lazy. force key.dims) stream
275+ in
276+ let add_new () = Map. add_exn ctx_arrays ~key ~data: (default () ) in
277+ let device = stream.device in
278+ if node.Low_level. read_only then
279+ if Tn. known_non_cross_stream key then add_new ()
280+ else (
281+ if Hashtbl. mem device.cross_stream_candidates key then
282+ Tn. update_memory_sharing key Tn. Shared_cross_stream 40 ;
283+ let data = Hashtbl. find_or_add device.cross_stream_candidates key ~default in
284+ Map. add_exn ctx_arrays ~key ~data )
285+ else if Tn. known_shared_cross_stream key then (
286+ if Hashtbl. mem device.owner_streams key then
287+ if not (stream.stream_id = Hashtbl. find_exn device.owner_streams key) then
288+ raise
289+ @@ Utils. User_error
290+ (" Cuda_backend.alloc_if_needed: node " ^ Tn. debug_name key
291+ ^ " assumed to be cross-stream-shared but then written to on multiple devices" )
292+ else Hashtbl. add_exn device.owner_streams ~key ~data: stream.stream_id;
293+ let data = Hashtbl. find_exn device.cross_stream_candidates key in
294+ Map. add_exn ctx_arrays ~key ~data )
295+ else (
296+ Tn. update_memory_sharing key Tn. Per_stream 41 ;
297+ Hashtbl. remove device.cross_stream_candidates key;
298+ add_new () ))
299+ else ctx_arrays
300+
274301 let compile ?shared ?name bindings comp : code =
275302 let name, lowered = lower_assignments ?name bindings comp.Assignments. asgns in
276303 let code = compile ?shared ~name bindings lowered in
277- let from_prior_context = Set. diff (Assignments. context_nodes comp.asgns) comp.embedded_nodes in
304+ let from_prior_context =
305+ Set. diff (Assignments. context_nodes ~unified_memory comp.asgns) comp.embedded_nodes
306+ in
278307 { from_prior_context; name; lowered; code; expected_merge_node = lowered.Low_level. merge_node }
279308
280309 let compile_batch ?shared ?names ?occupancy bindings comps =
@@ -284,7 +313,7 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
284313 in
285314 let code_batch = compile_batch ?shared ~names bindings lowereds in
286315 let from_prior_context =
287- from_prior_context_batch
316+ from_prior_context_batch ~unified_memory
288317 @@ Array. mapi lowereds ~f: (fun i -> Option. map ~f: (fun _ -> comps.(i)))
289318 in
290319 {
@@ -299,9 +328,10 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
299328
300329 let link context (code : code ) =
301330 verify_prior_context ~unified_memory ~ctx_arrays: context.ctx_arrays
302- ~from_prior_context: code.from_prior_context [| code.lowered.traced_store |] ;
331+ ~from_prior_context: code.from_prior_context;
303332 let inputs, outputs = Low_level. input_and_output_nodes code.lowered in
304- let ctx_arrays, bindings, schedule = link context code.code in
333+ let ctx_arrays = failwith " NOT IMPLEMENTED YET" in
334+ let bindings, schedule = link context code.code ctx_arrays in
305335 let context = make_child ~ctx_arrays context in
306336 let schedule =
307337 Task. prepend schedule ~work: (fun () ->
@@ -313,12 +343,13 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
313343
314344 let link_batch context code_batch =
315345 verify_prior_context ~unified_memory ~ctx_arrays: context.ctx_arrays
316- ~from_prior_context: code_batch.from_prior_context
317- @@ Array. filter_map code_batch.lowereds ~f: ( Option. map ~f: ( fun l -> l. Low_level. traced_store));
318- let _ctx_arrays, bindings, schedules = link_batch context code_batch.code_batch in
346+ ~from_prior_context: code_batch.from_prior_context;
347+ let ctx_arrays = failwith " NOT IMPLEMENTED YET " in
348+ let bindings, schedules = link_batch context code_batch.code_batch ctx_arrays in
319349 Array. fold_mapi schedules ~init: context ~f: (fun i context -> function
320350 | None -> (context, None )
321- | Some (ctx_arrays , schedule ) ->
351+ | Some schedule ->
352+ let ctx_arrays = Option. value_exn ctx_arrays.(i) in
322353 let context = make_child ~ctx_arrays context in
323354 let expected_merge_node = code_batch.expected_merge_nodes.(i) in
324355 let inputs, outputs =
0 commit comments