Skip to content

Commit 772bea0

Browse files
committed
Synchronize all devices of a stream, with cleanup; landmarks
1 parent 324bfc2 commit 772bea0

File tree

11 files changed

+168
-107
lines changed

11 files changed

+168
-107
lines changed

CHANGES.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
- Interface files for `Backends` and `Low_level`.
66
- Fixed #245: tracking of used memory.
7-
- TODO: stream-to-stream synchronization functionality, with lazy per-tensor-node synchronization.
7+
- Stream-to-stream synchronization functionality, with lazy per-tensor-node synchronization.
8+
- TODO: Automatic blocking on access of a host array when a scheduled `to_host` transfer has not finished.
89

910
### Changed
1011

@@ -19,7 +20,8 @@
1920
- Got rid of `subordinal`.
2021
- Removed dependency on `core`, broke up dependency on `ppx_jane`.
2122
- Huge refactoring of backend internal interfaces and API (not repeating same code).
22-
- TODO: Built per-tensor-node stream-to-stream synchronization into copying functions, removed obsolete blocking synchronizations.
23+
- Built per-tensor-node stream-to-stream synchronization into copying functions.
24+
- Re-introduced whole-device blocking synchronization, which now is just a slight optimization as it also cleans up event book-keeping.
2325

2426
### Fixed
2527

arrayjit/lib/backend_impl.ml

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -101,26 +101,27 @@ struct
101101
{
102102
dev;
103103
ordinal;
104-
latest_stream_id = -1;
105104
released = Atomic.make false;
106105
cross_stream_candidates = Hashtbl.create (module Tnode);
107106
owner_stream = Hashtbl.create (module Tnode);
108107
shared_writer_streams = Hashtbl.create (module Tnode);
109108
host_reading_streams = Hashtbl.create (module Tnode);
110109
host_writing_streams = Hashtbl.create (module Tnode);
110+
streams = Utils.weak_create ();
111111
}
112112

113-
let make_stream device runner ~stream_id =
114-
{
115-
device;
116-
runner;
117-
merge_buffer = ref None;
118-
stream_id;
119-
allocated_buffer = None;
120-
updating_for = Hashtbl.create (module Tnode);
121-
updating_for_merge_buffer = None;
122-
reader_streams = Hashtbl.create (module Tnode);
123-
}
113+
let make_stream device runner =
114+
Utils.register_new device.streams ~grow_by:8 (fun stream_id ->
115+
{
116+
device;
117+
runner;
118+
merge_buffer = ref None;
119+
stream_id;
120+
allocated_buffer = None;
121+
updating_for = Hashtbl.create (module Tnode);
122+
updating_for_merge_buffer = None;
123+
reader_streams = Hashtbl.create (module Tnode);
124+
})
124125

125126
let get_name stream = [%string "%{name}:%{stream.device.ordinal#Int}:%{stream.stream_id#Int}"]
126127

