Skip to content

Commit a75dce1

Browse files
committed
Support Tensor.params field via the Tensor.param function; a couple early missing-init fixes
More missing-init fixes after %cd syntax is updated to allow inline bindings for non-assignment expressions.
1 parent d2f4435 commit a75dce1

File tree

5 files changed

+40
-19
lines changed

5 files changed

+40
-19
lines changed

lib/tensor.ml

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,24 @@ let iter_embedded ~f t =
116116
Set.iter ~f t.forward.embedded_nodes;
117117
Option.iter t.diff ~f:(fun diff -> Set.iter ~f diff.backprop.embedded_nodes)
118118

119-
let init_params _t =
120-
(* Based on the interface documentation, this should collect forward code of t.params *)
121-
(* For now, return empty since the 'params' field is missing from the current implementation *)
122-
Asgns.empty_comp
119+
let rec init_params t =
120+
let open Asgns in
121+
let rem_embedded = ref @@ Set.empty (module Tn) in
122+
let asgns =
123+
Block_comment
124+
( "init params for " ^ Tn.debug_name t.value,
125+
sequential
126+
@@ Set.fold t.params ~init:[] ~f:(fun acc param ->
127+
if Set.is_empty param.params then param.forward.asgns :: acc
128+
else
129+
let asgns = init_params param in
130+
rem_embedded := Set.union !rem_embedded asgns.embedded_nodes;
131+
Seq (asgns.asgns, param.forward.asgns) :: acc) )
132+
in
133+
let embedded_nodes =
134+
Set.fold ~init:!rem_embedded t.params ~f:(fun acc p -> Set.add acc p.value)
135+
in
136+
{ asgns; embedded_nodes }
123137

124138
let initial_default_prec =
125139
Ir.Ops.prec_of_string (Utils.get_global_arg ~default:"single" ~arg_name:"default_prec")
@@ -299,7 +313,8 @@ let op ~(label : string list) ?(ternary_op = Shape.Pointwise_tern)
299313
session_state.backprop_roots <- Map.remove session_state.backprop_roots ti.id);
300314
(* The order is not relevant, we keep the same order as in backprop for readability. *)
301315
let diff = Some { grad = g; zero_grads; backprop } in
302-
let tensor = { params = Set.empty (module T); forward; diff; id; value = v; shape; children } in
316+
let params = Set.union_list (module T) @@ List.map ordered_ts ~f:(fun ti -> ti.params) in
317+
let tensor = { params; forward; diff; id; value = v; shape; children } in
303318
session_state.forward_roots <- Map.add_exn session_state.forward_roots ~key:id ~data:tensor;
304319
session_state.backprop_roots <- Map.add_exn session_state.backprop_roots ~key:id ~data:tensor;
305320
tensor
@@ -409,10 +424,10 @@ let ndarray ?(label = []) ?(grad_spec = Prohibit_grad) ?batch_dims ?input_dims ?
409424
Tn.update_prec ~only_if:is_up_to_fp16 t.value single);
410425
t
411426

412-
let param ?(more_label = []) ?input_dims ?output_dims ?input_axes ?output_axes ?deduced ?value ?values
413-
label =
427+
let param ?(more_label = []) ?input_dims ?output_dims ?input_axes ?output_axes ?deduced ?value
428+
?values label =
414429
let fetch_op_fn ~v:_ =
415-
match values, value with
430+
match (values, value) with
416431
| Some values, None -> Asgns.Constant_fill values
417432
| None, Some value -> Asgns.Constant value
418433
| None, None -> Asgns.Range_over_offsets
@@ -429,7 +444,8 @@ let param ?(more_label = []) ?input_dims ?output_dims ?input_axes ?output_axes ?
429444
update computations. *)
430445
let g = (Option.value_exn ~here:[%here] t.diff).grad in
431446
Tn.update_memory_mode g Never_virtual 26;
432-
t
447+
remove_fwd_root t;
448+
{ t with params = Set.singleton (module T) t }
433449

434450
let debug_name t = Tn.debug_name t.value
435451
let debug_grad t = Tn.debug_name (Option.value_exn t.diff).grad

lib/tensor.mli

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ and comparator_witness
4747
val comparator : (t, comparator_witness) Base.Comparator.t
4848

4949
val init_params : t -> comp
50-
(** [init_params t] simply collects the {!field:forward} code of [t.params] into a single sequence.
51-
*)
50+
(** [init_params t] collects into a single sequence the {!field:forward} code of [t.params], and
51+
transitively the initializations of the parameters of the parameters. *)
5252

5353
val is_fwd_root : t -> bool
5454
val remove_fwd_root : t -> unit

lib/train.ml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,8 @@ let grad_update ?(disable_rootness_check = false) ?(setup_for_parallel = false)
128128

