@@ -33,9 +33,10 @@ type buffer_ptr = Cu.Deviceptr.t
3333
3434let sexp_of_buffer_ptr ptr = Sexp. Atom (Cu.Deviceptr. string_of ptr)
3535
36- type ctx_array = Cu.Deviceptr .t
36+ type event = Cu.Delimited_event .t
3737
38- let sexp_of_ctx_array = sexp_of_buffer_ptr
38+ type ctx_array = { ptr : buffer_ptr ; mutable tracking : (event [@ sexp.opaque]) option }
39+ [@@ deriving sexp_of ]
3940
4041type physical_device = {
4142 dev : (Cu.Device .t [@ sexp.opaque]);
@@ -83,6 +84,19 @@ and context = {
8384let ctx_arrays ctx = ctx.ctx_arrays
8485let global_config = ref For_parallel_copying
8586
87+ let work_for ctx tn =
88+ match Map. find ctx.ctx_arrays tn with
89+ | None -> None
90+ | Some { tracking = Some event ; _ } -> Some event
91+ | Some ctx_array ->
92+ ctx_array.tracking < - Some (Cu.Delimited_event. record ctx.device.stream);
93+ ctx_array.tracking
94+
95+ let is_done event = Cu.Delimited_event. query event
96+ let will_wait_for context event = Cu.Delimited_event. wait context.device.stream event
97+ let sync event = Cu.Delimited_event. synchronize event
98+ let all_work device = Cu.Delimited_event. record device.stream
99+
86100let is_initialized, initialize =
87101 let initialized = ref false in
88102 let init (config : config ) : unit =
@@ -146,7 +160,7 @@ let get_device ~(ordinal : int) : physical_device =
146160 copy_merge_buffer;
147161 copy_merge_buffer_capacity;
148162 released = Atomic. make false ;
149- cross_device_candidates = Hashtbl. create (module Tn );
163+ cross_device_candidates = ( Hashtbl. create (module Tn ) : ctx_array Hashtbl. M ( Tn ).t );
150164 cross_device_shared = Hash_set. create (module Tn );
151165 non_cross_device = Hash_set. create (module Tn );
152166 owner_device_subordinal = Hashtbl. create (module Tn );
@@ -203,13 +217,13 @@ let finalize ctx =
203217 then (
204218 (* await does this: set_ctx ctx.device.physical.primary_context; *)
205219 await ctx.device;
220+ (* Cudajit's contexts, streams and events are destroyed by their respective finalizers. *)
206221 Option. iter ctx.run_module ~f: Cu.Module. unload;
207- Map. iteri ctx.ctx_arrays ~f: (fun ~key ~data : ptr ->
222+ Map. iteri ctx.ctx_arrays ~f: (fun ~key ~data ->
208223 if
209224 (not (Option. exists ctx.parent ~f: (fun pc -> Map. mem pc.ctx_arrays key)))
210225 && not (Hashtbl. mem ctx.device.physical.cross_device_candidates key)
211- then Cu.Deviceptr. mem_free ptr);
212- if Option. is_none ctx.parent then Cu.Stream. destroy ctx.device.stream)
226+ then Cu.Deviceptr. mem_free data.ptr))
213227
214228let init device =
215229 let ctx =
@@ -236,7 +250,8 @@ let unsafe_cleanup () =
236250 Cu.Context. set_current device.primary_context;
237251 Cu.Context. synchronize () ;
238252 Option. iter ! Utils. advance_captured_logs ~f: (fun callback -> callback () );
239- Hashtbl. iter device.cross_device_candidates ~f: Cu.Deviceptr. mem_free;
253+ Hashtbl. iter device.cross_device_candidates ~f: (fun ctx_array ->
254+ Cu.Deviceptr. mem_free ctx_array.ptr);
240255 Cu.Device. primary_ctx_release device.dev))
241256 done ;
242257 Core.Weak. fill ! devices 0 len None
@@ -246,7 +261,7 @@ let%diagn_l_sexp from_host (ctx : context) tn =
246261 | { Tn. array = (lazy (Some hosted )); _ } , Some dst ->
247262 set_ctx ctx.ctx;
248263 [% log " copying" , Tn. debug_name tn, " to" , (dst : ctx_array ), " from host" ];
249- let f src = Cu.Stream. memcpy_H_to_D ~dst ~src ctx.device.stream in
264+ let f src = Cu.Stream. memcpy_H_to_D ~dst: dst.ptr ~src ctx.device.stream in
250265 Ndarray. map { f } hosted;
251266 true
252267 | _ -> false
@@ -256,7 +271,7 @@ let%track_l_sexp to_host (ctx : context) (tn : Tn.t) =
256271 | { Tn. array = (lazy (Some hosted )); _ } , Some src ->
257272 set_ctx ctx.ctx;
258273 [% log " copying" , Tn. debug_name tn, " at" , (src : ctx_array ), " to host" ];
259- let f dst = Cu.Stream. memcpy_D_to_H ~dst ~src ctx.device.stream in
274+ let f dst = Cu.Stream. memcpy_D_to_H ~dst ~src: src.ptr ctx.device.stream in
260275 Ndarray. map { f } hosted;
261276 true
262277 | _ -> false
@@ -266,11 +281,11 @@ let%track_l_sexp rec device_to_device (tn : Tn.t) ~into_merge_buffer ~(dst : con
266281 let same_physical = phys_equal dst.device.physical src.device.physical in
267282 let memcpy ~d_arr ~s_arr =
268283 if same_physical then
269- Cu.Stream. memcpy_D_to_D ~size_in_bytes: (Tn. size_in_bytes tn) ~dst: d_arr ~src: s_arr
284+ Cu.Stream. memcpy_D_to_D ~size_in_bytes: (Tn. size_in_bytes tn) ~dst: d_arr.ptr ~src: s_arr.ptr
270285 dst.device.stream
271286 else
272- Cu.Stream. memcpy_peer ~size_in_bytes: (Tn. size_in_bytes tn) ~dst: d_arr ~dst_ctx: dst.ctx
273- ~src: s_arr ~src_ctx: src.ctx dst.device.stream
287+ Cu.Stream. memcpy_peer ~size_in_bytes: (Tn. size_in_bytes tn) ~dst: d_arr.ptr ~dst_ctx: dst.ctx
288+ ~src: s_arr.ptr ~src_ctx: src.ctx dst.device.stream
274289 in
275290 if
276291 same_physical
@@ -300,7 +315,7 @@ let%track_l_sexp rec device_to_device (tn : Tn.t) ~into_merge_buffer ~(dst : con
300315 true )
301316 | Streaming ->
302317 if phys_equal dst.device.physical src.device.physical then (
303- dst.device.merge_buffer < - Some (s_arr, tn);
318+ dst.device.merge_buffer < - Some (s_arr.ptr , tn);
304319 [% log " using merge buffer for" , Tn. debug_name tn, " from" , src.label];
305320 true )
306321 else
@@ -310,7 +325,7 @@ let%track_l_sexp rec device_to_device (tn : Tn.t) ~into_merge_buffer ~(dst : con
310325 set_ctx dst.ctx;
311326 let size_in_bytes = Tn. size_in_bytes tn in
312327 opt_alloc_merge_buffer ~size_in_bytes dst.device.physical;
313- memcpy ~d_arr: dst.device.physical.copy_merge_buffer ~s_arr ;
328+ memcpy ~d_arr: { ptr = dst.device.physical.copy_merge_buffer; tracking = None } ~s_arr ;
314329 dst.device.merge_buffer < - Some (dst.device.physical.copy_merge_buffer, tn);
315330 [% log " copied into merge buffer" , Tn. debug_name tn, " from" , src.label];
316331 true )
@@ -354,7 +369,8 @@ let%diagn_sexp cuda_to_ptx ~name cu_src =
354369 Stdio.Out_channel. flush oc;
355370 Stdio.Out_channel. close oc;
356371 let oc = Out_channel. open_text @@ Utils. build_file @@ name ^ " .cu_log" in
357- Stdio.Out_channel. output_string oc @@ Option. value_exn ~here: [% here] (Cu.Nvrtc. compilation_log ptx);
372+ Stdio.Out_channel. output_string oc
373+ @@ Option. value_exn ~here: [% here] (Cu.Nvrtc. compilation_log ptx);
358374 Stdio.Out_channel. flush oc;
359375 Stdio.Out_channel. close oc);
360376 ptx
@@ -486,7 +502,7 @@ let get_global_run_id =
486502 if ! next_id < 0 then next_id := 0 ;
487503 ! next_id
488504
489- let link_proc ~prior_context ~name ~(params : (string * param_source) list ) ~ctx_arrays
505+ let link_proc ~prior_context ~name ~(params : (string * param_source) list ) ~ctx_arrays traced_store
490506 lowered_bindings run_module =
491507 let module Cu = Cudajit in
492508 let func = Cu.Module. get_function run_module ~name in
@@ -505,8 +521,8 @@ let link_proc ~prior_context ~name ~(params : (string * param_source) list) ~ctx
505521 prior_context.ctx_arrays? *)
506522 List. map params ~f: (function
507523 | _name , Param_ptr tn ->
508- let ptr = Option. value_exn ~here: [% here] @@ Map. find ctx_arrays tn in
509- S. Tensor ptr
524+ let arr = Option. value_exn ~here: [% here] @@ Map. find ctx_arrays tn in
525+ S. Tensor arr. ptr
510526 | _name , Log_file_name -> S. Int log_id
511527 | _name , Merge_buffer ->
512528 let ptr = fst @@ Option. value_exn ~here: [% here] context.device.merge_buffer in
@@ -533,15 +549,18 @@ let link_proc ~prior_context ~name ~(params : (string * param_source) list) ~ctx
533549 (* Map.iteri ctx_arrays ~f:(fun ~key ~data:ptr -> if key.Low_level.zero_initialized then
534550 Cu.Stream.memset_d8 ptr Unsigned.UChar.zero ~length:(Tn.size_in_bytes key.Low_level.tn)); *)
535551 [% log " launching the kernel" ];
536- (* TODO: This doesn't help. *)
537- (* Option.iter !Utils.advance_captured_logs ~f:(fun callback -> callback ()); *)
538552 (if Utils. debug_log_from_routines () then
539553 Utils. add_log_processor ~prefix: log_id_prefix @@ fun _output ->
540554 [% log_block
541555 context.label;
542556 Utils. log_trace_tree _output]);
543- S. launch_kernel func ~grid_dim_x: 1 ~block_dim_x: 1 ~shared_mem_bytes: 0 context.device.stream
544- args;
557+ S. launch_kernel func ~grid_dim_x: 1 ~block_dim_x: 1 ~shared_mem_bytes: 0 context.device.stream args;
558+ Map. iteri ctx_arrays ~f: (fun ~key ~data ->
559+ (* Note: a tensor node can only be a context array if it is materialized. *)
560+ if Option. is_some data.tracking then
561+ let traced = Low_level. get_node traced_store key in
562+ if not traced.read_only then
563+ data.tracking < - Some (Cu.Delimited_event. record context.device.stream));
545564 [% log " kernel launched" ]
546565 in
547566 ( context,
@@ -557,7 +576,8 @@ let%diagn_sexp alloc_if_needed ctx device ~key ~data:node ctx_arrays =
557576 [% log Tn. debug_name key, " read_only" , (node.read_only : bool )];
558577 let default () =
559578 set_ctx ctx;
560- Cu.Deviceptr. mem_alloc ~size_in_bytes: (Tn. size_in_bytes key)
579+ let ptr = Cu.Deviceptr. mem_alloc ~size_in_bytes: (Tn. size_in_bytes key) in
580+ { ptr; tracking = None }
561581 in
562582 let add_new () = Map. add_exn ctx_arrays ~key ~data: (default () ) in
563583 let physical = device.physical in
@@ -600,8 +620,8 @@ let%track3_sexp link prior_context (code : code) : context * _ * _ =
600620 let idx_params = Indexing. bound_symbols code.bindings in
601621 let lowered_bindings : Indexing.lowered_bindings = List. map idx_params ~f: (fun s -> (s, ref 0 )) in
602622 let context, task =
603- link_proc ~prior_context ~name: code.name ~params: code.params ~ctx_arrays lowered_bindings
604- run_module
623+ link_proc ~prior_context ~name: code.name ~params: code.params ~ctx_arrays code.traced_store
624+ lowered_bindings run_module
605625 in
606626 (context, lowered_bindings, task)
607627
@@ -622,8 +642,8 @@ let%track3_sexp link_batch prior_context (code_batch : code_batch) : context * _
622642 ~f: (alloc_if_needed ctx prior_context.device)
623643 in
624644 let context, task =
625- link_proc ~prior_context: context ~name ~params ~ctx_arrays lowered_bindings
626- run_module
645+ link_proc ~prior_context: context ~name ~params ~ctx_arrays traced_store
646+ lowered_bindings run_module
627647 in
628648 ((context, ctx_arrays), Some task)))
629649 in
0 commit comments