Skip to content

Commit d49bc96

Browse files
committed
Fixes #204 and #319: fix emulation of FP8, BF16 via single prec floats in Pure_C_config numerics
1 parent 367ff3b commit d49bc96

File tree

2 files changed

+112
-23
lines changed

2 files changed

+112
-23
lines changed

CHANGES.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
## [0.5.4] -- current
2+
3+
### Added
4+
5+
- Support for Brain float aka. bfloat16 aka. BF16, and for FP8.
6+
- TODO: utilities for using Hugging Face Tokenizers.
7+
18
## [0.5.3] -- 2025-05-24
29

310
### Added

arrayjit/lib/c_syntax.ml

Lines changed: 105 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -217,20 +217,48 @@ struct
217217
Set.to_list !functions
218218

219219
let ternop_syntax prec op v1 v2 v3 =
220-
let op_prefix, op_infix1, op_infix2, op_suffix = Ops.ternop_c_syntax prec op in
221-
let open PPrint in
222-
group
223-
(string op_prefix ^^ v1 ^^ string op_infix1
224-
^^ ifflat (space ^^ v2) (nest 2 (break 1 ^^ v2))
225-
^^ string op_infix2
226-
^^ ifflat (space ^^ v3) (nest 2 (break 1 ^^ v3))
227-
^^ string op_suffix)
220+
match prec with
221+
| Ops.Bfloat16_prec _ ->
222+
(* For BFloat16, perform operations in float precision *)
223+
let float_v1 = PPrint.(string "bfloat16_to_float(" ^^ v1 ^^ string ")") in
224+
let float_v2 = PPrint.(string "bfloat16_to_float(" ^^ v2 ^^ string ")") in
225+
let float_v3 = PPrint.(string "bfloat16_to_float(" ^^ v3 ^^ string ")") in
226+
let op_prefix, op_infix1, op_infix2, op_suffix = Ops.ternop_c_syntax Ops.single op in
227+
let float_result = PPrint.(
228+
group (string op_prefix ^^ float_v1 ^^ string op_infix1
229+
^^ ifflat (space ^^ float_v2) (nest 2 (break 1 ^^ float_v2))
230+
^^ string op_infix2
231+
^^ ifflat (space ^^ float_v3) (nest 2 (break 1 ^^ float_v3))
232+
^^ string op_suffix)) in
233+
PPrint.(string "float_to_bfloat16(" ^^ float_result ^^ string ")")
234+
| Ops.Fp8_prec _ ->
235+
(* For FP8, perform operations in float precision *)
236+
let float_v1 = PPrint.(string "fp8_to_float(" ^^ v1 ^^ string ")") in
237+
let float_v2 = PPrint.(string "fp8_to_float(" ^^ v2 ^^ string ")") in
238+
let float_v3 = PPrint.(string "fp8_to_float(" ^^ v3 ^^ string ")") in
239+
let op_prefix, op_infix1, op_infix2, op_suffix = Ops.ternop_c_syntax Ops.single op in
240+
let float_result = PPrint.(
241+
group (string op_prefix ^^ float_v1 ^^ string op_infix1
242+
^^ ifflat (space ^^ float_v2) (nest 2 (break 1 ^^ float_v2))
243+
^^ string op_infix2
244+
^^ ifflat (space ^^ float_v3) (nest 2 (break 1 ^^ float_v3))
245+
^^ string op_suffix)) in
246+
PPrint.(string "float_to_fp8(" ^^ float_result ^^ string ")")
247+
| _ ->
248+
let op_prefix, op_infix1, op_infix2, op_suffix = Ops.ternop_c_syntax prec op in
249+
let open PPrint in
250+
group
251+
(string op_prefix ^^ v1 ^^ string op_infix1
252+
^^ ifflat (space ^^ v2) (nest 2 (break 1 ^^ v2))
253+
^^ string op_infix2
254+
^^ ifflat (space ^^ v3) (nest 2 (break 1 ^^ v3))
255+
^^ string op_suffix)
228256

