Skip to content

Commit e4b82ab

Browse files
committed
Finalize the transition to using local debug runtimes
1 parent e245c50 commit e4b82ab

File tree

8 files changed

+130
-225
lines changed

8 files changed

+130
-225
lines changed

arrayjit/lib/backends.ml

Lines changed: 71 additions & 127 deletions
Large diffs are not rendered by default.

arrayjit/lib/cc_backend.ml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,12 @@ let alloc_buffer ?old_buffer ~size_in_bytes () =
4141
| Some (_old_ptr, _old_size) -> assert false
4242
| None -> assert false
4343

44-
let to_buffer ?rt:_ tn ~dst ~src =
44+
let to_buffer tn ~dst ~src =
4545
let src = Map.find_exn src.arrays tn in
4646
Ndarray.map2 { f2 = Ndarray.A.blit } src dst
4747

48-
let host_to_buffer ?rt:_ src ~dst = Ndarray.map2 { f2 = Ndarray.A.blit } src dst
49-
let buffer_to_host ?rt:_ dst ~src = Ndarray.map2 { f2 = Ndarray.A.blit } src dst
48+
let host_to_buffer src ~dst = Ndarray.map2 { f2 = Ndarray.A.blit } src dst
49+
let buffer_to_host dst ~src = Ndarray.map2 { f2 = Ndarray.A.blit } src dst
5050
let unsafe_cleanup () = Stdlib.Gc.compact ()
5151

5252
let is_initialized, initialize =

arrayjit/lib/cuda_backend.cudajit.ml

Lines changed: 26 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -192,39 +192,30 @@ let unsafe_cleanup () =
192192
done;
193193
Core.Weak.fill !devices 0 len None
194194

195-
let%diagn_sexp from_host ?(rt : (module Minidebug_runtime.Debug_runtime) option) (ctx : context) tn
196-
=
195+
let%diagn_l_sexp from_host (ctx : context) tn =
197196
match (tn, Map.find ctx.global_arrays tn) with
198197
| { Tn.array = (lazy (Some hosted)); _ }, Some dst ->
199198
set_ctx ctx.ctx;
200-
(if Utils.settings.with_debug_level > 0 then
201-
let module Debug_runtime =
202-
(val Option.value_or_thunk rt ~default:(fun () -> (module Debug_runtime)))
203-
in
204-
[%log "copying", Tn.debug_name tn, "to", (dst : ctx_array), "from host"]);
199+
if Utils.settings.with_debug_level > 0 then
200+
[%log "copying", Tn.debug_name tn, "to", (dst : ctx_array), "from host"];
205201
let f src = Cudajit.memcpy_H_to_D_async ~dst ~src ctx.device.stream in
206202
Ndarray.map { f } hosted;
207203
true
208204
| _ -> false
209205

210-
let%track_sexp to_host ?(rt : (module Minidebug_runtime.Debug_runtime) option) (ctx : context)
211-
(tn : Tn.t) =
206+
let%track_l_sexp to_host (ctx : context) (tn : Tn.t) =
212207
match (tn, Map.find ctx.global_arrays tn) with
213208
| { Tn.array = (lazy (Some hosted)); _ }, Some src ->
214209
set_ctx ctx.ctx;
215-
(if Utils.settings.with_debug_level > 0 then
216-
let module Debug_runtime =
217-
(val Option.value_or_thunk rt ~default:(fun () ->
218-
(module Debug_runtime : Minidebug_runtime.Debug_runtime)))
219-
in
220-
[%log "copying", Tn.debug_name tn, "at", (src : ctx_array), "to host"]);
210+
if Utils.settings.with_debug_level > 0 then
211+
[%log "copying", Tn.debug_name tn, "at", (src : ctx_array), "to host"];
221212
let f dst = Cudajit.memcpy_D_to_H_async ~dst ~src ctx.device.stream in
222213
Ndarray.map { f } hosted;
223214
true
224215
| _ -> false
225216

