Skip to content

Commit b5d6104

Browse files
committed
Get rid of Postponed
In the future, device-config-specific compilation will be handled by laziness and caching.
1 parent 93b427d commit b5d6104

File tree

9 files changed

+81
-111
lines changed

9 files changed

+81
-111
lines changed

arrayjit/lib/anatomy_of_a_backend.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,11 @@ The modules and files of `arrayjit` can loosely be divided into three parts.
4646
- `reinitialize` a backend,
4747
- `finalize` a context (freeing all of its arrays that don't come from its parent context).
4848

49-
### Shared (relocatable) compilation, batch compilation
49+
### Batch compilation; in the future: lazy and cached compilation artifacts
5050

51-
Shared (relocatable) compilation, with `~shared:true`, improves compilation efficiency, because code can be compiled once for use on multiple devices (in multiple contexts). It also improves debugging convenience, by generating fewer debugging artifacts. A potential downside is slightly less efficient computations.
51+
Batched compilation produces fewer debugging artifacts. The compilation might also be slightly more efficient since the compiler needs to be invoked fewer times. Batched compilation and linking process _many routines for one device/stream_ at once.
5252

53-
Batched compilation has similar benefits, especially in producing fewer debugging artifacts. The compilation might also be slightly more efficient since the compiler needs to be invoked fewer times. While `~shared:true` compiles _one routine for many devices_, batched compilation and linking process _many routines for one device_ at once.
53+
In the future, when we introduce program search, `compile` functions will return compilation artifact objects. They will manage compilation lazily, caching compilation keyed by (a configuration of) device.
5454

5555
## Tensor nodes, arrays, memory properties
5656

@@ -112,7 +112,7 @@ Contexts track (or store) the on-device arrays corresponding to tensor nodes. Co
112112

113113
## Typical details of a backend implementation
114114

115-
During the compilation process, the old context cannot be available if the backend supports shared compilation. A backend may for simplicity not suport shared compilation, i.e. ignore `~shared:true` and postpone compilation to the linking phase. Currently, the CUDA backend ignores `~shared:false` and always generates context-and-device-independent kernels, that refer to context (i.e. global) arrays via parameters.
115+
During the compilation process, the old context cannot be available when `compile` is handled. Currently, all backends generate context-and-device-independent kernels, that refer to context arrays via parameters.
116116

117117
We use keys of the `Low_level.traced_store` containers assuming that they are precisely the tensor nodes used in the compiled code -- and the `Virtual` nodes are the ones optimized-away. The context can contain nodes from the parent context corresponding to tensors only needed by parent or ancestor context's computations. The `get_ident` function (e.g. provided by `C_syntax`) returns a human-readable identifier that's un-ambiguous in the context of the compiled code (shared within `compile_batch`).
118118

arrayjit/lib/backend_impl.ml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,10 +233,9 @@ module type Lowered_backend = sig
233233
type code [@@deriving sexp_of]
234234
type code_batch [@@deriving sexp_of]
235235

236-
val compile : ?shared:bool -> name:string -> Indexing.unit_bindings -> Low_level.optimized -> code
236+
val compile : name:string -> Indexing.unit_bindings -> Low_level.optimized -> code
237237

238238
val compile_batch :
239-
?shared:bool ->
240239
names:string option array ->
241240
Indexing.unit_bindings ->
242241
Low_level.optimized option array ->

arrayjit/lib/backend_intf.ml

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -220,22 +220,21 @@ module type Backend_common = sig
220220
type code [@@deriving sexp_of]
221221
type code_batch [@@deriving sexp_of]
222222

223-
val compile : ?shared:bool -> ?name:string -> Indexing.unit_bindings -> Assignments.comp -> code
224-
(** If [~shared:true] (default [false]), the backend should prefer to do more compile work in a
225-
device-and-stream-agnostic way. If [~shared:false], the backend can opt to postpone compiling
226-
altogether until [link] is called, to benefit from more optimizations. *)
223+
val compile : ?name:string -> Indexing.unit_bindings -> Assignments.comp -> code
224+
(** [name] is used to derive names for compilation artifacts. If omitted, it's derived via
225+
{!Assignments.get_name_exn}. *)
227226

228227
val compile_batch :
229-
?shared:bool ->
230228
?names:string array ->
231229
?occupancy:(name:string -> src_n:int -> bool) ->
232230
Indexing.unit_bindings ->
233231
Assignments.comp array ->
234232
code_batch
235-
(** Unlike the [~shared] parameter, [compile_batch] vs. [compile] is mostly about improving the
236-
compile time and debugging convenience by generating fewer files -- ideally does not affect
237-
execution, but there can be backend-specific differences. Only array entries for which
238-
[occupancy] returns true are included. *)
233+
(** [compile_batch] vs. [compile] is mostly about improving the compile time and debugging
234+
convenience by generating fewer files -- ideally does not affect execution, but there can be
235+
backend-specific differences. Only array entries for which [occupancy] returns true are
236+
included. [names] are used to derive names for compilation artifacts. If omitted, they're
237+
derived via {!Assignments.get_name_exn}. *)
239238
end
240239

241240
(** Parts shared by both assignments-level and lowered-level backend interfaces providing streams

arrayjit/lib/backends.ml

Lines changed: 18 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -209,51 +209,28 @@ module Add_device
209209
(Backend : Lowered_no_device_backend) : Lowered_backend = struct
210210
include Backend
211211

212-
type code =
213-
| Postponed of {
214-
lowered : Low_level.optimized;
215-
bindings : Indexing.unit_bindings;
216-
name : string;
217-
}
218-
| Compiled of { lowered : Low_level.optimized; proc : Backend.procedure }
219-
[@@deriving sexp_of]
212+
type code = { lowered : Low_level.optimized; proc : Backend.procedure } [@@deriving sexp_of]
220213

221-
type code_batch =
222-
| Postponed of {
223-
lowereds : Low_level.optimized option array;
224-
bindings : Indexing.unit_bindings;
225-
names : string option array;
226-
}
227-
| Compiled of {
228-
lowereds : Low_level.optimized option array;
229-
procs : Backend.procedure option array;
230-
}
214+
type code_batch = {
215+
lowereds : Low_level.optimized option array;
216+
procs : Backend.procedure option array;
217+
}
231218
[@@deriving sexp_of]
232219

233-
let compile ?(shared = false) ~name bindings lowered : code =
234-
if shared then
235-
let proc = compile ~name ~opt_ctx_arrays:None bindings lowered in
236-
Compiled { lowered; proc }
237-
else Postponed { lowered; bindings; name }
220+
let compile ~name bindings lowered : code =
221+
let proc = compile ~name ~opt_ctx_arrays:None bindings lowered in
222+
{ lowered; proc }
238223

239-
let compile_batch ?(shared = false) ~names bindings lowereds : code_batch =
240-
if shared then
241-
let procs = compile_batch ~names ~opt_ctx_arrays:None bindings lowereds in
242-
Compiled { lowereds; procs }
243-
else Postponed { lowereds; bindings; names }
224+
let compile_batch ~names bindings lowereds : code_batch =
225+
let procs = compile_batch ~names ~opt_ctx_arrays:None bindings lowereds in
226+
{ lowereds; procs }
244227

245228
include Add_scheduler (Backend)
246229

247230
let link context (code : code) ctx_arrays =
248231
let runner_label = get_name context.stream in
249232
let merge_buffer = context.stream.merge_buffer in
250-
let bindings, to_schedule =
251-
match code with
252-
| Postponed { lowered; bindings; name } ->
253-
let proc = Backend.compile ~name ~opt_ctx_arrays:(Some ctx_arrays) bindings lowered in
254-
link_compiled ~merge_buffer ~runner_label ctx_arrays proc
255-
| Compiled { proc; _ } -> link_compiled ~merge_buffer ~runner_label ctx_arrays proc
256-
in
233+
let bindings, to_schedule = link_compiled ~merge_buffer ~runner_label ctx_arrays code.proc in
257234
let schedule =
258235
Task.enschedule ~schedule_task ~get_stream_name:get_name context.stream to_schedule
259236
in
@@ -262,14 +239,8 @@ module Add_device
262239
let link_batch context (code_batch : code_batch) ctx_arrays =
263240
let runner_label = get_name context.stream in
264241
let merge_buffer = context.stream.merge_buffer in
265-
let procs =
266-
match code_batch with
267-
| Postponed { lowereds; bindings; names } ->
268-
Backend.compile_batch ~names ~opt_ctx_arrays:(Some ctx_arrays) bindings lowereds
269-
| Compiled { procs; _ } -> procs
270-
in
271242
let bindings, schedules =
272-
Array.fold_mapi procs ~init:None ~f:(fun i bindings -> function
243+
Array.fold_mapi code_batch.procs ~init:None ~f:(fun i bindings -> function
273244
| Some proc ->
274245
let ctx_arrays = Option.value_exn ctx_arrays.(i) in
275246
let bindings', to_schedule =
@@ -348,21 +319,21 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
348319
}
349320
[@@deriving sexp_of]
350321

351-
let%debug3_sexp compile ?shared ?name bindings (comp : Assignments.comp) : code =
322+
let%debug3_sexp compile ?name bindings (comp : Assignments.comp) : code =
352323
let name, lowered = lower_assignments ?name bindings comp.Assignments.asgns in
353-
let code = compile ?shared ~name bindings lowered in
324+
let code = compile ~name bindings lowered in
354325
let from_prior_context =
355326
Set.diff (Assignments.context_nodes ~use_host_memory comp.asgns) comp.embedded_nodes
356327
in
357328
{ from_prior_context; name; lowered; code; expected_merge_node = lowered.Low_level.merge_node }
358329

359-
let%debug3_sexp compile_batch ?shared ?names ?occupancy bindings (comps : Assignments.comp array)
360-
: code_batch =
330+
let%debug3_sexp compile_batch ?names ?occupancy bindings (comps : Assignments.comp array) :
331+
code_batch =
361332
let names, lowereds =
362333
lower_batch_assignments ?names ?occupancy bindings
363334
@@ Array.map comps ~f:(fun c -> c.Assignments.asgns)
364335
in
365-
let code_batch = compile_batch ?shared ~names bindings lowereds in
336+
let code_batch = compile_batch ~names bindings lowereds in
366337
let from_prior_context =
367338
from_prior_context_batch ~use_host_memory
368339
@@ Array.mapi lowereds ~f:(fun i -> Option.map ~f:(fun _ -> comps.(i)))

arrayjit/lib/cuda_backend.cudajit.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ struct
337337
| _ -> ("(" ^ typ_of_prec to_ ^ ")(", ")")
338338
end
339339

340-
let compile ?shared:_ ~name bindings ({ Low_level.traced_store; _ } as lowered) =
340+
let compile ~name bindings ({ Low_level.traced_store; _ } as lowered) =
341341
(* TODO: The following link seems to claim it's better to expand into loops than use memset.
342342
https://stackoverflow.com/questions/23712558/how-do-i-best-initialize-a-local-memory-array-to-0 *)
343343
let module Syntax = C_syntax.C_syntax (C_syntax_config (struct
@@ -353,7 +353,7 @@ let compile ?shared:_ ~name bindings ({ Low_level.traced_store; _ } as lowered)
353353
let ptx = cuda_to_ptx ~name @@ Buffer.contents b in
354354
{ traced_store; ptx; params; bindings; name }
355355

356-
let compile_batch ?shared:_ ~names bindings lowereds =
356+
let compile_batch ~names bindings lowereds =
357357
let module Syntax = C_syntax.C_syntax (C_syntax_config (struct
358358
let procs = Array.filter_map lowereds ~f:(Option.map ~f:(fun lowereds -> (lowereds, None)))
359359
end)) in

arrayjit/lib/lowered_backend_missing.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,10 @@ type code_batch
7878

7979
let sexp_of_code_batch _code_batch = failwith "Backend missing -- install the corresponding library"
8080

81-
let compile ?shared:_ ~name:_ _unit_bindings _optimized =
81+
let compile ~name:_ _unit_bindings _optimized =
8282
failwith "Backend missing -- install the corresponding library"
8383

84-
let compile_batch ?shared:_ ~names:_ _unit_bindings _optimizeds =
84+
let compile_batch ~names:_ _unit_bindings _optimizeds =
8585
failwith "Backend missing -- install the corresponding library"
8686

8787
let link _context _code = failwith "Backend missing -- install the corresponding library"

bin/moons_benchmark.ml

Lines changed: 42 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,22 +34,25 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_streams ~batch_size ~b
3434
Tensor.default_grad_prec := grad_prec;
3535
Utils.settings.output_debug_files_in_build_directory <- true;
3636
(* This will only log from routines if log-level is high enough. *)
37-
Utils.settings.debug_log_from_routines <- true;
37+
(* Utils.settings.debug_log_from_routines <- true; *)
3838
Rand.init (* seed *) 0;
3939
let hid_dim_1 = 16 in
4040
let hid_dim_2 = 8 in
4141
let hid_dim_3 = 4 in
4242
(* TINY for debugging: *)
4343
(* let hid_dim = 2 in *)
44+
(* let hid_dim = 4 in *)
4445
let data_len = 3 * 5 * 1024 in
4546
(* TINY for debugging: *)
4647
(* let data_len = 3 * 4 in *)
48+
(* let data_len = 3 * 16 in *)
4749
let flat_len = data_len / 2 in
4850
(* Note: [minibatch_size = batch_size / num_streams] is the actual per-device batch used. *)
49-
let epochs = 200 in
51+
(* let epochs = 400 in *)
5052
(* let epochs = 100 in *)
5153
(* let epochs = 50 in *)
5254
(* TINY for debugging: *)
55+
let epochs = 3 in
5356
(* let epochs = 2 in *)
5457
(* let epochs = 1 in *)
5558
(* let init_lr = 0.1 in *)
@@ -84,12 +87,15 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_streams ~batch_size ~b
8487
Stdlib.Format.printf "Initial backend global debug info: %a\n%!" Sexp.pp_hum
8588
@@ Backend.get_global_debug_info ();
8689
let per_batch_callback ~at_batch:_ ~at_step:_ ~learning_rate:_ ~batch_loss:_ ~epoch_loss:_ =
90+
(* Stdio.printf "Batch=%d, step=%d, lr=%f, batch loss=%f, epoch loss=%f\n%!" at_batch at_step
91+
learning_rate batch_loss epoch_loss; *)
8792
if Option.is_none !start_time then start_time := Some (Time_now.nanoseconds_since_unix_epoch ())
8893
in
8994
(* Tn.print_accessible_headers (); *)
9095
let per_epoch_callback ~at_step ~at_epoch ~learning_rate ~epoch_loss =
91-
Stdio.printf "Epoch=%d, step=%d, lr=%f, epoch loss=%f\n%!" at_epoch at_step learning_rate
92-
epoch_loss
96+
if at_epoch % 10 = 9 then
97+
Stdio.printf "Epoch=%d, step=%d, lr=%f, epoch loss=%f\n%!" at_epoch at_step learning_rate
98+
epoch_loss
9399
in
94100

95101
Backend.initialize Train.BT.Most_parallel_streams;
@@ -115,22 +121,25 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_streams ~batch_size ~b
115121
Stdio.print_endline "\n******** mlp_result **********";
116122
Tensor.print_tree ~with_id:true ~with_grad:false ~depth:9 model_result;
117123
Stdio.printf "\n********\n%!";
124+
Arrayjit.Tnode.print_accessible_headers ();
118125
let callback (x, y) = Float.((infer_callback [| x; y |]).(0) >= 0.) in
119-
let plot_moons =
120-
let open PrintBox_utils in
121-
plot
122-
~size:(120, 40)
123-
(* TINY for debugging: *)
124-
(* ~size:(20, 10) *)
125-
~x_label:"ixes" ~y_label:"ygreks"
126-
[
127-
Scatterplot { points = points1; pixel = "#" };
128-
Scatterplot { points = points2; pixel = "%" };
129-
Boundary_map { pixel_false = "."; pixel_true = "*"; callback };
130-
]
126+
let%track3_sexp plot_moons () =
127+
[%log_level
128+
0;
129+
let open PrintBox_utils in
130+
plot
131+
~size:(120, 40)
132+
(* TINY for debugging: *)
133+
(* ~size:(20, 10) *)
134+
~x_label:"ixes" ~y_label:"ygreks"
135+
[
136+
Scatterplot { points = points1; pixel = "#" };
137+
Scatterplot { points = points2; pixel = "%" };
138+
Boundary_map { pixel_false = "."; pixel_true = "*"; callback };
139+
]]
131140
in
132141
Stdio.printf "\nHalf-moons scatterplot and decision boundary:\n%!";
133-
PrintBox_text.output Stdio.stdout plot_moons;
142+
PrintBox_text.output Stdio.stdout @@ plot_moons ();
134143
Stdio.printf "\nBatch Log-loss:\n%!";
135144
let plot_loss =
136145
let open PrintBox_utils in
@@ -181,6 +190,7 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_streams ~batch_size ~b
181190
}
182191
in
183192
Stdio.printf "\n\n%!";
193+
Arrayjit.Tnode.print_accessible_headers ();
184194
Stdlib.Format.printf "Final backend global debug info: %a\n%!" Sexp.pp_hum
185195
@@ Backend.get_global_debug_info ();
186196
result
@@ -211,13 +221,24 @@ let _cuda_benchmarks =
211221
]))))))
212222

