@@ -200,17 +200,36 @@ let get_voidptr_not_managed nd : unit Ctypes.ptr =
200200(* let open Ctypes in coerce (ptr @@ typ_of_bigarray_kind @@ Bigarray.Genarray.kind arr) (ptr void)
201201 (bigarray_start genarray arr) *)
202202
203- let set_from_float arr idx v =
203+ (* * Helper function to adjust indices by adding left padding when padding is specified *)
204+ let adjust_idx_for_padding ?padding idx =
205+ match padding with
206+ | None -> idx
207+ | Some padding_arr ->
208+ Array. mapi idx ~f: (fun i dim_idx ->
209+ if i < Array. length padding_arr then
210+ dim_idx + padding_arr.(i).left
211+ else dim_idx)
212+
213+ (* * Helper function to compute end index for iteration, respecting padding margins *)
214+ let compute_end_idx ?padding dims axis =
215+ match padding with
216+ | None -> dims.(axis) - 1
217+ | Some padding_arr when axis < Array. length padding_arr ->
218+ dims.(axis) - padding_arr.(axis).left - padding_arr.(axis).right - 1
219+ | Some _ -> dims.(axis) - 1
220+
221+ let set_from_float ?padding arr idx v =
222+ let adjusted_idx = adjust_idx_for_padding ?padding idx in
204223 match arr with
205- | Byte_nd arr -> A. set arr idx @@ Char. of_int_exn @@ Int. of_float v
206- | Uint16_nd arr -> A. set arr idx @@ Int. of_float v
207- | Int32_nd arr -> A. set arr idx @@ Int32. of_float v
208- | Uint4x32_nd arr -> A. set arr idx @@ Stdlib.Complex. { re = v; im = 0.0 }
209- | Half_nd arr -> A. set arr idx v
210- | Bfloat16_nd arr -> A. set arr idx @@ float_to_bfloat16 v
211- | Fp8_nd arr -> A. set arr idx @@ Char. of_int_exn @@ float_to_fp8 v
212- | Single_nd arr -> A. set arr idx v
213- | Double_nd arr -> A. set arr idx v
224+ | Byte_nd arr -> A. set arr adjusted_idx @@ Char. of_int_exn @@ Int. of_float v
225+ | Uint16_nd arr -> A. set arr adjusted_idx @@ Int. of_float v
226+ | Int32_nd arr -> A. set arr adjusted_idx @@ Int32. of_float v
227+ | Uint4x32_nd arr -> A. set arr adjusted_idx @@ Stdlib.Complex. { re = v; im = 0.0 }
228+ | Half_nd arr -> A. set arr adjusted_idx v
229+ | Bfloat16_nd arr -> A. set arr adjusted_idx @@ float_to_bfloat16 v
230+ | Fp8_nd arr -> A. set arr adjusted_idx @@ Char. of_int_exn @@ float_to_fp8 v
231+ | Single_nd arr -> A. set arr adjusted_idx v
232+ | Double_nd arr -> A. set arr adjusted_idx v
214233
215234let fill_from_float arr v =
216235 match arr with
@@ -224,13 +243,17 @@ let fill_from_float arr v =
224243 | Single_nd arr -> A. fill arr v
225244 | Double_nd arr -> A. fill arr v
226245
227- let fold_bigarray arr ~init ~f =
246+ let fold_bigarray ? padding arr ~init ~f =
228247 let dims = A. dims arr in
229248 let accu = ref init in
230249 let rec cloop idx col =
231- if col = Array. length idx then accu := f ! accu idx @@ A. get arr idx
250+ if col = Array. length idx then
251+ let adjusted_idx = adjust_idx_for_padding ?padding idx in
252+ accu := f ! accu idx @@ A. get arr adjusted_idx
232253 else
233- for j = 0 to Int. pred dims.(col) do
254+ let end_idx = compute_end_idx ?padding dims col
255+ in
256+ for j = 0 to end_idx do
234257 idx.(col) < - j;
235258 cloop idx (Int. succ col)
236259 done
@@ -239,40 +262,41 @@ let fold_bigarray arr ~init ~f =
239262 cloop (Array. create ~len 0 ) 0 ;
240263 ! accu
241264
242- let fold_as_float ~init ~f arr =
265+ let fold_as_float ? padding ~init ~f arr =
243266 match arr with
244267 | Byte_nd arr ->
245- fold_bigarray ~init ~f: (fun accu idx c -> f accu idx @@ Float. of_int @@ Char. to_int c) arr
246- | Uint16_nd arr -> fold_bigarray ~init ~f: (fun accu idx v -> f accu idx @@ Float. of_int v) arr
247- | Int32_nd arr -> fold_bigarray ~init ~f: (fun accu idx v -> f accu idx @@ Int32. to_float v) arr
248- | Uint4x32_nd arr -> fold_bigarray ~init ~f: (fun accu idx c -> f accu idx c.Stdlib.Complex. re) arr
249- | Half_nd arr -> fold_bigarray ~init ~f arr
268+ fold_bigarray ?padding ~init ~f: (fun accu idx c -> f accu idx @@ Float. of_int @@ Char. to_int c) arr
269+ | Uint16_nd arr -> fold_bigarray ?padding ~init ~f: (fun accu idx v -> f accu idx @@ Float. of_int v) arr
270+ | Int32_nd arr -> fold_bigarray ?padding ~init ~f: (fun accu idx v -> f accu idx @@ Int32. to_float v) arr
271+ | Uint4x32_nd arr -> fold_bigarray ?padding ~init ~f: (fun accu idx c -> f accu idx c.Stdlib.Complex. re) arr
272+ | Half_nd arr -> fold_bigarray ?padding ~init ~f arr
250273 | Bfloat16_nd arr ->
251- fold_bigarray ~init ~f: (fun accu idx v -> f accu idx @@ bfloat16_to_float v) arr
274+ fold_bigarray ?padding ~init ~f: (fun accu idx v -> f accu idx @@ bfloat16_to_float v) arr
252275 | Fp8_nd arr ->
253- fold_bigarray ~init ~f: (fun accu idx c -> f accu idx @@ fp8_to_float @@ Char. to_int c) arr
254- | Single_nd arr -> fold_bigarray ~init ~f arr
255- | Double_nd arr -> fold_bigarray ~init ~f arr
276+ fold_bigarray ?padding ~init ~f: (fun accu idx c -> f accu idx @@ fp8_to_float @@ Char. to_int c) arr
277+ | Single_nd arr -> fold_bigarray ?padding ~init ~f arr
278+ | Double_nd arr -> fold_bigarray ?padding ~init ~f arr
256279
257280let size_in_bytes v =
258281 (* Cheating here because 1 number Bigarray is same size as empty Bigarray: it's more informative
259282 to report the cases differently. *)
260283 let f arr = if Array. is_empty @@ A. dims arr then 0 else A. size_in_bytes arr in
261284 apply { f } v
262285
263- let get_as_float arr idx =
286+ let get_as_float ?padding arr idx =
287+ let adjusted_idx = adjust_idx_for_padding ?padding idx in
264288 match arr with
265- | Byte_nd arr -> Float. of_int @@ Char. to_int @@ A. get arr idx
266- | Uint16_nd arr -> Float. of_int @@ A. get arr idx
267- | Int32_nd arr -> Int32. to_float @@ A. get arr idx
268- | Uint4x32_nd arr -> (A. get arr idx ).Stdlib.Complex. re
269- | Half_nd arr -> A. get arr idx
270- | Bfloat16_nd arr -> bfloat16_to_float @@ A. get arr idx
271- | Fp8_nd arr -> fp8_to_float @@ Char. to_int @@ A. get arr idx
272- | Single_nd arr -> A. get arr idx
273- | Double_nd arr -> A. get arr idx
274-
275- let retrieve_2d_points ?from_axis ~xdim ~ydim arr =
289+ | Byte_nd arr -> Float. of_int @@ Char. to_int @@ A. get arr adjusted_idx
290+ | Uint16_nd arr -> Float. of_int @@ A. get arr adjusted_idx
291+ | Int32_nd arr -> Int32. to_float @@ A. get arr adjusted_idx
292+ | Uint4x32_nd arr -> (A. get arr adjusted_idx ).Stdlib.Complex. re
293+ | Half_nd arr -> A. get arr adjusted_idx
294+ | Bfloat16_nd arr -> bfloat16_to_float @@ A. get arr adjusted_idx
295+ | Fp8_nd arr -> fp8_to_float @@ Char. to_int @@ A. get arr adjusted_idx
296+ | Single_nd arr -> A. get arr adjusted_idx
297+ | Double_nd arr -> A. get arr adjusted_idx
298+
299+ let retrieve_2d_points ?from_axis ? padding ~xdim ~ydim arr =
276300 let dims = dims arr in
277301 if Array. is_empty dims then [||]
278302 else
@@ -284,24 +308,26 @@ let retrieve_2d_points ?from_axis ~xdim ~ydim arr =
284308 if axis = n_axes then
285309 let x =
286310 idx.(from_axis) < - xdim;
287- get_as_float arr idx
311+ get_as_float ?padding arr idx
288312 in
289313 let y =
290314 idx.(from_axis) < - ydim;
291- get_as_float arr idx
315+ get_as_float ?padding arr idx
292316 in
293317 result := (x, y) :: ! result
294318 else if axis = from_axis then iter (axis + 1 )
295319 else
296- for p = 0 to dims.(axis) - 1 do
320+ let end_idx = compute_end_idx ?padding dims axis
321+ in
322+ for p = 0 to end_idx do
297323 idx.(axis) < - p;
298324 iter (axis + 1 )
299325 done
300326 in
301327 iter 0 ;
302328 Array. of_list_rev ! result
303329
304- let retrieve_1d_points ?from_axis ~xdim arr =
330+ let retrieve_1d_points ?from_axis ? padding ~xdim arr =
305331 let dims = dims arr in
306332 if Array. is_empty dims then [||]
307333 else
@@ -313,20 +339,22 @@ let retrieve_1d_points ?from_axis ~xdim arr =
313339 if axis = n_axes then
314340 let x =
315341 idx.(from_axis) < - xdim;
316- get_as_float arr idx
342+ get_as_float ?padding arr idx
317343 in
318344 result := x :: ! result
319345 else if axis = from_axis then iter (axis + 1 )
320346 else
321- for p = 0 to dims.(axis) - 1 do
347+ let end_idx = compute_end_idx ?padding dims axis
348+ in
349+ for p = 0 to end_idx do
322350 idx.(axis) < - p;
323351 iter (axis + 1 )
324352 done
325353 in
326354 iter 0 ;
327355 Array. of_list_rev ! result
328356
329- let retrieve_flat_values arr =
357+ let retrieve_flat_values ? padding arr =
330358 let dims = dims arr in
331359 if Array. is_empty dims then [||]
332360 else
@@ -335,18 +363,40 @@ let retrieve_flat_values arr =
335363 let idx = Array. create ~len: n_axes 0 in
336364 let rec iter axis =
337365 if axis = n_axes then
338- let x = get_as_float arr idx in
366+ let x = get_as_float ?padding arr idx in
339367 result := x :: ! result
340368 else
341- for p = 0 to dims.(axis) - 1 do
369+ let end_idx = compute_end_idx ?padding dims axis
370+ in
371+ for p = 0 to end_idx do
342372 idx.(axis) < - p;
343373 iter (axis + 1 )
344374 done
345375 in
346376 iter 0 ;
347377 Array. of_list_rev ! result
348378
349- let set_flat_values _arr _values = ()
379+ let set_flat_values ?padding arr values =
380+ let dims = dims arr in
381+ if not (Array. is_empty dims) then
382+ let n_axes = Array. length dims in
383+ let idx = Array. create ~len: n_axes 0 in
384+ let values_idx = ref 0 in
385+ let rec iter axis =
386+ if axis = n_axes then (
387+ if ! values_idx < Array. length values then (
388+ set_from_float ?padding arr idx values.(! values_idx);
389+ Int. incr values_idx
390+ ))
391+ else
392+ let end_idx = compute_end_idx ?padding dims axis
393+ in
394+ for p = 0 to end_idx do
395+ idx.(axis) < - p;
396+ iter (axis + 1 )
397+ done
398+ in
399+ iter 0
350400
351401let c_ptr_to_string nd =
352402 let prec = get_prec nd in
@@ -410,12 +460,25 @@ let get_used_memory () = Atomic.get used_memory
410460
411461(* * Dimensions to string, ["x"]-separated, e.g. 1x2x3 for batch dims 1, input dims 3, output dims 2.
412462 Outputs ["-"] for empty dimensions. *)
413- let int_dims_to_string ?(with_axis_numbers = false ) dims =
463+ let int_dims_to_string ?(with_axis_numbers = false ) ? padding dims =
414464 if Array. is_empty dims then " -"
415465 else if with_axis_numbers then
416466 String. concat_array ~sep: " x "
417467 @@ Array. mapi dims ~f: (fun d s -> Int. to_string d ^ " :" ^ Int. to_string s)
418- else String. concat_array ~sep: " x" @@ Array. map dims ~f: Int. to_string
468+ else
469+ let dim_strings = Array. mapi dims ~f: (fun i dim ->
470+ match padding with
471+ | None -> Int. to_string dim
472+ | Some padding_arr when i < Array. length padding_arr ->
473+ let unpadded_dim = dim - padding_arr.(i).left - padding_arr.(i).right in
474+ let total_padding = padding_arr.(i).left + padding_arr.(i).right in
475+ if total_padding > 0 then
476+ Int. to_string unpadded_dim ^ " +" ^ Int. to_string total_padding
477+ else
478+ Int. to_string dim
479+ | Some _ -> Int. to_string dim
480+ ) in
481+ String. concat_array ~sep: " x" dim_strings
419482
420483(* * Logs information about the array on the default ppx_minidebug runtime, if
421484 [from_log_level > Utlis.settings.with_log_level]. *)
0 commit comments