Skip to content

Commit d16b54f

Browse files
committed
(1) Get rid of the option to share merge buffers, (2) refactor tracking merge buffer events
-- formerly `~into_merge_buffer:Streaming` would not generate an event, but it should to prevent overriding the source. (2) will be continued: prohibiting overriding till the routine using the streamed merge buffer finishes.
1 parent b9987fa commit d16b54f

File tree

8 files changed

+65
-73
lines changed

8 files changed

+65
-73
lines changed

arrayjit/lib/anatomy_of_a_backend.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,6 @@ When using the default stream, CUDA would predictably write to the standard outp
144144
OCANNL expects backends to implement FIFO queue scheduling, and an event mechanism for synchronizing between streams (and ideally devices), matching the CUDA specification. On top of events, OCANNL implements per-tensor-node synchronization. 1/3rd of the `device` fields have to do with synchronization:
145145

146146
```ocaml
147-
mutable scheduled_shared_merge_node : (Tnode.t * 'event option) option;
148-
(** The tensor node that was most recently scheduled to be in the cross-stream merge buffer,
149-
and its readiness event. *)
150147
shared_writer_streams :
151148
(('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t;
152149
(** The streams that most recently have been scheduled to update (write to) a
@@ -162,7 +159,7 @@ OCANNL expects backends to implement FIFO queue scheduling, and an event mechani
162159
events are removed opportunistically. *)
163160
```
164161

165-
and 1/3rd of the stream fields also:
162+
and some stream fields also:
166163

167164
```ocaml
168165
updating_for : 'event Hashtbl.M(Tnode).t;
@@ -175,6 +172,8 @@ and 1/3rd of the stream fields also:
175172
removed opportunistically. *)
176173
```
177174

175+
While we never share merge buffers across streams, there is always an event associated with an occupied merge buffer. Its primary use is for tracking the merge buffer's stream as a reader on the source stream.
176+
178177
Besides routines, calling `from_host`, `to_host`, `device_to_device` from a backend puts the corresponding tasks on the device's queue. Both invoking a routine and calling these copying functions will perform the necessary event creations and synchronizations to ensure that when scheduling writing into an array precedes scheduling reading from it, the actual writing also precedes the actual reading.
179178

180179
### Data transfers