213223
let _cuda_parallel_benchmarks =
214-
List.concat_map [ (* 1; 2; *) (* 3; 4; 5; 6; 8; 10; 12; 16; *) 20 (* 32; 64 *) ] ~f:(fun num_streams ->
224+
List.concat_map
225+
[
226+
(* 1; 2; *)
227+
3;
228+
(* 4; 5; 6; 8; 10; 12; 16; 20 *)
229+
(* 32; 64 *)
230+
] ~f:(fun num_streams ->
215231
List.concat_map
216-
[ 3 * 5 * 16 (* ; 3 * 5 * 32 *) ]
232+
[
233+
(* TINY for debugging: *)
234+
(* 3 * 4 *)
235+
3 * 5 * 16 (* ; 3 * 5 * 32 *);
236+
]
217237
~f:(fun batch_size ->
218-
List.concat_map [ (* 1; *) (* 2; *) 3 ] ~f:(fun inlining_cutoff ->
238+
List.concat_map [ (* 1; 2; *) 3 ] ~f:(fun inlining_cutoff ->
219239
List.concat_map [ (* 1; 3; *) 7 (* *) ] ~f:(fun seed ->
220-
List.concat_map [ (* "gccjit" ; "cc"; *) "cuda" ] ~f:(fun backend_name ->
240+
List.concat_map [ (* "gccjit"; "cuda" ;"cc"; *) "sync_cc" ]
241+
~f:(fun backend_name ->
221242
List.concat_map [ (* CDSL.double; *) CDSL.single (* ; CDSL.half *) ]
222243
~f:(fun value_prec ->
223244
[

lib/attic.mld

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -476,26 +476,6 @@ Old post-launch code in Cuda_backend.link_proc:
476476
data.tracking <- Some (Cu.Delimited_event.record context.stream.cu_stream));
477477
]}
478478

479-
Obsoleted part of interfaces in backend_impl.ml:
480-
{[
481-
let expected_merge_node : code -> _ = function
482-
| Postponed { lowered = Low_level.{ merge_node; _ }; _ }
483-
| Compiled { lowered = Low_level.{ merge_node; _ }; _ } ->
484-
merge_node
485-
486-
let expected_merge_nodes : code_batch -> _ = function
487-
| Postponed { lowereds; _ } | Compiled { lowereds; _ } ->
488-
Array.map lowereds ~f:(fun lowered ->
489-
Option.(join @@ map lowered ~f:(fun optim -> optim.merge_node)))
490-
491-
let get_lowered : code -> _ = function
492-
| Postponed { lowered; _ } | Compiled { lowered; _ } -> lowered
493-
494-
let get_lowereds : code_batch -> _ = function
495-
| Postponed { lowereds; _ } -> lowereds
496-
| Compiled { lowereds; _ } -> lowereds
497-
]}
498-
499479
Old context finalizer from the cuda backend:
500480
{[
501481
let%track3_sexp finalize (ctx : context) : unit =

lib/train.ml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ let%track3_sexp parallel_update (type buffer_ptr dev runner event)
363363
Array.mapi ctxs ~f:(fun dst_n ctx ->
364364
if occupancy_dst ~dst_n then
365365
snd
366-
@@ Backend.(link_batch ctx @@ compile_batch ~shared:true ~occupancy Idx.Empty grad_merges)
366+
@@ Backend.(link_batch ctx @@ compile_batch ~occupancy Idx.Empty grad_merges)
367367
else [||])
368368
in
369369
(* We can cache scheduling, because merging and copying does not depend on static indexing. *)
@@ -441,8 +441,8 @@ let to_routine (type buffer_ptr dev runner event)
441441
with type buffer_ptr = buffer_ptr
442442
and type dev = dev
443443
and type runner = runner
444-
and type event = event) (context : Backend.context) ?shared ?name bindings comp =
445-
Backend.link context @@ Backend.compile ?shared ?name bindings comp
444+
and type event = event) (context : Backend.context) ?name bindings comp =
445+
Backend.link context @@ Backend.compile ?name bindings comp
446446

447447
type example_train_result = {
448448
inputs : Tensor.t;
@@ -500,7 +500,7 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
500500
Utils.settings.check_half_prec_constants_cutoff, no need to upcast learning_rate.value. *)
501501
set_hosted learning_rate.value;
502502
let sgd = sgd_update ~learning_rate ~weight_decay update in
503-
let grad_update = Backend.compile ~shared:true bindings update.fwd_bprop in
503+
let grad_update = Backend.compile bindings update.fwd_bprop in
504504
let grad_updates = Array.map contexts ~f:(fun ctx -> Backend.link ctx grad_update) in
505505
let sgd_update = to_routine (module Backend) grad_updates.(0).context bindings sgd in
506506
Tensor.log_debug_info ~from_log_level:2 inputs;

0 commit comments

Comments
 (0)