Skip to content

Commit 4214f83

Browse files
committed
backends: Get rid of subordinal
1 parent 702d540 commit 4214f83

File tree

9 files changed

+69
-98
lines changed

9 files changed

+69
-98
lines changed

CHANGES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
- Fixed #286: cross-stream-sharing incorporated into `Tnode.memory_mode`.
1616
- Moved the multicore backend from a `device = stream` model to a single device model.
1717
- Got rid of `unsafe_cleanup`.
18+
- Got rid of `subordinal`.
1819
- TODO: Built per-tensor-node stream-to-stream synchronization into device-to-device copying functions, removed obsolete blocking synchronizations.
1920

2021
### Fixed

arrayjit/lib/backend_types.ml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,6 @@ module type Backend = sig
202202
val get_ctx_stream : context -> stream
203203
val get_stream_device : stream -> device
204204
val to_ordinal : device -> int
205-
val to_subordinal : stream -> int
206205
val get_name : stream -> string
207206
end
208207

@@ -332,5 +331,4 @@ module type Lowered_backend = sig
332331
val get_ctx_stream : context -> stream
333332
val get_name : stream -> string
334333
val to_ordinal : device -> int
335-
val to_subordinal : stream -> int
336334
end

arrayjit/lib/backends.ml

Lines changed: 43 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ let check_merge_buffer ~scheduled_node ~code_node =
1818
("Merge buffer mismatch, on stream: " ^ name scheduled_node ^ ", expected by code: "
1919
^ name code_node)
2020

21-
module Multicore_backend (Backend : Backend_types.No_device_backend) (* : Backend_types.Backend *) =
21+
module Multicore_backend (Backend : Backend_types.No_device_backend) : Backend_types.Backend =
2222
struct
2323
module Domain = Domain [@warning "-3"]
2424

@@ -68,7 +68,7 @@ struct
6868
state : stream_state;
6969
merge_buffer : (buffer_ptr * Tnode.t) option ref;
7070
mutable allocated_buffer : (buffer_ptr * int) option;
71-
subordinal : int;
71+
unique_name : string;
7272
domain : (unit Domain.t[@sexp.opaque]);
7373
}
7474
[@@deriving sexp_of]
@@ -97,19 +97,17 @@ struct
9797
Stdlib.Condition.wait d.host_wait_for_idle d.mut
9898
done;
9999
Mut.unlock d.mut;
100-
Option.iter d.stream_error ~f:(fun e ->
101-
Exn.reraise e @@ name ^ " stream " ^ Int.to_string stream.subordinal))
100+
Option.iter d.stream_error ~f:(fun e -> Exn.reraise e @@ name ^ " " ^ stream.unique_name))
102101

103102
(** TODO: Returns the event indicating if any currently running or scheduled computations on the
104103
stream have completed. *)
105104
let all_work _stream = Not_implemented_yet
106105

