Skip to content

Commit 3ea5f58

Browse files
committed
Untested: mixed precision conversions in C_syntax; related cleanup
1 parent db206b0 commit 3ea5f58

File tree

5 files changed

+49
-33
lines changed

5 files changed

+49
-33
lines changed

CHANGES.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
### Added
44

5-
- The previously-mocked support for half precision.
5+
- Implemented the previously-mocked support for half precision (FP16).
66
- We work around the missing Ctypes coverage by not using `Ctypes.bigarray_start`.
7+
- We check FP16 constants for overflow.
8+
- We output half precision specific code from the CUDA backend.
79

810
### Changed
911

@@ -18,6 +20,8 @@
1820
- `debug_log_from_routines` should only happen when `log_level > 1`.
1921
- Bugs in `Multicore_backend`: `await` was not checking queue emptiness, `worker`'s `Condition.broadcast` was non-atomically guarded (doesn't need to be), possible deadloop due to the lockfree queue -- now replaced with `saturn_lockfree`.
2022
- Reduced busy-waiting inside `c_compile_and_load`, propagating compilation errors now instead of infinite loop on error.
23+
- Fixed loss of significant digits for small numbers when outputting files.
24+
- Added missing mixed-precision conversions in the `C_syntax` backend builder.
2125

2226
## [0.4.0] -- 2024-09-04
2327

arrayjit/lib/backend_utils.ml

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ struct
6666

6767
let pp_index_axis ppf = function
6868
| Indexing.Iterator it -> pp_index ppf it
69+
| Fixed_idx i when i < 0 -> Stdlib.Format.fprintf ppf "(%d)" i
6970
| Fixed_idx i -> Stdlib.Format.fprintf ppf "%d" i
7071

