@@ -611,11 +611,12 @@ let to_dag ?(single_node = false) ?(embedded_only = false) ?entries_per_axis ~sp
611611 in
612612 let node =
613613 if Lazy. is_val diff.grad.array then
614- match Lazy. force diff.grad.array with
615- | Some g_array ->
616- Tn. do_read diff.grad;
617- `Box (Nd. render_array ~brief: true ~prefix ?entries_per_axis ~labels ~indices g_array)
618- | None -> `Text (prefix ^ " " ^ where_located diff.grad)
614+ match Lazy. force diff.grad.array with
615+ | Some g_array ->
616+ Tn. do_read diff.grad;
617+ `Box
618+ (Nd. render_array ~brief: true ~prefix ?entries_per_axis ~labels ~indices g_array)
619+ | None -> `Text (prefix ^ " " ^ where_located diff.grad)
619620 else `Text (prefix ^ " <not-in-yet> " ^ where_located diff.grad)
620621 in
621622 `Subtree_with_ID (id, `Tree (add_shape [ node ], children))
@@ -725,31 +726,34 @@ let to_doc ?(force_read = false) ~with_grad ~with_code ?(with_low_level = false)
725726
726727 let open PPrint in
727728 (* Create document for tensor value *)
729+ let has_grad = with_grad && Option. is_some t.diff in
728730 let value_doc =
729- if not force_read && not (Lazy. is_val t.value.array ) then
730- string prefix_str ^^ string " <not-in-yet>" ^^ space
731+ if ( not force_read) && not (Lazy. is_val t.value.array ) then
732+ string prefix_str ^^ string " <not-in-yet>" ^^ break 1
731733 else
732734 match (style, Lazy. force t.value.array ) with
733- | _ , None -> string prefix_str ^^ string " <not-hosted>" ^^ space
735+ | _ , None ->
736+ string prefix_str ^^ string " <not-hosted>" ^^ if has_grad then break 1 else empty
734737 | `Inline , Some arr ->
735738 Tn. do_read t.value;
736739 string prefix_str ^^ space
737740 ^^ Nd. to_doc_inline ~num_batch_axes ~num_input_axes ~num_output_axes ?axes_spec arr
741+ ^^ if has_grad then break 1 else empty
738742 | _ , Some arr ->
739743 Tn. do_read t.value;
740- Nd. to_doc ~prefix: prefix_str ~labels ~indices arr
744+ Nd. to_doc ~prefix: prefix_str ~labels ~indices arr ^^ if has_grad then break 1 else empty
741745 in
742746
743747 (* Create document for gradient *)
744748 let grad_doc =
745749 if with_grad then
746750 match t.diff with
747751 | Some diff -> (
748- if not force_read && not (Lazy. is_val diff.grad.array ) then
749- string (grad_txt diff) ^^ string " <not-in-yet>" ^^ space
752+ if ( not force_read) && not (Lazy. is_val diff.grad.array ) then
753+ string (grad_txt diff) ^^ string " <not-in-yet>"
750754 else
751755 match Lazy. force diff.grad.array with
752- | None -> string (grad_txt diff) ^^ string " <not-hosted>" ^^ space
756+ | None -> string (grad_txt diff) ^^ string " <not-hosted>"
753757 | Some arr -> (
754758 match style with
755759 | `Inline ->
@@ -758,11 +762,10 @@ let to_doc ?(force_read = false) ~with_grad ~with_code ?(with_low_level = false)
758762 ^^ space
759763 ^^ Nd. to_doc_inline ~num_batch_axes ~num_input_axes ~num_output_axes ?axes_spec
760764 arr
761- ^^ string " \n "
762765 | `Default | `N5_layout _ | `Label_layout _ ->
763766 Tn. do_read diff.grad;
764767 let prefix = prefix_str ^ " " ^ grad_txt diff in
765- Nd. to_doc ~prefix ~labels ~indices arr ^^ string " \n " ))
768+ Nd. to_doc ~prefix ~labels ~indices arr))
766769 | None -> empty
767770 else empty
768771 in
@@ -774,15 +777,15 @@ let to_doc ?(force_read = false) ~with_grad ~with_code ?(with_low_level = false)
774777 match t.forward.asgns with
775778 | Noop -> empty
776779 | fwd_code ->
777- string " @[<v 2> Current forward body:"
778- ^^ hardline ^^ Asgns. to_doc () fwd_code ^^ string " @] " ^^ hardline
780+ group ( string " Current forward body:" ^^ nest 2 (hardline ^^ Asgns. to_doc () fwd_code))
781+ ^^ hardline
779782 in
780783 let bwd_doc =
781784 match t.diff with
782785 | Some { backprop = { asgns = Noop ; _ } ; _ } -> empty
783786 | Some { backprop = { asgns = bwd_code ; _ } ; _ } ->
784- string " @[<v 2> Current backprop body:"
785- ^^ hardline ^^ Asgns. to_doc () bwd_code ^^ string " @] " ^^ hardline
787+ group ( string " Current backprop body:" ^^ nest 2 (hardline ^^ Asgns. to_doc () bwd_code))
788+ ^^ hardline
786789 | None -> empty
787790 in
788791 fwd_doc ^^ bwd_doc
@@ -796,27 +799,30 @@ let to_doc ?(force_read = false) ~with_grad ~with_code ?(with_low_level = false)
796799 match t.forward.asgns with
797800 | Noop -> empty
798801 | fwd_code ->
799- string " @[<v 2>Current forward low-level body:"
802+ group
803+ (string " Current forward low-level body:"
804+ ^^ nest 2 (hardline ^^ Ir.Low_level. to_doc () (Asgns. to_low_level fwd_code)))
800805 ^^ hardline
801- ^^ Ir.Low_level. to_doc () (Asgns. to_low_level fwd_code)
802- ^^ string " @]" ^^ hardline
803806 in
804807 let bwd_doc =
805808 match t.diff with
806809 | Some { backprop = { asgns = Noop ; _ } ; _ } -> empty
807810 | Some { backprop = { asgns = bwd_code ; _ } ; _ } ->
808- string " @[<v 2>Current backprop low-level body:"
811+ group
812+ (string " Current backprop low-level body:"
813+ ^^ nest 2 (hardline ^^ Ir.Low_level. to_doc () (Asgns. to_low_level bwd_code)))
809814 ^^ hardline
810- ^^ Ir.Low_level. to_doc () (Asgns. to_low_level bwd_code)
811- ^^ string " @]" ^^ hardline
812815 | None -> empty
813816 in
814817 fwd_doc ^^ bwd_doc
815818 else empty
816819 in
817820
818821 (* Combine all documents and print *)
819- group (value_doc ^^ break 1 ^^ grad_doc ^^ break 1 ^^ code_doc ^^ break 1 ^^ low_level_doc)
822+ group
823+ (value_doc ^^ grad_doc
824+ ^^ (if is_empty value_doc && is_empty grad_doc then empty else hardline)
825+ ^^ code_doc ^^ low_level_doc)
820826
821827let print ?here ?(force_read = false ) ~with_grad ~with_code ?(with_low_level = false )
822828 (style : array_print_style ) t =
0 commit comments