@@ -551,27 +551,27 @@ let to_dag ?(single_node = false) ?(embedded_only = false) ?entries_per_axis ~sp
551551 `Vlist (false , nodes @ [ shape ])
552552 else `Vlist (false , nodes)
553553 in
554- let should_elide, is_non_embedded =
555- if embedded_only then ( not embedded, not embedded)
554+ let should_elide =
555+ if embedded_only then not embedded
556556 else if
557557 (* If this tensor appears embedded anywhere, use embedded logic for consistency *)
558558 Hash_set. mem tensors_with_embedded_occurrence t.id
559- then ( not embedded, not embedded)
559+ then not embedded
560560 else
561561 (* Only use visited tracking for tensors that are never embedded anywhere *)
562562 match visited with
563- | None -> ( not embedded, not embedded)
563+ | None -> not embedded
564564 | Some visited_set ->
565- if Hash_set. mem visited_set t.id then ( true , not embedded)
565+ if Hash_set. mem visited_set t.id then true
566566 else (
567567 Hash_set. add visited_set t.id;
568- ( false , not embedded) )
568+ false )
569569 in
570- let eid = id ^ if is_non_embedded then " non-emb" else " " in
570+ let txt = txt ^ if ( not should_elide) && not embedded then " non-emb" else " " in
571571 match (should_elide, with_value, with_grad, t.diff) with
572- | true , _ , _ , _ -> `Embed_subtree_ID id
572+ | true , _ , _ , _ -> `Embed_subtree_ID txt
573573 | _ , false , false , _ | _ , false , true , None ->
574- `Subtree_with_ID (eid , `Tree (add_shape [ `Text txt ], children))
574+ `Subtree_with_ID (id , `Tree (add_shape [ `Text txt ], children))
575575 | _ , true , false , _ | _ , true , true , None ->
576576 let node =
577577 lazy_optional_payload t.value.array ~spy
@@ -581,17 +581,19 @@ let to_dag ?(single_node = false) ?(embedded_only = false) ?entries_per_axis ~sp
581581 (Nd. render_array ~brief: true ~prefix: txt ?entries_per_axis ~labels ~indices v_array))
582582 ~missing: (fun () -> txt ^ " " ^ where_located t.value)
583583 in
584- `Subtree_with_ID (eid , `Tree (add_shape [ node ], children))
584+ `Subtree_with_ID (id , `Tree (add_shape [ node ], children))
585585 | _ , false , true , Some diff ->
586- let prefix = grad_txt diff in
586+ let prefix =
587+ grad_txt diff ^ if (not should_elide) && not embedded then " non-emb" else " "
588+ in
587589 let node =
588590 match Lazy. force diff.grad.array with
589591 | Some g_array ->
590592 Tn. do_read diff.grad;
591593 `Box (Nd. render_array ~brief: true ~prefix ?entries_per_axis ~labels ~indices g_array)
592594 | None -> `Text (prefix ^ " " ^ where_located diff.grad)
593595 in
594- `Subtree_with_ID (eid , `Tree (add_shape [ node ], children))
596+ `Subtree_with_ID (id , `Tree (add_shape [ node ], children))
595597 | _ , true , true , Some diff ->
596598 let node =
597599 let value =
@@ -614,7 +616,7 @@ let to_dag ?(single_node = false) ?(embedded_only = false) ?entries_per_axis ~sp
614616 in
615617 `Vlist (false , [ value; grad ])
616618 in
617- `Subtree_with_ID (eid , `Tree (add_shape [ node ], children))
619+ `Subtree_with_ID (id , `Tree (add_shape [ node ], children))
618620 in
619621 to_dag { subtensor = t; embedded = true }
620622
0 commit comments