Skip to content

Commit e586db2

Browse files
committed
Host memory wrapper needs to know the size
1 parent dcb400f commit e586db2

File tree

4 files changed

+12
-8
lines changed

4 files changed

+12
-8
lines changed

arrayjit/lib/backend_impl.ml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ open Backend_intf
1515
module type No_device_buffer_and_copying = sig
1616
include Alloc_buffer with type stream := unit
1717

18-
val use_host_memory : (unit Ctypes.ptr -> buffer_ptr) option
18+
val use_host_memory : (size_in_bytes:int -> unit Ctypes.ptr -> buffer_ptr) option
1919

2020
val get_used_memory : unit -> int
2121
(** Returns (an upper bound of) the memory used for arrays, in bytes. *)
@@ -29,7 +29,7 @@ module No_device_buffer_and_copying () :
2929
No_device_buffer_and_copying with type buffer_ptr = unit Ctypes.ptr = struct
3030
type buffer_ptr = unit Ctypes.ptr
3131

32-
let use_host_memory = Some Fn.id
32+
let use_host_memory = Some (fun ~size_in_bytes:_ ptr -> ptr)
3333
let sexp_of_buffer_ptr = Ops.sexp_of_voidptr
3434

3535
include Buffer_types (struct
@@ -145,9 +145,9 @@ end
145145
module type Backend_impl_common = sig
146146
include Backend_intf.Buffer
147147

148-
val use_host_memory : (unit Ctypes.ptr -> buffer_ptr) option
148+
val use_host_memory : (size_in_bytes:int -> unit Ctypes.ptr -> buffer_ptr) option
149149
(** If not [None], the backend will read from and write to the host memory directly whenever
150-
reasonable.
150+
reasonable. [size_in_bytes] is the size of the memory allocated on the host.
151151
152152
[use_host_memory] can only be [Some] on unified memory devices, like CPU and Apple Metal. *)
153153
end

arrayjit/lib/backends.ml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -413,9 +413,11 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
413413
&& Tn.known_shared_cross_streams key && Tn.is_hosted_force key 44
414414
then
415415
Hashtbl.update_and_return device.cross_stream_candidates key ~f:(fun _ ->
416-
get_buffer_ptr @@ Ndarray.get_voidptr_not_managed
417-
@@ Option.value_exn ~here:[%here]
418-
@@ Lazy.force key.array)
416+
get_buffer_ptr
417+
~size_in_bytes:(Lazy.force key.size_in_bytes)
418+
@@ Ndarray.get_voidptr_not_managed
419+
@@ Option.value_exn ~here:[%here]
420+
@@ Lazy.force key.array)
419421
else Hashtbl.find_or_add device.cross_stream_candidates key ~default
420422
in
421423
if Hashtbl.mem device.cross_stream_candidates key then

arrayjit/lib/c_syntax.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ module C_syntax (B : sig
1616

1717
type buffer_ptr
1818

19-
val use_host_memory : (unit Ctypes.ptr -> buffer_ptr) option
19+
val use_host_memory : (size_in_bytes:int -> unit Ctypes.ptr -> buffer_ptr) option
2020
val logs_to_stdout : bool
2121
val main_kernel_prefix : string
2222
val kernel_prep_line : string

arrayjit/lib/dune

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@
7272

7373
(library
7474
(name metal_backend)
75+
; Temporary disabled until we have a working Metal backend
76+
(enabled_if false)
7577
(optional)
7678
(modules metal_backend)
7779
(libraries base metal utils ir)

0 commit comments

Comments
 (0)