Skip to content

Commit 94ba839

Browse files
committed
Untested: (1) restoring hosted data initialization; (2) arbitrary tensor expression initialized params; (3) fix in backprop for params
(1) restores initialization functionality, but from ndarray, mostly no-copy. (2) allows for wrapping e.g. random sampling tensor expression as a param. (3) prevents backprop into initialization code of params. That code doesn't disappear, can be used manually.
1 parent b01c287 commit 94ba839

File tree

6 files changed

+128
-50
lines changed

6 files changed

+128
-50
lines changed

bin/hello_world_op.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ let%track2_sexp _Big_matrix (() : unit) : unit =
189189
let ctx = Backend.make_context stream in
190190
Rand.init 0;
191191
(* Hey is inferred to be a matrix. *)
192-
let hey = Tensor.param ~value:0.5 "hey" in
192+
let hey = TDSL.param ~value:0.5 "hey" in
193193
let zero_to_twenty = TDSL.range 20 in
194194
let%op yd = (hey * zero_to_twenty) + zero_to_twenty in
195195
Train.forward_and_forget backend ctx yd;

lib/operation.ml

Lines changed: 66 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,6 @@ module Initial_NTDSL = struct
3737
module O = struct end
3838
end
3939

40-
module Initial_TDSL = struct
41-
let term = Tensor.term ~grad_spec:If_needed
42-
let number = Tensor.number ~grad_spec:If_needed
43-
let ndarray = Tensor.ndarray ~grad_spec:If_needed
44-
let param = Tensor.param
45-
46-
module O = struct end
47-
end
48-
4940
let add ?(label = []) =
5041
let module NTDSL = Initial_NTDSL in
5142
let%cd op_asn ~v ~t1 ~t2 ~projections = v =: v1 + v2 in
@@ -452,27 +443,77 @@ module NDO = struct
452443
let ( <> ) = ne ~grad_spec:Prohibit_grad
453444
end
454445

446+
(** The input [i] dimensions default to empty. The batch and output dimensions will be inferred if
447+
omitted. Note: the data should have no padding and if padding is inferred, the data will be
448+
copied; otherwise, the resulting tensor value shares host memory with the ndarray. *)
449+
let reshape ~l ?b ?(i = []) ?o ndarray =
450+
Tensor.term ~label:[ l ] ?batch_dims:b ~input_dims:i ?output_dims:o ~init_data:(Reshape ndarray)
451+
()
452+
453+
(** The dimensions are taken from the provided ndarray, but the split into axis kinds still needs to
454+
be inferred (or provided). Assumes no padding. See also: {!reshape} and {!TDSL.wrap_param}. *)
455+
let wrap ~l ?b ?(i = []) ?o ndarray =
456+
Tensor.term ~label:[ l ] ?batch_dims:b ~input_dims:i ?output_dims:o
457+
~init_data:(Keep_shape_no_padding ndarray) ()
458+
459+
(** Assumes the ndarray is padded as given. This means the dimensions of the ndarray will differ
460+
from the dimensions of the tensor by the padding. See also: {!TDSL.wrap}. *)
461+
let wrap_padded ~l ?b ?(i = []) ?o ~padding ~padded_value ndarray =
462+
Tensor.term ~label:[ l ] ?batch_dims:b ~input_dims:i ?output_dims:o
463+
~init_data:(Padded { data = ndarray; padding; padded_value })
464+
()
465+
466+
(** The output dimensions are taken from the provided ndarray, assuming precisely the first axis is
467+
a batch axis, assumes no input axes and the batch dimensions are inferred. Assumes the data has
468+
no padding, and data is copied if padding is inferred. See also: {!reshape} and {!wrap}. *)
469+
let rebatch ~l ndarray =
470+
let output_dims = Ir.Ndarray.dims ndarray |> Array.to_list |> List.tl_exn in
471+
if List.is_empty output_dims then invalid_arg "rebatch: ndarray has just one axis";
472+
Tensor.term ~label:[ l ] ~input_dims:[] ~output_dims ~init_data:(Reshape ndarray) ()
473+
455474
module TDSL = struct
456-
include Initial_TDSL
457475
module O = DO
458476