129129
(** See: https://github.com/tinygrad/tinygrad/blob/master/tinygrad/nn/optim.py *)
130130
let sgd_one ~learning_rate ?(momentum = 0.0) ?(weight_decay = 0.0) ?(nesterov = false) p =
131-
if not @@ Set.mem p.Tensor.params p then
132-
raise @@ Tensor.Session_error ("Train.sgd_one: not a parameter", Some p);
131+
if Option.is_none p.Tensor.diff then
132+
raise @@ Tensor.Session_error ("Train.sgd_one: not differentiable", Some p);
133133
[%cd
134134
~~(p "param sgd step";
135135
"sgd_delta" =: p.grad + (!.weight_decay *. p);
@@ -254,6 +254,8 @@ let%track3_sexp parallel_update (type buffer_ptr dev runner event optimize_ctx)
254254
Array.for_all grad_updates ~f:(fun upd ->
255255
[%equal: Idx.static_symbol list] bindings @@ List.map ~f:fst upd.bindings))];
256256
let all_params : Tensor.t array = Set.to_array loss.Tensor.params in
257+
if Array.is_empty all_params then
258+
raise @@ Tensor.Session_error ("Train.parallel_update: no parameters", Some loss);
257259
let _occupancies_debug : bool array array = occupancies_dst_src in
258260
let ctxs = [%debug_notrace Array.map grad_updates ~f:(fun upd -> upd.context)] in
259261
let occupancy_dst ~dst_n = Array.exists ~f:Fn.id occupancies_dst_src.(dst_n) in

test/einsum/moons_demo_variant.ml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,18 @@ let () =
5454
computation. *)
5555
let weight_decay = 0.0001 in
5656
let%op scalar_loss = (margin_loss ++ "...|... => 0") /. !..batch_size in
57+
let init_params = Tensor.init_params scalar_loss in
5758
let update = Train.grad_update scalar_loss in
5859
let%op learning_rate = 0.1 *. ((2 *. !..steps) - !@step_n) /. !..steps in
5960
Train.set_hosted learning_rate.value;
6061
let sgd = Train.sgd_update ~learning_rate ~weight_decay scalar_loss in
62+
let init = Train.to_routine (module Backend) ctx bindings init_params in
6163
let sgd_routine =
62-
Train.to_routine (module Backend) ctx bindings (Asgns.sequence [ update; sgd ])
64+
Train.to_routine (module Backend) init.context bindings (Asgns.sequence [ update; sgd ])
6365
in
6466
let step_ref = IDX.find_exn sgd_routine.bindings step_n in
6567
step_ref := 0;
68+
Train.run init;
6669
for _epoch = 1 to epochs do
6770
Train.sequential_loop sgd_routine.bindings ~f:(fun () ->
6871
Train.run sgd_routine;
@@ -77,8 +80,7 @@ let () =
7780
let points = Tn.points_2d ~xdim:0 ~ydim:1 moons_flat.value in
7881
let classes = Tn.points_1d ~xdim:0 moons_classes.value in
7982
let points1, points2 = Array.partitioni_tf points ~f:Float.(fun i _ -> classes.(i) > 0.) in
80-
let%op mlp_result = mlp "point" in
81-
Train.set_on_host mlp_result.value;
83+
let%cd mlp_result = mlp "point" in
8284
let result_routine =
8385
Train.to_routine
8486
(module Backend)
@@ -89,8 +91,6 @@ let () =
8991
in
9092
let callback (x, y) =
9193
Tn.set_values point.value [| x; y |];
92-
(* For the gccjit backend, point is only on host, not on device. For cuda, this will be
93-
needed. *)
9494
Train.run result_routine;
9595
Float.(mlp_result.@[0] >= 0.)
9696
in

test/micrograd_demo_logging.ml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,12 @@ let () =
3030
let%op g = f /. 2 in
3131
let%op g = g + (10. /. f) in
3232
List.iter ~f:(Option.iter ~f:(fun diff -> Train.set_hosted diff.Tensor.grad)) [ a.diff; b.diff ];
33+
let init_params = Tensor.init_params g in
3334
let update = Train.grad_update g in
34-
let step = Train.to_routine (module Backend) ctx IDX.empty update in
35+
let init = Train.to_routine (module Backend) ctx IDX.empty init_params in
36+
let step = Train.to_routine (module Backend) init.context IDX.empty update in
3537
Utils.capture_stdout_logs @@ fun () ->
38+
Train.run init;
3639
Train.run step;
3740
Tensor.print ~with_code:false ~with_grad:false `Default g;
3841
Tensor.print ~with_code:false ~with_grad:true `Default a;

0 commit comments

Comments
 (0)