Skip to content

Commit 4aef0cc

Browse files
committed
Rename non_embedded/embedded distinction to inputs/outputs,
defensively fix (make more precise) handling of grad nodes when computing inputs/outputs.
1 parent 5eab2ea commit 4aef0cc

File tree

6 files changed

+80
-50
lines changed

6 files changed

+80
-50
lines changed

arrayjit/lib/backends.ml

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ module type No_device_backend = sig
5353
[occupancy] returns true are included. *)
5454

5555
val link :
56-
?from_prior_context:Tnode.t list ->
56+
?from_prior_context:Set.M(Tnode).t ->
5757
merge_buffer:(buffer_ptr * Tnode.t) option ref ->
5858
context ->
5959
code ->
@@ -64,7 +64,7 @@ module type No_device_backend = sig
6464
context, they must be part of the given context. *)
6565

6666
val link_batch :
67-
?from_prior_context:Tnode.t list ->
67+
?from_prior_context:Set.M(Tnode).t ->
6868
merge_buffer:(buffer_ptr * Tnode.t) option ref ->
6969
context ->
7070
code_batch ->
@@ -90,13 +90,19 @@ end
9090
module type Backend = sig
9191
include No_device_backend
9292

93-
val link : ?from_prior_context:Tnode.t list -> context -> code -> routine
94-
(** Returns the routine for the code's procedure, in a new context derived from the given context. *)
93+
val link : ?from_prior_context:Set.M(Tnode).t -> context -> code -> routine
94+
(** Returns the routine for the code's procedure, in a new context derived from the given context.
95+
96+
The [from_prior_context] nodes must not be added to the resulting context -- if needed in
97+
context, they must be part of the given context. *)
9598

9699
val link_batch :
97-
?from_prior_context:Tnode.t list -> context -> code_batch -> context * routine option array
100+
?from_prior_context:Set.M(Tnode).t -> context -> code_batch -> context * routine option array
98101
(** Returns the routines for the procedures included in the code batch. The returned context is
99-
downstream of all the returned routines. *)
102+
downstream of all the returned routines.
103+
104+
The [from_prior_context] nodes must not be added to the resulting context -- if needed in
105+
context, they must be part of the given context. *)
100106

