Skip to content

Commit 2806622

Browse files
committed
Go back to using ints to identify streams
1 parent 1775098 commit 2806622

File tree

5 files changed

+43
-56
lines changed

5 files changed

+43
-56
lines changed

arrayjit/lib/backend_types.ml

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ type ('buffer_ptr, 'device, 'stream_state, 'runner, 'event) stream = {
127127
device : 'device;
128128
state : 'stream_state;
129129
merge_buffer : ('buffer_ptr * Tnode.t) option ref;
130-
unique_name : string;
130+
stream_id : int;
131131
mutable allocated_buffer : 'buffer_ptr buffer option;
132132
runner : 'runner;
133133
requested_work_for : 'event option Hashtbl.M(Tnode).t;
@@ -155,8 +155,7 @@ module type Device = sig
155155
include Device_types
156156
include Alloc_buffer with type buffer_ptr := buffer_ptr and type stream := stream
157157

158-
val make_stream :
159-
device:device -> state:stream_state -> unique_name:string -> runner:runner -> stream
158+
val make_stream : device:device -> state:stream_state -> stream_id:int -> runner:runner -> stream
160159
end
161160

162161
module Device_types (Device_config : Device_config) = struct
@@ -173,14 +172,14 @@ struct
173172
include Device_types
174173
include Alloc_buffer
175174

176-
let make_stream ~device ~state ~unique_name ~runner =
175+
let make_stream ~device ~state ~stream_id ~runner =
177176
{
178177
device;
179178
state;
180179
merge_buffer = ref None;
181-
unique_name : string;
180+
stream_id;
182181
allocated_buffer = None;
183-
runner : 'runner;
182+
runner;
184183
requested_work_for = Hashtbl.create (module Tnode);
185184
}
186185
end

arrayjit/lib/backends.ml

Lines changed: 22 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ module Multicore_backend (Backend : No_device_backend) = struct
149149

150150
let is_dev_queue_empty state = Queue.size state.Device_config.queue = 0
151151
let is_idle stream = is_dev_queue_empty stream.state && stream.state.is_ready
152-
let name = "multicore " ^ name
152+
let name = "multicore_" ^ name
153+
let get_name stream = [%string "%{name}:0:%{stream.stream_id#Int}"]
153154

154155
let%track3_l_sexp await stream =
155156
assert (Domain.is_main_domain ());
@@ -162,18 +163,17 @@ module Multicore_backend (Backend : No_device_backend) = struct
162163
Stdlib.Condition.wait d.host_wait_for_idle d.mut
163164
done;
164165
Mut.unlock d.mut;
165-
Option.iter d.stream_error ~f:(fun e -> Exn.reraise e @@ name ^ " " ^ stream.unique_name))
166+
Option.iter d.stream_error ~f:(fun e -> Exn.reraise e @@ get_name stream))
166167

167168
(** TODO: Returns the event indicating if any currently running or scheduled computations on the
168169
stream have completed. *)
169170
let all_work _stream = Device_config.Not_implemented_yet
170171

