Skip to content

Commit 5cafbd8

Browse files
committed
Refactor the arrayjit package into multiple libraries
so each optional backend can be an optional library. Also add scaffolding for the metal backend. Refactor backend integration to use the new IR library. Updated dependencies in `arrayjit.opam` and `dune-project` to include `metal`. Adjusted references throughout the codebase to align with the new structure.
1 parent a65dd89 commit 5cafbd8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+582
-386
lines changed

arrayjit.opam

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,19 @@ depends: [
3232
"ppx_string"
3333
"ppx_variants_conv"
3434
"ppx_expect"
35+
"metal"
3536
"ppx_minidebug" {>= "2.2.0"}
3637
"odoc" {with-doc}
3738
]
3839
depopts: [
3940
"cudajit" {>= "0.7.0"}
4041
"gccjit" {>= "0.3.2"}
42+
"metal"
4143
]
4244
conflicts: [
4345
"cudajit" {< "0.7.0"}
4446
"gccjit" {< "0.3.2"}
47+
"metal" {< "0.12"}
4548
]
4649
build: [
4750
["dune" "subst"] {dev}

arrayjit/lib/backends.ml

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
open Base
2+
open Ir
23
module Tn = Tnode
4+
module Schedulers = Schedulers
35
open Backend_intf
46
open Backend_impl
57

@@ -61,12 +63,13 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
6163
| None, _ -> ()
6264
| Some `Host, Assignments.(Node tn | Merge_buffer tn) ->
6365
Hashtbl.update s.device.host_reading_streams tn ~f
64-
| Some (`Src src), (Node tn | Merge_buffer tn) -> Hashtbl.update src.reader_streams tn ~f);
66+
| Some (`Src src), (Assignments.Node tn | Assignments.Merge_buffer tn) ->
67+
Hashtbl.update src.reader_streams tn ~f);
6568
(* Wait for writing to finish before reading. *)
6669
(match (from, tn) with
67-
| _, Merge_buffer _ | Some `Host, _ -> ()
68-
| _, Node tn ->
69-
Tn.prepare_read
70+
| _, Assignments.Merge_buffer _ | Some `Host, _ -> ()
71+
| _, Assignments.Node tn ->
72+
Tnode.prepare_read
7073
~is_done:(fun () -> Backend.is_done e)
7174
~sync:(fun () -> Backend.sync e)
7275
~transfer:(fun () ->
@@ -513,9 +516,10 @@ let%track5_sexp fresh_backend ?backend_name () =
513516
with
514517
| "cc" -> (module Make_device_backend_from_lowered (Schedulers.Multicore) (Cc_backend) : Backend)
515518
| "gccjit" ->
516-
(module Make_device_backend_from_lowered (Schedulers.Multicore) (Gcc_backend) : Backend)
519+
(module Make_device_backend_from_lowered (Schedulers.Multicore) (Gcc_backend_impl) : Backend)
517520
| "sync_cc" -> (module Make_device_backend_from_lowered (Schedulers.Sync) (Cc_backend) : Backend)
518521
| "sync_gccjit" ->
519-
(module Make_device_backend_from_lowered (Schedulers.Sync) (Gcc_backend) : Backend)
520-
| "cuda" -> (module Raise_backend ((Cuda_backend.Fresh () : Lowered_backend)) : Backend)
522+
(module Make_device_backend_from_lowered (Schedulers.Sync) (Gcc_backend_impl) : Backend)
523+
| "cuda" -> (module Raise_backend ((Cuda_backend_impl.Fresh () : Lowered_backend)) : Backend)
524+
| "metal" -> (module Raise_backend ((Metal_backend_impl.Fresh () : Lowered_backend)) : Backend)
521525
| backend -> invalid_arg [%string "Backends.fresh_backend: unknown backend %{backend}"]

arrayjit/lib/backends.mli

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,24 @@
22

33
open Base
44

5+
module Schedulers = Schedulers
6+
57
val finalize :
68
'buffer_ptr 'dev 'runner 'event.
7-
(module Backend_intf.Backend
9+
(module Ir.Backend_intf.Backend
810
with type buffer_ptr = 'buffer_ptr
911
and type dev = 'dev
1012
and type event = 'event
1113
and type runner = 'runner) ->
12-
('buffer_ptr, ('buffer_ptr, 'dev, 'runner, 'event) Backend_intf.stream) Backend_intf.context ->
14+
('buffer_ptr, ('buffer_ptr, 'dev, 'runner, 'event) Ir.Backend_intf.stream) Ir.Backend_intf.context ->
1315
unit
1416
(** Frees the arrays that are specific to the context -- not contained in the parent context. Note:
1517
use [finalize] to optimize memory, it is not obligatory because all arrays are freed when their
1618
[buffer_ptr]s are garbage-collected.
1719
1820
Note: this type will get simpler with modular explicits. *)
1921

20-
val fresh_backend : ?backend_name:string -> unit -> (module Backend_intf.Backend)
22+
val fresh_backend : ?backend_name:string -> unit -> (module Ir.Backend_intf.Backend)
2123
(** Creates a new backend corresponding to [backend_name], or if omitted, selected via the global
2224
[backend] setting. It should be safe to reinitialize the tensor system before [fresh_backend].
2325
*)

arrayjit/lib/cc_backend.ml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
open Base
22
module Lazy = Utils.Lazy
3+
open Ir
34

45
let _get_local_debug_runtime = Utils.get_local_debug_runtime
56

arrayjit/lib/cc_backend.mli

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
include Backend_impl.Lowered_no_device_backend
1+
include Ir.Backend_impl.Lowered_no_device_backend
File renamed without changes.

arrayjit/lib/cuda_backend.mli

Lines changed: 0 additions & 1 deletion
This file was deleted.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
include Cuda_backend
File renamed without changes.

arrayjit/lib/cuda_backend_impl.mli

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
module Fresh : functor () -> Ir.Backend_impl.Lowered_backend

0 commit comments

Comments
 (0)