@@ -517,8 +517,17 @@ let lazy_optional_payload ~spy ~present ~missing v =
517517type array_print_style =
518518 [ `Default | `Inline | `Label_layout of (string * int ) list | `N5_layout of string ]
519519
520- let to_dag ?(single_node = false ) ?entries_per_axis ~spy ~with_shape ~with_id ~with_value ~with_grad
521- t =
520+ let to_dag ?(single_node = false ) ?(embedded_only = false ) ?entries_per_axis ~spy ~with_shape
521+ ~with_id ~with_value ~with_grad t =
522+ (* First scan to identify which tensors appear embedded anywhere *)
523+ let tensors_with_embedded_occurrence = Hash_set. create (module Int ) in
524+ let rec scan_for_embedded { subtensor = t ; embedded } =
525+ if embedded then Hash_set. add tensors_with_embedded_occurrence t.id;
526+ if not single_node then List. iter ~f: scan_for_embedded t.children
527+ in
528+ if not embedded_only then scan_for_embedded { subtensor = t; embedded = true };
529+
530+ let visited = if embedded_only then None else Some (Hash_set. create (module Int )) in
522531 let rec to_dag { subtensor = t ; embedded } : PrintBox_utils.dag =
523532 let id = Int. to_string t.id in
524533 let children = if single_node then [] else List. map ~f: to_dag t.children in
@@ -542,10 +551,27 @@ let to_dag ?(single_node = false) ?entries_per_axis ~spy ~with_shape ~with_id ~w
542551 `Vlist (false , nodes @ [ shape ])
543552 else `Vlist (false , nodes)
544553 in
545- match (not embedded, with_value, with_grad, t.diff) with
546- | true , _ , _ , _ -> `Embed_subtree_ID (Int. to_string t.id)
554+ let should_elide, is_non_embedded =
555+ if embedded_only then (not embedded, not embedded)
556+ else if
557+ (* If this tensor appears embedded anywhere, use embedded logic for consistency *)
558+ Hash_set. mem tensors_with_embedded_occurrence t.id
559+ then (not embedded, not embedded)
560+ else
561+ (* Only use visited tracking for tensors that are never embedded anywhere *)
562+ match visited with
563+ | None -> (not embedded, not embedded)
564+ | Some visited_set ->
565+ if Hash_set. mem visited_set t.id then (true , not embedded)
566+ else (
567+ Hash_set. add visited_set t.id;
568+ (false , not embedded))
569+ in
570+ let eid = id ^ if is_non_embedded then " non-emb" else " " in
571+ match (should_elide, with_value, with_grad, t.diff) with
572+ | true , _ , _ , _ -> `Embed_subtree_ID id
547573 | _ , false , false , _ | _ , false , true , None ->
548- `Subtree_with_ID (id , `Tree (add_shape [ `Text txt ], children))
574+ `Subtree_with_ID (eid , `Tree (add_shape [ `Text txt ], children))
549575 | _ , true , false , _ | _ , true , true , None ->
550576 let node =
551577 lazy_optional_payload t.value.array ~spy
@@ -555,7 +581,7 @@ let to_dag ?(single_node = false) ?entries_per_axis ~spy ~with_shape ~with_id ~w
555581 (Nd. render_array ~brief: true ~prefix: txt ?entries_per_axis ~labels ~indices v_array))
556582 ~missing: (fun () -> txt ^ " " ^ where_located t.value)
557583 in
558- `Subtree_with_ID (id , `Tree (add_shape [ node ], children))
584+ `Subtree_with_ID (eid , `Tree (add_shape [ node ], children))
559585 | _ , false , true , Some diff ->
560586 let prefix = grad_txt diff in
561587 let node =
@@ -565,7 +591,7 @@ let to_dag ?(single_node = false) ?entries_per_axis ~spy ~with_shape ~with_id ~w
565591 `Box (Nd. render_array ~brief: true ~prefix ?entries_per_axis ~labels ~indices g_array)
566592 | None -> `Text (prefix ^ " " ^ where_located diff.grad)
567593 in
568- `Subtree_with_ID (id , `Tree (add_shape [ node ], children))
594+ `Subtree_with_ID (eid , `Tree (add_shape [ node ], children))
569595 | _ , true , true , Some diff ->
570596 let node =
571597 let value =
@@ -588,13 +614,14 @@ let to_dag ?(single_node = false) ?entries_per_axis ~spy ~with_shape ~with_id ~w
588614 in
589615 `Vlist (false , [ value; grad ])
590616 in
591- `Subtree_with_ID (id , `Tree (add_shape [ node ], children))
617+ `Subtree_with_ID (eid , `Tree (add_shape [ node ], children))
592618 in
593619 to_dag { subtensor = t; embedded = true }
594620
595- let to_printbox ?single_node ?entries_per_axis ?(with_id = false ) ?(spy = false )
621+ let to_printbox ?single_node ?embedded_only ? entries_per_axis ?(with_id = false ) ?(spy = false )
596622 ?(with_shape = false ) ?(with_value = true ) ~with_grad ~depth t =
597- to_dag ?single_node ?entries_per_axis ~with_id ~spy ~with_shape ~with_value ~with_grad t
623+ to_dag ?single_node ?embedded_only ?entries_per_axis ~with_id ~spy ~with_shape ~with_value
624+ ~with_grad t
598625 |> PrintBox_utils. reformat_dag depth
599626
600627let log_debug_info ~from_log_level t =
@@ -777,10 +804,10 @@ let print_forward_roots ~with_grad ~with_code (style : array_print_style) =
777804 print ~with_grad ~with_code style root)
778805
779806let print_tree ?here ?entries_per_axis ?(with_backend_info = false ) ?(with_id = true ) ?(spy = false )
780- ?(with_shape = false ) ?(with_value = true ) ~with_grad ~depth t =
807+ ?(with_shape = false ) ?(with_value = true ) ? embedded_only ~with_grad ~depth t =
781808 Option. iter here ~f: (fun here ->
782809 Stdio. printf " HERE: %s\n %!" (Source_code_position. to_string here));
783810 (* FIXME: print backend info *)
784811 ignore with_backend_info;
785812 PrintBox_text. output Stdio. stdout @@ PrintBox_utils. dag_to_box @@ PrintBox_utils. boxify depth
786- @@ to_dag ?entries_per_axis ~with_id ~spy ~with_shape ~with_value ~with_grad t
813+ @@ to_dag ?entries_per_axis ?embedded_only ~with_id ~spy ~with_shape ~with_value ~with_grad t
0 commit comments