107106
let%track3_l_sexp schedule_task stream task =
108107
assert (Domain.is_main_domain ());
109-
[%log_result "schedule_task", Task.describe task, "stream", (stream.subordinal : int)];
108+
[%log_result "schedule_task", Task.describe task, stream.unique_name];
110109
let d = stream.state in
111-
Option.iter d.stream_error ~f:(fun e ->
112-
Exn.reraise e @@ name ^ " stream " ^ Int.to_string stream.subordinal);
110+
Option.iter d.stream_error ~f:(fun e -> Exn.reraise e @@ name ^ " " ^ stream.unique_name);
113111
if not d.keep_spinning then invalid_arg "Multicore_backend: stream not available";
114112
if not @@ Queue.try_push d.queue task then (
115113
await stream;
@@ -121,7 +119,7 @@ struct
121119

122120
let global_run_no = ref 0
123121

124-
let%track3_l_sexp spinup_stream ~(subordinal : int) : stream =
122+
let%track3_l_sexp spinup_stream ~unique_name : stream =
125123
Int.incr global_run_no;
126124
let state =
127125
{
@@ -153,14 +151,14 @@ struct
153151
with e ->
154152
state.stream_error <- Some e;
155153
state.keep_spinning <- false;
156-
[%log1 "Stream", (subordinal : int), "exception", Exn.to_string e];
154+
[%log1 unique_name, "exception", Exn.to_string e];
157155
(* TODO: we risk raising this error multiple times because await and schedule_task raise
158156
stream_error. But this is fine if we assume all exceptions are fatal. *)
159157
raise e
160158
in
161159
{
162160
state;
163-
subordinal;
161+
unique_name;
164162
domain = Domain.spawn worker;
165163
merge_buffer = ref None;
166164
allocated_buffer = None;
@@ -169,7 +167,7 @@ struct
169167
type context = { stream : stream; ctx : Backend.context } [@@deriving sexp_of]
170168
type nonrec routine = context routine [@@deriving sexp_of]
171169

172-
let init stream = { stream; ctx = Backend.init (name ^ " " ^ Int.to_string stream.subordinal) }
170+
let init stream = { stream; ctx = Backend.init (name ^ " " ^ stream.unique_name) }
173171
let initialize = Backend.initialize
174172
let is_initialized = Backend.is_initialized
175173

@@ -179,14 +177,14 @@ struct
179177

180178
let compile = Backend.compile
181179
let compile_batch = Backend.compile_batch
182-
let get_stream_name s = "stream " ^ Int.to_string s.subordinal
180+
let get_name stream = stream.unique_name
183181

184182
let link { ctx; stream } code =
185183
let task = Backend.link ~merge_buffer:stream.merge_buffer ctx code in
186184
{
187185
task with
188186
context = { ctx = task.context; stream };
189-
schedule = Task.enschedule ~schedule_task ~get_stream_name stream task.schedule;
187+
schedule = Task.enschedule ~schedule_task ~get_stream_name:get_name stream task.schedule;
190188
}
191189

192190
let link_batch { ctx; stream } code_batch =
@@ -198,7 +196,8 @@ struct
198196
{
199197
task with
200198
context = { ctx = task.context; stream };
201-
schedule = Task.enschedule ~schedule_task ~get_stream_name stream task.schedule;
199+
schedule =
200+
Task.enschedule ~schedule_task ~get_stream_name:get_name stream task.schedule;
202201
})) )
203202

204203
let from_host (context : context) (tn : Tnode.t) =
@@ -223,8 +222,7 @@ struct
223222
{
224223
context_lifetime = context;
225224
description =
226-
"from_host " ^ Tnode.debug_name tn ^ " dst "
227-
^ Int.to_string context.stream.subordinal;
225+
"from_host " ^ Tnode.debug_name tn ^ " dst " ^ context.stream.unique_name;
228226
work;
229227
});
230228
true
@@ -257,8 +255,7 @@ struct
257255
{
258256
context_lifetime = context;
259257
description =
260-
"from_host " ^ Tnode.debug_name tn ^ " dst "
261-
^ Int.to_string context.stream.subordinal;
258+
"from_host " ^ Tnode.debug_name tn ^ " dst " ^ context.stream.unique_name;
262259
work;
263260
});
264261
true
@@ -296,8 +293,8 @@ struct
296293
Backend.to_buffer tn ~dst:merge_ptr ~src:src.ctx
297294
in
298295
let description =
299-
"device_to_device " ^ Tnode.debug_name tn ^ " dst " ^ Int.to_string dev.subordinal ^ " src "
300-
^ Int.to_string src.stream.subordinal
296+
"device_to_device " ^ Tnode.debug_name tn ^ " dst " ^ dev.unique_name ^ " src "
297+
^ src.stream.unique_name
301298
in
302299
schedule_task dev (Task.Task { context_lifetime = (src, dst); description; work })
303300
in
@@ -311,13 +308,14 @@ struct
311308

312309
let num_devices () = 1
313310
let suggested_num_streams CPU = Domain.recommended_domain_count () - 1
314-
let latest_subordinal = ref 0
311+
let used_names = Hash_set.create (module String)
315312

316313
let cleanup_stream stream =
317314
assert (Domain.is_main_domain ());
318315
await stream;
319316
stream.state.keep_spinning <- false;
320317
Stdlib.Condition.broadcast stream.state.dev_wait_for_work;
318+
Hash_set.remove used_names stream.unique_name;
321319
Domain.join stream.domain
322320

323321
let get_device ~ordinal =
@@ -326,17 +324,20 @@ struct
326324
CPU
327325

