Skip to content

Commit 6634fe5

Browse files
committed
Full support for padding in ndarray.ml, by Claude Sonnet
1 parent 8044a09 commit 6634fe5

File tree

4 files changed

+137
-64
lines changed

4 files changed

+137
-64
lines changed

arrayjit/lib/ndarray.ml

Lines changed: 110 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -200,17 +200,36 @@ let get_voidptr_not_managed nd : unit Ctypes.ptr =
200200
(* let open Ctypes in coerce (ptr @@ typ_of_bigarray_kind @@ Bigarray.Genarray.kind arr) (ptr void)
201201
(bigarray_start genarray arr) *)
202202

203-
let set_from_float arr idx v =
203+
(** Helper function to adjust indices by adding left padding when padding is specified *)
204+
let adjust_idx_for_padding ?padding idx =
205+
match padding with
206+
| None -> idx
207+
| Some padding_arr ->
208+
Array.mapi idx ~f:(fun i dim_idx ->
209+
if i < Array.length padding_arr then
210+
dim_idx + padding_arr.(i).left
211+
else dim_idx)
212+
213+
(** Helper function to compute end index for iteration, respecting padding margins *)
214+
let compute_end_idx ?padding dims axis =
215+
match padding with
216+
| None -> dims.(axis) - 1
217+
| Some padding_arr when axis < Array.length padding_arr ->
218+
dims.(axis) - padding_arr.(axis).left - padding_arr.(axis).right - 1
219+
| Some _ -> dims.(axis) - 1
220+
221+
let set_from_float ?padding arr idx v =
222+
let adjusted_idx = adjust_idx_for_padding ?padding idx in
204223
match arr with
205-
| Byte_nd arr -> A.set arr idx @@ Char.of_int_exn @@ Int.of_float v
206-
| Uint16_nd arr -> A.set arr idx @@ Int.of_float v
207-
| Int32_nd arr -> A.set arr idx @@ Int32.of_float v
208-
| Uint4x32_nd arr -> A.set arr idx @@ Stdlib.Complex.{ re = v; im = 0.0 }
209-
| Half_nd arr -> A.set arr idx v
210-
| Bfloat16_nd arr -> A.set arr idx @@ float_to_bfloat16 v
211-
| Fp8_nd arr -> A.set arr idx @@ Char.of_int_exn @@ float_to_fp8 v
212-
| Single_nd arr -> A.set arr idx v
213-
| Double_nd arr -> A.set arr idx v
224+
| Byte_nd arr -> A.set arr adjusted_idx @@ Char.of_int_exn @@ Int.of_float v
225+
| Uint16_nd arr -> A.set arr adjusted_idx @@ Int.of_float v
226+
| Int32_nd arr -> A.set arr adjusted_idx @@ Int32.of_float v
227+
| Uint4x32_nd arr -> A.set arr adjusted_idx @@ Stdlib.Complex.{ re = v; im = 0.0 }
228+
| Half_nd arr -> A.set arr adjusted_idx v
229+
| Bfloat16_nd arr -> A.set arr adjusted_idx @@ float_to_bfloat16 v
230+
| Fp8_nd arr -> A.set arr adjusted_idx @@ Char.of_int_exn @@ float_to_fp8 v
231+
| Single_nd arr -> A.set arr adjusted_idx v
232+
| Double_nd arr -> A.set arr adjusted_idx v
214233

215234
let fill_from_float arr v =
216235
match arr with
@@ -224,13 +243,17 @@ let fill_from_float arr v =
224243
| Single_nd arr -> A.fill arr v
225244
| Double_nd arr -> A.fill arr v
226245

227-
let fold_bigarray arr ~init ~f =
246+
let fold_bigarray ?padding arr ~init ~f =
228247
let dims = A.dims arr in
229248
let accu = ref init in
230249
let rec cloop idx col =
231-
if col = Array.length idx then accu := f !accu idx @@ A.get arr idx
250+
if col = Array.length idx then
251+
let adjusted_idx = adjust_idx_for_padding ?padding idx in
252+
accu := f !accu idx @@ A.get arr adjusted_idx
232253
else
233-
for j = 0 to Int.pred dims.(col) do
254+
let end_idx = compute_end_idx ?padding dims col
255+
in
256+
for j = 0 to end_idx do
234257
idx.(col) <- j;
235258
cloop idx (Int.succ col)
236259
done
@@ -239,40 +262,41 @@ let fold_bigarray arr ~init ~f =
239262
cloop (Array.create ~len 0) 0;
240263
!accu
241264

