Skip to content

Commit 367ff3b

Browse files
committed
Claude's third pass at adding BF16, FP8: proper conversions in pure C
Claude decided to provide full definitions to inline rather than just headers for jitted C sources, for performance.
1 parent 1b56fc4 commit 367ff3b

File tree

3 files changed

+146
-40
lines changed

3 files changed

+146
-40
lines changed

arrayjit/lib/arrayjit_stubs.c

Lines changed: 62 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,44 +4,34 @@
44
#include <math.h>
55
#include <stdint.h>
66

7-
/* BFloat16 to Float conversion */
8-
CAMLprim value arrayjit_bfloat16_to_float(value v_bf16)
7+
/* Pure C conversion functions for use in C backends */
8+
9+
/* BFloat16 to Float conversion (C function) */
10+
static inline float bfloat16_to_float(uint16_t bf16)
911
{
10-
CAMLparam1(v_bf16);
11-
uint16_t bf16 = (uint16_t)Int_val(v_bf16);
12-
1312
/* BFloat16 format: 1 sign bit, 8 exponent bits, 7 mantissa bits
1413
To convert to float32, we shift left by 16 bits */
1514
uint32_t f32 = ((uint32_t)bf16) << 16;
16-
float result = *((float*)&f32);
17-
18-
CAMLreturn(caml_copy_double((double)result));
15+
return *((float*)&f32);
1916
}
2017

21-
/* Float to BFloat16 conversion */
22-
CAMLprim value arrayjit_float_to_bfloat16(value v_float)
18+
/* Float to BFloat16 conversion (C function) */
19+
static inline uint16_t float_to_bfloat16(float f)
2320
{
24-
CAMLparam1(v_float);
25-
float f = (float)Double_val(v_float);
2621
uint32_t f32 = *((uint32_t*)&f);
2722

2823
/* Round to nearest even */
2924
uint32_t rounded = f32 + 0x7FFF + ((f32 >> 16) & 1);
30-
uint16_t bf16 = (uint16_t)(rounded >> 16);
31-
32-
CAMLreturn(Val_int(bf16));
25+
return (uint16_t)(rounded >> 16);
3326
}
3427

35-
/* FP8 E5M2 format to Float conversion
28+
/* FP8 E5M2 format to Float conversion (C function)
3629
Format: 1 sign bit, 5 exponent bits, 2 mantissa bits */
37-
CAMLprim value arrayjit_fp8_to_float(value v_fp8)
30+
static inline float fp8_to_float(uint8_t fp8)
3831
{
39-
CAMLparam1(v_fp8);
40-
uint8_t fp8 = (uint8_t)Int_val(v_fp8);
41-
4232
/* Handle zero */
4333
if (fp8 == 0) {
44-
CAMLreturn(caml_copy_double(0.0));
34+
return 0.0f;
4535
}
4636

4737
uint32_t sign = (fp8 >> 7) & 1;
@@ -51,47 +41,43 @@ CAMLprim value arrayjit_fp8_to_float(value v_fp8)
5141
/* Handle special cases */
5242
if (exp == 0x1F) { /* Infinity or NaN */
5343
if (mant == 0) {
54-
float inf = sign ? -INFINITY : INFINITY;
55-
CAMLreturn(caml_copy_double((double)inf));
44+
return sign ? -INFINITY : INFINITY;
5645
} else {
57-
CAMLreturn(caml_copy_double((double)NAN));
46+
return NAN;
5847
}
5948
}
6049

6150
/* Denormalized numbers */
6251
if (exp == 0) {
6352
float result = ldexpf((float)mant / 4.0f, -14);
6453
if (sign) result = -result;
65-
CAMLreturn(caml_copy_double((double)result));
54+
return result;
6655
}
6756

6857
/* Normalized numbers */
6958
float result = (1.0f + (float)mant * 0.25f) * ldexpf(1.0f, (int)exp - 15);
7059
if (sign) result = -result;
7160

72-
CAMLreturn(caml_copy_double((double)result));
61+
return result;
7362
}
7463

