You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
OCANNL features a rudimentary bidirectional precision inference. It is much much less powerful than the constraints-based shape and projections inference. It is somewhat prominent because it contributes the `top_down_prec` flag to the central `Tensor.t` type. The core algorithm is just a couple dozen lines in the `Tensor.op` function, first the bottom-up pass:
4
+
5
+
```ocaml
6
+
let default_prec_for default get =
7
+
if top_down_prec then
8
+
(* For top-down precision, don't promote from inputs *)
9
+
lazy default
10
+
else
11
+
(* For bottom-up precision, only promote from non-top-down subtensors *)
12
+
let lazy_v_precs =
13
+
List.filter_map ordered_ts ~f:(fun ti ->
14
+
Option.map (get ti) ~f:(fun v ->
15
+
if ti.top_down_prec then lazy (Tn.get_specified_prec v)
16
+
else lazy (Some (Lazy.force v.prec))))
17
+
in
18
+
lazy
19
+
(List.filter_map lazy_v_precs ~f:Lazy.force
20
+
|> List.reduce ~f:Ir.Ops.promote_prec
21
+
|> Option.value ~default)
22
+
in
23
+
```
24
+
25
+
and later the top-down pass, here from the value node `v`:
26
+
27
+
```ocaml
28
+
let update_infer_prec tn prec =
29
+
(* Instead of just checking prec, we cross-check with dims (needed for code generation), to
30
+
catch prec forcing bugs. *)
31
+
if not (Lazy.is_val tn.Tn.dims) then Tn.update_infer_prec tn prec
32
+
in
33
+
(* Apply delayed top-down precision updates to parameter subtensors *)
34
+
List.iter top_down_ts ~f:(fun ti -> update_infer_prec ti.value v.Tn.prec);
35
+
```
36
+
37
+
Tensors that choose `top_down_prec=true` "detach" themselves from their defining tensor expression as far as precision goes. By default tensors are `top_down_prec=false`, except for all the parameter tensors (created via `Tensor.param`), and results of the operation `uint4x32_to_prec_uniform`.
0 commit comments