Skip to content

Commit 65255f6

Browse files
committed
Propagate optimization context (with the computations table for Low_level optimizer)
1 parent aac0de9 commit 65255f6

16 files changed

+203
-97
lines changed

arrayjit/lib/backend_impl.ml

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,16 @@ module Device_types (Device_config : Device_config) = struct
8989

9090
type nonrec device = (buffer_ptr, dev, runner, event) device [@@deriving sexp_of]
9191
type nonrec stream = (buffer_ptr, dev, runner, event) stream [@@deriving sexp_of]
92-
type nonrec context = (buffer_ptr, stream) context [@@deriving sexp_of]
92+
type nonrec context = (buffer_ptr, stream, optimize_ctx) context [@@deriving sexp_of]
93+
end
94+
module Device_types_ll (Device_config : Device_config_common) = struct
95+
include Device_config
96+
type optimize_ctx = Low_level.optimize_ctx [@@deriving sexp_of]
97+
let empty_optimize_ctx = { Low_level.computations = Hashtbl.create (module Tnode) }
98+
99+
type nonrec device = (buffer_ptr, dev, runner, event) device [@@deriving sexp_of]
100+
type nonrec stream = (buffer_ptr, dev, runner, event) stream [@@deriving sexp_of]
101+
type nonrec context = (buffer_ptr, stream, Low_level.optimize_ctx) context [@@deriving sexp_of]
93102
end
94103

95104
let next_global_device_id : Utils.atomic_int = Atomic.make 0
@@ -133,12 +142,13 @@ struct
133142

134143
let get_name stream = [%string "%{name}:%{stream.device.ordinal#Int}:%{stream.stream_id#Int}"]
135144

136-
let make_context ?(ctx_arrays = Map.empty (module Tnode)) stream =
137-
{ stream; parent = None; ctx_arrays; finalized = Atomic.make false }
145+
let make_context ?(ctx_arrays = Map.empty (module Tnode)) ?(optimize_ctx = empty_optimize_ctx) stream =
146+
{ stream; parent = None; ctx_arrays; finalized = Atomic.make false; optimize_ctx }
138147

139-
let make_child ?ctx_arrays parent =
148+
let make_child ?ctx_arrays ?optimize_ctx parent =
140149
let ctx_arrays = Option.value ctx_arrays ~default:parent.ctx_arrays in
141-
{ stream = parent.stream; parent = Some parent; ctx_arrays; finalized = Atomic.make false }
150+
let optimize_ctx = Option.value optimize_ctx ~default:parent.optimize_ctx in
151+
{ stream = parent.stream; parent = Some parent; ctx_arrays; finalized = Atomic.make false; optimize_ctx }
142152
end
143153

144154
(** Parts shared by backend implementations. *)
@@ -213,24 +223,25 @@ module type No_buffer_retrieval_or_syncing = sig
213223
end
214224

215225
(** An intermediate stage for converting {!Lowered_no_device_backend} backends into
216-
{!Lowered_backend}. *)
226+
{!Lowered_backend}. It could potentially be used for assignments-level backends too. *)
217227
module type With_scheduler = sig
218228
include Backend_device_common
219229

220230
val schedule_task : stream -> Task.t -> unit
221231
end
222232

223233
(** Lowered-level backend interface: implementation-facing API for device-based (GPU, or CPU after
224-
adding a scheduler) backends. *)
234+
adding a scheduler) backends based on the {!Low_level} IR. *)
225235
module type Lowered_backend = sig
226-
include Backend_device_common
236+
include Backend_device_common with type optimize_ctx := Low_level.optimize_ctx
227237

228238
include
229239
No_buffer_retrieval_or_syncing
230240
with type buffer_ptr := buffer_ptr
231241
and type dev := dev
232242
and type runner := runner
233243
and type event := event
244+
and type optimize_ctx := Low_level.optimize_ctx
234245

235246
type code [@@deriving sexp_of]
236247
type code_batch [@@deriving sexp_of]

arrayjit/lib/backend_intf.ml

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ type 'context routine = {
6060
}
6161
[@@deriving sexp_of]
6262

63-
module type Device_config = sig
63+
module type Device_config_common = sig
6464
include Buffer
6565

6666
type dev [@@deriving sexp_of]
@@ -73,10 +73,19 @@ module type Device_config = sig
7373
(** An event tracks if a stream finished computing past a particular point in its schedue. These
7474
values are used internally for scheduling across streams of the backend, and can be used for
7575
explicit scheduling. *)
76-
7776
val name : string
7877
end
7978

