Skip to content

Commit 7d020ca

Browse files
committed
Backends: Factor out the shared stream fields
1 parent 8ccd035 commit 7d020ca

File tree

3 files changed

+54
-54
lines changed

3 files changed

+54
-54
lines changed

arrayjit/lib/backend_types.ml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,28 @@ module Types = struct
3030
}
3131
[@@deriving sexp_of]
3232

33+
type ('buffer_ptr, 'event, 'device, 'state, 'runner) stream = {
34+
device : 'device;
35+
state : 'state;
36+
merge_buffer : ('buffer_ptr * Tnode.t) option ref;
37+
unique_name : string;
38+
mutable allocated_buffer : ('buffer_ptr * int) option;
39+
runner : 'runner;
40+
requested_work_for : 'event option Hashtbl.M(Tnode).t;
41+
}
42+
[@@deriving sexp_of]
43+
44+
let make_stream ~device ~state ~unique_name ~runner =
45+
{
46+
device;
47+
state;
48+
merge_buffer = ref None;
49+
unique_name : string;
50+
allocated_buffer = None;
51+
runner : 'runner;
52+
requested_work_for = Hashtbl.create (module Tnode);
53+
}
54+
3355
(** For now, we only configure a backend with regard to how many streams it should suggest using
3456
(where applicable). *)
3557
type config = Only_devices_parallel | For_parallel_copying | Most_parallel_streams

arrayjit/lib/backends.ml

Lines changed: 17 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -109,16 +109,13 @@ module Multicore_backend (Backend : Backend_types.No_device_backend) = struct
109109
}
110110
[@@deriving sexp_of]
111111

112-
type stream = {
113-
state : stream_state;
114-
merge_buffer : (buffer_ptr * Tnode.t) option ref;
115-
mutable allocated_buffer : (buffer_ptr * int) option;
116-
unique_name : string;
117-
domain : (unit Domain.t[@sexp.opaque]);
118-
}
119-
[@@deriving sexp_of]
112+
type domain = unit Domain.t
120113

121-
type event = Not_implemented_yet
114+
let sexp_of_domain (d : domain) = Sexp.Atom ("domain-" ^ Int.to_string (Domain.get_id d :> int))
115+
116+
type device = CPU [@@deriving sexp_of]
117+
type event = Not_implemented_yet [@@deriving sexp_of]
118+
type nonrec stream = (buffer_ptr, event, device, stream_state, domain) stream [@@deriving sexp_of]
122119

123120
(** TODO: Blocks till the event completes, if it's not done already. *)
124121
let sync Not_implemented_yet = ()
@@ -132,7 +129,6 @@ module Multicore_backend (Backend : Backend_types.No_device_backend) = struct
132129
let alloc_buffer ?old_buffer ~size_in_bytes _stream = alloc_buffer ?old_buffer ~size_in_bytes ()
133130
let get_used_memory _device = get_used_memory ()
134131

135-
type device = CPU [@@deriving sexp_of]
136132
type nonrec code = code [@@deriving sexp_of]
137133
type nonrec code_batch = code_batch [@@deriving sexp_of]
138134

@@ -210,19 +206,14 @@ module Multicore_backend (Backend : Backend_types.No_device_backend) = struct
210206
stream_error. But this is fine if we assume all exceptions are fatal. *)
211207
raise e
212208
in
213-
{
214-
state;
215-
unique_name;
216-
domain = Domain.spawn worker;
217-
merge_buffer = ref None;
218-
allocated_buffer = None;
219-
}
209+
make_stream ~device:CPU ~state ~unique_name ~runner:(Domain.spawn worker)
220210

221211
type nonrec context = { stream : stream; ctx : context } [@@deriving sexp_of]
222212

223213
let ctx_arrays context = ctx_arrays context.ctx
224214

225215
type nonrec routine = context Backend_types.Types.routine [@@deriving sexp_of]
216+
(** This overrides the routine type from [Backend]. *)
226217

227218
let init stream = { stream; ctx = init (name ^ " " ^ stream.unique_name) }
228219
let initialize = initialize
@@ -269,7 +260,7 @@ module Multicore_backend (Backend : Backend_types.No_device_backend) = struct
269260
stream.state.keep_spinning <- false;
270261
Stdlib.Condition.broadcast stream.state.dev_wait_for_work;
271262
Hash_set.remove used_names stream.unique_name;
272-
Domain.join stream.domain
263+
Domain.join stream.runner
273264

274265
let get_device ~ordinal =
275266
if ordinal <> 0 then
@@ -344,19 +335,12 @@ let sync_suggested_num_streams = ref 1
344335
module Sync_backend (Backend : Backend_types.No_device_backend) = struct
345336
include Backend
346337

347-
type event = unit
338+
type event = unit [@@deriving sexp_of]
348339

349340
let sync () = ()
350341
let is_done () = true
351342
let will_wait_for _context () = ()
352343

353-
type stream = {
354-
unique_name : string;
355-
merge_buffer : (buffer_ptr * Tnode.t) option ref;
356-
mutable allocated_buffer : (buffer_ptr * int) option;
357-
}
358-
[@@deriving sexp_of]
359-
360344
let alloc_buffer ?old_buffer ~size_in_bytes _stream =
361345
Backend.alloc_buffer ?old_buffer ~size_in_bytes ()
362346

@@ -375,13 +359,13 @@ module Sync_backend (Backend : Backend_types.No_device_backend) = struct
375359
let get_used_memory CPU = Backend.get_used_memory ()
376360
let next_stream = ref 0
377361

378-
let new_stream CPU =
362+
type nonrec stream = (buffer_ptr, event, device, unit, unit) stream [@@deriving sexp_of]
363+
364+
let new_stream CPU : stream =
379365
Int.incr next_stream;
380-
{
381-
unique_name = "stream " ^ Int.to_string (!next_stream - 1);
382-
merge_buffer = ref None;
383-
allocated_buffer = None;
384-
}
366+
make_stream ~device:CPU ~state:()
367+
~unique_name:("stream " ^ Int.to_string (!next_stream - 1))
368+
~runner:()
385369

386370
type code = Backend.code [@@deriving sexp_of]
387371
type code_batch = Backend.code_batch [@@deriving sexp_of]
@@ -399,6 +383,7 @@ module Sync_backend (Backend : Backend_types.No_device_backend) = struct
399383
let ctx_arrays context = ctx_arrays context.ctx
400384

401385
type nonrec routine = context Backend_types.Types.routine [@@deriving sexp_of]
386+
(** This overrides the routine type from [Backend]. *)
402387

403388
let init stream = { stream; ctx = Backend.init name }
404389
let initialize = Backend.initialize

arrayjit/lib/cuda_backend.cudajit.ml

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,7 @@ type device = {
4242
}
4343
[@@deriving sexp_of]
4444

45-
type stream = {
46-
device : device;
47-
cu_stream : Cu.Stream.t;
48-
unique_name : string;
49-
mutable merge_buffer : (buffer_ptr * Tn.t) option;
50-
}
51-
[@@deriving sexp_of]
45+
type nonrec stream = (buffer_ptr, event, device, unit, Cu.Stream.t) stream [@@deriving sexp_of]
5246

5347
type context = {
5448
label : string;
@@ -68,10 +62,10 @@ type context = {
6862
let ctx_arrays ctx = ctx.ctx_arrays
6963
let global_config = ref For_parallel_copying
7064
let is_done event = Cu.Delimited_event.query event
71-
let will_wait_for context event = Cu.Delimited_event.wait context.stream.cu_stream event
65+
let will_wait_for context event = Cu.Delimited_event.wait context.stream.runner event
7266
let sync event = Cu.Delimited_event.synchronize event
73-
let all_work stream = Cu.Delimited_event.record stream.cu_stream
74-
let scheduled_merge_node stream = Option.map ~f:snd stream.merge_buffer
67+
let all_work stream = Cu.Delimited_event.record stream.runner
68+
let scheduled_merge_node stream = Option.map ~f:snd !(stream.merge_buffer)
7569

7670
let is_initialized, initialize =
7771
let initialized = ref false in
@@ -172,7 +166,7 @@ let%track3_sexp new_stream (device : device) : stream =
172166
(* Strange that we need ctx_set_current even with a single device! *)
173167
set_ctx device.primary_context;
174168
let cu_stream = Cu.Stream.create ~non_blocking:true () in
175-
{ device; cu_stream; unique_name; merge_buffer = None }
169+
make_stream ~device ~state:() ~unique_name ~runner:cu_stream
176170

177171
let cuda_properties =
178172
let cache =
@@ -199,10 +193,10 @@ let get_name stream = stream.unique_name
199193

200194
let await stream : unit =
201195
set_ctx stream.device.primary_context;
202-
Cu.Stream.synchronize stream.cu_stream;
196+
Cu.Stream.synchronize stream.runner;
203197
Option.iter !Utils.advance_captured_logs ~f:(fun callback -> callback ())
204198

205-
let is_idle stream = Cu.Stream.is_ready stream.cu_stream
199+
let is_idle stream = Cu.Stream.is_ready stream.runner
206200

207201
let%track3_sexp finalize (ctx : context) : unit =
208202
if
@@ -235,23 +229,23 @@ let init stream =
235229

236230
let from_host ~dst_ptr ~dst hosted =
237231
set_ctx dst.ctx;
238-
let f src = Cu.Stream.memcpy_H_to_D ~dst:dst_ptr ~src dst.stream.cu_stream in
232+
let f src = Cu.Stream.memcpy_H_to_D ~dst:dst_ptr ~src dst.stream.runner in
239233
Ndarray.map { f } hosted
240234

241235
let to_host ~src_ptr ~src hosted =
242236
set_ctx src.ctx;
243-
let f dst = Cu.Stream.memcpy_D_to_H ~dst ~src:src_ptr src.stream.cu_stream in
237+
let f dst = Cu.Stream.memcpy_D_to_H ~dst ~src:src_ptr src.stream.runner in
244238
Ndarray.map { f } hosted
245239

246240
let device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
247241
let same_device = dst.stream.device.ordinal = src.stream.device.ordinal in
248242
let memcpy ~dst_ptr =
249243
if same_device then
250244
Cu.Stream.memcpy_D_to_D ~size_in_bytes:(Tn.size_in_bytes tn) ~dst:dst_ptr ~src:src_ptr
251-
dst.stream.cu_stream
245+
dst.stream.runner
252246
else
253247
Cu.Stream.memcpy_peer ~size_in_bytes:(Tn.size_in_bytes tn) ~dst:dst_ptr ~dst_ctx:dst.ctx
254-
~src:src_ptr ~src_ctx:src.ctx dst.stream.cu_stream
248+
~src:src_ptr ~src_ctx:src.ctx dst.stream.runner
255249
in
256250
match (into_merge_buffer, dst_ptr) with
257251
| No, None -> invalid_arg "Cuda_backend.device_to_device: missing dst_ptr"
@@ -260,13 +254,13 @@ let device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
260254
memcpy ~dst_ptr
261255
| Streaming, _ ->
262256
assert same_device;
263-
dst.stream.merge_buffer <- Some (src_ptr, tn)
257+
dst.stream.merge_buffer := Some (src_ptr, tn)
264258
| Copy, _ ->
265259
set_ctx dst.ctx;
266260
let size_in_bytes = Tn.size_in_bytes tn in
267261
opt_alloc_merge_buffer ~size_in_bytes dst.stream.device;
268262
memcpy ~dst_ptr:dst.stream.device.copy_merge_buffer;
269-
dst.stream.merge_buffer <- Some (dst.stream.device.copy_merge_buffer, tn)
263+
dst.stream.merge_buffer := Some (dst.stream.device.copy_merge_buffer, tn)
270264

271265
type code = {
272266
traced_store : Low_level.traced_store;
@@ -463,7 +457,7 @@ let link_proc ~prior_context ~name ~(params : (string * param_source) list) ~ctx
463457
S.Tensor arr
464458
| _name, Log_file_name -> S.Int log_id
465459
| _name, Merge_buffer ->
466-
let ptr = fst @@ Option.value_exn ~here:[%here] context.stream.merge_buffer in
460+
let ptr = fst @@ Option.value_exn ~here:[%here] !(context.stream.merge_buffer) in
467461
S.Tensor ptr
468462
| _name, Static_idx s ->
469463
let i = Indexing.find_exn lowered_bindings s in
@@ -492,8 +486,7 @@ let link_proc ~prior_context ~name ~(params : (string * param_source) list) ~ctx
492486
[%log_block
493487
context.label;
494488
Utils.log_trace_tree _output]);
495-
S.launch_kernel func ~grid_dim_x:1 ~block_dim_x:1 ~shared_mem_bytes:0 context.stream.cu_stream
496-
args;
489+
S.launch_kernel func ~grid_dim_x:1 ~block_dim_x:1 ~shared_mem_bytes:0 context.stream.runner args;
497490
[%log "kernel launched"]
498491
in
499492
( context,

0 commit comments

Comments
 (0)