226-
let%track_sexp rec device_to_device ?(rt : (module Minidebug_runtime.Debug_runtime) option)
227-
(tn : Tn.t) ~into_merge_buffer ~(dst : context) ~(src : context) =
217+
let%track_l_sexp rec device_to_device (tn : Tn.t) ~into_merge_buffer ~(dst : context)
218+
~(src : context) =
228219
let memcpy ~d_arr ~s_arr =
229220
if phys_equal dst.device.physical src.device.physical then
230221
Cudajit.memcpy_D_to_D_async ~size_in_bytes:(Tn.size_in_bytes tn) ~dst:d_arr ~src:s_arr
@@ -243,46 +234,34 @@ let%track_sexp rec device_to_device ?(rt : (module Minidebug_runtime.Debug_runti
243234
| Some d_arr ->
244235
set_ctx dst.ctx;
245236
memcpy ~d_arr ~s_arr;
246-
(if Utils.settings.with_debug_level > 0 then
247-
let module Debug_runtime =
248-
(val Option.value_or_thunk rt ~default:(fun () ->
249-
(module Debug_runtime : Minidebug_runtime.Debug_runtime)))
250-
in
251-
[%log
252-
"copied",
253-
Tn.debug_name tn,
254-
"from",
255-
src.label,
256-
"at",
257-
(s_arr : ctx_array),
258-
"to",
259-
(d_arr : ctx_array)]);
237+
if Utils.settings.with_debug_level > 0 then
238+
[%log
239+
"copied",
240+
Tn.debug_name tn,
241+
"from",
242+
src.label,
243+
"at",
244+
(s_arr : ctx_array),
245+
"to",
246+
(d_arr : ctx_array)];
260247
true)
261248
| Streaming ->
262249
if phys_equal dst.device.physical src.device.physical then (
263250
dst.device.merge_buffer <- Some (s_arr, tn);
264-
(if Utils.settings.with_debug_level > 0 then
265-
let module Debug_runtime =
266-
(val Option.value_or_thunk rt ~default:(fun () ->
267-
(module Debug_runtime : Minidebug_runtime.Debug_runtime)))
268-
in
269-
[%log "using merge buffer for", Tn.debug_name tn, "from", src.label]);
251+
if Utils.settings.with_debug_level > 0 then
252+
[%log "using merge buffer for", Tn.debug_name tn, "from", src.label];
270253
true)
271254
else
272255
(* TODO: support proper streaming, but it might be difficult. *)
273-
device_to_device ?rt tn ~into_merge_buffer:Copy ~dst ~src
256+
device_to_device tn ~into_merge_buffer:Copy ~dst ~src
274257
| Copy ->
275258
set_ctx dst.ctx;
276259
let size_in_bytes = Tn.size_in_bytes tn in
277260
opt_alloc_merge_buffer ~size_in_bytes dst.device.physical;
278261
memcpy ~d_arr:dst.device.physical.copy_merge_buffer ~s_arr;
279262
dst.device.merge_buffer <- Some (dst.device.physical.copy_merge_buffer, tn);
280-
(if Utils.settings.with_debug_level > 0 then
281-
let module Debug_runtime =
282-
(val Option.value_or_thunk rt ~default:(fun () ->
283-
(module Debug_runtime : Minidebug_runtime.Debug_runtime)))
284-
in
285-
[%log "copied into merge buffer", Tn.debug_name tn, "from", src.label]);
263+
if Utils.settings.with_debug_level > 0 then
264+
[%log "copied into merge buffer", Tn.debug_name tn, "from", src.label];
286265
true)
287266

