File tree Expand file tree Collapse file tree 2 files changed +7
-5
lines changed Expand file tree Collapse file tree 2 files changed +7
-5
lines changed Original file line number Diff line number Diff line change @@ -149,12 +149,13 @@ let op ~(label : string list) ?(compose_op = Shape.Pointwise_bin)
149149 let id = session_state.next_id in
150150 session_state.next_id < - session_state.next_id + 1 ;
151151 let shape = make_shape ~debug_name: (Tn. get_debug_name ~id ~label () ) ~id in
152- let lazy_v_precs = List. map orig_ts ~f: (fun ti -> ti.value.prec) in
153152 let default_prec =
153+ let lazy_v_precs = List. map orig_ts ~f: (fun ti -> ti.value.prec) in
154+ let default = ! default_value_prec in
154155 lazy
155156 (List. map lazy_v_precs ~f: Lazy. force
156157 |> List. reduce ~f: Arrayjit.Ops. promote_prec
157- |> Option. value ~default: ! default_value_prec )
158+ |> Option. value ~default )
158159 in
159160 let rec shape_logics = function
160161 | [] -> [ Shape. Terminal init_op ]
@@ -185,10 +186,11 @@ let op ~(label : string list) ?(compose_op = Shape.Pointwise_bin)
185186 let default_prec =
186187 let f ti = Option. map ti.diff ~f: (fun d -> d.grad.Tn. prec) in
187188 let lazy_g_precs = List. filter_map orig_ts ~f in
189+ let default = ! default_grad_prec in
188190 lazy
189191 (List. map lazy_g_precs ~f: Lazy. force
190192 |> List. reduce ~f: Arrayjit.Ops. promote_prec
191- |> Option. value ~default: ! default_grad_prec )
193+ |> Option. value ~default )
192194 in
193195 let grad_id = session_state.next_id in
194196 session_state.next_id < - session_state.next_id + 1 ;
Original file line number Diff line number Diff line change @@ -44,14 +44,14 @@ val with_unchanged_roots : f:(unit -> 'a) -> 'a
4444val default_value_prec : Arrayjit.Ops .prec ref
4545(* * The default precision for the value node of terminal (i.e. non-composite) tensors.
4646
47- Note: the precision can be set arbitrarily via {!Tnode.update_precision}. The default
47+ Note: the precision of a node can be set arbitrarily via {!Tnode.update_precision}. The default
4848 precision for value nodes of composite tensors is the maximum of precisions of the value nodes
4949 of sub-tensors. *)
5050
5151val default_grad_prec : Arrayjit.Ops .prec ref
5252(* * The default precision for the gradient node of terminal (i.e. non-composite) tensors.
5353
54- Note: the precision can be set arbitrarily via {!Tnode.update_precision}. The default
54+ Note: the precision of a node can be set arbitrarily via {!Tnode.update_precision}. The default
5555 precision for gradient nodes of composite tensors is the maximum of precisions of the gradient
5656 nodes of sub-tensors. *)
5757
You can’t perform that action at this time.
0 commit comments