Skip to content

Commit 80b7d04

Browse files
committed
Don't include zero_grads in consume_backprop_code, to avoid forcing handling zero_grads and backprop together
1 parent 24f71f9 commit 80b7d04

File tree

3 files changed

+8
-29
lines changed

3 files changed

+8
-29
lines changed

lib/tensor.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,7 @@ let consume_backprop_code t =
577577
found potentially unsafe roots: %{String.concat ~sep:", " @@ List.map ~f:debug_name unsafe_roots}|}],
578578
Some t );
579579
remove_bprop_root t;
580-
(diff.zero_grads, diff.backprop)
580+
diff.backprop
581581

582582
let random_seed = ref None
583583

lib/tensor.mli

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,11 @@ type tn_set = Set.M(Ir.Tnode).t
88
type asgns = Ir.Assignments.t
99
type comp = Ir.Assignments.comp
1010
type fetch_op = Ir.Assignments.fetch_op
11-
type projections = {
12-
projections_debug : string;
13-
projections : Ir.Indexing.projections Lazy.t;
14-
}
11+
type projections = { projections_debug : string; projections : Ir.Indexing.projections Lazy.t }
1512

1613
type diff = {
1714
grad : tn;
18-
zero_grads : asgns;
19-
(** Prepares for backpropagation. Always compile as: [Seq (zero_grads, backprop)]. *)
15+
zero_grads : asgns; (** Prepares for backpropagation. Beware of the "missing zero_grads" bug. *)
2016
backprop : comp;
2117
(** Backpropagates for the tensor and its descendants; which typically means adding partial
2218
gradients to the gradient tensor of the subtensors, then for sub-subtensors etc. *)
@@ -215,11 +211,12 @@ val consume_forward_code : t -> comp
215211
[consume_forward_code t] ensures [t] is a forward root, removes it from forward roots, and
216212
checks that there are no other forward roots for tensors with children. *)
217213

218-
val consume_backprop_code : t -> asgns * comp
214+
val consume_backprop_code : t -> comp
219215
(** A backprop root is a tensor with a gradient that is not (currently) receiving gradients from
220216
another tensor. I.e. it is not currently used to compute a tensor with a gradient.
221217
[consume_backprop_code t] ensures [t] is a backprop root, removes it from backprop roots, and
222-
checks that there are no other backprop roots for tensors with children. *)
218+
checks that there are no other backprop roots for tensors with children. It returns the backprop
219+
code -- note this does not include the zero_grads code. *)
223220

224221
val iter_embedded : f:(tn -> unit) -> t -> unit
225222
(** [iter_embedded t] iterates over all descendant nodes that are embedded, i.e. are members of

lib/train.ml

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -79,25 +79,6 @@ let forward t =
7979
let label = Tn.debug_name t.value in
8080
{ fwd with asgns = Asgns.Block_comment (label ^ " fwd", fwd.asgns) }
8181

82-
let diff_or_error t provenance =
83-
Option.value_or_thunk t.Tensor.diff ~default:(fun () ->
84-
raise @@ Tensor.Session_error (provenance ^ ": tensor is not differentiable", Some t))
85-
86-
let grad_update_nochecks loss =
87-
let diff = diff_or_error loss "Train.grad_update_nochecks" in
88-
let fwd_bprop =
89-
[%cd
90-
~~(loss "gradient update";
91-
~~(loss "fwd";
92-
loss.forward);
93-
~~(loss "zero grads";
94-
Asgns.to_comp diff.zero_grads);
95-
loss.grad =: 1;
96-
~~(loss "bprop";
97-
diff.backprop))]
98-
in
99-
fwd_bprop
100-
10182
(** Returns the tensor's forward, zeroing gradients, and backprop code wrapped with label-derived
10283
comments. Sets the tensor's value as "fully on host". If [setup_for_parallel] is true (false by
10384
default), sets the parameters and their gradients as "non-local" (on-device). *)
@@ -107,7 +88,8 @@ let grad_update ?(setup_for_parallel = false) loss =
10788
Set.iter loss.Tensor.params ~f:(fun p ->
10889
set_materialized (Option.value_exn ~here:[%here] p.diff).grad);
10990
let fwd = Tensor.consume_forward_code loss in
110-
let zero_grads, bprop = Tensor.consume_backprop_code loss in
91+
let bprop = Tensor.consume_backprop_code loss in
92+
let zero_grads = (Option.value_exn ~here:[%here] loss.diff).zero_grads in
11193
(* Note: the %cd syntax for [loss.grad] does not modify roots. *)
11294
[%cd
11395
~~(loss "gradient update for" loss;

0 commit comments

Comments
 (0)