11#include <caml/alloc.h>
2+ #include <caml/fail.h>
23#include <caml/memory.h>
34#include <caml/mlvalues.h>
45#include <caml/bigarray.h>
56#include <math.h>
67#include <stdint.h>
78#include <string.h>
9+ #include <stdlib.h>
810
911/* Pure C conversion functions for use in C backends */
1012
@@ -14,14 +16,14 @@ static inline float bfloat16_to_float(uint16_t bf16)
1416 /* BFloat16 format: 1 sign bit, 8 exponent bits, 7 mantissa bits
1517 To convert to float32, we shift left by 16 bits */
1618 uint32_t f32 = ((uint32_t )bf16 ) << 16 ;
17- return * ((float * )& f32 );
19+ return * ((float * )& f32 );
1820}
1921
2022/* Float to BFloat16 conversion (C function) */
2123static inline uint16_t float_to_bfloat16 (float f )
2224{
23- uint32_t f32 = * ((uint32_t * )& f );
24-
25+ uint32_t f32 = * ((uint32_t * )& f );
26+
2527 /* Round to nearest even */
2628 uint32_t rounded = f32 + 0x7FFF + ((f32 >> 16 ) & 1 );
2729 return (uint16_t )(rounded >> 16 );
@@ -32,88 +34,105 @@ static inline uint16_t float_to_bfloat16(float f)
3234static inline float fp8_to_float (uint8_t fp8 )
3335{
3436 /* Handle zero */
35- if (fp8 == 0 ) {
37+ if (fp8 == 0 )
38+ {
3639 return 0.0f ;
3740 }
38-
41+
3942 uint32_t sign = (fp8 >> 7 ) & 1 ;
4043 uint32_t exp = (fp8 >> 2 ) & 0x1F ;
4144 uint32_t mant = fp8 & 0x3 ;
42-
45+
4346 /* Handle special cases */
44- if (exp == 0x1F ) { /* Infinity or NaN */
45- if (mant == 0 ) {
47+ if (exp == 0x1F )
48+ { /* Infinity or NaN */
49+ if (mant == 0 )
50+ {
4651 return sign ? - INFINITY : INFINITY ;
47- } else {
52+ }
53+ else
54+ {
4855 return NAN ;
4956 }
5057 }
51-
58+
5259 /* Denormalized numbers */
53- if (exp == 0 ) {
60+ if (exp == 0 )
61+ {
5462 float result = ldexpf ((float )mant / 4.0f , -14 );
55- if (sign ) result = - result ;
63+ if (sign )
64+ result = - result ;
5665 return result ;
5766 }
58-
67+
5968 /* Normalized numbers */
6069 float result = (1.0f + (float )mant * 0.25f ) * ldexpf (1.0f , (int )exp - 15 );
61- if (sign ) result = - result ;
62-
70+ if (sign )
71+ result = - result ;
72+
6373 return result ;
6474}
6575
6676/* Float to FP8 E5M2 conversion (C function) */
6777static inline uint8_t float_to_fp8 (float f )
6878{
6979 /* Handle zero */
70- if (f == 0.0f ) {
80+ if (f == 0.0f )
81+ {
7182 return 0 ;
7283 }
73-
84+
7485 uint32_t sign = (f < 0 ) ? 1 : 0 ;
7586 f = fabsf (f );
76-
87+
7788 /* Handle special cases */
78- if (isinf (f )) {
79- return (sign << 7 ) | 0x7C ; /* Infinity: exp=0x1F, mant=0 */
89+ if (isinf (f ))
90+ {
91+ return (sign << 7 ) | 0x7C ; /* Infinity: exp=0x1F, mant=0 */
8092 }
81- if (isnan (f )) {
82- return (sign << 7 ) | 0x7F ; /* NaN: exp=0x1F, mant!=0 */
93+ if (isnan (f ))
94+ {
95+ return (sign << 7 ) | 0x7F ; /* NaN: exp=0x1F, mant!=0 */
8396 }
84-
97+
8598 /* Get exponent and mantissa */
8699 int exp_val ;
87100 float mant_f = frexpf (f , & exp_val );
88- int exp = exp_val + 14 ; /* Bias is 15, but frexp gives us mantissa in [0.5, 1) */
89-
101+ int exp = exp_val + 14 ; /* Bias is 15, but frexp gives us mantissa in [0.5, 1) */
102+
90103 /* Clamp to representable range */
91- if (exp < 0 ) {
104+ if (exp < 0 )
105+ {
92106 /* Underflow to zero */
93107 return sign << 7 ;
94108 }
95- if (exp > 30 ) {
109+ if (exp > 30 )
110+ {
96111 /* Overflow to infinity */
97112 return (sign << 7 ) | 0x7C ;
98113 }
99-
114+
100115 /* Handle denormalized numbers */
101- if (exp == 0 ) {
116+ if (exp == 0 )
117+ {
102118 float denorm_mant = f * ldexpf (1.0f , 14 ) * 4.0f ;
103119 uint32_t mant_bits = (uint32_t )(denorm_mant + 0.5f );
104- if (mant_bits > 3 ) mant_bits = 3 ;
120+ if (mant_bits > 3 )
121+ mant_bits = 3 ;
105122 return (sign << 7 ) | mant_bits ;
106123 }
107-
124+
108125 /* Normalized numbers: convert mantissa from [0.5, 1) to [0, 0.75] */
109126 mant_f = (mant_f - 0.5f ) * 4.0f ;
110- uint32_t mant_bits = (uint32_t )(mant_f + 0.5f ); /* Round to nearest */
111- if (mant_bits > 3 ) mant_bits = 3 ;
112-
127+ uint32_t mant_bits = (uint32_t )(mant_f + 0.5f ); /* Round to nearest */
128+ if (mant_bits > 3 )
129+ mant_bits = 3 ;
130+
113131 return (uint8_t )((sign << 7 ) | ((exp & 0x1F ) << 2 ) | (mant_bits & 0x3 ));
114132}
115133
116- typedef struct {
134+ typedef struct
135+ {
117136 uint32_t v [4 ];
118137} uint4x32_t ;
119138
@@ -141,7 +160,7 @@ CAMLprim value arrayjit_bfloat16_to_float(value v_bf16)
141160}
142161
143162/* Float to BFloat16 conversion (OCaml wrapper) */
144- CAMLprim value arrayjit_float_to_bfloat16 (value v_float )
163+ CAMLprim value arrayjit_float_to_bfloat16 (value v_float )
145164{
146165 CAMLparam1 (v_float );
147166 float f = (float )Double_val (v_float );
@@ -167,45 +186,115 @@ CAMLprim value arrayjit_float_to_fp8(value v_float)
167186 CAMLreturn (Val_int (fp8 ));
168187}
169188
170- /* Efficient copying with padding support */
171- CAMLprim value arrayjit_copy_with_padding ( value v_source , value v_target ,
172- value v_source_dims , value v_padding )
189+ // TODO: a more efficient approach would involve computing strides once and using memcpy
190+ // for contiguous inner slices, but that adds complexity).
191+ CAMLprim value arrayjit_copy_with_padding ( value v_source , value v_target , value v_padding )
173192{
174- CAMLparam4 (v_source , v_target , v_source_dims , v_padding );
175-
176- /* Get the bigarray data pointers */
177- void * source_data = Caml_ba_data_val (v_source );
178- void * target_data = Caml_ba_data_val (v_target );
179-
180- /* Get element size in bytes */
181- int kind = Caml_ba_kind_val (v_source );
182- size_t elem_size ;
183- switch (kind ) {
184- case CAML_BA_FLOAT32 : elem_size = 4 ; break ;
185- case CAML_BA_FLOAT64 : elem_size = 8 ; break ;
186- case CAML_BA_SINT8 :
187- case CAML_BA_UINT8 : elem_size = 1 ; break ;
188- case CAML_BA_SINT16 :
189- case CAML_BA_UINT16 : elem_size = 2 ; break ;
190- case CAML_BA_INT32 : elem_size = 4 ; break ;
191- case CAML_BA_COMPLEX64 : elem_size = 16 ; break ;
192- default : elem_size = 8 ; break ;
193+ CAMLparam3 (v_source , v_target , v_padding );
194+
195+ struct caml_ba_array * source_ba = Caml_ba_array_val (v_source );
196+ struct caml_ba_array * target_ba = Caml_ba_array_val (v_target );
197+ int ndim = source_ba -> num_dims ;
198+
199+ if (ndim != target_ba -> num_dims )
200+ {
201+ caml_failwith ("Source and target must have the same number of dimensions" );
202+ }
203+
204+ if (ndim == 0 )
205+ {
206+ CAMLreturn (Val_unit );
193207 }
194-
195- /* FIXME: For now, implement a simple flat copy */
196- /* The proper padding-aware copy would require more complex logic */
197- /* but this provides a foundation for optimization */
198- struct caml_ba_array * source_ba = Caml_ba_array_val (v_source );
199- intnat * source_dims_ba = source_ba -> dim ;
200- int source_ndim = source_ba -> num_dims ;
201-
202- size_t source_total = 1 ;
203- for (int i = 0 ; i < source_ndim ; i ++ ) {
204- source_total *= source_dims_ba [i ];
208+
209+ void * source_data = Caml_ba_data_val (v_source );
210+ void * target_data = Caml_ba_data_val (v_target );
211+
212+ size_t elem_size = caml_ba_byte_size (source_ba );
213+
214+ // Use source dimensions directly from bigarray
215+ intnat * source_shape = source_ba -> dim ;
216+
217+ // Extract paddings
218+ if (Wosize_val (v_padding ) != (uintnat )ndim )
219+ {
220+ caml_failwith ("Padding array length mismatch" );
221+ }
222+ intnat * left = malloc (ndim * sizeof (intnat ));
223+ intnat * right = malloc (ndim * sizeof (intnat ));
224+ if (left == NULL || right == NULL )
225+ caml_failwith ("Malloc failed" );
226+ for (int d = 0 ; d < ndim ; d ++ )
227+ {
228+ value pad = Field (v_padding , d );
229+ left [d ] = Long_val (Field (pad , 0 ));
230+ right [d ] = Long_val (Field (pad , 1 ));
231+ if (left [d ] < 0 || right [d ] < 0 )
232+ caml_failwith ("Negative padding" );
205233 }
206-
207- /* FIXME: Simple memcpy for now - must implement proper padding */
208- memcpy (target_data , source_data , source_total * elem_size );
209-
234+
235+ // Verify target dimensions match source + padding
236+ for (int d = 0 ; d < ndim ; d ++ )
237+ {
238+ if (target_ba -> dim [d ] != source_shape [d ] + left [d ] + right [d ])
239+ {
240+ caml_failwith ("Target dimensions do not match source + padding" );
241+ }
242+ }
243+
244+ // Multi-dimensional index loop
245+ intnat * indices = calloc (ndim , sizeof (intnat ));
246+ if (indices == NULL )
247+ caml_failwith ("Calloc failed" );
248+
249+ while (1 )
250+ {
251+ // Compute source flat offset
252+ intnat source_offset = 0 ;
253+ intnat s_stride = 1 ;
254+ for (int d = ndim - 1 ; d >= 0 ; d -- )
255+ {
256+ source_offset += indices [d ] * s_stride ;
257+ s_stride *= source_shape [d ];
258+ }
259+
260+ // Compute target flat offset with padding offset
261+ intnat target_offset = 0 ;
262+ intnat t_stride = 1 ;
263+ for (int d = ndim - 1 ; d >= 0 ; d -- )
264+ {
265+ target_offset += (indices [d ] + left [d ]) * t_stride ;
266+ t_stride *= target_ba -> dim [d ];
267+ }
268+
269+ // Copy the element
270+ memcpy ((char * )target_data + target_offset * elem_size ,
271+ (char * )source_data + source_offset * elem_size ,
272+ elem_size );
273+
274+ // Increment indices (odometer-style)
275+ int carry = 1 ;
276+ for (int d = ndim - 1 ; d >= 0 ; d -- )
277+ {
278+ if (carry == 0 )
279+ break ;
280+ indices [d ] += carry ;
281+ if (indices [d ] < source_shape [d ])
282+ {
283+ carry = 0 ;
284+ }
285+ else
286+ {
287+ indices [d ] = 0 ;
288+ carry = 1 ;
289+ }
290+ }
291+ if (carry == 1 )
292+ break ; // Done
293+ }
294+
295+ free (indices );
296+ free (left );
297+ free (right );
298+
210299 CAMLreturn (Val_unit );
211- }
300+ }
0 commit comments