242-
let fold_as_float ~init ~f arr =
265+
let fold_as_float ?padding ~init ~f arr =
243266
match arr with
244267
| Byte_nd arr ->
245-
fold_bigarray ~init ~f:(fun accu idx c -> f accu idx @@ Float.of_int @@ Char.to_int c) arr
246-
| Uint16_nd arr -> fold_bigarray ~init ~f:(fun accu idx v -> f accu idx @@ Float.of_int v) arr
247-
| Int32_nd arr -> fold_bigarray ~init ~f:(fun accu idx v -> f accu idx @@ Int32.to_float v) arr
248-
| Uint4x32_nd arr -> fold_bigarray ~init ~f:(fun accu idx c -> f accu idx c.Stdlib.Complex.re) arr
249-
| Half_nd arr -> fold_bigarray ~init ~f arr
268+
fold_bigarray ?padding ~init ~f:(fun accu idx c -> f accu idx @@ Float.of_int @@ Char.to_int c) arr
269+
| Uint16_nd arr -> fold_bigarray ?padding ~init ~f:(fun accu idx v -> f accu idx @@ Float.of_int v) arr
270+
| Int32_nd arr -> fold_bigarray ?padding ~init ~f:(fun accu idx v -> f accu idx @@ Int32.to_float v) arr
271+
| Uint4x32_nd arr -> fold_bigarray ?padding ~init ~f:(fun accu idx c -> f accu idx c.Stdlib.Complex.re) arr
272+
| Half_nd arr -> fold_bigarray ?padding ~init ~f arr
250273
| Bfloat16_nd arr ->
251-
fold_bigarray ~init ~f:(fun accu idx v -> f accu idx @@ bfloat16_to_float v) arr
274+
fold_bigarray ?padding ~init ~f:(fun accu idx v -> f accu idx @@ bfloat16_to_float v) arr
252275
| Fp8_nd arr ->
253-
fold_bigarray ~init ~f:(fun accu idx c -> f accu idx @@ fp8_to_float @@ Char.to_int c) arr
254-
| Single_nd arr -> fold_bigarray ~init ~f arr
255-
| Double_nd arr -> fold_bigarray ~init ~f arr
276+
fold_bigarray ?padding ~init ~f:(fun accu idx c -> f accu idx @@ fp8_to_float @@ Char.to_int c) arr
277+
| Single_nd arr -> fold_bigarray ?padding ~init ~f arr
278+
| Double_nd arr -> fold_bigarray ?padding ~init ~f arr
256279

257280
let size_in_bytes v =
258281
(* Cheating here because 1 number Bigarray is same size as empty Bigarray: it's more informative
259282
to report the cases differently. *)
260283
let f arr = if Array.is_empty @@ A.dims arr then 0 else A.size_in_bytes arr in
261284
apply { f } v
262285