328326
let new_stream CPU =
329-
let subordinal = !latest_subordinal in
330-
Int.incr latest_subordinal;
331-
let stream = spinup_stream ~subordinal in
327+
assert (Domain.is_main_domain ());
328+
let rec unique_name suffix =
329+
let name = "stream " ^ Int.to_string suffix in
330+
if Hash_set.mem used_names name then unique_name (suffix + 1) else name
331+
in
332+
let unique_name = unique_name 0 in
333+
Hash_set.add used_names unique_name;
334+
let stream = spinup_stream ~unique_name in
332335
Stdlib.Gc.finalise cleanup_stream stream;
333336
stream
334337

335338
let get_stream_device _stream = CPU
336339
let get_ctx_stream { stream; _ } = stream
337-
let get_name stream = Int.to_string stream.subordinal
338340
let to_ordinal _ = 0
339-
let to_subordinal { subordinal; _ } = subordinal
340341
end
341342

342343
(** For debugging, allow [Sync_backend(...).suggested_num_streams] calls to return >1 numbers. *)
@@ -354,7 +355,7 @@ module Sync_backend (Backend : Backend_types.No_device_backend) : Backend_types.
354355
let will_wait_for _context () = ()
355356

356357
type stream = {
357-
subordinal : int;
358+
unique_name : string;
358359
merge_buffer : (buffer_ptr * Tnode.t) option ref;
359360
mutable allocated_buffer : (buffer_ptr * int) option;
360361
}
@@ -396,7 +397,7 @@ module Sync_backend (Backend : Backend_types.No_device_backend) : Backend_types.
396397
Array.map routines
397398
~f:(Option.map ~f:(fun task -> { task with context = { ctx = task.context; stream } })) )
398399

399-
let get_name stream = Int.to_string stream.subordinal
400+
let get_name stream = stream.unique_name
400401

401402
let from_host (context : context) (tn : Tnode.t) =
402403
Option.value ~default:false
@@ -472,22 +473,32 @@ module Sync_backend (Backend : Backend_types.No_device_backend) : Backend_types.
472473

473474
let num_devices () = 1
474475
let suggested_num_streams _device = !sync_suggested_num_streams
475-
let next_stream_id = ref 0
476476

477477
let get_device ~ordinal =
478478
if ordinal <> 0 then invalid_arg "Sync_backend backends only have device number 0";
479479
CPU
480480

481+
let used_names = Hash_set.create (module String)
482+
483+
let cleanup_stream stream =
484+
assert (Domain.is_main_domain ());
485+
await stream;
486+
Hash_set.remove used_names stream.unique_name
487+
481488
let new_stream CPU =
482-
let result =
483-
{ subordinal = !next_stream_id; merge_buffer = ref None; allocated_buffer = None }
489+
let rec unique_name suffix =
490+
let name = "stream " ^ Int.to_string suffix in
491+
if Hash_set.mem used_names name then unique_name (suffix + 1) else name
484492
in
493+
let unique_name = unique_name 0 in
494+
Hash_set.add used_names unique_name;
495+
let result = { unique_name; merge_buffer = ref None; allocated_buffer = None } in
496+
Stdlib.Gc.finalise cleanup_stream result;
485497
result
486498

487499
let get_stream_device _stream = CPU
488500
let get_ctx_stream { stream; _ } = stream
489501
let to_ordinal _ = 0
490-
let to_subordinal stream = stream.subordinal
491502
end
492503

493504
let lower_assignments ?name bindings asgns =

arrayjit/lib/cuda_backend.cudajit.ml

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,20 @@ type device = {
3838
primary_context : Cu.Context.t;
3939
mutable copy_merge_buffer : buffer_ptr;
4040
mutable copy_merge_buffer_capacity : int;
41-
mutable latest_subordinal : int;
41+
used_names : Hash_set.M(String).t; (** Unique names of streams. *)
4242
released : Utils.atomic_bool;
4343
cross_stream_candidates : ctx_array Hashtbl.M(Tn).t;
4444
(** Freshly created arrays that might be shared across streams. The map can both grow and
4545
shrink. See the explanation on top of this file. *)
46-
owner_stream_subordinal : int Hashtbl.M(Tn).t;
46+
owner_streams : string Hashtbl.M(Tn).t;
4747
(** The streams owning the given nodes. This map can only grow. *)
4848
}
4949
[@@deriving sexp_of]
5050