288267
type code = {
@@ -522,7 +501,7 @@ let%track_sexp link_batch prior_context (code_batch : code_batch) =
522501
in
523502
(context, lowered_bindings, procs)
524503

525-
let to_buffer ?rt:_ _tn ~dst:_ ~src:_ = failwith "CUDA low-level: NOT IMPLEMENTED YET"
526-
let host_to_buffer ?rt:_ _tn ~dst:_ = failwith "CUDA low-level: NOT IMPLEMENTED YET"
527-
let buffer_to_host ?rt:_ _tn ~src:_ = failwith "CUDA low-level: NOT IMPLEMENTED YET"
504+
let to_buffer _tn ~dst:_ ~src:_ = failwith "CUDA low-level: NOT IMPLEMENTED YET"
505+
let host_to_buffer _tn ~dst:_ = failwith "CUDA low-level: NOT IMPLEMENTED YET"
506+
let buffer_to_host _tn ~src:_ = failwith "CUDA low-level: NOT IMPLEMENTED YET"
528507
let get_buffer _tn _context = failwith "CUDA low-level: NOT IMPLEMENTED YET"

arrayjit/lib/cuda_backend.missing.ml

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@ let ctx_arrays Unimplemented_ctx = Map.empty (module Tnode)
1919

2020
let link (Unimplemented_ctx : context) (code : code) =
2121
let lowered_bindings = List.map ~f:(fun s -> (s, ref 0)) @@ Indexing.bound_symbols code in
22-
let task =
23-
Tnode.{ description = "CUDA missing: install cudajit"; work = (fun _debug_runtime () -> ()) }
24-
in
22+
let task = Tnode.{ description = "CUDA missing: install cudajit"; work = (fun () -> ()) } in
2523
((Unimplemented_ctx : context), lowered_bindings, task)
2624

2725
let link_batch (Unimplemented_ctx : context) (code_batch : code_batch) =
@@ -31,16 +29,14 @@ let link_batch (Unimplemented_ctx : context) (code_batch : code_batch) =
3129
in
3230
let task =
3331
Array.map code_batch ~f:(fun _ ->
34-
Some
35-
Tnode.
36-
{ description = "CUDA missing: install cudajit"; work = (fun _debug_runtime () -> ()) })
32+
Some Tnode.{ description = "CUDA missing: install cudajit"; work = (fun () -> ()) })
3733
in
3834
((Unimplemented_ctx : context), lowered_bindings, task)
3935

4036
let unsafe_cleanup () = ()
41-
let from_host ?rt:_ _context _tn = false
42-
let to_host ?rt:_ _context _tn = false
43-
let device_to_device ?rt:_ _tn ~into_merge_buffer:_ ~dst:_ ~src:_ = false
37+
let from_host _context _tn = false
38+
let to_host _context _tn = false
39+
let device_to_device _tn ~into_merge_buffer:_ ~dst:_ ~src:_ = false
4440

4541
type device = Unimplemented_dev [@@deriving sexp_of]
4642
type physical_device = Unimplemented_phys_dev [@@deriving sexp_of]
@@ -58,7 +54,7 @@ let get_ctx_device Unimplemented_ctx = Unimplemented_dev
5854
let get_name Unimplemented_dev : string = failwith "CUDA missing: install cudajit"
5955
let to_ordinal _device = 0
6056
let to_subordinal _device = 0
61-
let to_buffer ?rt:_ _tn ~dst:_ ~src:_ = failwith "CUDA missing: install cudajit"
62-
let host_to_buffer ?rt:_ _tn ~dst:_ = failwith "CUDA missing: install cudajit"
63-
let buffer_to_host ?rt:_ _tn ~src:_ = failwith "CUDA missing: install cudajit"
57+
let to_buffer _tn ~dst:_ ~src:_ = failwith "CUDA missing: install cudajit"
58+
let host_to_buffer _tn ~dst:_ = failwith "CUDA missing: install cudajit"
59+
let buffer_to_host _tn ~src:_ = failwith "CUDA missing: install cudajit"
6460
let get_buffer _tn _context = failwith "CUDA missing: install cudajit"

arrayjit/lib/cuda_backend.mli

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,32 +28,21 @@ val link_batch :
2828