477+
let term = Tensor.term ~grad_spec:If_needed
478+
let number = Tensor.number ~grad_spec:If_needed
479+
let ndarray = Tensor.ndarray ~grad_spec:If_needed
480+
481+
let param ?value ?values =
482+
let t =
483+
match (value, values) with
484+
| Some _, Some _ -> invalid_arg "TDSL.param: both value and values are set"
485+
| Some value, None -> Tensor.fetch_param_init (Asgns.Constant value)
486+
| None, Some values -> Tensor.fetch_param_init (Asgns.Constant_fill values)
487+
| None, None -> !Tensor.default_param_init
488+
in
489+
Tensor.param ~t
490+
459491
let einsum = einsum ~grad_spec:If_needed
460492
let outer_sum = outer_sum ~grad_spec:If_needed
461493
let einsum1 = einsum1 ~grad_spec:If_needed
462494
let range = range ~grad_spec:If_needed
463495
let range_of_shape = range_of_shape ~grad_spec:If_needed
464496
let stop_gradient = stop_gradient
465-
466-
(** The input [i] dimensions default to empty. The batch dimensions will be inferred if omitted.
467-
*)
468-
let init_const ~l ?b ?(i = []) ~o values =
469-
Tensor.term ~label:[ l ] ~grad_spec:Prohibit_grad ?batch_dims:b ~input_dims:i ~output_dims:o
470-
~fetch_op:(Asgns.Constant_fill values) ()
471-
472-
(** It's like `Tensor.param` but without shape inference. *)
473-
let init_param ~l ?(b = []) ?(i = []) ?(o = []) values =
474-
Tensor.term ~label:[ l ] ~grad_spec:Require_grad ~batch_dims:b ~input_dims:i ~output_dims:o
475-
~fetch_op:(Asgns.Constant_fill values) ()
497+
let reshape = reshape ~grad_spec:If_needed
498+
let wrap = wrap ~grad_spec:If_needed
499+
let wrap_padded = wrap_padded ~grad_spec:If_needed
500+
let rebatch = rebatch ~grad_spec:If_needed
501+
502+
(** The input and output dimensions will be inferred if omitted. See {!reshape}. *)
503+
let reshape_param ~l ?i ?o ndarray =
504+
let t =
505+
Tensor.term ~grad_spec:Require_grad ~batch_dims:[] ~batch_axes:[] ~init_data:(Reshape ndarray)
506+
?fetch_op:None
507+
in
508+
Tensor.param ?input_dims:i ?output_dims:o ~t l
509+
510+
(** See {!wrap}. *)
511+
let wrap_param ~l ?i ?o ndarray =
512+
let t =
513+
Tensor.term ~grad_spec:Require_grad ~batch_dims:[] ~batch_axes:[]
514+
~init_data:(Keep_shape_no_padding ndarray) ?fetch_op:None
515+
in
516+
Tensor.param ?input_dims:i ?output_dims:o ~t l
476517
end
477518

478519
module NTDSL = struct
@@ -485,6 +526,10 @@ module NTDSL = struct
485526
let term = Tensor.term ~grad_spec:Prohibit_grad
486527
let range = range ~grad_spec:Prohibit_grad
487528
let range_of_shape = range_of_shape ~grad_spec:Prohibit_grad
529+
let reshape = reshape ~grad_spec:Prohibit_grad
530+
let wrap = wrap ~grad_spec:Prohibit_grad
531+
let wrap_padded = wrap_padded ~grad_spec:Prohibit_grad
532+
let rebatch = rebatch ~grad_spec:Prohibit_grad
488533

489534
let counter ?(label = []) =
490535
let module NTDSL = Initial_NTDSL in

lib/ppx_op.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ let make_p ~has_config ~loc =
2121

2222
let make_vb ?value ~has_config ~loc ~str_loc ~ident string =
2323
let pat = Ast_helper.Pat.var ~loc { loc = str_loc; txt = ident } in
24-
let value = match value with Some c -> [%expr Some [| [%e c] |]] | None -> [%expr None] in
25-
let v = [%expr [%e make_p ~has_config ~loc] ?values:[%e value] [%e string]] in
24+
let value = match value with Some c -> [%expr Some [%e c] ] | None -> [%expr None] in
25+
let v = [%expr [%e make_p ~has_config ~loc] ?value:[%e value] [%e string]] in
2626
let vb = Ast_helper.Vb.mk ~loc pat v in
2727
(pat, vb)
2828

lib/tensor.ml

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -205,9 +205,8 @@ let raw_unop ~initialize_neutral ~accum ~(t : t) ~(lhs_is_grad : bool) ~op ~(t1
205205
type grad_spec = Require_grad | Prohibit_grad | If_needed [@@deriving sexp, equal, variants]
206206

207207
let op ~(label : string list) ?(ternary_op = Shape.Pointwise_tern)
208-
?(compose_op = Shape.Pointwise_bin) ?(transpose_op = Shape.Pointwise_un)
209-
?init_data ?fetch_op ~op_asn ~grad_asn
210-
?(grad_spec = If_needed) make_shape (orig_ts : t list) : t =
208+
?(compose_op = Shape.Pointwise_bin) ?(transpose_op = Shape.Pointwise_un) ?init_data ?fetch_op
209+
~op_asn ~grad_asn ?(grad_spec = If_needed) make_shape (orig_ts : t list) : t =
211210
(* The code needs to be included in the order it was computed due to potential non-tree DAGs. *)
212211
let ordered_ts = List.dedup_and_sort orig_ts ~compare:(fun t1 t2 -> Int.ascending t1.id t2.id) in
213212
let id = session_state.next_id in
@@ -222,7 +221,7 @@ let op ~(label : string list) ?(ternary_op = Shape.Pointwise_tern)
222221
|> Option.value ~default)
223222
in
224223
let terminal_logic () =
225-
match fetch_op, init_data with
224+
match (fetch_op, init_data) with
226225
| None, None -> Shape.Terminal (`Fetch (Asgns.Constant 0.0))
227226
| Some fetch_op, _ -> Shape.Terminal (`Fetch fetch_op)
228227
| None, Some init_data -> Shape.Terminal (`Data init_data)
@@ -319,7 +318,8 @@ let op ~(label : string list) ?(ternary_op = Shape.Pointwise_tern)
319318
diff.backprop)
320319
in
321320
let bcks =
322-
List.filter_map ordered_ts ~f:(fun ti -> if is_bck_root ti then bprop ti else None)
321+
List.filter_map ordered_ts ~f:(fun ti ->
322+
if is_bck_root ti && not (Set.mem t.params ti) then bprop ti else None)
323323
in
324324
let backprop = Asgns.sequence @@ (grad_asn ~t ~g ~projections :: bcks) in
325325
let backprop =
@@ -375,7 +375,8 @@ let term ~label ~grad_spec ?batch_dims ?input_dims ?output_dims ?batch_axes ?inp
375375
Shape.make ?batch_dims ?input_dims ?output_dims ?batch_axes ?input_axes ?output_axes ?deduced ()
376376
in
377377
(* Note: fetch_op in op is used only for shape inference. *)
378-
op ~label ?compose_op:None ?transpose_op:None ?init_data ?fetch_op ~op_asn ~grad_asn ~grad_spec make_shape []
378+
op ~label ?compose_op:None ?transpose_op:None ?init_data ?fetch_op ~op_asn ~grad_asn ~grad_spec
379+
make_shape []
379380

380381
let float_to_label v = Float.to_string v
381382

@@ -438,18 +439,19 @@ let ndarray ?(label = []) ?(grad_spec = Prohibit_grad) ?batch_dims ?input_dims ?
438439
Tn.update_prec ~only_if:is_up_to_fp16 t.value single);
439440
t
440441

441-
let param ?(more_label = []) ?input_dims ?output_dims ?input_axes ?output_axes ?deduced ?value
442-
?values label =
443-
let fetch_op =
444-
match (values, value) with
445-
| Some values, None -> Asgns.Constant_fill values
446-
| None, Some value -> Asgns.Constant value
447-
| None, None -> Asgns.Range_over_offsets
448-
| Some _, Some _ -> invalid_arg "Tensor.param: both values and value are set"
449-
in
442+
let fetch_param_init fetch_op =
443+
term ~grad_spec:Require_grad ~batch_dims:[] ~batch_axes:[] ?init_data:None ~fetch_op
444+
445+
let default_param_init = ref @@ fetch_param_init (Asgns.Constant 0.0)
446+
447+
let param ?(more_label = []) ?input_dims ?output_dims ?input_axes ?output_axes ?deduced ?t label =
450448
let t =
451-
term ~label:(label :: more_label) ~grad_spec:Require_grad ~batch_dims:[] ?input_dims
452-
?output_dims ?input_axes ?output_axes ?deduced ~fetch_op ()
449+
match t with
450+
| Some t ->
451+
t ~label:(label :: more_label) ?input_dims ?output_dims ?input_axes ?output_axes ?deduced ()
452+
| None ->
453+
!default_param_init ~label:(label :: more_label) ?input_dims ?output_dims ?input_axes
454+
?output_axes ?deduced ()
453455
in
454456
let v = t.value in
455457
(* It is convenient to use the param syntax for volatiles (mutable embedded_nodes). *)

lib/tensor.mli

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -217,23 +217,54 @@ val ndarray :
217217
given values must fill the tensor's [value] node precisely; otherwise, the values will be looped
218218
over to populate the [value] node. *)
219219

220+
val default_param_init :
221+
(label:string list ->
222+
?input_dims:int list ->
223+
?output_dims:int list ->
224+
?input_axes:(string * int) list ->
225+
?output_axes:(string * int) list ->
226+
?deduced:Shape.deduce_within_shape ->
227+
unit ->
228+
t)
229+
ref
230+
(** The default initialization operation for {!param} calls that do not pass a [t]. *)
231+
232+
val fetch_param_init :
233+
fetch_op ->
234+
label:string list ->
235+
?input_dims:int list ->
236+
?output_dims:int list ->
237+
?input_axes:(string * int) list ->
238+
?output_axes:(string * int) list ->
239+
?deduced:Shape.deduce_within_shape ->
240+
unit ->
241+
t
242+
(** Helper for {!param} wrappers or to set {!default_param_init}. *)
243+
220244
val param :
221245
?more_label:string list ->
222246
?input_dims:int list ->
223247
?output_dims:int list ->
224248
?input_axes:(string * int) list ->
225249
?output_axes:(string * int) list ->
226250
?deduced:Shape.deduce_within_shape ->
227-
?value:float ->
228-
?values:float array ->
251+
?t:
252+
(label:string list ->
253+
?input_dims:int list ->
254+
?output_dims:int list ->
255+
?input_axes:(string * int) list ->
256+
?output_axes:(string * int) list ->
257+
?deduced:Shape.deduce_within_shape ->
258+
unit ->
259+
t) ->
229260
string ->
230261
t
231-
(* A tensor with no batch axes; input and output axes are by default inferred. [grad_spec] is set to
232-
[Require_grad]. The resulting tensor's label is the passed string, appended by [more_label] if
233-
any. If [value] is provided, the tensor is initialized to the given value. If [values] is
234-
provided, the tensor is initialized to the given values. At most one of [value] or [values] can
235-
be provided. Note: [values] will be looped over if necessary, but shape inference will try
236-
incorporating the number of values as tensor size. *)
262+
(** For proper parameters, [t] should produce a tensor with no batch axes; input and output axes
263+
should by default be inferred; [grad_spec] should be [Require_grad]. [t]'s label is the passed
264+
string, appended by [more_label] if any, other parameters are forwarded to [t]. If [t] is not
265+
provided, {!default_param_init} is used. This function returns [t]'s result with the field
266+
{!field:params} replaced by a singleton set containing that result, and it also updates the
267+
memory modes. *)
237268

238269
val consume_forward_code : t -> comp
239270
(** A forward root is a tensor that is not (currently) used to compute another tensor.

test/hello_world_op.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ let%expect_test "Big matrix" =
510510
let ctx = Backend.make_context stream in
511511
Rand.init 0;
512512
(* Hey is inferred to be a matrix. *)
513-
let hey = Tensor.param ~value:0.5 "hey" in
513+
let hey = TDSL.param ~value:0.5 "hey" in
514514
let zero_to_twenty = TDSL.range 20 in
515515
let y = TDSL.O.((hey * zero_to_twenty) + zero_to_twenty) in
516516
Train.forward_and_forget backend ctx y;

0 commit comments

Comments
 (0)