Skip to content

Commit e3c02af

Browse files
lukstaficlaude
andcommitted
Fix parameter gradients not embedded after params moved earlier
The commit 47a33fc moved params computation earlier in Tensor.op, but this broke the assumption that t.params was empty when building backprop. The condition `not (Set.mem t.params ti)` now correctly skipped parameter backprop, but also skipped adding their gradient nodes to embedded_nodes - causing "context lacks node x.grad" errors. Fix: still add parameter gradients to embedded_nodes when skipping their backprop code. Also adds zero2hero_1of7_exec standalone test for easier debugging. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 510a567 commit e3c02af

File tree

4 files changed

+450
-1
lines changed

4 files changed

+450
-1
lines changed

tensor/tensor.ml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,14 @@ let%track7_sexp op ~(label : string list) ?(ternary_op = Shape.Pointwise_tern)
409409
in
410410
let bcks =
411411
List.filter_map ordered_ts ~f:(fun ti ->
412-
if is_bck_root ti && not (Set.mem t.params ti) then bprop ti else None)
412+
if is_bck_root ti then
413+
if Set.mem t.params ti then (
414+
(* Parameter's backprop is terminal, but we still need its gradient embedded *)
415+
Option.iter ti.diff ~f:(fun diff ->
416+
embedded_nodes := Set.add !embedded_nodes diff.grad);
417+
None)
418+
else bprop ti
419+
else None)
413420
in
414421
let diff = Some { grad = g; zero_grads; backprop = Asgns.empty_comp } in
415422
let t = { t with diff } in

test/operations/dune

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,17 @@
377377
(preprocess
378378
(pps ppx_here ppx_ocannl ppx_expect)))
379379

380+
(test
381+
(name zero2hero_1of7_exec)
382+
(package neural_nets_lib)
383+
(deps
384+
ocannl_config
385+
(env_var OCANNL_BACKEND))
386+
(modules zero2hero_1of7_exec)
387+
(libraries base ocannl stdio)
388+
(preprocess
389+
(pps ppx_here ppx_ocannl ppx_expect)))
390+
380391
(library
381392
(name operations_tutorials)
382393
(package neural_nets_lib)

0 commit comments

Comments
 (0)