Skip to content

Commit 150cef7

Browse files
committed
Don't delay retrieving default precisions for value and gradient nodes
1 parent 61dfc04 commit 150cef7

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

lib/tensor.ml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,12 +149,13 @@ let op ~(label : string list) ?(compose_op = Shape.Pointwise_bin)
149149
let id = session_state.next_id in
150150
session_state.next_id <- session_state.next_id + 1;
151151
let shape = make_shape ~debug_name:(Tn.get_debug_name ~id ~label ()) ~id in
152-
let lazy_v_precs = List.map orig_ts ~f:(fun ti -> ti.value.prec) in
153152
let default_prec =
153+
let lazy_v_precs = List.map orig_ts ~f:(fun ti -> ti.value.prec) in
154+
let default = !default_value_prec in
154155
lazy
155156
(List.map lazy_v_precs ~f:Lazy.force
156157
|> List.reduce ~f:Arrayjit.Ops.promote_prec
157-
|> Option.value ~default:!default_value_prec)
158+
|> Option.value ~default)
158159
in
159160
let rec shape_logics = function
160161
| [] -> [ Shape.Terminal init_op ]
@@ -185,10 +186,11 @@ let op ~(label : string list) ?(compose_op = Shape.Pointwise_bin)
185186
let default_prec =
186187
let f ti = Option.map ti.diff ~f:(fun d -> d.grad.Tn.prec) in
187188
let lazy_g_precs = List.filter_map orig_ts ~f in
189+
let default = !default_grad_prec in
188190
lazy
189191
(List.map lazy_g_precs ~f:Lazy.force
190192
|> List.reduce ~f:Arrayjit.Ops.promote_prec
191-
|> Option.value ~default:!default_grad_prec)
193+
|> Option.value ~default)
192194
in
193195
let grad_id = session_state.next_id in
194196
session_state.next_id <- session_state.next_id + 1;

lib/tensor.mli

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,14 @@ val with_unchanged_roots : f:(unit -> 'a) -> 'a
4444
val default_value_prec : Arrayjit.Ops.prec ref
4545
(** The default precision for the value node of terminal (i.e. non-composite) tensors.
4646
47-
Note: the precision can be set arbitrarily via {!Tnode.update_precision}. The default
47+
Note: the precision of a node can be set arbitrarily via {!Tnode.update_precision}. The default
4848
precision for value nodes of composite tensors is the maximum of precisions of the value nodes
4949
of sub-tensors. *)
5050

5151
val default_grad_prec : Arrayjit.Ops.prec ref
5252
(** The default precision for the gradient node of terminal (i.e. non-composite) tensors.
5353
54-
Note: the precision can be set arbitrarily via {!Tnode.update_precision}. The default
54+
Note: the precision of a node can be set arbitrarily via {!Tnode.update_precision}. The default
5555
precision for gradient nodes of composite tensors is the maximum of precisions of the gradient
5656
nodes of sub-tensors. *)
5757

0 commit comments

Comments
 (0)