2929
val unsafe_cleanup : unit -> unit
3030

31-
val from_host : ?rt:(module Minidebug_runtime.Debug_runtime) -> context -> Tnode.t -> bool
31+
val from_host : context -> Tnode.t -> bool
3232
(** If the array is both hosted and in-context, copies from host to context. *)
3333

34-
val to_host : ?rt:(module Minidebug_runtime.Debug_runtime) -> context -> Tnode.t -> bool
34+
val to_host : context -> Tnode.t -> bool
3535
(** If the array is both hosted and in-context, copies from context to host. *)
3636

3737
val device_to_device :
38-
?rt:(module Minidebug_runtime.Debug_runtime) ->
39-
Tnode.t ->
40-
into_merge_buffer:merge_buffer_use ->
41-
dst:context ->
42-
src:context ->
43-
bool
38+
Tnode.t -> into_merge_buffer:merge_buffer_use -> dst:context -> src:context -> bool
4439
(** If the array is in both contexts, copies from [dst] to [src]. *)
4540

4641
type buffer_ptr [@@deriving sexp_of]
4742

48-
val to_buffer :
49-
?rt:(module Minidebug_runtime.Debug_runtime) -> Tnode.t -> dst:buffer_ptr -> src:context -> unit
50-
51-
val host_to_buffer :
52-
?rt:(module Minidebug_runtime.Debug_runtime) -> Ndarray.t -> dst:buffer_ptr -> unit
53-
54-
val buffer_to_host :
55-
?rt:(module Minidebug_runtime.Debug_runtime) -> Ndarray.t -> src:buffer_ptr -> unit
56-
43+
val to_buffer : Tnode.t -> dst:buffer_ptr -> src:context -> unit
44+
val host_to_buffer : Ndarray.t -> dst:buffer_ptr -> unit
45+
val buffer_to_host : Ndarray.t -> src:buffer_ptr -> unit
5746
val get_buffer : Tnode.t -> context -> buffer_ptr option
5847

5948
type physical_device

