Skip to content

Commit 46aafff

Browse files
committed
Restore printing parameter subtensors in forward code, mostly by Claude Sonnet
They are never
1 parent 37394a6 commit 46aafff

File tree

3 files changed

+181
-148
lines changed

3 files changed

+181
-148
lines changed

lib/tensor.ml

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -517,8 +517,17 @@ let lazy_optional_payload ~spy ~present ~missing v =
517517
type array_print_style =
518518
[ `Default | `Inline | `Label_layout of (string * int) list | `N5_layout of string ]
519519

520-
let to_dag ?(single_node = false) ?entries_per_axis ~spy ~with_shape ~with_id ~with_value ~with_grad
521-
t =
520+
let to_dag ?(single_node = false) ?(embedded_only = false) ?entries_per_axis ~spy ~with_shape
521+
~with_id ~with_value ~with_grad t =
522+
(* First scan to identify which tensors appear embedded anywhere *)
523+
let tensors_with_embedded_occurrence = Hash_set.create (module Int) in
524+
let rec scan_for_embedded { subtensor = t; embedded } =
525+
if embedded then Hash_set.add tensors_with_embedded_occurrence t.id;
526+
if not single_node then List.iter ~f:scan_for_embedded t.children
527+
in
528+
if not embedded_only then scan_for_embedded { subtensor = t; embedded = true };
529+
530+
let visited = if embedded_only then None else Some (Hash_set.create (module Int)) in
522531
let rec to_dag { subtensor = t; embedded } : PrintBox_utils.dag =
523532
let id = Int.to_string t.id in
524533
let children = if single_node then [] else List.map ~f:to_dag t.children in
@@ -542,10 +551,27 @@ let to_dag ?(single_node = false) ?entries_per_axis ~spy ~with_shape ~with_id ~w
542551
`Vlist (false, nodes @ [ shape ])
543552
else `Vlist (false, nodes)
544553
in
545-
match (not embedded, with_value, with_grad, t.diff) with
546-
| true, _, _, _ -> `Embed_subtree_ID (Int.to_string t.id)
554+
let should_elide, is_non_embedded =
555+
if embedded_only then (not embedded, not embedded)
556+
else if
557+
(* If this tensor appears embedded anywhere, use embedded logic for consistency *)
558+
Hash_set.mem tensors_with_embedded_occurrence t.id
559+
then (not embedded, not embedded)
560+
else
561+
(* Only use visited tracking for tensors that are never embedded anywhere *)
562+
match visited with
563+
| None -> (not embedded, not embedded)
564+
| Some visited_set ->
565+
if Hash_set.mem visited_set t.id then (true, not embedded)
566+
else (
567+
Hash_set.add visited_set t.id;
568+
(false, not embedded))
569+
in
570+
let eid = id ^ if is_non_embedded then " non-emb" else "" in
571+
match (should_elide, with_value, with_grad, t.diff) with
572+
| true, _, _, _ -> `Embed_subtree_ID id
547573
| _, false, false, _ | _, false, true, None ->
548-
`Subtree_with_ID (id, `Tree (add_shape [ `Text txt ], children))
574+
`Subtree_with_ID (eid, `Tree (add_shape [ `Text txt ], children))
549575
| _, true, false, _ | _, true, true, None ->
550576
let node =
551577
lazy_optional_payload t.value.array ~spy
@@ -555,7 +581,7 @@ let to_dag ?(single_node = false) ?entries_per_axis ~spy ~with_shape ~with_id ~w
555581
(Nd.render_array ~brief:true ~prefix:txt ?entries_per_axis ~labels ~indices v_array))
556582
~missing:(fun () -> txt ^ " " ^ where_located t.value)
557583
in
558-
`Subtree_with_ID (id, `Tree (add_shape [ node ], children))
584+
`Subtree_with_ID (eid, `Tree (add_shape [ node ], children))
559585
| _, false, true, Some diff ->
560586
let prefix = grad_txt diff in
561587
let node =
@@ -565,7 +591,7 @@ let to_dag ?(single_node = false) ?entries_per_axis ~spy ~with_shape ~with_id ~w
565591
`Box (Nd.render_array ~brief:true ~prefix ?entries_per_axis ~labels ~indices g_array)
566592
| None -> `Text (prefix ^ " " ^ where_located diff.grad)
567593
in
568-
`Subtree_with_ID (id, `Tree (add_shape [ node ], children))
594+
`Subtree_with_ID (eid, `Tree (add_shape [ node ], children))
569595
| _, true, true, Some diff ->
570596
let node =
571597
let value =
@@ -588,13 +614,14 @@ let to_dag ?(single_node = false) ?entries_per_axis ~spy ~with_shape ~with_id ~w
588614
in
589615
`Vlist (false, [ value; grad ])
590616
in
591-
`Subtree_with_ID (id, `Tree (add_shape [ node ], children))
617+
`Subtree_with_ID (eid, `Tree (add_shape [ node ], children))
592618
in
593619
to_dag { subtensor = t; embedded = true }
594620

595-
let to_printbox ?single_node ?entries_per_axis ?(with_id = false) ?(spy = false)
621+
let to_printbox ?single_node ?embedded_only ?entries_per_axis ?(with_id = false) ?(spy = false)
596622
?(with_shape = false) ?(with_value = true) ~with_grad ~depth t =
597-
to_dag ?single_node ?entries_per_axis ~with_id ~spy ~with_shape ~with_value ~with_grad t
623+
to_dag ?single_node ?embedded_only ?entries_per_axis ~with_id ~spy ~with_shape ~with_value
624+
~with_grad t
598625
|> PrintBox_utils.reformat_dag depth
599626

600627
let log_debug_info ~from_log_level t =
@@ -777,10 +804,10 @@ let print_forward_roots ~with_grad ~with_code (style : array_print_style) =
777804
print ~with_grad ~with_code style root)
778805

779806
let print_tree ?here ?entries_per_axis ?(with_backend_info = false) ?(with_id = true) ?(spy = false)
780-
?(with_shape = false) ?(with_value = true) ~with_grad ~depth t =
807+
?(with_shape = false) ?(with_value = true) ?embedded_only ~with_grad ~depth t =
781808
Option.iter here ~f:(fun here ->
782809
Stdio.printf "HERE: %s\n%!" (Source_code_position.to_string here));
783810
(* FIXME: print backend info *)
784811
ignore with_backend_info;
785812
PrintBox_text.output Stdio.stdout @@ PrintBox_utils.dag_to_box @@ PrintBox_utils.boxify depth
786-
@@ to_dag ?entries_per_axis ~with_id ~spy ~with_shape ~with_value ~with_grad t
813+
@@ to_dag ?entries_per_axis ?embedded_only ~with_id ~spy ~with_shape ~with_value ~with_grad t

lib/tensor.mli

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ type array_print_style =
317317

318318
val to_printbox :
319319
?single_node:bool ->
320+
?embedded_only:bool ->
320321
?entries_per_axis:int ->
321322
?with_id:bool ->
322323
?spy:bool ->
@@ -356,6 +357,7 @@ val print_tree :
356357
?spy:bool ->
357358
?with_shape:bool ->
358359
?with_value:bool ->
360+
?embedded_only:bool ->
359361
with_grad:bool ->
360362
depth:int ->
361363
t ->

0 commit comments

Comments
 (0)