75-
/* Float to FP8 E5M2 conversion */
76-
CAMLprim value arrayjit_float_to_fp8(value v_float)
64+
/* Float to FP8 E5M2 conversion (C function) */
65+
static inline uint8_t float_to_fp8(float f)
7766
{
78-
CAMLparam1(v_float);
79-
float f = (float)Double_val(v_float);
80-
8167
/* Handle zero */
8268
if (f == 0.0f) {
83-
CAMLreturn(Val_int(0));
69+
return 0;
8470
}
8571

8672
uint32_t sign = (f < 0) ? 1 : 0;
8773
f = fabsf(f);
8874

8975
/* Handle special cases */
9076
if (isinf(f)) {
91-
CAMLreturn(Val_int((sign << 7) | 0x7C)); /* Infinity: exp=0x1F, mant=0 */
77+
return (sign << 7) | 0x7C; /* Infinity: exp=0x1F, mant=0 */
9278
}
9379
if (isnan(f)) {
94-
CAMLreturn(Val_int((sign << 7) | 0x7F)); /* NaN: exp=0x1F, mant!=0 */
80+
return (sign << 7) | 0x7F; /* NaN: exp=0x1F, mant!=0 */
9581
}
9682

9783
/* Get exponent and mantissa */
@@ -102,26 +88,63 @@ CAMLprim value arrayjit_float_to_fp8(value v_float)
10288
/* Clamp to representable range */
10389
if (exp < 0) {
10490
/* Underflow to zero */
105-
CAMLreturn(Val_int(sign << 7));
91+
return sign << 7;
10692
}
10793
if (exp > 30) {
10894
/* Overflow to infinity */
109-
CAMLreturn(Val_int((sign << 7) | 0x7C));
95+
return (sign << 7) | 0x7C;
11096
}
11197

11298
/* Handle denormalized numbers */
11399
if (exp == 0) {
114100
float denorm_mant = f * ldexpf(1.0f, 14) * 4.0f;
115101
uint32_t mant_bits = (uint32_t)(denorm_mant + 0.5f);
116102
if (mant_bits > 3) mant_bits = 3;
117-
CAMLreturn(Val_int((sign << 7) | mant_bits));
103+
return (sign << 7) | mant_bits;
118104
}
119105

120106
/* Normalized numbers: convert mantissa from [0.5, 1) to [0, 0.75] */
121107
mant_f = (mant_f - 0.5f) * 4.0f;
122108
uint32_t mant_bits = (uint32_t)(mant_f + 0.5f); /* Round to nearest */
123109
if (mant_bits > 3) mant_bits = 3;
124110

125-
uint8_t result = (uint8_t)((sign << 7) | ((exp & 0x1F) << 2) | (mant_bits & 0x3));
126-
CAMLreturn(Val_int(result));
111+
return (uint8_t)((sign << 7) | ((exp & 0x1F) << 2) | (mant_bits & 0x3));
112+
}
113+
114+
/* OCaml wrapper functions */
115+
116+
/* BFloat16 to Float conversion (OCaml wrapper) */
117+
CAMLprim value arrayjit_bfloat16_to_float(value v_bf16)
118+
{
119+
CAMLparam1(v_bf16);
120+
uint16_t bf16 = (uint16_t)Int_val(v_bf16);
121+
float result = bfloat16_to_float(bf16);
122+
CAMLreturn(caml_copy_double((double)result));
123+
}
124+
125+
/* Float to BFloat16 conversion (OCaml wrapper) */
126+
CAMLprim value arrayjit_float_to_bfloat16(value v_float)
127+
{
128+
CAMLparam1(v_float);
129+
float f = (float)Double_val(v_float);
130+
uint16_t bf16 = float_to_bfloat16(f);
131+
CAMLreturn(Val_int(bf16));
132+
}
133+
134+
/* FP8 E5M2 format to Float conversion (OCaml wrapper) */
135+
CAMLprim value arrayjit_fp8_to_float(value v_fp8)
136+
{
137+
CAMLparam1(v_fp8);
138+
uint8_t fp8 = (uint8_t)Int_val(v_fp8);
139+
float result = fp8_to_float(fp8);
140+
CAMLreturn(caml_copy_double((double)result));
141+
}
142+
143+
/* Float to FP8 E5M2 conversion (OCaml wrapper) */
144+
CAMLprim value arrayjit_float_to_fp8(value v_float)
145+
{
146+
CAMLparam1(v_float);
147+
float f = (float)Double_val(v_float);
148+
uint8_t fp8 = float_to_fp8(f);
149+
CAMLreturn(Val_int(fp8));
127150
}