229257
let binop_syntax prec op v1 v2 =
230258
match op with
231259
| Ops.Satur01_gate -> (
232260
match prec with
233-
| Ops.Byte_prec _ | Ops.Uint16_prec _ | Ops.Int32_prec _ | Ops.Fp8_prec _ ->
261+
| Ops.Byte_prec _ | Ops.Uint16_prec _ | Ops.Int32_prec _ ->
234262
let open PPrint in
235263
group
236264
(parens
@@ -244,21 +272,34 @@ struct
244272
(nest 2
245273
(break 1 ^^ string "?" ^^ space ^^ v2 ^^ break 1 ^^ string ":" ^^ space
246274
^^ string "(" ^^ string (typ_of_prec prec) ^^ string ")0"))))
275+
| Ops.Fp8_prec _ ->
276+
let open PPrint in
277+
group
278+
(parens
279+
(group
280+
(parens
281+
(string "fp8_to_float(" ^^ v1 ^^ string ") > 0.0f && fp8_to_float("
282+
^^ v1 ^^ string ") < 1.0f"))
283+
^^ ifflat
284+
(space ^^ string "?" ^^ space ^^ v2 ^^ space ^^ string ":" ^^ space
285+
^^ string "float_to_fp8(0.0f)")
286+
(nest 2
287+
(break 1 ^^ string "?" ^^ space ^^ v2 ^^ break 1 ^^ string ":" ^^ space
288+
^^ string "float_to_fp8(0.0f)"))))
247289
| Ops.Bfloat16_prec _ ->
248-
(* For CC backend, convert to float for computation *)
249290
let open PPrint in
250291
group
251292
(parens
252293
(group
253294
(parens
254-
(string "(float)" ^^ v1 ^^ string " > 0.0f && (float)" ^^ v1
255-
^^ string " < 1.0f"))
295+
(string "bfloat16_to_float(" ^^ v1 ^^ string ") > 0.0f && bfloat16_to_float("
296+
^^ v1 ^^ string ") < 1.0f"))
256297
^^ ifflat
257298
(space ^^ string "?" ^^ space ^^ v2 ^^ space ^^ string ":" ^^ space
258-
^^ string "(unsigned short)0")
299+
^^ string "float_to_bfloat16(0.0f)")
259300
(nest 2
260301
(break 1 ^^ string "?" ^^ space ^^ v2 ^^ break 1 ^^ string ":" ^^ space
261-
^^ string "(unsigned short)0"))))
302+
^^ string "float_to_bfloat16(0.0f)"))))
262303
| Ops.Half_prec _ ->
263304
let open PPrint in
264305
group
@@ -294,17 +335,58 @@ struct
294335
^^ string "0.0"))))
295336
| Ops.Void_prec -> invalid_arg "Pure_C_config.binop_syntax: Satur01_gate on Void_prec")
296337
| _ ->
297-
let op_prefix, op_infix, op_suffix = Ops.binop_c_syntax prec op in
298-
let open PPrint in
299-
group
300-
(string op_prefix ^^ v1 ^^ string op_infix
301-
^^ ifflat (space ^^ v2) (nest 2 (break 1 ^^ v2))
302-
^^ string op_suffix)
338+
match prec with
339+
| Ops.Bfloat16_prec _ ->
340+
(* For BFloat16, perform all operations in float precision *)
341+
let float_v1 = PPrint.(string "bfloat16_to_float(" ^^ v1 ^^ string ")") in
342+
let float_v2 = PPrint.(string "bfloat16_to_float(" ^^ v2 ^^ string ")") in
343+
let op_prefix, op_infix, op_suffix = Ops.binop_c_syntax Ops.single op in
344+
let float_result = PPrint.(
345+
group (string op_prefix ^^ float_v1 ^^ string op_infix
346+
^^ ifflat (space ^^ float_v2) (nest 2 (break 1 ^^ float_v2))
347+
^^ string op_suffix)) in
348+
(* For comparison operations, return float result (0.0 or 1.0) converted to BFloat16 *)
349+
(match op with
350+
| Ops.Cmplt | Ops.Cmpeq | Ops.Cmpne | Ops.Or | Ops.And ->
351+
PPrint.(string "float_to_bfloat16(" ^^ float_result ^^ string ")")
352+
| _ ->
353+
PPrint.(string "float_to_bfloat16(" ^^ float_result ^^ string ")"))
354+
| Ops.Fp8_prec _ ->
355+
(* For FP8, perform all operations in float precision *)
356+
let float_v1 = PPrint.(string "fp8_to_float(" ^^ v1 ^^ string ")") in
357+
let float_v2 = PPrint.(string "fp8_to_float(" ^^ v2 ^^ string ")") in
358+
let op_prefix, op_infix, op_suffix = Ops.binop_c_syntax Ops.single op in
359+
let float_result = PPrint.(
360+
group (string op_prefix ^^ float_v1 ^^ string op_infix
361+
^^ ifflat (space ^^ float_v2) (nest 2 (break 1 ^^ float_v2))
362+
^^ string op_suffix)) in
363+
PPrint.(string "float_to_fp8(" ^^ float_result ^^ string ")")
364+
| _ ->
365+
let op_prefix, op_infix, op_suffix = Ops.binop_c_syntax prec op in
366+
let open PPrint in
367+
group
368+
(string op_prefix ^^ v1 ^^ string op_infix
369+
^^ ifflat (space ^^ v2) (nest 2 (break 1 ^^ v2))
370+
^^ string op_suffix)
303371

304372
let unop_syntax prec op v =
305-
let op_prefix, op_suffix = Ops.unop_c_syntax prec op in
306-
let open PPrint in
307-
group (string op_prefix ^^ v ^^ string op_suffix)
373+
match prec with
374+
| Ops.Bfloat16_prec _ ->
375+
(* For BFloat16, perform operations in float precision *)
376+
let float_v = PPrint.(string "bfloat16_to_float(" ^^ v ^^ string ")") in
377+
let op_prefix, op_suffix = Ops.unop_c_syntax Ops.single op in
378+
let float_result = PPrint.(group (string op_prefix ^^ float_v ^^ string op_suffix)) in
379+
PPrint.(string "float_to_bfloat16(" ^^ float_result ^^ string ")")
380+
| Ops.Fp8_prec _ ->
381+
(* For FP8, perform operations in float precision *)
382+
let float_v = PPrint.(string "fp8_to_float(" ^^ v ^^ string ")") in
383+
let op_prefix, op_suffix = Ops.unop_c_syntax Ops.single op in
384+
let float_result = PPrint.(group (string op_prefix ^^ float_v ^^ string op_suffix)) in
385+
PPrint.(string "float_to_fp8(" ^^ float_result ^^ string ")")
386+
| _ ->
387+
let op_prefix, op_suffix = Ops.unop_c_syntax prec op in
388+
let open PPrint in
389+
group (string op_prefix ^^ v ^^ string op_suffix)
308390

309391
let convert_precision = Ops.c_convert_precision
310392
let kernel_log_param = Some ("const char*", "log_file_name")

0 commit comments

Comments
 (0)