263-
let get_as_float arr idx =
286+
let get_as_float ?padding arr idx =
287+
let adjusted_idx = adjust_idx_for_padding ?padding idx in
264288
match arr with
265-
| Byte_nd arr -> Float.of_int @@ Char.to_int @@ A.get arr idx
266-
| Uint16_nd arr -> Float.of_int @@ A.get arr idx
267-
| Int32_nd arr -> Int32.to_float @@ A.get arr idx
268-
| Uint4x32_nd arr -> (A.get arr idx).Stdlib.Complex.re
269-
| Half_nd arr -> A.get arr idx
270-
| Bfloat16_nd arr -> bfloat16_to_float @@ A.get arr idx
271-
| Fp8_nd arr -> fp8_to_float @@ Char.to_int @@ A.get arr idx
272-
| Single_nd arr -> A.get arr idx
273-
| Double_nd arr -> A.get arr idx
274-
275-
let retrieve_2d_points ?from_axis ~xdim ~ydim arr =
289+
| Byte_nd arr -> Float.of_int @@ Char.to_int @@ A.get arr adjusted_idx
290+
| Uint16_nd arr -> Float.of_int @@ A.get arr adjusted_idx
291+
| Int32_nd arr -> Int32.to_float @@ A.get arr adjusted_idx
292+
| Uint4x32_nd arr -> (A.get arr adjusted_idx).Stdlib.Complex.re
293+
| Half_nd arr -> A.get arr adjusted_idx
294+
| Bfloat16_nd arr -> bfloat16_to_float @@ A.get arr adjusted_idx
295+
| Fp8_nd arr -> fp8_to_float @@ Char.to_int @@ A.get arr adjusted_idx
296+
| Single_nd arr -> A.get arr adjusted_idx
297+
| Double_nd arr -> A.get arr adjusted_idx
298+
299+
let retrieve_2d_points ?from_axis ?padding ~xdim ~ydim arr =
276300
let dims = dims arr in
277301
if Array.is_empty dims then [||]
278302
else
@@ -284,24 +308,26 @@ let retrieve_2d_points ?from_axis ~xdim ~ydim arr =
284308
if axis = n_axes then
285309
let x =
286310
idx.(from_axis) <- xdim;
287-
get_as_float arr idx
311+
get_as_float ?padding arr idx
288312
in
289313
let y =
290314
idx.(from_axis) <- ydim;
291-
get_as_float arr idx
315+
get_as_float ?padding arr idx
292316
in
293317
result := (x, y) :: !result
294318
else if axis = from_axis then iter (axis + 1)
295319
else
296-
for p = 0 to dims.(axis) - 1 do
320+
let end_idx = compute_end_idx ?padding dims axis
321+
in
322+
for p = 0 to end_idx do
297323
idx.(axis) <- p;
298324
iter (axis + 1)
299325
done
300326
in
301327
iter 0;
302328
Array.of_list_rev !result
303329

304-
let retrieve_1d_points ?from_axis ~xdim arr =
330+
let retrieve_1d_points ?from_axis ?padding ~xdim arr =
305331
let dims = dims arr in
306332
if Array.is_empty dims then [||]
307333
else
@@ -313,20 +339,22 @@ let retrieve_1d_points ?from_axis ~xdim arr =
313339
if axis = n_axes then
314340
let x =
315341
idx.(from_axis) <- xdim;
316-
get_as_float arr idx
342+
get_as_float ?padding arr idx
317343
in
318344
result := x :: !result
319345
else if axis = from_axis then iter (axis + 1)
320346
else
321-
for p = 0 to dims.(axis) - 1 do
347+
let end_idx = compute_end_idx ?padding dims axis
348+
in
349+
for p = 0 to end_idx do
322350
idx.(axis) <- p;
323351
iter (axis + 1)
324352
done
325353
in
326354
iter 0;
327355
Array.of_list_rev !result
328356

329-
let retrieve_flat_values arr =
357+
let retrieve_flat_values ?padding arr =
330358
let dims = dims arr in
331359
if Array.is_empty dims then [||]
332360
else
@@ -335,18 +363,40 @@ let retrieve_flat_values arr =
335363
let idx = Array.create ~len:n_axes 0 in
336364
let rec iter axis =
337365
if axis = n_axes then
338-
let x = get_as_float arr idx in
366+
let x = get_as_float ?padding arr idx in
339367
result := x :: !result
340368
else
341-
for p = 0 to dims.(axis) - 1 do
369+
let end_idx = compute_end_idx ?padding dims axis
370+
in
371+
for p = 0 to end_idx do
342372
idx.(axis) <- p;
343373
iter (axis + 1)
344374
done
345375
in
346376
iter 0;
347377
Array.of_list_rev !result
348378