79+
module type Device_config = sig
80+
include Device_config_common
81+
82+
type optimize_ctx [@@deriving sexp_of]
83+
(** The optimization context for compiling code, in particular {!Low_level.optimize_ctx} for
84+
low-level backends. *)
85+
86+
val empty_optimize_ctx : optimize_ctx
87+
end
88+
8089
type ('buffer_ptr, 'dev, 'runner, 'event) device_ref = {
8190
dev : 'dev;
8291
ordinal : int;
@@ -171,13 +180,14 @@ type ('buffer_ptr, 'dev, 'runner, 'event) stream =
171180

172181
let equal_stream = equal_stream_ref
173182

174-
type ('buffer_ptr, 'stream) context = {
183+
type ('buffer_ptr, 'stream, 'optimize_ctx) context = {
175184
stream : 'stream;
176-
parent : ('buffer_ptr, 'stream) context option;
185+
parent : ('buffer_ptr, 'stream, 'optimize_ctx) context option;
177186
ctx_arrays : 'buffer_ptr ctx_arrays;
178187
(** This map contains arrays used in this context or an ancestor context (they might be unique
179188
but might also be cross-stream shared. *)
180189
finalized : Utils.atomic_bool;
190+
optimize_ctx : 'optimize_ctx;
181191
}
182192
[@@deriving sexp_of]
183193

@@ -186,7 +196,7 @@ module type Device_types = sig
186196

187197
type nonrec device = (buffer_ptr, dev, runner, event) device [@@deriving sexp_of]
188198
type nonrec stream = (buffer_ptr, dev, runner, event) stream [@@deriving sexp_of]
189-
type nonrec context = (buffer_ptr, stream) context [@@deriving sexp_of]
199+
type nonrec context = (buffer_ptr, stream, optimize_ctx) context [@@deriving sexp_of]
190200
end
191201

192202
module type Device = sig
@@ -196,10 +206,10 @@ module type Device = sig
196206
val make_device : dev -> ordinal:int -> device
197207
val make_stream : device -> runner -> stream
198208

199-
val make_context : ?ctx_arrays:ctx_arrays -> stream -> context
209+
val make_context : ?ctx_arrays:ctx_arrays -> ?optimize_ctx:optimize_ctx -> stream -> context
200210
(** Returns a context without a parent. *)
201211

202-
val make_child : ?ctx_arrays:ctx_arrays -> context -> context
212+
val make_child : ?ctx_arrays:ctx_arrays -> ?optimize_ctx:optimize_ctx -> context -> context
203213
(** Returns a context with the same {!field:Backend_intf.context.stream}, and
204214
{!field:Backend_intf.context.ctx_arrays} if omitted, as the given context's, which is also the
205215
{!field:Backend_intf.context.parent}. *)
@@ -213,13 +223,18 @@ module type Backend_common = sig
213223

214224
type code [@@deriving sexp_of]
215225
type code_batch [@@deriving sexp_of]
226+
type optimize_ctx [@@deriving sexp_of]
227+
228+
val empty_optimize_ctx : optimize_ctx
229+
val get_optimize_ctx : code -> optimize_ctx
230+
val get_optimize_ctx_batch : code_batch -> optimize_ctx
216231

217-
val compile : Low_level.optimize_ctx -> ?name:string -> Indexing.unit_bindings -> Assignments.comp -> code
232+
val compile : optimize_ctx -> ?name:string -> Indexing.unit_bindings -> Assignments.comp -> code
218233
(** [name] is used to derive names for compilation artifacts. If omitted, it's derived via
219234
{!Assignments.get_name_exn}. *)
220235

221236
val compile_batch :
222-
Low_level.optimize_ctx ->
237+
optimize_ctx ->
223238
?names:string array ->
224239
?occupancy:(name:string -> src_n:int -> bool) ->
225240
Indexing.unit_bindings ->
@@ -330,7 +345,9 @@ end
330345

331346
module type Backend = sig
332347
include Backend_common
333-
include Backend_device_common with type buffer_ptr := buffer_ptr
348+
349+
include
350+
Backend_device_common with type buffer_ptr := buffer_ptr and type optimize_ctx := optimize_ctx
334351

335352
val link : context -> code -> context routine
336353
(** Returns the routine for the code's procedure, in a new context derived from the given context.

arrayjit/lib/backends.ml

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,9 @@ let lower_batch_assignments optim_ctx ?names ?occupancy bindings asgns_l =
209209
let asgns = asgns_l.(src_n) in
210210
if occupancy ~name ~src_n then
211211
( Some name,
212-
Some (Assignments.lower optim_ctx ~unoptim_ll_source ~ll_source ~cd_source ~name bound asgns) )
212+
Some
213+
(Assignments.lower optim_ctx ~unoptim_ll_source ~ll_source ~cd_source ~name bound
214+
asgns) )
213215
else (None, None))
214216

215217
let%debug3_sexp verify_prior_context ~use_host_memory ~ctx_arrays ~from_prior_context : unit =
@@ -232,13 +234,23 @@ let%debug3_sexp from_prior_context_batch ~use_host_memory (comps : Assignments.c
232234
module Add_device
233235
(Add_scheduler : functor
234236
(Impl : For_add_scheduler)
235-
-> With_scheduler with type buffer_ptr = Impl.buffer_ptr)
237+
->
238+
With_scheduler
239+
with type buffer_ptr = Impl.buffer_ptr
240+
and type optimize_ctx = Low_level.optimize_ctx)
236241
(Backend : Lowered_no_device_backend)
237242
(Config : sig
238243
val config : config
239-
end) : Lowered_backend = struct
244+
end)
245+
(* : Lowered_backend *) =
246+
struct
240247
include Backend
241248

249+
include Add_scheduler (struct
250+
include Backend
251+
include Config
252+
end)
253+
242254
type code = { lowered : Low_level.optimized; proc : Backend.procedure } [@@deriving sexp_of]
243255

244256
type code_batch = {
@@ -255,11 +267,6 @@ module Add_device
255267
let procs = compile_batch ~names bindings lowereds in
256268
{ lowereds; procs }
257269

258-
include Add_scheduler (struct
259-
include Backend
260-
include Config
261-
end)
262-
263270
let link context (code : code) ctx_arrays =
264271
let runner_label = get_name context.stream in
265272
let merge_buffer = context.stream.merge_buffer in
@@ -330,8 +337,14 @@ module Add_device
330337
end
331338

332339
module Raise_backend (Device : Lowered_backend) : Backend = struct
333-
include Device
334-
include Add_buffer_retrieval_and_syncing (Device)
340+
module Device_with_optimize_ctx = struct
341+
include Device
342+
343+
type optimize_ctx = Low_level.optimize_ctx [@@deriving sexp_of]
344+
end
345+
346+
include Device_with_optimize_ctx
347+
include Add_buffer_retrieval_and_syncing (Device_with_optimize_ctx)
335348

336349
type nonrec code = {
337350
from_prior_context : Set.M(Tnode).t;
@@ -351,6 +364,15 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
351364
}
352365
[@@deriving sexp_of]
353366

367+
type nonrec optimize_ctx = Low_level.optimize_ctx
368+
369+
let empty_optimize_ctx = { Low_level.computations = Hashtbl.create (module Tnode) }
370+
let get_optimize_ctx (code : code) = code.lowered.optimize_ctx
371+
372+
let get_optimize_ctx_batch (code_batch : code_batch) =
373+
Array.find_map code_batch.lowereds ~f:(Option.map ~f:(fun l -> l.Low_level.optimize_ctx))
374+
|> Option.value ~default:empty_optimize_ctx
375+
354376
let%debug3_sexp compile optim_ctx ?name bindings (comp : Assignments.comp) : code =
355377
let (name : string), (lowered : Low_level.optimized) =
356378
lower_assignments optim_ctx ?name bindings comp.asgns
@@ -361,8 +383,8 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
361383
in
362384
{ from_prior_context; name; lowered; code; expected_merge_node = lowered.Low_level.merge_node }
363385

364-
let%debug3_sexp compile_batch optim_ctx ?names ?occupancy bindings (comps : Assignments.comp array) :
365-
code_batch =
386+
let%debug3_sexp compile_batch optim_ctx ?names ?occupancy bindings
387+
(comps : Assignments.comp array) : code_batch =
366388
let names, lowereds =
367389
lower_batch_assignments optim_ctx ?names ?occupancy bindings
368390
@@ Array.map comps ~f:(fun c -> c.asgns)
@@ -487,7 +509,10 @@ end
487509
module Make_device_backend_from_lowered
488510
(Add_scheduler : functor
489511
(Impl : For_add_scheduler)
490-
-> With_scheduler with type buffer_ptr = Impl.buffer_ptr)
512+
->
513+
With_scheduler
514+
with type buffer_ptr = Impl.buffer_ptr
515+
and type optimize_ctx = Low_level.optimize_ctx)
491516
(Backend_impl : Lowered_no_device_backend)
492517
(Config : sig
493518
val config : config
@@ -498,12 +523,13 @@ struct
498523
include Backend_device
499524
end
500525

501-
let finalize (type buffer_ptr dev runner event)
526+
let finalize (type buffer_ptr dev runner event optimize_ctx)
502527
(module Backend : Backend
503528
with type buffer_ptr = buffer_ptr
504529
and type dev = dev
505530
and type runner = runner
506-
and type event = event) (ctx : Backend.context) : unit =
531+
and type event = event
532+
and type optimize_ctx = optimize_ctx) (ctx : Backend.context) : unit =
507533
Option.iter Backend.free_buffer ~f:(fun mem_free ->
508534
if Atomic.compare_and_set ctx.finalized false true then (
509535
Backend.await ctx.stream;

arrayjit/lib/backends.mli

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@ open Base
44
module Schedulers = Schedulers
55

66
val finalize :
7-
'buffer_ptr 'dev 'runner 'event.
7+
'buffer_ptr 'dev 'runner 'event 'optimize_ctx.
88
(module Ir.Backend_intf.Backend
99
with type buffer_ptr = 'buffer_ptr
1010
and type dev = 'dev
1111
and type event = 'event
12-
and type runner = 'runner) ->
13-
('buffer_ptr, ('buffer_ptr, 'dev, 'runner, 'event) Ir.Backend_intf.stream) Ir.Backend_intf.context ->
12+
and type runner = 'runner
13+
and type optimize_ctx = 'optimize_ctx) ->
14+
('buffer_ptr, ('buffer_ptr, 'dev, 'runner, 'event) Ir.Backend_intf.stream, 'optimize_ctx) Ir.Backend_intf.context ->
1415
unit
1516
(** Frees the arrays that are specific to the context -- not contained in the parent context. Note:
1617
use [finalize] to optimize memory, it is not obligatory because all arrays are freed when their

arrayjit/lib/low_level.mli

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ type optimize_ctx = {
116116
some assignment, the node cannot be virtual. Currently, we only allow for-loop symbols in
117117
assignment indices of virtual nodes. *)
118118
}
119+
[@@deriving sexp_of]
119120

120121
type optimized = {
121122
traced_store : traced_store;

arrayjit/lib/lowered_backend_missing.ml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ struct
1414
type dev
1515
type runner
1616
type event
17+
type optimize_ctx = Low_level.optimize_ctx [@@deriving sexp_of]
1718

19+
let empty_optimize_ctx = { Low_level.computations = Hashtbl.create (module Tnode) }
1820
let use_host_memory = None
1921

2022
let sexp_of_dev _dev =
@@ -38,7 +40,7 @@ struct
3840
let sexp_of_stream _stream =
3941
failwith @@ "Backend " ^ Config.name ^ " missing -- install the corresponding library"
4042

41-
type nonrec context = (buffer_ptr, stream) Backend_intf.context
43+
type nonrec context = (buffer_ptr, stream, Low_level.optimize_ctx) Backend_intf.context
4244

4345
let sexp_of_context _context =
4446
failwith @@ "Backend " ^ Config.name ^ " missing -- install the corresponding library"
@@ -57,10 +59,10 @@ struct
5759
let make_stream _device =
5860
failwith @@ "Backend " ^ Config.name ^ " missing -- install the corresponding library"
5961

60-
let make_context ?ctx_arrays:_ _stream =
62+
let make_context ?ctx_arrays:_ ?optimize_ctx:_ _stream =
6163
failwith @@ "Backend " ^ Config.name ^ " missing -- install the corresponding library"
6264

63-
let make_child ?ctx_arrays:_ _context =
65+
let make_child ?ctx_arrays:_ ?optimize_ctx:_ _context =
6466
failwith @@ "Backend " ^ Config.name ^ " missing -- install the corresponding library"
6567

6668
let get_name _stream =

arrayjit/lib/metal_backend.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ module Device_config = struct
4848
let name = "metal"
4949
end
5050

51-
module Device_stream = Backend_impl.Device_types (Device_config)
51+
module Device_stream = Backend_impl.Device_types_ll (Device_config)
5252

5353
(* Bring types into scope *)
5454
open Device_config

arrayjit/lib/schedulers.ml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ let _get_local_debug_runtime = Utils.get_local_debug_runtime
99
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]
1010

1111
module Multicore (Backend : For_add_scheduler) :
12-
With_scheduler with type buffer_ptr = Backend.buffer_ptr = struct
12+
With_scheduler
13+
with type buffer_ptr = Backend.buffer_ptr
14+
and type optimize_ctx = Ir.Low_level.optimize_ctx = struct
1315
include Backend
1416
module Domain = Domain [@warning "-3"]
1517
(* Currently, Backend.config is not used. *)
@@ -54,7 +56,7 @@ module Multicore (Backend : For_add_scheduler) :
5456
let name = "multicore_" ^ Backend.name
5557
end
5658

57-
module Device_types = Device_types (Device_config)
59+
module Device_types = Device_types_ll (Device_config)
5860
include Device (Device_types) (Alloc_buffer_ignore_stream (Device_types) (Backend))
5961
open Device_config
6062

@@ -250,7 +252,7 @@ module Sync (Backend : For_add_scheduler) = struct
250252
let name = "sync_" ^ Backend.name
251253
end
252254

253-
module Device_types = Device_types (Device_config)
255+
module Device_types = Device_types_ll (Device_config)
254256
include Device (Device_types) (Alloc_buffer_ignore_stream (Device_types) (Backend))
255257
open Device_config
256258

0 commit comments

Comments
 (0)