Skip to content

Commit db206b0

Browse files
committed
Detect FP16 constant overflow
1 parent 90b93f5 commit db206b0

File tree

3 files changed

+55
-8
lines changed

3 files changed

+55
-8
lines changed

arrayjit/lib/low_level.ml

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -736,7 +736,47 @@ let simplify_llc llc =
736736
let result = Unop (op, v) in
737737
if equal_float_t llv v then result else loop_float result
738738
in
739-
loop_proc llc
739+
let check_constant =
740+
match Utils.settings.check_half_prec_constants_cutoff with
741+
| None -> fun _prec _c -> ()
742+
| Some cutoff -> (
743+
fun tn c ->
744+
match tn.Tn.prec with
745+
| Ops.Half_prec _ ->
746+
if Float.(abs c >= cutoff) then
747+
raise
748+
@@ Utils.User_error
749+
("Constant " ^ Float.to_string c
750+
^ " is too big for FP16 aka. half precision, risk of overflow; increase \
751+
precision of tensor node " ^ Tn.debug_name tn)
752+
| _ -> ())
753+
in
754+
let rec check_proc llc =
755+
let loop = check_proc in
756+
match llc with
757+
| Seq (c1, c2) ->
758+
loop c1;
759+
loop c2
760+
| For_loop { body; _ } -> loop body
761+
| Zero_out _ -> ()
762+
| Set { tn; llv; _ } -> check_float tn llv
763+
| Set_local (id, llv) -> check_float id.tn llv
764+
| Noop | Comment _ | Staged_compilation _ -> ()
765+
and check_float tn llv =
766+
let loop = check_float tn in
767+
match llv with
768+
| Constant c -> check_constant tn c
769+
| Local_scope { body; _ } -> check_proc body
770+
| Binop (_, v1, v2) ->
771+
loop v1;
772+
loop v2
773+
| Unop (_, v) -> loop v
774+
| Embed_index (Indexing.Fixed_idx i) -> check_constant tn (Float.of_int i)
775+
| Embed_index _ | Get_local _ | Get_global (_, _) | Get (_, _) -> ()
776+
in
777+
let result = loop_proc llc in
778+
if Option.is_some Utils.settings.check_half_prec_constants_cutoff then check_proc result;
779+
result
740780

741781
type traced_store = (Tn.t, traced_array) Base.Hashtbl.t [@@deriving sexp_of]
742782

arrayjit/lib/utils.ml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ type settings = {
4040
mutable fixed_state_for_init : int option;
4141
mutable print_decimals_precision : int;
4242
(** When rendering arrays etc., outputs this many decimal digits. *)
43+
mutable check_half_prec_constants_cutoff : float option;
44+
(** If given, generic code optimization should fail if a half precision FP16 constant exceeds
45+
the cutoff. *)
4346
}
4447
[@@deriving sexp]
4548

@@ -51,6 +54,7 @@ let settings =
5154
output_debug_files_in_build_directory = false;
5255
fixed_state_for_init = None;
5356
print_decimals_precision = 2;
57+
check_half_prec_constants_cutoff = Some (2. **. 14.);
5458
}
5559

5660
let accessed_global_args = Hash_set.create (module String)
@@ -314,7 +318,10 @@ let restore_settings () =
314318
(let seed = get_global_arg ~arg_name:"fixed_state_for_init" ~default:"" in
315319
if String.is_empty seed then None else Some (Int.of_string seed));
316320
settings.print_decimals_precision <-
317-
Int.of_string @@ get_global_arg ~arg_name:"print_decimals_precision" ~default:"2"
321+
Int.of_string @@ get_global_arg ~arg_name:"print_decimals_precision" ~default:"2";
322+
settings.check_half_prec_constants_cutoff <-
323+
Float.of_string_opt
324+
@@ get_global_arg ~arg_name:"check_half_prec_constants_cutoff" ~default:"16384.0"
318325

319326
let () = restore_settings ()
320327
let with_runtime_debug () = settings.output_debug_files_in_build_directory && settings.log_level > 1

bin/moons_benchmark.ml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -176,14 +176,14 @@ let _suspend () =
176176
~backend_name:"gccjit" ~value_prec:CDSL.single ~grad_prec:CDSL.double ()
177177

178178
let cuda_benchmarks =
179-
List.concat_map [ 1; 3; 6; 12 (* ; 16; 32; 64 *) ] ~f:(fun num_devices ->
179+
List.concat_map [ (* 1; *) 3 (* ; 6; 12; 16; 32; 64 *) ] ~f:(fun num_devices ->
180180
List.concat_map
181-
[ 3 * 32; 3 * 64; 3 * 128 ]
181+
[ 3 * 32(* ; 3 * 64; 3 * 128 *)]
182182
~f:(fun batch_size ->
183-
List.concat_map [ 0; 1; (* 2; *) 3 ] ~f:(fun inlining_cutoff ->
184-
List.concat_map [ 1; 3 (* ; 7 *) ] ~f:(fun seed ->
185-
List.concat_map [ (* "gccjit" ; "cc"; *) "cuda" ] ~f:(fun backend_name ->
186-
List.concat_map [ (* CDSL.double; CDSL.single; *) CDSL.half ]
183+
List.concat_map [ (* 0; 1; 2; *) 3 ] ~f:(fun inlining_cutoff ->
184+
List.concat_map [ 1 (* ; 3 ; 7 *) ] ~f:(fun seed ->
185+
List.concat_map [ (* "gccjit" ; *) "cc";(* "cuda" *) ] ~f:(fun backend_name ->
186+
List.concat_map [ (* CDSL.double; CDSL.single; *) CDSL.half ]
187187
~f:(fun value_prec ->
188188
[
189189
classify_moons ~seed ~on_device:true ~inlining_cutoff ~num_devices

0 commit comments

Comments
 (0)