Skip to content

Commit bd0dc98

Browse files
committed
Fixes #286: use Tnode.sharing in the cuda backend
1 parent be9a299 commit bd0dc98

File tree

3 files changed

+25
-24
lines changed

3 files changed

+25
-24
lines changed

arrayjit/lib/cuda_backend.cudajit.ml

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,6 @@ type device = {
3030
cross_stream_candidates : ctx_array Hashtbl.M(Tn).t;
3131
(** Freshly created arrays that might be shared across streams. The map can both grow and
3232
shrink. See the explanation on top of this file. *)
33-
cross_stream_shared : Hash_set.M(Tn).t;
34-
(** Tensor nodes known to be cross-stream shared. This set can only grow. *)
35-
non_cross_stream : Hash_set.M(Tn).t;
36-
(** Tensor nodes known to not be cross-stream shared. This set can only grow. *)
3733
owner_stream_subordinal : int Hashtbl.M(Tn).t;
3834
(** The streams owning the given nodes. This map can only grow. *)
3935
}
@@ -77,7 +73,6 @@ let is_done event = Cu.Delimited_event.query event
7773
let will_wait_for context event = Cu.Delimited_event.wait context.stream.cu_stream event
7874
let sync event = Cu.Delimited_event.synchronize event
7975
let all_work stream = Cu.Delimited_event.record stream.cu_stream
80-
8176
let scheduled_merge_node stream = Option.map ~f:snd stream.merge_buffer
8277

8378
let is_initialized, initialize =
@@ -152,8 +147,6 @@ let get_device ~(ordinal : int) : device =
152147
copy_merge_buffer_capacity;
153148
released = Atomic.make false;
154149
cross_stream_candidates = (Hashtbl.create (module Tn) : ctx_array Hashtbl.M(Tn).t);
155-
cross_stream_shared = Hash_set.create (module Tn);
156-
non_cross_stream = Hash_set.create (module Tn);
157150
owner_stream_subordinal = Hashtbl.create (module Tn);
158151
}
159152
in
@@ -275,9 +268,7 @@ let%diagn2_l_sexp rec device_to_device (tn : Tn.t) ~into_merge_buffer ~(dst : co
275268
~src:s_arr.ptr ~src_ctx:src.ctx dst.stream.cu_stream
276269
in
277270
if
278-
same_device
279-
&& (src.stream.subordinal = dst.stream.subordinal
280-
|| Hash_set.mem dst.stream.device.cross_stream_shared tn)
271+
same_device && (src.stream.subordinal = dst.stream.subordinal || Tn.known_shared_cross_stream tn)
281272
then false
282273
else
283274
match Map.find src.ctx_arrays tn with
@@ -559,24 +550,25 @@ let link_proc ~prior_context ~name ~(params : (string * param_source) list) ~ctx
559550
work;
560551
} )
561552

562-
let%diagn2_sexp alloc_if_needed ctx stream ~key ~data:node ctx_arrays =
553+
let%track3_sexp alloc_if_needed ctx stream ~key ~data:node ctx_arrays =
563554
if is_in_context node && not (Map.mem ctx_arrays key) then (
564-
[%log Tn.debug_name key, "read_only", (node.read_only : bool)];
565-
let default () =
555+
[%log2 Tn.debug_name key, "read_only", (node.read_only : bool)];
556+
[%log3 (key : Tn.t)];
557+
let default () : ctx_array =
566558
set_ctx ctx;
567559
let ptr = Cu.Deviceptr.mem_alloc ~size_in_bytes:(Tn.size_in_bytes key) in
568560
{ ptr; tracking = None }
569561
in
570562
let add_new () = Map.add_exn ctx_arrays ~key ~data:(default ()) in
571563
let device = stream.device in
572564
if node.read_only then
573-
if Hash_set.mem device.non_cross_stream key then add_new ()
565+
if Tn.known_non_cross_stream key then add_new ()
574566
else (
575567
if Hashtbl.mem device.cross_stream_candidates key then
576-
Hash_set.add device.cross_stream_shared key;
568+
Tn.update_memory_sharing key Tn.Shared_cross_stream 40;
577569
let data = Hashtbl.find_or_add device.cross_stream_candidates key ~default in
578570
Map.add_exn ctx_arrays ~key ~data)
579-
else if Hash_set.mem device.cross_stream_shared key then (
571+
else if Tn.known_shared_cross_stream key then (
580572
if Hashtbl.mem device.owner_stream_subordinal key then
581573
if Hashtbl.find_exn device.owner_stream_subordinal key <> stream.subordinal then
582574
raise
@@ -587,7 +579,7 @@ let%diagn2_sexp alloc_if_needed ctx stream ~key ~data:node ctx_arrays =
587579
let data = Hashtbl.find_exn device.cross_stream_candidates key in
588580
Map.add_exn ctx_arrays ~key ~data)
589581
else (
590-
Hash_set.add device.non_cross_stream key;
582+
Tn.update_memory_sharing key Tn.Per_stream 41;
591583
Hashtbl.remove device.cross_stream_candidates key;
592584
add_new ()))
593585
else ctx_arrays

arrayjit/lib/tnode.ml

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,16 +132,13 @@ let log_debug_info ~from_log_level tn =
132132
| (lazy (Some nd)) -> Nd.log_debug_info ~from_log_level nd
133133
else [%log "<not-in-yet>"]]]
134134