7172
let pp_array_offset ppf (idcs, dims) =
@@ -223,33 +224,38 @@ struct
223224
| Binop (_, v1, v2) -> pp_top_locals ppf v1 + pp_top_locals ppf v2
224225
| Unop (_, v) -> pp_top_locals ppf v
225226
and pp_float prec ppf value =
226-
let num_typ = B.typ_of_prec prec in
227227
let loop = pp_float prec in
228228
match value with
229229
| Local_scope { id; _ } ->
230230
(* Embedding of Local_scope is done by pp_top_locals. *)
231231
loop ppf @@ Get_local id
232232
| Get_local id ->
233-
let get_typ = B.typ_of_prec id.tn.prec in
234-
if not @@ String.equal num_typ get_typ then fprintf ppf "(%s)" num_typ;
235-
fprintf ppf "v%d" id.scope_id
233+
let prefix, postfix = B.convert_precision ~from:id.tn.prec ~to_:prec in
234+
fprintf ppf "%sv%d%s" prefix id.scope_id postfix
236235
| Get_global (Ops.Merge_buffer { source_node_id }, Some idcs) ->
237236
let tn = Option.value_exn ~here:[%here] @@ Tn.find ~id:source_node_id in
238-
fprintf ppf "@[<2>((%s*)merge_buffer)[%a@;<0 -2>]@]" (B.typ_of_prec prec) pp_array_offset
237+
let prefix, postfix = B.convert_precision ~from:tn.prec ~to_:prec in
238+
fprintf ppf "@[<2>%smerge_buffer[%a@;<0 -2>]%s@]" prefix pp_array_offset
239239
(idcs, Lazy.force tn.dims)
240+
postfix
240241
| Get_global _ -> failwith "C_syntax: Get_global / FFI NOT IMPLEMENTED YET"
241242
| Get (tn, idcs) ->
242-
(* FIXME: implement type casts here and in other places to support mixed precision. *)
243243
Hash_set.add visited tn;
244244
let ident = get_ident tn in
245-
fprintf ppf "@[<2>%s[%a@;<0 -2>]@]" ident pp_array_offset (idcs, Lazy.force tn.dims)
245+
let prefix, postfix = B.convert_precision ~from:tn.prec ~to_:prec in
246+
fprintf ppf "@[<2>%s%s[%a@;<0 -2>]%s@]" prefix ident pp_array_offset
247+
(idcs, Lazy.force tn.dims)
248+
postfix
246249
| Constant c ->
247250
let prefix, postfix = B.convert_precision ~from:Ops.double ~to_:prec in
251+
let prefix, postfix =
252+
if String.is_empty prefix && Float.(c < 0.0) then ("(", ")" ^ postfix)
253+
else (prefix, postfix)
254+
in
248255
fprintf ppf "%s%.16g%s" prefix c postfix
249256
| Embed_index idx ->
250-
if not @@ List.exists ~f:(String.equal num_typ) [ "int"; "size_t" ] then
251-
fprintf ppf "(%s)" num_typ;
252-
pp_index_axis ppf idx
257+
let prefix, postfix = B.convert_precision ~from:Ops.double ~to_:prec in
258+
fprintf ppf "%s%a%s" prefix pp_index_axis idx postfix
253259
| Binop (Arg1, v1, _v2) -> loop ppf v1
254260
| Binop (Arg2, _v1, v2) -> loop ppf v2
255261
| Binop (op, v1, v2) ->
@@ -259,31 +265,39 @@ struct
259265
let prefix, postfix = B.unop_syntax prec op in
260266
fprintf ppf "@[<1>%s%a@]%s" prefix loop v postfix
261267
and debug_float prec (value : Low_level.float_t) : string * 'a list =
262-
let num_typ = B.typ_of_prec prec in
263268
let loop = debug_float prec in
264269
match value with
265270
| Local_scope { id; _ } ->
266271
(* Not printing the inlined definition: (1) code complexity; (2) don't overload the debug
267272
logs. *)
268273
loop @@ Get_local id
269274
| Get_local id ->
270-
let get_typ = B.typ_of_prec id.tn.prec in
271-
let v =
272-
(if not @@ String.equal num_typ get_typ then "(" ^ num_typ ^ ")" else "")
273-
^ "v" ^ Int.to_string id.scope_id
274-
in
275+
let prefix, postfix = B.convert_precision ~from:id.tn.prec ~to_:prec in
276+
let v = String.concat [ prefix; "v"; Int.to_string id.scope_id; postfix ] in
275277
(v ^ "{=%g}", [ `Value v ])
276278
| Get_global (Ops.Merge_buffer { source_node_id }, Some idcs) ->
277279
let tn = Option.value_exn ~here:[%here] @@ Tn.find ~id:source_node_id in
280+
let prefix, postfix = B.convert_precision ~from:tn.prec ~to_:prec in
278281
let dims = Lazy.force tn.dims in
279-
let v = sprintf "@[<2>merge_buffer[%s@;<0 -2>]@]" (array_offset_to_string (idcs, dims)) in
280-
("merge_buffer[%u]{=%g}", [ `Accessor (idcs, dims); `Value v ])
282+
let v =
283+
sprintf "@[<2>%smerge_buffer[%s@;<0 -2>]%s@]" prefix
284+
(array_offset_to_string (idcs, dims))
285+
postfix
286+
in
287+
( String.concat [ prefix; "merge_buffer[%u]"; postfix; "{=%g}" ],
288+
[ `Accessor (idcs, dims); `Value v ] )
281289
| Get_global _ -> failwith "Exec_as_cuda: Get_global / FFI NOT IMPLEMENTED YET"
282290
| Get (tn, idcs) ->
283291
let dims = Lazy.force tn.dims in
284292
let ident = get_ident tn in
285-
let v = sprintf "@[<2>%s[%s@;<0 -2>]@]" ident (array_offset_to_string (idcs, dims)) in
286-
(ident ^ "[%u]{=%g}", [ `Accessor (idcs, dims); `Value v ])
293+
let prefix, postfix = B.convert_precision ~from:tn.prec ~to_:prec in
294+
let v =
295+
sprintf "@[<2>%s%s[%s@;<0 -2>]%s@]" prefix ident
296+
(array_offset_to_string (idcs, dims))
297+
postfix
298+
in
299+
( String.concat [ prefix; ident; "[%u]"; postfix; "{=%g}" ],
300+
[ `Accessor (idcs, dims); `Value v ] )
287301
| Constant c ->
288302
let prefix, postfix = B.convert_precision ~from:Ops.double ~to_:prec in
289303
(prefix ^ Float.to_string c ^ postfix, [])

arrayjit/lib/cuda_backend.cudajit.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ struct
376376
| Half_prec _, Half_prec _
377377
| Byte_prec _, Byte_prec _
378378
| Void_prec, Void_prec ->
379-
("(", ")")
379+
("", "")
380380
| Double_prec _, Half_prec _ -> ("__double2half(", ")")
381381
| Single_prec _, Half_prec _ -> ("__float2half(", ")")
382382
| Byte_prec _, Half_prec _ -> ("__ushort2half_rn((unsigned short int)", ")")

arrayjit/lib/low_level.ml

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -739,17 +739,14 @@ let simplify_llc llc =
739739
let check_constant =
740740
match Utils.settings.check_half_prec_constants_cutoff with
741741
| None -> fun _prec _c -> ()
742-
| Some cutoff -> (
742+
| Some cutoff ->
743743
fun tn c ->
744-
match tn.Tn.prec with
745-
| Ops.Half_prec _ ->
746-
if Float.(abs c >= cutoff) then
747-
raise
748-
@@ Utils.User_error
749-
("Constant " ^ Float.to_string c
750-
^ " is too big for FP16 aka. half precision, risk of overflow; increase \
751-
precision of tensor node " ^ Tn.debug_name tn)
752-
| _ -> ())
744+
if Ops.is_fp16 tn.Tn.prec && Float.(abs c >= cutoff) then
745+
raise
746+
@@ Utils.User_error
747+
("Constant " ^ Float.to_string c
748+
^ " is too big for FP16 aka. half precision, risk of overflow; increase precision \
749+
of tensor node " ^ Tn.debug_name tn)
753750
in
754751
let rec check_proc llc =
755752
let loop = check_proc in

arrayjit/lib/ops.ml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ let byte = Byte_prec Byte
2828
let half = Half_prec Half
2929
let single = Single_prec Single
3030
let double = Double_prec Double
31+
let is_fp16 = function Half_prec _ -> true | _ -> false
3132

3233
let sexp_of_prec = function
3334
| Void_prec -> Sexp.Atom "Void_prec"
@@ -226,7 +227,7 @@ let c_convert_precision ~from ~to_ =
226227
| Half_prec _, Half_prec _
227228
| Byte_prec _, Byte_prec _
228229
| Void_prec, Void_prec ->
229-
("(", ")")
230+
("", "")
230231
| _ -> ("(" ^ c_typ_of_prec to_ ^ ")(", ")")
231232

232233
(** {2 *** Global references ***} *)

0 commit comments

Comments
 (0)