Skip to content

Commit 9ca154d

Browse files
committed
Verifying merge nodes should happen at schedule time
1 parent 11102da commit 9ca154d

File tree

7 files changed

+54
-39
lines changed

7 files changed

+54
-39
lines changed

arrayjit/lib/backend_impl.ml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ struct
101101
dev;
102102
ordinal;
103103
shared_merge_buffer = None;
104+
scheduled_shared_merge_node = None;
104105
latest_stream_id = -1;
105106
released = Atomic.make false;
106107
cross_stream_candidates = Hashtbl.create (module Tnode);
@@ -113,6 +114,7 @@ struct
113114
device;
114115
runner;
115116
merge_buffer = ref None;
117+
scheduled_merge_node = None;
116118
stream_id;
117119
allocated_buffer = None;
118120
queried_work_for = Hashtbl.create (module Tnode);
@@ -176,7 +178,7 @@ module type Lowered_no_device_backend = sig
176178
linking the code. *)
177179

178180
val link_compiled :
179-
merge_buffer:(buffer_ptr * Tnode.t) option ref ->
181+
merge_buffer:buffer option ref ->
180182
runner_label:string ->
181183
ctx_arrays ->
182184
procedure ->

arrayjit/lib/backend_intf.ml

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ type ('buffer_ptr, 'dev, 'event) device = {
8282
dev : 'dev;
8383
ordinal : int;
8484
mutable shared_merge_buffer : 'buffer_ptr buffer option;
85+
(** Depending on backend implementations, either the currently used cross-stream merge buffer,
86+
or the one most recently scheduled. *)
87+
mutable scheduled_shared_merge_node : Tnode.t option;
88+
(** The tensor node that was most recently scheduled to be in the cross-stream merge buffer. *)
8589
mutable latest_stream_id : int;
8690
released : Utils.atomic_bool;
8791
cross_stream_candidates : 'buffer_ptr Hashtbl.M(Tnode).t;
@@ -99,7 +103,11 @@ type ('buffer_ptr, 'dev, 'event) device = {
99103
type ('buffer_ptr, 'dev, 'runner, 'event) stream = {
100104
device : ('buffer_ptr, 'dev, 'event) device;
101105
runner : 'runner;
102-
merge_buffer : ('buffer_ptr * Tnode.t) option ref;
106+
merge_buffer : 'buffer_ptr buffer option ref;
107+
(** Depending on backend implementations, either the currently used merge buffer, or the one
108+
most recently scheduled. *)
109+
mutable scheduled_merge_node : Tnode.t option;
110+
(** The tensor node that was most recently scheduled to be in the [stream]'s merge buffer. *)
103111
stream_id : int;
104112
mutable allocated_buffer : 'buffer_ptr buffer option;
105113
queried_work_for : 'event option Hashtbl.M(Tnode).t;
@@ -109,10 +117,6 @@ type ('buffer_ptr, 'dev, 'runner, 'event) stream = {
109117
}
110118
[@@deriving sexp_of]
111119

112-
(** [scheduled_merge_node stream] is the tensor node that would be in the [stream]'s merge buffer
113-
right after [await stream]. *)
114-
let scheduled_merge_node stream = Option.map ~f:snd !(stream.merge_buffer)
115-
116120
type ('buffer_ptr, 'stream) context = {
117121
stream : 'stream;
118122
parent : ('buffer_ptr, 'stream) context option;

arrayjit/lib/backends.ml

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,16 @@ let _get_local_debug_runtime = Utils._get_local_debug_runtime
99
[%%global_debug_log_level 9]
1010
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]
1111

12-
let check_merge_buffer ~scheduled_node ~code_node =
12+
let check_merge_buffer stream ~code_node =
1313
let name = function Some tn -> Tnode.debug_name tn | None -> "none" in
14-
match (scheduled_node, code_node) with
14+
match (stream.scheduled_merge_node, code_node) with
1515
| _, None -> ()
1616
| Some actual, Some expected when Tnode.equal actual expected -> ()
1717
| _ ->
1818
raise
1919
@@ Utils.User_error
20-
("Merge buffer mismatch, on stream: " ^ name scheduled_node ^ ", expected by code: "
21-
^ name code_node)
20+
("Merge buffer mismatch, on stream: " ^ name stream.scheduled_merge_node
21+
^ ", expected by code: " ^ name code_node)
2222

2323
module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncing) = struct
2424
let work_for context tn =
@@ -177,11 +177,17 @@ module Add_device
177177
let link context (code : code) ctx_arrays =
178178
let runner_label = get_name context.stream in
179179
let merge_buffer = context.stream.merge_buffer in
180-
match code with
181-
| Postponed { lowered; bindings; name } ->
182-
let proc = Backend.compile ~name ~opt_ctx_arrays:(Some ctx_arrays) bindings lowered in
183-
link_compiled ~merge_buffer ~runner_label ctx_arrays proc
184-
| Compiled { proc; _ } -> link_compiled ~merge_buffer ~runner_label ctx_arrays proc
180+
let bindings, to_schedule =
181+
match code with
182+
| Postponed { lowered; bindings; name } ->
183+
let proc = Backend.compile ~name ~opt_ctx_arrays:(Some ctx_arrays) bindings lowered in
184+
link_compiled ~merge_buffer ~runner_label ctx_arrays proc
185+
| Compiled { proc; _ } -> link_compiled ~merge_buffer ~runner_label ctx_arrays proc
186+
in
187+
let schedule =
188+
Task.enschedule ~schedule_task ~get_stream_name:get_name context.stream to_schedule
189+
in
190+
(bindings, schedule)
185191

186192
let link_batch context (code_batch : code_batch) ctx_arrays =
187193
let runner_label = get_name context.stream in
@@ -196,8 +202,13 @@ module Add_device
196202
Array.fold_mapi procs ~init:None ~f:(fun i bindings -> function
197203
| Some proc ->
198204
let ctx_arrays = Option.value_exn ctx_arrays.(i) in
199-
let bindings', schedule = link_compiled ~merge_buffer ~runner_label ctx_arrays proc in
205+
let bindings', to_schedule =
206+
link_compiled ~merge_buffer ~runner_label ctx_arrays proc
207+
in
200208
Option.iter bindings ~f:(fun bindings -> assert (phys_equal bindings bindings'));
209+
let schedule =
210+
Task.enschedule ~schedule_task ~get_stream_name:get_name context.stream to_schedule
211+
in
201212
(Some bindings', Some schedule)
202213
| None -> (bindings, None))
203214
in
@@ -219,12 +230,13 @@ module Add_device
219230
let device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
220231
let dev = dst.stream in
221232
let size_in_bytes = Tnode.size_in_bytes tn in
233+
(* FIXME(#290): handle shared_merge_node. *)
222234
let work =
223-
(* TODO: log the operation if [Utils.settings.with_log_level > 0]. *)
235+
(* TODO: log the operation if [Utils.settings.with_log_level > 1]. *)
224236
match (into_merge_buffer, dst_ptr) with
225237
| No, None -> invalid_arg "Multicore_scheduler.device_to_device: missing dst_ptr"
226238
| No, Some dst_ptr -> fun () -> buffer_to_buffer ~dst:dst_ptr ~src:src_ptr ~size_in_bytes
227-
| Streaming, _ -> fun () -> dev.merge_buffer := Some (src_ptr, tn)
239+
| Streaming, _ -> fun () -> dev.merge_buffer := Some { ptr = src_ptr; size_in_bytes }
228240
| Copy, _ ->
229241
fun () ->
230242
let size_in_bytes = Tnode.size_in_bytes tn in
@@ -235,13 +247,14 @@ module Add_device
235247
dev.allocated_buffer <-
236248
Some (alloc_buffer ?old_buffer:dev.allocated_buffer ~size_in_bytes dst.stream);
237249
let merge_ptr = (Option.value_exn dev.allocated_buffer).ptr in
238-
dev.merge_buffer := Some (merge_ptr, tn);
250+
dev.merge_buffer := dev.allocated_buffer;
239251
buffer_to_buffer ~dst:merge_ptr ~src:src_ptr ~size_in_bytes
240252
in
241253
let description =
242254
"device_to_device " ^ Tnode.debug_name tn ^ " dst " ^ get_name dev ^ " src "
243255
^ get_name src.stream
244256
in
257+
(match into_merge_buffer with No -> () | _ -> dev.scheduled_merge_node <- Some tn);
245258
schedule_task dev (Task.Task { context_lifetime = (src, dst); description; work })
246259
end
247260

@@ -344,9 +357,7 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
344357
let context = make_child ~ctx_arrays context in
345358
let schedule =
346359
Task.prepend schedule ~work:(fun () ->
347-
check_merge_buffer
348-
~scheduled_node:(scheduled_merge_node context.stream)
349-
~code_node:code.expected_merge_node)
360+
check_merge_buffer context.stream ~code_node:code.expected_merge_node)
350361
in
351362
{ context; schedule; bindings; name = code.name; inputs; outputs }
352363

@@ -372,9 +383,7 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
372383
in
373384
let schedule =
374385
Task.prepend schedule ~work:(fun () ->
375-
check_merge_buffer
376-
~scheduled_node:(scheduled_merge_node context.stream)
377-
~code_node:expected_merge_node)
386+
check_merge_buffer context.stream ~code_node:expected_merge_node)
378387
in
379388
(context, Some { context; schedule; bindings; name; inputs; outputs }))
380389
end

arrayjit/lib/cc_backend.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ let%track3_sexp link_compiled ~merge_buffer ~runner_label ctx_arrays (code : pro
164164
| bs, Log_file_name :: ps ->
165165
Param_1 (ref (Some log_file_name), link bs ps Ctypes.(string @-> cs))
166166
| bs, Merge_buffer :: ps ->
167-
let get_ptr (ptr, _tn) = ptr in
167+
let get_ptr buf = buf.ptr in
168168
Param_2f (get_ptr, merge_buffer, link bs ps Ctypes.(ptr void @-> cs))
169169
| bs, Param_ptr tn :: ps ->
170170
let c_ptr =

arrayjit/lib/cuda_backend.cudajit.ml

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -194,13 +194,13 @@ let to_host ~src_ptr ~src hosted =
194194
let device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
195195
let dev = dst.stream.device in
196196
let same_device = dev.ordinal = src.stream.device.ordinal in
197+
let size_in_bytes = Tn.size_in_bytes tn in
197198
let memcpy ~dst_ptr =
198199
if same_device then
199-
Cu.Stream.memcpy_D_to_D ~size_in_bytes:(Tn.size_in_bytes tn) ~dst:dst_ptr ~src:src_ptr
200-
dst.stream.runner
200+
Cu.Stream.memcpy_D_to_D ~size_in_bytes ~dst:dst_ptr ~src:src_ptr dst.stream.runner
201201
else
202-
Cu.Stream.memcpy_peer ~size_in_bytes:(Tn.size_in_bytes tn) ~dst:dst_ptr ~dst_ctx:(ctx_of dst)
203-
~src:src_ptr ~src_ctx:(ctx_of src) dst.stream.runner
202+
Cu.Stream.memcpy_peer ~size_in_bytes ~dst:dst_ptr ~dst_ctx:(ctx_of dst) ~src:src_ptr
203+
~src_ctx:(ctx_of src) dst.stream.runner
204204
in
205205
match (into_merge_buffer, dst_ptr) with
206206
| No, None -> invalid_arg "Cuda_backend.device_to_device: missing dst_ptr"
@@ -209,14 +209,16 @@ let device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
209209
memcpy ~dst_ptr
210210
| Streaming, _ ->
211211
assert same_device;
212-
dst.stream.merge_buffer := Some (src_ptr, tn)
212+
dst.stream.scheduled_merge_node <- Some tn;
213+
dst.stream.merge_buffer := Some { ptr = src_ptr; size_in_bytes }
213214
| Copy, _ ->
214215
set_ctx @@ ctx_of dst;
215216
let size_in_bytes = Tn.size_in_bytes tn in
216217
opt_alloc_merge_buffer ~size_in_bytes dev;
217-
(* FIXME: why use the shared buffer? *)
218+
(* FIXME(#290): why use the shared buffer? This should depend on the memory mode! *)
218219
Option.iter dev.shared_merge_buffer ~f:(fun buffer -> memcpy ~dst_ptr:buffer.ptr);
219-
dst.stream.merge_buffer := Option.map dev.shared_merge_buffer ~f:(fun buf -> (buf.ptr, tn))
220+
dst.stream.scheduled_merge_node <- Some tn;
221+
dst.stream.merge_buffer := dev.shared_merge_buffer
220222

221223
type code = {
222224
traced_store : Low_level.traced_store;
@@ -272,7 +274,6 @@ struct
272274
let procs = Input.procs
273275
let hardcoded_context_ptr = None
274276
let use_host_memory = use_host_memory
275-
276277
let logs_to_stdout = true
277278
let main_kernel_prefix = "extern \"C\" __global__"
278279

@@ -402,8 +403,8 @@ let link_proc ~prior_context ~name ~(params : (string * param_source) list) ~ctx
402403
S.Tensor arr
403404
| _name, Log_file_name -> S.Int log_id
404405
| _name, Merge_buffer ->
405-
let ptr = fst @@ Option.value_exn ~here:[%here] !(stream.merge_buffer) in
406-
S.Tensor ptr
406+
let buf = Option.value_exn ~here:[%here] !(stream.merge_buffer) in
407+
S.Tensor buf.ptr
407408
| _name, Static_idx s ->
408409
let i = Indexing.find_exn lowered_bindings s in
409410
if !i < 0 then

arrayjit/lib/gcc_backend.gccjit.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ let%track3_sexp link_compiled ~merge_buffer ~runner_label ctx_arrays (code : pro
678678
in
679679
Param_2 (ref (Some c_ptr), link bs ps Ctypes.(ptr void @-> cs))
680680
| bs, Merge_buffer :: ps ->
681-
let get_ptr (ptr, _tn) = ptr in
681+
let get_ptr buf = buf.Backend_intf.ptr in
682682
Param_2f (get_ptr, merge_buffer, link bs ps Ctypes.(ptr void @-> cs))
683683
in
684684
(* Folding by [link] above reverses the input order. Important: [code.bindings] are traversed

arrayjit/lib/task.ml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,9 @@ let%track3_l_sexp enschedule ~schedule_task ~get_stream_name stream
3131
(Task { description; _ } as task) =
3232
[%log_result "enschedule", description, "on", get_stream_name stream];
3333
let work () = schedule_task stream task in
34-
(* TODO: keeping [task] in [context_lifetime] is redundant because it's captured by [work]. *)
3534
Task
3635
{
37-
context_lifetime = task;
36+
context_lifetime = ();
3837
description = "schedules {" ^ description ^ "} on " ^ get_stream_name stream;
3938
work;
4039
}

0 commit comments

Comments
 (0)