Skip to content

Commit 54645ab

Browse files
committed
Cleanup and formatting
1 parent d49bc96 commit 54645ab

File tree

11 files changed

+237
-726
lines changed

11 files changed

+237
-726
lines changed

arrayjit/lib/c_syntax.ml

Lines changed: 111 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -91,62 +91,65 @@ struct
9191
let arg_int_prefix = "const int "
9292
let extra_args = []
9393
let includes = [ "<stdio.h>"; "<stdlib.h>"; "<string.h>"; "<math.h>" ]
94-
let extra_declarations = [
95-
(* BFloat16 conversion functions *)
96-
"static inline float bfloat16_to_float(unsigned short bf16) {";
97-
" unsigned int f32 = ((unsigned int)bf16) << 16;";
98-
" return *((float*)&f32);";
99-
"}";
100-
"";
101-
"static inline unsigned short float_to_bfloat16(float f) {";
102-
" unsigned int f32 = *((unsigned int*)&f);";
103-
" unsigned int rounded = f32 + 0x7FFF + ((f32 >> 16) & 1);";
104-
" return (unsigned short)(rounded >> 16);";
105-
"}";
106-
"";
107-
(* FP8 E5M2 conversion functions *)
108-
"static inline float fp8_to_float(unsigned char fp8) {";
109-
" if (fp8 == 0) return 0.0f;";
110-
" unsigned int sign = (fp8 >> 7) & 1;";
111-
" unsigned int exp = (fp8 >> 2) & 0x1F;";
112-
" unsigned int mant = fp8 & 0x3;";
113-
" if (exp == 0x1F) {";
114-
" if (mant == 0) return sign ? -INFINITY : INFINITY;";
115-
" else return NAN;";
116-
" }";
117-
" if (exp == 0) {";
118-
" float result = ldexpf((float)mant / 4.0f, -14);";
119-
" if (sign) result = -result;";
120-
" return result;";
121-
" }";
122-
" float result = (1.0f + (float)mant * 0.25f) * ldexpf(1.0f, (int)exp - 15);";
123-
" if (sign) result = -result;";
124-
" return result;";
125-
"}";
126-
"";
127-
"static inline unsigned char float_to_fp8(float f) {";
128-
" if (f == 0.0f) return 0;";
129-
" unsigned int sign = (f < 0) ? 1 : 0;";
130-
" f = fabsf(f);";
131-
" if (isinf(f)) return (sign << 7) | 0x7C;";
132-
" if (isnan(f)) return (sign << 7) | 0x7F;";
133-
" int exp_val;";
134-
" float mant_f = frexpf(f, &exp_val);";
135-
" int exp = exp_val + 14;";
136-
" if (exp < 0) return sign << 7;";
137-
" if (exp > 30) return (sign << 7) | 0x7C;";
138-
" if (exp == 0) {";
139-
" float denorm_mant = f * ldexpf(1.0f, 14) * 4.0f;";
140-
" unsigned int mant_bits = (unsigned int)(denorm_mant + 0.5f);";
141-
" if (mant_bits > 3) mant_bits = 3;";
142-
" return (sign << 7) | mant_bits;";
143-
" }";
144-
" mant_f = (mant_f - 0.5f) * 4.0f;";
145-
" unsigned int mant_bits = (unsigned int)(mant_f + 0.5f);";
146-
" if (mant_bits > 3) mant_bits = 3;";
147-
" return (unsigned char)((sign << 7) | ((exp & 0x1F) << 2) | (mant_bits & 0x3));";
148-
"}";
149-
]
94+
95+
let extra_declarations =
96+
[
97+
(* BFloat16 conversion functions *)
98+
"static inline float bfloat16_to_float(unsigned short bf16) {";
99+
" unsigned int f32 = ((unsigned int)bf16) << 16;";
100+
" return *((float*)&f32);";
101+
"}";
102+
"";
103+
"static inline unsigned short float_to_bfloat16(float f) {";
104+
" unsigned int f32 = *((unsigned int*)&f);";
105+
" unsigned int rounded = f32 + 0x7FFF + ((f32 >> 16) & 1);";
106+
" return (unsigned short)(rounded >> 16);";
107+
"}";
108+
"";
109+
(* FP8 E5M2 conversion functions *)
110+
"static inline float fp8_to_float(unsigned char fp8) {";
111+
" if (fp8 == 0) return 0.0f;";
112+
" unsigned int sign = (fp8 >> 7) & 1;";
113+
" unsigned int exp = (fp8 >> 2) & 0x1F;";
114+
" unsigned int mant = fp8 & 0x3;";
115+
" if (exp == 0x1F) {";
116+
" if (mant == 0) return sign ? -INFINITY : INFINITY;";
117+
" else return NAN;";
118+
" }";
119+
" if (exp == 0) {";
120+
" float result = ldexpf((float)mant / 4.0f, -14);";
121+
" if (sign) result = -result;";
122+
" return result;";
123+
" }";
124+
" float result = (1.0f + (float)mant * 0.25f) * ldexpf(1.0f, (int)exp - 15);";
125+
" if (sign) result = -result;";
126+
" return result;";
127+
"}";
128+
"";
129+
"static inline unsigned char float_to_fp8(float f) {";
130+
" if (f == 0.0f) return 0;";
131+
" unsigned int sign = (f < 0) ? 1 : 0;";
132+
" f = fabsf(f);";
133+
" if (isinf(f)) return (sign << 7) | 0x7C;";
134+
" if (isnan(f)) return (sign << 7) | 0x7F;";
135+
" int exp_val;";
136+
" float mant_f = frexpf(f, &exp_val);";
137+
" int exp = exp_val + 14;";
138+
" if (exp < 0) return sign << 7;";
139+
" if (exp > 30) return (sign << 7) | 0x7C;";
140+
" if (exp == 0) {";
141+
" float denorm_mant = f * ldexpf(1.0f, 14) * 4.0f;";
142+
" unsigned int mant_bits = (unsigned int)(denorm_mant + 0.5f);";
143+
" if (mant_bits > 3) mant_bits = 3;";
144+
" return (sign << 7) | mant_bits;";
145+
" }";
146+
" mant_f = (mant_f - 0.5f) * 4.0f;";
147+
" unsigned int mant_bits = (unsigned int)(mant_f + 0.5f);";
148+
" if (mant_bits > 3) mant_bits = 3;";
149+
" return (unsigned char)((sign << 7) | ((exp & 0x1F) << 2) | (mant_bits & 0x3));";
150+
"}";
151+
]
152+
150153
let typ_of_prec = Ops.c_typ_of_prec
151154
let float_log_style = if Input.full_printf_support then "%g" else "%de-3"
152155

