@@ -8,15 +8,11 @@ type tn_set = Set.M(Ir.Tnode).t
88type asgns = Ir.Assignments .t
99type comp = Ir.Assignments .comp
1010type fetch_op = Ir.Assignments .fetch_op
11- type projections = {
12- projections_debug : string ;
13- projections : Ir.Indexing .projections Lazy .t ;
14- }
11+ type projections = { projections_debug : string ; projections : Ir.Indexing .projections Lazy .t }
1512
1613type diff = {
1714 grad : tn ;
18- zero_grads : asgns ;
19- (* * Prepares for backpropagation. Always compile as: [Seq (zero_grads, backprop)]. *)
15+ zero_grads : asgns ; (* * Prepares for backpropagation. Beware of the "missing zero_grads" bug. *)
2016 backprop : comp ;
2117 (* * Backpropagates for the tensor and its descendants; which typically means adding partial
2218 gradients to the gradient tensor of the subtensors, then for sub-subtensors etc. *)
@@ -215,11 +211,12 @@ val consume_forward_code : t -> comp
215211 [consume_forward_code t] ensures [t] is a forward root, removes it from forward roots, and
216212 checks that there are no other forward roots for tensors with children. *)
217213
218- val consume_backprop_code : t -> asgns * comp
214+ val consume_backprop_code : t -> comp
219215(* * A backprop root is a tensor with a gradient that is not (currently) receiving gradients from
220216 another tensor. I.e. it is not currently used to compute a tensor with a gradient.
221217 [consume_backprop_code t] ensures [t] is a backprop root, removes it from backprop roots, and
222- checks that there are no other backprop roots for tensors with children. *)
218+ checks that there are no other backprop roots for tensors with children. It returns the backprop
219+ code -- note this does not include the zero_grads code. *)
223220
224221val iter_embedded : f :(tn -> unit ) -> t -> unit
225222(* * [iter_embedded t] iterates over all descendant nodes that are embedded, i.e. are members of
0 commit comments