@@ -104,7 +104,7 @@ module Multicore_backend (Backend : No_device_backend) = struct
104104 include (
105105 Backend : Buffer with type buffer_ptr = Backend. buffer_ptr and type buffer = Backend. buffer)
106106
107- type device = CPU [@@ deriving sexp_of ]
107+ type dev = CPU [@@ deriving sexp_of ]
108108
109109 type stream_state = {
110110 mutable keep_spinning : bool ;
@@ -117,10 +117,11 @@ module Multicore_backend (Backend : No_device_backend) = struct
117117 }
118118 [@@ deriving sexp_of ]
119119
120- type runner = unit Domain .t
120+ type domain = unit Domain .t
121121
122- let sexp_of_runner (d : runner ) = Sexp. Atom (" domain-" ^ Int. to_string (Domain. get_id d :> int ))
122+ let sexp_of_domain (d : domain ) = Sexp. Atom (" domain-" ^ Int. to_string (Domain. get_id d :> int ))
123123
124+ type runner = { state : stream_state ; domain : domain } [@@ deriving sexp_of ]
124125 type event = Not_implemented_yet [@@ deriving sexp_of ]
125126 end
126127
@@ -132,29 +133,30 @@ module Multicore_backend (Backend : No_device_backend) = struct
132133 end
133134
134135 include Device (Device_types (Device_config )) (Alloc_buffer )
136+ open Device_config
135137
136138 (* * TODO: Blocks till the event completes, if it's not done already. *)
137- let sync Device_config. Not_implemented_yet = ()
139+ let sync Not_implemented_yet = ()
138140
139141 (* * TODO: Whether the event completed. *)
140- let is_done Device_config. Not_implemented_yet = true
142+ let is_done Not_implemented_yet = true
141143
142144 (* * TODO: Schedules waiting for the given event on the context's stream. *)
143- let will_wait_for _ctx Device_config. Not_implemented_yet = ()
145+ let will_wait_for _ctx Not_implemented_yet = ()
144146
145147 let get_used_memory _device = get_used_memory ()
146148
147149 type nonrec code = code [@@ deriving sexp_of ]
148150 type nonrec code_batch = code_batch [@@ deriving sexp_of ]
149151
150- let is_dev_queue_empty state = Queue. size state.Device_config. queue = 0
151- let is_idle stream = is_dev_queue_empty stream.state && stream.state.is_ready
152+ let is_dev_queue_empty state = Queue. size state.queue = 0
153+ let is_idle stream = is_dev_queue_empty stream.runner. state && stream.runner .state.is_ready
152154 let name = " multicore_" ^ name
153155 let get_name stream = [% string " %{name}:0:%{stream.stream_id#Int}" ]
154156
155157 let % track3_l_sexp await stream =
156158 assert (Domain. is_main_domain () );
157- let d = stream.state in
159+ let d = stream.runner. state in
158160 if (not @@ is_idle stream) && d.keep_spinning then (
159161 Mut. lock d.mut;
160162 while (not @@ is_idle stream) && d.keep_spinning do
@@ -167,13 +169,13 @@ module Multicore_backend (Backend : No_device_backend) = struct
167169
168170 (* * TODO: Returns the event indicating if any currently running or scheduled computations on the
169171 stream have completed. *)
170- let all_work _stream = Device_config. Not_implemented_yet
172+ let all_work _stream = Not_implemented_yet
171173
172174 let % track3_l_sexp schedule_task stream task =
173175 assert (Domain. is_main_domain () );
174176 [% log_result " schedule_task" , Task. describe task, get_name stream];
175- let d = stream.state in
176- Option. iter d.Device_config. stream_error ~f: (fun e -> Exn. reraise e @@ get_name stream);
177+ let d = stream.runner. state in
178+ Option. iter d.stream_error ~f: (fun e -> Exn. reraise e @@ get_name stream);
177179 if not d.keep_spinning then invalid_arg " Multicore_backend: stream not available" ;
178180 if not @@ Queue. try_push d.queue task then (
179181 await stream;
@@ -184,12 +186,13 @@ module Multicore_backend (Backend : No_device_backend) = struct
184186 Mut. unlock d.mut)
185187
186188 let global_run_no = ref 0
189+ let device : device = make_device CPU ~ordinal: 0
187190
188191 let % track3_l_sexp spinup_stream ~stream_id : stream =
189192 Int. incr global_run_no;
190193 let state =
191194 {
192- Device_config. keep_spinning = true ;
195+ keep_spinning = true ;
193196 stream_error = None ;
194197 queue = Queue. create ~size_exponent: 12 ;
195198 mut = Mut. create () ;
@@ -222,7 +225,7 @@ module Multicore_backend (Backend : No_device_backend) = struct
222225 stream_error. But this is fine if we assume all exceptions are fatal. *)
223226 raise e
224227 in
225- make_stream ~ device: Device_config. CPU ~ state ~stream_id ~runner: ( Domain. spawn worker)
228+ make_stream device { state; domain = Domain. spawn worker } ~stream_id
226229
227230 type nonrec context = { stream : stream ; ctx : context } [@@ deriving sexp_of ]
228231
@@ -262,30 +265,31 @@ module Multicore_backend (Backend : No_device_backend) = struct
262265 module Dynarr = Stdlib. Dynarray
263266
264267 let num_devices () = 1
265- let suggested_num_streams Device_config. CPU = Domain. recommended_domain_count () - 1
268+ let suggested_num_streams _device = Domain. recommended_domain_count () - 1
266269
267270 let cleanup_stream stream =
268271 assert (Domain. is_main_domain () );
269272 await stream;
270- stream.state.keep_spinning < - false ;
271- Stdlib.Condition. broadcast stream.state.dev_wait_for_work;
272- Domain. join stream.runner
273+ let r = stream.runner in
274+ r.state.keep_spinning < - false ;
275+ Stdlib.Condition. broadcast r.state.dev_wait_for_work;
276+ Domain. join r.domain
273277
274278 let get_device ~ordinal =
275279 if ordinal <> 0 then
276280 invalid_arg [% string " Multicore_backend.get_device %{ordinal#Int}: only device 0 exists" ];
277- Device_config. CPU
281+ device
278282
279283 let latest_stream_id = ref (- 1 )
280284
281- let new_stream Device_config. CPU =
285+ let new_stream _device =
282286 assert (Domain. is_main_domain () );
283287 Int. incr latest_stream_id;
284288 let stream = spinup_stream ~stream_id: ! latest_stream_id in
285289 Stdlib.Gc. finalise cleanup_stream stream;
286290 stream
287291
288- let get_stream_device _stream = Device_config. CPU
292+ let get_stream_device stream = stream.device
289293 let get_ctx_stream { stream; _ } = stream
290294 let to_ordinal _ = 0
291295
@@ -343,8 +347,7 @@ module Sync_backend (Backend : No_device_backend) = struct
343347 include (
344348 Backend : Buffer with type buffer_ptr = Backend. buffer_ptr and type buffer = Backend. buffer)
345349
346- type device = CPU [@@ deriving sexp_of ]
347- type stream_state = unit [@@ deriving sexp_of ]
350+ type dev = CPU [@@ deriving sexp_of ]
348351 type runner = unit [@@ deriving sexp_of ]
349352 type event = unit [@@ deriving sexp_of ]
350353 end
@@ -357,6 +360,7 @@ module Sync_backend (Backend : No_device_backend) = struct
357360 end
358361
359362 include Device (Device_types (Device_config )) (Alloc_buffer )
363+ open Device_config
360364
361365 let sync () = ()
362366 let is_done () = true
@@ -365,22 +369,23 @@ module Sync_backend (Backend : No_device_backend) = struct
365369 let alloc_buffer ?old_buffer ~size_in_bytes _stream =
366370 Backend. alloc_buffer ?old_buffer ~size_in_bytes ()
367371
368- let to_ordinal Device_config. CPU = 0
372+ let device : device = make_device CPU ~ordinal: 0
373+ let to_ordinal device = device.ordinal
369374
370375 let get_device ~ordinal =
371376 if ordinal <> 0 then
372377 invalid_arg @@ " Sync_backend.get_device: there is only one device, but ordinal="
373378 ^ Int. to_string ordinal;
374- Device_config. CPU
379+ device
375380
376381 let num_devices () = 1
377- let suggested_num_streams Device_config. CPU = ! sync_suggested_num_streams
378- let get_used_memory Device_config. CPU = Backend. get_used_memory ()
382+ let suggested_num_streams _ = ! sync_suggested_num_streams
383+ let get_used_memory _ = Backend. get_used_memory ()
379384 let latest_stram_id = ref (- 1 )
380385
381- let new_stream Device_config. CPU : stream =
386+ let new_stream device =
382387 Int. incr latest_stram_id;
383- make_stream ~ device: Device_config. CPU ~state: () ~stream_id: ! latest_stram_id ~runner: ()
388+ make_stream device () ~stream_id: ! latest_stram_id
384389
385390 type code = Backend .code [@@ deriving sexp_of ]
386391 type code_batch = Backend .code_batch [@@ deriving sexp_of ]
@@ -394,7 +399,7 @@ module Sync_backend (Backend : No_device_backend) = struct
394399 type context = { stream : stream ; ctx : Backend .context } [@@ deriving sexp_of ]
395400
396401 let get_ctx_stream context = context.stream
397- let get_stream_device _stream = Device_config. CPU
402+ let get_stream_device stream = stream.device
398403 let ctx_arrays context = ctx_arrays context.ctx
399404 let init stream = { stream; ctx = Backend. init name }
400405 let initialize = Backend. initialize
0 commit comments