Skip to content

Commit 651f631

Browse files
committed
Refactor the multicore device from multi-device to multi-stream; tiny cleanups
1 parent a071197 commit 651f631

File tree

3 files changed

+45
-45
lines changed

3 files changed

+45
-45
lines changed

arrayjit/lib/backends.ml

Lines changed: 44 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ let check_merge_buffer ~scheduled_node ~code_node =
1818
("Merge buffer mismatch, on stream: " ^ name scheduled_node ^ ", expected by code: "
1919
^ name code_node)
2020

21-
module Multicore_backend (Backend : Backend_types.No_device_backend) : Backend_types.Backend =
21+
module Multicore_backend (Backend : Backend_types.No_device_backend) (* : Backend_types.Backend *) =
2222
struct
2323
module Domain = Domain [@warning "-3"]
2424

@@ -68,7 +68,7 @@ struct
6868
state : stream_state;
6969
merge_buffer : (buffer_ptr * Tnode.t) option ref;
7070
mutable allocated_buffer : (buffer_ptr * int) option;
71-
ordinal : int;
71+
subordinal : int;
7272
domain : (unit Domain.t[@sexp.opaque]);
7373
}
7474
[@@deriving sexp_of]
@@ -78,7 +78,7 @@ struct
7878

7979
let get_used_memory _device = Backend.get_used_memory ()
8080

81-
type device = stream [@@deriving sexp_of]
81+
type device = CPU [@@deriving sexp_of]
8282
type code = Backend.code [@@deriving sexp_of]
8383
type code_batch = Backend.code_batch [@@deriving sexp_of]
8484

@@ -98,18 +98,18 @@ struct
9898
done;
9999
Mut.unlock d.mut;
100100
Option.iter d.stream_error ~f:(fun e ->
101-
Exn.reraise e @@ name ^ " stream " ^ Int.to_string stream.ordinal))
101+
Exn.reraise e @@ name ^ " stream " ^ Int.to_string stream.subordinal))
102102

103103
(** TODO: Returns the event indicating if any currently running or scheduled computations on the
104104
stream have completed. *)
105105
let all_work _stream = Not_implemented_yet
106106

107107
let%track3_l_sexp schedule_task stream task =
108108
assert (Domain.is_main_domain ());
109-
[%log_result "schedule_task", Task.describe task, "stream", (stream.ordinal : int)];
109+
[%log_result "schedule_task", Task.describe task, "stream", (stream.subordinal : int)];
110110
let d = stream.state in
111111
Option.iter d.stream_error ~f:(fun e ->
112-
Exn.reraise e @@ name ^ " stream " ^ Int.to_string stream.ordinal);
112+
Exn.reraise e @@ name ^ " stream " ^ Int.to_string stream.subordinal);
113113
if not d.keep_spinning then invalid_arg "Multicore_backend: stream not available";
114114
if not @@ Queue.try_push d.queue task then (
115115
await stream;
@@ -121,7 +121,7 @@ struct
121121

122122
let global_run_no = ref 0
123123

124-
let%track3_l_sexp spinup_stream ~(ordinal : int) : stream =
124+
let%track3_l_sexp spinup_stream ~(subordinal : int) : stream =
125125
Int.incr global_run_no;
126126
let state =
127127
{
@@ -153,14 +153,14 @@ struct
153153
with e ->
154154
state.stream_error <- Some e;
155155
state.keep_spinning <- false;
156-
[%log1 "Stream", (ordinal : int), "exception", Exn.to_string e];
156+
[%log1 "Stream", (subordinal : int), "exception", Exn.to_string e];
157157
(* TODO: we risk raising this error multiple times because await and schedule_task raise
158158
stream_error. But this is fine if we assume all exceptions are fatal. *)
159159
raise e
160160
in
161161
{
162162
state;
163-
ordinal;
163+
subordinal;
164164
domain = Domain.spawn worker;
165165
merge_buffer = ref None;
166166
allocated_buffer = None;
@@ -169,7 +169,7 @@ struct
169169
type context = { stream : stream; ctx : Backend.context } [@@deriving sexp_of]
170170
type nonrec routine = context routine [@@deriving sexp_of]
171171

172-
let init stream = { stream; ctx = Backend.init (name ^ " " ^ Int.to_string stream.ordinal) }
172+
let init stream = { stream; ctx = Backend.init (name ^ " " ^ Int.to_string stream.subordinal) }
173173
let initialize = Backend.initialize
174174
let is_initialized = Backend.is_initialized
175175

@@ -179,7 +179,7 @@ struct
179179

180180
let compile = Backend.compile
181181
let compile_batch = Backend.compile_batch
182-
let get_stream_name s = "stream " ^ Int.to_string s.ordinal
182+
let get_stream_name s = "stream " ^ Int.to_string s.subordinal
183183

184184
let link { ctx; stream } code =
185185
let task = Backend.link ~merge_buffer:stream.merge_buffer ctx code in
@@ -224,7 +224,7 @@ struct
224224
context_lifetime = context;
225225
description =
226226
"from_host " ^ Tnode.debug_name tn ^ " dst "
227-
^ Int.to_string context.stream.ordinal;
227+
^ Int.to_string context.stream.subordinal;
228228
work;
229229
});
230230
true
@@ -258,7 +258,7 @@ struct
258258
context_lifetime = context;
259259
description =
260260
"from_host " ^ Tnode.debug_name tn ^ " dst "
261-
^ Int.to_string context.stream.ordinal;
261+
^ Int.to_string context.stream.subordinal;
262262
work;
263263
});
264264
true
@@ -296,8 +296,8 @@ struct
296296
Backend.to_buffer tn ~dst:merge_ptr ~src:src.ctx
297297
in
298298
let description =
299-
"device_to_device " ^ Tnode.debug_name tn ^ " dst " ^ Int.to_string dev.ordinal ^ " src "
300-
^ Int.to_string src.stream.ordinal
299+
"device_to_device " ^ Tnode.debug_name tn ^ " dst " ^ Int.to_string dev.subordinal ^ " src "
300+
^ Int.to_string src.stream.subordinal
301301
in
302302
schedule_task dev (Task.Task { context_lifetime = (src, dst); description; work })
303303
in
@@ -307,37 +307,40 @@ struct
307307
true
308308
| _ -> false
309309

