Skip to content

Commit 2d00d55

Browse files
committed
Third pass on bidirectional precision inference: include top-down tensors with precision specified by the user in bottom-up propagation
There'll be a fourth pass, to not force precisions from below from defaults but rather start with unspecified.
1 parent d304b2c commit 2d00d55

File tree

9 files changed

+138
-37
lines changed

9 files changed

+138
-37
lines changed

arrayjit/lib/tnode.ml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,9 @@ let update_infer_prec tn delayed_prec =
436436
tn.delayed_prec_unsafe <-
437437
Default_spec (lazy (Ops.promote_prec (Lazy.force old_prec) (Lazy.force delayed_prec)))
438438

439+
let get_specified_prec tn =
440+
match tn.delayed_prec_unsafe with Specified prec -> Some prec | _ -> None
441+
439442
let exceeds_fp16_cutoff tn c =
440443
match Utils.settings.check_half_prec_constants_cutoff with
441444
| None -> false

lib/precision_inference.md

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Bidirectional precision inference
2+
3+
OCANNL features a rudimentary bidirectional precision inference. It is much much less powerful than the constraints-based shape and projections inference. It is somewhat prominent because it contributes the `top_down_prec` flag to the central `Tensor.t` type. The core algorithm is just a couple dozen lines in the `Tensor.op` function, first the bottom-up pass:
4+
5+
```ocaml
6+
let default_prec_for default get =
7+
if top_down_prec then
8+
(* For top-down precision, don't promote from inputs *)
9+
lazy default
10+
else
11+
(* For bottom-up precision, only promote from non-top-down subtensors *)
12+
let lazy_v_precs =
13+
List.filter_map ordered_ts ~f:(fun ti ->
14+
Option.map (get ti) ~f:(fun v ->
15+
if ti.top_down_prec then lazy (Tn.get_specified_prec v)
16+
else lazy (Some (Lazy.force v.prec))))
17+
in
18+
lazy
19+
(List.filter_map lazy_v_precs ~f:Lazy.force
20+
|> List.reduce ~f:Ir.Ops.promote_prec
21+
|> Option.value ~default)
22+
in
23+
```
24+
25+
and later the top-down pass, here from the value node `v`:
26+
27+
```ocaml
28+
let update_infer_prec tn prec =
29+
(* Instead of just checking prec, we cross-check with dims (needed for code generation), to
30+
catch prec forcing bugs. *)
31+
if not (Lazy.is_val tn.Tn.dims) then Tn.update_infer_prec tn prec
32+
in
33+
(* Apply delayed top-down precision updates to parameter subtensors *)
34+
List.iter top_down_ts ~f:(fun ti -> update_infer_prec ti.value v.Tn.prec);
35+
```
36+
37+
Tensors that choose `top_down_prec=true` "detach" themselves from their defining tensor expression as far as precision goes. By default tensors are `top_down_prec=false`, except for all the parameter tensors (created via `Tensor.param`), and results of the operation `uint4x32_to_prec_uniform`.

lib/tensor.ml

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -254,20 +254,25 @@ let%track7_sexp op ~(label : string list) ?(ternary_op = Shape.Pointwise_tern)
254254
let _session_state_next_id : int = session_state.next_id in
255255
let shape = make_shape ~debug_name:(Tn.get_debug_name ~id ~label ()) ~id in
256256
(* Split subtensors by whether they use top-down precision inference *)
257-
let top_down_ts, bottom_up_ts = List.partition_tf ordered_ts ~f:(fun t -> t.top_down_prec) in
258-
let default_prec =
257+
let top_down_ts = List.filter ordered_ts ~f:(fun t -> t.top_down_prec) in
258+
let default_prec_for default get =
259259
if top_down_prec then
260260
(* For top-down precision, don't promote from inputs *)
261-
lazy !default_value_prec
261+
lazy default
262262
else
263263
(* For bottom-up precision, only promote from non-top-down subtensors *)
264-
let lazy_v_precs = List.map bottom_up_ts ~f:(fun ti -> ti.value.prec) in
265-
let default = !default_value_prec in
264+
let lazy_v_precs =
265+
List.filter_map ordered_ts ~f:(fun ti ->
266+
Option.map (get ti) ~f:(fun v ->
267+
if ti.top_down_prec then lazy (Tn.get_specified_prec v)
268+
else lazy (Some (Lazy.force v.prec))))
269+
in
266270
lazy
267-
(List.map lazy_v_precs ~f:Lazy.force
271+
(List.filter_map lazy_v_precs ~f:Lazy.force
268272
|> List.reduce ~f:Ir.Ops.promote_prec
269273
|> Option.value ~default)
270274
in
275+
let default_prec = default_prec_for !default_value_prec (fun t -> Some t.value) in
271276
let terminal_logic () =
272277
let open Shape in
273278
match terminal_op with
@@ -357,20 +362,8 @@ let%track7_sexp op ~(label : string list) ?(ternary_op = Shape.Pointwise_tern)
357362
session_state.forward_roots <- Map.add_exn session_state.forward_roots ~key:id ~data:t;
358363
t)
359364
else
360-
let default_prec =
361-
if top_down_prec then
362-
(* For top-down precision, don't promote from inputs *)
363-
lazy !default_grad_prec
364-
else
365-
(* For bottom-up precision, only promote from non-top-down subtensors *)
366-
let f ti = Option.map ti.diff ~f:(fun d -> d.grad.Tn.prec) in
367-
let lazy_g_precs = List.filter_map bottom_up_ts ~f in
368-
let default = !default_grad_prec in
369-
lazy
370-
(List.map lazy_g_precs ~f:Lazy.force
371-
|> List.reduce ~f:Ir.Ops.promote_prec
372-
|> Option.value ~default)
373-
in
365+
let get ti = Option.map ti.diff ~f:(fun d -> d.grad) in
366+
let default_prec = default_prec_for !default_grad_prec get in
374367
let grad_id = session_state.next_id in
375368
session_state.next_id <- session_state.next_id + 1;
376369
let g =

test/operations/dune

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,28 +66,59 @@
6666
(preprocess
6767
(pps ppx_here ppx_ocannl)))
6868

69+
(test
70+
(name top_down_prec)
71+
(modules top_down_prec)
72+
(libraries base ocannl)
73+
(preprocess
74+
(pps ppx_here ppx_ocannl)))
75+
6976
(rule
7077
(alias runtest)
7178
(target
7279
(dir build_files))
7380
(action
74-
(run
75-
%{dep:threefry4x32_demo.exe}
76-
"--ocannl_output_prec_in_ll_files=true"
77-
"--ocannl_output_debug_files_in_build_directory=true")))
81+
(progn
82+
(run
83+
%{dep:threefry4x32_demo.exe}
84+
"--ocannl_output_prec_in_ll_files=true"
85+
"--ocannl_output_debug_files_in_build_directory=true"
86+
"--ocannl_clean_up_artifacts_on_startup=false")
87+
(run
88+
%{dep:top_down_prec.exe}
89+
"--ocannl_output_prec_in_ll_files=true"
90+
"--ocannl_output_debug_files_in_build_directory=true"
91+
"--ocannl_clean_up_artifacts_on_startup=false"))))
92+
93+
(rule
94+
(deps "build_files/n3_fwd-unoptimized.ll")
95+
(target "n3_fwd_with_prec-unoptimized.ll.actual")
96+
(action
97+
(copy
98+
"build_files/n3_fwd-unoptimized.ll"
99+
"n3_fwd_with_prec-unoptimized.ll.actual")))
100+
101+
(rule
102+
(alias runtest)
103+
(action
104+
(diff
105+
"n3_fwd_with_prec-unoptimized.ll.expected"
106+
"n3_fwd_with_prec-unoptimized.ll.actual")))
78107