5151
and stream = {
5252
device : device;
5353
cu_stream : Cu.Stream.t;
54-
subordinal : int;
54+
unique_name : string;
5555
mutable merge_buffer : (buffer_ptr * Tn.t) option;
5656
}
5757

@@ -176,13 +176,17 @@ let%track3_sexp get_device ~(ordinal : int) : device =
176176
(* We need this: there can be an arbitrary gap between the finalizer run and the deallocation. *)
177177
if Atomic.get result.released then default () else result
178178

179-
let new_stream device =
180-
let subordinal = device.latest_subordinal in
181-
device.latest_subordinal <- device.latest_subordinal + 1;
179+
let%track3_sexp new_stream (device : device) : stream =
180+
let rec unique_name suffix =
181+
let name = "stream " ^ Int.to_string suffix in
182+
if Hash_set.mem device.used_names name then unique_name (suffix + 1) else name
183+
in
184+
let unique_name = unique_name 0 in
185+
Hash_set.add device.used_names unique_name;
182186
(* Strange that we need ctx_set_current even with a single device! *)
183187
set_ctx device.primary_context;
184188
let cu_stream = Cu.Stream.create ~non_blocking:true () in
185-
{ device; cu_stream; subordinal; merge_buffer = None }
189+
{ device; cu_stream; unique_name; merge_buffer = None }
186190

187191
let cuda_properties =
188192
let cache =
@@ -205,10 +209,7 @@ let suggested_num_streams device =
205209
let get_ctx_stream { stream; _ } = stream
206210
let get_stream_device { device; _ } = device
207211
let to_ordinal { ordinal; _ } = ordinal
208-
let to_subordinal { subordinal; _ } = subordinal
209-
210-
let get_name stream =
211-
Int.to_string (to_ordinal stream.device) ^ "_" ^ Int.to_string (to_subordinal stream)
212+
let get_name stream = stream.unique_name
212213

213214
let await stream : unit =
214215
set_ctx stream.device.primary_context;
@@ -278,7 +279,9 @@ let%diagn2_l_sexp rec device_to_device (tn : Tn.t) ~into_merge_buffer ~(dst : co
278279
~src:s_arr.ptr ~src_ctx:src.ctx dst.stream.cu_stream
279280
in
280281
if
281-
same_device && (src.stream.subordinal = dst.stream.subordinal || Tn.known_shared_cross_stream tn)
282+
same_device
283+
&& (Tn.known_shared_cross_stream tn
284+
|| String.equal src.stream.unique_name dst.stream.unique_name)
282285
then false
283286
else
284287
match Map.find src.ctx_arrays tn with
@@ -581,13 +584,13 @@ let%track3_sexp alloc_if_needed ctx stream ~key ~data:node ctx_arrays =
581584
let data = Hashtbl.find_or_add device.cross_stream_candidates key ~default in
582585
Map.add_exn ctx_arrays ~key ~data)
583586
else if Tn.known_shared_cross_stream key then (
584-
if Hashtbl.mem device.owner_stream_subordinal key then
585-
if Hashtbl.find_exn device.owner_stream_subordinal key <> stream.subordinal then
587+
if Hashtbl.mem device.owner_streams key then
588+
if not @@ String.equal stream.unique_name @@ Hashtbl.find_exn device.owner_streams key then
586589
raise
587590
@@ Utils.User_error
588591
("Cuda_backend.alloc_if_needed: node " ^ Tn.debug_name key
589592
^ " assumed to be cross-stream-shared but then written to on multiple devices")
590-
else Hashtbl.add_exn device.owner_stream_subordinal ~key ~data:stream.subordinal;
593+
else Hashtbl.add_exn device.owner_streams ~key ~data:stream.unique_name;
591594
let data = Hashtbl.find_exn device.cross_stream_candidates key in
592595
Map.add_exn ctx_arrays ~key ~data)
593596
else (

arrayjit/lib/cuda_backend.missing.ml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,5 +72,4 @@ let suggested_num_streams Unimplemented_device = 0
7272
let get_ctx_stream Unimplemented_ctx = Unimplemented_stream
7373
let get_name Unimplemented_stream : string = failwith "CUDA missing: install cudajit"
7474
let to_ordinal _stream = 0
75-
let to_subordinal _stream = 0
7675
let name = "cuda"

0 commit comments

Comments
 (0)