@@ -221,28 +224,34 @@ struct
221224
| Ops.Bfloat16_prec _ ->
222225
(* For BFloat16, perform operations in float precision *)
223226
let float_v1 = PPrint.(string "bfloat16_to_float(" ^^ v1 ^^ string ")") in
224-
let float_v2 = PPrint.(string "bfloat16_to_float(" ^^ v2 ^^ string ")") in
227+
let float_v2 = PPrint.(string "bfloat16_to_float(" ^^ v2 ^^ string ")") in
225228
let float_v3 = PPrint.(string "bfloat16_to_float(" ^^ v3 ^^ string ")") in
226229
let op_prefix, op_infix1, op_infix2, op_suffix = Ops.ternop_c_syntax Ops.single op in
227-
let float_result = PPrint.(
228-
group (string op_prefix ^^ float_v1 ^^ string op_infix1
229-
^^ ifflat (space ^^ float_v2) (nest 2 (break 1 ^^ float_v2))
230-
^^ string op_infix2
231-
^^ ifflat (space ^^ float_v3) (nest 2 (break 1 ^^ float_v3))
232-
^^ string op_suffix)) in
230+
let float_result =
231+
PPrint.(
232+
group
233+
(string op_prefix ^^ float_v1 ^^ string op_infix1
234+
^^ ifflat (space ^^ float_v2) (nest 2 (break 1 ^^ float_v2))
235+
^^ string op_infix2
236+
^^ ifflat (space ^^ float_v3) (nest 2 (break 1 ^^ float_v3))
237+
^^ string op_suffix))
238+
in
233239
PPrint.(string "float_to_bfloat16(" ^^ float_result ^^ string ")")
234240
| Ops.Fp8_prec _ ->
235241
(* For FP8, perform operations in float precision *)
236242
let float_v1 = PPrint.(string "fp8_to_float(" ^^ v1 ^^ string ")") in
237-
let float_v2 = PPrint.(string "fp8_to_float(" ^^ v2 ^^ string ")") in
243+
let float_v2 = PPrint.(string "fp8_to_float(" ^^ v2 ^^ string ")") in
238244
let float_v3 = PPrint.(string "fp8_to_float(" ^^ v3 ^^ string ")") in
239245
let op_prefix, op_infix1, op_infix2, op_suffix = Ops.ternop_c_syntax Ops.single op in
240-
let float_result = PPrint.(
241-
group (string op_prefix ^^ float_v1 ^^ string op_infix1
242-
^^ ifflat (space ^^ float_v2) (nest 2 (break 1 ^^ float_v2))
243-
^^ string op_infix2
244-
^^ ifflat (space ^^ float_v3) (nest 2 (break 1 ^^ float_v3))
245-
^^ string op_suffix)) in
246+
let float_result =
247+
PPrint.(
248+
group
249+
(string op_prefix ^^ float_v1 ^^ string op_infix1
250+
^^ ifflat (space ^^ float_v2) (nest 2 (break 1 ^^ float_v2))
251+
^^ string op_infix2
252+
^^ ifflat (space ^^ float_v3) (nest 2 (break 1 ^^ float_v3))
253+
^^ string op_suffix))
254+
in
246255
PPrint.(string "float_to_fp8(" ^^ float_result ^^ string ")")
247256
| _ ->
248257
let op_prefix, op_infix1, op_infix2, op_suffix = Ops.ternop_c_syntax prec op in
@@ -268,18 +277,23 @@ struct
268277
^^ string " < 1.0f"))
269278
^^ ifflat
270279
(space ^^ string "?" ^^ space ^^ v2 ^^ space ^^ string ":" ^^ space
271-
^^ string "(" ^^ string (typ_of_prec prec) ^^ string ")0")
280+
^^ string "("
281+
^^ string (typ_of_prec prec)
282+
^^ string ")0")
272283
(nest 2
273284
(break 1 ^^ string "?" ^^ space ^^ v2 ^^ break 1 ^^ string ":" ^^ space
274-
^^ string "(" ^^ string (typ_of_prec prec) ^^ string ")0"))))
285+
^^ string "("
286+
^^ string (typ_of_prec prec)
287+
^^ string ")0"))))
275288
| Ops.Fp8_prec _ ->
276289
let open PPrint in
277290
group
278291
(parens
279292
(group
280293
(parens
281-
(string "fp8_to_float(" ^^ v1 ^^ string ") > 0.0f && fp8_to_float("
282-
^^ v1 ^^ string ") < 1.0f"))
294+
(string "fp8_to_float(" ^^ v1
295+
^^ string ") > 0.0f && fp8_to_float("
296+
^^ v1 ^^ string ") < 1.0f"))
283297
^^ ifflat
284298
(space ^^ string "?" ^^ space ^^ v2 ^^ space ^^ string ":" ^^ space
285299
^^ string "float_to_fp8(0.0f)")
@@ -292,8 +306,9 @@ struct
292306
(parens
293307
(group
294308
(parens
295-
(string "bfloat16_to_float(" ^^ v1 ^^ string ") > 0.0f && bfloat16_to_float("
296-
^^ v1 ^^ string ") < 1.0f"))
309+
(string "bfloat16_to_float(" ^^ v1
310+
^^ string ") > 0.0f && bfloat16_to_float("
311+
^^ v1 ^^ string ") < 1.0f"))
297312
^^ ifflat
298313
(space ^^ string "?" ^^ space ^^ v2 ^^ space ^^ string ":" ^^ space
299314
^^ string "float_to_bfloat16(0.0f)")
@@ -334,40 +349,45 @@ struct
334349
(break 1 ^^ string "?" ^^ space ^^ v2 ^^ break 1 ^^ string ":" ^^ space
335350
^^ string "0.0"))))
336351
| Ops.Void_prec -> invalid_arg "Pure_C_config.binop_syntax: Satur01_gate on Void_prec")
337-
| _ ->
352+
| _ -> (
338353
match prec with
339-
| Ops.Bfloat16_prec _ ->
354+
| Ops.Bfloat16_prec _ -> (
340355
(* For BFloat16, perform all operations in float precision *)
341356
let float_v1 = PPrint.(string "bfloat16_to_float(" ^^ v1 ^^ string ")") in
342357
let float_v2 = PPrint.(string "bfloat16_to_float(" ^^ v2 ^^ string ")") in
343358
let op_prefix, op_infix, op_suffix = Ops.binop_c_syntax Ops.single op in
344-
let float_result = PPrint.(
345-
group (string op_prefix ^^ float_v1 ^^ string op_infix
346-
^^ ifflat (space ^^ float_v2) (nest 2 (break 1 ^^ float_v2))
347-
^^ string op_suffix)) in
359+
let float_result =
360+
PPrint.(
361+
group
362+
(string op_prefix ^^ float_v1 ^^ string op_infix
363+
^^ ifflat (space ^^ float_v2) (nest 2 (break 1 ^^ float_v2))
364+
^^ string op_suffix))
365+
in
348366
(* For comparison operations, return float result (0.0 or 1.0) converted to BFloat16 *)
349-
(match op with
367+
match op with
350368
| Ops.Cmplt | Ops.Cmpeq | Ops.Cmpne | Ops.Or | Ops.And ->
351-
PPrint.(string "float_to_bfloat16(" ^^ float_result ^^ string ")")
352-
| _ ->
353-
PPrint.(string "float_to_bfloat16(" ^^ float_result ^^ string ")"))
369+
PPrint.(string "float_to_bfloat16(" ^^ float_result ^^ string ")")
370+
| _ -> PPrint.(string "float_to_bfloat16(" ^^ float_result ^^ string ")"))
354371
| Ops.Fp8_prec _ ->
355372
(* For FP8, perform all operations in float precision *)
356373
let float_v1 = PPrint.(string "fp8_to_float(" ^^ v1 ^^ string ")") in
357374
let float_v2 = PPrint.(string "fp8_to_float(" ^^ v2 ^^ string ")") in
358375
let op_prefix, op_infix, op_suffix = Ops.binop_c_syntax Ops.single op in
359-
let float_result = PPrint.(
360-
group (string op_prefix ^^ float_v1 ^^ string op_infix
361-
^^ ifflat (space ^^ float_v2) (nest 2 (break 1 ^^ float_v2))
362-
^^ string op_suffix)) in
376+
let float_result =
377+
PPrint.(
378+
group
379+
(string op_prefix ^^ float_v1 ^^ string op_infix
380+
^^ ifflat (space ^^ float_v2) (nest 2 (break 1 ^^ float_v2))
381+
^^ string op_suffix))
382+
in
363383
PPrint.(string "float_to_fp8(" ^^ float_result ^^ string ")")
364384
| _ ->
365385
let op_prefix, op_infix, op_suffix = Ops.binop_c_syntax prec op in
366386
let open PPrint in
367387
group
368388
(string op_prefix ^^ v1 ^^ string op_infix
369389
^^ ifflat (space ^^ v2) (nest 2 (break 1 ^^ v2))
370-
^^ string op_suffix)
390+
^^ string op_suffix))
371391