135-
(** The one exception to "most local" is the sharing property: defaults to [Shared_cross_stream]. *)
135+
(** The one exception to "most local" is that the sharing property is kept at [Unset]. *)
136136
let default_to_most_local tn provenance =
137137
match tn.memory_mode with
138138
| None | Some (Effectively_constant, _) -> tn.memory_mode <- Some (Virtual, provenance)
139139
| Some (Never_virtual, _) -> tn.memory_mode <- Some (Local, provenance)
140140
| Some (Device_only, _) -> tn.memory_mode <- Some (Local, provenance)
141-
| Some (Materialized, _) -> tn.memory_mode <- Some (On_device Shared_cross_stream, provenance)
142-
| Some (On_device Unset, _) -> tn.memory_mode <- Some (On_device Shared_cross_stream, provenance)
143-
| Some (Hosted (Changed_on_devices Unset), _) ->
144-
tn.memory_mode <- Some (Hosted (Changed_on_devices Shared_cross_stream), provenance)
141+
| Some (Materialized, _) -> tn.memory_mode <- Some (On_device Unset, provenance)
145142
| Some ((Virtual | Local | On_device _ | Hosted _), _) -> ()
146143

147144
let is_virtual_force tn provenance =
@@ -194,6 +191,18 @@ let known_not_param tn =
194191
true
195192
| _ -> false
196193

194+
let known_shared_cross_stream tn =
195+
match tn.memory_mode with
196+
| Some ((On_device Shared_cross_stream | Hosted (Changed_on_devices Shared_cross_stream)), _) ->
197+
true
198+
| _ -> false
199+
200+
let known_non_cross_stream tn =
201+
match tn.memory_mode with
202+
| Some ((On_device Per_stream | Hosted (Changed_on_devices Per_stream)), _) ->
203+
true
204+
| _ -> false
205+
197206
let mode_is_unspecified tn =
198207
match tn.memory_mode with
199208
| None | Some ((Never_virtual | Effectively_constant), _) -> true
@@ -246,7 +255,7 @@ let update_memory_sharing tn sharing provenance =
246255
[%string
247256
"Tnode.update_memory_sharing: update %{prov2#Int} -> %{provenance#Int} for \
248257
%{debug_name tn} -- change from non-shared to shared is currently not permitted"]
249-
| Some ((On_device _ | Device_only | Materialized), _), Shared_cross_stream ->
258+
| Some ((On_device _ | Device_only | Materialized), _), _ ->
250259
tn.memory_mode <- Some (On_device sharing, provenance)
251260
| Some (Hosted (Changed_on_devices Per_stream), prov2), Shared_cross_stream ->
252261
raise

bin/hello_world.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ module Rand = Arrayjit.Rand.Lib
1111
let hello1 () =
1212
Rand.init 0;
1313
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
14-
Utils.set_log_level 2;
14+
(* Utils.set_log_level 2; *)
1515
(* Utils.settings.output_debug_files_in_build_directory <- true; *)
1616
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
1717
let ctx = Backend.init stream in

0 commit comments

Comments
 (0)