Skip to content

Commit 58b4a60

Browse files
committed
Rename await_ev -> sync; cuda backend: event API functions
(but not yet syncing on copy).
1 parent 112d458 commit 58b4a60

File tree

4 files changed

+71
-50
lines changed

4 files changed

+71
-50
lines changed

arrayjit/lib/backends.ml

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ module type Backend = sig
103103
values are used internally for scheduling across devices of the backend, and can be used for
104104
explicit scheduling. *)
105105

106-
val await_ev : event -> unit
106+
val sync : event -> unit
107107
(** Blocks till the event completes, if it's not done already. *)
108108

109109
val is_done : event -> bool
@@ -113,9 +113,8 @@ module type Backend = sig
113113
(** If the tensor node is in the context, returns the event indicating if currently running or
114114
scheduled computations modifying that node on the context's device have completed.
115115
116-
NOTE: [work_for ctx tn], if work tracking was not registered for [tn], will register work
117-
tracking for [tn] and return the event tracking all currently scheduled computations on
118-
[ctx]'s device. *)
116+
NOTE: [work_for ctx tn], if work tracking was not yet registered for [tn], will register work
117+
tracking for [tn] and return the [all_work] event for [ctx]'s device. *)
119118

120119
val will_wait_for : context -> event -> unit
121120
(** Schedules waiting for the given event on the context's device.
@@ -127,13 +126,13 @@ module type Backend = sig
127126
val from_host : context -> Tnode.t -> bool
128127
(** If the tensor node is both hosted and in-context, schedules a copy from host to context and
129128
returns true, otherwise returns false. NOTE: it's the caller's responsibility to synchronize
130-
the device (via [await ctx.device] or [await_ev (work_for ctx tn)]) before the host's data is
129+
the device (via [await ctx.device] or [sync (work_for ctx tn)]) before the host's data is
131130
overwritten. *)
132131

133132
val to_host : context -> Tnode.t -> bool
134133
(** If the tensor node is both hosted and in-context, schedules a copy from context to host and
135134
returns true, otherwise returns false. NOTE: it's the caller's responsibility to synchronize
136-
the device (via [await ctx.device] or [await_ev (work_for ctx tn)]) before the host's data is
135+
the device (via [await ctx.device] or [sync (work_for ctx tn)]) before the host's data is
137136
read. *)
138137

139138
val device_to_device :
@@ -205,7 +204,7 @@ module Multicore_backend (Backend : No_device_backend) : Backend = struct
205204
type event = Not_implemented_yet (** TODO: NOT IMPLEMENTED YET *)
206205

207206
(** TODO: Blocks till the event completes, if it's not done already. *)
208-
let await_ev Not_implemented_yet = ()
207+
let sync Not_implemented_yet = ()
209208

210209
(** TODO: Whether the event completed. *)
211210
let is_done Not_implemented_yet = true
@@ -553,7 +552,7 @@ module Sync_backend (Backend : No_device_backend) : Backend = struct
553552
type buffer_ptr = Backend.buffer_ptr [@@deriving sexp_of]
554553
type event = unit
555554

556-
let await_ev () = ()
555+
let sync () = ()
557556
let is_done () = true
558557
let work_for _context _tn = Some ()
559558
let will_wait_for _context () = ()
@@ -954,6 +953,9 @@ module Cuda_backend : Backend = struct
954953
type nonrec context = { ctx : context; expected_merge_node : Tnode.t option } [@@deriving sexp_of]
955954
type nonrec routine = context routine [@@deriving sexp_of]
956955

956+
let work_for context = work_for context.ctx
957+
let will_wait_for context = will_wait_for context.ctx
958+
957959
let compile ?shared:_ ?name bindings asgns : code =
958960
let name, lowered = lower_assignments ?name bindings asgns in
959961
{
@@ -993,20 +995,6 @@ module Cuda_backend : Backend = struct
993995
name;
994996
})) )
995997

996-
type event = Cudajit.Event.t
997-
998-
let work_for _ctx _tn = Some (Cudajit.Event.create ())
999-
(* TODO: NOT IMPLEMENTED YET *)
1000-
1001-
let is_done event = Cudajit.Event.query event
1002-
let will_wait_for _context _event = ()
1003-
(* Cudajit.Event.wait (get_ctx_device context.ctx).Cuda_backend.stream event *)
1004-
(* TODO: NOT IMPLEMENTED YET *)
1005-
1006-
let await_ev event = Cudajit.Event.synchronize event
1007-
let all_work _device = Cudajit.Event.create ()
1008-
(* TODO: NOT IMPLEMENTED YET *)
1009-
1010998
let init device = { ctx = init device; expected_merge_node = None }
1011999
let get_ctx_device context = get_ctx_device context.ctx
10121000
let finalize context = finalize context.ctx

arrayjit/lib/cuda_backend.cudajit.ml

Lines changed: 47 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,10 @@ type buffer_ptr = Cu.Deviceptr.t
3333

3434
let sexp_of_buffer_ptr ptr = Sexp.Atom (Cu.Deviceptr.string_of ptr)
3535

36-
type ctx_array = Cu.Deviceptr.t
36+
type event = Cu.Delimited_event.t
3737

38-
let sexp_of_ctx_array = sexp_of_buffer_ptr
38+
type ctx_array = { ptr : buffer_ptr; mutable tracking : (event[@sexp.opaque]) option }
39+
[@@deriving sexp_of]
3940

4041
type physical_device = {
4142
dev : (Cu.Device.t[@sexp.opaque]);
@@ -83,6 +84,19 @@ and context = {
8384
let ctx_arrays ctx = ctx.ctx_arrays
8485
let global_config = ref For_parallel_copying
8586

87+
let work_for ctx tn =
88+
match Map.find ctx.ctx_arrays tn with
89+
| None -> None
90+
| Some { tracking = Some event; _ } -> Some event
91+
| Some ctx_array ->
92+
ctx_array.tracking <- Some (Cu.Delimited_event.record ctx.device.stream);
93+
ctx_array.tracking
94+
95+
let is_done event = Cu.Delimited_event.query event
96+
let will_wait_for context event = Cu.Delimited_event.wait context.device.stream event
97+
let sync event = Cu.Delimited_event.synchronize event
98+
let all_work device = Cu.Delimited_event.record device.stream
99+
86100
let is_initialized, initialize =
87101
let initialized = ref false in
88102
let init (config : config) : unit =
@@ -146,7 +160,7 @@ let get_device ~(ordinal : int) : physical_device =
146160
copy_merge_buffer;
147161
copy_merge_buffer_capacity;
148162
released = Atomic.make false;
149-
cross_device_candidates = Hashtbl.create (module Tn);
163+
cross_device_candidates = (Hashtbl.create (module Tn) : ctx_array Hashtbl.M(Tn).t);
150164
cross_device_shared = Hash_set.create (module Tn);
151165
non_cross_device = Hash_set.create (module Tn);
152166
owner_device_subordinal = Hashtbl.create (module Tn);
@@ -203,13 +217,13 @@ let finalize ctx =
203217
then (
204218
(* await does this: set_ctx ctx.device.physical.primary_context; *)
205219
await ctx.device;
220+
(* Cudajit's contexts, streams and events are destroyed by their respective finalizers. *)
206221
Option.iter ctx.run_module ~f:Cu.Module.unload;
207-
Map.iteri ctx.ctx_arrays ~f:(fun ~key ~data:ptr ->
222+
Map.iteri ctx.ctx_arrays ~f:(fun ~key ~data ->
208223
if
209224
(not (Option.exists ctx.parent ~f:(fun pc -> Map.mem pc.ctx_arrays key)))
210225
&& not (Hashtbl.mem ctx.device.physical.cross_device_candidates key)
211-
then Cu.Deviceptr.mem_free ptr);
212-
if Option.is_none ctx.parent then Cu.Stream.destroy ctx.device.stream)
226+
then Cu.Deviceptr.mem_free data.ptr))
213227

214228
let init device =
215229
let ctx =
@@ -236,7 +250,8 @@ let unsafe_cleanup () =
236250
Cu.Context.set_current device.primary_context;
237251
Cu.Context.synchronize ();
238252
Option.iter !Utils.advance_captured_logs ~f:(fun callback -> callback ());
239-
Hashtbl.iter device.cross_device_candidates ~f:Cu.Deviceptr.mem_free;
253+
Hashtbl.iter device.cross_device_candidates ~f:(fun ctx_array ->
254+
Cu.Deviceptr.mem_free ctx_array.ptr);
240255
Cu.Device.primary_ctx_release device.dev))
241256
done;
242257
Core.Weak.fill !devices 0 len None
@@ -246,7 +261,7 @@ let%diagn_l_sexp from_host (ctx : context) tn =
246261
| { Tn.array = (lazy (Some hosted)); _ }, Some dst ->
247262
set_ctx ctx.ctx;
248263
[%log "copying", Tn.debug_name tn, "to", (dst : ctx_array), "from host"];
249-
let f src = Cu.Stream.memcpy_H_to_D ~dst ~src ctx.device.stream in
264+
let f src = Cu.Stream.memcpy_H_to_D ~dst:dst.ptr ~src ctx.device.stream in
250265
Ndarray.map { f } hosted;
251266
true
252267
| _ -> false
@@ -256,7 +271,7 @@ let%track_l_sexp to_host (ctx : context) (tn : Tn.t) =
256271
| { Tn.array = (lazy (Some hosted)); _ }, Some src ->
257272
set_ctx ctx.ctx;
258273
[%log "copying", Tn.debug_name tn, "at", (src : ctx_array), "to host"];
259-
let f dst = Cu.Stream.memcpy_D_to_H ~dst ~src ctx.device.stream in
274+
let f dst = Cu.Stream.memcpy_D_to_H ~dst ~src:src.ptr ctx.device.stream in
260275
Ndarray.map { f } hosted;
261276
true
262277
| _ -> false
@@ -266,11 +281,11 @@ let%track_l_sexp rec device_to_device (tn : Tn.t) ~into_merge_buffer ~(dst : con
266281
let same_physical = phys_equal dst.device.physical src.device.physical in
267282
let memcpy ~d_arr ~s_arr =
268283
if same_physical then
269-
Cu.Stream.memcpy_D_to_D ~size_in_bytes:(Tn.size_in_bytes tn) ~dst:d_arr ~src:s_arr
284+
Cu.Stream.memcpy_D_to_D ~size_in_bytes:(Tn.size_in_bytes tn) ~dst:d_arr.ptr ~src:s_arr.ptr
270285
dst.device.stream
271286
else
272-
Cu.Stream.memcpy_peer ~size_in_bytes:(Tn.size_in_bytes tn) ~dst:d_arr ~dst_ctx:dst.ctx
273-
~src:s_arr ~src_ctx:src.ctx dst.device.stream
287+
Cu.Stream.memcpy_peer ~size_in_bytes:(Tn.size_in_bytes tn) ~dst:d_arr.ptr ~dst_ctx:dst.ctx
288+
~src:s_arr.ptr ~src_ctx:src.ctx dst.device.stream
274289
in
275290
if
276291
same_physical
@@ -300,7 +315,7 @@ let%track_l_sexp rec device_to_device (tn : Tn.t) ~into_merge_buffer ~(dst : con
300315
true)
301316
| Streaming ->
302317
if phys_equal dst.device.physical src.device.physical then (
303-
dst.device.merge_buffer <- Some (s_arr, tn);
318+
dst.device.merge_buffer <- Some (s_arr.ptr, tn);
304319
[%log "using merge buffer for", Tn.debug_name tn, "from", src.label];
305320
true)
306321
else
@@ -310,7 +325,7 @@ let%track_l_sexp rec device_to_device (tn : Tn.t) ~into_merge_buffer ~(dst : con
310325
set_ctx dst.ctx;
311326
let size_in_bytes = Tn.size_in_bytes tn in
312327
opt_alloc_merge_buffer ~size_in_bytes dst.device.physical;
313-
memcpy ~d_arr:dst.device.physical.copy_merge_buffer ~s_arr;
328+
memcpy ~d_arr:{ ptr = dst.device.physical.copy_merge_buffer; tracking = None } ~s_arr;
314329
dst.device.merge_buffer <- Some (dst.device.physical.copy_merge_buffer, tn);
315330
[%log "copied into merge buffer", Tn.debug_name tn, "from", src.label];
316331
true)
@@ -354,7 +369,8 @@ let%diagn_sexp cuda_to_ptx ~name cu_src =
354369
Stdio.Out_channel.flush oc;
355370
Stdio.Out_channel.close oc;
356371
let oc = Out_channel.open_text @@ Utils.build_file @@ name ^ ".cu_log" in
357-
Stdio.Out_channel.output_string oc @@ Option.value_exn ~here:[%here] (Cu.Nvrtc.compilation_log ptx);
372+
Stdio.Out_channel.output_string oc
373+
@@ Option.value_exn ~here:[%here] (Cu.Nvrtc.compilation_log ptx);
358374
Stdio.Out_channel.flush oc;
359375
Stdio.Out_channel.close oc);
360376
ptx
@@ -486,7 +502,7 @@ let get_global_run_id =
486502
if !next_id < 0 then next_id := 0;
487503
!next_id
488504

489-
let link_proc ~prior_context ~name ~(params : (string * param_source) list) ~ctx_arrays
505+
let link_proc ~prior_context ~name ~(params : (string * param_source) list) ~ctx_arrays traced_store
490506
lowered_bindings run_module =
491507
let module Cu = Cudajit in
492508
let func = Cu.Module.get_function run_module ~name in
@@ -505,8 +521,8 @@ let link_proc ~prior_context ~name ~(params : (string * param_source) list) ~ctx
505521
prior_context.ctx_arrays? *)
506522
List.map params ~f:(function
507523
| _name, Param_ptr tn ->
508-
let ptr = Option.value_exn ~here:[%here] @@ Map.find ctx_arrays tn in
509-
S.Tensor ptr
524+
let arr = Option.value_exn ~here:[%here] @@ Map.find ctx_arrays tn in
525+
S.Tensor arr.ptr
510526
| _name, Log_file_name -> S.Int log_id
511527
| _name, Merge_buffer ->
512528
let ptr = fst @@ Option.value_exn ~here:[%here] context.device.merge_buffer in
@@ -533,15 +549,18 @@ let link_proc ~prior_context ~name ~(params : (string * param_source) list) ~ctx
533549
(* Map.iteri ctx_arrays ~f:(fun ~key ~data:ptr -> if key.Low_level.zero_initialized then
534550
Cu.Stream.memset_d8 ptr Unsigned.UChar.zero ~length:(Tn.size_in_bytes key.Low_level.tn)); *)
535551
[%log "launching the kernel"];
536-
(* TODO: This doesn't help. *)
537-
(* Option.iter !Utils.advance_captured_logs ~f:(fun callback -> callback ()); *)
538552
(if Utils.debug_log_from_routines () then
539553
Utils.add_log_processor ~prefix:log_id_prefix @@ fun _output ->
540554
[%log_block
541555
context.label;
542556
Utils.log_trace_tree _output]);
543-
S.launch_kernel func ~grid_dim_x:1 ~block_dim_x:1 ~shared_mem_bytes:0 context.device.stream
544-
args;
557+
S.launch_kernel func ~grid_dim_x:1 ~block_dim_x:1 ~shared_mem_bytes:0 context.device.stream args;
558+
Map.iteri ctx_arrays ~f:(fun ~key ~data ->
559+
(* Note: a tensor node can only be a context array if it is materialized. *)
560+
if Option.is_some data.tracking then
561+
let traced = Low_level.get_node traced_store key in
562+
if not traced.read_only then
563+
data.tracking <- Some (Cu.Delimited_event.record context.device.stream));
545564
[%log "kernel launched"]
546565
in
547566
( context,
@@ -557,7 +576,8 @@ let%diagn_sexp alloc_if_needed ctx device ~key ~data:node ctx_arrays =
557576
[%log Tn.debug_name key, "read_only", (node.read_only : bool)];
558577
let default () =
559578
set_ctx ctx;
560-
Cu.Deviceptr.mem_alloc ~size_in_bytes:(Tn.size_in_bytes key)
579+
let ptr = Cu.Deviceptr.mem_alloc ~size_in_bytes:(Tn.size_in_bytes key) in
580+
{ ptr; tracking = None }
561581
in
562582
let add_new () = Map.add_exn ctx_arrays ~key ~data:(default ()) in
563583
let physical = device.physical in
@@ -600,8 +620,8 @@ let%track3_sexp link prior_context (code : code) : context * _ * _ =
600620
let idx_params = Indexing.bound_symbols code.bindings in
601621
let lowered_bindings : Indexing.lowered_bindings = List.map idx_params ~f:(fun s -> (s, ref 0)) in
602622
let context, task =
603-
link_proc ~prior_context ~name:code.name ~params:code.params ~ctx_arrays lowered_bindings
604-
run_module
623+
link_proc ~prior_context ~name:code.name ~params:code.params ~ctx_arrays code.traced_store
624+
lowered_bindings run_module
605625
in
606626
(context, lowered_bindings, task)
607627

@@ -622,8 +642,8 @@ let%track3_sexp link_batch prior_context (code_batch : code_batch) : context * _
622642
~f:(alloc_if_needed ctx prior_context.device)
623643
in
624644
let context, task =
625-
link_proc ~prior_context:context ~name ~params ~ctx_arrays lowered_bindings
626-
run_module
645+
link_proc ~prior_context:context ~name ~params ~ctx_arrays traced_store
646+
lowered_bindings run_module
627647
in
628648
((context, ctx_arrays), Some task)))
629649
in

arrayjit/lib/cuda_backend.missing.ml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@ type context = Unimplemented_ctx [@@deriving sexp_of]
55
type code = Indexing.unit_bindings [@@deriving sexp_of]
66
type code_batch = Indexing.unit_bindings array [@@deriving sexp_of]
77
type ctx_array = | [@@deriving sexp_of]
8+
type event = unit
89

10+
let sync () = ()
11+
let is_done () = true
12+
let work_for _ctx _tn = Some ()
13+
let will_wait_for _ctx () = ()
914
let initialize (_config : Backend_utils.Types.config) = ()
1015
let is_initialized () = true
1116
let finalize _context = ()
@@ -58,6 +63,7 @@ let init Unimplemented_dev = Unimplemented_ctx
5863
let alloc_buffer ?old_buffer:_ ~size_in_bytes:_ Unimplemented_dev = Unimplemented_buffer_ptr
5964
let await _device = ()
6065
let is_idle _device = true
66+
let all_work _device = ()
6167
let get_device ~ordinal:_ = failwith "CUDA missing: install cudajit"
6268
let new_virtual_device Unimplemented_phys_dev = Unimplemented_dev
6369
let get_physical_device Unimplemented_dev = Unimplemented_phys_dev

arrayjit/lib/cuda_backend.mli

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@ open Base
33
type context [@@deriving sexp_of]
44
type code [@@deriving sexp_of]
55
type code_batch [@@deriving sexp_of]
6-
type ctx_array
6+
type ctx_array [@@deriving sexp_of]
7+
type event
8+
9+
val sync : event -> unit
10+
val is_done : event -> bool
11+
val work_for : context -> Tnode.t -> event option
12+
val will_wait_for : context -> event -> unit
713

814
open Backend_utils.Types
915

@@ -52,6 +58,7 @@ val alloc_buffer : ?old_buffer:buffer_ptr * int -> size_in_bytes:int -> device -
5258
val init : device -> context
5359
val await : device -> unit
5460
val is_idle : device -> bool
61+
val all_work : device -> event
5562
val sexp_of_device : device -> Sexplib.Sexp.t
5663
val num_physical_devices : unit -> int
5764
val suggested_num_virtual_devices : physical_device -> int

0 commit comments

Comments
 (0)