@@ -60,7 +60,7 @@ type 'context routine = {
6060}
6161[@@ deriving sexp_of ]
6262
63- module type Device_config = sig
63+ module type Device_config_common = sig
6464 include Buffer
6565
6666 type dev [@@deriving sexp_of]
@@ -73,10 +73,19 @@ module type Device_config = sig
7373 (* * An event tracks if a stream finished computing past a particular point in its schedue. These
7474 values are used internally for scheduling across streams of the backend, and can be used for
7575 explicit scheduling. *)
76-
7776 val name : string
7877end
7978
79+ module type Device_config = sig
80+ include Device_config_common
81+
82+ type optimize_ctx [@@deriving sexp_of]
83+ (* * The optimization context for compiling code, in particular {!Low_level.optimize_ctx} for
84+ low-level backends. *)
85+
86+ val empty_optimize_ctx : optimize_ctx
87+ end
88+
8089type ('buffer_ptr, 'dev, 'runner, 'event) device_ref = {
8190 dev : 'dev ;
8291 ordinal : int ;
@@ -171,13 +180,14 @@ type ('buffer_ptr, 'dev, 'runner, 'event) stream =
171180
172181let equal_stream = equal_stream_ref
173182
174- type ('buffer_ptr, 'stream) context = {
183+ type ('buffer_ptr, 'stream, 'optimize_ctx ) context = {
175184 stream : 'stream ;
176- parent : ('buffer_ptr , 'stream ) context option ;
185+ parent : ('buffer_ptr , 'stream , 'optimize_ctx ) context option ;
177186 ctx_arrays : 'buffer_ptr ctx_arrays ;
178187 (* * This map contains arrays used in this context or an ancestor context (they might be unique
179188 but might also be cross-stream shared. *)
180189 finalized : Utils .atomic_bool ;
190+ optimize_ctx : 'optimize_ctx ;
181191}
182192[@@ deriving sexp_of ]
183193
@@ -186,7 +196,7 @@ module type Device_types = sig
186196
187197 type nonrec device = (buffer_ptr , dev , runner , event ) device [@@ deriving sexp_of ]
188198 type nonrec stream = (buffer_ptr , dev , runner , event ) stream [@@ deriving sexp_of ]
189- type nonrec context = (buffer_ptr , stream ) context [@@ deriving sexp_of ]
199+ type nonrec context = (buffer_ptr , stream , optimize_ctx ) context [@@ deriving sexp_of ]
190200end
191201
192202module type Device = sig
@@ -196,10 +206,10 @@ module type Device = sig
196206 val make_device : dev -> ordinal :int -> device
197207 val make_stream : device -> runner -> stream
198208
199- val make_context : ?ctx_arrays : ctx_arrays -> stream -> context
209+ val make_context : ?ctx_arrays : ctx_arrays -> ? optimize_ctx : optimize_ctx -> stream -> context
200210 (* * Returns a context without a parent. *)
201211
202- val make_child : ?ctx_arrays : ctx_arrays -> context -> context
212+ val make_child : ?ctx_arrays : ctx_arrays -> ? optimize_ctx : optimize_ctx -> context -> context
203213 (* * Returns a context with the same {!field:Backend_intf.context.stream}, and
204214 {!field:Backend_intf.context.ctx_arrays} if omitted, as the given context's, which is also the
205215 {!field:Backend_intf.context.parent}. *)
@@ -213,13 +223,18 @@ module type Backend_common = sig
213223
214224 type code [@@deriving sexp_of]
215225 type code_batch [@@deriving sexp_of]
226+ type optimize_ctx [@@deriving sexp_of]
227+
228+ val empty_optimize_ctx : optimize_ctx
229+ val get_optimize_ctx : code -> optimize_ctx
230+ val get_optimize_ctx_batch : code_batch -> optimize_ctx
216231
217- val compile : Low_level . optimize_ctx -> ?name : string -> Indexing .unit_bindings -> Assignments .comp -> code
232+ val compile : optimize_ctx -> ?name : string -> Indexing .unit_bindings -> Assignments .comp -> code
218233 (* * [name] is used to derive names for compilation artifacts. If omitted, it's derived via
219234 {!Assignments.get_name_exn}. *)
220235
221236 val compile_batch :
222- Low_level . optimize_ctx ->
237+ optimize_ctx ->
223238 ?names : string array ->
224239 ?occupancy : (name :string -> src_n :int -> bool ) ->
225240 Indexing .unit_bindings ->
330345
331346module type Backend = sig
332347 include Backend_common
333- include Backend_device_common with type buffer_ptr := buffer_ptr
348+
349+ include
350+ Backend_device_common with type buffer_ptr := buffer_ptr and type optimize_ctx := optimize_ctx
334351
335352 val link : context -> code -> context routine
336353 (* * Returns the routine for the code's procedure, in a new context derived from the given context.
0 commit comments