Skip to content

Commit 7dd2d35

Browse files
committed
Rename iter_outputs -> iter_embedded to avoid confusion
1 parent 4aef0cc commit 7dd2d35

File tree

5 files changed

+16
-16
lines changed

5 files changed

+16
-16
lines changed

bin/compilation_speed.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ let benchmark_overhead backend () =
3131
let init_assign_x = link ctx @@ compile ~name:"init_assign_x" IDX.empty mock_update_x in
3232
let f_routine = link init_assign_x.context @@ compile IDX.empty update_f.fwd_bprop in
3333
Tensor.print_tree ~with_grad:true ~with_backend_info:true ~depth:9 f;
34-
Tensor.iter_outputs f ~f:(fun a -> ignore (from_host f_routine.context a : bool));
34+
Tensor.iter_embedded f ~f:(fun a -> ignore (from_host f_routine.context a : bool));
3535

3636
let xs = Array.init n_data ~f:Float.(fun i -> of_int i - (of_int n_data /. 2.)) in
3737
let open Operation.At in

bin/zero2hero_1of7.ml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,9 @@ let _suspended () =
163163
let device = new_virtual_device @@ get_device ~ordinal:0 in
164164
let update = Train.grad_update l in
165165
let routine = link (init device) @@ compile IDX.empty @@ update.fwd_bprop in
166-
Tensor.iter_outputs l ~f:(fun a -> ignore (from_host routine.context a : bool));
166+
Tensor.iter_embedded l ~f:(fun a -> ignore (from_host routine.context a : bool));
167167
Train.run routine;
168-
Tensor.iter_outputs l ~f:(fun a -> ignore (to_host routine.context a : bool));
168+
Tensor.iter_embedded l ~f:(fun a -> ignore (to_host routine.context a : bool));
169169
await device;
170170
Stdio.print_endline
171171
{|
@@ -177,7 +177,7 @@ let _suspended () =
177177
link routine.context @@ compile IDX.empty @@ Train.sgd_update ~learning_rate update
178178
in
179179
(* learning_rate is virtual so this will not print anything. *)
180-
Tensor.iter_outputs learning_rate ~f:(fun a ->
180+
Tensor.iter_embedded learning_rate ~f:(fun a ->
181181
ignore (from_host routine.context a : bool));
182182
Stdio.print_endline
183183
{|
@@ -187,7 +187,7 @@ let _suspended () =
187187
List.iter [ a.value; b.value; c.value; f.value ] ~f:(fun a ->
188188
assert (from_host routine.context a));
189189
Train.run routine;
190-
Tensor.iter_outputs l ~f:(fun a -> ignore (to_host routine.context a : bool));
190+
Tensor.iter_embedded l ~f:(fun a -> ignore (to_host routine.context a : bool));
191191
await device;
192192
Stdio.print_endline
193193
{|
@@ -198,7 +198,7 @@ let _suspended () =
198198
let update = Train.grad_update l in
199199
let routine = link routine.context @@ compile IDX.empty update.fwd_bprop in
200200
Train.run routine;
201-
Tensor.iter_outputs l ~f:(fun a -> ignore (to_host routine.context a : bool));
201+
Tensor.iter_embedded l ~f:(fun a -> ignore (to_host routine.context a : bool));
202202
await device;
203203
Stdio.print_endline
204204
{|

lib/tensor.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ let rec inputs_and_outputs t =
385385
in
386386
(non_embedded - embedded, embedded)
387387

388-
let iter_outputs ~f t = Set.iter ~f @@ snd @@ inputs_and_outputs t
388+
let iter_embedded ~f t = Set.iter ~f @@ snd @@ inputs_and_outputs t
389389
let input_nodes t = fst @@ inputs_and_outputs t
390390
let debug_name t = Tn.debug_name t.value
391391
let debug_grad t = Tn.debug_name (Option.value_exn t.diff).grad

lib/tensor.mli

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ type t = {
2828
shape inference. *)
2929
children : subtensor list;
3030
non_embedded : tn_set;
31-
(** These tensor nodes ([value], resp. [grad] of {!diff}) of the children which are not
32-
computed by [forward], resp. [backprop] of {!diff}. *)
31+
(** These tensor nodes ([value], resp. {!grad} of {!field-diff}) of the children which are not
32+
computed by [forward], resp. {!backprop} of {!field-diff}. *)
3333
}
3434
[@@deriving sexp_of]
3535
(** Information needed for compositional code generation. *)
@@ -196,11 +196,11 @@ val consume_backprop_code : t -> asgns * asgns
196196

197197
val input_nodes : t -> tn_set
198198
(** The nodes of descendant tensors whose computation is not embedded by the given tensor. They are
199-
"inputs" coming from other computations. *)
199+
"inputs" coming from other computations. NOTE: this a specific, narrow meaning of "inputs". *)
200200

201-
val iter_outputs : f:(tn -> unit) -> t -> unit
202-
(** [iter_outputs t] iterates over all descendant nodes that are embedded, i.e. are not members
203-
of [input_nodes t]. *)
201+
val iter_embedded : f:(tn -> unit) -> t -> unit
202+
(** [iter_embedded t] iterates over all descendant nodes that are embedded, i.e. are not members of
203+
[input_nodes t] -- see {!input_nodes}. *)
204204

205205
val unsafe_reinitialize : unit -> unit
206206
(** Bring global state to its initialization values. This invalidates any previously defined tensors

lib/train.ml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,18 +262,18 @@ let%track3_sexp round_robin_dry_run ~num_devices jitbs ~dry_sync : unit =
262262
let set_virtual (a : Tn.t) = Tn.update_memory_mode a Virtual 29
263263

264264
let every_non_literal_on_host =
265-
Tensor.iter_outputs ~f:(fun a ->
265+
Tensor.iter_embedded ~f:(fun a ->
266266
if Tn.mode_is_unspecified a && not (Tn.known_constant a) then set_hosted a)
267267

268268
let%debug2_sexp all_host_to_device (type context)
269269
(module Backend : Backend_type with type context = context) context =
270270
let f tn = ignore (Backend.from_host context tn : bool) in
271-
Tensor.iter_outputs ~f
271+
Tensor.iter_embedded ~f
272272

273273
let%debug2_sexp all_device_to_host (type context)
274274
(module Backend : Backend_type with type context = context) context =
275275
let f tn = ignore (Backend.to_host context tn : bool) in
276-
Tensor.iter_outputs ~f
276+
Tensor.iter_embedded ~f
277277

278278
(** Executes the jitted code and copies arrays embedded in the given tenosor from and to host,
279279
synchronizes before copying to host. If [looping] is provided, loops over bindings and executes

0 commit comments

Comments
 (0)