372392
let unop_syntax prec op v =
373393
match prec with

arrayjit/lib/cc_backend.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ struct
100100

101101
(* Override to add our custom type and conversion support *)
102102
let typ_of_prec = typ_of_prec
103-
let extra_declarations = extra_declarations (* Our bfloat16/fp8 conversion functions *)
103+
let extra_declarations = extra_declarations (* Our bfloat16/fp8 conversion functions *)
104104
let convert_precision = convert_precision
105105
end
106106

@@ -209,4 +209,4 @@ let%track3_sexp link_compiled ~merge_buffer ~runner_label ctx_arrays (code : pro
209209
context_lifetime = (ctx_arrays, code);
210210
description = "executes " ^ code.name ^ " on " ^ runner_label;
211211
work;
212-
} )
212+
} )

arrayjit/lib/cuda_backend.ml

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -284,8 +284,8 @@ end) : Ir.Backend_impl.Lowered_backend = struct
284284
| Ops.Uint16_prec _ -> "unsigned short"
285285
| Ops.Int32_prec _ -> "int"
286286
| Ops.Half_prec _ -> "__half"
287-
| Ops.Bfloat16_prec _ -> "__nv_bfloat16" (* CUDA bfloat16 type *)
288-
| Ops.Fp8_prec _ -> "__nv_fp8_e5m2" (* CUDA FP8 type (E5M2 format) *)
287+
| Ops.Bfloat16_prec _ -> "__nv_bfloat16" (* CUDA bfloat16 type *)
288+
| Ops.Fp8_prec _ -> "__nv_fp8_e5m2" (* CUDA FP8 type (E5M2 format) *)
289289
| Ops.Single_prec _ -> "float"
290290
| Ops.Double_prec _ -> "double"
291291
| Ops.Void_prec -> "void"
@@ -326,8 +326,8 @@ end) : Ir.Backend_impl.Lowered_backend = struct
326326
| ToPowOf, Bfloat16_prec _ ->
327327
fun v1 v2 ->
328328
group
329-
(string "__float2bfloat16(powf(__bfloat162float(" ^^ v1 ^^ string "), __bfloat162float("
330-
^^ v2 ^^ string ")))")
329+
(string "__float2bfloat16(powf(__bfloat162float("
330+
^^ v1 ^^ string "), __bfloat162float(" ^^ v2 ^^ string ")))")
331331
| Relu_gate, (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Fp8_prec _) ->
332332
fun v1 v2 ->
333333
group
@@ -343,15 +343,13 @@ end) : Ir.Backend_impl.Lowered_backend = struct
343343
fun v1 v2 ->
344344
group
345345
(parens
346-
(group
347-
(parens
348-
(string "__bfloat162float(" ^^ v1 ^^ string ") > 0.0f"))
346+
(group (parens (string "__bfloat162float(" ^^ v1 ^^ string ") > 0.0f"))
349347
^^ ifflat
350348
(space ^^ string "?" ^^ space ^^ v2 ^^ space ^^ string ":" ^^ space
351-
^^ string "__float2bfloat16(0.0f)")
349+
^^ string "__float2bfloat16(0.0f)")
352350
(nest 2
353351
(break 1 ^^ string "?" ^^ space ^^ v2 ^^ break 1 ^^ string ":" ^^ space
354-
^^ string "__float2bfloat16(0.0f)"))))
352+
^^ string "__float2bfloat16(0.0f)"))))
355353
| Satur01_gate, Byte_prec _ ->
356354
fun v1 v2 ->
357355
group

