@@ -736,7 +736,47 @@ let simplify_llc llc =
736736 let result = Unop (op, v) in
737737 if equal_float_t llv v then result else loop_float result
738738 in
739- loop_proc llc
739+ let check_constant =
740+ match Utils. settings.check_half_prec_constants_cutoff with
741+ | None -> fun _prec _c -> ()
742+ | Some cutoff -> (
743+ fun tn c ->
744+ match tn.Tn. prec with
745+ | Ops. Half_prec _ ->
746+ if Float. (abs c > = cutoff) then
747+ raise
748+ @@ Utils. User_error
749+ (" Constant " ^ Float. to_string c
750+ ^ " is too big for FP16 aka. half precision, risk of overflow; increase \
751+ precision of tensor node " ^ Tn. debug_name tn)
752+ | _ -> () )
753+ in
754+ let rec check_proc llc =
755+ let loop = check_proc in
756+ match llc with
757+ | Seq (c1 , c2 ) ->
758+ loop c1;
759+ loop c2
760+ | For_loop { body; _ } -> loop body
761+ | Zero_out _ -> ()
762+ | Set { tn; llv; _ } -> check_float tn llv
763+ | Set_local (id , llv ) -> check_float id.tn llv
764+ | Noop | Comment _ | Staged_compilation _ -> ()
765+ and check_float tn llv =
766+ let loop = check_float tn in
767+ match llv with
768+ | Constant c -> check_constant tn c
769+ | Local_scope { body; _ } -> check_proc body
770+ | Binop (_ , v1 , v2 ) ->
771+ loop v1;
772+ loop v2
773+ | Unop (_ , v ) -> loop v
774+ | Embed_index (Indexing. Fixed_idx i ) -> check_constant tn (Float. of_int i)
775+ | Embed_index _ | Get_local _ | Get_global (_ , _ ) | Get (_ , _ ) -> ()
776+ in
777+ let result = loop_proc llc in
778+ if Option. is_some Utils. settings.check_half_prec_constants_cutoff then check_proc result;
779+ result
740780
741781type traced_store = (Tn .t , traced_array ) Base.Hashtbl .t [@@ deriving sexp_of ]
742782
0 commit comments