@@ -7,39 +7,171 @@ let _get_local_debug_runtime = Utils._get_local_debug_runtime
77[%% global_debug_log_level 9 ]
88[%% global_debug_log_level_from_env_var " OCANNL_LOG_LEVEL" ]
99
10- module type Buffer_ptr = sig
10+ type 'buffer_ptr buffer = { ptr : 'buffer_ptr ; size_in_bytes : int } [@@ deriving sexp_of ]
11+
12+ module Buffer_types (Buffer_ptr : sig
1113 type buffer_ptr [@@deriving sexp_of]
12- type ctx_arrays = buffer_ptr Map .M (Tnode ).t [@@ deriving sexp_of ]
14+ end ) =
15+ struct
16+ type nonrec buffer = Buffer_ptr .buffer_ptr buffer [@@ deriving sexp_of ]
17+ type ctx_arrays = Buffer_ptr .buffer_ptr Map .M (Tnode ).t [@@ deriving sexp_of ]
1318end
1419
15- module No_device_buffer_ptr : Buffer_ptr with type buffer_ptr = Ndarray. t = struct
16- type buffer_ptr = Ndarray .t [@@ deriving sexp_of ]
17- type ctx_arrays = buffer_ptr Map .M (Tnode ).t [@@ deriving sexp_of ]
20+ module type Buffer = sig
21+ type buffer_ptr [@@deriving sexp_of]
22+
23+ val c_ptr_to_string : (buffer_ptr -> Ops .prec -> string ) option
24+
25+ include module type of Buffer_types (struct
26+ type nonrec buffer_ptr = buffer_ptr [@@ deriving sexp_of ]
27+ end )
1828end
1929
20- module Types = struct
21- type 'context routine = {
22- context : 'context ;
23- schedule : Task .t ;
24- bindings : Indexing .lowered_bindings ;
25- name : string ;
26- inputs : Set .M (Tnode ).t;
27- (* * The materialized read-only and read-before-write (within the routine) non-constant
28- nodes. They are inputs in a broad sense, as they could be recurrent nodes or parameters. *)
29- outputs : Set .M (Tnode ).t; (* * All the materialized nodes written-to by the routine. *)
30- }
31- [@@ deriving sexp_of ]
30+ module type Alloc_buffer = sig
31+ include Buffer
32+
33+ type stream
34+
35+ val alloc_buffer : ?old_buffer : buffer -> size_in_bytes :int -> stream -> buffer
36+ val alloc_zero_init_array : Ops .prec -> dims :int array -> stream -> buffer_ptr
37+ end
38+
39+ module type No_device_buffer_and_copying = sig
40+ include Alloc_buffer with type stream := unit
41+
42+ val buffer_to_buffer : dst :buffer_ptr -> src :buffer_ptr -> size_in_bytes :int -> unit
43+ val host_to_buffer : Ndarray .t -> dst :buffer_ptr -> unit
44+ val buffer_to_host : Ndarray .t -> src :buffer_ptr -> unit
45+ end
46+
47+ module No_device_buffer_and_copying :
48+ No_device_buffer_and_copying with type buffer_ptr = unit Ctypes. ptr = struct
49+ type buffer_ptr = unit Ctypes .ptr
50+
51+ let sexp_of_buffer_ptr = Ops. sexp_of_voidptr
52+
53+ include Buffer_types (struct
54+ type nonrec buffer_ptr = buffer_ptr [@@ deriving sexp_of ]
55+ end )
56+
57+ let alloc_buffer ?old_buffer ~size_in_bytes () =
58+ match old_buffer with
59+ | Some ({ size_in_bytes = old_size ; _ } as buffer ) when size_in_bytes < = old_size -> buffer
60+ | _ ->
61+ let ptr = Ctypes. (to_voidp @@ allocate_n int8_t ~count: size_in_bytes) in
62+ { ptr; size_in_bytes }
63+
64+ let alloc_zero_init_array prec ~dims () =
65+ let size_in_bytes =
66+ (if Array. length dims = 0 then 0 else Array. reduce_exn dims ~f: ( * )) * Ops. prec_in_bytes prec
67+ in
68+ Ctypes. (to_voidp @@ allocate_n int8_t ~count: size_in_bytes)
69+
70+ let buffer_to_buffer ~dst :Ctypes_static. (CPointer dst ) ~src :Ctypes_static. (CPointer src )
71+ ~size_in_bytes =
72+ Ctypes_memory_stubs. memcpy ~dst ~src ~size: size_in_bytes
73+
74+ let host_to_buffer src ~dst :Ctypes_static. (CPointer dst ) =
75+ Ctypes_memory_stubs. memcpy ~dst
76+ ~src: (Ndarray. get_fatptr_not_managed src)
77+ ~size: (Ndarray. size_in_bytes src)
78+
79+ let buffer_to_host dst ~src :Ctypes_static. (CPointer src ) =
80+ Ctypes_memory_stubs. memcpy
81+ ~dst: (Ndarray. get_fatptr_not_managed dst)
82+ ~src ~size: (Ndarray. size_in_bytes dst)
83+
84+ let c_ptr_to_string = Some Ops. c_ptr_to_string
85+ end
86+
87+ (* * For now, we only configure a backend with regard to how many streams it should suggest using
88+ (where applicable). *)
89+ type config = Only_devices_parallel | For_parallel_copying | Most_parallel_streams
90+ [@@ deriving equal , sexp , variants ]
91+
92+ type merge_buffer_use = No | Streaming | Copy [@@ deriving equal , sexp ]
93+
94+ type param_source =
95+ | Log_file_name
96+ | Merge_buffer
97+ | Param_ptr of Tnode .t
98+ | Static_idx of Indexing .static_symbol
99+ [@@ deriving sexp_of ]
100+
101+ type 'context routine = {
102+ context : 'context ;
103+ schedule : Task .t ;
104+ bindings : Indexing .lowered_bindings ;
105+ name : string ;
106+ inputs : Set .M (Tnode ).t;
107+ (* * The materialized read-only and read-before-write (within the routine) non-constant nodes.
108+ They are inputs in a broad sense, as they could be recurrent nodes or parameters. *)
109+ outputs : Set .M (Tnode ).t; (* * All the materialized nodes written-to by the routine. *)
110+ }
111+ [@@ deriving sexp_of ]
112+
113+ module type Device_config = sig
114+ include Buffer
115+
116+ type device [@@deriving sexp_of]
117+ type stream_state [@@deriving sexp_of]
118+ type runner [@@deriving sexp_of]
119+
120+ type event [@@deriving sexp_of]
121+ (* * An event tracks if a stream finished computing past a particular point in its schedue. These
122+ values are used internally for scheduling across streams of the backend, and can be used for
123+ explicit scheduling. *)
124+ end
32125
33- type ('buffer_ptr, 'event, 'device, 'state, 'runner) stream = {
34- device : 'device ;
35- state : 'state ;
36- merge_buffer : ('buffer_ptr * Tnode .t ) option ref ;
37- unique_name : string ;
38- mutable allocated_buffer : ('buffer_ptr * int ) option ;
39- runner : 'runner ;
40- requested_work_for : 'event option Hashtbl .M (Tnode ).t;
41- }
126+ type ('buffer_ptr, 'device, 'stream_state, 'runner, 'event) stream = {
127+ device : 'device ;
128+ state : 'stream_state ;
129+ merge_buffer : ('buffer_ptr * Tnode .t ) option ref ;
130+ unique_name : string ;
131+ mutable allocated_buffer : 'buffer_ptr buffer option ;
132+ runner : 'runner ;
133+ requested_work_for : 'event option Hashtbl .M (Tnode ).t;
134+ }
135+ [@@ deriving sexp_of ]
136+
137+ module type Device_types = sig
138+ include Device_config
139+
140+ type nonrec stream = (buffer_ptr , device , stream_state , runner , event ) stream [@@ deriving sexp_of ]
141+ end
142+
143+ module Stream (Device_config : Device_config ) = struct
144+ type nonrec stream =
145+ ( Device_config .buffer_ptr ,
146+ Device_config .device ,
147+ Device_config .stream_state ,
148+ Device_config .runner ,
149+ Device_config .event )
150+ stream
42151 [@@ deriving sexp_of ]
152+ end
153+
154+ module type Device = sig
155+ include Device_types
156+ include Alloc_buffer with type buffer_ptr := buffer_ptr and type stream := stream
157+
158+ val make_stream :
159+ device :device -> state :stream_state -> unique_name :string -> runner :runner -> stream
160+ end
161+
162+ module Device_types (Device_config : Device_config ) = struct
163+ include Device_config
164+ include Stream (Device_config )
165+ end
166+
167+ module Device
168+ (Device_types : Device_types )
169+ (Alloc_buffer : Alloc_buffer
170+ with type buffer_ptr := Device_types. buffer_ptr
171+ and type stream := Device_types. stream) =
172+ struct
173+ include Device_types
174+ include Alloc_buffer
43175
44176 let make_stream ~device ~state ~unique_name ~runner =
45177 {
@@ -51,25 +183,12 @@ module Types = struct
51183 runner : 'runner ;
52184 requested_work_for = Hashtbl. create (module Tnode );
53185 }
54-
55- (* * For now, we only configure a backend with regard to how many streams it should suggest using
56- (where applicable). *)
57- type config = Only_devices_parallel | For_parallel_copying | Most_parallel_streams
58- [@@ deriving equal , sexp , variants ]
59-
60- type merge_buffer_use = No | Streaming | Copy [@@ deriving equal , sexp ]
61-
62- type param_source =
63- | Log_file_name
64- | Merge_buffer
65- | Param_ptr of Tnode .t
66- | Static_idx of Indexing .static_symbol
67- [@@ deriving sexp_of ]
68186end
69187
70188(* * Parts shared by both assignments-level and lowered-level backend interfaces. *)
71189module type Backend_any_common = sig
72- type buffer_ptr [@@deriving sexp_of]
190+ include Buffer
191+
73192 type context [@@deriving sexp_of]
74193 type stream
75194
@@ -79,7 +198,7 @@ module type Backend_any_common = sig
79198
80199 val name : string
81200
82- val initialize : Types . config -> unit
201+ val initialize : config -> unit
83202 (* * Initializes a backend before first use. Typically does nothing if the backend is already
84203 initialized, but some backends can do some safe cleanups. *)
85204
@@ -91,15 +210,12 @@ module type Backend_any_common = sig
91210
92211 val finalize : context -> unit
93212 (* * Finalizes (just) the context. *)
94-
95- val alloc_buffer : ?old_buffer : buffer_ptr * int -> size_in_bytes :int -> stream -> buffer_ptr
96213end
97214
98215(* * Parts shared by assignments-level backend interfaces. *)
99216module type Backend_common = sig
100217 include Backend_any_common
101218
102- type routine = context Types .routine [@@ deriving sexp_of ]
103219 type code [@@deriving sexp_of]
104220 type code_batch [@@deriving sexp_of]
105221
@@ -121,11 +237,12 @@ module type Backend_common = sig
121237 [occupancy] returns true are included. *)
122238end
123239
124- (* * Parts shared by backend implementations excluding what's already in {!Backend_any_common}. *)
240+ (* * Parts shared by backend implementations excluding what's already in {!Backend_any_common},
241+ except for {!Buffer} which is duplicated for technical reasons. *)
125242module type Backend_impl_common = sig
126243 type context [@@deriving sexp_of]
127244
128- include Buffer_ptr
245+ include Buffer
129246
130247 val ctx_arrays : context -> ctx_arrays
131248
@@ -136,51 +253,33 @@ module type Backend_impl_common = sig
136253 directly from the host. *)
137254end
138255
139- module type No_device_copying = sig
140- type buffer_ptr
141-
142- val buffer_to_buffer : dst :buffer_ptr -> src :buffer_ptr -> unit
143- val host_to_buffer : Ndarray .t -> dst :buffer_ptr -> unit
144- val buffer_to_host : Ndarray .t -> src :buffer_ptr -> unit
145- end
146-
147256(* * An intermediate interface for stream-agnostic (typically CPU) backend implementations. *)
148257module type No_device_backend = sig
149258 include Backend_common with type init_info := string and type stream := unit
150259 include Backend_impl_common with type context := context and type buffer_ptr := buffer_ptr
151260
152- val link : merge_buffer :(buffer_ptr * Tnode .t ) option ref -> context -> code -> routine
261+ val link : merge_buffer :(buffer_ptr * Tnode .t ) option ref -> context -> code -> context routine
153262 (* * Returns the routine for the code's procedure, in a new context derived from the given context. *)
154263
155264 val link_batch :
156265 merge_buffer :(buffer_ptr * Tnode .t ) option ref ->
157266 context ->
158267 code_batch ->
159- context * routine option array
268+ context * context routine option array
160269 (* * Returns the routines for the procedures included in the code batch. The returned context is
161270 downstream of all the returned routines (in particular, the routines' contexts are not
162271 independent). *)
163272
164273 val get_used_memory : unit -> int
165274 (* * Returns (an upper bound of) the memory used for arrays, in bytes. *)
166275
167- include No_device_copying with type buffer_ptr := buffer_ptr
276+ include No_device_buffer_and_copying with type buffer_ptr := buffer_ptr
168277end
169278
170279(* * Parts shared by both assignments-level and lowered-level backend interfaces providing streams
171280 and devices. *)
172281module type Backend_device_common = sig
173- type buffer_ptr
174- type device
175-
176- type event
177- (* * An event tracks if a stream finished computing past a particular point in its schedue. These
178- values are used internally for scheduling across streams of the backend, and can be used for
179- explicit scheduling. *)
180-
181- type stream_state [@@deriving sexp_of]
182- type runner [@@deriving sexp_of]
183- type stream = (buffer_ptr , event , device , stream_state , runner ) Types .stream [@@ deriving sexp_of ]
282+ include Device
184283
185284 include
186285 Backend_any_common
@@ -218,8 +317,8 @@ module type Backend_device_common = sig
218317 val num_devices : unit -> int
219318
220319 val suggested_num_streams : device -> int
221- (* * The optimal number of streams for the given device to follow the {!Types. config} strategy
222- passed to {!No_device_backend.initialize}. *)
320+ (* * The optimal number of streams for the given device to follow the {!config} strategy passed to
321+ {!No_device_backend.initialize}. *)
223322
224323 val new_stream : device -> stream
225324 val get_ctx_stream : context -> stream
@@ -252,7 +351,7 @@ module type With_buffer_retrieval_and_syncing = sig
252351 read. *)
253352
254353 val device_to_device :
255- Tnode .t -> into_merge_buffer :Types . merge_buffer_use -> dst :context -> src :context -> bool
354+ Tnode .t -> into_merge_buffer :merge_buffer_use -> dst :context -> src :context -> bool
256355 (* * [device_to_device tn ~into_merge_buffer ~dst ~src] proceeds as follows:
257356 - If the node is absent from the [src] context and either it is present in the [dst] context
258357 or [into_merge_buffer] is different from [No]: raises an error.
@@ -274,14 +373,15 @@ module type Backend = sig
274373
275374 include
276375 Backend_common
277- with type context := context
376+ with type buffer_ptr := buffer_ptr
377+ and type context := context
278378 and type init_info := stream
279379 and type stream := stream
280380
281- val link : context -> code -> routine
381+ val link : context -> code -> context routine
282382 (* * Returns the routine for the code's procedure, in a new context derived from the given context. *)
283383
284- val link_batch : context -> code_batch -> context * routine option array
384+ val link_batch : context -> code_batch -> context * context routine option array
285385 (* * Returns the routines for the procedures included in the code batch. The returned context is
286386 downstream of all the returned routines. *)
287387
@@ -321,7 +421,7 @@ module type Lowered_no_device_backend = sig
321421 procedure ->
322422 context * Indexing .lowered_bindings * Task .t * string
323423
324- include No_device_copying with type buffer_ptr := buffer_ptr
424+ include No_device_buffer_and_copying with type buffer_ptr := buffer_ptr
325425end
326426
327427module type No_buffer_retrieval_or_syncing = sig
@@ -336,7 +436,7 @@ module type No_buffer_retrieval_or_syncing = sig
336436
337437 val device_to_device :
338438 Tnode .t ->
339- into_merge_buffer :Types . merge_buffer_use ->
439+ into_merge_buffer :merge_buffer_use ->
340440 dst_ptr :buffer_ptr option ->
341441 dst :context ->
342442 src_ptr :buffer_ptr ->
0 commit comments