File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -520,7 +520,7 @@ let rebatch ~l ndarray =
520520 let output_dims = Ir.Ndarray. dims ndarray |> Array. to_list |> List. tl_exn in
521521 Tensor. term ~init_data: (Reshape ndarray) ~label: [ l ] ~input_dims: [] ~output_dims
522522
523- let uniform ?grad_spec =
523+ let uniform ?grad_spec () =
524524 uint4x32_to_prec_uniform ?grad_spec
525525 (threefry4x32 (embed_self_id () )
526526 (Tensor. term ~fetch_op: Range_over_offsets ~grad_spec: Prohibit_grad
@@ -538,15 +538,15 @@ module TDSL = struct
538538 let embed_self_id = embed_self_id
539539
540540 (* * The default initialization operation for {!param} calls. *)
541- let default_param_init = ref uniform
541+ let default_param_init = ref ( uniform ~grad_spec: Require_grad )
542542
543543 let param ?value ?values =
544544 let t =
545545 match (value, values) with
546546 | Some _ , Some _ -> invalid_arg " TDSL.param: both value and values are set"
547547 | Some value , None -> Tensor. param_init [| value |]
548548 | None , Some values -> Tensor. param_init values
549- | None , None -> ! default_param_init ~grad_spec: Require_grad ~batch_dims: [] ?batch_axes:None
549+ | None , None -> ! default_param_init () ~batch_dims: [] ?batch_axes:None
550550 in
551551 Tensor. param ~t
552552
You can’t perform that action at this time.
0 commit comments