349-
let set_flat_values _arr _values = ()
379+
let set_flat_values ?padding arr values =
380+
let dims = dims arr in
381+
if not (Array.is_empty dims) then
382+
let n_axes = Array.length dims in
383+
let idx = Array.create ~len:n_axes 0 in
384+
let values_idx = ref 0 in
385+
let rec iter axis =
386+
if axis = n_axes then (
387+
if !values_idx < Array.length values then (
388+
set_from_float ?padding arr idx values.(!values_idx);
389+
Int.incr values_idx
390+
))
391+
else
392+
let end_idx = compute_end_idx ?padding dims axis
393+
in
394+
for p = 0 to end_idx do
395+
idx.(axis) <- p;
396+
iter (axis + 1)
397+
done
398+
in
399+
iter 0
350400

351401
let c_ptr_to_string nd =
352402
let prec = get_prec nd in
@@ -410,12 +460,25 @@ let get_used_memory () = Atomic.get used_memory
410460

411461
(** Dimensions to string, ["x"]-separated, e.g. 1x2x3 for batch dims 1, input dims 3, output dims 2.
412462
Outputs ["-"] for empty dimensions. *)
413-
let int_dims_to_string ?(with_axis_numbers = false) dims =
463+
let int_dims_to_string ?(with_axis_numbers = false) ?padding dims =
414464
if Array.is_empty dims then "-"
415465
else if with_axis_numbers then
416466
String.concat_array ~sep:" x "
417467
@@ Array.mapi dims ~f:(fun d s -> Int.to_string d ^ ":" ^ Int.to_string s)
418-
else String.concat_array ~sep:"x" @@ Array.map dims ~f:Int.to_string
468+
else
469+
let dim_strings = Array.mapi dims ~f:(fun i dim ->
470+
match padding with
471+
| None -> Int.to_string dim
472+
| Some padding_arr when i < Array.length padding_arr ->
473+
let unpadded_dim = dim - padding_arr.(i).left - padding_arr.(i).right in
474+
let total_padding = padding_arr.(i).left + padding_arr.(i).right in
475+
if total_padding > 0 then
476+
Int.to_string unpadded_dim ^ "+" ^ Int.to_string total_padding
477+
else
478+
Int.to_string dim
479+
| Some _ -> Int.to_string dim
480+
) in
481+
String.concat_array ~sep:"x" dim_strings
419482

420483
(** Logs information about the array on the default ppx_minidebug runtime, if
421484
[from_log_level > Utlis.settings.with_log_level]. *)

arrayjit/lib/tnode.ml

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,9 @@ let has a = match a.array with (lazy (Some _)) -> true | _ -> false
458458

459459
let dims_to_string ?(with_axis_numbers = false) arr =
460460
let dims_s =
461-
if Lazy.is_val arr.dims then Nd.int_dims_to_string ~with_axis_numbers @@ Lazy.force arr.dims
461+
if Lazy.is_val arr.dims then
462+
let padding = Option.map ~f:fst (Lazy.force arr.padding) in
463+
Nd.int_dims_to_string ~with_axis_numbers ?padding @@ Lazy.force arr.dims
462464
else "<not-in-yet>"
463465
in
464466
Ops.prec_string (Lazy.force arr.prec) ^ " prec " ^ dims_s
@@ -723,29 +725,35 @@ let do_write tn =
723725

724726
let points_1d ?from_axis ~xdim tn =
725727
do_read tn;
726-
Option.value_map ~default:[||] ~f:(fun arr -> Nd.retrieve_1d_points ?from_axis ~xdim arr)
728+
let padding = Option.map ~f:fst (Lazy.force tn.padding) in
729+
Option.value_map ~default:[||] ~f:(fun arr -> Nd.retrieve_1d_points ?from_axis ?padding ~xdim arr)
727730
@@ Lazy.force tn.array
728731

729732
let points_2d ?from_axis ~xdim ~ydim tn =
730733
do_read tn;
731-
Option.value_map ~default:[||] ~f:(fun arr -> Nd.retrieve_2d_points ?from_axis ~xdim ~ydim arr)
734+
let padding = Option.map ~f:fst (Lazy.force tn.padding) in
735+
Option.value_map ~default:[||] ~f:(fun arr -> Nd.retrieve_2d_points ?from_axis ?padding ~xdim ~ydim arr)
732736
@@ Lazy.force tn.array
733737