arrayjit/lib/backend_intf.ml

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ end
8282
type ('buffer_ptr, 'dev, 'runner, 'event) device_ref = {
8383
dev : 'dev;
8484
ordinal : int;
85-
mutable latest_stream_id : int;
8685
released : Utils.atomic_bool;
8786
cross_stream_candidates : 'buffer_ptr Hashtbl.M(Tnode).t;
8887
owner_stream : ('buffer_ptr, 'dev, 'runner, 'event) stream_ref Hashtbl.M(Tnode).t;
@@ -92,6 +91,7 @@ type ('buffer_ptr, 'dev, 'runner, 'event) device_ref = {
9291
(('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) list Hashtbl.M(Tnode).t;
9392
host_writing_streams :
9493
(('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) list Hashtbl.M(Tnode).t;
94+
mutable streams : ('buffer_ptr, 'dev, 'runner, 'event) stream_ref Utils.weak_dynarray;
9595
}
9696

9797
and ('buffer_ptr, 'dev, 'runner, 'event) stream_ref = {
@@ -114,7 +114,6 @@ type ('buffer_ptr, 'dev, 'runner, 'event) device =
114114
('buffer_ptr, 'dev, 'runner, 'event) device_ref = {
115115
dev : 'dev;
116116
ordinal : int;
117-
mutable latest_stream_id : int;
118117
released : Utils.atomic_bool;
119118
cross_stream_candidates : 'buffer_ptr Hashtbl.M(Tnode).t;
120119
(** Freshly created arrays that might be shared across streams. The map can both grow and
@@ -136,6 +135,7 @@ type ('buffer_ptr, 'dev, 'runner, 'event) device =
136135
(('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) list Hashtbl.M(Tnode).t;
137136
(** The streams that most recently have been writing to a node's on-host array. The completed
138137
events are removed opportunistically. *)
138+
mutable streams : ('buffer_ptr, 'dev, 'runner, 'event) stream_ref Utils.weak_dynarray; (** . *)
139139
}
140140
[@@deriving sexp_of]
141141

@@ -147,7 +147,7 @@ type ('buffer_ptr, 'dev, 'runner, 'event) stream =
147147
(** Depending on backend implementations, either the currently used merge buffer, or the one
148148
most recently scheduled. Note that the pointer can be reused for nodes that fit in an
149149
already allocated buffer. *)
150-
stream_id : int; (** An ID unique within the device. *)
150+
stream_id : int; (** An ID unique within the device for the lifetime of the stream. *)
151151
mutable allocated_buffer : 'buffer_ptr buffer option;
152152
updating_for : 'event Hashtbl.M(Tnode).t;
153153
(* The completion event for the most recent updating (writing to) a node via this stream. *)
@@ -188,7 +188,7 @@ module type Device = sig
188188
include Alloc_buffer with type buffer_ptr := buffer_ptr and type stream := stream
189189

190190
val make_device : dev -> ordinal:int -> device
191-
val make_stream : device -> runner -> stream_id:int -> stream
191+
val make_stream : device -> runner -> stream
192192

193193
val make_context : ?ctx_arrays:ctx_arrays -> stream -> context
194194
(** Returns a context without a parent. *)
@@ -291,6 +291,7 @@ module type Backend_device_common = sig
291291
end
292292

293293
module type With_buffer_retrieval_and_syncing = sig
294+
type device
294295
type context
295296
type event
296297

@@ -318,6 +319,9 @@ module type With_buffer_retrieval_and_syncing = sig
318319
buffer, and initializes the merge buffer's streaming event.
319320
- If [into_merge_buffer=Copy], schedules copying from [src] to the merge buffer of [dst]'s
320321
stream, and updates the writer event for the merge buffer. *)
322+
323+
val sync_device : device -> unit
324+
(** Synchronizes all the streams on a device, and cleans up (removes) all associated events. *)
321325
end
322326

323327
module type Backend = sig
@@ -331,5 +335,9 @@ module type Backend = sig
331335
(** Returns the routines for the procedures included in the code batch. The returned context is
332336
downstream of all the returned routines. *)
333337

334-
include With_buffer_retrieval_and_syncing with type context := context and type event := event
338+
include
339+
With_buffer_retrieval_and_syncing
340+
with type device := device
341+
and type context := context
342+
and type event := event
335343
end

arrayjit/lib/backends.ml

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ let check_merge_buffer stream ~code_node =
2222
^ ", expected by code: " ^ name code_node)
2323

2424
module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncing) = struct
25-
let wait_for_all ctx streams tn =
25+
let[@landmark] wait_for_all ctx streams tn =
2626
let s = ctx.stream in
2727
Hashtbl.update_and_return streams tn
2828
~f:
@@ -31,15 +31,15 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
3131
|> List.iter ~f:(fun (work_stream, e) ->
3232
if not (equal_stream work_stream s) then Backend.will_wait_for ctx e)
3333

34-
let wait_for_ready ~dst ~src tn =
34+
let[@landmark] wait_for_ready ~dst ~src tn =
3535
let s = src.stream in
3636
let d = dst.stream in
3737
(* TODO: maybe it's worthwhile to clean up s.updating_for every now and then. *)
3838
Hashtbl.find s.updating_for tn
3939
|> Option.iter ~f:(fun upd_e ->
4040
if not (equal_stream s d || Backend.is_done upd_e) then Backend.will_wait_for dst upd_e)
4141

42-
let update_writer_event ?e ?from s tn =
42+
let[@landmark] update_writer_event ?e ?from s tn =
4343
let e = Option.value_or_thunk e ~default:(fun () -> Backend.all_work s) in
4444
let f l = (s, e) :: Option.value ~default:[] l in
4545
(match (from, tn) with
@@ -52,13 +52,14 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
5252
| Node tn ->
5353
if Tn.potentially_cross_stream tn then
5454
Hashtbl.update s.device.shared_writer_streams tn ~f:(fun l ->
55-
(s, e) :: Option.value ~default:[] l);
55+
(s, e) :: Option.value ~default:[] l)
56+
else Hashtbl.remove s.device.shared_writer_streams tn;
5657
Hashtbl.update s.updating_for tn ~f:(fun _ -> e)
5758
| Merge_buffer tn ->
5859
(* Note: the previous event does not need to be done! *)
5960
s.updating_for_merge_buffer <- Some (tn, Some e)
6061

61-
let%diagn2_l_sexp from_host (ctx : Backend.context) tn =
62+
let%track2_l_sexp[@landmark] from_host (ctx : Backend.context) tn =
6263
match (tn, Map.find ctx.ctx_arrays tn) with
6364
| { Tn.array = (lazy (Some hosted)); _ }, Some dst ->
6465
wait_for_all ctx ctx.stream.reader_streams tn;
@@ -68,7 +69,7 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
6869
true
6970
| _ -> false
7071

71-
let%diagn2_l_sexp to_host (ctx : Backend.context) (tn : Tn.t) =
72+
let%track2_l_sexp[@landmark] to_host (ctx : Backend.context) (tn : Tn.t) =
7273
match (tn, Map.find ctx.ctx_arrays tn) with
7374
| { Tn.array = (lazy (Some hosted)); _ }, Some src ->
7475
if Tn.potentially_cross_stream tn then
@@ -82,8 +83,8 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
8283
true
8384
| _ -> false
8485

85-
let%diagn2_l_sexp device_to_device (tn : Tn.t) ~into_merge_buffer ~(dst : Backend.context)
86-
~(src : Backend.context) =
86+
let%diagn2_l_sexp[@landmark] device_to_device (tn : Tn.t) ~into_merge_buffer
87+
~(dst : Backend.context) ~(src : Backend.context) =
8788
let ordinal_of ctx = ctx.stream.device.ordinal in
8889
let name_of ctx = Backend.(get_name ctx.stream) in
8990
let same_device = ordinal_of dst = ordinal_of src in
@@ -115,30 +116,40 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
115116
Backend.(
116117
device_to_device tn ~into_merge_buffer ~dst_ptr:None ~dst ~src_ptr:s_arr ~src);
117118
dst.stream.updating_for_merge_buffer <- Some (tn, None);
118-
Task.run task;
119+
let[@landmark] merge_task () = Task.run task in
120+
merge_task ();
119121
update_writer_event ~from:(`Src src.stream) dst.stream @@ Merge_buffer tn;
120122
[%log "streaming into merge buffer", Tn.debug_name tn, "from", name_of src];
121123
true)
122124

123-
let%track3_l_sexp sync_routine r =
125+
let%track2_l_sexp sync_routine r =
124126
let s = r.context.stream in
125-
let pre () =
126-
Hashtbl.filter_mapi_inplace s.device.shared_writer_streams ~f:(fun ~key ~data ->
127-
if Tn.potentially_cross_stream key then
128-
if Set.mem r.inputs key then (
129-
let data = List.filter data ~f:(fun (_, e) -> Backend.is_done e) in
130-
List.iter data ~f:(fun (work_stream, e) ->
131-
if not (equal_stream work_stream s) then Backend.will_wait_for r.context e);
132-
Some data)
133-
else Some data
134-
else None)
127+
let[@landmark] pre () =
128+
Set.iter r.inputs ~f:(fun tn ->
129+
if Tn.potentially_cross_stream tn then
130+
Option.iter (Hashtbl.find s.device.shared_writer_streams tn) ~f:(fun data ->
131+
let data = List.filter data ~f:(fun (_, e) -> not (Backend.is_done e)) in
132+
Hashtbl.set s.device.shared_writer_streams ~key:tn ~data;
133+
List.iter data ~f:(fun (work_stream, e) ->
134+
if not (equal_stream work_stream s) then Backend.will_wait_for r.context e))
135+
else Hashtbl.remove s.device.shared_writer_streams tn)
135136
(* Since merge buffers are always per-stream, no need to check r.merge_buffer_input. *)
136137
in
137-
let post () =
138+
let[@landmark] post () =
138139
let e = Backend.all_work s in
139140
Set.iter r.outputs ~f:(fun tn -> update_writer_event ~e s @@ Node tn)
140141
in
141142
{ r with schedule = Task.(prepend ~work:pre @@ append ~work:post r.schedule) }
143+
144+
let[@landmark] sync_device device =
145+
Utils.weak_iter device.streams ~f:Backend.await;
146+
Hashtbl.clear device.host_writing_streams;
147+
Hashtbl.clear device.host_reading_streams;
148+
Hashtbl.clear device.shared_writer_streams;
149+
Utils.weak_iter device.streams ~f:(fun s ->
150+
Hashtbl.clear s.reader_streams;
151+
s.updating_for_merge_buffer <- None;
152+
Hashtbl.clear s.updating_for)
142153
end
143154

144155
let lower_assignments ?name bindings asgns =
@@ -268,20 +279,20 @@ module Add_device
268279
in
269280
(Option.value_exn ~here:[%here] bindings, schedules)
270281

271-
let from_host ~dst_ptr ~dst hosted =
282+
let[@landmark] from_host ~dst_ptr ~dst hosted =
272283
let work () = host_to_buffer hosted ~dst:dst_ptr in
273284
(* TODO: pass description to from_host. *)
274285
schedule_task dst.stream
275286
(Task.Task
276287
{ context_lifetime = dst; description = "from_host on " ^ get_name dst.stream; work })
277288

278-
let to_host ~src_ptr ~src hosted =
289+
let[@landmark] to_host ~src_ptr ~src hosted =
279290
let work () = buffer_to_host hosted ~src:src_ptr in
280291
(* TODO: pass description to to_host. *)
281292
schedule_task src.stream
282293
(Task.Task { context_lifetime = src; description = "to_host on " ^ get_name src.stream; work })
283294

284-
let device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
295+
let[@landmark] device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
285296
let s = dst.stream in
286297
let size_in_bytes = Tnode.size_in_bytes tn in
287298
let work =
@@ -468,7 +479,7 @@ let reinitialize (module Backend : Backend) config =
468479
Stdlib.Gc.full_major ();
469480
Backend.initialize config)
470481

471-
let%track3_sexp finalize (type buffer_ptr dev runner event)
482+
let[@landmark] finalize (type buffer_ptr dev runner event)
472483
(module Backend : Backend
473484
with type buffer_ptr = buffer_ptr
474485
and type dev = dev

arrayjit/lib/cuda_backend.cudajit.ml

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ let _get_local_debug_runtime = Utils._get_local_debug_runtime
1313
let () =
1414
Cu.cuda_call_hook :=
1515
Some
16-
(fun ~message ~status ->
16+
(fun ~message:_message ~status:_status ->
1717
[%debug_l_sexp
1818
[%log5_block
19-
message;
20-
if not @@ Cu.is_success status then [%log (status : Cu.result)]]])
19+
_message;
20+
if not @@ Cu.is_success _status then [%log (_status : Cu.result)]]])
2121

2222
let _suspended () =
2323
Cu.cuda_call_hook := Some (fun ~message ~status:_ -> Stdlib.Printf.printf "CUDA %s\n" message)
@@ -149,11 +149,10 @@ let%track3_sexp get_device ~(ordinal : int) : device =
149149
if Atomic.get result.released then default () else result
150150

151151
let%track3_sexp new_stream (device : device) : stream =
152-
device.latest_stream_id <- device.latest_stream_id + 1;
153152
(* Strange that we need ctx_set_current even with a single device! *)
154153
set_ctx device.dev.primary_context;
155154
let cu_stream = Cu.Stream.create ~non_blocking:true () in
156-
make_stream device cu_stream ~stream_id:device.latest_stream_id
155+
make_stream device cu_stream
157156

158157
let cuda_properties =
159158
let cache =
@@ -173,24 +172,24 @@ let suggested_num_streams device =
173172
| For_parallel_copying -> 1 + (cuda_properties device).async_engine_count
174173
| Most_parallel_streams -> (cuda_properties device).multiprocessor_count
175174

176-
let await stream : unit =
175+
let[@landmark] await stream : unit =
177176
set_ctx stream.device.dev.primary_context;
178177
Cu.Stream.synchronize stream.runner;
179178
Option.iter !Utils.advance_captured_logs ~f:(fun callback -> callback ())
180179

181180
let is_idle stream = Cu.Stream.is_ready stream.runner
182181

183-
let from_host ~dst_ptr ~dst hosted =
182+
let[@landmark] from_host ~dst_ptr ~dst hosted =
184183
set_ctx @@ ctx_of dst;
185184
let f src = Cu.Stream.memcpy_H_to_D ~dst:dst_ptr ~src dst.stream.runner in
186185
Ndarray.map { f } hosted
187186

188-
let to_host ~src_ptr ~src hosted =
187+
let[@landmark] to_host ~src_ptr ~src hosted =
189188
set_ctx @@ ctx_of src;
190189
let f dst = Cu.Stream.memcpy_D_to_H ~dst ~src:src_ptr src.stream.runner in
191190
Ndarray.map { f } hosted
192191

193-
let device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
192+
let[@landmark] device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
194193
let dev = dst.stream.device in
195194
let same_device = dev.ordinal = src.stream.device.ordinal in
196195
let size_in_bytes = Tn.size_in_bytes tn in
@@ -248,6 +247,8 @@ let%diagn2_sexp cuda_to_ptx ~name cu_src =
248247
let options =
249248
"--use_fast_math" :: (if Utils.with_runtime_debug () then [ "--device-debug" ] else [])
250249
in
250+
(* FIXME: every now and then the compilation crashes because the options are garbled. *)
251+
(* Stdio.printf "PTX options %s\n%!" @@ String.concat ~sep:", " options; *)
251252
let ptx = Cu.Nvrtc.compile_to_ptx ~cu_src ~name:name_cu ~options ~with_debug in
252253
if Utils.settings.output_debug_files_in_build_directory then (
253254
let oc = Out_channel.open_text @@ Utils.build_file @@ name ^ ".ptx" in

arrayjit/lib/dune

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
ctypes
1313
ctypes.foreign
1414
saturn_lockfree
15+
landmarks
1516
(select
1617
gcc_backend.ml
1718
from
@@ -31,6 +32,7 @@
3132
ppx_sexp_conv
3233
ppx_string
3334
ppx_variants_conv
35+
landmarks-ppx
3436
ppx_minidebug))
3537
(modules
3638
utils

0 commit comments

Comments
 (0)