171172
let%track3_l_sexp schedule_task stream task =
172173
assert (Domain.is_main_domain ());
173-
[%log_result "schedule_task", Task.describe task, stream.unique_name];
174+
[%log_result "schedule_task", Task.describe task, get_name stream];
174175
let d = stream.state in
175-
Option.iter d.Device_config.stream_error ~f:(fun e ->
176-
Exn.reraise e @@ name ^ " " ^ stream.unique_name);
176+
Option.iter d.Device_config.stream_error ~f:(fun e -> Exn.reraise e @@ get_name stream);
177177
if not d.keep_spinning then invalid_arg "Multicore_backend: stream not available";
178178
if not @@ Queue.try_push d.queue task then (
179179
await stream;
@@ -185,7 +185,7 @@ module Multicore_backend (Backend : No_device_backend) = struct
185185

186186
let global_run_no = ref 0
187187

188-
let%track3_l_sexp spinup_stream ~unique_name : stream =
188+
let%track3_l_sexp spinup_stream ~stream_id : stream =
189189
Int.incr global_run_no;
190190
let state =
191191
{
@@ -217,17 +217,17 @@ module Multicore_backend (Backend : No_device_backend) = struct
217217
with e ->
218218
state.stream_error <- Some e;
219219
state.keep_spinning <- false;
220-
[%log1 unique_name, "exception", Exn.to_string e];
220+
[%log1 "stream", (stream_id : int), "exception", Exn.to_string e];
221221
(* TODO: we risk raising this error multiple times because await and schedule_task raise
222222
stream_error. But this is fine if we assume all exceptions are fatal. *)
223223
raise e
224224
in
225-
make_stream ~device:Device_config.CPU ~state ~unique_name ~runner:(Domain.spawn worker)
225+
make_stream ~device:Device_config.CPU ~state ~stream_id ~runner:(Domain.spawn worker)
226226

227227
type nonrec context = { stream : stream; ctx : context } [@@deriving sexp_of]
228228

229229
let ctx_arrays context = ctx_arrays context.ctx
230-
let init stream = { stream; ctx = init (name ^ " " ^ stream.unique_name) }
230+
let init stream = { stream; ctx = init (get_name stream) }
231231
let initialize = initialize
232232
let is_initialized = is_initialized
233233

@@ -237,7 +237,6 @@ module Multicore_backend (Backend : No_device_backend) = struct
237237

238238
let compile = compile
239239
let compile_batch = compile_batch
240-
let get_name stream = stream.unique_name
241240

242241
let link { ctx; stream } code =
243242
let task = link ~merge_buffer:stream.merge_buffer ctx code in
@@ -264,30 +263,25 @@ module Multicore_backend (Backend : No_device_backend) = struct
264263

265264
let num_devices () = 1
266265
let suggested_num_streams Device_config.CPU = Domain.recommended_domain_count () - 1
267-
let used_names = Hash_set.create (module String)
268266

269267
let cleanup_stream stream =
270268
assert (Domain.is_main_domain ());
271269
await stream;
272270
stream.state.keep_spinning <- false;
273271
Stdlib.Condition.broadcast stream.state.dev_wait_for_work;
274-
Hash_set.remove used_names stream.unique_name;
275272
Domain.join stream.runner
276273

277274
let get_device ~ordinal =
278275
if ordinal <> 0 then
279276
invalid_arg [%string "Multicore_backend.get_device %{ordinal#Int}: only device 0 exists"];
280277
Device_config.CPU
281278

279+
let latest_stream_id = ref (-1)
280+
282281
let new_stream Device_config.CPU =
283282
assert (Domain.is_main_domain ());
284-
let rec unique_name suffix =
285-
let name = "stream " ^ Int.to_string suffix in
286-
if Hash_set.mem used_names name then unique_name (suffix + 1) else name
287-
in
288-
let unique_name = unique_name 0 in
289-
Hash_set.add used_names unique_name;
290-
let stream = spinup_stream ~unique_name in
283+
Int.incr latest_stream_id;
284+
let stream = spinup_stream ~stream_id:!latest_stream_id in
291285
Stdlib.Gc.finalise cleanup_stream stream;
292286
stream
293287

@@ -300,14 +294,13 @@ module Multicore_backend (Backend : No_device_backend) = struct
300294
(* TODO: pass description to from_host. *)
301295
schedule_task dst.stream
302296
(Task.Task
303-
{ context_lifetime = dst; description = "from_host on " ^ dst.stream.unique_name; work })
297+
{ context_lifetime = dst; description = "from_host on " ^ get_name dst.stream; work })
304298

305299
let to_host ~src_ptr ~src hosted =
306300
let work () = buffer_to_host hosted ~src:src_ptr in
307301
(* TODO: pass description to to_host. *)
308302
schedule_task src.stream
309-
(Task.Task
310-
{ context_lifetime = src; description = "to_host on " ^ src.stream.unique_name; work })
303+
(Task.Task { context_lifetime = src; description = "to_host on " ^ get_name src.stream; work })
311304

312305
let device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
313306
let dev = dst.stream in
@@ -332,8 +325,8 @@ module Multicore_backend (Backend : No_device_backend) = struct
332325
buffer_to_buffer ~dst:merge_ptr ~src:src_ptr ~size_in_bytes
333326
in
334327
let description =
335-
"device_to_device " ^ Tnode.debug_name tn ^ " dst " ^ dev.unique_name ^ " src "
336-
^ src.stream.unique_name
328+
"device_to_device " ^ Tnode.debug_name tn ^ " dst " ^ get_name dev ^ " src "
329+
^ get_name src.stream
337330
in
338331
schedule_task dev (Task.Task { context_lifetime = (src, dst); description; work })
339332
end
@@ -383,20 +376,18 @@ module Sync_backend (Backend : No_device_backend) = struct
383376
let num_devices () = 1
384377
let suggested_num_streams Device_config.CPU = !sync_suggested_num_streams
385378
let get_used_memory Device_config.CPU = Backend.get_used_memory ()
386-
let next_stream = ref 0
379+
let latest_stram_id = ref (-1)
387380

388381
let new_stream Device_config.CPU : stream =
389-
Int.incr next_stream;
390-
make_stream ~device:Device_config.CPU ~state:()
391-
~unique_name:("stream " ^ Int.to_string (!next_stream - 1))
392-
~runner:()
382+
Int.incr latest_stram_id;
383+
make_stream ~device:Device_config.CPU ~state:() ~stream_id:!latest_stram_id ~runner:()
393384

394385
type code = Backend.code [@@deriving sexp_of]
395386
type code_batch = Backend.code_batch [@@deriving sexp_of]
396387

397388
let all_work _stream = ()
398389
let is_idle _stream = true
399-
let name = "sync " ^ Backend.name
390+
let name = "sync_" ^ Backend.name
400391
let await _stream = ()
401392
(* let global_run_no = ref 0 *)
402393

@@ -422,7 +413,7 @@ module Sync_backend (Backend : No_device_backend) = struct
422413
Array.map routines
423414
~f:(Option.map ~f:(fun task -> { task with context = { ctx = task.context; stream } })) )
424415

