Skip to content

Commit 03c7989

Browse files
committed
Factor out a shared device record, include stream_state in runner
1 parent c41f23a commit 03c7989

File tree

4 files changed

+123
-114
lines changed

4 files changed

+123
-114
lines changed

arrayjit/lib/backend_types.ml

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -113,48 +113,67 @@ type 'context routine = {
113113
module type Device_config = sig
114114
include Buffer
115115

116-
type device [@@deriving sexp_of]
117-
type stream_state [@@deriving sexp_of]
116+
type dev [@@deriving sexp_of]
117+
(** Interface to a device driver. *)
118+
118119
type runner [@@deriving sexp_of]
120+
(** Interface to a stream driver. *)
119121

120122
type event [@@deriving sexp_of]
121123
(** An event tracks if a stream finished computing past a particular point in its schedue. These
122124
values are used internally for scheduling across streams of the backend, and can be used for
123125
explicit scheduling. *)
124126
end
125127

126-
type ('buffer_ptr, 'device, 'stream_state, 'runner, 'event) stream = {
127-
device : 'device;
128-
state : 'stream_state;
128+
type ('buffer_ptr, 'dev, 'event) device = {
129+
dev : 'dev;
130+
ordinal : int;
131+
mutable shared_merge_buffer : 'buffer_ptr buffer option;
132+
mutable latest_stream_id : int;
133+
released : Utils.atomic_bool;
134+
cross_stream_candidates : 'buffer_ptr Hashtbl.M(Tnode).t;
135+
(** Freshly created arrays that might be shared across streams. The map can both grow and
136+
shrink. See the explanation on top of this file. *)
137+
owner_streams : int Hashtbl.M(Tnode).t;
138+
(** The streams owning the given nodes. This map can only grow. *)
139+
stream_working_on : (int * 'event) option Hashtbl.M(Tnode).t;
140+
(** The stream that most recently has been updating the node, and the associated update
141+
completion event. Only populated when {!field-queried_work_for} is populated. *)
142+
}
143+
[@@deriving sexp_of]
144+
145+
type ('buffer_ptr, 'dev, 'runner, 'event) stream = {
146+
device : ('buffer_ptr, 'dev, 'event) device;
147+
runner : 'runner;
129148
merge_buffer : ('buffer_ptr * Tnode.t) option ref;
130149
stream_id : int;
131150
mutable allocated_buffer : 'buffer_ptr buffer option;
132-
queried_work_for : 'event option Hashtbl.M(Tnode).t;
151+
queried_work_for : 'event option Hashtbl.M(Tnode).t; (* The completion event for updating the node via this stream. Only populated after the first time {!} *)
133152
}
134153
[@@deriving sexp_of]
135154

136155
module type Device_types = sig
137156
include Device_config
138157

139-
type nonrec stream = (buffer_ptr, device, stream_state, runner, event) stream [@@deriving sexp_of]
158+
type nonrec device = (buffer_ptr, dev, event) device [@@deriving sexp_of]
159+
type nonrec stream = (buffer_ptr, dev, runner, event) stream [@@deriving sexp_of]
140160
end
141161

142162
module Stream (Device_config : Device_config) = struct
163+
type nonrec device = (Device_config.buffer_ptr, Device_config.dev, Device_config.event) device
164+
[@@deriving sexp_of]
165+
143166
type nonrec stream =
144-
( Device_config.buffer_ptr,
145-
Device_config.device,
146-
Device_config.stream_state,
147-
Device_config.runner,
148-
Device_config.event )
149-
stream
167+
(Device_config.buffer_ptr, Device_config.dev, Device_config.runner, Device_config.event) stream
150168
[@@deriving sexp_of]
151169
end
152170

153171
module type Device = sig
154172
include Device_types
155173
include Alloc_buffer with type buffer_ptr := buffer_ptr and type stream := stream
156174

157-
val make_stream : device:device -> state:stream_state -> stream_id:int -> runner:runner -> stream
175+
val make_device : dev -> ordinal:int -> device
176+
val make_stream : device -> runner -> stream_id:int -> stream
158177
end
159178

160179
module Device_types (Device_config : Device_config) = struct
@@ -171,14 +190,25 @@ struct
171190
include Device_types
172191
include Alloc_buffer
173192

174-
let make_stream ~device ~state ~stream_id ~runner =
193+
let make_device dev ~ordinal =
194+
{
195+
dev;
196+
ordinal;
197+
shared_merge_buffer = None;
198+
latest_stream_id = -1;
199+
released = Atomic.make false;
200+
cross_stream_candidates = Hashtbl.create (module Tnode);
201+
owner_streams = Hashtbl.create (module Tnode);
202+
stream_working_on = Hashtbl.create (module Tnode);
203+
}
204+
205+
let make_stream device runner ~stream_id =
175206
{
176207
device;
177-
state;
208+
runner;
178209
merge_buffer = ref None;
179210
stream_id;
180211
allocated_buffer = None;
181-
runner;
182212
queried_work_for = Hashtbl.create (module Tnode);
183213
}
184214
end

arrayjit/lib/backends.ml

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ module Multicore_backend (Backend : No_device_backend) = struct
104104
include (
105105
Backend : Buffer with type buffer_ptr = Backend.buffer_ptr and type buffer = Backend.buffer)
106106

107-
type device = CPU [@@deriving sexp_of]
107+
type dev = CPU [@@deriving sexp_of]
108108

109109
type stream_state = {
110110
mutable keep_spinning : bool;
@@ -117,10 +117,11 @@ module Multicore_backend (Backend : No_device_backend) = struct
117117
}
118118
[@@deriving sexp_of]
119119

120-
type runner = unit Domain.t
120+
type domain = unit Domain.t
121121

122-
let sexp_of_runner (d : runner) = Sexp.Atom ("domain-" ^ Int.to_string (Domain.get_id d :> int))
122+
let sexp_of_domain (d : domain) = Sexp.Atom ("domain-" ^ Int.to_string (Domain.get_id d :> int))
123123

124+
type runner = { state : stream_state; domain : domain } [@@deriving sexp_of]
124125
type event = Not_implemented_yet [@@deriving sexp_of]
125126
end
126127

@@ -132,29 +133,30 @@ module Multicore_backend (Backend : No_device_backend) = struct
132133
end
133134

134135
include Device (Device_types (Device_config)) (Alloc_buffer)
136+
open Device_config
135137

136138
(** TODO: Blocks till the event completes, if it's not done already. *)
137-
let sync Device_config.Not_implemented_yet = ()
139+
let sync Not_implemented_yet = ()
138140

139141
(** TODO: Whether the event completed. *)
140-
let is_done Device_config.Not_implemented_yet = true
142+
let is_done Not_implemented_yet = true
141143

142144
(** TODO: Schedules waiting for the given event on the context's stream. *)
143-
let will_wait_for _ctx Device_config.Not_implemented_yet = ()
145+
let will_wait_for _ctx Not_implemented_yet = ()
144146

145147
let get_used_memory _device = get_used_memory ()
146148

147149
type nonrec code = code [@@deriving sexp_of]
148150
type nonrec code_batch = code_batch [@@deriving sexp_of]
149151

150-
let is_dev_queue_empty state = Queue.size state.Device_config.queue = 0
151-
let is_idle stream = is_dev_queue_empty stream.state && stream.state.is_ready
152+
let is_dev_queue_empty state = Queue.size state.queue = 0
153+
let is_idle stream = is_dev_queue_empty stream.runner.state && stream.runner.state.is_ready
152154
let name = "multicore_" ^ name
153155
let get_name stream = [%string "%{name}:0:%{stream.stream_id#Int}"]
154156

155157
let%track3_l_sexp await stream =
156158
assert (Domain.is_main_domain ());
157-
let d = stream.state in
159+
let d = stream.runner.state in
158160
if (not @@ is_idle stream) && d.keep_spinning then (
159161
Mut.lock d.mut;
160162
while (not @@ is_idle stream) && d.keep_spinning do
@@ -167,13 +169,13 @@ module Multicore_backend (Backend : No_device_backend) = struct
167169

168170
(** TODO: Returns the event indicating if any currently running or scheduled computations on the
169171
stream have completed. *)
170-
let all_work _stream = Device_config.Not_implemented_yet
172+
let all_work _stream = Not_implemented_yet
171173

172174
let%track3_l_sexp schedule_task stream task =
173175
assert (Domain.is_main_domain ());
174176
[%log_result "schedule_task", Task.describe task, get_name stream];
175-
let d = stream.state in
176-
Option.iter d.Device_config.stream_error ~f:(fun e -> Exn.reraise e @@ get_name stream);
177+
let d = stream.runner.state in
178+
Option.iter d.stream_error ~f:(fun e -> Exn.reraise e @@ get_name stream);
177179
if not d.keep_spinning then invalid_arg "Multicore_backend: stream not available";
178180
if not @@ Queue.try_push d.queue task then (
179181
await stream;
@@ -184,12 +186,13 @@ module Multicore_backend (Backend : No_device_backend) = struct
184186
Mut.unlock d.mut)
185187

186188
let global_run_no = ref 0
189+
let device : device = make_device CPU ~ordinal:0
187190

188191
let%track3_l_sexp spinup_stream ~stream_id : stream =
189192
Int.incr global_run_no;
190193
let state =
191194
{
192-
Device_config.keep_spinning = true;
195+
keep_spinning = true;
193196
stream_error = None;
194197
queue = Queue.create ~size_exponent:12;
195198
mut = Mut.create ();
@@ -222,7 +225,7 @@ module Multicore_backend (Backend : No_device_backend) = struct
222225
stream_error. But this is fine if we assume all exceptions are fatal. *)
223226
raise e
224227
in
225-
make_stream ~device:Device_config.CPU ~state ~stream_id ~runner:(Domain.spawn worker)
228+
make_stream device { state; domain = Domain.spawn worker } ~stream_id
226229

227230
type nonrec context = { stream : stream; ctx : context } [@@deriving sexp_of]
228231

@@ -262,30 +265,31 @@ module Multicore_backend (Backend : No_device_backend) = struct
262265
module Dynarr = Stdlib.Dynarray
263266

264267
let num_devices () = 1
265-
let suggested_num_streams Device_config.CPU = Domain.recommended_domain_count () - 1
268+
let suggested_num_streams _device = Domain.recommended_domain_count () - 1
266269

267270
let cleanup_stream stream =
268271
assert (Domain.is_main_domain ());
269272
await stream;
270-
stream.state.keep_spinning <- false;
271-
Stdlib.Condition.broadcast stream.state.dev_wait_for_work;
272-
Domain.join stream.runner
273+
let r = stream.runner in
274+
r.state.keep_spinning <- false;
275+
Stdlib.Condition.broadcast r.state.dev_wait_for_work;
276+
Domain.join r.domain
273277

274278
let get_device ~ordinal =
275279
if ordinal <> 0 then
276280
invalid_arg [%string "Multicore_backend.get_device %{ordinal#Int}: only device 0 exists"];
277-
Device_config.CPU
281+
device
278282

279283
let latest_stream_id = ref (-1)
280284

281-
let new_stream Device_config.CPU =
285+
let new_stream _device =
282286
assert (Domain.is_main_domain ());
283287
Int.incr latest_stream_id;
284288
let stream = spinup_stream ~stream_id:!latest_stream_id in
285289
Stdlib.Gc.finalise cleanup_stream stream;
286290
stream
287291

288-
let get_stream_device _stream = Device_config.CPU
292+
let get_stream_device stream = stream.device
289293
let get_ctx_stream { stream; _ } = stream
290294
let to_ordinal _ = 0
291295

@@ -343,8 +347,7 @@ module Sync_backend (Backend : No_device_backend) = struct
343347
include (
344348
Backend : Buffer with type buffer_ptr = Backend.buffer_ptr and type buffer = Backend.buffer)
345349

346-
type device = CPU [@@deriving sexp_of]
347-
type stream_state = unit [@@deriving sexp_of]
350+
type dev = CPU [@@deriving sexp_of]
348351
type runner = unit [@@deriving sexp_of]
349352
type event = unit [@@deriving sexp_of]
350353
end
@@ -357,6 +360,7 @@ module Sync_backend (Backend : No_device_backend) = struct
357360
end
358361

359362
include Device (Device_types (Device_config)) (Alloc_buffer)
363+
open Device_config
360364

361365
let sync () = ()
362366
let is_done () = true
@@ -365,22 +369,23 @@ module Sync_backend (Backend : No_device_backend) = struct
365369
let alloc_buffer ?old_buffer ~size_in_bytes _stream =
366370
Backend.alloc_buffer ?old_buffer ~size_in_bytes ()
367371

368-
let to_ordinal Device_config.CPU = 0
372+
let device : device = make_device CPU ~ordinal:0
373+
let to_ordinal device = device.ordinal
369374

370375
let get_device ~ordinal =
371376
if ordinal <> 0 then
372377
invalid_arg @@ "Sync_backend.get_device: there is only one device, but ordinal="
373378
^ Int.to_string ordinal;
374-
Device_config.CPU
379+
device
375380

376381
let num_devices () = 1
377-
let suggested_num_streams Device_config.CPU = !sync_suggested_num_streams
378-
let get_used_memory Device_config.CPU = Backend.get_used_memory ()
382+
let suggested_num_streams _ = !sync_suggested_num_streams
383+
let get_used_memory _ = Backend.get_used_memory ()
379384
let latest_stram_id = ref (-1)
380385

381-
let new_stream Device_config.CPU : stream =
386+
let new_stream device =
382387
Int.incr latest_stram_id;
383-
make_stream ~device:Device_config.CPU ~state:() ~stream_id:!latest_stram_id ~runner:()
388+
make_stream device () ~stream_id:!latest_stram_id
384389

385390
type code = Backend.code [@@deriving sexp_of]
386391
type code_batch = Backend.code_batch [@@deriving sexp_of]
@@ -394,7 +399,7 @@ module Sync_backend (Backend : No_device_backend) = struct
394399
type context = { stream : stream; ctx : Backend.context } [@@deriving sexp_of]
395400

396401
let get_ctx_stream context = context.stream
397-
let get_stream_device _stream = Device_config.CPU
402+
let get_stream_device stream = stream.device
398403
let ctx_arrays context = ctx_arrays context.ctx
399404
let init stream = { stream; ctx = Backend.init name }
400405
let initialize = Backend.initialize

0 commit comments

Comments
 (0)