Skip to content

Commit f24a1e6

Browse files
committed
Cleanup to the param interface
1 parent a74ec40 commit f24a1e6

File tree

6 files changed

+23
-42
lines changed

6 files changed

+23
-42
lines changed

bin/hello_world.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ let hello3 () =
5050
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
5151
let ctx = Backend.make_context stream in
5252
(* Hey is inferred to be a matrix. *)
53-
let hey = Tensor.param "hey" in
53+
let hey = TDSL.param "hey" in
5454
let zero_to_twenty = TDSL.range 20 in
5555
let y = TDSL.O.(( + ) ~label:[ "y" ] (hey * zero_to_twenty) zero_to_twenty) in
5656
Train.set_hosted hey.value;

lib/operation.ml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -478,13 +478,16 @@ module TDSL = struct
478478
let number = Tensor.number ~grad_spec:If_needed
479479
let ndarray = Tensor.ndarray ~grad_spec:If_needed
480480

481+
(** The default initialization operation for {!param} calls. *)
482+
let default_param_init = ref @@ Tensor.fetch_param_init (Asgns.Constant 0.0)
483+
481484
let param ?value ?values =
482485
let t =
483486
match (value, values) with
484487
| Some _, Some _ -> invalid_arg "TDSL.param: both value and values are set"
485488
| Some value, None -> Tensor.fetch_param_init (Asgns.Constant value)
486489
| None, Some values -> Tensor.fetch_param_init (Asgns.Constant_fill values)
487-
| None, None -> !Tensor.default_param_init
490+
| None, None -> !default_param_init
488491
in
489492
Tensor.param ~t
490493

@@ -502,15 +505,15 @@ module TDSL = struct
502505
(** The input and output dimensions will be inferred if omitted. See {!reshape}. *)
503506
let reshape_param ~l ?i ?o ndarray =
504507
let t =
505-
Tensor.term ~grad_spec:Require_grad ~batch_dims:[] ~batch_axes:[] ~init_data:(Reshape ndarray)
506-
?fetch_op:None
508+
Tensor.term ~grad_spec:Require_grad ~batch_dims:[] ?batch_axes:None
509+
~init_data:(Reshape ndarray) ?fetch_op:None
507510
in
508511
Tensor.param ?input_dims:i ?output_dims:o ~t l
509512

510513
(** See {!wrap}. *)
511514
let wrap_param ~l ?i ?o ndarray =
512515
let t =
513-
Tensor.term ~grad_spec:Require_grad ~batch_dims:[] ~batch_axes:[]
516+
Tensor.term ~grad_spec:Require_grad ~batch_dims:[] ?batch_axes:None
514517
~init_data:(Keep_shape_no_padding ndarray) ?fetch_op:None
515518
in
516519
Tensor.param ?input_dims:i ?output_dims:o ~t l

lib/shape.ml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -462,10 +462,8 @@ let%debug4_sexp get_inequalities ({ shape = cur_sh; logic; id = _ } as _upd : up
462462
r = [ cur_sh.batch; cur_sh.output; cur_sh.input ];
463463
constr =
464464
Exact
465-
(Lazy.force tn.dims
466-
|> Array.to_list |> List.tl_exn
467-
|> List.map ~f:(fun d -> get_dim ~d ())
468-
);
465+
(Lazy.force tn.dims |> Array.to_list |> List.tl_exn
466+
|> List.map ~f:(fun d -> get_dim ~d ()));
469467
}
470468
:: mark_terminal () )
471469
else (Row.dim_map_empty, mark_terminal ())

lib/syntax_extensions.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,8 @@ let interpret_ternop op v1 v2 v3 =
149149

150150
```ocaml
151151
let hid_dim = 8 in
152-
let w = Tensor.param "w" in
153-
let b = Tensor.param ~output_dims:[ hid_dim ] "b" in
152+
let w = TDSL.param "w" in
153+
let b = TDSL.param ~output_dims:[ hid_dim ] "b" in
154154
let layer x = TDSL.O.( relu(w * x + b) ) in
155155
...
156156
```
@@ -159,8 +159,8 @@ Since `TDSL.O` is opened for the scope of an extension point `%op`:
159159

