Skip to content

Commit 1fdbc4b

Browse files
committed
Be lenient about pre-filling params with values but special-case filling with a single value
The multiple-values was intended to add shape constraint but apparently is leaky.
1 parent 02fabcd commit 1fdbc4b

File tree

5 files changed

+17
-6
lines changed

5 files changed

+17
-6
lines changed

arrayjit/lib/assignments.ml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,10 @@ let%diagn2_sexp to_low_level code =
266266
let offset = Indexing.reflect_projection ~dims ~projection:idcs in
267267
set array idcs @@ Embed_index offset)
268268
| Fetch { array; fetch_op = Constant_fill values; dims = (lazy dims) } ->
269+
(* TODO: consider failing here and strengthening shape inference. *)
270+
let size = Array.length values in
269271
Low_level.unroll_dims dims ~body:(fun idcs ~offset ->
270-
set array idcs @@ Constant values.(offset))
272+
set array idcs @@ Constant values.(offset % size))
271273
in
272274
loop code
273275

bin/hello_world_op.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ let%track2_sexp _Big_matrix (() : unit) : unit =
189189
let ctx = Backend.make_context stream in
190190
Rand.init 0;
191191
(* Hey is inferred to be a matrix. *)
192-
let hey = Tensor.param ~values:[| 0.5 |] "hey" in
192+
let hey = Tensor.param ~value:0.5 "hey" in
193193
let zero_to_twenty = TDSL.range 20 in
194194
let%op yd = (hey * zero_to_twenty) + zero_to_twenty in
195195
Train.forward_and_forget backend ctx yd;

lib/tensor.ml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -409,10 +409,14 @@ let ndarray ?(label = []) ?(grad_spec = Prohibit_grad) ?batch_dims ?input_dims ?
409409
Tn.update_prec ~only_if:is_up_to_fp16 t.value single);
410410
t
411411

412-
let param ?(more_label = []) ?input_dims ?output_dims ?input_axes ?output_axes ?deduced ?values
412+
let param ?(more_label = []) ?input_dims ?output_dims ?input_axes ?output_axes ?deduced ?value ?values
413413
label =
414414
let fetch_op_fn ~v:_ =
415-
match values with Some values -> Asgns.Constant_fill values | None -> Asgns.Range_over_offsets
415+
match values, value with
416+
| Some values, None -> Asgns.Constant_fill values
417+
| None, Some value -> Asgns.Constant value
418+
| None, None -> Asgns.Range_over_offsets
419+
| Some _, Some _ -> invalid_arg "Tensor.param: both values and value are set"
416420
in
417421
let t =
418422
term ~label:(label :: more_label) ~grad_spec:Require_grad ~batch_dims:[] ?input_dims

lib/tensor.mli

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,12 +214,16 @@ val param :
214214
?input_axes:(string * int) list ->
215215
?output_axes:(string * int) list ->
216216
?deduced:Shape.deduce_within_shape ->
217+
?value:float ->
217218
?values:float array ->
218219
string ->
219220
t
220221
(* A tensor with no batch axes; input and output axes are by default inferred. [grad_spec] is set to
221222
[Require_grad]. The resulting tensor's label is the passed string, appended by [more_label] if
222-
any. *)
223+
any. If [value] is provided, the tensor is initialized to the given value. If [values] is
224+
provided, the tensor is initialized to the given values. At most one of [value] or [values] can
225+
be provided. Note: [values] will be looped over if necessary, but shape inference will try
226+
incorporating the number of values as tensor size. *)
223227

224228
val consume_forward_code : t -> comp
225229
(** A forward root is a tensor that is not (currently) used to compute another tensor.

test/hello_world_op.ml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ let%expect_test "Print constant tensor" =
104104

105105
let%op hey = [ (1, 2, 3); (4, 5, 6) ] in
106106
Train.forward_and_forget backend ctx hey;
107+
(* ignore (failwith @@ Tn.debug_memory_mode hey.value.memory_mode); *)
107108
Tensor.print ~with_code:false ~with_grad:false `Inline @@ hey;
108109
[%expect
109110
{|
@@ -509,7 +510,7 @@ let%expect_test "Big matrix" =
509510
let ctx = Backend.make_context stream in
510511
Rand.init 0;
511512
(* Hey is inferred to be a matrix. *)
512-
let hey = Tensor.param ~values:[| 0.5 |] "hey" in
513+
let hey = Tensor.param ~value:0.5 "hey" in
513514
let zero_to_twenty = TDSL.range 20 in
514515
let y = TDSL.O.((hey * zero_to_twenty) + zero_to_twenty) in
515516
Train.forward_and_forget backend ctx y;

0 commit comments

Comments
 (0)