arrayjit/lib/c_syntax.ml

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,62 @@ struct
9191
let arg_int_prefix = "const int "
9292
let extra_args = []
9393
let includes = [ "<stdio.h>"; "<stdlib.h>"; "<string.h>"; "<math.h>" ]
94-
let extra_declarations = []
94+
let extra_declarations = [
95+
(* BFloat16 conversion functions *)
96+
"static inline float bfloat16_to_float(unsigned short bf16) {";
97+
" unsigned int f32 = ((unsigned int)bf16) << 16;";
98+
" return *((float*)&f32);";
99+
"}";
100+
"";
101+
"static inline unsigned short float_to_bfloat16(float f) {";
102+
" unsigned int f32 = *((unsigned int*)&f);";
103+
" unsigned int rounded = f32 + 0x7FFF + ((f32 >> 16) & 1);";
104+
" return (unsigned short)(rounded >> 16);";
105+
"}";
106+
"";
107+
(* FP8 E5M2 conversion functions *)
108+
"static inline float fp8_to_float(unsigned char fp8) {";
109+
" if (fp8 == 0) return 0.0f;";
110+
" unsigned int sign = (fp8 >> 7) & 1;";
111+
" unsigned int exp = (fp8 >> 2) & 0x1F;";
112+
" unsigned int mant = fp8 & 0x3;";
113+
" if (exp == 0x1F) {";
114+
" if (mant == 0) return sign ? -INFINITY : INFINITY;";
115+
" else return NAN;";
116+
" }";
117+
" if (exp == 0) {";
118+
" float result = ldexpf((float)mant / 4.0f, -14);";
119+
" if (sign) result = -result;";
120+
" return result;";
121+
" }";
122+
" float result = (1.0f + (float)mant * 0.25f) * ldexpf(1.0f, (int)exp - 15);";
123+
" if (sign) result = -result;";
124+
" return result;";
125+
"}";
126+
"";
127+
"static inline unsigned char float_to_fp8(float f) {";
128+
" if (f == 0.0f) return 0;";
129+
" unsigned int sign = (f < 0) ? 1 : 0;";
130+
" f = fabsf(f);";
131+
" if (isinf(f)) return (sign << 7) | 0x7C;";
132+
" if (isnan(f)) return (sign << 7) | 0x7F;";
133+
" int exp_val;";
134+
" float mant_f = frexpf(f, &exp_val);";
135+
" int exp = exp_val + 14;";
136+
" if (exp < 0) return sign << 7;";
137+
" if (exp > 30) return (sign << 7) | 0x7C;";
138+
" if (exp == 0) {";
139+
" float denorm_mant = f * ldexpf(1.0f, 14) * 4.0f;";
140+
" unsigned int mant_bits = (unsigned int)(denorm_mant + 0.5f);";
141+
" if (mant_bits > 3) mant_bits = 3;";
142+
" return (sign << 7) | mant_bits;";
143+
" }";
144+
" mant_f = (mant_f - 0.5f) * 4.0f;";
145+
" unsigned int mant_bits = (unsigned int)(mant_f + 0.5f);";
146+
" if (mant_bits > 3) mant_bits = 3;";
147+
" return (unsigned char)((sign << 7) | ((exp & 0x1F) << 2) | (mant_bits & 0x3));";
148+
"}";
149+
]
95150
let typ_of_prec = Ops.c_typ_of_prec
96151
let float_log_style = if Input.full_printf_support then "%g" else "%de-3"
97152

arrayjit/lib/ops.ml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,34 @@ let c_convert_precision ~from ~to_ =
519519
| Fp8_prec _, Fp8_prec _
520520
| Void_prec, Void_prec ->
521521
("", "")
522+
(* BFloat16 conversions *)
523+
| Bfloat16_prec _, Single_prec _ -> ("bfloat16_to_float(", ")")
524+
| Single_prec _, Bfloat16_prec _ -> ("float_to_bfloat16(", ")")
525+
| Bfloat16_prec _, Double_prec _ -> ("(double)bfloat16_to_float(", ")")
526+
| Double_prec _, Bfloat16_prec _ -> ("float_to_bfloat16((float)", ")")
527+
(* FP8 conversions *)
528+
| Fp8_prec _, Single_prec _ -> ("fp8_to_float(", ")")
529+
| Single_prec _, Fp8_prec _ -> ("float_to_fp8(", ")")
530+
| Fp8_prec _, Double_prec _ -> ("(double)fp8_to_float(", ")")
531+
| Double_prec _, Fp8_prec _ -> ("float_to_fp8((float)", ")")
532+
(* Conversions involving BFloat16 and other types *)
533+
| Bfloat16_prec _, Half_prec _ -> ("(_Float16)bfloat16_to_float(", ")")
534+
| Half_prec _, Bfloat16_prec _ -> ("float_to_bfloat16((float)", ")")
535+
| Bfloat16_prec _, (Byte_prec _ | Uint16_prec _ | Int32_prec _) ->
536+
("(" ^ c_typ_of_prec to_ ^ ")bfloat16_to_float(", ")")
537+
| (Byte_prec _ | Uint16_prec _ | Int32_prec _), Bfloat16_prec _ ->
538+
("float_to_bfloat16((float)", ")")
539+
(* Conversions involving FP8 and other types *)
540+
| Fp8_prec _, Half_prec _ -> ("(_Float16)fp8_to_float(", ")")
541+
| Half_prec _, Fp8_prec _ -> ("float_to_fp8((float)", ")")
542+
| Fp8_prec _, (Byte_prec _ | Uint16_prec _ | Int32_prec _) ->
543+
("(" ^ c_typ_of_prec to_ ^ ")fp8_to_float(", ")")
544+
| (Byte_prec _ | Uint16_prec _ | Int32_prec _), Fp8_prec _ ->
545+
("float_to_fp8((float)", ")")
546+
(* BFloat16 <-> FP8 conversions *)
547+
| Bfloat16_prec _, Fp8_prec _ -> ("float_to_fp8(bfloat16_to_float(", "))")
548+
| Fp8_prec _, Bfloat16_prec _ -> ("float_to_bfloat16(fp8_to_float(", "))")
549+
(* Default case for all other conversions *)
522550
| _ -> ("(" ^ c_typ_of_prec to_ ^ ")(", ")")
523551

524552
(** {2 *** Global references ***} *)

0 commit comments

Comments
 (0)