Skip to content

Commit 0a45e7b

Browse files
committed
Fourth pass on bidirectional precision inference: don't force precisions from below from defaults when inferring
1 parent 2d00d55 commit 0a45e7b

File tree

7 files changed

+49
-45
lines changed

7 files changed

+49
-45
lines changed

arrayjit/lib/tnode.ml

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ type memory_mode =
6262
optional [array] of {!t}). *)
6363
[@@deriving sexp, compare, equal]
6464

65-
type delayed_prec = Not_specified | Default_spec of Ops.prec Lazy.t | Specified of Ops.prec
65+
type delayed_prec = Default of Ops.prec | Inferred of Ops.prec Lazy.t | Specified of Ops.prec
6666
[@@deriving sexp, equal]
6767

6868
type prepare = { is_done : unit -> bool; sync : unit -> unit; transfer : unit -> unit }
@@ -378,7 +378,8 @@ let update_prec ?only_if tn prec =
378378
| Some cond -> (
379379
match tn.delayed_prec_unsafe with
380380
| Specified old_prec -> cond old_prec
381-
| Default_spec old_prec when Lazy.is_val old_prec -> cond @@ Lazy.force old_prec
381+
| Default old_prec -> cond old_prec
382+
| Inferred old_prec when Lazy.is_val old_prec -> cond @@ Lazy.force old_prec
382383
| _ -> true)
383384
in
384385
if do_update then
@@ -409,12 +410,14 @@ let update_prec ?only_if tn prec =
409410
", but the precision is already set to ";
410411
Ops.prec_string (Lazy.force tn.prec);
411412
])
412-
| Default_spec old_prec, Some cond when not @@ Lazy.is_val old_prec ->
413+
| Inferred old_prec, Some cond ->
413414
tn.delayed_prec_unsafe <-
414-
Default_spec
415+
Inferred
415416
(lazy
416417
(let old = Lazy.force old_prec in
417418
if cond old then prec else old))
419+
| Default old_prec, Some cond ->
420+
tn.delayed_prec_unsafe <- (if cond old_prec then Specified prec else Default old_prec)
418421
| _ -> tn.delayed_prec_unsafe <- Specified prec
419422

420423
let update_infer_prec tn delayed_prec =
@@ -430,11 +433,11 @@ let update_infer_prec tn delayed_prec =
430433
else
431434
match tn.delayed_prec_unsafe with
432435
| Specified _ -> () (* User-specified precision has higher priority *)
433-
| Not_specified -> tn.delayed_prec_unsafe <- Default_spec delayed_prec
434-
| Default_spec old_prec ->
435-
(* Combine with existing default precision via promotion *)
436+
| Default _ -> tn.delayed_prec_unsafe <- Inferred delayed_prec
437+
| Inferred old_prec ->
438+
(* Combine with existing inferred precision via promotion *)
436439
tn.delayed_prec_unsafe <-
437-
Default_spec (lazy (Ops.promote_prec (Lazy.force old_prec) (Lazy.force delayed_prec)))
440+
Inferred (lazy (Ops.promote_prec (Lazy.force old_prec) (Lazy.force delayed_prec)))
438441

439442
let get_specified_prec tn =
440443
match tn.delayed_prec_unsafe with Specified prec -> Some prec | _ -> None
@@ -450,9 +453,8 @@ let exceeds_fp16_cutoff tn c =
450453
if Lazy.is_val tn.prec then Lazy.force tn.prec
451454
else
452455
match tn.delayed_prec_unsafe with
453-
| Specified prec -> prec
454-
| Default_spec prec -> Lazy.force prec
455-
| Not_specified -> Lazy.force tn.prec
456+
| Specified prec | Default prec -> prec
457+
| Inferred prec -> Lazy.force prec
456458
in
457459
Ops.is_up_to_fp16 prec
458460

@@ -582,7 +584,7 @@ end)
582584

583585
let registry = Registry.create 16
584586