101107
type event
102108
(** An event tracks if a device finished computing past a particular point in its schedue. These
@@ -805,7 +811,7 @@ end
805811
let verify_prior_context ~ctx_arrays ~is_in_context ~prior_context ~from_prior_context traced_stores
806812
=
807813
let olds = ctx_arrays prior_context in
808-
List.iter from_prior_context ~f:(fun tn ->
814+
Set.iter from_prior_context ~f:(fun tn ->
809815
let node = Array.find_map traced_stores ~f:(fun store -> Hashtbl.find store tn) in
810816
if
811817
Option.value_map node ~default:false ~f:(fun node ->
@@ -875,7 +881,8 @@ module Simple_no_device_backend (Backend : Simple_backend) : No_device_backend =
875881
if shared then Compiled (lowereds, compile_batch ~names ~opt_ctx_arrays:None bindings lowereds)
876882
else Postponed { lowereds; bindings; names }
877883

878-
let link ?(from_prior_context = []) ~merge_buffer (prior_context : context) (code : code) =
884+
let link ?(from_prior_context = Set.empty (module Tnode)) ~merge_buffer (prior_context : context)
885+
(code : code) =
879886
Backend.(
880887
verify_prior_context ~ctx_arrays ~is_in_context ~prior_context ~from_prior_context
881888
[| get_traced_store code |]);
@@ -890,8 +897,8 @@ module Simple_no_device_backend (Backend : Simple_backend) : No_device_backend =
890897
in
891898
{ context; schedule; bindings; name }
892899

893-
let link_batch ?(from_prior_context = []) ~merge_buffer (prior_context : context)
894-
(code_batch : code_batch) =
900+
let link_batch ?(from_prior_context = Set.empty (module Tnode)) ~merge_buffer
901+
(prior_context : context) (code_batch : code_batch) =
895902
Backend.(
896903
verify_prior_context ~ctx_arrays ~is_in_context ~prior_context ~from_prior_context
897904
@@ get_traced_stores code_batch);
@@ -975,13 +982,13 @@ module Cuda_backend : Backend = struct
975982
Option.(join @@ map lowered ~f:(fun optim -> optim.Low_level.merge_node)));
976983
}
977984

978-
let link ?(from_prior_context = []) context code =
985+
let link ?(from_prior_context = Set.empty (module Tnode)) context code =
979986
verify_prior_context ~ctx_arrays ~is_in_context ~prior_context:context.ctx ~from_prior_context
980987
[| code.traced_store |];
981988
let ctx, bindings, schedule = link context.ctx code.code in
982989
{ context = { ctx; expected_merge_node = code.expected_merge_node }; schedule; bindings; name }
983990

984-
let link_batch ?(from_prior_context = []) context code_batch =
991+
let link_batch ?(from_prior_context = Set.empty (module Tnode)) context code_batch =
985992
verify_prior_context ~ctx_arrays ~is_in_context ~prior_context:context.ctx ~from_prior_context
986993
code_batch.traced_stores;
987994
let ctx, bindings, schedules = link_batch context.ctx code_batch.code_batch in

bin/compilation_speed.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ let benchmark_overhead backend () =
3131
let init_assign_x = link ctx @@ compile ~name:"init_assign_x" IDX.empty mock_update_x in
3232
let f_routine = link init_assign_x.context @@ compile IDX.empty update_f.fwd_bprop in
3333
Tensor.print_tree ~with_grad:true ~with_backend_info:true ~depth:9 f;
34-
Tensor.iter_embedded_arrays f ~f:(fun a -> ignore (from_host f_routine.context a : bool));
34+
Tensor.iter_outputs f ~f:(fun a -> ignore (from_host f_routine.context a : bool));
3535

3636
let xs = Array.init n_data ~f:Float.(fun i -> of_int i - (of_int n_data /. 2.)) in
3737
let open Operation.At in

bin/zero2hero_1of7.ml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,9 @@ let _suspended () =
163163
let device = new_virtual_device @@ get_device ~ordinal:0 in
164164
let update = Train.grad_update l in
165165
let routine = link (init device) @@ compile IDX.empty @@ update.fwd_bprop in
166-
Tensor.iter_embedded_arrays l ~f:(fun a -> ignore (from_host routine.context a : bool));
166+
Tensor.iter_outputs l ~f:(fun a -> ignore (from_host routine.context a : bool));
167167
Train.run routine;
168-
Tensor.iter_embedded_arrays l ~f:(fun a -> ignore (to_host routine.context a : bool));
168+
Tensor.iter_outputs l ~f:(fun a -> ignore (to_host routine.context a : bool));
169169
await device;
170170
Stdio.print_endline
171171
{|
@@ -177,7 +177,7 @@ let _suspended () =
177177
link routine.context @@ compile IDX.empty @@ Train.sgd_update ~learning_rate update
178178
in
179179
(* learning_rate is virtual so this will not print anything. *)
180-
Tensor.iter_embedded_arrays learning_rate ~f:(fun a ->
180+
Tensor.iter_outputs learning_rate ~f:(fun a ->
181181
ignore (from_host routine.context a : bool));
182182
Stdio.print_endline
183183
{|
@@ -187,7 +187,7 @@ let _suspended () =
187187
List.iter [ a.value; b.value; c.value; f.value ] ~f:(fun a ->
188188
assert (from_host routine.context a));
189189
Train.run routine;
190-
Tensor.iter_embedded_arrays l ~f:(fun a -> ignore (to_host routine.context a : bool));
190+
Tensor.iter_outputs l ~f:(fun a -> ignore (to_host routine.context a : bool));
191191
await device;
192192
Stdio.print_endline
193193
{|
@@ -198,7 +198,7 @@ let _suspended () =
198198
let update = Train.grad_update l in
199199
let routine = link routine.context @@ compile IDX.empty update.fwd_bprop in
200200
Train.run routine;
201-
Tensor.iter_embedded_arrays l ~f:(fun a -> ignore (to_host routine.context a : bool));
201+
Tensor.iter_outputs l ~f:(fun a -> ignore (to_host routine.context a : bool));
202202
await device;
203203
Stdio.print_endline
204204
{|

lib/tensor.ml

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ module Idx = Arrayjit.Indexing
66
module Debug_runtime = Arrayjit.Utils.Debug_runtime
77

88
type tn = Tn.t
9+
type tn_set = Set.M(Arrayjit.Tnode).t
910
type asgns = Asgns.t
1011
type init_op = Arrayjit.Ops.init_op
1112
type fetch_op = Asgns.fetch_op
@@ -23,9 +24,10 @@ type t = {
2324
forward : Asgns.t;
2425
diff : diff option;
2526
id : int;
26-
value : Tn.t;
27+
value : tn;
2728
shape : Shape.t;
2829
children : subtensor list;
30+
non_embedded : tn_set;
2931
}
3032

3133
and subtensor = { subtensor : t; embedded : bool }
@@ -147,12 +149,14 @@ let op ~(label : string list) ?(compose_op = Shape.Pointwise_bin)
147149
?(transpose_op = Shape.Pointwise_un) ?(init_op = default_init_op) ~op_asn ~grad_asn
148150
?(grad_spec = If_needed) make_shape (orig_ts : t list) : t =
149151
let ordered_ts = List.dedup_and_sort orig_ts ~compare:(fun t1 t2 -> Int.ascending t1.id t2.id) in
152+
let non_embedded = ref @@ Set.empty (module Tn) in
150153
let children =
151154
List.folding_map orig_ts
152155
~init:(Set.empty (module Int))
153156
~f:(fun used ti ->
154-
( Set.add used ti.id,
155-
{ subtensor = ti; embedded = is_fwd_root ti && not (Set.mem used ti.id) } ))
157+
let root = is_fwd_root ti in
158+
if not root then non_embedded := Set.add !non_embedded ti.value;
159+
(Set.add used ti.id, { subtensor = ti; embedded = root && not (Set.mem used ti.id) }))
156160
in
157161
let id = session_state.next_id in
158162
session_state.next_id <- session_state.next_id + 1;
@@ -187,7 +191,9 @@ let op ~(label : string list) ?(compose_op = Shape.Pointwise_bin)
187191
|| Fn.non is_require_grad grad_spec
188192
&& List.for_all orig_ts ~f:(fun ti -> Option.is_none ti.diff)
189193
then (
190-
let tensor = { forward; diff = None; id; value = v; shape; children } in
194+
let tensor =
195+
{ forward; diff = None; id; value = v; shape; children; non_embedded = !non_embedded }
196+
in
191197
session_state.forward_roots <- Map.add_exn session_state.forward_roots ~key:id ~data:tensor;
192198
tensor)
193199
else
@@ -216,7 +222,11 @@ let op ~(label : string list) ?(compose_op = Shape.Pointwise_bin)
216222
that all ancestors of a node are backpropagated before the node is backpropagated, even for
217223
non-tree DAGs. *)
218224
let backprop =
219-
let bprop = dcode ~f:(fun diff -> diff.backprop) in
225+
let bprop =
226+
dcode ~f:(fun diff ->
227+
non_embedded := Set.add !non_embedded diff.grad;
228+
diff.backprop)
229+
in
220230
let bcks =
221231
List.map ordered_ts ~f:(fun ti -> if is_bck_root ti then bprop ti else Asgns.Noop)
222232
in
@@ -226,7 +236,7 @@ let op ~(label : string list) ?(compose_op = Shape.Pointwise_bin)
226236
session_state.backprop_roots <- Map.remove session_state.backprop_roots ti.id);
227237
(* The order is not relevant, we keep the same order as in backprop for readability. *)
228238
let diff = Some { grad = g; zero_grads; backprop } in
229-
let tensor = { forward; diff; id; value = v; shape; children } in
239+
let tensor = { forward; diff; id; value = v; shape; children; non_embedded = !non_embedded } in
230240
session_state.forward_roots <- Map.add_exn session_state.forward_roots ~key:id ~data:tensor;
231241
session_state.backprop_roots <- Map.add_exn session_state.backprop_roots ~key:id ~data:tensor;
232242
tensor
@@ -350,30 +360,33 @@ let param ?(more_label = []) ?input_dims ?output_dims ?input_axes ?output_axes ?
350360
Tn.update_memory_mode g Never_virtual 26;
351361
t
352362

353-
let rec iter_embedded_arrays ~f t =
354-
f t.value;
355-
Option.iter t.diff ~f:(fun diff -> f diff.grad);
356-
List.iter ~f:(fun ch -> if ch.embedded then iter_embedded_arrays ~f ch.subtensor) t.children
357-
358-
let rec non_and_embedded_nodes t =
363+
let rec inputs_and_outputs t =
364+
(* TODO: consider either caching here, or as a field of t. *)
365+
let opt_grad t = Option.value_map ~default:[] ~f:(fun diff -> [ diff.grad ]) t.diff in
366+
let dir_outputs t =
367+
Set.of_list (module Tn)
368+
@@ List.filter ~f:(fun tn -> not @@ Set.mem t.non_embedded tn)
369+
@@ (t.value :: opt_grad t)
370+
in
371+
let open Arrayjit.Utils.Set_O in
359372
let non_embedded, embedded =
360373
List.fold t.children
361-
~init:(Set.empty (module Self), Set.empty (module Self))
374+
~init:(t.non_embedded, Set.of_list (module Tn) (t.value :: opt_grad t))
362375
~f:(fun (non_embedded, embedded) ch ->
363-
if ch.embedded then (non_embedded, Set.add embedded ch.subtensor)
364-
else (Set.add non_embedded ch.subtensor, embedded))
376+
(ch.subtensor.non_embedded + non_embedded, dir_outputs ch.subtensor + embedded))
365377
in
366-
let open Arrayjit.Utils.Set_O in
367378
let non_embedded, embedded =
368379
List.fold t.children ~init:(non_embedded, embedded)
369380
~f:(fun ((non_embedded, embedded) as accu) ch ->
370381
if ch.embedded then
371-
let more_non, more = non_and_embedded_nodes ch.subtensor in
382+
let more_non, more = inputs_and_outputs ch.subtensor in
372383
(non_embedded + more_non, embedded + more)
373384
else accu)
374385
in
375386
(non_embedded - embedded, embedded)
376387

388+
let iter_outputs ~f t = Set.iter ~f @@ snd @@ inputs_and_outputs t
389+
let input_nodes t = fst @@ inputs_and_outputs t
377390
let debug_name t = Tn.debug_name t.value
378391
let debug_grad t = Tn.debug_name (Option.value_exn t.diff).grad
379392

lib/tensor.mli

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
open Base
44

55
type tn = Arrayjit.Tnode.t
6+
type tn_set = Set.M(Arrayjit.Tnode).t
67
type asgns = Arrayjit.Assignments.t
78
type init_op = Arrayjit.Ops.init_op
89
type fetch_op = Arrayjit.Assignments.fetch_op
@@ -26,11 +27,19 @@ type t = {
2627
(** The eventual shape of [t.value] and [t.diff.grad], incorporating the current state of
2728
shape inference. *)
2829
children : subtensor list;
30+
non_embedded : tn_set;
31+
(** These tensor nodes ([value], resp. [grad] of {!diff}) of the children which are not
32+
computed by [forward], resp. [backprop] of {!diff}. *)
2933
}
3034
[@@deriving sexp_of]
3135
(** Information needed for compositional code generation. *)
3236

33-
and subtensor = { subtensor : t; embedded : bool }
37+
and subtensor = {
38+
subtensor : t;
39+
embedded : bool;
40+
(** A tensor can be an [embedded] child at most once -- that's where its [forward] computation
41+
ends up when used as part of a bigger computation. *)
42+
}
3443

3544
type comparator_witness
3645

@@ -174,9 +183,6 @@ val param :
174183
[Require_grad]. The resulting tensor's label is the passed string, appended by [more_label] if
175184
any. *)
176185

177-
val iter_embedded_arrays : f:(tn -> unit) -> t -> unit
178-
val non_and_embedded_nodes : t -> (t, comparator_witness) Set.t * (t, comparator_witness) Set.t
179-
180186
val consume_forward_code : t -> asgns
181187
(** A forward root is a tensor that is not (currently) used to compute another tensor.
182188
[consume_forward_code t] ensures [t] is a forward root, removes it from forward roots, and
@@ -188,6 +194,14 @@ val consume_backprop_code : t -> asgns * asgns
188194
[consume_backprop_code t] ensures [t] is a backprop root, removes it from backprop roots, and
189195
checks that there are no other backprop roots for tensors with children. *)
190196

197+
val input_nodes : t -> tn_set
198+
(** The nodes of descendant tensors whose computation is not embedded by the given tensor. They are
199+
"inputs" coming from other computations. *)
200+
201+
val iter_outputs : f:(tn -> unit) -> t -> unit
202+
(** [iter_outputs t] iterates over all descendant nodes that are embedded, i.e. are not members
203+
of [input_nodes t]. *)
204+
191205
val unsafe_reinitialize : unit -> unit
192206
(** Bring global state to its initialization values. This invalidates any previously defined tensors
193207
and tensor nodes. Also reinitializes the modules: {!Shape}, {!Arrayjit.Tnode},

lib/train.ml

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -262,22 +262,18 @@ let%track3_sexp round_robin_dry_run ~num_devices jitbs ~dry_sync : unit =
262262
let set_virtual (a : Tn.t) = Tn.update_memory_mode a Virtual 29
263263

264264
let every_non_literal_on_host =
265-
Tensor.iter_embedded_arrays ~f:(fun a ->
265+
Tensor.iter_outputs ~f:(fun a ->
266266
if Tn.mode_is_unspecified a && not (Tn.known_constant a) then set_hosted a)
267267

268268
let%debug2_sexp all_host_to_device (type context)
269269
(module Backend : Backend_type with type context = context) context =
270270
let f tn = ignore (Backend.from_host context tn : bool) in
271-
Tensor.iter_embedded_arrays ~f
271+
Tensor.iter_outputs ~f
272272

273273
let%debug2_sexp all_device_to_host (type context)
274274
(module Backend : Backend_type with type context = context) context =
275275
let f tn = ignore (Backend.to_host context tn : bool) in
276-
Tensor.iter_embedded_arrays ~f
277-
278-
let needs_prior_context t =
279-
Tensor.non_and_embedded_nodes t |> fst |> Set.to_list
280-
|> List.concat_map ~f:(fun t -> t.value :: Option.(to_list @@ map t.diff ~f:(fun d -> d.grad)))
276+
Tensor.iter_outputs ~f
281277

282278
(** Executes the jitted code and copies arrays embedded in the given tenosor from and to host,
283279
synchronizes before copying to host. If [looping] is provided, loops over bindings and executes
@@ -352,7 +348,7 @@ let%track3_sexp parallel_update (type context)
352348
(* We can cache scheduling, because merging and copying does not depend on static indexing. *)
353349
let loss_merge =
354350
Backend.(
355-
link ~from_prior_context:(needs_prior_context updaten.loss) sgd_update.context
351+
link ~from_prior_context:(Tensor.input_nodes updaten.loss) sgd_update.context
356352
@@ compile Idx.Empty
357353
[%cd
358354
~~("merging" updaten.loss;
@@ -459,7 +455,7 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
459455
set_hosted learning_rate.value;
460456
let sgd = sgd_update ~learning_rate ~weight_decay update in
461457
let grad_update = Backend.compile ~shared:true bindings update.fwd_bprop in
462-
let from_prior_context = needs_prior_context update.loss in
458+
let from_prior_context = Tensor.input_nodes update.loss in
463459
let grad_updates =
464460
Array.map prior_contexts ~f:(fun ctx -> Backend.link ~from_prior_context ctx grad_update)
465461
in
@@ -511,7 +507,7 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
511507
(* By using sgd_update.context, maybe we don't need to copy the parameters back to the host. *)
512508
let routine =
513509
Backend.(
514-
link ~from_prior_context:(needs_prior_context model_result) sgd_update.context
510+
link ~from_prior_context:(Tensor.input_nodes model_result) sgd_update.context
515511
@@ compile IDX.empty
516512
@@ Block_comment ("infer " ^ Tn.debug_name model_result.value, infer_fwd))
517513
in
@@ -533,7 +529,7 @@ let%track3_sexp forward_and_ctx ?(disable_rootness_check = false) (type context)
533529
(module Backend : Backend_type with type context = context) ctx ?(bindings = IDX.empty) t =
534530
let routine =
535531
Backend.(
536-
link ~from_prior_context:(needs_prior_context t) ctx
532+
link ~from_prior_context:(Tensor.input_nodes t) ctx
537533
@@ compile bindings @@ forward ~disable_rootness_check t)
538534
in
539535
if not disable_rootness_check then Tensor.remove_bprop_root t;

0 commit comments

Comments
 (0)