Skip to content

Commit 042e9c9

Browse files
committed
In progress: factor out the context record type, remove redundant accessors
Broken: CUDA_ERROR_INVALID_HANDLE for moons_demo_parallel_run, but moons_demo_parallel succeeds without crashing.
1 parent d9c6d88 commit 042e9c9

19 files changed

+420
-333
lines changed

arrayjit/lib/backend_impl.ml

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ module Device_types (Device_config : Device_config) = struct
7070
type nonrec stream =
7171
(Device_config.buffer_ptr, Device_config.dev, Device_config.runner, Device_config.event) stream
7272
[@@deriving sexp_of]
73+
74+
type nonrec context = (buffer_ptr, stream) context [@@deriving sexp_of]
7375
end
7476

7577
module Device
@@ -104,17 +106,20 @@ struct
104106
}
105107

106108
let get_name stream = [%string "%{name}:%{stream.device.ordinal#Int}:%{stream.stream_id#Int}"]
109+
110+
let make_context ?(ctx_arrays = Map.empty (module Tnode)) stream =
111+
{ stream; parent = None; ctx_arrays; finalized = Atomic.make false }
112+
113+
let make_child ?ctx_arrays parent =
114+
let ctx_arrays = Option.value ctx_arrays ~default:parent.ctx_arrays in
115+
{ stream = parent.stream; parent = Some parent; ctx_arrays; finalized = Atomic.make false }
107116
end
108117

109118
(** Parts shared by backend implementations excluding what's already in {!Backend_any_common},
110119
except for {!Buffer} which is duplicated for technical reasons. *)
111120
module type Backend_impl_common = sig
112-
type context [@@deriving sexp_of]
113-
114121
include Buffer
115122

