File tree Expand file tree Collapse file tree 6 files changed +22
-7
lines changed Expand file tree Collapse file tree 6 files changed +22
-7
lines changed Original file line number Diff line number Diff line change @@ -90,6 +90,9 @@ let prec_string = function
9090 | Single_prec _ -> " single"
9191 | Double_prec _ -> " double"
9292
93+ let prec_of_string s =
94+ prec_of_sexp (Sexp. Atom (String. (capitalize @@ lowercase s) ^ " _prec" ))
95+
9396let equal_prec p1 p2 =
9497 match (p1, p2) with
9598 | Void_prec , Void_prec -> true
Original file line number Diff line number Diff line change @@ -562,12 +562,15 @@ let create ?default_prec ~id ~label ~dims init_op =
562562 Registry. add registry tn;
563563 tn
564564
565+ let initial_default_prec =
566+ Ops. prec_of_string (Utils. get_global_arg ~default: " single" ~arg_name: " default_prec" )
567+
565568let find =
566569 let mock =
567570 {
568571 array = lazy None ;
569- prec = lazy Ops. single ;
570- delayed_prec_unsafe = Specified Ops. single ;
572+ prec = lazy initial_default_prec ;
573+ delayed_prec_unsafe = Specified initial_default_prec ;
571574 dims = lazy [||];
572575 size_in_bytes = lazy 0 ;
573576 id = - 1 ;
Original file line number Diff line number Diff line change @@ -3,4 +3,5 @@ log_main_domain_to_stdout=true
33backend=multicore_cc
44log_level=0
55print_decimals_precision=2
6- prefer_backend_uniformity=true
6+ prefer_backend_uniformity=true
7+ default_prec=bfloat16
Original file line number Diff line number Diff line change @@ -99,8 +99,11 @@ let iter_embedded ~f t =
9999 Set. iter ~f t.forward.embedded_nodes;
100100 Option. iter t.diff ~f: (fun diff -> Set. iter ~f diff.backprop.embedded_nodes)
101101
102- let default_value_prec = ref Ir.Ops. single
103- let default_grad_prec = ref Ir.Ops. single
102+ let initial_default_prec =
103+ Ir.Ops. prec_of_string (Utils. get_global_arg ~default: " single" ~arg_name: " default_prec" )
104+
105+ let default_value_prec = ref initial_default_prec
106+ let default_grad_prec = ref initial_default_prec
104107
105108exception Session_error of string * t option [@@ deriving sexp]
106109
Original file line number Diff line number Diff line change @@ -168,4 +168,8 @@ log_file_stem=debug
168168# It is useful for testing to have outputs more uniform across backends even if that criples
169169# some backends. Currently, this setting only affects logging from routines to accomodate Metal's
170170# shortcoming.
171- prefer_backend_uniformity=false
171+ prefer_backend_uniformity=false
172+
173+ # The initial value for the default precisions for tensors. The default precisions for values and
174+ # gradients can be changed separately via the `Tensor` API.
175+ default_prec=single
Original file line number Diff line number Diff line change @@ -3,4 +3,5 @@ log_main_domain_to_stdout=true
33backend=multicore_cc
44log_level=0
55print_decimals_precision=2
6- prefer_backend_uniformity=true
6+ prefer_backend_uniformity=true
7+ default_prec=bfloat16
You can’t perform that action at this time.
0 commit comments