@@ -129,74 +129,126 @@ let init_bigarray_of_prec (type ocaml elt_t) (prec : (ocaml, elt_t) Ops.precisio
129129let indices_to_offset ~dims ~idcs =
130130 Array. fold2_exn dims idcs ~init: 0 ~f: (fun accu dim idx -> (accu * dim) + idx)
131131
132- let create_bigarray (type ocaml elt_t ) (prec : (ocaml, elt_t) Ops.precision ) ~dims
132+ let create_bigarray (type ocaml elt_t ) (prec : (ocaml, elt_t) Ops.precision ) ~dims ~ padding
133133 (init_op : Ops.init_op ) : (ocaml, elt_t) bigarray =
134134 Option. iter Utils. settings.fixed_state_for_init ~f: (fun seed -> Rand.Lib. init seed);
135+
136+ (* Handle padding: dims already includes padding, compute unpadded dimensions *)
137+ let unpadded_dims, padding_info = match padding with
138+ | None -> dims, None
139+ | Some (pad_config , pad_value ) ->
140+ let unpadded_dims = Array. map2_exn dims pad_config ~f: (fun dim (left , right ) ->
141+ dim - left - right) in
142+ (unpadded_dims, Some (pad_config, pad_value))
143+ in
144+
145+ let arr = create_bigarray_of_prec prec dims in
146+
147+ (* Fill with padding value if padding is specified *)
148+ (match padding_info with
149+ | None -> ()
150+ | Some (_ , pad_value ) ->
151+ (* Fill the entire array with padding value using precision-specific fill *)
152+ (match prec with
153+ | Ops. Byte -> A. fill arr (Char. of_int_exn @@ Int. of_float pad_value)
154+ | Ops. Uint16 -> A. fill arr (Int. of_float pad_value)
155+ | Ops. Int32 -> A. fill arr (Int32. of_float pad_value)
156+ | Ops. Half -> A. fill arr pad_value
157+ | Ops. Bfloat16 -> A. fill arr (float_to_bfloat16 pad_value)
158+ | Ops. Fp8 -> A. fill arr (Char. of_int_exn @@ float_to_fp8 pad_value)
159+ | Ops. Single -> A. fill arr pad_value
160+ | Ops. Double -> A. fill arr pad_value));
161+
162+ (* Helper function to convert unpadded indices to padded indices *)
163+ let unpadded_to_padded_indices idcs =
164+ match padding_info with
165+ | None -> idcs
166+ | Some (pad_config , _ ) ->
167+ Array. map2_exn idcs pad_config ~f: (fun idx (left , _ ) -> idx + left)
168+ in
169+
170+ (* For non-constant fill operations, we need to iterate over unpadded dimensions *)
171+ let init_unpadded_region init_func =
172+ let rec loop_dims idcs dim_idx =
173+ if dim_idx = Array. length unpadded_dims then
174+ let padded_idcs = unpadded_to_padded_indices idcs in
175+ A. set arr padded_idcs (init_func idcs)
176+ else
177+ for i = 0 to unpadded_dims.(dim_idx) - 1 do
178+ idcs.(dim_idx) < - i;
179+ loop_dims idcs (dim_idx + 1 )
180+ done
181+ in
182+ loop_dims (Array. create ~len: (Array. length unpadded_dims) 0 ) 0
183+ in
184+
135185 let constant_fill_f f values strict =
136186 let len = Array. length values in
137187 if strict then (
138- let size = Array. fold ~init: 1 ~f: ( * ) dims in
188+ let size = Array. fold ~init: 1 ~f: ( * ) unpadded_dims in
139189 if size <> len then
140190 raise
141191 @@ Utils. User_error
142192 [% string
143193 " Ndarray.create_bigarray: Constant_fill: invalid data size %{len#Int}, expected \
144194 %{size#Int}" ];
145- init_bigarray_of_prec prec dims ~f: (fun idcs -> f values.(indices_to_offset ~dims ~idcs )))
195+ init_unpadded_region (fun idcs -> f values.(indices_to_offset ~dims: unpadded_dims ~idcs )))
146196 else
147- init_bigarray_of_prec prec dims ~f: (fun idcs ->
148- f values.(indices_to_offset ~dims ~idcs % len))
197+ init_unpadded_region (fun idcs ->
198+ f values.(indices_to_offset ~dims: unpadded_dims ~idcs % len))
149199 in
150200 let constant_fill_float values strict = constant_fill_f Fn. id values strict in
151- match (prec, init_op) with
201+
202+ (match (prec, init_op) with
152203 | Ops. Byte , Constant_fill { values; strict } ->
153- constant_fill_f (Fn. compose Char. of_int_exn Int. of_float) values strict
204+ ignore ( constant_fill_f (Fn. compose Char. of_int_exn Int. of_float) values strict)
154205 | Ops. Byte , Range_over_offsets ->
155- init_bigarray_of_prec prec dims ~f: (fun idcs ->
156- Char. of_int_exn @@ indices_to_offset ~dims ~idcs )
157- | Ops. Byte , Standard_uniform -> init_bigarray_of_prec prec dims ~f: ( fun _ -> Rand.Lib. char () )
158- | Ops. Uint16 , Constant_fill { values; strict } -> constant_fill_f Int. of_float values strict
206+ init_unpadded_region (fun idcs -> Char. of_int_exn @@ indices_to_offset ~dims: unpadded_dims ~idcs )
207+ | Ops. Byte , Standard_uniform -> init_unpadded_region ( fun _ -> Rand.Lib. char () )
208+ | Ops. Uint16 , Constant_fill { values; strict } ->
209+ ignore ( constant_fill_f Int. of_float values strict)
159210 | Ops. Uint16 , Range_over_offsets ->
160- init_bigarray_of_prec prec dims ~f: (fun idcs -> indices_to_offset ~dims ~idcs )
161- | Ops. Uint16 , Standard_uniform -> init_bigarray_of_prec prec dims ~f: (fun _ -> Random. int 65536 )
162- | Ops. Int32 , Constant_fill { values; strict } -> constant_fill_f Int32. of_float values strict
211+ init_unpadded_region (fun idcs -> indices_to_offset ~dims: unpadded_dims ~idcs )
212+ | Ops. Uint16 , Standard_uniform -> init_unpadded_region (fun _ -> Random. int 65536 )
213+ | Ops. Int32 , Constant_fill { values; strict } ->
214+ ignore (constant_fill_f Int32. of_float values strict)
163215 | Ops. Int32 , Range_over_offsets ->
164- init_bigarray_of_prec prec dims ~f: (fun idcs ->
165- Int32. of_int_exn @@ indices_to_offset ~dims ~idcs )
216+ init_unpadded_region (fun idcs -> Int32. of_int_exn @@ indices_to_offset ~dims: unpadded_dims ~idcs )
166217 | Ops. Int32 , Standard_uniform ->
167- init_bigarray_of_prec prec dims ~f: (fun _ -> Random. int32 Int32. max_value)
168- | Ops. Half , Constant_fill { values; strict } -> constant_fill_float values strict
218+ init_unpadded_region (fun _ -> Random. int32 Int32. max_value)
219+ | Ops. Half , Constant_fill { values; strict } -> ignore ( constant_fill_float values strict)
169220 | Ops. Half , Range_over_offsets ->
170- init_bigarray_of_prec prec dims ~f: (fun idcs -> Float. of_int @@ indices_to_offset ~dims ~idcs )
221+ init_unpadded_region (fun idcs -> Float. of_int @@ indices_to_offset ~dims: unpadded_dims ~idcs )
171222 | Ops. Half , Standard_uniform ->
172- init_bigarray_of_prec prec dims ~f: (fun _ -> Rand.Lib. float_range 0.0 1.0 )
223+ init_unpadded_region (fun _ -> Rand.Lib. float_range 0.0 1.0 )
173224 | Ops. Bfloat16 , Constant_fill { values; strict } ->
174- constant_fill_f float_to_bfloat16 values strict
225+ ignore ( constant_fill_f float_to_bfloat16 values strict)
175226 | Ops. Bfloat16 , Range_over_offsets ->
176- init_bigarray_of_prec prec dims ~f: (fun idcs ->
177- float_to_bfloat16 @@ Float. of_int @@ indices_to_offset ~dims ~idcs )
227+ init_unpadded_region (fun idcs ->
228+ float_to_bfloat16 @@ Float. of_int @@ indices_to_offset ~dims: unpadded_dims ~idcs )
178229 | Ops. Bfloat16 , Standard_uniform ->
179- init_bigarray_of_prec prec dims ~f: (fun _ ->
180- float_to_bfloat16 @@ Rand.Lib. float_range 0.0 1.0 )
230+ init_unpadded_region (fun _ -> float_to_bfloat16 @@ Rand.Lib. float_range 0.0 1.0 )
181231 | Ops. Fp8 , Constant_fill { values; strict } ->
182- constant_fill_f (Fn. compose Char. of_int_exn float_to_fp8) values strict
232+ ignore ( constant_fill_f (Fn. compose Char. of_int_exn float_to_fp8) values strict)
183233 | Ops. Fp8 , Range_over_offsets ->
184- init_bigarray_of_prec prec dims ~f: (fun idcs ->
185- Char. of_int_exn @@ float_to_fp8 @@ Float. of_int @@ indices_to_offset ~dims ~idcs )
234+ init_unpadded_region (fun idcs ->
235+ Char. of_int_exn @@ float_to_fp8 @@ Float. of_int @@ indices_to_offset ~dims: unpadded_dims ~idcs )
186236 | Ops. Fp8 , Standard_uniform ->
187- init_bigarray_of_prec prec dims ~f: (fun _ ->
188- Char. of_int_exn @@ float_to_fp8 @@ Rand.Lib. float_range 0.0 1.0 )
189- | Ops. Single , Constant_fill { values; strict } -> constant_fill_float values strict
237+ init_unpadded_region (fun _ -> Char. of_int_exn @@ float_to_fp8 @@ Rand.Lib. float_range 0.0 1.0 )
238+ | Ops. Single , Constant_fill { values; strict } -> ignore (constant_fill_float values strict)
190239 | Ops. Single , Range_over_offsets ->
191- init_bigarray_of_prec prec dims ~f: (fun idcs -> Float. of_int @@ indices_to_offset ~dims ~idcs )
240+ init_unpadded_region (fun idcs -> Float. of_int @@ indices_to_offset ~dims: unpadded_dims ~idcs )
192241 | Ops. Single , Standard_uniform ->
193- init_bigarray_of_prec prec dims ~f: (fun _ -> Rand.Lib. float_range 0.0 1.0 )
194- | Ops. Double , Constant_fill { values; strict } -> constant_fill_float values strict
242+ init_unpadded_region (fun _ -> Rand.Lib. float_range 0.0 1.0 )
243+ | Ops. Double , Constant_fill { values; strict } -> ignore ( constant_fill_float values strict)
195244 | Ops. Double , Range_over_offsets ->
196- init_bigarray_of_prec prec dims ~f: (fun idcs -> Float. of_int @@ indices_to_offset ~dims ~idcs )
245+ init_unpadded_region (fun idcs -> Float. of_int @@ indices_to_offset ~dims: unpadded_dims ~idcs )
197246 | Ops. Double , Standard_uniform ->
198- init_bigarray_of_prec prec dims ~f: (fun _ -> Rand.Lib. float_range 0.0 1.0 )
247+ init_unpadded_region (fun _ -> Rand.Lib. float_range 0.0 1.0 )
199248 | _ , File_mapped (filename , stored_prec ) ->
249+ (* For file mapping, we don't support padding yet - require no padding *)
250+ if Option. is_some padding then
251+ raise @@ Utils. User_error " Ndarray.create_bigarray: File_mapped initialization does not support padding" ;
200252 (* See: https://github.com/janestreet/torch/blob/master/src/torch/dataset_helper.ml#L3 *)
201253 if not @@ Ops. equal_prec stored_prec (Ops. pack_prec prec) then
202254 raise
@@ -215,12 +267,13 @@ let create_bigarray (type ocaml elt_t) (prec : (ocaml, elt_t) Ops.precision) ~di
215267 [% string
216268 " Ndarray.create_bigarray: File_mapped: invalid file bytes %{len#Int}, expected \
217269 %{size * Ops.prec_in_bytes stored_prec#Int}" ]);
218- let ba =
270+ let file_arr =
219271 Unix. map_file fd (precision_to_bigarray_kind prec) Bigarray. c_layout false dims
220272 ~pos: (Int64. of_int 0 )
221273 in
222274 Unix. close fd;
223- ba
275+ A. blit file_arr arr);
276+ arr
224277
225278(* * {2 *** Accessing ***} *)
226279
@@ -331,7 +384,8 @@ let reset_bigarray (init_op : Ops.init_op) (type o b) (prec : (o, b) Ops.precisi
331384 | Ops. Int32 , Constant_fill { values; strict } -> constant_set_f Int32. of_float values strict
332385 | Ops. Int32 , Range_over_offsets ->
333386 set_bigarray arr ~f: (fun idcs -> Int32. of_int_exn @@ indices_to_offset ~dims ~idcs )
334- | Ops. Int32 , Standard_uniform -> set_bigarray arr ~f: (fun _ -> Random. int32 Int32. max_value)
387+ | Ops. Int32 , Standard_uniform ->
388+ set_bigarray arr ~f: (fun _ -> Random. int32 Int32. max_value)
335389 | Ops. Half , Constant_fill { values; strict } -> constant_set_float values strict
336390 | Ops. Half , Range_over_offsets ->
337391 set_bigarray arr ~f: (fun idcs -> Float. of_int @@ indices_to_offset ~dims ~idcs )
@@ -357,7 +411,7 @@ let reset_bigarray (init_op : Ops.init_op) (type o b) (prec : (o, b) Ops.precisi
357411 | Ops. Double , Range_over_offsets ->
358412 set_bigarray arr ~f: (fun idcs -> Float. of_int @@ indices_to_offset ~dims ~idcs )
359413 | Ops. Double , Standard_uniform -> set_bigarray arr ~f: (fun _ -> Rand.Lib. float_range 0.0 1.0 )
360- | _ , File_mapped _ -> A. blit (create_bigarray prec ~dims init_op) arr
414+ | _ , File_mapped _ -> A. blit (create_bigarray prec ~dims ~padding: None init_op) arr
361415
362416let reset init_op arr =
363417 let f arr = reset_bigarray init_op arr in
@@ -504,16 +558,17 @@ let hash_t nd = Nativeint.hash @@ to_native nd
504558
505559let used_memory = Atomic. make 0
506560
507- let % track7_sexp create_array ~debug: (_debug : string ) (prec : Ops.prec ) ~(dims : int array ) init_op
508- =
561+ let % track7_sexp create_array ~debug: (_debug : string ) (prec : Ops.prec ) ~(dims : int array )
562+ ~padding init_op =
563+ (* dims already includes padding if padding is specified *)
509564 let size_in_bytes : int =
510565 (if Array. length dims = 0 then 0 else Array. reduce_exn dims ~f: ( * )) * Ops. prec_in_bytes prec
511566 in
512567 let % track7_sexp finalizer (_result : t ) =
513568 let _ : int = Atomic. fetch_and_add used_memory size_in_bytes in
514569 [% log3 " Deleting" , _debug, ptr_to_string_hum _result]
515570 in
516- let f prec = as_array prec @@ create_bigarray prec ~dims init_op in
571+ let f prec = as_array prec @@ create_bigarray prec ~dims ~padding init_op in
517572 let result = Ops. map_prec { f } prec in
518573 Stdlib.Gc. finalise finalizer result;
519574 let _ : int = Atomic. fetch_and_add used_memory size_in_bytes in
@@ -524,7 +579,7 @@ let%track7_sexp create_array ~debug:(_debug : string) (prec : Ops.prec) ~(dims :
524579 result
525580
526581let empty_array prec =
527- create_array prec ~dims: [||] (Constant_fill { values = [| 0.0 |]; strict = false })
582+ create_array prec ~dims: [||] ~padding: None (Constant_fill { values = [| 0.0 |]; strict = false })
528583
529584let get_used_memory () = Atomic. get used_memory
530585
0 commit comments