Skip to content

Commit 03a0d87

Browse files
committed
Fix nullary operation uniform and the default initialization setup to generate properly fresh tensor expressions
1 parent 88f6de2 commit 03a0d87

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

lib/operation.ml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ let rebatch ~l ndarray =
520520
let output_dims = Ir.Ndarray.dims ndarray |> Array.to_list |> List.tl_exn in
521521
Tensor.term ~init_data:(Reshape ndarray) ~label:[ l ] ~input_dims:[] ~output_dims
522522

523-
let uniform ?grad_spec =
523+
let uniform ?grad_spec () =
524524
uint4x32_to_prec_uniform ?grad_spec
525525
(threefry4x32 (embed_self_id ())
526526
(Tensor.term ~fetch_op:Range_over_offsets ~grad_spec:Prohibit_grad
@@ -538,15 +538,15 @@ module TDSL = struct
538538
let embed_self_id = embed_self_id
539539

540540
(** The default initialization operation for {!param} calls. *)
541-
let default_param_init = ref uniform
541+
let default_param_init = ref (uniform ~grad_spec:Require_grad)
542542

543543
let param ?value ?values =
544544
let t =
545545
match (value, values) with
546546
| Some _, Some _ -> invalid_arg "TDSL.param: both value and values are set"
547547
| Some value, None -> Tensor.param_init [| value |]
548548
| None, Some values -> Tensor.param_init values
549-
| None, None -> !default_param_init ~grad_spec:Require_grad ~batch_dims:[] ?batch_axes:None
549+
| None, None -> !default_param_init () ~batch_dims:[] ?batch_axes:None
550550
in
551551
Tensor.param ~t
552552

0 commit comments

Comments
 (0)