Skip to content

Commit dbff017

Browse files
committed
Fix missing reshape in Tnode.create_with_reshape
1 parent 61cf318 commit dbff017

File tree

2 files changed

+32
-32
lines changed

2 files changed

+32
-32
lines changed

arrayjit/lib/ndarray.ml

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -385,15 +385,14 @@ let hash nd = Nativeint.hash (to_native nd)
385385
let hash_fold_t acc nd = hash_fold_nativeint acc (to_native nd)
386386
let hash_t nd = Nativeint.hash @@ to_native nd
387387

388-
(** C function declarations for efficient copying *)
389-
external copy_with_padding_c :
390-
('a, 'b) bigarray -> ('a, 'b) bigarray -> axis_padding array -> unit
388+
external copy_with_padding_c : ('a, 'b) bigarray -> ('a, 'b) bigarray -> axis_padding array -> unit
391389
= "arrayjit_copy_with_padding"
390+
(** C function declarations for efficient copying *)
392391

392+
(** Copies the whole of [source] onto the parts of [target] skipping over padding margins --
393+
requires that source dimensions + padding = target dimensions. *)
393394
let copy_with_padding ~source ~target ~padding =
394-
let copy_impl source_arr target_arr =
395-
copy_with_padding_c source_arr target_arr padding
396-
in
395+
let copy_impl source_arr target_arr = copy_with_padding_c source_arr target_arr padding in
397396
map2 { f2 = copy_impl } source target
398397

399398
(** {2 *** Creating ***} *)
@@ -420,6 +419,11 @@ let%track7_sexp create_array ~debug:(_debug : string) (prec : Ops.prec) ~(dims :
420419
[%log _debug, ptr_to_string_hum result]]];
421420
result
422421

422+
(** See {!Bigarray.reshape}. *)
423+
let reshape nd dims =
424+
let f prec arr = as_array prec @@ Bigarray.reshape arr dims in
425+
map_with_prec { f } nd
426+
423427
let get_used_memory () = Atomic.get used_memory
424428

425429
(** {2 *** Printing ***} *)

arrayjit/lib/tnode.ml

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -634,35 +634,31 @@ let create_with_reshape ~id ~label ~base_ndarray ~dims ~padding ~from_padded ()
634634
Nd.as_array prec (Bigarray.reshape arr target_dims)
635635
in
636636
Some (Nd.map_with_prec { f = f_reshape_with_prec } base_ndarray)
637-
| Some (padding_spec, _), false ->
637+
| Some (padding, _), false ->
638638
(* Create new bigarray with padding and copy source into non-padding parts *)
639-
let target_array =
640-
Nd.create_array ~debug prec_val ~dims:target_dims ~padding:target_padding
641-
in
639+
let target = Nd.create_array ~debug prec_val ~dims:target_dims ~padding:target_padding in
642640
let source_dims = Nd.dims base_ndarray in
643-
let copy_with_padding () =
644-
(* Calculate actual data dimensions (target dims minus padding) *)
645-
let data_dims =
646-
Array.map2_exn target_dims padding_spec ~f:(fun dim { Nd.left; right } ->
647-
dim - left - right)
648-
in
649-
(* Check total elements match, allowing shape differences *)
650-
let source_total =
651-
if Array.is_empty source_dims then 0 else Array.reduce_exn source_dims ~f:( * )
652-
in
653-
let data_total =
654-
if Array.is_empty data_dims then 0 else Array.reduce_exn data_dims ~f:( * )
655-
in
656-
if source_total <> data_total then
657-
invalid_arg
658-
[%string
659-
"create_with_reshape: source has %{source_total#Int} elements but target data \
660-
area has %{data_total#Int} elements"];
661-
(* Use C function for efficient copying *)
662-
Nd.copy_with_padding ~source:base_ndarray ~target:target_array ~padding:padding_spec
641+
(* Calculate actual data dimensions (target dims minus padding) *)
642+
let data_dims =
643+
Array.map2_exn target_dims padding ~f:(fun dim { Nd.left; right } ->
644+
dim - left - right)
645+
in
646+
(* Check total elements match, allowing shape differences *)
647+
let source_total =
648+
if Array.is_empty source_dims then 0 else Array.reduce_exn source_dims ~f:( * )
649+
in
650+
let data_total =
651+
if Array.is_empty data_dims then 0 else Array.reduce_exn data_dims ~f:( * )
663652
in
664-
copy_with_padding ();
665-
Some target_array)
653+
if source_total <> data_total then
654+
invalid_arg
655+
[%string
656+
"create_with_reshape: source has %{source_total#Int} elements but target data \
657+
area has %{data_total#Int} elements"];
658+
(* Use C function for efficient copying *)
659+
let source = Nd.reshape base_ndarray data_dims in
660+
Nd.copy_with_padding ~source ~target ~padding;
661+
Some target)
666662
and prec = lazy prec_val
667663
and size_in_bytes = lazy (num_elems tn * Ops.prec_in_bytes (Lazy.force tn.prec))
668664
and tn =

0 commit comments

Comments
 (0)