585-
let create ?default_prec ~id ~label ~dims ~padding () =
587+
let create delayed_prec ~id ~label ~dims ~padding () =
586588
let debug = "Host array for " ^ get_debug_name ~id ~label () in
587589
let rec array =
588590
lazy
@@ -594,17 +596,12 @@ let create ?default_prec ~id ~label ~dims ~padding () =
594596
and prec =
595597
lazy
596598
(match tn.delayed_prec_unsafe with
597-
| Specified prec | Default_spec (lazy prec) -> prec
598-
| Not_specified ->
599-
raise @@ Utils.User_error "Tnode.update_prec: precision is not specified yet")
599+
| Default prec | Specified prec | Inferred (lazy prec) -> prec)
600600
and size_in_bytes = lazy (num_elems tn * Ops.prec_in_bytes (Lazy.force tn.prec))
601601
and tn =
602-
let delayed_prec_unsafe =
603-
match default_prec with None -> Not_specified | Some prec -> Default_spec prec
604-
in
605602
{
606603
array;
607-
delayed_prec_unsafe;
604+
delayed_prec_unsafe = delayed_prec;
608605
prec;
609606
dims;
610607
padding;

lib/precision_inference.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
# Bidirectional precision inference
22

3-
OCANNL features a rudimentary bidirectional precision inference. It is much much less powerful than the constraints-based shape and projections inference. It is somewhat prominent because it contributes the `top_down_prec` flag to the central `Tensor.t` type. The core algorithm is just a couple dozen lines in the `Tensor.op` function, first the bottom-up pass:
3+
OCANNL features a rudimentary bidirectional precision inference. It is much less powerful than the constraints-based shape and projections inference. It is somewhat prominent because it contributes the `top_down_prec` flag to the central `Tensor.t` type.
4+
5+
Tensors that choose `top_down_prec=true` "detach" themselves from their defining tensor expression as far as precision goes. By default tensors are `top_down_prec=false`, except for all the parameter tensors (created via `Tensor.param`), and results of the operation `uint4x32_to_prec_uniform`. When a tensor precision is set by the user via `Tnode.update_prec`, this setting takes precedence over any inferences. When a `top_down_prec=true` tensor has its precision set by the user, it contributes this precision in the bottom up inference (together with all `top_down_prec=false` subtensors).
6+
7+
The core algorithm is just a couple dozen lines in the `Tensor.op` function, first the bottom-up pass:
48

59
```ocaml
610
let default_prec_for default get =
@@ -34,4 +38,3 @@ and later the top-down pass, here from the value node `v`:
3438
List.iter top_down_ts ~f:(fun ti -> update_infer_prec ti.value v.Tn.prec);
3539
```
3640

37-
Tensors that choose `top_down_prec=true` "detach" themselves from their defining tensor expression as far as precision goes. By default tensors are `top_down_prec=false`, except for all the parameter tensors (created via `Tensor.param`), and results of the operation `uint4x32_to_prec_uniform`.

lib/tensor.ml

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -255,10 +255,10 @@ let%track7_sexp op ~(label : string list) ?(ternary_op = Shape.Pointwise_tern)
255255
let shape = make_shape ~debug_name:(Tn.get_debug_name ~id ~label ()) ~id in
256256
(* Split subtensors by whether they use top-down precision inference *)
257257
let top_down_ts = List.filter ordered_ts ~f:(fun t -> t.top_down_prec) in
258-
let default_prec_for default get =
258+
let delayed_prec_for default get =
259259
if top_down_prec then
260260
(* For top-down precision, don't promote from inputs *)
261-
lazy default
261+
Tn.Default default
262262
else
263263
(* For bottom-up precision, only promote from non-top-down subtensors *)
264264
let lazy_v_precs =
@@ -267,12 +267,13 @@ let%track7_sexp op ~(label : string list) ?(ternary_op = Shape.Pointwise_tern)
267267
if ti.top_down_prec then lazy (Tn.get_specified_prec v)
268268
else lazy (Some (Lazy.force v.prec))))
269269
in
270-
lazy
271-
(List.filter_map lazy_v_precs ~f:Lazy.force
272-
|> List.reduce ~f:Ir.Ops.promote_prec
273-
|> Option.value ~default)
270+
Tn.Inferred
271+
(lazy
272+
(List.filter_map lazy_v_precs ~f:Lazy.force
273+
|> List.reduce ~f:Ir.Ops.promote_prec
274+
|> Option.value ~default))
274275
in
275-
let default_prec = default_prec_for !default_value_prec (fun t -> Some t.value) in
276+
let delayed_prec = delayed_prec_for !default_value_prec (fun t -> Some t.value) in
276277
let terminal_logic () =
277278
let open Shape in
278279
match terminal_op with
@@ -291,7 +292,7 @@ let%track7_sexp op ~(label : string list) ?(ternary_op = Shape.Pointwise_tern)
291292
| Some (Shape.Data (Asgns.Padded { data; padding = padding_spec; padded_value })) ->
292293
let padding = Some (padding_spec, padded_value) in
293294
Tn.create_from_padded ~id ~label ~ndarray:data ~padding ()
294-
| Some (Shape.Fetch _) | None -> Tn.create ~default_prec ~id ~label ~dims ~padding ()
295+
| Some (Shape.Fetch _) | None -> Tn.create delayed_prec ~id ~label ~dims ~padding ()
295296
in
296297
let update_infer_prec tn prec =
297298
(* Instead of just checking prec, we cross-check with dims (needed for code generation), to
@@ -363,11 +364,11 @@ let%track7_sexp op ~(label : string list) ?(ternary_op = Shape.Pointwise_tern)
363364
t)
364365
else
365366
let get ti = Option.map ti.diff ~f:(fun d -> d.grad) in
366-
let default_prec = default_prec_for !default_grad_prec get in
367+
let delayed_prec = delayed_prec_for !default_grad_prec get in
367368
let grad_id = session_state.next_id in
368369
session_state.next_id <- session_state.next_id + 1;
369370
let g =
370-
Tn.create ~default_prec ~id:grad_id ~label:("grad" :: label) ~dims
371+
Tn.create delayed_prec ~id:grad_id ~label:("grad" :: label) ~dims
371372
~padding:(lazy (Shape.to_padding shape))
372373
()
373374
in