arrayjit/lib/gcc_backend.gccjit.ml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,12 @@ type context = {
4747

4848
let ctx_arrays context = context.arrays
4949

50-
let to_buffer ?rt:_ tn ~dst ~src =
50+
let to_buffer tn ~dst ~src =
5151
let src = Map.find_exn src.arrays tn in
5252
Ndarray.map2 { f2 = Ndarray.A.blit } src dst
5353

54-
let host_to_buffer ?rt:_ src ~dst = Ndarray.map2 { f2 = Ndarray.A.blit } src dst
55-
let buffer_to_host ?rt:_ dst ~src = Ndarray.map2 { f2 = Ndarray.A.blit } src dst
54+
let host_to_buffer src ~dst = Ndarray.map2 { f2 = Ndarray.A.blit } src dst
55+
let buffer_to_host dst ~src = Ndarray.map2 { f2 = Ndarray.A.blit } src dst
5656

5757
let unsafe_cleanup () =
5858
let open Gccjit in

arrayjit/lib/gcc_backend.missing.ml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@ let expected_merge_node Unimplemented_proc =
1515

1616
let is_in_context _node = failwith "gcc backend missing: install the optional dependency gccjit"
1717

18-
let to_buffer ?rt:_ _tn ~dst:_ ~src:_ =
18+
let to_buffer _tn ~dst:_ ~src:_ =
1919
failwith "gcc backend missing: install the optional dependency gccjit"
2020

21-
let host_to_buffer ?rt:_ _src ~dst:_ =
21+
let host_to_buffer _src ~dst:_ =
2222
failwith "gcc backend missing: install the optional dependency gccjit"
2323

24-
let buffer_to_host ?rt:_ _dst ~src:_ =
24+
let buffer_to_host _dst ~src:_ =
2525
failwith "gcc backend missing: install the optional dependency gccjit"
2626

2727
let alloc_buffer ?old_buffer:_ ~size_in_bytes:_ () =
@@ -35,13 +35,13 @@ let compile_batch ~names:_ ~opt_ctx_arrays:_ _bindings _codes =
3535
let link_compiled ~merge_buffer:_ Unimplemented_ctx Unimplemented_proc =
3636
failwith "gcc backend missing: install the optional dependency gccjit"
3737

38-
let from_host ?rt:_ Unimplemented_ctx _tn =
38+
let from_host Unimplemented_ctx _tn =
3939
failwith "gcc backend missing: install the optional dependency gccjit"
4040

41-
let to_host ?rt:_ Unimplemented_ctx _tn =
41+
let to_host Unimplemented_ctx _tn =
4242
failwith "gcc backend missing: install the optional dependency gccjit"
4343

44-
let device_to_device ?rt:_ _tn ~into_merge_buffer:_ ~dst:_ ~src:_ =
44+
let device_to_device _tn ~into_merge_buffer:_ ~dst:_ ~src:_ =
4545
failwith "gcc backend missing: install the optional dependency gccjit"
4646

4747
let physical_merge_buffers = false

arrayjit/lib/writing_a_backend.md

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ Currently, OCANNL integrates new backends via code in [Backends](backends.ml), s
2222
```ocaml
2323
type lowered_bindings = (static_symbol, int ref) List.Assoc.t (* in indexing.ml *)
2424
25-
type task = Work of ((module Debug_runtime) -> unit -> unit) (* in tnode.ml *)
25+
type task =
26+
| Task : { context_lifetime : 'a; description : string; work : unit -> unit; } -> task (* in tnode.ml *)
2627
2728
type 'context routine = {
2829
context : 'context;
@@ -253,33 +254,29 @@ module type No_device_backend = sig
253254
...
254255
val alloc_buffer : ?old_buffer:buffer_ptr * int -> size_in_bytes:int -> unit -> buffer_ptr
255256
...
256-
val to_buffer :
257-
?rt:(module Minidebug_runtime.Debug_runtime) -> Tnode.t -> dst:buffer_ptr -> src:context -> unit
257+
val to_buffer : Tnode.t -> dst:buffer_ptr -> src:context -> unit
258258
259-
val host_to_buffer :
260-
?rt:(module Minidebug_runtime.Debug_runtime) -> Ndarray.t -> dst:buffer_ptr -> unit
259+
val host_to_buffer : Ndarray.t -> dst:buffer_ptr -> unit
261260
262-
val buffer_to_host :
263-
?rt:(module Minidebug_runtime.Debug_runtime) -> Ndarray.t -> src:buffer_ptr -> unit
261+
val buffer_to_host : Ndarray.t -> src:buffer_ptr -> unit
264262
265263
val get_buffer : Tnode.t -> context -> buffer_ptr option
266264
end
267265
module type Backend = sig
268266
include No_device_backend
269267
...
270268
271-
val from_host : ?rt:(module Minidebug_runtime.Debug_runtime) -> context -> Tnode.t -> bool
269+
val from_host : context -> Tnode.t -> bool
272270
(** If the array is both hosted and in-context, schedules a copy from host to context and returns
273271
true, otherwise returns false. NOTE: when run for a device, it's the caller's responsibility
274272
to synchronize the device before the host's data is overwritten. *)
275273
276-
val to_host : ?rt:(module Minidebug_runtime.Debug_runtime) -> context -> Tnode.t -> bool
274+
val to_host : context -> Tnode.t -> bool
277275
(** If the array is both hosted and in-context, schedules a copy from context to host and returns
278276
true, otherwise returns false. NOTE: when run for a device, it's the caller's responsibility
279277
to synchronize the device before the host's data is read. *)
280278
281279
val device_to_device :
282-
?rt:(module Minidebug_runtime.Debug_runtime) ->
283280
Tnode.t ->
284281
into_merge_buffer:merge_buffer_use ->
285282
dst:context ->

0 commit comments

Comments
 (0)