@@ -30,10 +30,6 @@ type device = {
3030 cross_stream_candidates : ctx_array Hashtbl .M (Tn ).t;
3131 (* * Freshly created arrays that might be shared across streams. The map can both grow and
3232 shrink. See the explanation on top of this file. *)
33- cross_stream_shared : Hash_set .M (Tn ).t;
34- (* * Tensor nodes known to be cross-stream shared. This set can only grow. *)
35- non_cross_stream : Hash_set .M (Tn ).t;
36- (* * Tensor nodes known to not be cross-stream shared. This set can only grow. *)
3733 owner_stream_subordinal : int Hashtbl .M (Tn ).t;
3834 (* * The streams owning the given nodes. This map can only grow. *)
3935}
@@ -77,7 +73,6 @@ let is_done event = Cu.Delimited_event.query event
7773let will_wait_for context event = Cu.Delimited_event. wait context.stream.cu_stream event
7874let sync event = Cu.Delimited_event. synchronize event
7975let all_work stream = Cu.Delimited_event. record stream.cu_stream
80-
8176let scheduled_merge_node stream = Option. map ~f: snd stream.merge_buffer
8277
8378let is_initialized, initialize =
@@ -152,8 +147,6 @@ let get_device ~(ordinal : int) : device =
152147 copy_merge_buffer_capacity;
153148 released = Atomic. make false ;
154149 cross_stream_candidates = (Hashtbl. create (module Tn ) : ctx_array Hashtbl. M (Tn ).t);
155- cross_stream_shared = Hash_set. create (module Tn );
156- non_cross_stream = Hash_set. create (module Tn );
157150 owner_stream_subordinal = Hashtbl. create (module Tn );
158151 }
159152 in
@@ -275,9 +268,7 @@ let%diagn2_l_sexp rec device_to_device (tn : Tn.t) ~into_merge_buffer ~(dst : co
275268 ~src: s_arr.ptr ~src_ctx: src.ctx dst.stream.cu_stream
276269 in
277270 if
278- same_device
279- && (src.stream.subordinal = dst.stream.subordinal
280- || Hash_set. mem dst.stream.device.cross_stream_shared tn)
271+ same_device && (src.stream.subordinal = dst.stream.subordinal || Tn. known_shared_cross_stream tn)
281272 then false
282273 else
283274 match Map. find src.ctx_arrays tn with
@@ -559,24 +550,25 @@ let link_proc ~prior_context ~name ~(params : (string * param_source) list) ~ctx
559550 work;
560551 } )
561552
562- let % diagn2_sexp alloc_if_needed ctx stream ~key ~data: node ctx_arrays =
553+ let % track3_sexp alloc_if_needed ctx stream ~key ~data: node ctx_arrays =
563554 if is_in_context node && not (Map. mem ctx_arrays key) then (
564- [% log Tn. debug_name key, " read_only" , (node.read_only : bool )];
565- let default () =
555+ [% log2 Tn. debug_name key, " read_only" , (node.read_only : bool )];
556+ [% log3 (key : Tn.t )];
557+ let default () : ctx_array =
566558 set_ctx ctx;
567559 let ptr = Cu.Deviceptr. mem_alloc ~size_in_bytes: (Tn. size_in_bytes key) in
568560 { ptr; tracking = None }
569561 in
570562 let add_new () = Map. add_exn ctx_arrays ~key ~data: (default () ) in
571563 let device = stream.device in
572564 if node.read_only then
573- if Hash_set. mem device.non_cross_stream key then add_new ()
565+ if Tn. known_non_cross_stream key then add_new ()
574566 else (
575567 if Hashtbl. mem device.cross_stream_candidates key then
576- Hash_set. add device.cross_stream_shared key ;
568+ Tn. update_memory_sharing key Tn. Shared_cross_stream 40 ;
577569 let data = Hashtbl. find_or_add device.cross_stream_candidates key ~default in
578570 Map. add_exn ctx_arrays ~key ~data )
579- else if Hash_set. mem device.cross_stream_shared key then (
571+ else if Tn. known_shared_cross_stream key then (
580572 if Hashtbl. mem device.owner_stream_subordinal key then
581573 if Hashtbl. find_exn device.owner_stream_subordinal key <> stream.subordinal then
582574 raise
@@ -587,7 +579,7 @@ let%diagn2_sexp alloc_if_needed ctx stream ~key ~data:node ctx_arrays =
587579 let data = Hashtbl. find_exn device.cross_stream_candidates key in
588580 Map. add_exn ctx_arrays ~key ~data )
589581 else (
590- Hash_set. add device.non_cross_stream key ;
582+ Tn. update_memory_sharing key Tn. Per_stream 41 ;
591583 Hashtbl. remove device.cross_stream_candidates key;
592584 add_new () ))
593585 else ctx_arrays
0 commit comments