44#include <math.h>
55#include <stdint.h>
66
7- /* BFloat16 to Float conversion */
8- CAMLprim value arrayjit_bfloat16_to_float (value v_bf16 )
7+ /* Pure C conversion functions for use in C backends */
8+
9+ /* BFloat16 to Float conversion (C function) */
10+ static inline float bfloat16_to_float (uint16_t bf16 )
911{
10- CAMLparam1 (v_bf16 );
11- uint16_t bf16 = (uint16_t )Int_val (v_bf16 );
12-
1312 /* BFloat16 format: 1 sign bit, 8 exponent bits, 7 mantissa bits
1413 To convert to float32, we shift left by 16 bits */
1514 uint32_t f32 = ((uint32_t )bf16 ) << 16 ;
16- float result = * ((float * )& f32 );
17-
18- CAMLreturn (caml_copy_double ((double )result ));
15+ return * ((float * )& f32 );
1916}
2017
21- /* Float to BFloat16 conversion */
22- CAMLprim value arrayjit_float_to_bfloat16 ( value v_float )
18+ /* Float to BFloat16 conversion (C function) */
19+ static inline uint16_t float_to_bfloat16 ( float f )
2320{
24- CAMLparam1 (v_float );
25- float f = (float )Double_val (v_float );
2621 uint32_t f32 = * ((uint32_t * )& f );
2722
2823 /* Round to nearest even */
2924 uint32_t rounded = f32 + 0x7FFF + ((f32 >> 16 ) & 1 );
30- uint16_t bf16 = (uint16_t )(rounded >> 16 );
31-
32- CAMLreturn (Val_int (bf16 ));
25+ return (uint16_t )(rounded >> 16 );
3326}
3427
35- /* FP8 E5M2 format to Float conversion
28+ /* FP8 E5M2 format to Float conversion (C function)
3629 Format: 1 sign bit, 5 exponent bits, 2 mantissa bits */
37- CAMLprim value arrayjit_fp8_to_float ( value v_fp8 )
30+ static inline float fp8_to_float ( uint8_t fp8 )
3831{
39- CAMLparam1 (v_fp8 );
40- uint8_t fp8 = (uint8_t )Int_val (v_fp8 );
41-
4232 /* Handle zero */
4333 if (fp8 == 0 ) {
44- CAMLreturn ( caml_copy_double ( 0.0 )) ;
34+ return 0.0f ;
4535 }
4636
4737 uint32_t sign = (fp8 >> 7 ) & 1 ;
@@ -51,47 +41,43 @@ CAMLprim value arrayjit_fp8_to_float(value v_fp8)
5141 /* Handle special cases */
5242 if (exp == 0x1F ) { /* Infinity or NaN */
5343 if (mant == 0 ) {
54- float inf = sign ? - INFINITY : INFINITY ;
55- CAMLreturn (caml_copy_double ((double )inf ));
44+ return sign ? - INFINITY : INFINITY ;
5645 } else {
57- CAMLreturn ( caml_copy_double (( double ) NAN )) ;
46+ return NAN ;
5847 }
5948 }
6049
6150 /* Denormalized numbers */
6251 if (exp == 0 ) {
6352 float result = ldexpf ((float )mant / 4.0f , -14 );
6453 if (sign ) result = - result ;
65- CAMLreturn ( caml_copy_double (( double ) result )) ;
54+ return result ;
6655 }
6756
6857 /* Normalized numbers */
6958 float result = (1.0f + (float )mant * 0.25f ) * ldexpf (1.0f , (int )exp - 15 );
7059 if (sign ) result = - result ;
7160
72- CAMLreturn ( caml_copy_double (( double ) result )) ;
61+ return result ;
7362}
7463
75- /* Float to FP8 E5M2 conversion */
76- CAMLprim value arrayjit_float_to_fp8 ( value v_float )
64+ /* Float to FP8 E5M2 conversion (C function) */
65+ static inline uint8_t float_to_fp8 ( float f )
7766{
78- CAMLparam1 (v_float );
79- float f = (float )Double_val (v_float );
80-
8167 /* Handle zero */
8268 if (f == 0.0f ) {
83- CAMLreturn ( Val_int ( 0 )) ;
69+ return 0 ;
8470 }
8571
8672 uint32_t sign = (f < 0 ) ? 1 : 0 ;
8773 f = fabsf (f );
8874
8975 /* Handle special cases */
9076 if (isinf (f )) {
91- CAMLreturn ( Val_int (( sign << 7 ) | 0x7C )) ; /* Infinity: exp=0x1F, mant=0 */
77+ return ( sign << 7 ) | 0x7C ; /* Infinity: exp=0x1F, mant=0 */
9278 }
9379 if (isnan (f )) {
94- CAMLreturn ( Val_int (( sign << 7 ) | 0x7F )) ; /* NaN: exp=0x1F, mant!=0 */
80+ return ( sign << 7 ) | 0x7F ; /* NaN: exp=0x1F, mant!=0 */
9581 }
9682
9783 /* Get exponent and mantissa */
@@ -102,26 +88,63 @@ CAMLprim value arrayjit_float_to_fp8(value v_float)
10288 /* Clamp to representable range */
10389 if (exp < 0 ) {
10490 /* Underflow to zero */
105- CAMLreturn ( Val_int ( sign << 7 )) ;
91+ return sign << 7 ;
10692 }
10793 if (exp > 30 ) {
10894 /* Overflow to infinity */
109- CAMLreturn ( Val_int (( sign << 7 ) | 0x7C )) ;
95+ return ( sign << 7 ) | 0x7C ;
11096 }
11197
11298 /* Handle denormalized numbers */
11399 if (exp == 0 ) {
114100 float denorm_mant = f * ldexpf (1.0f , 14 ) * 4.0f ;
115101 uint32_t mant_bits = (uint32_t )(denorm_mant + 0.5f );
116102 if (mant_bits > 3 ) mant_bits = 3 ;
117- CAMLreturn ( Val_int (( sign << 7 ) | mant_bits )) ;
103+ return ( sign << 7 ) | mant_bits ;
118104 }
119105
120106 /* Normalized numbers: convert mantissa from [0.5, 1) to [0, 0.75] */
121107 mant_f = (mant_f - 0.5f ) * 4.0f ;
122108 uint32_t mant_bits = (uint32_t )(mant_f + 0.5f ); /* Round to nearest */
123109 if (mant_bits > 3 ) mant_bits = 3 ;
124110
125- uint8_t result = (uint8_t )((sign << 7 ) | ((exp & 0x1F ) << 2 ) | (mant_bits & 0x3 ));
126- CAMLreturn (Val_int (result ));
111+ return (uint8_t )((sign << 7 ) | ((exp & 0x1F ) << 2 ) | (mant_bits & 0x3 ));
112+ }
113+
114+ /* OCaml wrapper functions */
115+
116+ /* BFloat16 to Float conversion (OCaml wrapper) */
117+ CAMLprim value arrayjit_bfloat16_to_float (value v_bf16 )
118+ {
119+ CAMLparam1 (v_bf16 );
120+ uint16_t bf16 = (uint16_t )Int_val (v_bf16 );
121+ float result = bfloat16_to_float (bf16 );
122+ CAMLreturn (caml_copy_double ((double )result ));
123+ }
124+
125+ /* Float to BFloat16 conversion (OCaml wrapper) */
126+ CAMLprim value arrayjit_float_to_bfloat16 (value v_float )
127+ {
128+ CAMLparam1 (v_float );
129+ float f = (float )Double_val (v_float );
130+ uint16_t bf16 = float_to_bfloat16 (f );
131+ CAMLreturn (Val_int (bf16 ));
132+ }
133+
134+ /* FP8 E5M2 format to Float conversion (OCaml wrapper) */
135+ CAMLprim value arrayjit_fp8_to_float (value v_fp8 )
136+ {
137+ CAMLparam1 (v_fp8 );
138+ uint8_t fp8 = (uint8_t )Int_val (v_fp8 );
139+ float result = fp8_to_float (fp8 );
140+ CAMLreturn (caml_copy_double ((double )result ));
141+ }
142+
143+ /* Float to FP8 E5M2 conversion (OCaml wrapper) */
144+ CAMLprim value arrayjit_float_to_fp8 (value v_float )
145+ {
146+ CAMLparam1 (v_float );
147+ float f = (float )Double_val (v_float );
148+ uint8_t fp8 = float_to_fp8 (f );
149+ CAMLreturn (Val_int (fp8 ));
127150}
0 commit comments