Skip to content

Commit 1775098

Browse files
committed
A big refactoring of backend_types.ml; CPU allocation and copying not using bigarrays
I'm unsure about this "busy-work" refactoring, but it does give some insight into the structure of backends.
1 parent e0515ef commit 1775098

14 files changed

+446
-334
lines changed

arrayjit/lib/backend_types.ml

Lines changed: 177 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -7,39 +7,171 @@ let _get_local_debug_runtime = Utils._get_local_debug_runtime
77
[%%global_debug_log_level 9]
88
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]
99

10-
module type Buffer_ptr = sig
10+
type 'buffer_ptr buffer = { ptr : 'buffer_ptr; size_in_bytes : int } [@@deriving sexp_of]
11+
12+
module Buffer_types (Buffer_ptr : sig
1113
type buffer_ptr [@@deriving sexp_of]
12-
type ctx_arrays = buffer_ptr Map.M(Tnode).t [@@deriving sexp_of]
14+
end) =
15+
struct
16+
type nonrec buffer = Buffer_ptr.buffer_ptr buffer [@@deriving sexp_of]
17+
type ctx_arrays = Buffer_ptr.buffer_ptr Map.M(Tnode).t [@@deriving sexp_of]
1318
end
1419

15-
module No_device_buffer_ptr : Buffer_ptr with type buffer_ptr = Ndarray.t = struct
16-
type buffer_ptr = Ndarray.t [@@deriving sexp_of]
17-
type ctx_arrays = buffer_ptr Map.M(Tnode).t [@@deriving sexp_of]
20+
module type Buffer = sig
21+
type buffer_ptr [@@deriving sexp_of]
22+
23+
val c_ptr_to_string : (buffer_ptr -> Ops.prec -> string) option
24+
25+
include module type of Buffer_types (struct
26+
type nonrec buffer_ptr = buffer_ptr [@@deriving sexp_of]
27+
end)
1828
end
1929

