Skip to content

Commit 031fc20

Browse files
committed
In progress / broken: huge overhaul of backend internal interfaces and API
All code that could reasonably be shared is shared now. It will make it easier to generically apply things like buffer-to-buffer synchronization. Bumped to cudajit 0.6.0.
1 parent b3c8920 commit 031fc20

14 files changed

+349
-422
lines changed

CHANGES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
- Got rid of `unsafe_cleanup`.
1919
- Got rid of `subordinal`.
2020
- Removed dependency on `core`, broke up dependency on `ppx_jane`.
21+
- Huge refactoring of backend internal interfaces and API (not repeating same code).
2122
- TODO: Built per-tensor-node stream-to-stream synchronization into copying functions, removed obsolete blocking synchronizations.
2223

2324
### Fixed

arrayjit.opam

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ depends: [
3737
"odoc" {with-doc}
3838
]
3939
depopts: [
40-
"cudajit" {>= "0.5.1"}
40+
"cudajit" {>= "0.6.0"}
4141
"gccjit" {>= "0.3.2"}
4242
]
4343
conflicts: [
44-
"cudajit" {< "0.5.1"}
44+
"cudajit" {< "0.6.0"}
4545
"gccjit" {< "0.3.2"}
4646
]
4747
build: [

arrayjit/lib/backend_impl.ml

Lines changed: 56 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,15 @@ open Backend_intf
1616
module type No_device_buffer_and_copying = sig
1717
include Alloc_buffer with type stream := unit
1818

19+
val get_used_memory : unit -> int
20+
(** Returns (an upper bound of) the memory used for arrays, in bytes. *)
21+
1922
val buffer_to_buffer : dst:buffer_ptr -> src:buffer_ptr -> size_in_bytes:int -> unit
2023
val host_to_buffer : Ndarray.t -> dst:buffer_ptr -> unit
2124
val buffer_to_host : Ndarray.t -> src:buffer_ptr -> unit
2225
end
2326

24-
module No_device_buffer_and_copying :
27+
module No_device_buffer_and_copying () :
2528
No_device_buffer_and_copying with type buffer_ptr = unit Ctypes.ptr = struct
2629
type buffer_ptr = unit Ctypes.ptr
2730

@@ -31,18 +34,28 @@ module No_device_buffer_and_copying :
3134
type nonrec buffer_ptr = buffer_ptr [@@deriving sexp_of]
3235
end)
3336

34-
let alloc_buffer ?old_buffer ~size_in_bytes () =
35-
match old_buffer with
36-
| Some ({ size_in_bytes = old_size; _ } as buffer) when size_in_bytes <= old_size -> buffer
37-
| _ ->
38-
let ptr = Ctypes.(to_voidp @@ allocate_n int8_t ~count:size_in_bytes) in
39-
{ ptr; size_in_bytes }
37+
let used_memory = Atomic.make 0
38+
let get_used_memory () = Atomic.get used_memory
39+
40+
let alloc_impl ~size_in_bytes =
41+
let finalize _ptr = ignore (Atomic.fetch_and_add used_memory ~-size_in_bytes : int) in
42+
let ptr = Ctypes.(to_voidp @@ allocate_n int8_t ~count:size_in_bytes) in
43+
let _ : int = Atomic.fetch_and_add used_memory size_in_bytes in
44+
Stdlib.Gc.finalise finalize ptr;
45+
ptr
4046

4147
let alloc_zero_init_array prec ~dims () =
4248
let size_in_bytes =
4349
(if Array.length dims = 0 then 0 else Array.reduce_exn dims ~f:( * )) * Ops.prec_in_bytes prec
4450
in
45-
Ctypes.(to_voidp @@ allocate_n int8_t ~count:size_in_bytes)
51+
alloc_impl ~size_in_bytes
52+
53+
let alloc_buffer ?old_buffer ~size_in_bytes () =
54+
match old_buffer with
55+
| Some ({ size_in_bytes = old_size; _ } as buffer) when size_in_bytes <= old_size -> buffer
56+
| _ -> { ptr = alloc_impl ~size_in_bytes; size_in_bytes }
57+
58+
let free_buffer = None
4659

