Skip to content

Commit e34f941

Browse files
committed
Split schedulers.ml out of backends.ml
1 parent 031fc20 commit e34f941

File tree

7 files changed

+261
-252
lines changed

7 files changed

+261
-252
lines changed

CHANGES.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
## [0.4.2] -- current
1+
## [0.5.0] -- current
22

33
### Added
44

@@ -11,7 +11,7 @@
1111
- Migrated to cudajit 0.5.
1212
- Verifying that code is linked with the right contexts, by tracking `embedded_nodes` with assignments.
1313
- Renaming: (virtual) `device` -> `stream`, `physical_device` -> `device`.
14-
- New files: split out `backend_types.ml` from `backends.ml`; moved `Tnode.task` to `task.ml`; renamed `backend_utils.ml` to `c_syntax.ml`.
14+
- New files: split out `backend_intf.ml`, `backend_impl.ml`, `schedulers.ml` from `backends.ml`; moved `Tnode.task` to `task.ml`; renamed `backend_utils.ml` to `c_syntax.ml`.
1515
- Removed half-static verification of merge buffer nodes inside `device_to_device`.
1616
- Fixed #286: cross-stream-sharing incorporated into `Tnode.memory_mode`.
1717
- Moved the multicore backend from a `device = stream` model to a single device model.

arrayjit/lib/backend_impl.ml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,3 +291,17 @@ module type Lowered_backend = sig
291291
batch. The returned [ctx_arrays] will be part of a context downstream of all the tasks and the
292292
tasks' contexts are not independent (typically, they are cumulative). *)
293293
end
294+
295+
module Alloc_buffer_ignore_stream
296+
(Device_types : Device_types)
297+
(Backend : Alloc_buffer with type buffer_ptr = Device_types.buffer_ptr and type stream := unit) :
298+
Alloc_buffer with type buffer_ptr = Backend.buffer_ptr and type stream = Device_types.stream =
299+
struct
300+
include Device_types
301+
302+
let alloc_buffer ?old_buffer ~size_in_bytes _stream =
303+
Backend.alloc_buffer ?old_buffer ~size_in_bytes ()
304+
305+
let alloc_zero_init_array prec ~dims _stream = Backend.alloc_zero_init_array prec ~dims ()
306+
let free_buffer = Option.map Backend.free_buffer ~f:(fun memfree _stream ptr -> memfree () ptr)
307+
end

arrayjit/lib/backends.ml

Lines changed: 4 additions & 247 deletions
Original file line numberDiff line numberDiff line change
@@ -87,249 +87,6 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
8787
true)
8888
end
8989