160160
```ocaml
161161
let hid_dim = 8 in
162-
let w = Tensor.param "w" in
163-
let b = Tensor.param ~output_dims:[ hid_dim ] "b" in
162+
let w = TDSL.param "w" in
163+
let b = TDSL.param ~output_dims:[ hid_dim ] "b" in
164164
let%op layer x = relu(w * x + b) in
165165
...
166166
```
@@ -413,7 +413,7 @@ If you recall, inline declared param tensors get lifted out of functions except
413413

414414
```ocaml
415415
let mlp_layer ~config =
416-
let w = Tensor.param "w" and b = Tensor.param ~output_dims:[ config.hid_dim ] in
416+
let w = TDSL.param "w" and b = TDSL.param ~output_dims:[ config.hid_dim ] in
417417
fun x -> TDSL.O.(w * x + b)
418418
```
419419

lib/tensor.ml

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -440,18 +440,11 @@ let ndarray ?(label = []) ?(grad_spec = Prohibit_grad) ?batch_dims ?input_dims ?
440440
t
441441

442442
let fetch_param_init fetch_op =
443-
term ~grad_spec:Require_grad ~batch_dims:[] ~batch_axes:[] ?init_data:None ~fetch_op
443+
term ~grad_spec:Require_grad ~batch_dims:[] ?batch_axes:None ?init_data:None ~fetch_op
444444

445-
let default_param_init = ref @@ fetch_param_init (Asgns.Constant 0.0)
446-
447-
let param ?(more_label = []) ?input_dims ?output_dims ?input_axes ?output_axes ?deduced ?t label =
445+
let param ?(more_label = []) ?input_dims ?output_dims ?input_axes ?output_axes ?deduced ~t label =
448446
let t =
449-
match t with
450-
| Some t ->
451-
t ~label:(label :: more_label) ?input_dims ?output_dims ?input_axes ?output_axes ?deduced ()
452-
| None ->
453-
!default_param_init ~label:(label :: more_label) ?input_dims ?output_dims ?input_axes
454-
?output_axes ?deduced ()
447+
t ~label:(label :: more_label) ?input_dims ?output_dims ?input_axes ?output_axes ?deduced ()
455448
in
456449
let v = t.value in
457450
(* It is convenient to use the param syntax for volatiles (mutable embedded_nodes). *)

lib/tensor.mli

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -217,18 +217,6 @@ val ndarray :
217217
given values must fill the tensor's [value] node precisely; otherwise, the values will be looped
218218
over to populate the [value] node. *)
219219

220-
val default_param_init :
221-
(label:string list ->
222-
?input_dims:int list ->
223-
?output_dims:int list ->
224-
?input_axes:(string * int) list ->
225-
?output_axes:(string * int) list ->
226-
?deduced:Shape.deduce_within_shape ->
227-
unit ->
228-
t)
229-
ref
230-
(** The default initialization operation for {!param} calls that do not pass a [t]. *)
231-
232220
val fetch_param_init :
233221
fetch_op ->
234222
label:string list ->
@@ -239,7 +227,7 @@ val fetch_param_init :
239227
?deduced:Shape.deduce_within_shape ->
240228
unit ->
241229
t
242-
(** Helper for {!param} wrappers or to set {!default_param_init}. *)
230+
(** Helper for {!param} wrappers. *)
243231

244232
val param :
245233
?more_label:string list ->
@@ -248,7 +236,7 @@ val param :
248236
?input_axes:(string * int) list ->
249237
?output_axes:(string * int) list ->
250238
?deduced:Shape.deduce_within_shape ->
251-
?t:
239+
t:
252240
(label:string list ->
253241
?input_dims:int list ->
254242
?output_dims:int list ->
@@ -261,10 +249,9 @@ val param :
261249
t
262250
(** For proper parameters, [t] should produce a tensor with no batch axes; input and output axes
263251
should by default be inferred; [grad_spec] should be [Require_grad]. [t]'s label is the passed
264-
string, appended by [more_label] if any, other parameters are forwarded to [t]. If [t] is not
265-
provided, {!default_param_init} is used. This function returns [t]'s result with the field
266-
{!field:params} replaced by a singleton set containing that result, and it also updates the
267-
memory modes. *)
252+
string, appended by [more_label] if any, other parameters are forwarded to [t]. This function
253+
returns [t]'s result with the field {!field:params} replaced by a singleton set containing that
254+
result, and it also updates the memory modes. *)
268255

269256
val consume_forward_code : t -> comp
270257
(** A forward root is a tensor that is not (currently) used to compute another tensor.

0 commit comments

Comments
 (0)