Skip to content

Commit 1b56fc4

Browse files
committed
Configurable "default" default precision
1 parent 423af64 commit 1b56fc4

File tree

6 files changed

+22
-7
lines changed

6 files changed

+22
-7
lines changed

arrayjit/lib/ops.ml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ let prec_string = function
9090
| Single_prec _ -> "single"
9191
| Double_prec _ -> "double"
9292

93+
let prec_of_string s =
94+
prec_of_sexp (Sexp.Atom (String.(capitalize @@ lowercase s) ^ "_prec"))
95+
9396
let equal_prec p1 p2 =
9497
match (p1, p2) with
9598
| Void_prec, Void_prec -> true

arrayjit/lib/tnode.ml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -562,12 +562,15 @@ let create ?default_prec ~id ~label ~dims init_op =
562562
Registry.add registry tn;
563563
tn
564564

565+
let initial_default_prec =
566+
Ops.prec_of_string (Utils.get_global_arg ~default:"single" ~arg_name:"default_prec")
567+
565568
let find =
566569
let mock =
567570
{
568571
array = lazy None;
569-
prec = lazy Ops.single;
570-
delayed_prec_unsafe = Specified Ops.single;
572+
prec = lazy initial_default_prec;
573+
delayed_prec_unsafe = Specified initial_default_prec;
571574
dims = lazy [||];
572575
size_in_bytes = lazy 0;
573576
id = -1;

arrayjit/test/ocannl_config

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ log_main_domain_to_stdout=true
33
backend=multicore_cc
44
log_level=0
55
print_decimals_precision=2
6-
prefer_backend_uniformity=true
6+
prefer_backend_uniformity=true
7+
default_prec=bfloat16

lib/tensor.ml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,11 @@ let iter_embedded ~f t =
9999
Set.iter ~f t.forward.embedded_nodes;
100100
Option.iter t.diff ~f:(fun diff -> Set.iter ~f diff.backprop.embedded_nodes)
101101

102-
let default_value_prec = ref Ir.Ops.single
103-
let default_grad_prec = ref Ir.Ops.single
102+
let initial_default_prec =
103+
Ir.Ops.prec_of_string (Utils.get_global_arg ~default:"single" ~arg_name:"default_prec")
104+
105+
let default_value_prec = ref initial_default_prec
106+
let default_grad_prec = ref initial_default_prec
104107

105108
exception Session_error of string * t option [@@deriving sexp]
106109

ocannl_config.example

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,4 +168,8 @@ log_file_stem=debug
168168
# It is useful for testing to have outputs more uniform across backends even if that criples
169169
# some backends. Currently, this setting only affects logging from routines to accomodate Metal's
170170
# shortcoming.
171-
prefer_backend_uniformity=false
171+
prefer_backend_uniformity=false
172+
173+
# The initial value for the default precisions for tensors. The default precisions for values and
174+
# gradients can be changed separately via the `Tensor` API.
175+
default_prec=single

test/config/ocannl_config

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ log_main_domain_to_stdout=true
33
backend=multicore_cc
44
log_level=0
55
print_decimals_precision=2
6-
prefer_backend_uniformity=true
6+
prefer_backend_uniformity=true
7+
default_prec=bfloat16

0 commit comments

Comments
 (0)