425-
let get_name stream = stream.unique_name
416+
let get_name stream = [%string "%{name}:0:%{stream.stream_id#Int}"]
426417
let from_host ~dst_ptr ~dst:_ hosted = host_to_buffer hosted ~dst:dst_ptr
427418
let to_host ~src_ptr ~src:_ hosted = buffer_to_host hosted ~src:src_ptr
428419

arrayjit/lib/cuda_backend.cudajit.ml

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,12 @@ module Device_config = struct
3939
primary_context : Cu.Context.t;
4040
mutable copy_merge_buffer : buffer_ptr;
4141
mutable copy_merge_buffer_capacity : int;
42-
used_names : Hash_set.M(String).t; (** Unique names of streams. *)
42+
mutable latest_stream_id : int;
4343
released : Utils.atomic_bool;
4444
cross_stream_candidates : buffer_ptr Hashtbl.M(Tn).t;
4545
(** Freshly created arrays that might be shared across streams. The map can both grow and
4646
shrink. See the explanation on top of this file. *)
47-
owner_streams : string Hashtbl.M(Tn).t;
47+
owner_streams : int Hashtbl.M(Tn).t;
4848
(** The streams owning the given nodes. This map can only grow. *)
4949
}
5050
[@@deriving sexp_of]
@@ -166,7 +166,7 @@ let%track3_sexp get_device ~(ordinal : int) : device =
166166
{
167167
dev;
168168
ordinal;
169-
used_names = Hash_set.create (module String);
169+
latest_stream_id = -1;
170170
primary_context;
171171
copy_merge_buffer;
172172
copy_merge_buffer_capacity;
@@ -184,16 +184,11 @@ let%track3_sexp get_device ~(ordinal : int) : device =
184184
if Atomic.get result.released then default () else result
185185

186186
let%track3_sexp new_stream (device : device) : stream =
187-
let rec unique_name suffix =
188-
let name = "stream " ^ Int.to_string suffix in
189-
if Hash_set.mem device.used_names name then unique_name (suffix + 1) else name
190-
in
191-
let unique_name = unique_name 0 in
192-
Hash_set.add device.used_names unique_name;
187+
device.latest_stream_id <- device.latest_stream_id + 1;
193188
(* Strange that we need ctx_set_current even with a single device! *)
194189
set_ctx device.primary_context;
195190
let cu_stream = Cu.Stream.create ~non_blocking:true () in
196-
make_stream ~device ~state:() ~unique_name ~runner:cu_stream
191+
make_stream ~device ~state:() ~stream_id:device.latest_stream_id ~runner:cu_stream
197192

198193
let cuda_properties =
199194
let cache =
@@ -216,7 +211,10 @@ let suggested_num_streams device =
216211
let get_ctx_stream { stream; _ } = stream
217212
let get_stream_device { device; _ } = device
218213
let to_ordinal Device_config.{ ordinal; _ } = ordinal
219-
let get_name stream = stream.unique_name
214+
let name = "cuda"
215+
216+
let get_name stream =
217+
[%string "%{name}:%{stream.device.Device_config.ordinal#Int}:%{stream.stream_id#Int}"]
220218

221219
let await stream : unit =
222220
set_ctx stream.device.Device_config.primary_context;
@@ -543,12 +541,12 @@ let%track3_sexp alloc_if_needed ctx stream ~key ~data:node ctx_arrays =
543541
Map.add_exn ctx_arrays ~key ~data)
544542
else if Tn.known_shared_cross_stream key then (
545543
if Hashtbl.mem device.owner_streams key then
546-
if not @@ String.equal stream.unique_name @@ Hashtbl.find_exn device.owner_streams key then
544+
if not (stream.stream_id = Hashtbl.find_exn device.owner_streams key) then
547545
raise
548546
@@ Utils.User_error
549547
("Cuda_backend.alloc_if_needed: node " ^ Tn.debug_name key
550548
^ " assumed to be cross-stream-shared but then written to on multiple devices")
551-
else Hashtbl.add_exn device.owner_streams ~key ~data:stream.unique_name;
549+
else Hashtbl.add_exn device.owner_streams ~key ~data:stream.stream_id;
552550
let data = Hashtbl.find_exn device.cross_stream_candidates key in
553551
Map.add_exn ctx_arrays ~key ~data)
554552
else (
@@ -601,5 +599,3 @@ let%track3_sexp link_batch prior_context (code_batch : code_batch) : context * _
601599
((context, ctx_arrays), Some task)))
602600
in
603601
(context, lowered_bindings, procs)
604-
605-
let name = "cuda"