arrayjit/lib/backend_impl.ml

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,6 @@ struct
101101
{
102102
dev;
103103
ordinal;
104-
shared_merge_buffer = None;
105-
scheduled_shared_merge_node = None;
106104
latest_stream_id = -1;
107105
released = Atomic.make false;
108106
cross_stream_candidates = Hashtbl.create (module Tnode);
@@ -117,7 +115,6 @@ struct
117115
device;
118116
runner;
119117
merge_buffer = ref None;
120-
scheduled_merge_node = None;
121118
stream_id;
122119
allocated_buffer = None;
123120
updating_for = Hashtbl.create (module Tnode);
@@ -202,7 +199,7 @@ module type No_buffer_retrieval_or_syncing = sig
202199
(** Like {!Backend.from_host}, but without synchronization and buffer retrieval. *)
203200

204201
val to_host : src_ptr:buffer_ptr -> src:context -> Ndarray.t -> unit
205-
(** Like {!Backend.to_host}, but without synchronization and buffer retrieval. *)
202+
(** Like {!Backend.to_host}, but without synchronization events and buffer retrieval. *)
206203

207204
val device_to_device :
208205
Tnode.t ->
@@ -212,8 +209,8 @@ module type No_buffer_retrieval_or_syncing = sig
212209
src_ptr:buffer_ptr ->
213210
src:context ->
214211
unit
215-
(** Like {!Backend.device_to_device}, but without synchronization and buffer retrieval. Raises
216-
[Invalid_argument] if [into_merge_buffer = No] and [dst_ptr = None]. *)
212+
(** Like {!Backend.device_to_device}, but without synchronization events and buffer retrieval.
213+
Raises [Invalid_argument] if [into_merge_buffer = No] and [dst_ptr = None]. *)
217214
end
218215

219216
(** An intermediate stage for converting {!Lowered_no_device_backend} backends into

arrayjit/lib/backend_intf.ml

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ type 'context routine = {
5757
inputs : Set.M(Tnode).t;
5858
(** The materialized read-only and read-before-write (within the routine) non-constant nodes.
5959
They are inputs in a broad sense, as they could be recurrent nodes or parameters. *)
60-
merge_buffer_input : bool; (** Similar to {!field-inputs}, for the merge buffer. *)
60+
merge_buffer_input : Tnode.t option; (** Similar to {!field-inputs}, for the merge buffer. *)
6161
outputs : Set.M(Tnode).t; (** All the materialized nodes written-to by the routine. *)
6262
}
6363
[@@deriving sexp_of]
@@ -82,8 +82,6 @@ end
8282
type ('buffer_ptr, 'dev, 'runner, 'event) device_ref = {
8383
dev : 'dev;
8484
ordinal : int;
85-
mutable shared_merge_buffer : 'buffer_ptr buffer option;
86-
mutable scheduled_shared_merge_node : (Tnode.t * 'event option) option;
8785
mutable latest_stream_id : int;
8886
released : Utils.atomic_bool;
8987
cross_stream_candidates : 'buffer_ptr Hashtbl.M(Tnode).t;
@@ -100,7 +98,6 @@ and ('buffer_ptr, 'dev, 'runner, 'event) stream_ref = {
10098
device : ('buffer_ptr, 'dev, 'runner, 'event) device_ref;
10199
runner : 'runner;
102100
merge_buffer : 'buffer_ptr buffer option ref;
103-
mutable scheduled_merge_node : Tnode.t option;
104101
stream_id : int;
105102
mutable allocated_buffer : 'buffer_ptr buffer option;
106103
updating_for : 'event Hashtbl.M(Tnode).t;
@@ -117,11 +114,6 @@ type ('buffer_ptr, 'dev, 'runner, 'event) device =
117114
('buffer_ptr, 'dev, 'runner, 'event) device_ref = {
118115
dev : 'dev;
119116
ordinal : int;
120-
mutable shared_merge_buffer : 'buffer_ptr buffer option;
121-
(** Depending on backend implementations, either the currently used cross-stream merge buffer,
122-
or the one most recently scheduled. *)
123-
mutable scheduled_shared_merge_node : (Tnode.t * 'event option) option;
124-
(** The tensor node that was most recently scheduled to be in the cross-stream merge buffer. *)
125117
mutable latest_stream_id : int;
126118
released : Utils.atomic_bool;
127119
cross_stream_candidates : 'buffer_ptr Hashtbl.M(Tnode).t;
@@ -153,15 +145,15 @@ type ('buffer_ptr, 'dev, 'runner, 'event) stream =
153145
runner : 'runner;
154146
merge_buffer : 'buffer_ptr buffer option ref;
155147
(** Depending on backend implementations, either the currently used merge buffer, or the one
156-
most recently scheduled. *)
157-
mutable scheduled_merge_node : Tnode.t option;
158-
(** The tensor node that was most recently scheduled to be in the [stream]'s merge buffer. *)
148+
most recently scheduled. Note that the pointer can be reused for nodes that fit in an
149+
already allocated buffer. *)
159150
stream_id : int; (** An ID unique within the device. *)
160151
mutable allocated_buffer : 'buffer_ptr buffer option;
161152
updating_for : 'event Hashtbl.M(Tnode).t;
162-
(* The completion event for updating (writing to) a node via this stream, if any. *)
153+
(* The completion event for the most recent updating (writing to) a node via this stream. *)
163154
mutable updating_for_merge_buffer : (Tnode.t * 'event) option;
164-
(** Like {!field-updating_for}, but for the merge buffer. *)
155+
(** The tensor node that was most recently scheduled to be in the [stream]'s merge buffer and
156+
its updating completion event. See also {!field-updating_for}. *)
165157
reader_streams :
166158
(('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) list Hashtbl.M(Tnode).t;
167159
(** The streams, other than this stream, that most recently have been reading from a node in

arrayjit/lib/backends.ml

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@ let _get_local_debug_runtime = Utils._get_local_debug_runtime
1111

1212
let check_merge_buffer stream ~code_node =
1313
let name = function Some tn -> Tnode.debug_name tn | None -> "none" in
14-
match (stream.scheduled_merge_node, code_node) with
14+
match (stream.updating_for_merge_buffer, code_node) with
1515
| _, None -> ()
16-
| Some actual, Some expected when Tnode.equal actual expected -> ()
16+
| Some (actual, _), Some expected when Tnode.equal actual expected -> ()
1717
| _ ->
1818
raise
1919
@@ Utils.User_error
20-
("Merge buffer mismatch, on stream: " ^ name stream.scheduled_merge_node
21-
^ ", expected by code: " ^ name code_node)
20+
("Merge buffer mismatch, on stream: "
21+
^ name (Option.map ~f:fst stream.updating_for_merge_buffer)
22+
^ ", expected by code: " ^ name code_node)
2223

2324
module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncing) = struct
2425
let wait_for_all ctx streams tn =
@@ -54,8 +55,7 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
5455
(s, e) :: Option.value ~default:[] l);
5556
Hashtbl.update s.updating_for tn ~f:(fun _ -> e)
5657
| Merge_buffer tn ->
57-
Option.iter s.updating_for_merge_buffer ~f:(fun (_old_tn, old_e) ->
58-
assert (Backend.is_done old_e));
58+
(* Note: the previous event does not need to be done! *)
5959
s.updating_for_merge_buffer <- Some (tn, e)
6060

6161
let%diagn2_l_sexp from_host (ctx : Backend.context) tn =
@@ -105,31 +105,32 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
105105
update_writer_event ~from:(`Src src.stream) dst.stream @@ Node tn;
106106
[%log "copying", Tn.debug_name tn, "from", name_of src, "to", name_of dst];
107107
true)
108-
| Streaming when same_device ->
109-
Backend.(
110-
device_to_device tn ~into_merge_buffer ~dst_ptr:None ~dst ~src_ptr:s_arr ~src);
111-
[%log "using merge buffer for", Tn.debug_name tn, "from", name_of src];
112-
true
113108
| Copy | Streaming ->
114109
Backend.(
115110
device_to_device tn ~into_merge_buffer ~dst_ptr:None ~dst ~src_ptr:s_arr ~src);
116111
update_writer_event ~from:(`Src src.stream) dst.stream @@ Merge_buffer tn;
117-
[%log "copying into merge buffer", Tn.debug_name tn, "from", name_of src];
112+
let use =
113+
match into_merge_buffer with
114+
| Copy -> "copying"
115+
| Streaming -> "streaming"
116+
| No -> assert false
117+
in
118+
[%log use, "into merge buffer", Tn.debug_name tn, "from", name_of src];
118119
true)
119120

120121
let%track3_l_sexp sync_routine r =
121122
let s = r.context.stream in
122123
let pre () =
123-
Hashtbl.iteri s.device.shared_writer_streams ~f:(fun ~key ~data ->
124-
if Set.mem r.inputs key then
125-
List.iter data ~f:(fun (work_stream, e) ->
126-
if not (equal_stream work_stream s) then Backend.will_wait_for r.context e));
127-
if r.merge_buffer_input then
128-
Option.iter s.device.scheduled_shared_merge_node ~f:(fun (shared_tn, e) ->
129-
match (s.scheduled_merge_node, e) with
130-
| Some merge_tn, Some e ->
131-
if Tn.equal shared_tn merge_tn then Backend.will_wait_for r.context e
132-
| _ -> ())
124+
Hashtbl.filter_mapi_inplace s.device.shared_writer_streams ~f:(fun ~key ~data ->
125+
if Tn.potentially_cross_stream key then
126+
if Set.mem r.inputs key then (
127+
let data = List.filter data ~f:(fun (_, e) -> Backend.is_done e) in
128+
List.iter data ~f:(fun (work_stream, e) ->
129+
if not (equal_stream work_stream s) then Backend.will_wait_for r.context e);
130+
Some data)
131+
else Some data
132+
else None)
133+
(* Since merge buffers are always per-stream, no need to check r.merge_buffer_input. *)
133134
in
134135
let post () =
135136
let e = Backend.all_work s in
@@ -281,7 +282,6 @@ module Add_device
281282
let device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
282283
let dev = dst.stream in
283284
let size_in_bytes = Tnode.size_in_bytes tn in
284-
(* FIXME(#290): handle shared_merge_node. *)
285285
let work =
286286
(* TODO: log the operation if [Utils.settings.with_log_level > 1]. *)
287287
match (into_merge_buffer, dst_ptr) with
@@ -305,7 +305,6 @@ module Add_device
305305
"device_to_device " ^ Tnode.debug_name tn ^ " dst " ^ get_name dev ^ " src "
306306
^ get_name src.stream
307307
in
308-
(match into_merge_buffer with No -> () | _ -> dev.scheduled_merge_node <- Some tn);
309308
schedule_task dev (Task.Task { context_lifetime = (src, dst); description; work })
310309
end
311310

arrayjit/lib/cuda_backend.cudajit.ml

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -101,22 +101,20 @@ let get_used_memory (device : device) =
101101
let free, total = Cudajit.Device.get_free_and_total_mem () in
102102
total - free
103103

104-
let opt_alloc_merge_buffer ~size_in_bytes device : unit =
104+
let opt_alloc_merge_buffer ~size_in_bytes dev stream : unit =
105105
if
106-
Option.value_map ~default:true device.shared_merge_buffer ~f:(fun buffer ->
106+
Option.value_map ~default:true !(stream.merge_buffer) ~f:(fun buffer ->
107107
buffer.size_in_bytes < size_in_bytes)
108108
then (
109-
set_ctx device.dev.primary_context;
110-
Option.iter device.shared_merge_buffer ~f:(fun buffer -> Cu.Deviceptr.mem_free buffer.ptr);
111-
device.shared_merge_buffer <-
112-
Some { ptr = Cu.Deviceptr.mem_alloc ~size_in_bytes; size_in_bytes })
109+
set_ctx dev.primary_context;
110+
Option.iter !(stream.merge_buffer) ~f:(fun buffer -> Cu.Deviceptr.mem_free buffer.ptr);
111+
stream.merge_buffer := Some { ptr = Cu.Deviceptr.mem_alloc ~size_in_bytes; size_in_bytes })
113112

114113
let%track3_sexp cleanup_device (device : device) =
115114
Cu.Context.set_current device.dev.primary_context;
116115
Cu.Context.synchronize ();
117116
Option.iter !Utils.advance_captured_logs ~f:(fun callback -> callback ());
118117
(* Note: this is not necessary as releasing the primary context by GC will reset the context. *)
119-
Option.iter ~f:(fun buffer -> Cu.Deviceptr.mem_free buffer.ptr) device.shared_merge_buffer;
120118
Hashtbl.iter device.cross_stream_candidates ~f:(fun buffer_ptr ->
121119
Cu.Deviceptr.mem_free buffer_ptr)
122120

@@ -138,8 +136,6 @@ let%track3_sexp get_device ~(ordinal : int) : device =
138136
if Utils.debug_log_from_routines () && not (Hash_set.mem initialized_devices ordinal) then
139137
Option.iter Utils.settings.cuda_printf_fifo_size ~f:Cu.Context.(set_limit PRINTF_FIFO_SIZE);
140138
Hash_set.add initialized_devices ordinal;
141-
(* let size_in_bytes = 8 in let shared_merge_buffer = { ptr = Cu.Deviceptr.mem_alloc
142-
~size_in_bytes; size_in_bytes } in *)
143139
let result = make_device dev ~ordinal in
144140
Stdlib.Gc.finalise finalize_device result;
145141
Stdlib.Weak.set !devices ordinal (Some result);
@@ -209,16 +205,13 @@ let device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
209205
memcpy ~dst_ptr
210206
| Streaming, _ ->
211207
assert same_device;
212-
dst.stream.scheduled_merge_node <- Some tn;
213208
dst.stream.merge_buffer := Some { ptr = src_ptr; size_in_bytes }
214209
| Copy, _ ->
215210
set_ctx @@ ctx_of dst;
216211
let size_in_bytes = Tn.size_in_bytes tn in
217-
opt_alloc_merge_buffer ~size_in_bytes dev;
218-
(* FIXME(#290): why use the shared buffer? This should depend on the memory mode! *)
219-
Option.iter dev.shared_merge_buffer ~f:(fun buffer -> memcpy ~dst_ptr:buffer.ptr);
220-
dst.stream.scheduled_merge_node <- Some tn;
221-
dst.stream.merge_buffer := dev.shared_merge_buffer
212+
opt_alloc_merge_buffer ~size_in_bytes dev.dev dst.stream;
213+
let buffer = Option.value_exn ~here:[%here] !(dst.stream.merge_buffer) in
214+
memcpy ~dst_ptr:buffer.ptr
222215

223216
type code = {
224217
traced_store : Low_level.traced_store;

arrayjit/lib/low_level.ml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,15 @@ let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
199199
~f:(visit ~is_assigned:(traced.zeroed_out || Hash_set.mem traced.assignments at_pos))
200200
| Local_scope { body; _ } -> loop_proc env body
201201
| Get_local _ -> ()
202-
| Get_global (Ops.Merge_buffer { source_node_id }, _) -> merge_node_id := Some source_node_id
202+
| Get_global (Ops.Merge_buffer { source_node_id }, _) ->
203+
Option.iter !merge_node_id ~f:(fun merge_node_id ->
204+
if merge_node_id <> source_node_id then
205+
raise
206+
@@ Utils.User_error
207+
[%string
208+
"Low_evel.optimize_proc: currently only one merge buffer per routine is \
209+
allowed, found node ids %{source_node_id#Int} and %{merge_node_id#Int}"]);
210+
merge_node_id := Some source_node_id
203211
| Get_global _ -> ()
204212
| Embed_index _ -> ()
205213
| Binop (Arg1, llv1, _llv2) -> loop llv1
@@ -752,7 +760,7 @@ let input_and_output_nodes optimized =
752760
else outputs
753761
in
754762
(inputs, outputs)),
755-
Option.is_some optimized.merge_node )
763+
optimized.merge_node )
756764

757765
let%diagn2_sexp optimize_proc static_indices llc =
758766
let traced_store = Hashtbl.create (module Tnode) in

arrayjit/lib/low_level.mli

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,11 @@ val optimize :
9999
t ->
100100
optimized
101101

102-
val input_and_output_nodes : optimized -> (Set.M(Tnode).t * Set.M(Tnode).t) * bool
102+
val input_and_output_nodes : optimized -> (Set.M(Tnode).t * Set.M(Tnode).t) * Tnode.t option
103103
(** Inputs are the materialized read-only and read-before-write (within the code) non-constant
104-
nodes. They are inputs in a broad sense, as they could be recurrent nodes or parameters.
105-
106-
Outputs are all the materialized nodes written-to by the code. *)
104+
non-merge nodes. They are inputs in a broad sense, as they could be recurrent nodes or
105+
parameters. Outputs are all the materialized nodes written-to by the code. The last returned
106+
component is the input merge node, if used in the code. *)
107107

108108
(** {2 Printing} *)
109109

lib/train.ml

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,8 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
512512
assert (Backend.to_host sgd_update.context learning_rate.value);
513513
(* scalar_loss is not in the sgd_update context. *)
514514
assert (Backend.to_host grad_updates.(0).context scalar_loss.value);
515+
(* TODO: syncing callbacks should be integrated into Tensor. *)
516+
Backend.await grad_updates.(0).context.stream;
515517
let batch_loss = scalar_loss.@[0] in
516518
epoch_loss := !epoch_loss +. batch_loss;
517519
batch_losses := batch_loss :: !batch_losses;
@@ -530,14 +532,16 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
530532
Option.iter per_epoch_callback ~f:(fun f ->
531533
f ~at_step:!step_ref ~at_epoch:epoch ~learning_rate:learning_rate.@[0]
532534
~epoch_loss:!epoch_loss);
533-
let debug_at pos =
535+
let _debug_at pos =
534536
Array.iter streams ~f:(fun s ->
535537
Stdlib.Format.printf "Stream %d debug %s:@ %a\n%!" s.stream_id pos Sexp.pp_hum
536538
@@ Backend.get_debug_info s)
537539
in
538-
if per_epoch_debug_streams then debug_at "before sync";
539-
Array.iter streams ~f:Backend.await;
540-
if per_epoch_debug_streams then debug_at "after sync"
540+
if per_epoch_debug_streams then _debug_at "before sync";
541+
(* TODO: there should be nothing pending left to sync. *)
542+
Array.iter streams ~f:Backend.await
543+
(* This is now cleaned up by await. *)
544+
(* if per_epoch_debug_streams then _debug_at "after sync" *)
541545
done;
542546
let%op model_result = model "infer" in
543547
let infer_fwd =

0 commit comments

Comments
 (0)