@@ -87,249 +87,6 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
8787 true )
8888end
8989
90- module Alloc_buffer_ignore_stream
91- (Device_types : Device_types )
92- (Backend : Alloc_buffer with type buffer_ptr = Device_types.buffer_ptr and type stream := unit ) :
93- Alloc_buffer with type buffer_ptr = Backend. buffer_ptr and type stream = Device_types. stream =
94- struct
95- include Device_types
96-
97- let alloc_buffer ?old_buffer ~size_in_bytes _stream =
98- Backend. alloc_buffer ?old_buffer ~size_in_bytes ()
99-
100- let alloc_zero_init_array prec ~dims _stream = Backend. alloc_zero_init_array prec ~dims ()
101- let free_buffer = Option. map Backend. free_buffer ~f: (fun memfree _stream ptr -> memfree () ptr)
102- end
103-
104- module Multicore_scheduler (Backend : For_add_scheduler ) :
105- With_scheduler with type buffer_ptr = Backend. buffer_ptr = struct
106- include Backend
107- module Domain = Domain [@ warning " -3" ]
108-
109- let global_config = ref Only_devices_parallel
110-
111- let initialize config =
112- global_config := config;
113- initialize config
114-
115- let is_initialized = is_initialized
116-
117- type task_list = Task .t Utils .mutable_list [@@ deriving sexp_of ]
118-
119- module Mut = Stdlib. Mutex
120- module Queue = Saturn_lockfree. Single_prod_single_cons_queue
121-
122- type task_queue = Task .t Queue .t
123-
124- let sexp_of_task_queue q =
125- Sexp. (List [ Atom " task_queue_of_size" ; Atom (Int. to_string @@ Queue. size q) ])
126-
127- module Device_config = struct
128- include (
129- Backend : Buffer with type buffer_ptr = Backend. buffer_ptr and type buffer = Backend. buffer)
130-
131- type dev = CPU [@@ deriving sexp_of ]
132-
133- type stream_state = {
134- mutable keep_spinning : bool ;
135- mutable stream_error : exn option ;
136- queue : task_queue ;
137- mut : (Mut .t [@ sexp.opaque]);
138- host_wait_for_idle : (Stdlib.Condition .t [@ sexp.opaque]);
139- dev_wait_for_work : (Stdlib.Condition .t [@ sexp.opaque]);
140- mutable is_ready : bool ;
141- }
142- [@@ deriving sexp_of ]
143-
144- type domain = unit Domain .t
145-
146- let sexp_of_domain (d : domain ) = Sexp. Atom (" domain-" ^ Int. to_string (Domain. get_id d :> int ))
147-
148- type runner = { state : stream_state ; domain : domain } [@@ deriving sexp_of ]
149- type event = Not_implemented_yet [@@ deriving sexp_of ]
150-
151- let name = " multicore_" ^ Backend. name
152- end
153-
154- module Device_types = Device_types (Device_config )
155- include Device (Device_types ) (Alloc_buffer_ignore_stream (Device_types ) (Backend ))
156- open Device_config
157-
158- (* * TODO: Blocks till the event completes, if it's not done already. *)
159- let sync Not_implemented_yet = ()
160-
161- (* * TODO: Whether the event completed. *)
162- let is_done Not_implemented_yet = true
163-
164- (* * TODO: Schedules waiting for the given event on the context's stream. *)
165- let will_wait_for _ctx Not_implemented_yet = ()
166-
167- let get_used_memory _device = get_used_memory ()
168- let is_dev_queue_empty state = Queue. size state.queue = 0
169- let is_idle stream = is_dev_queue_empty stream.runner.state && stream.runner.state.is_ready
170- let name = " multicore_" ^ name
171- let get_name stream = [% string " %{name}:0:%{stream.stream_id#Int}" ]
172-
173- let % track3_l_sexp await stream =
174- assert (Domain. is_main_domain () );
175- let d = stream.runner.state in
176- if (not @@ is_idle stream) && d.keep_spinning then (
177- Mut. lock d.mut;
178- while (not @@ is_idle stream) && d.keep_spinning do
179- (* If the stream "is ready", it needs to be woken up first to finish the work. *)
180- if d.is_ready then Stdlib.Condition. broadcast d.dev_wait_for_work;
181- Stdlib.Condition. wait d.host_wait_for_idle d.mut
182- done ;
183- Mut. unlock d.mut;
184- Option. iter d.stream_error ~f: (fun e -> Exn. reraise e @@ get_name stream))
185-
186- (* * TODO: Returns the event indicating if any currently running or scheduled computations on the
187- stream have completed. *)
188- let all_work _stream = Not_implemented_yet
189-
190- let % track3_l_sexp schedule_task stream task =
191- assert (Domain. is_main_domain () );
192- [% log_result " schedule_task" , Task. describe task, get_name stream];
193- let d = stream.runner.state in
194- Option. iter d.stream_error ~f: (fun e -> Exn. reraise e @@ get_name stream);
195- if not d.keep_spinning then invalid_arg " Multicore_scheduler: stream not available" ;
196- if not @@ Queue. try_push d.queue task then (
197- await stream;
198- Queue. push_exn d.queue task);
199- if d.is_ready then (
200- Mut. lock d.mut;
201- Stdlib.Condition. broadcast d.dev_wait_for_work;
202- Mut. unlock d.mut)
203-
204- let global_run_no = ref 0
205- let device : device = make_device CPU ~ordinal: 0
206-
207- let % track3_l_sexp spinup_stream ~stream_id : stream =
208- Int. incr global_run_no;
209- let state =
210- {
211- keep_spinning = true ;
212- stream_error = None ;
213- queue = Queue. create ~size_exponent: 12 ;
214- mut = Mut. create () ;
215- is_ready = false ;
216- host_wait_for_idle = Stdlib.Condition. create () ;
217- dev_wait_for_work = Stdlib.Condition. create () ;
218- }
219- in
220- let % track3_l_sexp worker (() : unit ) : unit =
221- assert (not @@ Domain. is_main_domain () );
222- try
223- while state.keep_spinning do
224- match Queue. pop_opt state.queue with
225- | None ->
226- Mut. lock state.mut;
227- state.is_ready < - true ;
228- Stdlib.Condition. broadcast state.host_wait_for_idle;
229- while is_dev_queue_empty state && state.keep_spinning do
230- Stdlib.Condition. wait state.dev_wait_for_work state.mut
231- done ;
232- state.is_ready < - false ;
233- Mut. unlock state.mut
234- | Some task -> Task. run task
235- done
236- with e ->
237- state.stream_error < - Some e;
238- state.keep_spinning < - false ;
239- [% log1 " stream" , (stream_id : int ), " exception" , Exn. to_string e];
240- (* TODO: we risk raising this error multiple times because await and schedule_task raise
241- stream_error. But this is fine if we assume all exceptions are fatal. *)
242- raise e
243- in
244- make_stream device { state; domain = Domain. spawn worker } ~stream_id
245-
246- module Dynarr = Stdlib. Dynarray
247-
248- let num_devices () = 1
249- let suggested_num_streams _device = Domain. recommended_domain_count () - 1
250-
251- let cleanup_stream stream =
252- assert (Domain. is_main_domain () );
253- await stream;
254- let r = stream.runner in
255- r.state.keep_spinning < - false ;
256- Stdlib.Condition. broadcast r.state.dev_wait_for_work;
257- Domain. join r.domain
258-
259- let get_device ~ordinal =
260- if ordinal <> 0 then
261- invalid_arg [% string " Multicore_scheduler.get_device %{ordinal#Int}: only device 0 exists" ];
262- device
263-
264- let latest_stream_id = ref (- 1 )
265-
266- let new_stream _device =
267- assert (Domain. is_main_domain () );
268- Int. incr latest_stream_id;
269- let stream = spinup_stream ~stream_id: ! latest_stream_id in
270- Stdlib.Gc. finalise cleanup_stream stream;
271- stream
272- end
273-
274- (* * For debugging, allow [Sync_scheduler(...).suggested_num_streams] calls to return >1 numbers. *)
275- let sync_suggested_num_streams = ref 1
276-
277- (* * A minimalisitc wrapper creating backends where all calls run synchronously on the main thread.
278- There is only one device, but an arbitrary number of streams. *)
279- module Sync_scheduler (Backend : For_add_scheduler ) = struct
280- include Backend
281-
282- module Device_config = struct
283- include (
284- Backend : Buffer with type buffer_ptr = Backend. buffer_ptr and type buffer = Backend. buffer)
285-
286- type dev = CPU [@@ deriving sexp_of ]
287- type runner = unit [@@ deriving sexp_of ]
288- type event = unit [@@ deriving sexp_of ]
289-
290- let name = " sync_" ^ Backend. name
291- end
292-
293- module Device_types = Device_types (Device_config )
294- include Device (Device_types ) (Alloc_buffer_ignore_stream (Device_types ) (Backend ))
295- open Device_config
296-
297- let sync () = ()
298- let is_done () = true
299- let will_wait_for _context () = ()
300-
301- let alloc_buffer ?old_buffer ~size_in_bytes _stream =
302- Backend. alloc_buffer ?old_buffer ~size_in_bytes ()
303-
304- let device : device = make_device CPU ~ordinal: 0
305-
306- let get_device ~ordinal =
307- if ordinal <> 0 then
308- invalid_arg @@ " Sync_scheduler.get_device: there is only one device, but ordinal="
309- ^ Int. to_string ordinal;
310- device
311-
312- let num_devices () = 1
313- let suggested_num_streams _ = ! sync_suggested_num_streams
314- let get_used_memory _ = Backend. get_used_memory ()
315- let latest_stram_id = ref (- 1 )
316-
317- let new_stream device =
318- Int. incr latest_stram_id;
319- make_stream device () ~stream_id: ! latest_stram_id
320-
321- let all_work _stream = ()
322- let is_idle _stream = true
323- let name = " sync_" ^ Backend. name
324- let await _stream = ()
325- (* let global_run_no = ref 0 *)
326-
327- let initialize = Backend. initialize
328- let is_initialized = Backend. is_initialized
329- let get_name stream = [% string " %{name}:0:%{stream.stream_id#Int}" ]
330- let schedule_task _stream task = Task. run task
331- end
332-
33390let lower_assignments ?name bindings asgns =
33491 let name = Option. value_or_thunk name ~default: (fun () -> Assignments. get_name_exn asgns) in
33592 let unoptim_ll_source = Utils. get_debug_formatter ~fname: (name ^ " -unoptimized.ll" ) in
@@ -587,10 +344,10 @@ struct
587344 include Backend_device
588345end
589346
590- module Cc_multicore = Make_device_backend_from_lowered (Multicore_scheduler ) (Cc_backend )
591- module Gcc_multicore = Make_device_backend_from_lowered (Multicore_scheduler ) (Gcc_backend )
592- module Cc_sync = Make_device_backend_from_lowered (Sync_scheduler ) (Cc_backend )
593- module Gcc_sync = Make_device_backend_from_lowered (Sync_scheduler ) (Gcc_backend )
347+ module Cc_multicore = Make_device_backend_from_lowered (Schedulers. Multicore ) (Cc_backend )
348+ module Gcc_multicore = Make_device_backend_from_lowered (Schedulers. Multicore ) (Gcc_backend )
349+ module Cc_sync = Make_device_backend_from_lowered (Schedulers. Sync ) (Cc_backend )
350+ module Gcc_sync = Make_device_backend_from_lowered (Schedulers. Sync ) (Gcc_backend )
594351
595352let reinitialize (module Backend : Backend ) config =
596353 if not @@ Backend. is_initialized () then Backend. initialize config
0 commit comments