lib/attic.mld

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ Old copying mechanism in backends.ml Multicore_backend:
382382
{
383383
context_lifetime = context;
384384
description =
385-
"from_host " ^ Tnode.debug_name tn ^ " dst " ^ context.stream.unique_name;
385+
"from_host " ^ Tnode.debug_name tn ^ " dst " ^ context.stream.stream_id;
386386
work;
387387
});
388388
true
@@ -415,7 +415,7 @@ Old copying mechanism in backends.ml Multicore_backend:
415415
{
416416
context_lifetime = context;
417417
description =
418-
"from_host " ^ Tnode.debug_name tn ^ " dst " ^ context.stream.unique_name;
418+
"from_host " ^ Tnode.debug_name tn ^ " dst " ^ context.stream.stream_id;
419419
work;
420420
});
421421
true
@@ -453,8 +453,8 @@ Old copying mechanism in backends.ml Multicore_backend:
453453
Backend.to_buffer tn ~dst:merge_ptr ~src:src.ctx
454454
in
455455
let description =
456-
"device_to_device " ^ Tnode.debug_name tn ^ " dst " ^ dev.unique_name ^ " src "
457-
^ src.stream.unique_name
456+
"device_to_device " ^ Tnode.debug_name tn ^ " dst " ^ dev.stream_id ^ " src "
457+
^ src.stream.stream_id
458458
in
459459
schedule_task dev (Task.Task { context_lifetime = (src, dst); description; work })
460460
in

lib/train.ml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,7 @@ let%track3_sexp parallel_update (type context)
404404
let fs = [%debug_notrace Array.map grad_updates ~f:(fun upd () -> Task.run upd.schedule)] in
405405
fun () -> round_robin fs lowered_bindings sgd_update.bindings ~sync
406406

407+
(* Note: this type signature looks ugly, but it will get simple again with modular explicits. *)
407408
let get_all_suggested_streams ?(max_num_streams : int option)
408409
(type buffer_ptr device stream_state runner event)
409410
(module Backend : Backend_type

0 commit comments

Comments
 (0)