79108
(rule
80-
(deps "build_files/n3_fwd.ll")
81-
(target "n3_fwd_with_prec.ll.actual")
109+
(deps "build_files/d_fwd-unoptimized.ll")
110+
(target "top_down_prec-unoptimized.ll.actual")
82111
(action
83-
(copy "build_files/n3_fwd.ll" "n3_fwd_with_prec.ll.actual")))
112+
(copy
113+
"build_files/d_fwd-unoptimized.ll"
114+
"top_down_prec-unoptimized.ll.actual")))
84115

85116
(rule
86117
(alias runtest)
87118
(action
88119
(diff
89-
"n3_fwd_with_prec.ll.expected"
90-
"n3_fwd_with_prec.ll.actual")))
120+
"top_down_prec-unoptimized.ll.expected"
121+
"top_down_prec-unoptimized.ll.actual")))
91122

92123
(test
93124
(name test_vec_simple)
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
2+
n3_fwd (): /* n3 fwd */
3+
random_seed<uint4x32>[] := 42;
4+
for i2 = 0 to 5 { n1<uint4x32>[i2] := i2; }
5+
for i4 = 0 to 5 {
6+
threefry4x32<uint4x32>[i4] := (random_seed<uint4x32>[] ^^^^ n1<uint4x32>[i4]);
7+
}
8+
for i6 = 0 to 5 {
9+
n3<half>[8*i6]<8> := uint4x32_to_prec_uniform(threefry4x32<uint4x32>[i6], <8>);
10+
}
11+
/* end */

test/operations/n3_fwd_with_prec.ll.expected

Lines changed: 0 additions & 8 deletions
This file was deleted.
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
2+
d_fwd (): /* d fwd */
3+
n6<half>[0] := (a<single>[0] + b<half>[0]);
4+
d<bfloat16>[0] := (n6<half>[0] * c<single>[0]);
5+
/* end */
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
Retrieving commandline, environment, or config file variable ocannl_log_level
2+
Found 0, in the config file
3+
┌────────────────────┐
4+
│[8]: *._d shape 0:1 │
5+
│┌┬──────┐ │
6+
│││axis 0│ │
7+
│├┼──────┤ │
8+
│││ 8.00 │ │
9+
│└┴──────┘ │
10+
└────────────────────┘
11+
grad_*._d <not-hosted>

test/operations/top_down_prec.ml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
(* A simple demo of using the Threefry4x32 PRNG in OCANNL *)
2+
3+
open Base
4+
module Tensor = Ocannl.Tensor
5+
module Train = Ocannl.Train
6+
module TDSL = Ocannl.Operation.TDSL
7+
module Tn = Ir.Tnode
8+
9+
let () =
10+
Tensor.unsafe_reinitialize ();
11+
let module Backend = (val Backends.fresh_backend ()) in
12+
let%op d = ("a" [2] + "b" [2]) *. "c" [2] in
13+
Tn.update_prec b.value Ir.Ops.half;
14+
Tn.update_prec d.value Ir.Ops.bfloat16;
15+
(* Compile and run *)
16+
Ocannl.Train.set_hosted d.value;
17+
ignore (Ocannl.Train.forward_once (module Backend) d);
18+
Train.printf d

0 commit comments

Comments
 (0)