Skip to content

Commit 0159bfd

Browse files
committed
Fix: upcast constants that exceed fp16 cutoff config
1 parent e698ef3 commit 0159bfd

File tree

7 files changed

+45
-21
lines changed

7 files changed

+45
-21
lines changed

arrayjit/lib/low_level.ml

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -693,17 +693,13 @@ let simplify_llc llc =
693693
let result = Unop (op, v) in
694694
if equal_float_t llv v then result else loop_float result
695695
in
696-
let check_constant =
697-
match Utils.settings.check_half_prec_constants_cutoff with
698-
| None -> fun _prec _c -> ()
699-
| Some cutoff ->
700-
fun tn c ->
701-
if (Ops.is_fp16 @@ Lazy.force tn.Tn.prec) && Float.(abs c >= cutoff) then
702-
raise
703-
@@ Utils.User_error
704-
("Constant " ^ Float.to_string c
705-
^ " is too big for FP16 aka. half precision, risk of overflow; increase precision \
706-
of tensor node " ^ Tn.debug_name tn)
696+
let check_constant tn c =
697+
if Tn.exceeds_fp16_cutoff tn c then
698+
raise
699+
@@ Utils.User_error
700+
("Constant " ^ Float.to_string c
701+
^ " is too big for FP16 aka. half precision, risk of overflow; increase precision of \
702+
tensor node " ^ Tn.debug_name tn)
707703
in
708704
let rec check_proc llc =
709705
let loop = check_proc in

arrayjit/lib/ops.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ let byte = Byte_prec Byte
2828
let half = Half_prec Half
2929
let single = Single_prec Single
3030
let double = Double_prec Double
31-
let is_fp16 = function Half_prec _ -> true | _ -> false
31+
let is_up_to_fp16 = function Half_prec _ | Byte_prec _ -> true | _ -> false
3232

3333
let sexp_of_prec = function
3434
| Void_prec -> Sexp.Atom "Void_prec"

arrayjit/lib/tnode.ml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ type t = {
7676
(** Display information. It is better if the last element of the list is the most narrow or
7777
alphanumeric, e.g. an identifier. *)
7878
mutable delayed_prec_unsafe : delayed_prec;
79+
(** Participates in the computation of {!field-prec}. *)
7980
mutable memory_mode : (memory_mode * int) option;
8081
mutable backend_info : Sexp.t;
8182
mutable code_name : string option;
@@ -374,6 +375,23 @@ let update_prec ?only_if tn prec =
374375
if cond old then prec else old))
375376
| _ -> tn.delayed_prec_unsafe <- Specified prec
376377

378+
let exceeds_fp16_cutoff tn c =
379+
match Utils.settings.check_half_prec_constants_cutoff with
380+
| None -> false
381+
| Some cutoff ->
382+
(* Only force if needed. *)
383+
Float.(abs c >= cutoff)
384+
&&
385+
let prec =
386+
if Lazy.is_val tn.prec then Lazy.force tn.prec
387+
else
388+
match tn.delayed_prec_unsafe with
389+
| Specified prec -> prec
390+
| Default_spec prec -> Lazy.force prec
391+
| Not_specified -> Lazy.force tn.prec
392+
in
393+
Ops.is_up_to_fp16 prec
394+
377395
include Comparator.Make (struct
378396
type nonrec t = t
379397

bin/moons_benchmark.ml

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -214,13 +214,8 @@ let _mem_benchmarks =
214214
~f:(fun batch_size ->
215215
List.concat_map [ 0; (* 1; 2; *) 3 ] ~f:(fun inlining_cutoff ->
216216
List.concat_map [ (* 1; 3; *) 7 (* *) ] ~f:(fun seed ->
217-
List.concat_map
218-
[
219-
(* "gccjit" ; *)
220-
(* "cc"; *)
221-
"cuda";
222-
] ~f:(fun backend_name ->
223-
List.concat_map [ (* CDSL.double; *) CDSL.single (* ; CDSL.half *) ]
217+
List.concat_map [ (* "gccjit" ; *) "cc"; "cuda" ] ~f:(fun backend_name ->
218+
List.concat_map [ (* CDSL.double; *) CDSL.single; CDSL.half ]
224219
~f:(fun value_prec ->
225220
[
226221
classify_moons ~seed ~on_device:true ~inlining_cutoff ~num_streams
@@ -242,6 +237,13 @@ let _suspended () =
242237
(* let () = List.map benchmarks ~f:(nth_best 2) |> PrintBox_utils.table |> PrintBox_text.output
243238
Stdio.stdout *)
244239

240+
let _suspended () =
241+
[
242+
classify_moons ~seed:7 ~on_device:true ~inlining_cutoff:0 ~num_streams:3 ~batch_size:240
243+
~backend_name:"cc" ~value_prec:CDSL.half ~grad_prec:CDSL.half ();
244+
]
245+
|> PrintBox_utils.table |> PrintBox_text.output Stdio.stdout
246+
245247
let benchmark benchmarks =
246248
List.map benchmarks ~f:(fun bench -> bench ())
247249
|> PrintBox_utils.table |> PrintBox_text.output Stdio.stdout

lib/tensor.ml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,9 @@ let number ?(label = []) ?axis_label ?(grad_spec = Prohibit_grad) c =
314314
| Some axis_label -> t ~output_axes:[ (axis_label, 1) ] ()
315315
in
316316
Tn.update_memory_mode t.value Effectively_constant 24;
317+
Arrayjit.Ops.(
318+
if Tn.exceeds_fp16_cutoff t.value c then
319+
Tn.update_prec ~only_if:is_up_to_fp16 t.value single);
317320
t
318321

319322
let ndarray ?(label = []) ?(grad_spec = Prohibit_grad) ?batch_dims ?input_dims ?output_dims
@@ -356,6 +359,10 @@ let ndarray ?(label = []) ?(grad_spec = Prohibit_grad) ?batch_dims ?input_dims ?
356359
()
357360
in
358361
Tn.update_memory_mode t.value Effectively_constant 24;
362+
let max_abs = Array.fold values ~init:0. ~f:(fun acc v -> Float.(max acc @@ abs v)) in
363+
Arrayjit.Ops.(
364+
if Tn.exceeds_fp16_cutoff t.value max_abs then
365+
Tn.update_prec ~only_if:is_up_to_fp16 t.value single);
359366
t
360367

361368
let param ?(more_label = []) ?input_dims ?output_dims ?input_axes ?output_axes ?deduced

lib/train.ml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,8 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
487487
| None -> !.init_lr *. ((2 *. !..steps) - !@step_n) /. !..steps
488488
| Some schedule -> schedule ~batch_n ~step_n
489489
in
490-
Tn.update_prec ~only_if:Ops.is_fp16 learning_rate.value Ops.single;
490+
(* Note: constants at default half-prec are automatically upcasted when they exceed
491+
Utils.settings.check_half_prec_constants_cutoff, no need to upcast learning_rate.value. *)
491492
set_hosted learning_rate.value;
492493
let sgd = sgd_update ~learning_rate ~weight_decay update in
493494
let grad_update = Backend.compile ~shared:true bindings update.fwd_bprop in

todo.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# This file is for tasks with a smaller granularity than issues, typically immediate tasks.
2-
(B) bin/moons_benchmark with the cc backend crashes with half-prec overflow
2+
(B) bin/moons_benchmark with the cc backend crashes with half-prec overflow {cm:2024-11-24}
33
(B) remove syncing from the data parallel algo: stream-to-stream syncing is now automatic {cm:2024-11-23}
44
(A) cuda backend crashes in bin/moons_benchmark {cm:2024-11-22}
55
(B) figure out why cuda backend parallelism slows down in later epochs

0 commit comments

Comments
 (0)