Skip to content

Commit a6a7fda

Browse files
committed
Refactor backends interface to initialization on module creation
1 parent fd67b3c commit a6a7fda

19 files changed

+82
-80
lines changed

CHANGES.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
1+
## [0.5.3] -- current
2+
3+
### Added
4+
5+
- The Metal framework backend (Apple Silicon).
6+
7+
### Changed
8+
9+
- Removed `initialize` and `is_initialized` from the backend API; instead, backends should be initialized on functor application. The functors now take `config` as argument.
10+
111
## [0.5.2] -- 2025-04-07
212

313
### Added

arrayjit/lib/backend_impl.ml

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,7 @@ struct
141141
{ stream = parent.stream; parent = Some parent; ctx_arrays; finalized = Atomic.make false }
142142
end
143143

144-
(** Parts shared by backend implementations excluding what's already in
145-
{!Backend_intf.Backend_any_common}, except for {!Backend_intf.Buffer} which is duplicated for
146-
technical reasons. *)
144+
(** Parts shared by backend implementations. *)
147145
module type Backend_impl_common = sig
148146
include Backend_intf.Buffer
149147

@@ -157,17 +155,15 @@ end
157155
(** An interface to adding schedulers for stream-agnostic (typically CPU) backend implementations.
158156
*)
159157
module type For_add_scheduler = sig
160-
include Backend_any_common
161-
162158
val name : string
159+
val config : config
163160

164-
include No_device_buffer_and_copying with type buffer_ptr := buffer_ptr
161+
include No_device_buffer_and_copying
165162
end
166163

167164
(** Lowered-level stream agnostic backend interface: implementation-facing API for CPU backends. *)
168165
module type Lowered_no_device_backend = sig
169166
include Backend_impl_common
170-
include Backend_any_common with type buffer_ptr := buffer_ptr
171167

172168
val name : string
173169

arrayjit/lib/backend_intf.ml

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -206,22 +206,10 @@ module type Device = sig
206206
val get_name : stream -> string
207207
end
208208

209-
(** Parts shared by both assignments-level and lowered-level backend interfaces. *)
210-
module type Backend_any_common = sig
211-
include Buffer
212-
213-
val initialize : config -> unit
214-
(** Initializes a backend before first use. Typically does nothing if the backend is already
215-
initialized, but some backends can do some safe cleanups. *)
216-
217-
val is_initialized : unit -> bool
218-
(** Returns false if there was no previous {!initialize} call. If it returns false, one must call
219-
{!initialize} before using the backend. *)
220-
end
221209

222210
(** Parts shared by assignments-level backend interfaces. *)
223211
module type Backend_common = sig
224-
include Backend_any_common
212+
include Buffer
225213

226214
type code [@@deriving sexp_of]
227215
type code_batch [@@deriving sexp_of]
@@ -250,7 +238,6 @@ end
250238
synchronization is provided by a component outside of backend implementations). *)
251239
module type Backend_device_common = sig
252240
include Device
253-
include Backend_any_common with type buffer_ptr := buffer_ptr
254241

255242
val sync : event -> unit
256243
(** Blocks till the event completes, if it's not done already.

arrayjit/lib/backends.ml

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,10 @@ module Add_device
233233
(Add_scheduler : functor
234234
(Impl : For_add_scheduler)
235235
-> With_scheduler with type buffer_ptr = Impl.buffer_ptr)
236-
(Backend : Lowered_no_device_backend) : Lowered_backend = struct
236+
(Backend : Lowered_no_device_backend)
237+
(Config : sig
238+
val config : config
239+
end) : Lowered_backend = struct
237240
include Backend
238241

239242
type code = { lowered : Low_level.optimized; proc : Backend.procedure } [@@deriving sexp_of]
@@ -252,7 +255,10 @@ module Add_device
252255
let procs = compile_batch ~names bindings lowereds in
253256
{ lowereds; procs }
254257

255-
include Add_scheduler (Backend)
258+
include Add_scheduler (struct
259+
include Backend
260+
include Config
261+
end)
256262

257263
let link context (code : code) ctx_arrays =
258264
let runner_label = get_name context.stream in
@@ -481,9 +487,12 @@ module Make_device_backend_from_lowered
481487
(Add_scheduler : functor
482488
(Impl : For_add_scheduler)
483489
-> With_scheduler with type buffer_ptr = Impl.buffer_ptr)
484-
(Backend_impl : Lowered_no_device_backend) =
490+
(Backend_impl : Lowered_no_device_backend)
491+
(Config : sig
492+
val config : config
493+
end) =
485494
struct
486-
module Lowered_device = Add_device (Add_scheduler) (Backend_impl)
495+
module Lowered_device = Add_device (Add_scheduler) (Backend_impl) (Config)
487496
module Backend_device = Raise_backend (Lowered_device)
488497
include Backend_device
489498
end
@@ -503,23 +512,31 @@ let finalize (type buffer_ptr dev runner event)
503512
&& not (Hashtbl.mem ctx.stream.device.cross_stream_candidates key)
504513
then mem_free ctx.stream data)))
505514

