File tree Expand file tree Collapse file tree 7 files changed +45
-21
lines changed Expand file tree Collapse file tree 7 files changed +45
-21
lines changed Original file line number Diff line number Diff line change @@ -693,17 +693,13 @@ let simplify_llc llc =
693693 let result = Unop (op, v) in
694694 if equal_float_t llv v then result else loop_float result
695695 in
696- let check_constant =
697- match Utils. settings.check_half_prec_constants_cutoff with
698- | None -> fun _prec _c -> ()
699- | Some cutoff ->
700- fun tn c ->
701- if (Ops. is_fp16 @@ Lazy. force tn.Tn. prec) && Float. (abs c > = cutoff) then
702- raise
703- @@ Utils. User_error
704- (" Constant " ^ Float. to_string c
705- ^ " is too big for FP16 aka. half precision, risk of overflow; increase precision \
706- of tensor node " ^ Tn. debug_name tn)
696+ let check_constant tn c =
697+ if Tn. exceeds_fp16_cutoff tn c then
698+ raise
699+ @@ Utils. User_error
700+ (" Constant " ^ Float. to_string c
701+ ^ " is too big for FP16 aka. half precision, risk of overflow; increase precision of \
702+ tensor node " ^ Tn. debug_name tn)
707703 in
708704 let rec check_proc llc =
709705 let loop = check_proc in
Original file line number Diff line number Diff line change @@ -28,7 +28,7 @@ let byte = Byte_prec Byte
2828let half = Half_prec Half
2929let single = Single_prec Single
3030let double = Double_prec Double
31- let is_fp16 = function Half_prec _ -> true | _ -> false
31+ let is_up_to_fp16 = function Half_prec _ | Byte_prec _ -> true | _ -> false
3232
3333let sexp_of_prec = function
3434 | Void_prec -> Sexp. Atom " Void_prec"
Original file line number Diff line number Diff line change @@ -76,6 +76,7 @@ type t = {
7676 (* * Display information. It is better if the last element of the list is the most narrow or
7777 alphanumeric, e.g. an identifier. *)
7878 mutable delayed_prec_unsafe : delayed_prec ;
79+ (* * Participates in the computation of {!field-prec}. *)
7980 mutable memory_mode : (memory_mode * int ) option ;
8081 mutable backend_info : Sexp .t ;
8182 mutable code_name : string option ;
@@ -374,6 +375,23 @@ let update_prec ?only_if tn prec =
374375 if cond old then prec else old))
375376 | _ -> tn.delayed_prec_unsafe < - Specified prec
376377
378+ let exceeds_fp16_cutoff tn c =
379+ match Utils. settings.check_half_prec_constants_cutoff with
380+ | None -> false
381+ | Some cutoff ->
382+ (* Only force if needed. *)
383+ Float. (abs c > = cutoff)
384+ &&
385+ let prec =
386+ if Lazy. is_val tn.prec then Lazy. force tn.prec
387+ else
388+ match tn.delayed_prec_unsafe with
389+ | Specified prec -> prec
390+ | Default_spec prec -> Lazy. force prec
391+ | Not_specified -> Lazy. force tn.prec
392+ in
393+ Ops. is_up_to_fp16 prec
394+
377395include Comparator. Make (struct
378396 type nonrec t = t
379397
Original file line number Diff line number Diff line change @@ -214,13 +214,8 @@ let _mem_benchmarks =
214214 ~f: (fun batch_size ->
215215 List. concat_map [ 0 ; (* 1; 2; *) 3 ] ~f: (fun inlining_cutoff ->
216216 List. concat_map [ (* 1; 3; *) 7 (* *) ] ~f: (fun seed ->
217- List. concat_map
218- [
219- (* "gccjit" ; *)
220- (* "cc"; *)
221- " cuda" ;
222- ] ~f: (fun backend_name ->
223- List. concat_map [ (* CDSL.double; *) CDSL. single (* ; CDSL.half *) ]
217+ List. concat_map [ (* "gccjit" ; *) " cc" ; " cuda" ] ~f: (fun backend_name ->
218+ List. concat_map [ (* CDSL.double; *) CDSL. single; CDSL. half ]
224219 ~f: (fun value_prec ->
225220 [
226221 classify_moons ~seed ~on_device: true ~inlining_cutoff ~num_streams
@@ -242,6 +237,13 @@ let _suspended () =
242237(* let () = List.map benchmarks ~f:(nth_best 2) |> PrintBox_utils.table |> PrintBox_text.output
243238 Stdio.stdout *)
244239
240+ let _suspended () =
241+ [
242+ classify_moons ~seed: 7 ~on_device: true ~inlining_cutoff: 0 ~num_streams: 3 ~batch_size: 240
243+ ~backend_name: " cc" ~value_prec: CDSL. half ~grad_prec: CDSL. half () ;
244+ ]
245+ |> PrintBox_utils. table |> PrintBox_text. output Stdio. stdout
246+
245247let benchmark benchmarks =
246248 List. map benchmarks ~f: (fun bench -> bench () )
247249 |> PrintBox_utils. table |> PrintBox_text. output Stdio. stdout
Original file line number Diff line number Diff line change @@ -314,6 +314,9 @@ let number ?(label = []) ?axis_label ?(grad_spec = Prohibit_grad) c =
314314 | Some axis_label -> t ~output_axes: [ (axis_label, 1 ) ] ()
315315 in
316316 Tn. update_memory_mode t.value Effectively_constant 24 ;
317+ Arrayjit.Ops. (
318+ if Tn. exceeds_fp16_cutoff t.value c then
319+ Tn. update_prec ~only_if: is_up_to_fp16 t.value single);
317320 t
318321
319322let ndarray ?(label = [] ) ?(grad_spec = Prohibit_grad ) ?batch_dims ?input_dims ?output_dims
@@ -356,6 +359,10 @@ let ndarray ?(label = []) ?(grad_spec = Prohibit_grad) ?batch_dims ?input_dims ?
356359 ()
357360 in
358361 Tn. update_memory_mode t.value Effectively_constant 24 ;
362+ let max_abs = Array. fold values ~init: 0. ~f: (fun acc v -> Float. (max acc @@ abs v)) in
363+ Arrayjit.Ops. (
364+ if Tn. exceeds_fp16_cutoff t.value max_abs then
365+ Tn. update_prec ~only_if: is_up_to_fp16 t.value single);
359366 t
360367
361368let param ?(more_label = [] ) ?input_dims ?output_dims ?input_axes ?output_axes ?deduced
Original file line number Diff line number Diff line change @@ -487,7 +487,8 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
487487 | None -> ! .init_lr *. ((2 *. ! ..steps) - ! @ step_n) /. ! ..steps
488488 | Some schedule -> schedule ~batch_n ~step_n
489489 in
490- Tn. update_prec ~only_if: Ops. is_fp16 learning_rate.value Ops. single;
490+ (* Note: constants at default half-prec are automatically upcasted when they exceed
491+ Utils.settings.check_half_prec_constants_cutoff, no need to upcast learning_rate.value. *)
491492 set_hosted learning_rate.value;
492493 let sgd = sgd_update ~learning_rate ~weight_decay update in
493494 let grad_update = Backend. compile ~shared: true bindings update.fwd_bprop in
Original file line number Diff line number Diff line change 11# This file is for tasks with a smaller granularity than issues, typically immediate tasks.
2- (B) bin/moons_benchmark with the cc backend crashes with half-prec overflow
2+ (B) bin/moons_benchmark with the cc backend crashes with half-prec overflow {cm:2024-11-24}
33(B) remove syncing from the data parallel algo: stream-to-stream syncing is now automatic {cm:2024-11-23}
44(A) cuda backend crashes in bin/moons_benchmark {cm:2024-11-22}
55(B) figure out why cuda backend parallelism slows down in later epochs
You can’t perform that action at this time.
0 commit comments