Skip to content

Commit 102b9f8

Browse files
committed
Fix Tensor.print endlines
1 parent 3c0a07c commit 102b9f8

File tree

3 files changed

+34
-88
lines changed

3 files changed

+34
-88
lines changed

lib/tensor.ml

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -611,11 +611,12 @@ let to_dag ?(single_node = false) ?(embedded_only = false) ?entries_per_axis ~sp
611611
in
612612
let node =
613613
if Lazy.is_val diff.grad.array then
614-
match Lazy.force diff.grad.array with
615-
| Some g_array ->
616-
Tn.do_read diff.grad;
617-
`Box (Nd.render_array ~brief:true ~prefix ?entries_per_axis ~labels ~indices g_array)
618-
| None -> `Text (prefix ^ " " ^ where_located diff.grad)
614+
match Lazy.force diff.grad.array with
615+
| Some g_array ->
616+
Tn.do_read diff.grad;
617+
`Box
618+
(Nd.render_array ~brief:true ~prefix ?entries_per_axis ~labels ~indices g_array)
619+
| None -> `Text (prefix ^ " " ^ where_located diff.grad)
619620
else `Text (prefix ^ " <not-in-yet> " ^ where_located diff.grad)
620621
in
621622
`Subtree_with_ID (id, `Tree (add_shape [ node ], children))
@@ -725,31 +726,34 @@ let to_doc ?(force_read = false) ~with_grad ~with_code ?(with_low_level = false)
725726

726727
let open PPrint in
727728
(* Create document for tensor value *)
729+
let has_grad = with_grad && Option.is_some t.diff in
728730
let value_doc =
729-
if not force_read && not (Lazy.is_val t.value.array) then
730-
string prefix_str ^^ string " <not-in-yet>" ^^ space
731+
if (not force_read) && not (Lazy.is_val t.value.array) then
732+
string prefix_str ^^ string " <not-in-yet>" ^^ break 1
731733
else
732734
match (style, Lazy.force t.value.array) with
733-
| _, None -> string prefix_str ^^ string " <not-hosted>" ^^ space
735+
| _, None ->
736+
string prefix_str ^^ string " <not-hosted>" ^^ if has_grad then break 1 else empty
734737
| `Inline, Some arr ->
735738
Tn.do_read t.value;
736739
string prefix_str ^^ space
737740
^^ Nd.to_doc_inline ~num_batch_axes ~num_input_axes ~num_output_axes ?axes_spec arr
741+
^^ if has_grad then break 1 else empty
738742
| _, Some arr ->
739743
Tn.do_read t.value;
740-
Nd.to_doc ~prefix:prefix_str ~labels ~indices arr
744+
Nd.to_doc ~prefix:prefix_str ~labels ~indices arr ^^ if has_grad then break 1 else empty
741745
in
742746

743747
(* Create document for gradient *)
744748
let grad_doc =
745749
if with_grad then
746750
match t.diff with
747751
| Some diff -> (
748-
if not force_read && not (Lazy.is_val diff.grad.array) then
749-
string (grad_txt diff) ^^ string " <not-in-yet>" ^^ space
752+
if (not force_read) && not (Lazy.is_val diff.grad.array) then
753+
string (grad_txt diff) ^^ string " <not-in-yet>"
750754
else
751755
match Lazy.force diff.grad.array with
752-
| None -> string (grad_txt diff) ^^ string " <not-hosted>" ^^ space
756+
| None -> string (grad_txt diff) ^^ string " <not-hosted>"
753757
| Some arr -> (
754758
match style with
755759
| `Inline ->
@@ -758,11 +762,10 @@ let to_doc ?(force_read = false) ~with_grad ~with_code ?(with_low_level = false)
758762
^^ space
759763
^^ Nd.to_doc_inline ~num_batch_axes ~num_input_axes ~num_output_axes ?axes_spec
760764
arr
761-
^^ string "\n"
762765
| `Default | `N5_layout _ | `Label_layout _ ->
763766
Tn.do_read diff.grad;
764767
let prefix = prefix_str ^ " " ^ grad_txt diff in
765-
Nd.to_doc ~prefix ~labels ~indices arr ^^ string "\n"))
768+
Nd.to_doc ~prefix ~labels ~indices arr))
766769
| None -> empty
767770
else empty
768771
in
@@ -774,15 +777,15 @@ let to_doc ?(force_read = false) ~with_grad ~with_code ?(with_low_level = false)
774777
match t.forward.asgns with
775778
| Noop -> empty
776779
| fwd_code ->
777-
string "@[<v 2>Current forward body:"
778-
^^ hardline ^^ Asgns.to_doc () fwd_code ^^ string "@]" ^^ hardline
780+
group (string "Current forward body:" ^^ nest 2 (hardline ^^ Asgns.to_doc () fwd_code))
781+
^^ hardline
779782
in
780783
let bwd_doc =
781784
match t.diff with
782785
| Some { backprop = { asgns = Noop; _ }; _ } -> empty
783786
| Some { backprop = { asgns = bwd_code; _ }; _ } ->
784-
string "@[<v 2>Current backprop body:"
785-
^^ hardline ^^ Asgns.to_doc () bwd_code ^^ string "@]" ^^ hardline
787+
group (string "Current backprop body:" ^^ nest 2 (hardline ^^ Asgns.to_doc () bwd_code))
788+
^^ hardline
786789
| None -> empty
787790
in
788791
fwd_doc ^^ bwd_doc
@@ -796,27 +799,30 @@ let to_doc ?(force_read = false) ~with_grad ~with_code ?(with_low_level = false)
796799
match t.forward.asgns with
797800
| Noop -> empty
798801
| fwd_code ->
799-
string "@[<v 2>Current forward low-level body:"
802+
group
803+
(string "Current forward low-level body:"
804+
^^ nest 2 (hardline ^^ Ir.Low_level.to_doc () (Asgns.to_low_level fwd_code)))
800805
^^ hardline
801-
^^ Ir.Low_level.to_doc () (Asgns.to_low_level fwd_code)
802-
^^ string "@]" ^^ hardline
803806
in
804807
let bwd_doc =
805808
match t.diff with
806809
| Some { backprop = { asgns = Noop; _ }; _ } -> empty
807810
| Some { backprop = { asgns = bwd_code; _ }; _ } ->
808-
string "@[<v 2>Current backprop low-level body:"
811+
group
812+
(string "Current backprop low-level body:"
813+
^^ nest 2 (hardline ^^ Ir.Low_level.to_doc () (Asgns.to_low_level bwd_code)))
809814
^^ hardline
810-
^^ Ir.Low_level.to_doc () (Asgns.to_low_level bwd_code)
811-
^^ string "@]" ^^ hardline
812815
| None -> empty
813816
in
814817
fwd_doc ^^ bwd_doc
815818
else empty
816819
in
817820

818821
(* Combine all documents and print *)
819-
group (value_doc ^^ break 1 ^^ grad_doc ^^ break 1 ^^ code_doc ^^ break 1 ^^ low_level_doc)
822+
group
823+
(value_doc ^^ grad_doc
824+
^^ (if is_empty value_doc && is_empty grad_doc then empty else hardline)
825+
^^ code_doc ^^ low_level_doc)
820826

821827
let print ?here ?(force_read = false) ~with_grad ~with_code ?(with_low_level = false)
822828
(style : array_print_style) t =

0 commit comments

Comments
 (0)