Skip to content

Commit 24ee10f

Browse files
committed
Support padding via pre-padded ndarrays (my idea, Claude's code)
1 parent 85b227b commit 24ee10f

File tree

5 files changed

+158
-63
lines changed

5 files changed

+158
-63
lines changed

arrayjit/lib/ndarray.ml

Lines changed: 99 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -129,74 +129,126 @@ let init_bigarray_of_prec (type ocaml elt_t) (prec : (ocaml, elt_t) Ops.precisio
129129
let indices_to_offset ~dims ~idcs =
130130
Array.fold2_exn dims idcs ~init:0 ~f:(fun accu dim idx -> (accu * dim) + idx)
131131

132-
let create_bigarray (type ocaml elt_t) (prec : (ocaml, elt_t) Ops.precision) ~dims
132+
let create_bigarray (type ocaml elt_t) (prec : (ocaml, elt_t) Ops.precision) ~dims ~padding
133133
(init_op : Ops.init_op) : (ocaml, elt_t) bigarray =
134134
Option.iter Utils.settings.fixed_state_for_init ~f:(fun seed -> Rand.Lib.init seed);
135+
136+
(* Handle padding: dims already includes padding, compute unpadded dimensions *)
137+
let unpadded_dims, padding_info = match padding with
138+
| None -> dims, None
139+
| Some (pad_config, pad_value) ->
140+
let unpadded_dims = Array.map2_exn dims pad_config ~f:(fun dim (left, right) ->
141+
dim - left - right) in
142+
(unpadded_dims, Some (pad_config, pad_value))
143+
in
144+
145+
let arr = create_bigarray_of_prec prec dims in
146+
147+
(* Fill with padding value if padding is specified *)
148+
(match padding_info with
149+
| None -> ()
150+
| Some (_, pad_value) ->
151+
(* Fill the entire array with padding value using precision-specific fill *)
152+
(match prec with
153+
| Ops.Byte -> A.fill arr (Char.of_int_exn @@ Int.of_float pad_value)
154+
| Ops.Uint16 -> A.fill arr (Int.of_float pad_value)
155+
| Ops.Int32 -> A.fill arr (Int32.of_float pad_value)
156+
| Ops.Half -> A.fill arr pad_value
157+
| Ops.Bfloat16 -> A.fill arr (float_to_bfloat16 pad_value)
158+
| Ops.Fp8 -> A.fill arr (Char.of_int_exn @@ float_to_fp8 pad_value)
159+
| Ops.Single -> A.fill arr pad_value
160+
| Ops.Double -> A.fill arr pad_value));
161+
162+
(* Helper function to convert unpadded indices to padded indices *)
163+
let unpadded_to_padded_indices idcs =
164+
match padding_info with
165+
| None -> idcs
166+
| Some (pad_config, _) ->
167+
Array.map2_exn idcs pad_config ~f:(fun idx (left, _) -> idx + left)
168+
in
169+
170+
(* For non-constant fill operations, we need to iterate over unpadded dimensions *)
171+
let init_unpadded_region init_func =
172+
let rec loop_dims idcs dim_idx =
173+
if dim_idx = Array.length unpadded_dims then
174+
let padded_idcs = unpadded_to_padded_indices idcs in
175+
A.set arr padded_idcs (init_func idcs)
176+
else
177+
for i = 0 to unpadded_dims.(dim_idx) - 1 do
178+
idcs.(dim_idx) <- i;
179+
loop_dims idcs (dim_idx + 1)
180+
done
181+
in
182+
loop_dims (Array.create ~len:(Array.length unpadded_dims) 0) 0
183+
in
184+
135185
let constant_fill_f f values strict =
136186
let len = Array.length values in
137187
if strict then (
138-
let size = Array.fold ~init:1 ~f:( * ) dims in
188+
let size = Array.fold ~init:1 ~f:( * ) unpadded_dims in
139189
if size <> len then
140190
raise
141191
@@ Utils.User_error
142192
[%string
143193
"Ndarray.create_bigarray: Constant_fill: invalid data size %{len#Int}, expected \
144194
%{size#Int}"];
145-
init_bigarray_of_prec prec dims ~f:(fun idcs -> f values.(indices_to_offset ~dims ~idcs)))
195+
init_unpadded_region (fun idcs -> f values.(indices_to_offset ~dims:unpadded_dims ~idcs)))
146196
else
147-
init_bigarray_of_prec prec dims ~f:(fun idcs ->
148-
f values.(indices_to_offset ~dims ~idcs % len))
197+
init_unpadded_region (fun idcs ->
198+
f values.(indices_to_offset ~dims:unpadded_dims ~idcs % len))
149199
in
150200
let constant_fill_float values strict = constant_fill_f Fn.id values strict in
151-
match (prec, init_op) with
201+
202+
(match (prec, init_op) with
152203
| Ops.Byte, Constant_fill { values; strict } ->
153-
constant_fill_f (Fn.compose Char.of_int_exn Int.of_float) values strict
204+
ignore (constant_fill_f (Fn.compose Char.of_int_exn Int.of_float) values strict)
154205
| Ops.Byte, Range_over_offsets ->
155-
init_bigarray_of_prec prec dims ~f:(fun idcs ->
156-
Char.of_int_exn @@ indices_to_offset ~dims ~idcs)
157-
| Ops.Byte, Standard_uniform -> init_bigarray_of_prec prec dims ~f:(fun _ -> Rand.Lib.char ())
158-
| Ops.Uint16, Constant_fill { values; strict } -> constant_fill_f Int.of_float values strict
206+
init_unpadded_region (fun idcs -> Char.of_int_exn @@ indices_to_offset ~dims:unpadded_dims ~idcs)
207+
| Ops.Byte, Standard_uniform -> init_unpadded_region (fun _ -> Rand.Lib.char ())
208+
| Ops.Uint16, Constant_fill { values; strict } ->
209+
ignore (constant_fill_f Int.of_float values strict)
159210
| Ops.Uint16, Range_over_offsets ->
160-
init_bigarray_of_prec prec dims ~f:(fun idcs -> indices_to_offset ~dims ~idcs)
161-
| Ops.Uint16, Standard_uniform -> init_bigarray_of_prec prec dims ~f:(fun _ -> Random.int 65536)
162-
| Ops.Int32, Constant_fill { values; strict } -> constant_fill_f Int32.of_float values strict
211+
init_unpadded_region (fun idcs -> indices_to_offset ~dims:unpadded_dims ~idcs)
212+
| Ops.Uint16, Standard_uniform -> init_unpadded_region (fun _ -> Random.int 65536)
213+
| Ops.Int32, Constant_fill { values; strict } ->
214+
ignore (constant_fill_f Int32.of_float values strict)
163215
| Ops.Int32, Range_over_offsets ->
164-
init_bigarray_of_prec prec dims ~f:(fun idcs ->
165-
Int32.of_int_exn @@ indices_to_offset ~dims ~idcs)
216+
init_unpadded_region (fun idcs -> Int32.of_int_exn @@ indices_to_offset ~dims:unpadded_dims ~idcs)
166217
| Ops.Int32, Standard_uniform ->
167-
init_bigarray_of_prec prec dims ~f:(fun _ -> Random.int32 Int32.max_value)
168-
| Ops.Half, Constant_fill { values; strict } -> constant_fill_float values strict
218+
init_unpadded_region (fun _ -> Random.int32 Int32.max_value)
219+
| Ops.Half, Constant_fill { values; strict } -> ignore (constant_fill_float values strict)
169220
| Ops.Half, Range_over_offsets ->
170-
init_bigarray_of_prec prec dims ~f:(fun idcs -> Float.of_int @@ indices_to_offset ~dims ~idcs)
221+
init_unpadded_region (fun idcs -> Float.of_int @@ indices_to_offset ~dims:unpadded_dims ~idcs)
171222
| Ops.Half, Standard_uniform ->
172-
init_bigarray_of_prec prec dims ~f:(fun _ -> Rand.Lib.float_range 0.0 1.0)
223+
init_unpadded_region (fun _ -> Rand.Lib.float_range 0.0 1.0)
173224
| Ops.Bfloat16, Constant_fill { values; strict } ->
174-
constant_fill_f float_to_bfloat16 values strict
225+
ignore (constant_fill_f float_to_bfloat16 values strict)
175226
| Ops.Bfloat16, Range_over_offsets ->
176-
init_bigarray_of_prec prec dims ~f:(fun idcs ->
177-
float_to_bfloat16 @@ Float.of_int @@ indices_to_offset ~dims ~idcs)
227+
init_unpadded_region (fun idcs ->
228+
float_to_bfloat16 @@ Float.of_int @@ indices_to_offset ~dims:unpadded_dims ~idcs)
178229
| Ops.Bfloat16, Standard_uniform ->
179-
init_bigarray_of_prec prec dims ~f:(fun _ ->
180-
float_to_bfloat16 @@ Rand.Lib.float_range 0.0 1.0)
230+
init_unpadded_region (fun _ -> float_to_bfloat16 @@ Rand.Lib.float_range 0.0 1.0)
181231
| Ops.Fp8, Constant_fill { values; strict } ->
182-
constant_fill_f (Fn.compose Char.of_int_exn float_to_fp8) values strict
232+
ignore (constant_fill_f (Fn.compose Char.of_int_exn float_to_fp8) values strict)
183233
| Ops.Fp8, Range_over_offsets ->
184-
init_bigarray_of_prec prec dims ~f:(fun idcs ->
185-
Char.of_int_exn @@ float_to_fp8 @@ Float.of_int @@ indices_to_offset ~dims ~idcs)
234+
init_unpadded_region (fun idcs ->
235+
Char.of_int_exn @@ float_to_fp8 @@ Float.of_int @@ indices_to_offset ~dims:unpadded_dims ~idcs)
186236
| Ops.Fp8, Standard_uniform ->
187-
init_bigarray_of_prec prec dims ~f:(fun _ ->
188-
Char.of_int_exn @@ float_to_fp8 @@ Rand.Lib.float_range 0.0 1.0)
189-
| Ops.Single, Constant_fill { values; strict } -> constant_fill_float values strict
237+
init_unpadded_region (fun _ -> Char.of_int_exn @@ float_to_fp8 @@ Rand.Lib.float_range 0.0 1.0)
238+
| Ops.Single, Constant_fill { values; strict } -> ignore (constant_fill_float values strict)
190239
| Ops.Single, Range_over_offsets ->
191-
init_bigarray_of_prec prec dims ~f:(fun idcs -> Float.of_int @@ indices_to_offset ~dims ~idcs)
240+
init_unpadded_region (fun idcs -> Float.of_int @@ indices_to_offset ~dims:unpadded_dims ~idcs)
192241
| Ops.Single, Standard_uniform ->
193-
init_bigarray_of_prec prec dims ~f:(fun _ -> Rand.Lib.float_range 0.0 1.0)
194-
| Ops.Double, Constant_fill { values; strict } -> constant_fill_float values strict
242+
init_unpadded_region (fun _ -> Rand.Lib.float_range 0.0 1.0)
243+
| Ops.Double, Constant_fill { values; strict } -> ignore (constant_fill_float values strict)
195244
| Ops.Double, Range_over_offsets ->
196-
init_bigarray_of_prec prec dims ~f:(fun idcs -> Float.of_int @@ indices_to_offset ~dims ~idcs)
245+
init_unpadded_region (fun idcs -> Float.of_int @@ indices_to_offset ~dims:unpadded_dims ~idcs)
197246
| Ops.Double, Standard_uniform ->
198-
init_bigarray_of_prec prec dims ~f:(fun _ -> Rand.Lib.float_range 0.0 1.0)
247+
init_unpadded_region (fun _ -> Rand.Lib.float_range 0.0 1.0)
199248
| _, File_mapped (filename, stored_prec) ->
249+
(* For file mapping, we don't support padding yet - require no padding *)
250+
if Option.is_some padding then
251+
raise @@ Utils.User_error "Ndarray.create_bigarray: File_mapped initialization does not support padding";
200252
(* See: https://github.com/janestreet/torch/blob/master/src/torch/dataset_helper.ml#L3 *)
201253
if not @@ Ops.equal_prec stored_prec (Ops.pack_prec prec) then
202254
raise
@@ -215,12 +267,13 @@ let create_bigarray (type ocaml elt_t) (prec : (ocaml, elt_t) Ops.precision) ~di
215267
[%string
216268
"Ndarray.create_bigarray: File_mapped: invalid file bytes %{len#Int}, expected \
217269
%{size * Ops.prec_in_bytes stored_prec#Int}"]);
218-
let ba =
270+
let file_arr =
219271
Unix.map_file fd (precision_to_bigarray_kind prec) Bigarray.c_layout false dims
220272
~pos:(Int64.of_int 0)
221273
in
222274
Unix.close fd;
223-
ba
275+
A.blit file_arr arr);
276+
arr
224277

225278
(** {2 *** Accessing ***} *)
226279

@@ -331,7 +384,8 @@ let reset_bigarray (init_op : Ops.init_op) (type o b) (prec : (o, b) Ops.precisi
331384
| Ops.Int32, Constant_fill { values; strict } -> constant_set_f Int32.of_float values strict
332385
| Ops.Int32, Range_over_offsets ->
333386
set_bigarray arr ~f:(fun idcs -> Int32.of_int_exn @@ indices_to_offset ~dims ~idcs)
334-
| Ops.Int32, Standard_uniform -> set_bigarray arr ~f:(fun _ -> Random.int32 Int32.max_value)
387+
| Ops.Int32, Standard_uniform ->
388+
set_bigarray arr ~f:(fun _ -> Random.int32 Int32.max_value)
335389
| Ops.Half, Constant_fill { values; strict } -> constant_set_float values strict
336390
| Ops.Half, Range_over_offsets ->
337391
set_bigarray arr ~f:(fun idcs -> Float.of_int @@ indices_to_offset ~dims ~idcs)
@@ -357,7 +411,7 @@ let reset_bigarray (init_op : Ops.init_op) (type o b) (prec : (o, b) Ops.precisi
357411
| Ops.Double, Range_over_offsets ->
358412
set_bigarray arr ~f:(fun idcs -> Float.of_int @@ indices_to_offset ~dims ~idcs)
359413
| Ops.Double, Standard_uniform -> set_bigarray arr ~f:(fun _ -> Rand.Lib.float_range 0.0 1.0)
360-
| _, File_mapped _ -> A.blit (create_bigarray prec ~dims init_op) arr
414+
| _, File_mapped _ -> A.blit (create_bigarray prec ~dims ~padding:None init_op) arr
361415

362416
let reset init_op arr =
363417
let f arr = reset_bigarray init_op arr in
@@ -504,16 +558,17 @@ let hash_t nd = Nativeint.hash @@ to_native nd
504558

505559
let used_memory = Atomic.make 0
506560

507-
let%track7_sexp create_array ~debug:(_debug : string) (prec : Ops.prec) ~(dims : int array) init_op
508-
=
561+
let%track7_sexp create_array ~debug:(_debug : string) (prec : Ops.prec) ~(dims : int array)
562+
~padding init_op =
563+
(* dims already includes padding if padding is specified *)
509564
let size_in_bytes : int =
510565
(if Array.length dims = 0 then 0 else Array.reduce_exn dims ~f:( * )) * Ops.prec_in_bytes prec
511566
in
512567
let%track7_sexp finalizer (_result : t) =
513568
let _ : int = Atomic.fetch_and_add used_memory size_in_bytes in
514569
[%log3 "Deleting", _debug, ptr_to_string_hum _result]
515570
in
516-
let f prec = as_array prec @@ create_bigarray prec ~dims init_op in
571+
let f prec = as_array prec @@ create_bigarray prec ~dims ~padding init_op in
517572
let result = Ops.map_prec { f } prec in
518573
Stdlib.Gc.finalise finalizer result;
519574
let _ : int = Atomic.fetch_and_add used_memory size_in_bytes in
@@ -524,7 +579,7 @@ let%track7_sexp create_array ~debug:(_debug : string) (prec : Ops.prec) ~(dims :
524579
result
525580

526581
let empty_array prec =
527-
create_array prec ~dims:[||] (Constant_fill { values = [| 0.0 |]; strict = false })
582+
create_array prec ~dims:[||] ~padding:None (Constant_fill { values = [| 0.0 |]; strict = false })
528583

529584
let get_used_memory () = Atomic.get used_memory
530585

arrayjit/lib/tnode.ml

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ type t = {
7171
array : Nd.t option Lazy.t;
7272
prec : Ops.prec Lazy.t;
7373
dims : int array Lazy.t;
74+
padding : ((int * int) array * float) option Lazy.t;
75+
(** If the tensor node is pre-padded, this is the pair (left padding, right padding) and the
76+
padding value. *)
7477
size_in_bytes : int Lazy.t;
7578
id : int;
7679
label : string list;
@@ -94,6 +97,15 @@ let num_elems tn =
9497
let dims = Lazy.force tn.dims in
9598
if Array.is_empty dims then 0 else Array.reduce_exn dims ~f:( * )
9699

100+
let dims_without_padding tn =
101+
match Lazy.force tn.padding with
102+
| None -> Lazy.force tn.dims
103+
| Some (padding, _) ->
104+
let dims = Lazy.force tn.dims in
105+
Array.map2_exn dims padding ~f:(fun dim (left, right) -> dim - left - right)
106+
107+
let get_padding tn = Lazy.force tn.padding
108+
97109
let id { id; _ } = "n" ^ Int.to_string id
98110
let label a = String.concat ~sep:"_" a.label
99111
let is_alphanum_ = String.for_all ~f:(fun c -> Char.equal c '_' || Char.is_alphanum c)
@@ -522,12 +534,14 @@ end)
522534

523535
let registry = Registry.create 16
524536

525-
let create ?default_prec ~id ~label ~dims init_op =
537+
let create ?default_prec ~id ~label ~dims ~padding init_op =
526538
let debug = "Host array for " ^ get_debug_name ~id ~label () in
527539
let rec array =
528540
lazy
529541
(if is_hosted_force tn 30 then
530-
Some (Nd.create_array ~debug (Lazy.force prec) ~dims:(Lazy.force dims) init_op)
542+
Some
543+
(Nd.create_array ~debug (Lazy.force prec) ~dims:(Lazy.force dims) ~padding:(Lazy.force padding)
544+
init_op)
531545
else None)
532546
and prec =
533547
lazy
@@ -545,6 +559,7 @@ let create ?default_prec ~id ~label ~dims init_op =
545559
delayed_prec_unsafe;
546560
prec;
547561
dims;
562+
padding;
548563
size_in_bytes;
549564
id;
550565
label;
@@ -572,6 +587,7 @@ let find =
572587
prec = lazy initial_default_prec;
573588
delayed_prec_unsafe = Specified initial_default_prec;
574589
dims = lazy [||];
590+
padding = lazy None;
575591
size_in_bytes = lazy 0;
576592
id = -1;
577593
label = [];

arrayjit/test/test_numerical_types.ml

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ let test_bfloat16_conversions () =
1515

1616
(* Test round-trip through ndarray *)
1717
let arr =
18-
Ndarray.create_array ~debug:"test" Ops.bfloat16 ~dims:[| 3; 2 |]
18+
Ndarray.create_array ~debug:"test" Ops.bfloat16 ~dims:[| 3; 2 |] ~padding:None
1919
(Ops.Constant_fill { values = [| 1.0; 2.0; 3.14; -1.5; 0.125; 1000.0 |]; strict = true })
2020
in
2121

@@ -36,14 +36,43 @@ let test_fp8_conversions () =
3636

3737
(* Test round-trip through ndarray *)
3838
let arr =
39-
Ndarray.create_array ~debug:"test" Ops.fp8 ~dims:[| 2; 2 |]
39+
Ndarray.create_array ~debug:"test" Ops.fp8 ~dims:[| 2; 2 |] ~padding:None
4040
(Ops.Constant_fill { values = [| 1.0; 0.5; 2.0; -1.0 |]; strict = true })
4141
in
4242

4343
Stdio.printf "\nFP8 array values:\n";
4444
let flat_values = Ndarray.retrieve_flat_values arr in
4545
Array.iteri flat_values ~f:(fun i v -> Stdio.printf " [%d] = %.6f\n" i v)
4646

47+
let test_padding () =
48+
Stdio.printf "\n\nTesting padding functionality:\n";
49+
50+
(* Test padding with float32 array *)
51+
let padding_config = [| (1, 1); (2, 1) |] in (* left=1,right=1 for first dim; left=2,right=1 for second dim *)
52+
let padding_value = -999.0 in
53+
54+
let padded_dims = [| 4; 6 |] in (* (2+1+1) x (3+2+1) *)
55+
56+
let arr =
57+
Ndarray.create_array ~debug:"padded_test" Ops.single ~dims:padded_dims
58+
~padding:(Some (padding_config, padding_value))
59+
(Ops.Constant_fill { values = [| 1.0; 2.0; 3.0; 4.0; 5.0; 6.0 |]; strict = true })
60+
in
61+
62+
Stdio.printf "Padded array (dims 4x6, unpadded region 2x3):\n";
63+
let dims = Ndarray.dims arr in
64+
for i = 0 to dims.(0) - 1 do
65+
for j = 0 to dims.(1) - 1 do
66+
let idx = [| i; j |] in
67+
let value = Ndarray.get_as_float arr idx in
68+
Stdio.printf "%8.1f " value;
69+
done;
70+
Stdio.printf "\n"
71+
done;
72+
73+
Stdio.printf "\nExpected: padding value (-999.0) in margins, data values (1.0-6.0) in center region\n"
74+
4775
let () =
4876
test_bfloat16_conversions ();
49-
test_fp8_conversions ()
77+
test_fp8_conversions ();
78+
test_padding ()

lib/row.ml

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -511,16 +511,12 @@ let rec subst_dim env = function
511511
| Some (Solved_dim (Var v2)) when equal_dim_var v v2 -> default
512512
| Some (Solved_dim d) -> subst_dim env d
513513
| _ -> default)
514-
| Affine { solved; unsolved } ->
515-
(* Substitute variables in affine expression *)
516-
let new_unsolved = List.filter_map unsolved ~f:(fun (coeff, v) ->
514+
| Affine { solved = _; unsolved } as init ->
515+
List.fold unsolved ~init ~f:(fun acc (_coeff, v) ->
517516
match Map.find env.dim_env v with
518-
| Some (Solved_dim _) ->
519-
(* Variable is solved, need to incorporate into affine expression *)
520-
None (* Will be handled below *)
521-
| _ -> Some (coeff, v)) in
522-
(* FIXME: properly handle substitution of solved variables into affine expressions *)
523-
Affine { solved; unsolved = new_unsolved }
517+
| Some (Solved_dim d) ->
518+
s_dim_one v ~value:d ~in_:acc
519+
| _ -> acc)
524520

525521
let s_row_one v ~value:{ dims = more_dims; bcast; id = _ } ~in_ =
526522
match in_ with
@@ -614,8 +610,7 @@ let%debug5_sexp rec unify_dim ~stage (eq : dim * dim) (env : environment) :
614610
(* FIXME: For now, we can only unify identical affine expressions *)
615611
if equal_dim dim1 dim2 then ([], env)
616612
else raise @@ Shape_error ("Cannot unify different affine dimensions", [ Dim_mismatch [ dim1; dim2 ] ])
617-
| Affine _, _ | _, Affine _ ->
618-
(* Cannot unify affine with non-affine dimensions *)
613+
| Affine _, Dim _ | Dim _, Affine _ ->
619614
raise @@ Shape_error ("Cannot unify affine dimension with non-affine", [ Dim_mismatch [ dim1; dim2 ] ])
620615
| Var v, dim2 | dim2, Var v ->
621616
let ineqs = ref [] in

0 commit comments

Comments
 (0)