@@ -233,7 +233,10 @@ module Add_device
233233 (Add_scheduler : functor
234234 (Impl : For_add_scheduler )
235235 -> With_scheduler with type buffer_ptr = Impl. buffer_ptr)
236- (Backend : Lowered_no_device_backend ) : Lowered_backend = struct
236+ (Backend : Lowered_no_device_backend )
237+ (Config : sig
238+ val config : config
239+ end ) : Lowered_backend = struct
237240 include Backend
238241
239242 type code = { lowered : Low_level .optimized ; proc : Backend .procedure } [@@ deriving sexp_of ]
@@ -252,7 +255,10 @@ module Add_device
252255 let procs = compile_batch ~names bindings lowereds in
253256 { lowereds; procs }
254257
255- include Add_scheduler (Backend )
258+ include Add_scheduler (struct
259+ include Backend
260+ include Config
261+ end )
256262
257263 let link context (code : code ) ctx_arrays =
258264 let runner_label = get_name context.stream in
@@ -481,9 +487,12 @@ module Make_device_backend_from_lowered
481487 (Add_scheduler : functor
482488 (Impl : For_add_scheduler )
483489 -> With_scheduler with type buffer_ptr = Impl. buffer_ptr)
484- (Backend_impl : Lowered_no_device_backend ) =
490+ (Backend_impl : Lowered_no_device_backend )
491+ (Config : sig
492+ val config : config
493+ end ) =
485494struct
486- module Lowered_device = Add_device (Add_scheduler ) (Backend_impl )
495+ module Lowered_device = Add_device (Add_scheduler ) (Backend_impl ) ( Config )
487496 module Backend_device = Raise_backend (Lowered_device )
488497 include Backend_device
489498end
@@ -503,23 +512,31 @@ let finalize (type buffer_ptr dev runner event)
503512 && not (Hashtbl. mem ctx.stream.device.cross_stream_candidates key)
504513 then mem_free ctx.stream data)))
505514
506- let % track5_sexp fresh_backend ?backend_name () =
515+ let % track5_sexp fresh_backend ?backend_name ?(config = For_parallel_copying ) () =
507516 Stdlib.Gc. full_major () ;
508517 (* TODO: is running again needed to give time to weak arrays to become empty? *)
509518 Stdlib.Gc. full_major () ;
510519 (* Note: we invoke functors from within fresh_backend to fully isolate backends from distinct
511520 calls to fresh_backend. *)
521+ let module Config = struct
522+ let config = config
523+ end in
512524 match
513525 Option. value_or_thunk backend_name ~default: (fun () ->
514526 Utils. get_global_arg ~arg_name: " backend" ~default: " cc" )
515527 |> String. lowercase
516528 with
517- | "cc" -> (module Make_device_backend_from_lowered (Schedulers. Multicore ) (Cc_backend ) : Backend )
529+ | "cc" ->
530+ (module Make_device_backend_from_lowered (Schedulers. Multicore ) (Cc_backend ) (Config )
531+ : Backend )
518532 | "gccjit" ->
519- (module Make_device_backend_from_lowered (Schedulers. Multicore ) (Gcc_backend_impl ) : Backend )
520- | "sync_cc" -> (module Make_device_backend_from_lowered (Schedulers. Sync ) (Cc_backend ) : Backend )
533+ (module Make_device_backend_from_lowered (Schedulers. Multicore ) (Gcc_backend_impl ) (Config )
534+ : Backend )
535+ | "sync_cc" ->
536+ (module Make_device_backend_from_lowered (Schedulers. Sync ) (Cc_backend ) (Config ) : Backend )
521537 | "sync_gccjit" ->
522- (module Make_device_backend_from_lowered (Schedulers. Sync ) (Gcc_backend_impl ) : Backend )
523- | "cuda" -> (module Raise_backend ((Cuda_backend_impl. Fresh () : Lowered_backend )) : Backend )
524- | "metal" -> (module Raise_backend ((Metal_backend_impl. Fresh () : Lowered_backend )) : Backend )
538+ (module Make_device_backend_from_lowered (Schedulers. Sync ) (Gcc_backend_impl ) (Config )
539+ : Backend )
540+ | "cuda" -> (module Raise_backend ((Cuda_backend_impl. Fresh (Config ) : Lowered_backend )) : Backend )
541+ | "metal" -> (module Raise_backend ((Metal_backend_impl. Fresh (Config ) : Lowered_backend )) : Backend )
525542 | backend -> invalid_arg [% string " Backends.fresh_backend: unknown backend %{backend}" ]
0 commit comments