506-
let%track5_sexp fresh_backend ?backend_name () =
515+
let%track5_sexp fresh_backend ?backend_name ?(config = For_parallel_copying) () =
507516
Stdlib.Gc.full_major ();
508517
(* TODO: is running again needed to give time to weak arrays to become empty? *)
509518
Stdlib.Gc.full_major ();
510519
(* Note: we invoke functors from within fresh_backend to fully isolate backends from distinct
511520
calls to fresh_backend. *)
521+
let module Config = struct
522+
let config = config
523+
end in
512524
match
513525
Option.value_or_thunk backend_name ~default:(fun () ->
514526
Utils.get_global_arg ~arg_name:"backend" ~default:"cc")
515527
|> String.lowercase
516528
with
517-
| "cc" -> (module Make_device_backend_from_lowered (Schedulers.Multicore) (Cc_backend) : Backend)
529+
| "cc" ->
530+
(module Make_device_backend_from_lowered (Schedulers.Multicore) (Cc_backend) (Config)
531+
: Backend)
518532
| "gccjit" ->
519-
(module Make_device_backend_from_lowered (Schedulers.Multicore) (Gcc_backend_impl) : Backend)
520-
| "sync_cc" -> (module Make_device_backend_from_lowered (Schedulers.Sync) (Cc_backend) : Backend)
533+
(module Make_device_backend_from_lowered (Schedulers.Multicore) (Gcc_backend_impl) (Config)
534+
: Backend)
535+
| "sync_cc" ->
536+
(module Make_device_backend_from_lowered (Schedulers.Sync) (Cc_backend) (Config) : Backend)
521537
| "sync_gccjit" ->
522-
(module Make_device_backend_from_lowered (Schedulers.Sync) (Gcc_backend_impl) : Backend)
523-
| "cuda" -> (module Raise_backend ((Cuda_backend_impl.Fresh () : Lowered_backend)) : Backend)
524-
| "metal" -> (module Raise_backend ((Metal_backend_impl.Fresh () : Lowered_backend)) : Backend)
538+
(module Make_device_backend_from_lowered (Schedulers.Sync) (Gcc_backend_impl) (Config)
539+
: Backend)
540+
| "cuda" -> (module Raise_backend ((Cuda_backend_impl.Fresh (Config) : Lowered_backend)) : Backend)
541+
| "metal" -> (module Raise_backend ((Metal_backend_impl.Fresh (Config) : Lowered_backend)) : Backend)
525542
| backend -> invalid_arg [%string "Backends.fresh_backend: unknown backend %{backend}"]

arrayjit/lib/backends.mli

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ val finalize :
1919
2020
Note: this type will get simpler with modular explicits. *)
2121

22-
val fresh_backend : ?backend_name:string -> unit -> (module Ir.Backend_intf.Backend)
22+
val fresh_backend :
23+
?backend_name:string -> ?config:Ir.Backend_intf.config -> unit -> (module Ir.Backend_intf.Backend)
2324
(** Creates a new backend corresponding to [backend_name], or if omitted, selected via the global
2425
[backend] setting. It should be safe to reinitialize the tensor system before [fresh_backend].
2526
*)

arrayjit/lib/cc_backend.ml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,6 @@ let compiler_command =
3636

3737
module Tn = Tnode
3838

39-
let is_initialized, initialize =
40-
let initialized = ref false in
41-
((fun () -> !initialized), fun _config -> initialized := true)
42-
4339
type library = { lib : (Dl.library[@sexp.opaque]); libname : string } [@@deriving sexp_of]
4440

4541
type procedure = {

arrayjit/lib/cuda_backend.ml

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,10 @@ end
7474
(* [initialized_devices] never forgets its entries. *)
7575
let initialized_devices = Hash_set.create (module Int)
7676

77-
module Fresh () = struct
77+
module Fresh (Config : sig
78+
val config : Ir.Backend_intf.config
79+
end) =
80+
struct
7881
include Backend_impl.Device (Device_stream) (Alloc_buffer)
7982

8083
let use_host_memory = None
@@ -83,17 +86,8 @@ module Fresh () = struct
8386
let will_wait_for context event = Cu.Delimited_event.wait context.stream.runner event
8487
let sync event = Cu.Delimited_event.synchronize event
8588
let all_work stream = Cu.Delimited_event.record stream.runner
86-
let global_config = ref For_parallel_copying
8789
let () = Cu.init ()
8890

89-
let is_initialized, initialize =
90-
let initialized = ref false in
91-
let init (config : config) : unit =
92-
initialized := true;
93-
global_config := config
94-
in
95-
((fun () -> !initialized), init)
96-
9791
let num_devices = Cu.Device.get_count
9892
let devices = ref @@ Array.create ~len:(num_devices ()) None
9993

@@ -162,7 +156,7 @@ module Fresh () = struct
162156
get_props
163157

164158
let suggested_num_streams device =
165-
match !global_config with
159+
match Config.config with
166160
| Only_devices_parallel -> 1
167161
| For_parallel_copying -> 1 + (cuda_properties device).async_engine_count
168162
| Most_parallel_streams -> (cuda_properties device).multiprocessor_count
Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1-
module Fresh () = struct
1+
module Fresh (Config : sig
2+
val config : Ir.Backend_intf.config
3+
end) =
4+
struct
5+
let _ = ignore Config.config
6+
27
include Lowered_backend_missing
38
end

arrayjit/lib/cuda_backend_impl.mli

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1-
module Fresh : functor () -> Ir.Backend_impl.Lowered_backend
1+
module Fresh (_ : sig
2+
val config : Ir.Backend_intf.config
3+
end) : Ir.Backend_impl.Lowered_backend

arrayjit/lib/dune

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@
7272

7373
(library
7474
(name metal_backend)
75+
; TODO: Enable this backend.
76+
(enabled_if false)
7577
(optional)
7678
(modules metal_backend)
7779
(libraries base metal utils ir)

0 commit comments

Comments
 (0)