test/operations/dune

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@
8383
%{dep:threefry4x32_demo.exe}
8484
"--ocannl_output_prec_in_ll_files=true"
8585
"--ocannl_output_debug_files_in_build_directory=true"
86-
"--ocannl_clean_up_artifacts_on_startup=false")
86+
"--ocannl_clean_up_artifacts_on_startup=true")
8787
(run
8888
%{dep:top_down_prec.exe}
8989
"--ocannl_output_prec_in_ll_files=true"
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11

22
d_fwd (): /* d fwd */
3-
n6<half>[0] := (a<single>[0] + b<half>[0]);
4-
d<bfloat16>[0] := (n6<half>[0] * c<single>[0]);
3+
n6<half>[0] := (a<half>[0] + b<half>[0]);
4+
d<bfloat16>[0] := (n6<half>[0] * c<bfloat16>[0]);
55
/* end */
Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
Retrieving commandline, environment, or config file variable ocannl_log_level
22
Found 0, in the config file
3-
┌────────────────────┐
4-
│[8]: *._d shape 0:1 │
5-
│┌┬──────┐ │
6-
│││axis 0│ │
7-
│├┼──────┤ │
8-
│││ 8.00 │ │
9-
│└┴──────┘ │
10-
└────────────────────┘
11-
grad_*._d <not-hosted>
3+
#8 *._d
4+
8.00
5+
#9 grad_*._d Virt/30
6+
<void>
7+
#6 + Virt/152 │#4 c non-emb
8+
<void> │ 2.00
9+
#7 grad_+ Virt/30 │#5 grad_c Local/26030
10+
<void> │<void>
11+
#0 a non-emb │#2 b non-emb │
12+
2.00 │ 2.00 │
13+
#1 grad_a Local/26030│#3 grad_b Local/26030│
14+
<void> │<void> │

test/operations/top_down_prec.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ let () =
1212
let%op d = ("a" [2] + "b" [2]) *. "c" [2] in
1313
Tn.update_prec b.value Ir.Ops.half;
1414
Tn.update_prec d.value Ir.Ops.bfloat16;
15-
(* Compile and run *)
15+
(* Even when the default precision is single, c is bfloat16 and a is half. *)
1616
Ocannl.Train.set_hosted d.value;
1717
ignore (Ocannl.Train.forward_once (module Backend) d);
18-
Train.printf d
18+
Train.printf_tree d

0 commit comments

Comments
 (0)