Skip to content

Commit 6683f26

Browse files
committed
Fix my mess-up in previous commit, expose embeddedness
1 parent 46aafff commit 6683f26

File tree

2 files changed

+221
-219
lines changed

2 files changed

+221
-219
lines changed

lib/tensor.ml

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -551,27 +551,27 @@ let to_dag ?(single_node = false) ?(embedded_only = false) ?entries_per_axis ~sp
551551
`Vlist (false, nodes @ [ shape ])
552552
else `Vlist (false, nodes)
553553
in
554-
let should_elide, is_non_embedded =
555-
if embedded_only then (not embedded, not embedded)
554+
let should_elide =
555+
if embedded_only then not embedded
556556
else if
557557
(* If this tensor appears embedded anywhere, use embedded logic for consistency *)
558558
Hash_set.mem tensors_with_embedded_occurrence t.id
559-
then (not embedded, not embedded)
559+
then not embedded
560560
else
561561
(* Only use visited tracking for tensors that are never embedded anywhere *)
562562
match visited with
563-
| None -> (not embedded, not embedded)
563+
| None -> not embedded
564564
| Some visited_set ->
565-
if Hash_set.mem visited_set t.id then (true, not embedded)
565+
if Hash_set.mem visited_set t.id then true
566566
else (
567567
Hash_set.add visited_set t.id;
568-
(false, not embedded))
568+
false)
569569
in
570-
let eid = id ^ if is_non_embedded then " non-emb" else "" in
570+
let txt = txt ^ if (not should_elide) && not embedded then " non-emb" else "" in
571571
match (should_elide, with_value, with_grad, t.diff) with
572-
| true, _, _, _ -> `Embed_subtree_ID id
572+
| true, _, _, _ -> `Embed_subtree_ID txt
573573
| _, false, false, _ | _, false, true, None ->
574-
`Subtree_with_ID (eid, `Tree (add_shape [ `Text txt ], children))
574+
`Subtree_with_ID (id, `Tree (add_shape [ `Text txt ], children))
575575
| _, true, false, _ | _, true, true, None ->
576576
let node =
577577
lazy_optional_payload t.value.array ~spy
@@ -581,17 +581,19 @@ let to_dag ?(single_node = false) ?(embedded_only = false) ?entries_per_axis ~sp
581581
(Nd.render_array ~brief:true ~prefix:txt ?entries_per_axis ~labels ~indices v_array))
582582
~missing:(fun () -> txt ^ " " ^ where_located t.value)
583583
in
584-
`Subtree_with_ID (eid, `Tree (add_shape [ node ], children))
584+
`Subtree_with_ID (id, `Tree (add_shape [ node ], children))
585585
| _, false, true, Some diff ->
586-
let prefix = grad_txt diff in
586+
let prefix =
587+
grad_txt diff ^ if (not should_elide) && not embedded then " non-emb" else ""
588+
in
587589
let node =
588590
match Lazy.force diff.grad.array with
589591
| Some g_array ->
590592
Tn.do_read diff.grad;
591593
`Box (Nd.render_array ~brief:true ~prefix ?entries_per_axis ~labels ~indices g_array)
592594
| None -> `Text (prefix ^ " " ^ where_located diff.grad)
593595
in
594-
`Subtree_with_ID (eid, `Tree (add_shape [ node ], children))
596+
`Subtree_with_ID (id, `Tree (add_shape [ node ], children))
595597
| _, true, true, Some diff ->
596598
let node =
597599
let value =
@@ -614,7 +616,7 @@ let to_dag ?(single_node = false) ?(embedded_only = false) ?entries_per_axis ~sp
614616
in
615617
`Vlist (false, [ value; grad ])
616618
in
617-
`Subtree_with_ID (eid, `Tree (add_shape [ node ], children))
619+
`Subtree_with_ID (id, `Tree (add_shape [ node ], children))
618620
in
619621
to_dag { subtensor = t; embedded = true }
620622

0 commit comments

Comments
 (0)