116-
val ctx_arrays : context -> ctx_arrays
117-
118123
val is_in_context : Low_level.traced_array -> bool
119124
(** If true, the node is required to be in the contexts linked with code that uses it.
120125
@@ -124,19 +129,25 @@ end
124129

125130
(** An intermediate interface for stream-agnostic (typically CPU) backend implementations. *)
126131
module type No_device_backend = sig
127-
include Backend_common with type init_info := string and type stream := unit
128-
include Backend_impl_common with type context := context and type buffer_ptr := buffer_ptr
132+
include Backend_common
133+
include Backend_impl_common with type buffer_ptr := buffer_ptr
129134

130135
val name : string
131136

132-
val link : merge_buffer:(buffer_ptr * Tnode.t) option ref -> context -> code -> context routine
137+
val link :
138+
merge_buffer:(buffer_ptr * Tnode.t) option ref ->
139+
runner_label:string ->
140+
ctx_arrays ->
141+
code ->
142+
ctx_arrays routine
133143
(** Returns the routine for the code's procedure, in a new context derived from the given context. *)
134144

135145
val link_batch :
136146
merge_buffer:(buffer_ptr * Tnode.t) option ref ->
137-
context ->
147+
runner_label:string ->
148+
ctx_arrays ->
138149
code_batch ->
139-
context * context routine option array
150+
ctx_arrays * ctx_arrays routine option array
140151
(** Returns the routines for the procedures included in the code batch. The returned context is
141152
downstream of all the returned routines (in particular, the routines' contexts are not
142153
independent). *)
@@ -191,13 +202,7 @@ end
191202
(** Lowered-level stream agnostic backend interface: implementation-facing API for CPU backends. *)
192203
module type Lowered_no_device_backend = sig
193204
include Backend_impl_common
194-
195-
include
196-
Backend_any_common
197-
with type context := context
198-
and type stream := unit
199-
and type init_info := string
200-
and type buffer_ptr := buffer_ptr
205+
include Backend_any_common with type buffer_ptr := buffer_ptr
201206

202207
val name : string
203208

@@ -219,16 +224,18 @@ module type Lowered_no_device_backend = sig
219224

220225
val link_compiled :
221226
merge_buffer:(buffer_ptr * Tnode.t) option ref ->
222-
context ->
227+
runner_label:string ->
228+
ctx_arrays ->
223229
procedure ->
224-
context * Indexing.lowered_bindings * Task.t * string
230+
ctx_arrays * Indexing.lowered_bindings * Task.t * string
231+
(** [runner_label] will be [get_name stream] of the stream from which the [ctx_arrays] come from. *)
225232

226233
include No_device_buffer_and_copying with type buffer_ptr := buffer_ptr
227234
end
228235

229236
module type No_buffer_retrieval_or_syncing = sig
230237
include Backend_impl_common
231-
include Backend_device_common with type context := context and type buffer_ptr := buffer_ptr
238+
include Backend_device_common with type buffer_ptr := buffer_ptr
232239

233240
val from_host : dst_ptr:buffer_ptr -> dst:context -> Ndarray.t -> unit
234241
(** Like {!Backend.from_host}, but without synchronization and buffer retrieval. *)

arrayjit/lib/backend_intf.ml

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
open Base
66

77
type 'buffer_ptr buffer = { ptr : 'buffer_ptr; size_in_bytes : int } [@@deriving sexp_of]
8+
type 'buffer_ptr ctx_arrays = 'buffer_ptr Map.M(Tnode).t [@@deriving sexp_of]
89

910
module Buffer_types (Buffer_ptr : sig
1011
type buffer_ptr [@@deriving sexp_of]
1112
end) =
1213
struct
1314
type nonrec buffer = Buffer_ptr.buffer_ptr buffer [@@deriving sexp_of]
14-
type ctx_arrays = Buffer_ptr.buffer_ptr Map.M(Tnode).t [@@deriving sexp_of]
15+
type nonrec ctx_arrays = Buffer_ptr.buffer_ptr ctx_arrays [@@deriving sexp_of]
1516
end
1617

1718
module type Buffer = sig
@@ -105,11 +106,22 @@ type ('buffer_ptr, 'dev, 'runner, 'event) stream = {
105106
}
106107
[@@deriving sexp_of]
107108

109+
type ('buffer_ptr, 'stream) context = {
110+
stream : 'stream;
111+
parent : ('buffer_ptr, 'stream) context option;
112+
ctx_arrays : 'buffer_ptr ctx_arrays;
113+
(** This map contains arrays used in this context or an ancestor context (they might be unique
114+
but might also be cross-stream shared. *)
115+
finalized : Utils.atomic_bool;
116+
}
117+
[@@deriving sexp_of]
118+
108119
module type Device_types = sig
109120
include Device_config
110121

111122
type nonrec device = (buffer_ptr, dev, event) device [@@deriving sexp_of]
112123
type nonrec stream = (buffer_ptr, dev, runner, event) stream [@@deriving sexp_of]
124+
type nonrec context = (buffer_ptr, stream) context [@@deriving sexp_of]
113125
end
114126

115127
module type Device = sig
@@ -118,32 +130,28 @@ module type Device = sig
118130

119131
val make_device : dev -> ordinal:int -> device
120132
val make_stream : device -> runner -> stream_id:int -> stream
133+
134+
val make_context : ?ctx_arrays:ctx_arrays -> stream -> context
135+
(** Returns a context without a parent. *)
136+
137+
val make_child : ?ctx_arrays:ctx_arrays -> context -> context
138+
(** Returns a context with the same {!field-stream}, and {!field-ctx_arrays} if omitted, as the
139+
given context's, which is also the {!field-parent}. *)
140+
121141
val get_name : stream -> string
122142
end
123143

124144
(** Parts shared by both assignments-level and lowered-level backend interfaces. *)
125145
module type Backend_any_common = sig
126146
include Buffer
127147

128-
type context [@@deriving sexp_of]
129-
type stream
130-
131-
type init_info
132-
(** For backends derived via {!No_device_backend}, this is usually the backend name concatenated
133-
with the device or stream number. For {!Backend}, [init_info = stream]. *)
134-
135148
val initialize : config -> unit
136149
(** Initializes a backend before first use. Typically does nothing if the backend is already
137150
initialized, but some backends can do some safe cleanups. *)
138151

139152
val is_initialized : unit -> bool
140153
(** Returns false if there was no previous {!initialize} call. If it returns false, one must call
141154
{!initialize} before using the backend. *)
142-
143-
val init : init_info -> context
144-
145-
val finalize : context -> unit
146-
(** Finalizes (just) the context. *)
147155
end
148156

149157
(** Parts shared by assignments-level backend interfaces. *)
@@ -175,12 +183,7 @@ end
175183
and devices. *)
176184
module type Backend_device_common = sig
177185
include Device
178-
179-
include
180-
Backend_any_common
181-
with type buffer_ptr := buffer_ptr
182-
and type init_info := stream
183-
and type stream := stream
186+
include Backend_any_common with type buffer_ptr := buffer_ptr
184187

185188
val sync : event -> unit
186189
(** Blocks till the event completes, if it's not done already. *)
@@ -216,7 +219,6 @@ module type Backend_device_common = sig
216219
{!No_device_backend.initialize}. *)
217220

218221
val new_stream : device -> stream
219-
val get_ctx_stream : context -> stream
220222
end
221223

222224
module type With_buffer_retrieval_and_syncing = sig
@@ -262,13 +264,7 @@ end
262264

263265
module type Backend = sig
264266
include Backend_device_common
265-
266-
include
267-
Backend_common
268-
with type buffer_ptr := buffer_ptr
269-
and type context := context
270-
and type init_info := stream
271-
and type stream := stream
267+
include Backend_common with type buffer_ptr := buffer_ptr
272268

273269
val link : context -> code -> context routine
274270
(** Returns the routine for the code's procedure, in a new context derived from the given context. *)

0 commit comments

Comments
 (0)