90-
module Alloc_buffer_ignore_stream
91-
(Device_types : Device_types)
92-
(Backend : Alloc_buffer with type buffer_ptr = Device_types.buffer_ptr and type stream := unit) :
93-
Alloc_buffer with type buffer_ptr = Backend.buffer_ptr and type stream = Device_types.stream =
94-
struct
95-
include Device_types
96-
97-
let alloc_buffer ?old_buffer ~size_in_bytes _stream =
98-
Backend.alloc_buffer ?old_buffer ~size_in_bytes ()
99-
100-
let alloc_zero_init_array prec ~dims _stream = Backend.alloc_zero_init_array prec ~dims ()
101-
let free_buffer = Option.map Backend.free_buffer ~f:(fun memfree _stream ptr -> memfree () ptr)
102-
end
103-
104-
module Multicore_scheduler (Backend : For_add_scheduler) :
105-
With_scheduler with type buffer_ptr = Backend.buffer_ptr = struct
106-
include Backend
107-
module Domain = Domain [@warning "-3"]
108-
109-
let global_config = ref Only_devices_parallel
110-
111-
let initialize config =
112-
global_config := config;
113-
initialize config
114-
115-
let is_initialized = is_initialized
116-
117-
type task_list = Task.t Utils.mutable_list [@@deriving sexp_of]
118-
119-
module Mut = Stdlib.Mutex
120-
module Queue = Saturn_lockfree.Single_prod_single_cons_queue
121-
122-
type task_queue = Task.t Queue.t
123-
124-
let sexp_of_task_queue q =
125-
Sexp.(List [ Atom "task_queue_of_size"; Atom (Int.to_string @@ Queue.size q) ])
126-
127-
module Device_config = struct
128-
include (
129-
Backend : Buffer with type buffer_ptr = Backend.buffer_ptr and type buffer = Backend.buffer)
130-
131-
type dev = CPU [@@deriving sexp_of]
132-
133-
type stream_state = {
134-
mutable keep_spinning : bool;
135-
mutable stream_error : exn option;
136-
queue : task_queue;
137-
mut : (Mut.t[@sexp.opaque]);
138-
host_wait_for_idle : (Stdlib.Condition.t[@sexp.opaque]);
139-
dev_wait_for_work : (Stdlib.Condition.t[@sexp.opaque]);
140-
mutable is_ready : bool;
141-
}
142-
[@@deriving sexp_of]
143-
144-
type domain = unit Domain.t
145-
146-
let sexp_of_domain (d : domain) = Sexp.Atom ("domain-" ^ Int.to_string (Domain.get_id d :> int))
147-
148-
type runner = { state : stream_state; domain : domain } [@@deriving sexp_of]
149-
type event = Not_implemented_yet [@@deriving sexp_of]
150-
151-
let name = "multicore_" ^ Backend.name
152-
end
153-
154-
module Device_types = Device_types (Device_config)
155-
include Device (Device_types) (Alloc_buffer_ignore_stream (Device_types) (Backend))
156-
open Device_config
157-
158-
(** TODO: Blocks till the event completes, if it's not done already. *)
159-
let sync Not_implemented_yet = ()
160-
161-
(** TODO: Whether the event completed. *)
162-
let is_done Not_implemented_yet = true
163-
164-
(** TODO: Schedules waiting for the given event on the context's stream. *)
165-
let will_wait_for _ctx Not_implemented_yet = ()
166-
167-
let get_used_memory _device = get_used_memory ()
168-
let is_dev_queue_empty state = Queue.size state.queue = 0
169-
let is_idle stream = is_dev_queue_empty stream.runner.state && stream.runner.state.is_ready
170-
let name = "multicore_" ^ name
171-
let get_name stream = [%string "%{name}:0:%{stream.stream_id#Int}"]
172-
173-
let%track3_l_sexp await stream =
174-
assert (Domain.is_main_domain ());
175-
let d = stream.runner.state in
176-
if (not @@ is_idle stream) && d.keep_spinning then (
177-
Mut.lock d.mut;
178-
while (not @@ is_idle stream) && d.keep_spinning do
179-
(* If the stream "is ready", it needs to be woken up first to finish the work. *)
180-
if d.is_ready then Stdlib.Condition.broadcast d.dev_wait_for_work;
181-
Stdlib.Condition.wait d.host_wait_for_idle d.mut
182-
done;
183-
Mut.unlock d.mut;
184-
Option.iter d.stream_error ~f:(fun e -> Exn.reraise e @@ get_name stream))
185-
186-
(** TODO: Returns the event indicating if any currently running or scheduled computations on the
187-
stream have completed. *)
188-
let all_work _stream = Not_implemented_yet
189-
190-
let%track3_l_sexp schedule_task stream task =
191-
assert (Domain.is_main_domain ());
192-
[%log_result "schedule_task", Task.describe task, get_name stream];
193-
let d = stream.runner.state in
194-
Option.iter d.stream_error ~f:(fun e -> Exn.reraise e @@ get_name stream);
195-
if not d.keep_spinning then invalid_arg "Multicore_scheduler: stream not available";
196-
if not @@ Queue.try_push d.queue task then (
197-
await stream;
198-
Queue.push_exn d.queue task);
199-
if d.is_ready then (
200-
Mut.lock d.mut;
201-
Stdlib.Condition.broadcast d.dev_wait_for_work;
202-
Mut.unlock d.mut)
203-
204-
let global_run_no = ref 0
205-
let device : device = make_device CPU ~ordinal:0
206-
207-
let%track3_l_sexp spinup_stream ~stream_id : stream =
208-
Int.incr global_run_no;
209-
let state =
210-
{
211-
keep_spinning = true;
212-
stream_error = None;
213-
queue = Queue.create ~size_exponent:12;
214-
mut = Mut.create ();
215-
is_ready = false;
216-
host_wait_for_idle = Stdlib.Condition.create ();
217-
dev_wait_for_work = Stdlib.Condition.create ();
218-
}
219-
in
220-
let%track3_l_sexp worker (() : unit) : unit =
221-
assert (not @@ Domain.is_main_domain ());
222-
try
223-
while state.keep_spinning do
224-
match Queue.pop_opt state.queue with
225-
| None ->
226-
Mut.lock state.mut;
227-
state.is_ready <- true;
228-
Stdlib.Condition.broadcast state.host_wait_for_idle;
229-
while is_dev_queue_empty state && state.keep_spinning do
230-
Stdlib.Condition.wait state.dev_wait_for_work state.mut
231-
done;
232-
state.is_ready <- false;
233-
Mut.unlock state.mut
234-
| Some task -> Task.run task
235-
done
236-
with e ->
237-
state.stream_error <- Some e;
238-
state.keep_spinning <- false;
239-
[%log1 "stream", (stream_id : int), "exception", Exn.to_string e];
240-
(* TODO: we risk raising this error multiple times because await and schedule_task raise
241-
stream_error. But this is fine if we assume all exceptions are fatal. *)
242-
raise e
243-
in
244-
make_stream device { state; domain = Domain.spawn worker } ~stream_id
245-
246-
module Dynarr = Stdlib.Dynarray
247-
248-
let num_devices () = 1
249-
let suggested_num_streams _device = Domain.recommended_domain_count () - 1
250-
251-
let cleanup_stream stream =
252-
assert (Domain.is_main_domain ());
253-
await stream;
254-
let r = stream.runner in
255-
r.state.keep_spinning <- false;
256-
Stdlib.Condition.broadcast r.state.dev_wait_for_work;
257-
Domain.join r.domain
258-
259-
let get_device ~ordinal =
260-
if ordinal <> 0 then
261-
invalid_arg [%string "Multicore_scheduler.get_device %{ordinal#Int}: only device 0 exists"];
262-
device
263-
264-
let latest_stream_id = ref (-1)
265-
266-
let new_stream _device =
267-
assert (Domain.is_main_domain ());
268-
Int.incr latest_stream_id;
269-
let stream = spinup_stream ~stream_id:!latest_stream_id in
270-
Stdlib.Gc.finalise cleanup_stream stream;
271-
stream
272-
end
273-
274-
(** For debugging, allow [Sync_scheduler(...).suggested_num_streams] calls to return >1 numbers. *)
275-
let sync_suggested_num_streams = ref 1
276-
277-
(** A minimalisitc wrapper creating backends where all calls run synchronously on the main thread.
278-
There is only one device, but an arbitrary number of streams. *)
279-
module Sync_scheduler (Backend : For_add_scheduler) = struct
280-
include Backend
281-
282-
module Device_config = struct
283-
include (
284-
Backend : Buffer with type buffer_ptr = Backend.buffer_ptr and type buffer = Backend.buffer)
285-
286-
type dev = CPU [@@deriving sexp_of]
287-
type runner = unit [@@deriving sexp_of]
288-
type event = unit [@@deriving sexp_of]
289-
290-
let name = "sync_" ^ Backend.name
291-
end
292-
293-
module Device_types = Device_types (Device_config)
294-
include Device (Device_types) (Alloc_buffer_ignore_stream (Device_types) (Backend))
295-
open Device_config
296-
297-
let sync () = ()
298-
let is_done () = true
299-
let will_wait_for _context () = ()
300-
301-
let alloc_buffer ?old_buffer ~size_in_bytes _stream =
302-
Backend.alloc_buffer ?old_buffer ~size_in_bytes ()
303-
304-
let device : device = make_device CPU ~ordinal:0
305-
306-
let get_device ~ordinal =
307-
if ordinal <> 0 then
308-
invalid_arg @@ "Sync_scheduler.get_device: there is only one device, but ordinal="
309-
^ Int.to_string ordinal;
310-
device
311-
312-
let num_devices () = 1
313-
let suggested_num_streams _ = !sync_suggested_num_streams
314-
let get_used_memory _ = Backend.get_used_memory ()
315-
let latest_stram_id = ref (-1)
316-
317-
let new_stream device =
318-
Int.incr latest_stram_id;
319-
make_stream device () ~stream_id:!latest_stram_id
320-
321-
let all_work _stream = ()
322-
let is_idle _stream = true
323-
let name = "sync_" ^ Backend.name
324-
let await _stream = ()
325-
(* let global_run_no = ref 0 *)
326-
327-
let initialize = Backend.initialize
328-
let is_initialized = Backend.is_initialized
329-
let get_name stream = [%string "%{name}:0:%{stream.stream_id#Int}"]
330-
let schedule_task _stream task = Task.run task
331-
end
332-
33390
let lower_assignments ?name bindings asgns =
33491
let name = Option.value_or_thunk name ~default:(fun () -> Assignments.get_name_exn asgns) in
33592
let unoptim_ll_source = Utils.get_debug_formatter ~fname:(name ^ "-unoptimized.ll") in
@@ -587,10 +344,10 @@ struct
587344
include Backend_device
588345
end
589346

590-
module Cc_multicore = Make_device_backend_from_lowered (Multicore_scheduler) (Cc_backend)
591-
module Gcc_multicore = Make_device_backend_from_lowered (Multicore_scheduler) (Gcc_backend)
592-
module Cc_sync = Make_device_backend_from_lowered (Sync_scheduler) (Cc_backend)
593-
module Gcc_sync = Make_device_backend_from_lowered (Sync_scheduler) (Gcc_backend)
347+
module Cc_multicore = Make_device_backend_from_lowered (Schedulers.Multicore) (Cc_backend)
348+
module Gcc_multicore = Make_device_backend_from_lowered (Schedulers.Multicore) (Gcc_backend)
349+
module Cc_sync = Make_device_backend_from_lowered (Schedulers.Sync) (Cc_backend)
350+
module Gcc_sync = Make_device_backend_from_lowered (Schedulers.Sync) (Gcc_backend)
594351

595352
let reinitialize (module Backend : Backend) config =
596353
if not @@ Backend.is_initialized () then Backend.initialize config

arrayjit/lib/backends.mli

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
open Base
44

5-
val sync_suggested_num_streams : int ref
6-
75
val reinitialize : (module Backend_intf.Backend) -> Backend_intf.config -> unit
86
(** Initializes the backend, and if it was already initialized, performs garbage collection. *)
97

arrayjit/lib/dune

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
cc_backend
4949
gcc_backend
5050
cuda_backend
51+
schedulers
5152
backends)
5253
(modes byte native))
5354

0 commit comments

Comments
 (0)