arrayjit/lib/metal_backend.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ end) : Ir.Backend_impl.Lowered_backend = struct
444444
| Ops.Uint16_prec _ -> "ushort"
445445
| Ops.Int32_prec _ -> "int"
446446
| Ops.Half_prec _ -> "half"
447-
| Ops.Bfloat16_prec _ -> "bfloat" (* Metal supports bfloat16 natively *)
447+
| Ops.Bfloat16_prec _ -> "bfloat" (* Metal supports bfloat16 natively *)
448448
| Ops.Fp8_prec _ -> invalid_arg "Metal backend does not support FP8 precision"
449449
| Ops.Single_prec _ -> "float"
450450
| Ops.Double_prec _ -> "double"
@@ -455,7 +455,7 @@ end) : Ir.Backend_impl.Lowered_backend = struct
455455
| Ops.Uint16_prec _ -> ""
456456
| Ops.Int32_prec _ -> ""
457457
| Ops.Half_prec _ -> "h"
458-
| Ops.Bfloat16_prec _ -> "bf" (* TODO: Verify actual Metal suffix for bfloat16 *)
458+
| Ops.Bfloat16_prec _ -> "bf" (* TODO: Verify actual Metal suffix for bfloat16 *)
459459
| Ops.Fp8_prec _ -> invalid_arg "Metal backend does not support FP8 precision"
460460
| Ops.Single_prec _ -> "f"
461461
| Ops.Double_prec _ -> ""

0 commit comments

Comments
 (0)