Skip to content

Commit 61cf318

Browse files
committed
Proper implementation of arrayjit_copy_with_padding, by Grok
1 parent 72b981e commit 61cf318

File tree

2 files changed

+165
-77
lines changed

2 files changed

+165
-77
lines changed

arrayjit/lib/arrayjit_stubs.c

Lines changed: 163 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
#include <caml/alloc.h>
2+
#include <caml/fail.h>
23
#include <caml/memory.h>
34
#include <caml/mlvalues.h>
45
#include <caml/bigarray.h>
56
#include <math.h>
67
#include <stdint.h>
78
#include <string.h>
9+
#include <stdlib.h>
810

911
/* Pure C conversion functions for use in C backends */
1012

@@ -14,14 +16,14 @@ static inline float bfloat16_to_float(uint16_t bf16)
1416
/* BFloat16 format: 1 sign bit, 8 exponent bits, 7 mantissa bits
1517
To convert to float32, we shift left by 16 bits */
1618
uint32_t f32 = ((uint32_t)bf16) << 16;
17-
return *((float*)&f32);
19+
return *((float *)&f32);
1820
}
1921

2022
/* Float to BFloat16 conversion (C function) */
2123
static inline uint16_t float_to_bfloat16(float f)
2224
{
23-
uint32_t f32 = *((uint32_t*)&f);
24-
25+
uint32_t f32 = *((uint32_t *)&f);
26+
2527
/* Round to nearest even */
2628
uint32_t rounded = f32 + 0x7FFF + ((f32 >> 16) & 1);
2729
return (uint16_t)(rounded >> 16);
@@ -32,88 +34,105 @@ static inline uint16_t float_to_bfloat16(float f)
3234
static inline float fp8_to_float(uint8_t fp8)
3335
{
3436
/* Handle zero */
35-
if (fp8 == 0) {
37+
if (fp8 == 0)
38+
{
3639
return 0.0f;
3740
}
38-
41+
3942
uint32_t sign = (fp8 >> 7) & 1;
4043
uint32_t exp = (fp8 >> 2) & 0x1F;
4144
uint32_t mant = fp8 & 0x3;
42-
45+
4346
/* Handle special cases */
44-
if (exp == 0x1F) { /* Infinity or NaN */
45-
if (mant == 0) {
47+
if (exp == 0x1F)
48+
{ /* Infinity or NaN */
49+
if (mant == 0)
50+
{
4651
return sign ? -INFINITY : INFINITY;
47-
} else {
52+
}
53+
else
54+
{
4855
return NAN;
4956
}
5057
}
51-
58+
5259
/* Denormalized numbers */
53-
if (exp == 0) {
60+
if (exp == 0)
61+
{
5462
float result = ldexpf((float)mant / 4.0f, -14);
55-
if (sign) result = -result;
63+
if (sign)
64+
result = -result;
5665
return result;
5766
}
58-
67+
5968
/* Normalized numbers */
6069
float result = (1.0f + (float)mant * 0.25f) * ldexpf(1.0f, (int)exp - 15);
61-
if (sign) result = -result;
62-
70+
if (sign)
71+
result = -result;
72+
6373
return result;
6474
}
6575

6676
/* Float to FP8 E5M2 conversion (C function) */
6777
static inline uint8_t float_to_fp8(float f)
6878
{
6979
/* Handle zero */
70-
if (f == 0.0f) {
80+
if (f == 0.0f)
81+
{
7182
return 0;
7283
}
73-
84+
7485
uint32_t sign = (f < 0) ? 1 : 0;
7586
f = fabsf(f);
76-
87+
7788
/* Handle special cases */
78-
if (isinf(f)) {
79-
return (sign << 7) | 0x7C; /* Infinity: exp=0x1F, mant=0 */
89+
if (isinf(f))
90+
{
91+
return (sign << 7) | 0x7C; /* Infinity: exp=0x1F, mant=0 */
8092
}
81-
if (isnan(f)) {
82-
return (sign << 7) | 0x7F; /* NaN: exp=0x1F, mant!=0 */
93+
if (isnan(f))
94+
{
95+
return (sign << 7) | 0x7F; /* NaN: exp=0x1F, mant!=0 */
8396
}
84-
97+
8598
/* Get exponent and mantissa */
8699
int exp_val;
87100
float mant_f = frexpf(f, &exp_val);
88-
int exp = exp_val + 14; /* Bias is 15, but frexp gives us mantissa in [0.5, 1) */
89-
101+
int exp = exp_val + 14; /* Bias is 15, but frexp gives us mantissa in [0.5, 1) */
102+
90103
/* Clamp to representable range */
91-
if (exp < 0) {
104+
if (exp < 0)
105+
{
92106
/* Underflow to zero */
93107
return sign << 7;
94108
}
95-
if (exp > 30) {
109+
if (exp > 30)
110+
{
96111
/* Overflow to infinity */
97112
return (sign << 7) | 0x7C;
98113
}
99-
114+
100115
/* Handle denormalized numbers */
101-
if (exp == 0) {
116+
if (exp == 0)
117+
{
102118
float denorm_mant = f * ldexpf(1.0f, 14) * 4.0f;
103119
uint32_t mant_bits = (uint32_t)(denorm_mant + 0.5f);
104-
if (mant_bits > 3) mant_bits = 3;
120+
if (mant_bits > 3)
121+
mant_bits = 3;
105122
return (sign << 7) | mant_bits;
106123
}
107-
124+
108125
/* Normalized numbers: convert mantissa from [0.5, 1) to [0, 0.75] */
109126
mant_f = (mant_f - 0.5f) * 4.0f;
110-
uint32_t mant_bits = (uint32_t)(mant_f + 0.5f); /* Round to nearest */
111-
if (mant_bits > 3) mant_bits = 3;
112-
127+
uint32_t mant_bits = (uint32_t)(mant_f + 0.5f); /* Round to nearest */
128+
if (mant_bits > 3)
129+
mant_bits = 3;
130+
113131
return (uint8_t)((sign << 7) | ((exp & 0x1F) << 2) | (mant_bits & 0x3));
114132
}
115133

116-
typedef struct {
134+
typedef struct
135+
{
117136
uint32_t v[4];
118137
} uint4x32_t;
119138

@@ -141,7 +160,7 @@ CAMLprim value arrayjit_bfloat16_to_float(value v_bf16)
141160
}
142161

143162
/* Float to BFloat16 conversion (OCaml wrapper) */
144-
CAMLprim value arrayjit_float_to_bfloat16(value v_float)
163+
CAMLprim value arrayjit_float_to_bfloat16(value v_float)
145164
{
146165
CAMLparam1(v_float);
147166
float f = (float)Double_val(v_float);
@@ -167,45 +186,115 @@ CAMLprim value arrayjit_float_to_fp8(value v_float)
167186
CAMLreturn(Val_int(fp8));
168187
}
169188

170-
/* Efficient copying with padding support */
171-
CAMLprim value arrayjit_copy_with_padding(value v_source, value v_target,
172-
value v_source_dims, value v_padding)
189+
// TODO: a more efficient approach would involve computing strides once and using memcpy
190+
// for contiguous inner slices, but that adds complexity).
191+
CAMLprim value arrayjit_copy_with_padding(value v_source, value v_target, value v_padding)
173192
{
174-
CAMLparam4(v_source, v_target, v_source_dims, v_padding);
175-
176-
/* Get the bigarray data pointers */
177-
void* source_data = Caml_ba_data_val(v_source);
178-
void* target_data = Caml_ba_data_val(v_target);
179-
180-
/* Get element size in bytes */
181-
int kind = Caml_ba_kind_val(v_source);
182-
size_t elem_size;
183-
switch(kind) {
184-
case CAML_BA_FLOAT32: elem_size = 4; break;
185-
case CAML_BA_FLOAT64: elem_size = 8; break;
186-
case CAML_BA_SINT8:
187-
case CAML_BA_UINT8: elem_size = 1; break;
188-
case CAML_BA_SINT16:
189-
case CAML_BA_UINT16: elem_size = 2; break;
190-
case CAML_BA_INT32: elem_size = 4; break;
191-
case CAML_BA_COMPLEX64: elem_size = 16; break;
192-
default: elem_size = 8; break;
193+
CAMLparam3(v_source, v_target, v_padding);
194+
195+
struct caml_ba_array *source_ba = Caml_ba_array_val(v_source);
196+
struct caml_ba_array *target_ba = Caml_ba_array_val(v_target);
197+
int ndim = source_ba->num_dims;
198+
199+
if (ndim != target_ba->num_dims)
200+
{
201+
caml_failwith("Source and target must have the same number of dimensions");
202+
}
203+
204+
if (ndim == 0)
205+
{
206+
CAMLreturn(Val_unit);
193207
}
194-
195-
/* FIXME: For now, implement a simple flat copy */
196-
/* The proper padding-aware copy would require more complex logic */
197-
/* but this provides a foundation for optimization */
198-
struct caml_ba_array* source_ba = Caml_ba_array_val(v_source);
199-
intnat* source_dims_ba = source_ba->dim;
200-
int source_ndim = source_ba->num_dims;
201-
202-
size_t source_total = 1;
203-
for(int i = 0; i < source_ndim; i++) {
204-
source_total *= source_dims_ba[i];
208+
209+
void *source_data = Caml_ba_data_val(v_source);
210+
void *target_data = Caml_ba_data_val(v_target);
211+
212+
size_t elem_size = caml_ba_byte_size(source_ba);
213+
214+
// Use source dimensions directly from bigarray
215+
intnat *source_shape = source_ba->dim;
216+
217+
// Extract paddings
218+
if (Wosize_val(v_padding) != (uintnat)ndim)
219+
{
220+
caml_failwith("Padding array length mismatch");
221+
}
222+
intnat *left = malloc(ndim * sizeof(intnat));
223+
intnat *right = malloc(ndim * sizeof(intnat));
224+
if (left == NULL || right == NULL)
225+
caml_failwith("Malloc failed");
226+
for (int d = 0; d < ndim; d++)
227+
{
228+
value pad = Field(v_padding, d);
229+
left[d] = Long_val(Field(pad, 0));
230+
right[d] = Long_val(Field(pad, 1));
231+
if (left[d] < 0 || right[d] < 0)
232+
caml_failwith("Negative padding");
205233
}
206-
207-
/* FIXME: Simple memcpy for now - must implement proper padding */
208-
memcpy(target_data, source_data, source_total * elem_size);
209-
234+
235+
// Verify target dimensions match source + padding
236+
for (int d = 0; d < ndim; d++)
237+
{
238+
if (target_ba->dim[d] != source_shape[d] + left[d] + right[d])
239+
{
240+
caml_failwith("Target dimensions do not match source + padding");
241+
}
242+
}
243+
244+
// Multi-dimensional index loop
245+
intnat *indices = calloc(ndim, sizeof(intnat));
246+
if (indices == NULL)
247+
caml_failwith("Calloc failed");
248+
249+
while (1)
250+
{
251+
// Compute source flat offset
252+
intnat source_offset = 0;
253+
intnat s_stride = 1;
254+
for (int d = ndim - 1; d >= 0; d--)
255+
{
256+
source_offset += indices[d] * s_stride;
257+
s_stride *= source_shape[d];
258+
}
259+
260+
// Compute target flat offset with padding offset
261+
intnat target_offset = 0;
262+
intnat t_stride = 1;
263+
for (int d = ndim - 1; d >= 0; d--)
264+
{
265+
target_offset += (indices[d] + left[d]) * t_stride;
266+
t_stride *= target_ba->dim[d];
267+
}
268+
269+
// Copy the element
270+
memcpy((char *)target_data + target_offset * elem_size,
271+
(char *)source_data + source_offset * elem_size,
272+
elem_size);
273+
274+
// Increment indices (odometer-style)
275+
int carry = 1;
276+
for (int d = ndim - 1; d >= 0; d--)
277+
{
278+
if (carry == 0)
279+
break;
280+
indices[d] += carry;
281+
if (indices[d] < source_shape[d])
282+
{
283+
carry = 0;
284+
}
285+
else
286+
{
287+
indices[d] = 0;
288+
carry = 1;
289+
}
290+
}
291+
if (carry == 1)
292+
break; // Done
293+
}
294+
295+
free(indices);
296+
free(left);
297+
free(right);
298+
210299
CAMLreturn(Val_unit);
211-
}
300+
}

arrayjit/lib/ndarray.ml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -387,13 +387,12 @@ let hash_t nd = Nativeint.hash @@ to_native nd
387387

388388
(** C function declarations for efficient copying *)
389389
external copy_with_padding_c :
390-
('a, 'b) bigarray -> ('a, 'b) bigarray -> int array -> axis_padding array -> unit
390+
('a, 'b) bigarray -> ('a, 'b) bigarray -> axis_padding array -> unit
391391
= "arrayjit_copy_with_padding"
392392

393393
let copy_with_padding ~source ~target ~padding =
394-
let source_dims = dims source in
395394
let copy_impl source_arr target_arr =
396-
copy_with_padding_c source_arr target_arr source_dims padding
395+
copy_with_padding_c source_arr target_arr padding
397396
in
398397
map2 { f2 = copy_impl } source target
399398

0 commit comments

Comments
 (0)