734738
let set_value tn =
735739
do_write tn;
736-
Nd.set_from_float @@ Option.value_exn ~here:[%here] @@ Lazy.force tn.array
740+
let padding = Option.map ~f:fst (Lazy.force tn.padding) in
741+
Nd.set_from_float ?padding @@ Option.value_exn ~here:[%here] @@ Lazy.force tn.array
737742

738743
let get_value tn =
739744
do_read tn;
740-
Nd.get_as_float @@ Option.value_exn ~here:[%here] @@ Lazy.force tn.array
745+
let padding = Option.map ~f:fst (Lazy.force tn.padding) in
746+
Nd.get_as_float ?padding @@ Option.value_exn ~here:[%here] @@ Lazy.force tn.array
741747

742748
let set_values tn values =
743749
do_write tn;
744-
Nd.(set_flat_values values @@ Option.value_exn ~here:[%here] @@ Lazy.force tn.array)
750+
let padding = Option.map ~f:fst (Lazy.force tn.padding) in
751+
Nd.(set_flat_values ?padding (Option.value_exn ~here:[%here] @@ Lazy.force tn.array) values)
745752

746753
let get_values tn =
747754
do_read tn;
748-
Nd.(retrieve_flat_values @@ Option.value_exn ~here:[%here] @@ Lazy.force tn.array)
755+
let padding = Option.map ~f:fst (Lazy.force tn.padding) in
756+
Nd.(retrieve_flat_values ?padding @@ Option.value_exn ~here:[%here] @@ Lazy.force tn.array)
749757

750758
let print_accessible_headers () =
751759
Stdio.printf "Tnode: collecting accessible arrays...%!\n";

arrayjit/test/test_numerical_types.expected

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ Testing BFloat16 conversions:
1313
-inf -> 0xff80 -> -inf
1414

1515
BFloat16 array values:
16-
[0] = 1.000000
17-
[1] = 2.000000
18-
[2] = 3.140625
19-
[3] = -1.500000
20-
[4] = 0.125000
16+
[0] = 0.000000
17+
[1] = 1.000000
18+
[2] = -1.000000
19+
[3] = 3.140625
20+
[4] = 0.000999
2121
[5] = 1000.000000
2222

2323

@@ -32,10 +32,10 @@ Testing FP8 conversions:
3232
-0.250000 -> 0xb4 -> -0.250000
3333

3434
FP8 array values:
35-
[0] = 1.000000
36-
[1] = 0.500000
37-
[2] = 2.000000
38-
[3] = -1.000000
35+
[0] = 0.000000
36+
[1] = 1.000000
37+
[2] = -1.000000
38+
[3] = 0.500000
3939

4040

4141
Testing padding functionality:

arrayjit/test/test_numerical_types.ml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ let test_bfloat16_conversions () =
1515

1616
(* Test round-trip through ndarray *)
1717
let arr = Ndarray.create_array ~debug:"test" Ops.bfloat16 ~dims:[| 3; 2 |] ~padding:None in
18+
Ndarray.set_flat_values arr (Array.of_list test_values);
1819

1920
Stdio.printf "\nBFloat16 array values:\n";
2021
let flat_values = Ndarray.retrieve_flat_values arr in
@@ -33,6 +34,7 @@ let test_fp8_conversions () =
3334

3435
(* Test round-trip through ndarray *)
3536
let arr = Ndarray.create_array ~debug:"test" Ops.fp8 ~dims:[| 2; 2 |] ~padding:None in
37+
Ndarray.set_flat_values arr (Array.of_list test_values);
3638

3739
Stdio.printf "\nFP8 array values:\n";
3840
let flat_values = Ndarray.retrieve_flat_values arr in
@@ -53,7 +55,7 @@ let test_padding () =
5355
Ndarray.create_array ~debug:"padded_test" Ops.single ~dims:padded_dims
5456
~padding:(Some (padding_config, padding_value))
5557
in
56-
58+
Ndarray.set_flat_values ~padding:padding_config arr [| 1.0; 2.0; 3.0; 4.0; 5.0; 6.0 |];
5759
Stdio.printf "Padded array (dims 4x6, unpadded region 2x3):\n";
5860
let dims = Ndarray.dims arr in
5961
for i = 0 to dims.(0) - 1 do

0 commit comments

Comments
 (0)