@@ -18,7 +18,7 @@ let check_merge_buffer ~scheduled_node ~code_node =
1818 (" Merge buffer mismatch, on stream: " ^ name scheduled_node ^ " , expected by code: "
1919 ^ name code_node)
2020
21- module Multicore_backend (Backend : Backend_types.No_device_backend ) : Backend_types. Backend =
21+ module Multicore_backend (Backend : Backend_types.No_device_backend ) (* : Backend_types.Backend *) =
2222struct
2323 module Domain = Domain [@ warning " -3" ]
2424
6868 state : stream_state ;
6969 merge_buffer : (buffer_ptr * Tnode .t ) option ref ;
7070 mutable allocated_buffer : (buffer_ptr * int ) option ;
71- ordinal : int ;
71+ subordinal : int ;
7272 domain : (unit Domain .t [@ sexp.opaque]);
7373 }
7474 [@@ deriving sexp_of ]
7878
7979 let get_used_memory _device = Backend. get_used_memory ()
8080
81- type device = stream [@@ deriving sexp_of ]
81+ type device = CPU [@@ deriving sexp_of ]
8282 type code = Backend .code [@@ deriving sexp_of ]
8383 type code_batch = Backend .code_batch [@@ deriving sexp_of ]
8484
@@ -98,18 +98,18 @@ struct
9898 done ;
9999 Mut. unlock d.mut;
100100 Option. iter d.stream_error ~f: (fun e ->
101- Exn. reraise e @@ name ^ " stream " ^ Int. to_string stream.ordinal ))
101+ Exn. reraise e @@ name ^ " stream " ^ Int. to_string stream.subordinal ))
102102
103103 (* * TODO: Returns the event indicating if any currently running or scheduled computations on the
104104 stream have completed. *)
105105 let all_work _stream = Not_implemented_yet
106106
107107 let % track3_l_sexp schedule_task stream task =
108108 assert (Domain. is_main_domain () );
109- [% log_result " schedule_task" , Task. describe task, " stream" , (stream.ordinal : int )];
109+ [% log_result " schedule_task" , Task. describe task, " stream" , (stream.subordinal : int )];
110110 let d = stream.state in
111111 Option. iter d.stream_error ~f: (fun e ->
112- Exn. reraise e @@ name ^ " stream " ^ Int. to_string stream.ordinal );
112+ Exn. reraise e @@ name ^ " stream " ^ Int. to_string stream.subordinal );
113113 if not d.keep_spinning then invalid_arg " Multicore_backend: stream not available" ;
114114 if not @@ Queue. try_push d.queue task then (
115115 await stream;
@@ -121,7 +121,7 @@ struct
121121
122122 let global_run_no = ref 0
123123
124- let % track3_l_sexp spinup_stream ~(ordinal : int ) : stream =
124+ let % track3_l_sexp spinup_stream ~(subordinal : int ) : stream =
125125 Int. incr global_run_no;
126126 let state =
127127 {
@@ -153,14 +153,14 @@ struct
153153 with e ->
154154 state.stream_error < - Some e;
155155 state.keep_spinning < - false ;
156- [% log1 " Stream" , (ordinal : int ), " exception" , Exn. to_string e];
156+ [% log1 " Stream" , (subordinal : int ), " exception" , Exn. to_string e];
157157 (* TODO: we risk raising this error multiple times because await and schedule_task raise
158158 stream_error. But this is fine if we assume all exceptions are fatal. *)
159159 raise e
160160 in
161161 {
162162 state;
163- ordinal ;
163+ subordinal ;
164164 domain = Domain. spawn worker;
165165 merge_buffer = ref None ;
166166 allocated_buffer = None ;
@@ -169,7 +169,7 @@ struct
169169 type context = { stream : stream ; ctx : Backend .context } [@@ deriving sexp_of ]
170170 type nonrec routine = context routine [@@ deriving sexp_of ]
171171
172- let init stream = { stream; ctx = Backend. init (name ^ " " ^ Int. to_string stream.ordinal ) }
172+ let init stream = { stream; ctx = Backend. init (name ^ " " ^ Int. to_string stream.subordinal ) }
173173 let initialize = Backend. initialize
174174 let is_initialized = Backend. is_initialized
175175
@@ -179,7 +179,7 @@ struct
179179
180180 let compile = Backend. compile
181181 let compile_batch = Backend. compile_batch
182- let get_stream_name s = " stream " ^ Int. to_string s.ordinal
182+ let get_stream_name s = " stream " ^ Int. to_string s.subordinal
183183
184184 let link { ctx; stream } code =
185185 let task = Backend. link ~merge_buffer: stream.merge_buffer ctx code in
@@ -224,7 +224,7 @@ struct
224224 context_lifetime = context;
225225 description =
226226 " from_host " ^ Tnode. debug_name tn ^ " dst "
227- ^ Int. to_string context.stream.ordinal ;
227+ ^ Int. to_string context.stream.subordinal ;
228228 work;
229229 });
230230 true
@@ -258,7 +258,7 @@ struct
258258 context_lifetime = context;
259259 description =
260260 " from_host " ^ Tnode. debug_name tn ^ " dst "
261- ^ Int. to_string context.stream.ordinal ;
261+ ^ Int. to_string context.stream.subordinal ;
262262 work;
263263 });
264264 true
@@ -296,8 +296,8 @@ struct
296296 Backend. to_buffer tn ~dst: merge_ptr ~src: src.ctx
297297 in
298298 let description =
299- " device_to_device " ^ Tnode. debug_name tn ^ " dst " ^ Int. to_string dev.ordinal ^ " src "
300- ^ Int. to_string src.stream.ordinal
299+ " device_to_device " ^ Tnode. debug_name tn ^ " dst " ^ Int. to_string dev.subordinal ^ " src "
300+ ^ Int. to_string src.stream.subordinal
301301 in
302302 schedule_task dev (Task. Task { context_lifetime = (src, dst); description; work })
303303 in
@@ -307,37 +307,40 @@ struct
307307 true
308308 | _ -> false
309309
310- let num_devices () = Domain. recommended_domain_count () - 1
311- let suggested_num_streams _device = 1
312- let devices : device option array = Array. create ~len: (num_devices () ) None
310+ module Dynarr = Stdlib. Dynarray
313311
314- let % track2_sexp unsafe_cleanup () =
312+ let num_devices () = 1
313+ let suggested_num_streams CPU = Domain. recommended_domain_count () - 1
314+ let latest_subordinal = ref 0
315+
316+ let cleanup_stream stream =
315317 assert (Domain. is_main_domain () );
316- let wait_for_finish stream =
317- await stream;
318- stream.state.keep_spinning < - false ;
319- Stdlib.Condition. broadcast stream.state.dev_wait_for_work
320- in
321- Array. iter devices ~f: (Option. iter ~f: wait_for_finish);
322- let cleanup ordinal device =
323- Domain. join device.domain;
324- devices.(ordinal) < - None
325- in
326- Array. iteri devices ~f: (fun ordinal -> Option. iter ~f: (cleanup ordinal));
318+ await stream;
319+ stream.state.keep_spinning < - false ;
320+ Stdlib.Condition. broadcast stream.state.dev_wait_for_work;
321+ Domain. join stream.domain
322+
323+ let % track2_sexp unsafe_cleanup () =
324+ latest_subordinal := 0 ;
327325 Backend. unsafe_cleanup ()
328326
329327 let get_device ~ordinal =
330- Option. value_or_thunk devices.(ordinal) ~default: (fun () ->
331- let dev = spinup_stream ~ordinal in
332- devices.(ordinal) < - Some dev;
333- dev)
328+ if ordinal <> 0 then
329+ invalid_arg [% string " Multicore_backend.get_device %{ordinal#Int}: only device 0 exists" ];
330+ CPU
331+
332+ let new_stream CPU =
333+ let subordinal = ! latest_subordinal in
334+ Int. incr latest_subordinal;
335+ let stream = spinup_stream ~subordinal in
336+ Stdlib.Gc. finalise cleanup_stream stream;
337+ stream
334338
335- let new_stream device = device
336- let get_stream_device stream = stream
339+ let get_stream_device _stream = CPU
337340 let get_ctx_stream { stream; _ } = stream
338- let get_name device = Int. to_string device.ordinal
339- let to_ordinal { ordinal; _ } = ordinal
340- let to_subordinal _ = 0
341+ let get_name stream = Int. to_string stream.subordinal
342+ let to_ordinal _ = 0
343+ let to_subordinal { subordinal; _ } = subordinal
341344end
342345
343346(* * For debugging, allow [Sync_backend(...).suggested_num_streams] calls to return >1 numbers. *)
799802module Cuda_backend : Backend_types .Backend = Lowered_backend ((
800803 Cuda_backend : Backend_types. Lowered_backend ))
801804
805+ (* * Initializes the backend, and if it was already initialized, performs garbage collection. *)
802806let reinitialize (module Backend : Backend_types.Backend ) config =
803807 if not @@ Backend. is_initialized () then Backend. initialize config
804808 else (
@@ -807,7 +811,7 @@ let reinitialize (module Backend : Backend_types.Backend) config =
807811 Backend. initialize config)
808812
809813(* * Reinitializes and returns a backend corresponding to [backend_name], or if omitted, selected via
810- the global [backend] setting. *)
814+ the global [backend] setting. See {!reinitialize}. *)
811815let fresh_backend ?backend_name ?(config = Only_devices_parallel ) () =
812816 let backend =
813817 match
0 commit comments