4760
let buffer_to_buffer ~dst:Ctypes_static.(CPointer dst) ~src:Ctypes_static.(CPointer src)
4861
~size_in_bytes =
@@ -127,34 +140,12 @@ module type Backend_impl_common = sig
127140
directly from the host. *)
128141
end
129142

130-
(** An intermediate interface for stream-agnostic (typically CPU) backend implementations. *)
131-
module type No_device_backend = sig
132-
include Backend_common
133-
include Backend_impl_common with type buffer_ptr := buffer_ptr
143+
(** An interface to adding schedulers for stream-agnostic (typically CPU) backend implementations. *)
144+
module type For_add_scheduler = sig
145+
include Backend_any_common
134146

135147
val name : string
136148

137-
val link :
138-
merge_buffer:(buffer_ptr * Tnode.t) option ref ->
139-
runner_label:string ->
140-
ctx_arrays ->
141-
code ->
142-
ctx_arrays routine
143-
(** Returns the routine for the code's procedure, in a new context derived from the given context. *)
144-
145-
val link_batch :
146-
merge_buffer:(buffer_ptr * Tnode.t) option ref ->
147-
runner_label:string ->
148-
ctx_arrays ->
149-
code_batch ->
150-
ctx_arrays * ctx_arrays routine option array
151-
(** Returns the routines for the procedures included in the code batch. The returned context is
152-
downstream of all the returned routines (in particular, the routines' contexts are not
153-
independent). *)
154-
155-
val get_used_memory : unit -> int
156-
(** Returns (an upper bound of) the memory used for arrays, in bytes. *)
157-
158149
include No_device_buffer_and_copying with type buffer_ptr := buffer_ptr
159150
end
160151

@@ -227,8 +218,8 @@ module type Lowered_no_device_backend = sig
227218
runner_label:string ->
228219
ctx_arrays ->
229220
procedure ->
230-
ctx_arrays * Indexing.lowered_bindings * Task.t * string
231-
(** [runner_label] will be [get_name stream] of the stream from which the [ctx_arrays] come from. *)
221+
ctx_arrays * Indexing.lowered_bindings * Task.t
222+
(** [runner_label] will be [get_name stream] of the stream holding the resulting [ctx_arrays]. *)
232223

233224
include No_device_buffer_and_copying with type buffer_ptr := buffer_ptr
234225
end
@@ -255,28 +246,48 @@ module type No_buffer_retrieval_or_syncing = sig
255246
[Invalid_argument] if [into_merge_buffer = No] and [dst_ptr = None]. *)
256247
end
257248

258-
(** Lowered-level backend interface: implementation-facing API for device-based (typically GPU)
259-
backends. *)
249+
(** A compilation-agnostic backend API -- {!Lowered_backend} instantates it, but
250+
{!Lowered_no_device_backend} backends are also converted to its instantations. *)
251+
module type With_scheduler = sig
252+
include Backend_device_common
253+
254+
val schedule_task : stream -> Task.t -> unit
255+
end
256+
257+
(** Lowered-level backend interface: implementation-facing API for device-based (GPU, or CPU after
258+
adding a scheduler) backends. *)
260259
module type Lowered_backend = sig
261-
include No_buffer_retrieval_or_syncing
260+
include Backend_device_common
261+
262+
include
263+
No_buffer_retrieval_or_syncing
264+
with type buffer_ptr := buffer_ptr
265+
and type dev := dev
266+
and type runner := runner
267+
and type event := event
262268

263269
type code [@@deriving sexp_of]
264270
type code_batch [@@deriving sexp_of]
265271

266-
val compile : name:string -> Indexing.unit_bindings -> Low_level.optimized -> code
272+
val compile : ?shared:bool -> name:string -> Indexing.unit_bindings -> Low_level.optimized -> code
267273

