Skip to content

Commit e0515ef

Browse files
committed
Backends: expose Types.stream from the signatures, implement work_for
1 parent 7d020ca commit e0515ef

File tree

3 files changed

+29
-12
lines changed

3 files changed

+29
-12
lines changed

arrayjit/lib/backend_types.ml

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -170,15 +170,24 @@ end
170170
(** Parts shared by both assignments-level and lowered-level backend interfaces providing streams
171171
and devices. *)
172172
module type Backend_device_common = sig
173-
type stream [@@deriving sexp_of]
174-
175-
include Backend_any_common with type init_info := stream and type stream := stream
173+
type buffer_ptr
174+
type device
176175

177176
type event
178177
(** An event tracks if a stream finished computing past a particular point in its schedue. These
179178
values are used internally for scheduling across streams of the backend, and can be used for
180179
explicit scheduling. *)
181180

181+
type stream_state [@@deriving sexp_of]
182+
type runner [@@deriving sexp_of]
183+
type stream = (buffer_ptr, event, device, stream_state, runner) Types.stream [@@deriving sexp_of]
184+
185+
include
186+
Backend_any_common
187+
with type buffer_ptr := buffer_ptr
188+
and type init_info := stream
189+
and type stream := stream
190+
182191
val sync : event -> unit
183192
(** Blocks till the event completes, if it's not done already. *)
184193

@@ -192,8 +201,6 @@ module type Backend_device_common = sig
192201
called internally when necessary. But there is one exception, see {!device_to_device} when
193202
[into_merge_buffer=Streaming]. *)
194203

195-
type device
196-
197204
val get_used_memory : device -> int
198205
(** Returns (an upper bound of) the memory used for arrays, in bytes. *)
199206

arrayjit/lib/backends.ml

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,14 @@ struct
2424
type context = Backend.context
2525
type event = Backend.event
2626

27-
(* FIXME: *)
28-
let work_for _context _tn = failwith "NOT IMPLEMENTED YET"
27+
let work_for context tn =
28+
let stream = Backend.get_ctx_stream context in
29+
let default () = Some (Backend.all_work stream) in
30+
if not @@ Map.mem (Backend.ctx_arrays context) tn then None
31+
else
32+
Hashtbl.update_and_return stream.requested_work_for tn ~f:(function
33+
| None | Some None -> default ()
34+
| Some (Some _ as event) -> event)
2935

3036
let%diagn2_l_sexp from_host (ctx : Backend.context) tn =
3137
match (tn, Map.find (Backend.ctx_arrays ctx) tn) with
@@ -109,13 +115,13 @@ module Multicore_backend (Backend : Backend_types.No_device_backend) = struct
109115
}
110116
[@@deriving sexp_of]
111117

112-
type domain = unit Domain.t
118+
type runner = unit Domain.t
113119

114-
let sexp_of_domain (d : domain) = Sexp.Atom ("domain-" ^ Int.to_string (Domain.get_id d :> int))
120+
let sexp_of_runner (d : runner) = Sexp.Atom ("domain-" ^ Int.to_string (Domain.get_id d :> int))
115121

116122
type device = CPU [@@deriving sexp_of]
117123
type event = Not_implemented_yet [@@deriving sexp_of]
118-
type nonrec stream = (buffer_ptr, event, device, stream_state, domain) stream [@@deriving sexp_of]
124+
type nonrec stream = (buffer_ptr, event, device, stream_state, runner) stream [@@deriving sexp_of]
119125

120126
(** TODO: Blocks till the event completes, if it's not done already. *)
121127
let sync Not_implemented_yet = ()
@@ -359,7 +365,9 @@ module Sync_backend (Backend : Backend_types.No_device_backend) = struct
359365
let get_used_memory CPU = Backend.get_used_memory ()
360366
let next_stream = ref 0
361367

362-
type nonrec stream = (buffer_ptr, event, device, unit, unit) stream [@@deriving sexp_of]
368+
type stream_state = unit [@@deriving sexp_of]
369+
type runner = unit [@@deriving sexp_of]
370+
type nonrec stream = (buffer_ptr, event, device, stream_state, runner) stream [@@deriving sexp_of]
363371

364372
let new_stream CPU : stream =
365373
Int.incr next_stream;

arrayjit/lib/cuda_backend.cudajit.ml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ type device = {
4242
}
4343
[@@deriving sexp_of]
4444

45-
type nonrec stream = (buffer_ptr, event, device, unit, Cu.Stream.t) stream [@@deriving sexp_of]
45+
type stream_state = unit [@@deriving sexp_of]
46+
type runner = Cu.Stream.t [@@deriving sexp_of]
47+
type nonrec stream = (buffer_ptr, event, device, stream_state, runner) stream [@@deriving sexp_of]
4648

4749
type context = {
4850
label : string;

0 commit comments

Comments
 (0)