20-
module Types = struct
21-
type 'context routine = {
22-
context : 'context;
23-
schedule : Task.t;
24-
bindings : Indexing.lowered_bindings;
25-
name : string;
26-
inputs : Set.M(Tnode).t;
27-
(** The materialized read-only and read-before-write (within the routine) non-constant
28-
nodes. They are inputs in a broad sense, as they could be recurrent nodes or parameters. *)
29-
outputs : Set.M(Tnode).t; (** All the materialized nodes written-to by the routine. *)
30-
}
31-
[@@deriving sexp_of]
30+
module type Alloc_buffer = sig
31+
include Buffer
32+
33+
type stream
34+
35+
val alloc_buffer : ?old_buffer:buffer -> size_in_bytes:int -> stream -> buffer
36+
val alloc_zero_init_array : Ops.prec -> dims:int array -> stream -> buffer_ptr
37+
end
38+
39+
module type No_device_buffer_and_copying = sig
40+
include Alloc_buffer with type stream := unit
41+
42+
val buffer_to_buffer : dst:buffer_ptr -> src:buffer_ptr -> size_in_bytes:int -> unit
43+
val host_to_buffer : Ndarray.t -> dst:buffer_ptr -> unit
44+
val buffer_to_host : Ndarray.t -> src:buffer_ptr -> unit
45+
end
46+
47+
module No_device_buffer_and_copying :
48+
No_device_buffer_and_copying with type buffer_ptr = unit Ctypes.ptr = struct
49+
type buffer_ptr = unit Ctypes.ptr
50+
51+
let sexp_of_buffer_ptr = Ops.sexp_of_voidptr
52+
53+
include Buffer_types (struct
54+
type nonrec buffer_ptr = buffer_ptr [@@deriving sexp_of]
55+
end)
56+
57+
let alloc_buffer ?old_buffer ~size_in_bytes () =
58+
match old_buffer with
59+
| Some ({ size_in_bytes = old_size; _ } as buffer) when size_in_bytes <= old_size -> buffer
60+
| _ ->
61+
let ptr = Ctypes.(to_voidp @@ allocate_n int8_t ~count:size_in_bytes) in
62+
{ ptr; size_in_bytes }
63+
64+
let alloc_zero_init_array prec ~dims () =
65+
let size_in_bytes =
66+
(if Array.length dims = 0 then 0 else Array.reduce_exn dims ~f:( * )) * Ops.prec_in_bytes prec
67+
in
68+
Ctypes.(to_voidp @@ allocate_n int8_t ~count:size_in_bytes)
69+
70+
let buffer_to_buffer ~dst:Ctypes_static.(CPointer dst) ~src:Ctypes_static.(CPointer src)
71+
~size_in_bytes =
72+
Ctypes_memory_stubs.memcpy ~dst ~src ~size:size_in_bytes
73+
74+
let host_to_buffer src ~dst:Ctypes_static.(CPointer dst) =
75+
Ctypes_memory_stubs.memcpy ~dst
76+
~src:(Ndarray.get_fatptr_not_managed src)
77+
~size:(Ndarray.size_in_bytes src)
78+
79+
let buffer_to_host dst ~src:Ctypes_static.(CPointer src) =
80+
Ctypes_memory_stubs.memcpy
81+
~dst:(Ndarray.get_fatptr_not_managed dst)
82+
~src ~size:(Ndarray.size_in_bytes dst)
83+
84+
let c_ptr_to_string = Some Ops.c_ptr_to_string
85+
end
86+
87+
(** For now, we only configure a backend with regard to how many streams it should suggest using
88+
(where applicable). *)
89+
type config = Only_devices_parallel | For_parallel_copying | Most_parallel_streams
90+
[@@deriving equal, sexp, variants]
91+
92+
type merge_buffer_use = No | Streaming | Copy [@@deriving equal, sexp]
93+
94+
type param_source =
95+
| Log_file_name
96+
| Merge_buffer
97+
| Param_ptr of Tnode.t
98+
| Static_idx of Indexing.static_symbol
99+
[@@deriving sexp_of]
100+
101+
type 'context routine = {
102+
context : 'context;
103+
schedule : Task.t;
104+
bindings : Indexing.lowered_bindings;
105+
name : string;
106+
inputs : Set.M(Tnode).t;
107+
(** The materialized read-only and read-before-write (within the routine) non-constant nodes.
108+
They are inputs in a broad sense, as they could be recurrent nodes or parameters. *)
109+
outputs : Set.M(Tnode).t; (** All the materialized nodes written-to by the routine. *)
110+
}
111+
[@@deriving sexp_of]
112+
113+
module type Device_config = sig
114+
include Buffer
115+
116+
type device [@@deriving sexp_of]
117+
type stream_state [@@deriving sexp_of]
118+
type runner [@@deriving sexp_of]
119+
120+
type event [@@deriving sexp_of]
121+
(** An event tracks if a stream finished computing past a particular point in its schedue. These
122+
values are used internally for scheduling across streams of the backend, and can be used for
123+
explicit scheduling. *)
124+
end
32125