310-
let num_devices () = Domain.recommended_domain_count () - 1
311-
let suggested_num_streams _device = 1
312-
let devices : device option array = Array.create ~len:(num_devices ()) None
310+
module Dynarr = Stdlib.Dynarray
313311

314-
let%track2_sexp unsafe_cleanup () =
312+
let num_devices () = 1
313+
let suggested_num_streams CPU = Domain.recommended_domain_count () - 1
314+
let latest_subordinal = ref 0
315+
316+
let cleanup_stream stream =
315317
assert (Domain.is_main_domain ());
316-
let wait_for_finish stream =
317-
await stream;
318-
stream.state.keep_spinning <- false;
319-
Stdlib.Condition.broadcast stream.state.dev_wait_for_work
320-
in
321-
Array.iter devices ~f:(Option.iter ~f:wait_for_finish);
322-
let cleanup ordinal device =
323-
Domain.join device.domain;
324-
devices.(ordinal) <- None
325-
in
326-
Array.iteri devices ~f:(fun ordinal -> Option.iter ~f:(cleanup ordinal));
318+
await stream;
319+
stream.state.keep_spinning <- false;
320+
Stdlib.Condition.broadcast stream.state.dev_wait_for_work;
321+
Domain.join stream.domain
322+
323+
let%track2_sexp unsafe_cleanup () =
324+
latest_subordinal := 0;
327325
Backend.unsafe_cleanup ()
328326

329327
let get_device ~ordinal =
330-
Option.value_or_thunk devices.(ordinal) ~default:(fun () ->
331-
let dev = spinup_stream ~ordinal in
332-
devices.(ordinal) <- Some dev;
333-
dev)
328+
if ordinal <> 0 then
329+
invalid_arg [%string "Multicore_backend.get_device %{ordinal#Int}: only device 0 exists"];
330+
CPU
331+
332+
let new_stream CPU =
333+
let subordinal = !latest_subordinal in
334+
Int.incr latest_subordinal;
335+
let stream = spinup_stream ~subordinal in
336+
Stdlib.Gc.finalise cleanup_stream stream;
337+
stream
334338

335-
let new_stream device = device
336-
let get_stream_device stream = stream
339+
let get_stream_device _stream = CPU
337340
let get_ctx_stream { stream; _ } = stream
338-
let get_name device = Int.to_string device.ordinal
339-
let to_ordinal { ordinal; _ } = ordinal
340-
let to_subordinal _ = 0
341+
let get_name stream = Int.to_string stream.subordinal
342+
let to_ordinal _ = 0
343+
let to_subordinal { subordinal; _ } = subordinal
341344
end
342345

343346
(** For debugging, allow [Sync_backend(...).suggested_num_streams] calls to return >1 numbers. *)
@@ -799,6 +802,7 @@ end
799802
module Cuda_backend : Backend_types.Backend = Lowered_backend ((
800803
Cuda_backend : Backend_types.Lowered_backend))
801804

805+
(** Initializes the backend, and if it was already initialized, performs garbage collection. *)
802806
let reinitialize (module Backend : Backend_types.Backend) config =
803807
if not @@ Backend.is_initialized () then Backend.initialize config
804808
else (
@@ -807,7 +811,7 @@ let reinitialize (module Backend : Backend_types.Backend) config =
807811
Backend.initialize config)
808812

809813
(** Reinitializes and returns a backend corresponding to [backend_name], or if omitted, selected via
810-
the global [backend] setting. *)
814+
the global [backend] setting. See {!reinitialize}. *)
811815
let fresh_backend ?backend_name ?(config = Only_devices_parallel) () =
812816
let backend =
813817
match

arrayjit/lib/cuda_backend.missing.ml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,4 @@ let get_ctx_stream Unimplemented_ctx = Unimplemented_stream
7474
let get_name Unimplemented_stream : string = failwith "CUDA missing: install cudajit"
7575
let to_ordinal _stream = 0
7676
let to_subordinal _stream = 0
77-
let to_buffer _tn ~dst:_ ~src:_ = failwith "CUDA missing: install cudajit"
78-
let host_to_buffer _tn ~dst:_ = failwith "CUDA missing: install cudajit"
79-
let buffer_to_host _tn ~src:_ = failwith "CUDA missing: install cudajit"
80-
let get_buffer _tn _context = failwith "CUDA missing: install cudajit"
8177
let name = "cuda"

test/ocannl_config

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
randomness_lib=for_tests
22
log_main_domain_to_stdout=true
3-
backend=cuda
3+
backend=cc
44
log_level=0

0 commit comments

Comments
 (0)