268274
val compile_batch :
275+
?shared:bool ->
269276
names:string option array ->
270277
Indexing.unit_bindings ->
271278
Low_level.optimized option array ->
272279
code_batch
273280

274-
val link : context -> code -> context * Indexing.lowered_bindings * Task.t
281+
val link : context -> code -> ctx_arrays * Indexing.lowered_bindings * Task.t
282+
(** The results correspond to the fields {!field-Backend_intf.ctx_arrays} of
283+
{!field-Backend_intf.context}, {!field-Backend_intf.bindings} and
284+
{!field-Backend_intf.schedule} of {!Backend_intf.routine}. *)
275285

276286
val link_batch :
277-
context -> code_batch -> context * Indexing.lowered_bindings * Task.t option array
278-
279-
val scheduled_merge_node : stream -> Tnode.t option
280-
(** [scheduled_merge_node stream] is the tensor node that would be in the [stream]'s merge buffer
281-
right after [await stream]. *)
287+
context ->
288+
code_batch ->
289+
ctx_arrays * Indexing.lowered_bindings * (ctx_arrays * Task.t) option array
290+
(** Returns the schedule tasks and their [ctx_arrays] for the procedures included in the code
291+
batch. The returned [ctx_arrays] will be part of a context downstream of all the tasks and the
292+
tasks' contexts are not independent (typically, they are cumulative). *)
282293
end

arrayjit/lib/backend_intf.ml

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ module type Alloc_buffer = sig
3232

3333
val alloc_buffer : ?old_buffer:buffer -> size_in_bytes:int -> stream -> buffer
3434
val alloc_zero_init_array : Ops.prec -> dims:int array -> stream -> buffer_ptr
35+
val free_buffer : (stream -> buffer_ptr -> unit) option
3536
end
3637

3738
(** For now, we only configure a backend with regard to how many streams it should suggest using
@@ -106,6 +107,10 @@ type ('buffer_ptr, 'dev, 'runner, 'event) stream = {
106107
}
107108
[@@deriving sexp_of]
108109

110+
(** [scheduled_merge_node stream] is the tensor node that would be in the [stream]'s merge buffer
111+
right after [await stream]. *)
112+
let scheduled_merge_node stream = Option.map ~f:snd !(stream.merge_buffer)
113+
109114
type ('buffer_ptr, 'stream) context = {
110115
stream : 'stream;
111116
parent : ('buffer_ptr, 'stream) context option;
@@ -180,7 +185,10 @@ module type Backend_common = sig
180185
end
181186

182187
(** Parts shared by both assignments-level and lowered-level backend interfaces providing streams
183-
and devices. *)
188+
and devices, both user-facing and implementation-facing. Does not include: compilation and
189+
linking (differnt for assignments-level and lowered-level); copying and tensor-node-level
190+
synchronization (copying is different for user-facing and implementation-facing APIs,
191+
synchronization is provided by a component outside of backend implementations). *)
184192
module type Backend_device_common = sig
185193
include Device
186194
include Backend_any_common with type buffer_ptr := buffer_ptr
@@ -215,8 +223,7 @@ module type Backend_device_common = sig
215223
val num_devices : unit -> int
216224

217225
val suggested_num_streams : device -> int
218-
(** The optimal number of streams for the given device to follow the {!config} strategy passed to
219-
{!No_device_backend.initialize}. *)
226+
(** The optimal number of streams for the given device to follow the {!config} strategy. *)
220227

221228
val new_stream : device -> stream
222229
end
@@ -263,8 +270,8 @@ module type With_buffer_retrieval_and_syncing = sig
263270
end
264271

265272
module type Backend = sig
266-
include Backend_device_common
267-
include Backend_common with type buffer_ptr := buffer_ptr
273+
include Backend_common
274+
include Backend_device_common with type buffer_ptr := buffer_ptr
268275

269276
val link : context -> code -> context routine
270277
(** Returns the routine for the code's procedure, in a new context derived from the given context. *)

0 commit comments

Comments
 (0)