33-
type ('buffer_ptr, 'event, 'device, 'state, 'runner) stream = {
34-
device : 'device;
35-
state : 'state;
36-
merge_buffer : ('buffer_ptr * Tnode.t) option ref;
37-
unique_name : string;
38-
mutable allocated_buffer : ('buffer_ptr * int) option;
39-
runner : 'runner;
40-
requested_work_for : 'event option Hashtbl.M(Tnode).t;
41-
}
126+
type ('buffer_ptr, 'device, 'stream_state, 'runner, 'event) stream = {
127+
device : 'device;
128+
state : 'stream_state;
129+
merge_buffer : ('buffer_ptr * Tnode.t) option ref;
130+
unique_name : string;
131+
mutable allocated_buffer : 'buffer_ptr buffer option;
132+
runner : 'runner;
133+
requested_work_for : 'event option Hashtbl.M(Tnode).t;
134+
}
135+
[@@deriving sexp_of]
136+
137+
module type Device_types = sig
138+
include Device_config
139+
140+
type nonrec stream = (buffer_ptr, device, stream_state, runner, event) stream [@@deriving sexp_of]
141+
end
142+
143+
module Stream (Device_config : Device_config) = struct
144+
type nonrec stream =
145+
( Device_config.buffer_ptr,
146+
Device_config.device,
147+
Device_config.stream_state,
148+
Device_config.runner,
149+
Device_config.event )
150+
stream
42151
[@@deriving sexp_of]
152+
end
153+
154+
module type Device = sig
155+
include Device_types
156+
include Alloc_buffer with type buffer_ptr := buffer_ptr and type stream := stream
157+
158+
val make_stream :
159+
device:device -> state:stream_state -> unique_name:string -> runner:runner -> stream
160+
end
161+
162+
module Device_types (Device_config : Device_config) = struct
163+
include Device_config
164+
include Stream (Device_config)
165+
end
166+
167+
module Device
168+
(Device_types : Device_types)
169+
(Alloc_buffer : Alloc_buffer
170+
with type buffer_ptr := Device_types.buffer_ptr
171+
and type stream := Device_types.stream) =
172+
struct
173+
include Device_types
174+
include Alloc_buffer
43175

44176
let make_stream ~device ~state ~unique_name ~runner =
45177
{
@@ -51,25 +183,12 @@ module Types = struct
51183
runner : 'runner;
52184
requested_work_for = Hashtbl.create (module Tnode);
53185
}
54-
55-
(** For now, we only configure a backend with regard to how many streams it should suggest using
56-
(where applicable). *)
57-
type config = Only_devices_parallel | For_parallel_copying | Most_parallel_streams
58-
[@@deriving equal, sexp, variants]
59-
60-
type merge_buffer_use = No | Streaming | Copy [@@deriving equal, sexp]
61-
62-
type param_source =
63-
| Log_file_name
64-
| Merge_buffer
65-
| Param_ptr of Tnode.t
66-
| Static_idx of Indexing.static_symbol
67-
[@@deriving sexp_of]
68186
end
69187

70188
(** Parts shared by both assignments-level and lowered-level backend interfaces. *)
71189
module type Backend_any_common = sig
72-
type buffer_ptr [@@deriving sexp_of]
190+
include Buffer
191+
73192
type context [@@deriving sexp_of]
74193
type stream
75194

@@ -79,7 +198,7 @@ module type Backend_any_common = sig
79198

80199
val name : string
81200

82-
val initialize : Types.config -> unit
201+
val initialize : config -> unit
83202
(** Initializes a backend before first use. Typically does nothing if the backend is already
84203
initialized, but some backends can do some safe cleanups. *)
85204

@@ -91,15 +210,12 @@ module type Backend_any_common = sig
91210

92211
val finalize : context -> unit
93212
(** Finalizes (just) the context. *)
94-
95-
val alloc_buffer : ?old_buffer:buffer_ptr * int -> size_in_bytes:int -> stream -> buffer_ptr
96213
end
97214

98215
(** Parts shared by assignments-level backend interfaces. *)
99216
module type Backend_common = sig
100217
include Backend_any_common
101218

102-
type routine = context Types.routine [@@deriving sexp_of]
103219
type code [@@deriving sexp_of]
104220
type code_batch [@@deriving sexp_of]
105221

@@ -121,11 +237,12 @@ module type Backend_common = sig
121237
[occupancy] returns true are included. *)
122238
end
123239

124-
(** Parts shared by backend implementations excluding what's already in {!Backend_any_common}. *)
240+
(** Parts shared by backend implementations excluding what's already in {!Backend_any_common},
241+
except for {!Buffer} which is duplicated for technical reasons. *)
125242
module type Backend_impl_common = sig
126243
type context [@@deriving sexp_of]
127244

128-
include Buffer_ptr
245+
include Buffer
129246

130247
val ctx_arrays : context -> ctx_arrays
131248

@@ -136,51 +253,33 @@ module type Backend_impl_common = sig
136253
directly from the host. *)
137254
end
138255

139-
module type No_device_copying = sig
140-
type buffer_ptr
141-
142-
val buffer_to_buffer : dst:buffer_ptr -> src:buffer_ptr -> unit
143-
val host_to_buffer : Ndarray.t -> dst:buffer_ptr -> unit
144-
val buffer_to_host : Ndarray.t -> src:buffer_ptr -> unit
145-
end
146-
147256
(** An intermediate interface for stream-agnostic (typically CPU) backend implementations. *)
148257
module type No_device_backend = sig
149258
include Backend_common with type init_info := string and type stream := unit
150259
include Backend_impl_common with type context := context and type buffer_ptr := buffer_ptr
151260

152-
val link : merge_buffer:(buffer_ptr * Tnode.t) option ref -> context -> code -> routine
261+
val link : merge_buffer:(buffer_ptr * Tnode.t) option ref -> context -> code -> context routine
153262
(** Returns the routine for the code's procedure, in a new context derived from the given context. *)
154263

155264
val link_batch :
156265
merge_buffer:(buffer_ptr * Tnode.t) option ref ->
157266
context ->
158267
code_batch ->
159-
context * routine option array
268+
context * context routine option array
160269
(** Returns the routines for the procedures included in the code batch. The returned context is
161270
downstream of all the returned routines (in particular, the routines' contexts are not
162271
independent). *)
163272

164273
val get_used_memory : unit -> int
165274
(** Returns (an upper bound of) the memory used for arrays, in bytes. *)
166275

167-
include No_device_copying with type buffer_ptr := buffer_ptr
276+
include No_device_buffer_and_copying with type buffer_ptr := buffer_ptr
168277
end
169278

170279
(** Parts shared by both assignments-level and lowered-level backend interfaces providing streams
171280
and devices. *)
172281
module type Backend_device_common = sig
173-
type buffer_ptr
174-
type device
175-
176-
type event
177-
(** An event tracks if a stream finished computing past a particular point in its schedue. These
178-
values are used internally for scheduling across streams of the backend, and can be used for
179-
explicit scheduling. *)
180-
181-
type stream_state [@@deriving sexp_of]
182-
type runner [@@deriving sexp_of]
183-
type stream = (buffer_ptr, event, device, stream_state, runner) Types.stream [@@deriving sexp_of]
282+
include Device
184283

185284
include
186285
Backend_any_common
@@ -218,8 +317,8 @@ module type Backend_device_common = sig
218317
val num_devices : unit -> int
219318

220319
val suggested_num_streams : device -> int
221-
(** The optimal number of streams for the given device to follow the {!Types.config} strategy
222-
passed to {!No_device_backend.initialize}. *)
320+
(** The optimal number of streams for the given device to follow the {!config} strategy passed to
321+
{!No_device_backend.initialize}. *)
223322

224323
val new_stream : device -> stream
225324
val get_ctx_stream : context -> stream
@@ -252,7 +351,7 @@ module type With_buffer_retrieval_and_syncing = sig
252351
read. *)
253352

254353
val device_to_device :
255-
Tnode.t -> into_merge_buffer:Types.merge_buffer_use -> dst:context -> src:context -> bool
354+
Tnode.t -> into_merge_buffer:merge_buffer_use -> dst:context -> src:context -> bool
256355
(** [device_to_device tn ~into_merge_buffer ~dst ~src] proceeds as follows:
257356
- If the node is absent from the [src] context and either it is present in the [dst] context
258357
or [into_merge_buffer] is different from [No]: raises an error.
@@ -274,14 +373,15 @@ module type Backend = sig
274373

275374
include
276375
Backend_common
277-
with type context := context
376+
with type buffer_ptr := buffer_ptr
377+
and type context := context
278378
and type init_info := stream
279379
and type stream := stream
280380

281-
val link : context -> code -> routine
381+
val link : context -> code -> context routine
282382
(** Returns the routine for the code's procedure, in a new context derived from the given context. *)
283383

284-
val link_batch : context -> code_batch -> context * routine option array
384+
val link_batch : context -> code_batch -> context * context routine option array
285385
(** Returns the routines for the procedures included in the code batch. The returned context is
286386
downstream of all the returned routines. *)
287387

@@ -321,7 +421,7 @@ module type Lowered_no_device_backend = sig
321421
procedure ->
322422
context * Indexing.lowered_bindings * Task.t * string
323423

324-
include No_device_copying with type buffer_ptr := buffer_ptr
424+
include No_device_buffer_and_copying with type buffer_ptr := buffer_ptr
325425
end
326426

327427
module type No_buffer_retrieval_or_syncing = sig
@@ -336,7 +436,7 @@ module type No_buffer_retrieval_or_syncing = sig
336436

337437
val device_to_device :
338438
Tnode.t ->
339-
into_merge_buffer:Types.merge_buffer_use ->
439+
into_merge_buffer:merge_buffer_use ->
340440
dst_ptr:buffer_ptr option ->
341441
dst:context ->
342442
src_ptr:buffer_ptr ->

0 commit comments

Comments
 (0)